import scanpy as sc
import pandas as pd
import numpy as np
import trvae
Using TensorFlow backend.
sc.set_figure_params(dpi=100)
condition_key
is the key for your batch or conditions labels in your adata.obs
condition_key = "condition"
adata = sc.read("data/haber_count.h5ad")
adata
AnnData object with n_obs × n_vars = 9842 × 15215 obs: 'batch', 'barcode', 'condition', 'cell_label'
One can use more genes but in order to train the network quickly, we will extract top 2000 genes. This can be done with normalize_hvg
function in the tl
module of trVAE package. The function accepts the following arguments:
.X
attribute.adata
and put total counts per cell in "size_factors" column of adata.obs
(True
is recommended).False
is recommended).adata
after normalization (True
is recommended).adata
normalization.adata = trvae.tl.normalize_hvg(adata,
target_sum=1e4,
size_factors=True,
scale_input=False,
logtrans_input=True,
n_top_genes=2000)
adata
AnnData object with n_obs × n_vars = 9842 × 2000 obs: 'batch', 'barcode', 'condition', 'cell_label', 'size_factors' var: 'highly_variable', 'means', 'dispersions', 'dispersions_norm' uns: 'log1p'
sc.pp.neighbors(adata)
sc.tl.umap(adata)
WARNING: You’re trying to run this on 2000 dimensions of `.X`, if you really want this, set `use_rep='X'`. Falling back to preprocessing with `sc.pp.pca` and default params.
just for visualization, no cell type label is required for the model
cell_type_key = "cell_label"
sc.pl.umap(adata, color=[condition_key, cell_type_key],
wspace=0.6,
frameon=False)
conditions = adata.obs[condition_key].unique().tolist()
Some of network parameters:
relu
, leaky_relu
, linear
, ...adata.var_names.tolist()
)mse
or sse
)network = trvae.models.trVAE(x_dimension=adata.shape[1],
architecture=[256,64],
z_dimension=10,
gene_names=adata.var_names.tolist(),
conditions=conditions,
model_path='./models/trVAE/haber/',
alpha=0.0001,
beta=50,
eta=100,
loss_fn='sse',
output_activation='linear')
trVAE' network has been successfully constructed! trVAE'snetwork has been successfully compiled!
You can train scArches with train function with the following parameters:
network.train(adata,
condition_key,
train_size=0.8,
n_epochs=50,
batch_size=512,
early_stop_limit=10,
lr_reducer=20,
verbose=5,
save=True,
)
|████████████████████| 100.0% - loss: 251.2036 - mmd_loss: 0.4532 - recon_loss: 250.7504 - val_loss: 8249.9760 - val_mmd_loss: 70.4720 - val_recon_loss: 8179.50403 trVAE has been successfully saved in ./models/trVAE/haber/.
if you use trVAE for batch-removal we recommend to use z Latent space computed using get_latent
function This function has the following parameters:
latent_adata = network.get_latent(adata, condition_key)
latent_adata
AnnData object with n_obs × n_vars = 9842 × 10 obs: 'batch', 'barcode', 'condition', 'cell_label', 'size_factors'
sc.pp.neighbors(latent_adata)
sc.tl.umap(latent_adata)
sc.pl.umap(latent_adata, color=[condition_key, cell_type_key], wspace=0.5, frameon=False)
we transfer all conditions to the batch labels with maximum number of samples.
target_condition
is the the condtion that you want your source adata
be transformed to.
adata.obs[condition_key].value_counts()
Control 3240 Hpoly.Day10 2711 Hpoly.Day3 2121 Salmonella 1770 Name: condition, dtype: int64
target_condition = adata.obs[condition_key].value_counts().index[0]
corrected_data = network.predict(adata,condition_key,target_condition=target_condition)
sc.pp.neighbors(corrected_data)
sc.tl.umap(corrected_data)
WARNING: You’re trying to run this on 2000 dimensions of `.X`, if you really want this, set `use_rep='X'`. Falling back to preprocessing with `sc.pp.pca` and default params.
As observed in the corrected gene expression data all samples were mapped to control cells and are mixd now
sc.pl.umap(corrected_data, color=[condition_key, cell_type_key], wspace=0.5, frameon=False)