import ipywidgets as widgets
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import wot
# input paths
FULL_DS_PATH = 'data/ExprMatrix.h5ad'
CELL_DAYS_PATH = 'data/cell_days.txt'
VAR_DS_PATH = 'data/ExprMatrix.var.genes.h5ad'
TMAP_PATH = 'tmaps/serum'
CELL_SETS_PATH = 'data/major_cell_sets.gmt'
COORDS_PATH = 'data/fle_coords.txt'
In their raw form, the transport maps show the descendants and ancestors of each individual cell. Here we explore several different techniques to analyze the ancestors, descendants and trajectories of sets of cells.
Given a set $C$ of cells at time $t_j$, we can compute the descendant distribution at time $t_{j+1}$ by pushing the cell set through the transport matrix. This push forward operation is implemented in a matrix multiplication as follows. We first form a probability vector $p_{t_j}$ to represent the cell set $C$ as follows:
$$p_{t_j}(x) = \begin{cases} \frac 1 {|C|} & \quad x \in C \\ 0 & \quad \text{otherwise}.\end{cases}$$In the following code block we generate probability vectors for the cell sets IPS, Neural, Trophoblast, and Stromal on day 18. The $\texttt{populations}$ object is a list of these probability vectors for different cell sets.
tmap_model = wot.tmap.TransportMapModel.from_directory(TMAP_PATH)
cell_sets = wot.io.read_sets(CELL_SETS_PATH, as_dict=True)
populations = tmap_model.population_from_cell_sets(cell_sets, at_time=12)
We can then push forward each probability vector by multiplying by the transport map on the right
$$p_{t_{j+1}}^T = p_{t_{j}}^T \pi_{t_j,t_{j+1}}.$$Multiplying $p_{t_{j+1}}$ by $\pi_{t_{j+1},t_{j+2}}$ would push it forward again to give the descendant distribution at time $t_{j+2}$. Continuing in this way, we can compute the descendant distribution at any later time point $t_{\ell} > t_j$.
To compute the ancestors of $C$ at an earlier time point $t_i < t_j$, we pull the cell set back through the transport map. This pull-back operation is also implemented as a matrix multiplication $$p_{t_{j-1}} = \pi_{t_{j-1},t_j} p_{t_j}.$$ The trajectory of a cell set $C$ refers to the sequence of ancestor distributions at earlier time points and descendant distributions at later time points.
The method trajectories
takes a list of probability vectors and returns the trajectories containing descendant distributions at later time points and ancestor distributions at earlier time points.
trajectory_ds = tmap_model.trajectories(populations)
We can now visualize trajectories on the force layout embedding coordinates.
# Load embedding coordinates
coord_df = pd.read_csv(COORDS_PATH, sep='\t', index_col=0)
nbins = 500
xrange = coord_df['x'].min(), coord_df['x'].max()
yrange = coord_df['y'].min(), coord_df['y'].max()
coord_df['x'] = np.floor(
np.interp(coord_df['x'], [xrange[0], xrange[1]], [0, nbins - 1])).astype(int)
coord_df['y'] = np.floor(
np.interp(coord_df['y'], [yrange[0], yrange[1]], [0, nbins - 1])).astype(int)
trajectory_ds.obs = trajectory_ds.obs.join(coord_df)
# Visualize trajectories
trajectory_dropdown = widgets.Dropdown(
options=trajectory_ds.var.index,
description='Trajectory:'
)
def update_trajectory_vis(name):
figure = plt.figure(figsize=(10, 10))
plt.axis('off')
plt.tight_layout()
plt.title(name)
plt.scatter(coord_df['x'], coord_df['y'], c='#f0f0f0',
s=4, marker=',', edgecolors='none', alpha=0.8)
binned_df = trajectory_ds.obs.copy()
binned_df['values'] = trajectory_ds[:, name].X
binned_df = binned_df.groupby(['x', 'y'], as_index=False).sum()
plt.scatter(binned_df['x'], binned_df['y'], c=binned_df['values'],
s=6, marker=',', edgecolors='none', vmax=binned_df['values'].quantile(0.975))
plt.colorbar().ax.set_title('Trajectory')
widgets.interact(update_trajectory_vis, name=trajectory_dropdown)
interactive(children=(Dropdown(description='Trajectory:', options=('IPS', 'Stromal', 'Neural', 'Trophoblast', …
<function __main__.update_trajectory_vis(name)>
We can also generate a trajectory movie across time
# from matplotlib import animation
# from IPython.display import Video
# name = 'IPS'
# movie_file_name = '{}_trajectory.mov'.format(name)
# figure = plt.figure(figsize=(10, 10))
# plt.axis('off')
# plt.tight_layout()
# unique_days = trajectory_ds.obs['day'].unique()
# binned_df = trajectory_ds.obs.copy()
# binned_df['values'] = trajectory_ds[:, name].X
# binned_df = binned_df.groupby(['x', 'y'], as_index=False).sum()
# vmax=binned_df['values'].quantile(0.975)
# def animate(i):
# _trajectory_ds = trajectory_ds[trajectory_ds.obs['day']==unique_days[i]]
# binned_df = _trajectory_ds.obs.copy()
# plt.suptitle('Day {}, {:,} cells'.format(unique_days[i], _trajectory_ds.shape[0]), y=0.95)
# plt.scatter(coord_df['x'], coord_df['y'], c='#f0f0f0', s=4, marker=',', edgecolors='none')
# binned_df['values'] = _trajectory_ds[:, name].X
# binned_df = binned_df.groupby(['x', 'y'], as_index=False).sum()
# plt.scatter(binned_df['x'], binned_df['y'], c=binned_df['values'],
# s=6, marker=',', edgecolors='none', vmin=0, vmax=vmax)
# anim = animation.FuncAnimation(figure, func=animate, frames=range(0, len(unique_days)), init_func=lambda **args:None, repeat=False, interval=400)
# anim.save(movie_file_name)
# plt.close(figure)
# Video(movie_file_name)
We now show how to compute trends in expression along trajectories. Mathematically, the trajectory is represented by a probability distribution over cells at each time point. Therefore we can easily compute the average expression according to this probability distribution.
We begin by computing expression trends for all genes and saving the trajectory_trends
to disk.
#Load expression data
adata = wot.io.read_dataset(FULL_DS_PATH)
#Compute trends for all genes
trajectory_trends = wot.tmap.trajectory_trends_from_trajectory(trajectory_ds, adata)
# Save each trajectory in a separate file
for i in range(len(trajectory_trends)):
wot.io.write_dataset(trajectory_trends[i], trajectory_ds.var.index[i] + '_trends.txt')
Transforming to str index. Transforming to str index. Transforming to str index. Transforming to str index. Transforming to str index.
Read in trajectory trends
trajectory_trend_datasets = []
trajectory_names = []
for i in range(trajectory_ds.shape[1]):
trajectory_names.append(trajectory_ds.var.index[i])
trajectory_trend_datasets.append(wot.io.read_dataset(trajectory_ds.var.index[i] + '_trends.txt'))
Transforming to str index. Transforming to str index. Transforming to str index. Transforming to str index. Transforming to str index.
Visualize trends
trajectory_dropdown = widgets.SelectMultiple(
options=trajectory_ds.var.index,
value=[trajectory_ds.var.index[0]],
description='Trajectory:'
)
gene_input = widgets.Text(
placeholder='',
description='Genes:',
value='Nanog',
continuous_update=False
)
def update_trends_vis(selected_trajectories, gene_names):
gene_names = gene_names.replace(' ', '').split(',')
figure = plt.figure(figsize=(10, 10))
for selected_trajectory in selected_trajectories:
trajectory_index = trajectory_names.index(selected_trajectory)
mean = trajectory_trend_datasets[trajectory_index]
mean = mean[:, mean.var.index.str.lower().isin([gene.lower() for gene in gene_names])]
timepoints = mean.obs.index.values.astype(float)
mean.obs.index = mean.obs.index.astype('category')
if mean.shape[1] > 0:
for i in range(mean.shape[1]): # each gene
mean_i = mean[:, i].X
plt.plot(timepoints, mean_i, label=mean.var.index[i] + ', ' + selected_trajectory)
plt.xlabel("Day")
plt.ylabel("Expression")
plt.legend()
widgets.interact(update_trends_vis, selected_trajectories=trajectory_dropdown, gene_names=gene_input)
interactive(children=(SelectMultiple(description='Trajectory:', index=(0,), options=('IPS', 'Stromal', 'Neural…
<function __main__.update_trends_vis(selected_trajectories, gene_names)>
adata_var = wot.io.read_dataset(VAR_DS_PATH)
divergence_df = wot.tmap.trajectory_divergence(adata_var, trajectory_ds, distance_metric='total_variation')
Plot divergence
divergence_df['name'] = divergence_df['name1'].str.split('/').str.get(0) + ' vs. ' + divergence_df['name2'].str.split('/').str.get(
0)
plt.figure(figsize=(10, 10))
plt.xlabel("Day")
plt.ylabel("Distance")
for p, d in divergence_df.groupby('name'):
plt.plot(d['day2'], d['distance'], '-o', label=p)
plt.legend(loc='best')
<matplotlib.legend.Legend at 0x1a32900390>