Source code for sctour.predict

import torch
from torchdiffeq import odeint
from typing import Optional, Union
from typing_extensions import Literal
import numpy as np
from anndata import AnnData
from scipy import sparse
from scipy.sparse import spmatrix
import os

from ._utils import get_step_size
from .train import Trainer


def _check_data(
    adata1: AnnData,
    adata2: AnnData,
    loss_mode: str,
) -> np.ndarray:
    """
    Check the query data.

    Parameters
    ----------
    adata1
        An :class:`~anndata.AnnData` object for the query dataset.
    adata2
        An :class:`~anndata.AnnData` object for the training dataset.
    loss_mode
        The `loss_mode` used for model training.

    Returns
    ----------
    :class:`~numpy.ndarray`
        The expression matrix for the query dataset.
    """

    if len(adata1.var_names.intersection(adata2.var_names)) != adata2.n_vars:
        raise ValueError(
                "The query AnnData must contain all the genes that are used for training in the training dataset."
                )

    X = adata1[:, adata2.var_names].X
    if loss_mode == 'mse':
        if (X.min() < 0) or (X.max() > np.log1p(1e6)):
            raise ValueError(
                    "Invalid expression matrix in `.X`. Model trained from `mse` mode expects log1p(normalized expression) in `.X` of the query AnnData."
                    )
    else:
        data = X.data if sparse.issparse(X) else X
        if (data.min() < 0) or np.any(~np.equal(np.mod(data, 1), 0)):
            raise ValueError(
                    f"Invalid expression matrix in `.X`. Model trained from `{loss_mode}` mode expects raw UMI counts in `.X` of the query AnnData."
                    )
        else:
            X = np.log1p(X)

    return X


