sctour.train.Trainer
- class sctour.train.Trainer(adata: anndata._core.anndata.AnnData, percent: Optional[float] = None, n_latent: int = 5, n_ode_hidden: int = 25, n_vae_hidden: int = 128, batch_norm: bool = False, ode_method: str = 'euler', step_size: Optional[int] = None, alpha_recon_lec: float = 0.5, alpha_recon_lode: float = 0.5, alpha_kl: float = 1.0, loss_mode: Literal['mse', 'nb', 'zinb'] = 'nb', nepoch: Optional[int] = None, batch_size: int = 1024, drop_last: bool = False, lr: float = 0.001, wt_decay: float = 1e-06, eps: float = 0.01, random_state: int = 0, val_frac: float = 0.1, use_gpu: bool = True)[source]
Class for implementing the scTour training process.
- Parameters
adata – An
AnnDataobject for the training data.percent – The percentage of cells used for model training. Default to 0.2 when the cell number > 10,000 and to 0.9 otherwise.
n_latent – The dimensionality of the latent space. (Default: 5)
n_ode_hidden – The dimensionality of the hidden layer for the latent ODE function. (Default: 25)
n_vae_hidden – The dimensionality of the hidden layer for the VAE. (Default: 128)
batch_norm – Whether to include a BatchNorm layer. (Default: False)
ode_method – The solver for ODE. List of ODE solvers can be found in torchdiffeq. (Default: ‘euler’)
step_size – The step size during integration.
alpha_recon_lec – The scaling factor for the reconstruction error from encoder-derived latent space. (Default: 0.5)
alpha_recon_lode – The scaling factor for the reconstruction error from ODE-solver-derived latent space. (Default: 0.5)
alpha_kl – The scaling factor for the KL divergence in the loss function. (Default: 1.0)
loss_mode – The mode for calculating the reconstruction error. (Default: ‘nb’) Three modes are included:
'mse': mean squared error;'nb': negative binomial conditioned likelihood;'zinb': zero-inflated negative binomial conditioned likelihood.nepoch – Number of epochs.
batch_size – The batch size during training. (Default: 1024)
drop_last – Whether or not drop the last batch when its size is smaller than batch_size. (Default: False)
lr – The learning rate. (Default: 1e-3)
wt_decay – The weight decay (L2 penalty) for Adam optimizer. (Default: 1e-6)
eps – The eps parameter for Adam optimizer. (Default: 0.01)
random_state – The seed for generating random numbers. (Default: 0)
val_frac – The percentage of data used for validation. (Default: 0.1)
use_gpu – Whether to use GPU when available. (Default: True)
Methods
get_latentsp([alpha_z, alpha_predz, ...])Infer the latent space.
get_time()Infer the developmental pseudotime.
get_vector_field(T, Z)Infer the vector field.
save_model(save_dir, save_prefix)Save the trained scTour model.
train()Model training.