Source code for sctour.train

import torch
from torch.utils.data import DataLoader
from torchdiffeq import odeint
from typing import Optional, Union
from typing_extensions import Literal
import numpy as np
from anndata import AnnData
from scipy import sparse
from scipy.sparse import spmatrix
from tqdm import tqdm
import os
from collections import defaultdict

from .model import TNODE
from ._utils import get_step_size
from .data import split_data, MakeDataset, BatchSampler
from . import logger


##reverse time
[docs]def reverse_time( T: np.ndarray, ) -> np.ndarray: """ Post-inference adjustment to reverse the pseudotime. Parameters ---------- T The pseudotime inferred for each cell. Returns ---------- :class:`~numpy.ndarray` The reversed pseudotime. """ return 1 - T
[docs]class Trainer: """ Class for implementing the scTour training process. Parameters ---------- adata An :class:`~anndata.AnnData` object for the training data. percent The percentage of cells used for model training. Default to 0.2 when the cell number > 10,000 and to 0.9 otherwise. n_latent The dimensionality of the latent space. (Default: 5) n_ode_hidden The dimensionality of the hidden layer for the latent ODE function. (Default: 25) n_vae_hidden The dimensionality of the hidden layer for the VAE. (Default: 128) batch_norm Whether to include a `BatchNorm` layer. (Default: `False`) ode_method The solver for ODE. List of ODE solvers can be found in `torchdiffeq`. (Default: `'euler'`) step_size The step size during integration. alpha_recon_lec The scaling factor for the reconstruction error from encoder-derived latent space. (Default: 0.5) alpha_recon_lode The scaling factor for the reconstruction error from ODE-solver-derived latent space. (Default: 0.5) alpha_kl The scaling factor for the KL divergence in the loss function. (Default: 1.0) loss_mode The mode for calculating the reconstruction error. (Default: `'nb'`) Three modes are included: ``'mse'``: mean squared error; ``'nb'``: negative binomial conditioned likelihood; ``'zinb'``: zero-inflated negative binomial conditioned likelihood. nepoch Number of epochs. batch_size The batch size during training. (Default: 1024) drop_last Whether or not drop the last batch when its size is smaller than `batch_size`. (Default: `False`) lr The learning rate. (Default: 1e-3) wt_decay The weight decay (L2 penalty) for Adam optimizer. (Default: 1e-6) eps The `eps` parameter for Adam optimizer. (Default: 0.01) random_state The seed for generating random numbers. (Default: 0) val_frac The percentage of data used for validation. (Default: 0.1) use_gpu Whether to use GPU when available. (Default: `True`) """ def __init__( self, adata: AnnData, percent: Optional[float] = None, n_latent: int = 5, n_ode_hidden: int = 25, n_vae_hidden: int = 128, batch_norm: bool = False, ode_method: str = 'euler', step_size: Optional[int] = None, alpha_recon_lec: float = 0.5, alpha_recon_lode: float = 0.5, alpha_kl: float = 1., loss_mode: Literal['mse', 'nb', 'zinb'] = 'nb', nepoch: Optional[int] = None, batch_size: int = 1024, drop_last: bool = False, lr: float = 1e-3, wt_decay: float = 1e-6, eps: float = 0.01, random_state: int = 0, val_frac: float = 0.1, use_gpu: bool = True, ): self.loss_mode = loss_mode if self.loss_mode not in ['mse', 'nb', 'zinb']: raise ValueError( f"`loss_mode` must be one of ['mse', 'nb', 'zinb'], but input was '{self.loss_mode}'." ) if (alpha_recon_lec < 0) or (alpha_recon_lec > 1): raise ValueError( '`alpha_recon_lec` must be between 0 and 1.' ) if (alpha_recon_lode < 0) or (alpha_recon_lode > 1): raise ValueError( '`alpha_recon_lode` must be between 0 and 1.' ) if alpha_recon_lec + alpha_recon_lode != 1: raise ValueError( 'The sum of `alpha_recon_lec` and `alpha_recon_lode` must be 1.' ) self.adata = adata if 'n_genes_by_counts' not in self.adata.obs: raise KeyError( "`n_genes_by_counts` not found in `.obs` of the AnnData. Please run `scanpy.pp.calculate_qc_metrics` first to calculate the number of genes detected in each cell." ) if loss_mode == 'mse': if (self.adata.X.min() < 0) or (self.adata.X.max() > np.log1p(1e6)): raise ValueError( "Invalid expression matrix in `.X`. `mse` mode expects log1p(normalized expression) in `.X` of the AnnData." ) else: X = self.adata.X.data if sparse.issparse(self.adata.X) else self.adata.X if (X.min() < 0) or np.any(~np.equal(np.mod(X, 1), 0)): raise ValueError( f"Invalid expression matrix in `.X`. `{self.loss_mode}` mode expects raw UMI counts in `.X` of the AnnData." ) self.n_cells = adata.n_obs self.batch_size = batch_size self.drop_last = drop_last self.percent = percent if self.percent is None: if self.n_cells > 10000: self.percent = .2 else: self.percent = .9 else: if (self.percent < 0) or (self.percent > 1): raise ValueError( "`percent` must be between 0 and 1." ) self.val_frac = val_frac if (self.val_frac < 0) or (self.val_frac > 1): raise ValueError( '`val_frac` must be between 0 and 1.' ) if nepoch is None: ncells = round(self.n_cells * self.percent) self.nepoch = np.min([round((10000 / ncells) * 400), 400]) else: self.nepoch = nepoch self.lr = lr self.wt_decay = wt_decay self.eps = eps self.time_reverse = None self.random_state = random_state np.random.seed(random_state) # random.seed(random_state) torch.manual_seed(random_state) # torch.backends.cudnn.benchmark = False # torch.use_deterministic_algorithms(True) self.use_gpu = use_gpu gpu = torch.cuda.is_available() and use_gpu if gpu: torch.cuda.manual_seed(random_state) self.device = torch.device('cuda') logger.info('Running using GPU.') else: self.device = torch.device('cpu') logger.info('Running using CPU.') self.n_int = adata.n_vars self.model_kwargs = dict( device = self.device, n_int = self.n_int, n_latent = n_latent, n_ode_hidden = n_ode_hidden, n_vae_hidden = n_vae_hidden, batch_norm = batch_norm, ode_method = ode_method, step_size = step_size, alpha_recon_lec = alpha_recon_lec, alpha_recon_lode = alpha_recon_lode, alpha_kl = alpha_kl, loss_mode = loss_mode, ) self.model = TNODE(**self.model_kwargs) self.log = defaultdict(list) def _get_data_loaders(self) -> None: """ Generate Data Loaders for training and validation datasets. """ train_data, val_data = split_data(self.adata, self.percent, self.val_frac) self.train_dataset = MakeDataset(train_data, self.loss_mode) self.val_dataset = MakeDataset(val_data, self.loss_mode) # sampler = BatchSampler(train_data.n_obs, self.batch_size, self.drop_last) # self.train_dl = DataLoader(self.train_dataset, batch_sampler = sampler) self.train_dl = DataLoader(self.train_dataset, batch_size = self.batch_size, shuffle = True) self.val_dl = DataLoader(self.val_dataset, batch_size = self.batch_size)
[docs] def train(self): """ Model training. """ self._get_data_loaders() params = filter(lambda p: p.requires_grad, self.model.parameters()) self.optimizer = torch.optim.Adam(params, lr = self.lr, weight_decay = self.wt_decay, eps = self.eps) with tqdm(total=self.nepoch, unit='epoch') as t: for tepoch in range(t.total): train_loss = self._on_epoch_train(self.train_dl) val_loss = self._on_epoch_val(self.val_dl) self.log['train_loss'].append(train_loss) self.log['validation_loss'].append(val_loss) t.set_description(f"Epoch {tepoch + 1}") t.set_postfix({'train_loss': train_loss, 'val_loss': val_loss}, refresh=False) t.update()
def _on_epoch_train(self, DL) -> float: """ Go through the model and update the model parameters. Parameters ---------- DL DataLoader for training dataset. Returns ---------- float Training loss for the current epoch. """ self.model.train() total_loss = .0 ss = 0 for X, Y in DL: self.optimizer.zero_grad() X = X.to(self.device) Y = Y.to(self.device) loss, recon_loss_ec, recon_loss_ode, kl_div, z_div = self.model(X, Y) loss.backward() self.optimizer.step() total_loss += loss.item() * X.size(0) ss += X.size(0) train_loss = total_loss/ss return train_loss @torch.no_grad() def _on_epoch_val(self, DL) -> float: """ Validate using validation dataset. Parameters ---------- DL DataLoader for validation dataset. Returns ---------- float Validation loss for the current epoch. """ self.model.eval() total_loss = .0 ss = 0 for X, Y in DL: X = X.to(self.device) Y = Y.to(self.device) loss, recon_loss_ec, recon_loss_ode, kl_div, z_div = self.model(X, Y) total_loss += loss.item() * X.size(0) ss += X.size(0) val_loss = total_loss/ss return val_loss
[docs] def get_time( self, ) -> np.ndarray: """ Infer the developmental pseudotime. Returns ---------- :class:`~numpy.ndarray` The pseudotime inferred for each cell. """ X = self.adata.X if self.loss_mode in ['nb', 'zinb']: X = np.log1p(X) ts = self._get_time(self.model, X) ## The model might return pseudotime in reverse order. Check this based on number of genes expressed in each cell. if self.time_reverse is None: n_genes = torch.tensor(self.adata.obs['n_genes_by_counts'].values).float().log1p().to(self.device) m_ts = ts.mean() m_ngenes = n_genes.mean() beta_direction = (ts * n_genes).sum() - len(ts) * m_ts * m_ngenes if beta_direction > 0: self.time_reverse = True else: self.time_reverse = False if self.time_reverse: ts = 1 - ts return ts.cpu().numpy()
[docs] def get_vector_field( self, T: np.ndarray, Z: np.ndarray, ) -> np.ndarray: """ Infer the vector field. Parameters ---------- T The pseudotime estimated for each cell. Z The latent representation for each cell. Returns ---------- :class:`~numpy.ndarray` The estimated vector field. """ vf = self._get_vector_field( self.model, T, Z, self.time_reverse, ) return vf
[docs] def get_latentsp( self, alpha_z: float = .5, alpha_predz: float = .5, step_size: Optional[int] = None, step_wise: bool = False, batch_size: Optional[int] = None, ) -> tuple: """ Infer the latent space. Parameters ---------- alpha_z Scaling factor for encoder-derived latent space. (Default: 0.5) alpha_predz Scaling factor for ODE-solver-derived latent space. (Default: 0.5) step_size Step size during integration. step_wise Whether to perform step-wise integration by iteratively considering only two time points each time. (Default: `False`) batch_size Batch size when deriving the latent space. The default is no mini-batching. Returns ---------- tuple 3-tuple of weighted combined latent space, encoder-derived latent space, and ODE-solver-derived latent space. """ X = self.adata.X if self.model.loss_mode in ['nb', 'zinb']: X = np.log1p(X) mix_zs, zs, pred_zs = self._get_latentsp(self.model, X, alpha_z, alpha_predz, step_size, step_wise, batch_size, ) return mix_zs, zs, pred_zs
[docs] def save_model( self, save_dir: str, save_prefix: str, ) -> None: """ Save the trained scTour model. Parameters ---------- save_dir The directory where the model will be saved. save_prefix The prefix for model name. The model will be saved in 'save_dir/save_prefix.pth'. """ save_path = os.path.abspath(os.path.join(save_dir, f'{save_prefix}.pth')) # save_path = os.path.abspath(os.path.join(save_dir, f'{save_prefix}.tar')) torch.save( { 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(), 'model_kwargs': self.model_kwargs, 'time_reverse': self.time_reverse, 'adata': self.adata, 'percent': self.percent, 'nepoch': self.nepoch, 'batch_size': self.batch_size, 'random_state': self.random_state, 'drop_last': self.drop_last, 'lr': self.lr, 'wt_decay': self.wt_decay, 'eps': self.eps, 'val_frac': self.val_frac, 'use_gpu': self.use_gpu, }, save_path )
@staticmethod @torch.no_grad() def _get_time( model: TNODE, X: Union[np.ndarray, spmatrix], ) -> torch.tensor: """ Derive the developmental pseudotime for cells. Parameters ---------- model The trained scTour model. X The data matrix. Returns ---------- :class:`torch.Tensor` The pseudotime estimated for each cell. """ model.eval() if sparse.issparse(X): X = X.A X = torch.tensor(X).to(model.device) ts, _, _ = model.encoder(X) ts = ts.ravel() return ts @staticmethod @torch.no_grad() def _get_vector_field( model: TNODE, T: np.ndarray, Z: np.ndarray, time_reverse: bool, ) -> np.ndarray: """ Derive the vector field for cells. Parameters ---------- model The trained scTour model. T The pseudotime for each cell. Z The latent representation for each cell. time_reverse Whether to reverse the vector field. Returns ---------- :class:`~numpy.ndarray` The estimated vector field. """ model.eval() if not (isinstance(T, np.ndarray) and isinstance(Z, np.ndarray)): raise TypeError( 'The inputs must be numpy arrays.' ) Z = torch.tensor(Z) T = torch.tensor(T) if time_reverse is None: raise RuntimeError( 'It seems you did not run `get_time()` function first after model training.' ) direction = 1 if time_reverse: direction = -1 return direction * model.lode_func(T, Z).numpy() @staticmethod @torch.no_grad() def _get_latentsp( model: TNODE, X: Union[np.ndarray, spmatrix], alpha_z: float = .5, alpha_predz: float = .5, step_size: Optional[int] = None, step_wise: bool = False, batch_size: Optional[int] = None, ): """ Derive the latent representations of cells. Parameters ---------- model The trained scTour model. X The data matrix. alpha_z Scaling factor for encoder-derived latent space. (Default: 0.5) alpha_predz Scaling factor for ODE-solver-derived latent space. (Default: 0.5) step_size Step size during integration. step_wise Whether to perform step-wise integration by iteratively considering only two time points each time. (Default: `False`) batch_size Batch size when deriving the latent space. The default is no mini-batching. Returns ---------- tuple 3-tuple of weighted combined latent space, encoder-derived latent space, and ODE-solver-derived latent space. """ model.eval() if (alpha_z < 0) or (alpha_z > 1): raise ValueError( '`alpha_z` must be between 0 and 1.' ) if (alpha_predz < 0) or (alpha_predz > 1): raise ValueError( '`alpha_predz` must be between 0 and 1.' ) if alpha_z + alpha_predz != 1: raise ValueError( 'The sum of `alpha_z` and `alpha_predz` must be 1.' ) if sparse.issparse(X): X = X.A X = torch.tensor(X).to(model.device) T, qz_mean, qz_logvar = model.encoder(X) T = T.ravel().cpu() epsilon = torch.randn(qz_mean.size()) zs = epsilon * torch.exp(.5 * qz_logvar.cpu()) + qz_mean.cpu() sort_T, sort_idx, sort_ridx = np.unique(T, return_index=True, return_inverse=True) sort_T = torch.tensor(sort_T) sort_zs = zs[sort_idx] pred_zs = [] if batch_size is None: batch_size = len(sort_T) times = int(np.ceil(len(sort_T) / batch_size)) for i in range(times): idx1 = i * batch_size idx2 = np.min([(i + 1)*batch_size, len(sort_T)]) t = sort_T[idx1:idx2] z = sort_zs[idx1:idx2] z0 = z[0] if not step_wise: options = get_step_size(step_size, t[0], t[-1], len(t)) pred_z = odeint( model.lode_func, z0, t, method = model.ode_method, options = options ).view(-1, model.n_latent) else: pred_z = torch.empty((len(t), z.size(1))) pred_z[0] = z0 for j in range(len(t) - 1): t2 = t[j:(j + 2)] options = get_step_size(step_size, t2[0], t2[-1], len(t2)) pred_z[j + 1] = odeint( model.lode_func, z[j], t2, method = model.ode_method, options = options )[1] pred_zs += [pred_z] pred_zs = torch.cat(pred_zs) pred_zs = pred_zs[sort_ridx] mix_zs = alpha_z * zs + alpha_predz * pred_zs return mix_zs.numpy(), zs.numpy(), pred_zs.numpy()