[docs]def load_model(model: str): """ Load the trained scTour model for prediction. Parameters ---------- model Filename for the scTour model trained and saved. Returns ---------- :class:`sctour.train.Trainer` The trained scTour model. """ if not os.path.isfile(model): raise FileNotFoundError( f'No such file: `{model}`.' ) checkpoint = torch.load(model, map_location=torch.device('cpu')) # checkpoint = torch.load(model) model_kwargs = checkpoint['model_kwargs'] del model_kwargs['device'] del model_kwargs['n_int'] tnode = Trainer( adata = checkpoint['adata'], percent = checkpoint['percent'], nepoch = checkpoint['nepoch'], batch_size = checkpoint['batch_size'], drop_last = checkpoint['drop_last'], lr = checkpoint['lr'], wt_decay = checkpoint['wt_decay'], eps = checkpoint['eps'], random_state = checkpoint['random_state'], val_frac = checkpoint['val_frac'], use_gpu = checkpoint['use_gpu'], **model_kwargs, ) tnode.model.load_state_dict(checkpoint['model_state_dict']) tnode.time_reverse = checkpoint['time_reverse'] return tnode
[docs]def predict_time( model: Trainer, adata: AnnData, reverse: bool = False, ) -> np.ndarray: """ Predict the pseudotime for query cells. Parameters ---------- model A :class:`sctour.train.Trainer` for trained scTour model. adata An :class:`~anndata.AnnData` object for the query dataset. reverse Whether to reverse the predicted pseudotime. When the pseudotime returned by `get_time()` function for the training data was in reverse order and you used the post-inference adjustment (`reverse_time()` function), please set this parameter to `True`. (Default: `False`) Returns ---------- :class:`~numpy.ndarray` The pseudotime predicted for the query cells. """ if model.time_reverse is None: raise RuntimeError( 'It seems you did not run `get_time()` function after model training. Please run `get_time()` first after training for the training data before you run `predict_time()` for the query data.' ) X = _check_data(adata, model.adata, model.loss_mode) ts = model._get_time(model = model.model, X = X) if model.time_reverse: ts = 1 - ts if reverse: ts = 1 - ts return ts.cpu().numpy()
[docs]def predict_vector_field( model: Trainer, T: np.ndarray, Z: np.ndarray, ) -> np.ndarray: """ Predict the vector field for query cells. Parameters ---------- model A :class:`sctour.train.Trainer` for trained scTour model. T The predicted pseudotime for query cells. Z The predicted latent representations for query cells. Returns ---------- :class:`~numpy.ndarray` The vector field predicted for query cells. """ vf = model._get_vector_field( model = model.model, T = T, Z = Z, time_reverse = model.time_reverse, ) return vf
[docs]def predict_latentsp( model: Trainer, adata: AnnData, mode: Literal['coarse', 'fine'] = 'fine', alpha_z: float = .5, alpha_predz: float = .5, step_size: Optional[int] = None, step_wise: bool = False, batch_size: Optional[int] = None, ) -> tuple: """ Predict the latent representations for query cells given their transcriptomes. Parameters ---------- model A :class:`sctour.train.Trainer` for trained scTour model. adata An :class:`~anndata.AnnData` object for the query dataset. mode The mode for deriving the latent space for the query dataset. Two modes are included: ``'fine'``: derive the latent space by taking the training data into consideration; ``'coarse'``: derive the latent space directly from the query data without involving the training data. alpha_z Scaling factor for encoder-derived latent space. (Default: 0.5) alpha_predz Scaling factor for ODE-solver-derived latent space. (Default: 0.5) step_size The step size during integration. step_wise Whether to perform step-wise integration by iteratively considering only two time points each time. (Default: `False`) batch_size Batch size when deriving the latent space. The default is no mini-batching. Returns ---------- tuple 3-tuple of weighted combined latent space, encoder-derived latent space, and ODE-solver-derived latent space. """ X = _check_data(adata, model.adata, model.loss_mode) if mode == 'coarse': mix_zs, zs, pred_zs = model._get_latentsp( model = model.model, X = X, alpha_z = alpha_z, alpha_predz = alpha_predz, step_size = step_size, step_wise = step_wise, batch_size = batch_size, ) if mode == 'fine': X2 = model.adata.X if model.loss_mode in ['nb', 'zinb']: X2 = np.log1p(X2) if sparse.issparse(X2): X2 = X2.A if sparse.issparse(X): X = X.A mix_zs, zs, pred_zs = model._get_latentsp( model = model.model, X = np.vstack((X, X2)), alpha_z = alpha_z, alpha_predz = alpha_predz, step_size = step_size, step_wise = step_wise, batch_size = batch_size, ) mix_zs = mix_zs[:len(X)] zs = zs[:len(X)] pred_zs = pred_zs[:len(X)] return mix_zs, zs, pred_zs
[docs]@torch.no_grad() def predict_ltsp_from_time( model: Trainer, T: np.ndarray, reverse: bool = False, step_wise: bool = True, step_size: Optional[int] = None, alpha_z: float = 0.5, alpha_predz: float = 0.5, k: int = 20, ) -> np.ndarray: """ Predict the transcriptomic latent space for query (unobserved) time intervals. Parameters ---------- model A :class:`sctour.train.Trainer` for trained scTour model. T A 1D numpy array containing the query time points (with values between 0 and 1). The latent space corresponding to these time points will be predicted. reverse When the pseudotime returned by `get_time()` function for the training data was in reverse order and you used the post-inference adjustment (`reverse_time()` function), please set this parameter to `True`. (Default: `False`) step_wise Whether to perform step-wise integration by iteratively considering only two time points when inferring the reference latent space from the training data. (Default: `True`) step_size The step size during integration. alpha_z Scaling factor for encoder-derived latent space. (Default: 0.5) alpha_predz Scaling factor for ODE-solver-derived latent space. (Default: 0.5) k The k nearest neighbors in the time space considered when predicting the latent representation for each query time point. (Default: 20) Returns ---------- :class:`~numpy.ndarray` Predicted latent space corresponding to the query time interval. """ mdl = model.model if not isinstance(T, np.ndarray): raise TypeError( "The input time interval must be a numpy array." ) if len(T.shape) > 1: raise TypeError( "The input time interval must be a 1D numpy array." ) if np.any(T < 0) or np.any(T > 1): raise ValueError( "The input time points must be in [0, 1]." ) ridx = np.random.permutation(len(T)) rT = torch.tensor(T[ridx]) ## get the reference time and latent space from the training data X = model.adata.X if model.loss_mode in ['nb', 'zinb']: X = np.log1p(X) mix_zs, zs, pred_zs = model._get_latentsp(model = mdl, X = X, alpha_z = alpha_z, alpha_predz = alpha_predz, step_wise = step_wise, step_size = step_size, ) ts = model._get_time(model = mdl, X = X) if model.time_reverse: ts = 1 - ts if reverse: ts = 1 - ts ts = ts.cpu() zs = torch.tensor(mix_zs) pred_T_zs = torch.empty((len(rT), mdl.n_latent)) for i, t in enumerate(rT): diff = torch.abs(t - ts) idxs = torch.argsort(diff) # n = (diff == 0).sum() # idxs = idxs[n:(k + n)] if (diff == 0).any(): pred_T_zs[i] = zs[idxs[0]].clone() else: idxs = idxs[:k] k_zs = torch.empty((k, mdl.n_latent)) for j, idx in enumerate(idxs): z0 = zs[idx].clone() t0 = ts[idx].clone() pred_t = torch.stack((t0, t)) if pred_t[0] < pred_t[1]: options = get_step_size(step_size, pred_t[0], pred_t[-1], len(pred_t)) else: options = get_step_size(step_size, pred_t[-1], pred_t[0], len(pred_t)) k_zs[j] = odeint( mdl.lode_func, z0, pred_t, method = mdl.ode_method, options = options )[1] k_zs = torch.mean(k_zs, dim = 0) pred_T_zs[i] = k_zs ts = torch.cat((ts, t.unsqueeze(0))) zs = torch.cat((zs, k_zs.unsqueeze(0))) pred_T_zs = pred_T_zs[np.argsort(ridx)] return pred_T_zs.numpy()