Running trVAE on Haber et. al dataset for batch-removal and style transfer

In [2]:
import scanpy as sc
import pandas as pd
import numpy as np
import trvae
Using TensorFlow backend.
In [3]:
sc.set_figure_params(dpi=100)

Loading & preparing data

condition_key is the key for your batch or conditions labels in your adata.obs

In [27]:
condition_key = "condition"
In [5]:
adata = sc.read("data/haber_count.h5ad")
adata
Out[5]:
AnnData object with n_obs × n_vars = 9842 × 15215
    obs: 'batch', 'barcode', 'condition', 'cell_label'

Normalizing & Extracting Top 2000 Highly Variable Genes

One can use more genes but in order to train the network quickly, we will extract top 2000 genes. This can be done with normalize_hvg function in the tl module of trVAE package. The function accepts the following arguments:

  • adata: adata containing raw counts in its .X attribute.
  • target_sum: total counts per cell after normalization
  • size_factors: whether to normalize the adata and put total counts per cell in "size_factors" column of adata.obs (True is recommended).
  • scale_input: whether to scale the dataset after normalization (False is recommended).
  • logtrans_input: whether to log-transform the adata after normalization (True is recommended).
  • n_top_genes: number of highly variable genes to be selected after adata normalization.
In [6]:
adata = trvae.tl.normalize_hvg(adata, 
                               target_sum=1e4,
                               size_factors=True, 
                               scale_input=False, 
                               logtrans_input=True, 
                               n_top_genes=2000)
adata
Out[6]:
AnnData object with n_obs × n_vars = 9842 × 2000
    obs: 'batch', 'barcode', 'condition', 'cell_label', 'size_factors'
    var: 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'log1p'

Original Data UMAP Visualization

In [7]:
sc.pp.neighbors(adata)
sc.tl.umap(adata)
WARNING: You’re trying to run this on 2000 dimensions of `.X`, if you really want this, set `use_rep='X'`.
         Falling back to preprocessing with `sc.pp.pca` and default params.

just for visualization, no cell type label is required for the model

In [30]:
cell_type_key = "cell_label"
In [8]:
sc.pl.umap(adata, color=[condition_key, cell_type_key], 
           wspace=0.6, 
           frameon=False)

Calculate number of batches

In [7]:
conditions = adata.obs[condition_key].unique().tolist()

Create the network

Some of network parameters:

  • x_dimension: size input features (necessary)
  • conditons: list of unique batches(studies) names
  • architecture: architecture of the network (optional)
  • output_activation: activation function of trVAE's last layer
  • alpha: coefficient of KL Divergence loss (optional)
  • beta: coefficient of MMD loss (optional)
  • eta: coefficient of reconstruction (MSE or SSE) loss (optional) can be one of the relu, leaky_relu, linear, ...
  • gene_names: list of gene names (adata.var_names.tolist())
  • loss_fn: trVAE's loss function (Has to be one of mse or sse)
In [10]:
network = trvae.models.trVAE(x_dimension=adata.shape[1],
                             architecture=[256,64],
                             z_dimension=10,
                             gene_names=adata.var_names.tolist(),
                             conditions=conditions,
                             model_path='./models/trVAE/haber/',
                             alpha=0.0001,
                             beta=50,
                             eta=100,
                             loss_fn='sse',
                             output_activation='linear')
trVAE' network has been successfully constructed!
trVAE'snetwork has been successfully compiled!

Training trVAE

You can train scArches with train function with the following parameters:

  • adata: Annotated dataset used for training and evaluating scArches.
  • condition_key: name of the column in obs matrix in adata which contains the batch_id for each sample.
  • n_epochs: number of epochs used to train scArches.
  • batch_size: number of sample used to sample as mini-batches in order to optimize scArches. Please NOTE that for MSE loss with MMD regularization batch sizes upper that 512 is highly recommended
  • save: whether to save scArches' model and configs after training phase or not.
  • retrain: if False and scArches' pretrained model exists in model_path, will restore scArches' weights. Otherwise will train and validate scArches on adata.
In [11]:
network.train(adata,
              condition_key,
              train_size=0.8,
              n_epochs=50,
              batch_size=512,
              early_stop_limit=10,
              lr_reducer=20,
              verbose=5,
              save=True,
              )
 |████████████████████| 100.0%  - loss: 251.2036 - mmd_loss: 0.4532 - recon_loss: 250.7504 - val_loss: 8249.9760 - val_mmd_loss: 70.4720 - val_recon_loss: 8179.50403

trVAE has been successfully saved in ./models/trVAE/haber/.

Getting corrected latent adata

if you use trVAE for batch-removal we recommend to use z Latent space computed using get_latent function This function has the following parameters:

  • adata: Annotated dataset to be transformed to latent space
  • batch_key: Name of the column in obs matrix in adata which contains the study for each sample.
In [12]:
latent_adata = network.get_latent(adata, condition_key)
latent_adata
Out[12]:
AnnData object with n_obs × n_vars = 9842 × 10
    obs: 'batch', 'barcode', 'condition', 'cell_label', 'size_factors'

UMAP visualization of latent space

In [13]:
sc.pp.neighbors(latent_adata)
sc.tl.umap(latent_adata)
In [14]:
sc.pl.umap(latent_adata, color=[condition_key, cell_type_key], wspace=0.5, frameon=False)