import scanpy as sc
import pandas as pd
import numpy as np
import trvae
sc.set_figure_params(dpi=100)
condition_key
is the key for your batch or conditions labels in your adata.obs
condition_key = "condition"
adata = sc.read("data/haber_count.h5ad")
adata
One can use more genes but in order to train the network quickly, we will extract top 2000 genes. This can be done with normalize_hvg
function in the tl
module of trVAE package. The function accepts the following arguments:
.X
attribute.adata
and put total counts per cell in "size_factors" column of adata.obs
(True
is recommended).False
is recommended).adata
after normalization (True
is recommended).adata
normalization.adata = trvae.tl.normalize_hvg(adata,
target_sum=1e4,
size_factors=True,
scale_input=False,
logtrans_input=True,
n_top_genes=2000)
adata
sc.pp.neighbors(adata)
sc.tl.umap(adata)
just for visualization, no cell type label is required for the model
cell_type_key = "cell_label"
sc.pl.umap(adata, color=[condition_key, cell_type_key],
wspace=0.6,
frameon=False)
conditions = adata.obs[condition_key].unique().tolist()
Some of network parameters:
relu
, leaky_relu
, linear
, ... adata.var_names.tolist()
)mse
or sse
)network = trvae.models.trVAE(x_dimension=adata.shape[1],
architecture=[256,64],
z_dimension=10,
gene_names=adata.var_names.tolist(),
conditions=conditions,
model_path='./models/trVAE/haber/',
alpha=0.0001,
beta=50,
eta=100,
loss_fn='sse',
output_activation='linear')
You can train scArches with train function with the following parameters:
network.train(adata,
condition_key,
train_size=0.8,
n_epochs=50,
batch_size=512,
early_stop_limit=10,
lr_reducer=20,
verbose=5,
save=True,
)
if you use trVAE for batch-removal we recommend to use z Latent space computed using get_latent
function This function has the following parameters:
latent_adata = network.get_latent(adata, condition_key)
latent_adata
sc.pp.neighbors(latent_adata)
sc.tl.umap(latent_adata)
sc.pl.umap(latent_adata, color=[condition_key, cell_type_key], wspace=0.5, frameon=False)