import os
os.chdir('../')
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=UserWarning)
import scanpy as sc
import torch
import scarches as sca
from scarches.dataset.trvae.data_handling import remove_sparsity
import matplotlib.pyplot as plt
import numpy as np
import gdown
sc.settings.set_figure_params(dpi=200, frameon=False)
sc.set_figure_params(dpi=200)
sc.set_figure_params(figsize=(4, 4))
torch.set_printoptions(precision=3, sci_mode=False, edgeitems=7)
Here we use the CelSeq2 and SS2 studies as query data and the other 3 studies as reference atlas. We strongly suggest to use earlystopping to avoid over-fitting. The best earlystopping criteria are the 'elbo' for SCVI pretraining and also for unlabelled surgery training and 'accurarcy' for semi-supervised SCANVI training.
condition_key = 'study'
cell_type_key = 'cell_type'
target_conditions = ['Pancreas CelSeq2', 'Pancreas SS2']
vae_epochs = 500
scanvi_epochs = 200
surgery_epochs = 500
early_stopping_kwargs = {
"early_stopping_metric": "elbo",
"save_best_state_metric": "elbo",
"patience": 10,
"threshold": 0,
"reduce_lr_on_plateau": True,
"lr_patience": 8,
"lr_factor": 0.1,
}
early_stopping_kwargs_scanvi = {
"early_stopping_metric": "accuracy",
"save_best_state_metric": "accuracy",
"on": "full_dataset",
"patience": 10,
"threshold": 0.001,
"reduce_lr_on_plateau": True,
"lr_patience": 8,
"lr_factor": 0.1,
}
early_stopping_kwargs_surgery = {
"early_stopping_metric": "elbo",
"save_best_state_metric": "elbo",
"on": "full_dataset",
"patience": 10,
"threshold": 0.001,
"reduce_lr_on_plateau": True,
"lr_patience": 8,
"lr_factor": 0.1,
}
url = 'https://drive.google.com/uc?id=1ehxgfHTsMZXy6YzlFKGJOsBKQ5rrvMnd'
output = 'pancreas.h5ad'
gdown.download(url, output, quiet=False)
Downloading... From: https://drive.google.com/uc?id=1ehxgfHTsMZXy6YzlFKGJOsBKQ5rrvMnd To: C:\Users\sergei.rybakov\projects\notebooks\pancreas.h5ad 126MB [00:29, 4.31MB/s]
'pancreas.h5ad'
adata_all = sc.read('pancreas.h5ad')
adata = adata_all.raw.to_adata()
adata = remove_sparsity(adata)
source_adata = adata[~adata.obs[condition_key].isin(target_conditions)].copy()
target_adata = adata[adata.obs[condition_key].isin(target_conditions)].copy()
source_adata
AnnData object with n_obs × n_vars = 10294 × 1000 obs: 'batch', 'study', 'cell_type', 'size_factors'
target_adata
AnnData object with n_obs × n_vars = 5387 × 1000 obs: 'batch', 'study', 'cell_type', 'size_factors'
sca.dataset.setup_anndata(source_adata, batch_key=condition_key, labels_key=cell_type_key)
INFO Using batches from adata.obs["study"] INFO Using labels from adata.obs["cell_type"] INFO Using data from adata.X INFO Computing library size prior per batch INFO Successfully registered anndata object containing 10294 cells, 1000 vars, 3 batches, 8 labels, and 0 proteins. Also registered 0 extra categorical covariates and 0 extra continuous covariates. INFO Please do not further modify adata until model is trained.
The parameters chosen here proofed to work best in the case of surgery with SCANVI.
vae = sca.models.SCANVI(
source_adata,
"Unknown",
n_layers=2,
encode_covariates=True,
deeply_inject_covariates=False,
use_layer_norm="both",
use_batch_norm="none",
)
print("Labelled Indices: ", len(vae._labeled_indices))
print("Unlabelled Indices: ", len(vae._unlabeled_indices))
Labelled Indices: 10294 Unlabelled Indices: 0
vae.train(
n_epochs_unsupervised=vae_epochs,
n_epochs_semisupervised=scanvi_epochs,
unsupervised_trainer_kwargs=dict(early_stopping_kwargs=early_stopping_kwargs),
semisupervised_trainer_kwargs=dict(metrics_to_monitor=["elbo", "accuracy"],
early_stopping_kwargs=early_stopping_kwargs_scanvi),
frequency=1
)
INFO Training Unsupervised Trainer for 500 epochs. INFO Training SemiSupervised Trainer for 200 epochs. INFO KL warmup for 400 epochs Training...: 20%|█████████████▍ | 99/500 [04:03<18:22, 2.75s/it]INFO Reducing LR on epoch 99. Training...: 25%|████████████████▊ | 125/500 [05:15<17:12, 2.75s/it]INFO Reducing LR on epoch 125. Training...: 25%|█████████████████ | 127/500 [05:20<17:06, 2.75s/it]INFO Stopping early: no improvement of more than 0 nats in 10 epochs INFO If the early stopping criterion is too strong, please instantiate it with different parameters in the train method. Training...: 25%|█████████████████ | 127/500 [05:23<15:50, 2.55s/it] INFO Training is still in warming up phase. If your applications rely on the posterior quality, consider training for more epochs or reducing the kl warmup. INFO Training time: 214 s. / 500 epochs INFO KL warmup phase exceeds overall training phaseIf your applications rely on the posterior quality, consider training for more epochs or reducing the kl warmup. INFO KL warmup for 400 epochs Training...: 19%|████████████▉ | 38/200 [05:51<25:02, 9.28s/it]INFO Reducing LR on epoch 38. Training...: 20%|█████████████▌ | 40/200 [06:10<24:43, 9.27s/it]INFO Stopping early: no improvement of more than 0.001 nats in 10 epochs INFO If the early stopping criterion is too strong, please instantiate it with different parameters in the train method. Training...: 20%|█████████████▌ | 40/200 [06:19<25:18, 9.49s/it] INFO Training is still in warming up phase. If your applications rely on the posterior quality, consider training for more epochs or reducing the kl warmup. INFO Training time: 228 s. / 200 epochs
reference_latent = sc.AnnData(vae.get_latent_representation())
reference_latent.obs["cell_type"] = source_adata.obs[cell_type_key].tolist()
reference_latent.obs["batch"] = source_adata.obs[condition_key].tolist()
sc.pp.neighbors(reference_latent, n_neighbors=8)
sc.tl.leiden(reference_latent)
sc.tl.umap(reference_latent)
sc.pl.umap(reference_latent,
color=['batch', 'cell_type'],
frameon=False,
wspace=0.6,
)
... storing 'cell_type' as categorical ... storing 'batch' as categorical
One can also compute the accuracy of the learned classifier
reference_latent.obs['predictions'] = vae.predict()
print("Acc: {}".format(np.mean(reference_latent.obs.predictions == reference_latent.obs.cell_type)))
Acc: 0.9619195647950263
After pretraining the model can be saved for later use
ref_path = 'ref_model/'
vae.save(ref_path, overwrite=True)
If the cell types in 'target_adata' are equal to or a subset of the reference data cell types, one can just pass the adata without further preprocessing. It is also possible then to do semi-supervised training with scArches.
However if there are new cell types in 'target_adata' or if there is no '.obs' in the anndata for cell type labels (e.g. the data is unlabeled), one can only use scANVI in an unsupervised manner during surgery due to the nature of the classifier.
In addition one has to preprocess 'target_adata' in the following way:
If there are new celltypes in there, save the original labels in other column and replace all labels with unlabeled category:
target_adata.obs['orig_cell_types'] = target_adata.obs[cell_type_key].copy()
target_adata.obs[cell_type_key] = vae.unlabeled_category_
If there is no '.obs' column for cell types:
target_adata.obs[cell_type_key] = vae.unlabeled_category_
If 'target_adata' is in the right format, one can proceed with the surgery pipeline. Here we do the surgery unsupervised, but due to the overlapping cell types in query and reference data, one could also do supervised or semi-supervised surgery by setting the indices accordingly.
model = sca.models.SCANVI.load_query_data(
target_adata,
ref_path,
freeze_dropout = True,
)
model._unlabeled_indices = np.arange(target_adata.n_obs)
model._labeled_indices = []
print("Labelled Indices: ", len(model._labeled_indices))
print("Unlabelled Indices: ", len(model._unlabeled_indices))
INFO Using data from adata.X INFO Computing library size prior per batch INFO Registered keys:['X', 'batch_indices', 'local_l_mean', 'local_l_var', 'labels'] INFO Successfully registered anndata object containing 5387 cells, 1000 vars, 5 batches, 8 labels, and 0 proteins. Also registered 0 extra categorical covariates and 0 extra continuous covariates. WARNING Make sure the registered X field in anndata contains unnormalized count data. Labelled Indices: 0 Unlabelled Indices: 5387
model.train(
n_epochs_semisupervised=surgery_epochs,
train_base_model=False,
semisupervised_trainer_kwargs=dict(metrics_to_monitor=["accuracy", "elbo"],
weight_decay=0,
early_stopping_kwargs=early_stopping_kwargs_surgery
),
frequency=1
)
INFO Training Unsupervised Trainer for 400 epochs. INFO Training SemiSupervised Trainer for 500 epochs. INFO KL warmup for 400 epochs Training...: 22%|██████████████▊ | 111/500 [07:08<25:03, 3.87s/it]INFO Reducing LR on epoch 111. Training...: 23%|███████████████▏ | 113/500 [07:16<24:55, 3.86s/it]INFO Stopping early: no improvement of more than 0.001 nats in 10 epochs INFO If the early stopping criterion is too strong, please instantiate it with different parameters in the train method. Training...: 23%|███████████████▏ | 113/500 [07:20<25:08, 3.90s/it] INFO Training is still in warming up phase. If your applications rely on the posterior quality, consider training for more epochs or reducing the kl warmup. INFO Training time: 217 s. / 500 epochs
query_latent = sc.AnnData(model.get_latent_representation())
query_latent.obs['cell_type'] = target_adata.obs[cell_type_key].tolist()
query_latent.obs['batch'] = target_adata.obs[condition_key].tolist()
WARNING Make sure the registered X field in anndata contains unnormalized count data.
sc.pp.neighbors(query_latent)
sc.tl.leiden(query_latent)
sc.tl.umap(query_latent)
plt.figure()
sc.pl.umap(
query_latent,
color=["batch", "cell_type"],
frameon=False,
wspace=0.6,
)
... storing 'cell_type' as categorical ... storing 'batch' as categorical
<Figure size 320x320 with 0 Axes>
surgery_path = 'surgery_model'
model.save(surgery_path, overwrite=True)
query_latent.obs['predictions'] = model.predict()
print("Acc: {}".format(np.mean(query_latent.obs.predictions == query_latent.obs.cell_type)))
WARNING Make sure the registered X field in anndata contains unnormalized count data. Acc: 0.8791535177278633
df = query_latent.obs.groupby(["cell_type", "predictions"]).size().unstack(fill_value=0)
norm_df = df / df.sum(axis=0)
plt.figure(figsize=(8, 8))
_ = plt.pcolor(norm_df)
_ = plt.xticks(np.arange(0.5, len(df.columns), 1), df.columns, rotation=90)
_ = plt.yticks(np.arange(0.5, len(df.index), 1), df.index)
plt.xlabel("Predicted")
plt.ylabel("Observed")
Text(0, 0.5, 'Observed')
adata_full = source_adata.concatenate(target_adata)
full_latent = sc.AnnData(model.get_latent_representation(adata=adata_full))
full_latent.obs['cell_type'] = adata_full.obs[cell_type_key].tolist()
full_latent.obs['batch'] = adata_full.obs[condition_key].tolist()
INFO Input adata not setup with scvi. attempting to transfer anndata setup INFO Using data from adata.X INFO Computing library size prior per batch INFO Registered keys:['X', 'batch_indices', 'local_l_mean', 'local_l_var', 'labels'] INFO Successfully registered anndata object containing 15681 cells, 1000 vars, 5 batches, 8 labels, and 0 proteins. Also registered 0 extra categorical covariates and 0 extra continuous covariates.
sc.pp.neighbors(full_latent)
sc.tl.leiden(full_latent)
sc.tl.umap(full_latent)
plt.figure()
sc.pl.umap(
full_latent,
color=["batch", "cell_type"],
frameon=False,
wspace=0.6,
)
... storing 'cell_type' as categorical ... storing 'batch' as categorical
<Figure size 320x320 with 0 Axes>
full_latent.obs['predictions'] = model.predict(adata=adata_full)
print("Acc: {}".format(np.mean(full_latent.obs.predictions == full_latent.obs.cell_type)))
Acc: 0.933486384796888
sc.pp.neighbors(full_latent)
sc.tl.leiden(full_latent)
sc.tl.umap(full_latent)
plt.figure()
sc.pl.umap(
full_latent,
color=["predictions", "cell_type"],
frameon=False,
wspace=0.6,
)
... storing 'predictions' as categorical
<Figure size 320x320 with 0 Axes>