scDRP package
Submodules
scDRP.data module
- class scDRP.data.CombinedDataset(X, a, t, c)
Bases:
DatasetDataset for the combined data
- __init__(X, a, t, c)
- Parameters:
X – numpy array of shape (n_samples, n_features)
a – one-hot encoded numpy array of shape (n_samples, num_perturbations)
t – one-hot encoded numpy array of shape (n_samples, num_celltypes) or None
c – one-hot encoded numpy array of shape (n_samples, batch_dim) or None
scDRP.loss module
- class scDRP.loss.HSICLoss(sigma_x: float | None = None, sigma_y: float | None = None, sigma_method: Literal['median', 'scott'] = 'median', sigma_scale: float = 1.0)
Bases:
ModuleHilbert-Schmidt Independence Criterion (HSIC) for Independence Testing.
Measures the dependence between X and Y. HSIC is zero if and only if X and Y are independent (assuming characteristic kernels like Gaussian RBF).
Based on: Gretton et al. “Measuring Statistical Dependence with Hilbert-Schmidt Norms” (2005)
- Parameters:
sigma_x (float, optional) – Fixed bandwidth for X. If None, uses automatic selection.
sigma_y (float, optional) – Fixed bandwidth for Y. If None, uses automatic selection.
sigma_method (str) – Method for automatic sigma selection when sigma_x/sigma_y is None. - ‘median’: Median heuristic (default, recommended) - ‘scott’: Scott’s rule
sigma_scale (float) – Additional scaling factor for computed sigma. Default: 1.0
Example
>>> # Automatic sigma selection (recommended) >>> loss_fn = HSICLoss() >>> X = torch.randn(100, 5) >>> Y = torch.randn(100, 3) >>> loss = loss_fn(X, Y) >>> >>> # Manual sigma specification >>> loss_fn = HSICLoss(sigma_x=1.0, sigma_y=0.5) >>> loss = loss_fn(X, Y)
- __init__(sigma_x: float | None = None, sigma_y: float | None = None, sigma_method: Literal['median', 'scott'] = 'median', sigma_scale: float = 1.0)
Initializes the HSIC loss.
- Parameters:
sigma_x – Fixed bandwidth for X. If None, auto-computed per forward pass.
sigma_y – Fixed bandwidth for Y. If None, auto-computed per forward pass.
sigma_method – Method for automatic bandwidth selection.
sigma_scale – Global scaling factor applied to all computed sigmas.
- _compute_sigma(x: Tensor, fixed_sigma: float | None) float
Computes or returns sigma for a given tensor.
- Parameters:
x – Input tensor.
fixed_sigma – Pre-specified sigma value. If None, auto-compute.
- Returns:
Sigma value to use.
- _gaussian_kernel(x: Tensor, y: Tensor, sigma: float) Tensor
Computes Gaussian RBF kernel matrix.
K(x, y) = exp(-||x - y||^2 / (2sigma^2))
- Parameters:
x – Tensor of shape [batch_size, dim_x].
y – Tensor of shape [batch_size, dim_y].
sigma – Bandwidth parameter.
- Returns:
Kernel matrix of shape [batch_size, batch_size].
- _median_heuristic(x: Tensor) float
Computes sigma using median heuristic.
sigma^2 = median(||x_i - x_j||^2) * sigma_scale
- Parameters:
x – Tensor of shape [batch_size, dim].
- Returns:
Optimal sigma value.
- _scotts_rule(x: Tensor) float
Computes sigma using Scott’s rule.
sigma = n^(-1/(d+4)) * std(X) * sigma_scale
- Parameters:
x – Tensor of shape [batch_size, dim].
- Returns:
Optimal sigma value.
- forward(X: Tensor, Y: Tensor) Tensor
Computes HSIC between X and Y.
- Parameters:
X – Input tensor of shape [batch_size, dim_x].
Y – Input tensor of shape [batch_size, dim_y].
- Returns:
HSIC value (scalar). Zero indicates independence.
- Raises:
ValueError – If batch sizes don’t match.
- class scDRP.loss.NOCCOLoss(epsilon: float = 0.0001, sigma_x: float | None = None, sigma_y: float | None = None, sigma_z: float | None = None, sigma_method: Literal['median', 'scott'] = 'median', sigma_scale: float = 1.0)
Bases:
ModuleNormalized Cross-Covariance Operator (NOCCO) Loss for Conditional Independence.
Measures the conditional dependence between X and Y given Z using the Hilbert-Schmidt norm of the normalized conditional cross-covariance operator.
Based on: “Kernel Measures of Conditional Dependence” (Fukumizu et al., NIPS 2007)
- Parameters:
epsilon (float) – Regularization parameter for matrix inversion stability. Default: 1e-4
sigma_x (float, optional) – Fixed bandwidth for X. If None, uses automatic selection.
sigma_y (float, optional) – Fixed bandwidth for Y. If None, uses automatic selection.
sigma_z (float, optional) – Fixed bandwidth for Z. If None, uses automatic selection.
sigma_method (str) – Method for automatic sigma selection. Options: - ‘median’: Median heuristic (default, recommended) - ‘scott’: Scott’s rule
sigma_scale (float) – Additional scaling factor for computed sigma. Default: 1.0
Example
>>> # Automatic sigma selection (recommended) >>> loss_fn = NOCCOLoss() >>> X, Y, Z = torch.randn(100, 5), torch.randn(100, 3), torch.randn(100, 2) >>> loss = loss_fn(X, Y, Z) >>> >>> # Manual sigma specification >>> loss_fn = NOCCOLoss(sigma_x=1.0, sigma_y=1.0, sigma_z=0.5) >>> loss = loss_fn(X, Y, Z)
- __init__(epsilon: float = 0.0001, sigma_x: float | None = None, sigma_y: float | None = None, sigma_z: float | None = None, sigma_method: Literal['median', 'scott'] = 'median', sigma_scale: float = 1.0)
Initializes the NOCCO loss.
- _center_gram(K: Tensor) Tensor
Centers the Gram matrix in feature space.
- _compute_normalized_cov(G_Y: Tensor, G_X: Tensor, n: int) Tensor
Computes normalized cross-covariance operator V_Y_X.
- _compute_sigma(x: Tensor, fixed_sigma: float | None) float
Computes or returns sigma for a given tensor.
- _gaussian_kernel(x: Tensor, y: Tensor, sigma: float) Tensor
Computes Gaussian RBF kernel matrix.
- _median_heuristic(x: Tensor) float
Computes sigma using median heuristic.
- _scotts_rule(x: Tensor) float
Computes sigma using Scott’s rule.
- forward(X: Tensor, Y: Tensor, Z: Tensor) Tensor
Computes the NOCCO loss for conditional independence testing.
- Parameters:
X – Input tensor of shape [batch_size, dim_x].
Y – Input tensor of shape [batch_size, dim_y].
Z – Conditioning tensor of shape [batch_size, dim_z].
- Returns:
Scalar loss value (non-negative). Zero indicates conditional independence.
- class scDRP.loss.UnnormalizedHSCICLoss(epsilon: float = 0.0001, sigma_x: float | None = None, sigma_y: float | None = None, sigma_z: float | None = None, sigma_method: Literal['median', 'scott'] = 'median', sigma_scale: float = 1.0)
Bases:
ModuleUnnormalized Hilbert-Schmidt Conditional Independence Criterion (HSCIC).
Computes the unnormalized squared Hilbert-Schmidt norm of the conditional cross-covariance operator: ||sigma_Yẍ|Z||^2_HS.
Based on: Fukumizu et al. (2004) and Sheng & Sriperumbudur (2023)
- Parameters:
epsilon (float) – Regularization parameter for matrix inversion. Default: 1e-4
sigma_x (float, optional) – Fixed bandwidth for X. If None, uses automatic selection.
sigma_y (float, optional) – Fixed bandwidth for Y. If None, uses automatic selection.
sigma_z (float, optional) – Fixed bandwidth for Z. If None, uses automatic selection.
sigma_method (str) – Method for automatic sigma selection. Options: - ‘median’: Median heuristic (default, recommended) - ‘scott’: Scott’s rule
sigma_scale (float) – Additional scaling factor for computed sigma. Default: 1.0
Example
>>> # Automatic sigma selection (recommended) >>> loss_fn = UnnormalizedHSCICLoss() >>> X, Y, Z = torch.randn(100, 5), torch.randn(100, 3), torch.randn(100, 2) >>> loss = loss_fn(X, Y, Z) >>> >>> # Manual sigma specification >>> loss_fn = UnnormalizedHSCICLoss(sigma_x=1.0, sigma_y=1.0, sigma_z=0.5) >>> loss = loss_fn(X, Y, Z)
- __init__(epsilon: float = 0.0001, sigma_x: float | None = None, sigma_y: float | None = None, sigma_z: float | None = None, sigma_method: Literal['median', 'scott'] = 'median', sigma_scale: float = 1.0)
Initializes the Unnormalized HSCIC loss.
- _center_gram(K: Tensor) Tensor
Centers the Gram matrix in the RKHS feature space.
- _compute_sigma(x: Tensor, fixed_sigma: float | None) float
Computes or returns sigma for a given tensor.
- _gaussian_kernel(x: Tensor, y: Tensor, sigma: float) Tensor
Computes Gaussian RBF kernel matrix.
- _median_heuristic(x: Tensor) float
Computes sigma using median heuristic.
- _scotts_rule(x: Tensor) float
Computes sigma using Scott’s rule.
- forward(x: Tensor, y: Tensor, z: Tensor) Tensor
Computes the unnormalized HSCIC loss value.
- Parameters:
x – Input tensor X of shape [batch_size, dim_x].
y – Input tensor Y of shape [batch_size, dim_y].
z – Conditioning tensor Z of shape [batch_size, dim_z].
- Returns:
Scalar loss value (non-negative). Zero indicates X ⊥⊥ Y | Z.
- class scDRP.loss.ZINBLoss
Bases:
ModuleZero-Inflated Negative Binomial Loss
- __init__()
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x, rho, dispersion, pi, s, eps=1e-08)
- Parameters:
x – observed data
rho – mean gene expression (mean of the negative binomial distribution equals rho * s)
dispersion – dispersion of the negative binomial distribution
pi – zero-inflation parameter
s – scale parameter (library size)
eps – small value to prevent numerical instability
- Returns:
negative log likelihood
- Return type:
loss
- scDRP.loss.klLoss(mu, logvar)
KL divergence between the latent distribution and the prior
- Parameters:
mu – mean of the latent distribution
logvar – log variance of the latent distribution
Returns
kl – KL divergence
- Returns:
KL divergence
- Return type:
kl
- scDRP.loss.klLoss_prior(mu_q, logvar_q, mu_p, logvar_p)
Compute KL(q || p) for two Gaussians q(z|x) ~ N(mu_p, exp(logvar_p)) and p(z) ~ N(mu_q, exp(logvar_q))
- Parameters:
mu_q – mean of q
logvar_q – log variance of q
mu_p – mean of p
logvar_p – log variance of p
- Returns:
KL divergence
- Return type:
kl
scDRP.model module
- class scDRP.model.Decoder(device, input_dim=3000, covariate_dim=1, layer_dims=[500, 100], latent_dim=20, dropout_rate=0.2, library_size_strategy='observed')
Bases:
ModuleDecoder network
- __init__(device, input_dim=3000, covariate_dim=1, layer_dims=[500, 100], latent_dim=20, dropout_rate=0.2, library_size_strategy='observed')
- Parameters:
input_dim – int, input dimension
covariate_dim – int, covariate dimension
layer_dims – list of int, hidden layer dimensions
latent_dim – int, latent dimension
dropout_rate – float, dropout rate in MLP
- forward(z, c, dispersion_strategy='gene')
- Parameters:
z – torch.Tensor, latent variables (batch_size x latent_dim)
c – torch.Tensor, covariate data (batch_size x covariate_dim)
dispersion_strategy – str, strategy to specify dispersion factor (we have two options: gene-wise and gene-cell wise but currently use gene-wise only)
- Returns:
torch.Tensor, mean of the negative binomial distribution dispersion: torch.Tensor, dispersion of the negative binomial distribution pi: torch.Tensor, zero-inflation parameter
- Return type:
rho
- class scDRP.model.Encoder(device, input_dim=3000, layer_dims=[500, 100], latent_dim=20, dropout_rate=0.2)
Bases:
ModuleEncoder network
- __init__(device, input_dim=3000, layer_dims=[500, 100], latent_dim=20, dropout_rate=0.2)
- Parameters:
input_dim – int, input dimension
layer_dims – list of int, hidden layer dimensions
latent_dim – int, latent dimension
dropout_rate – float, dropout rate in MLP
- forward(x)
- Parameters:
x – torch.Tensor, input data (batch_size x input_dim)
- Returns:
torch.Tensor, latent variable mu: torch.Tensor, mean of the latent variable logvar: torch.Tensor, log variance of the latent variable
- Return type:
z
- reparameterize(mu, logvar)
Reparameterization trick
- Parameters:
mu – torch.Tensor, mean of the latent variable
logvar – torch.Tensor, log variance of the latent variable
- Returns:
torch.Tensor, latent variable
- Return type:
z
- class scDRP.model.HardConcreteGate(size, beta=0.6666666666666666, gamma=-0.1, zeta=1.1)
Bases:
ModuleHard Concrete Gate for L0 regularization
- Parameters:
size – int, size of the gate
beta – float, temperature parameter. Default is 2/3
gamma – float, left stretch parameter. Default is -0.1
zeta – float, right stretch parameter. Default is 1.1
- __init__(size, beta=0.6666666666666666, gamma=-0.1, zeta=1.1)
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x, training=True)
Forward pass with gating
- Parameters:
x – input tensor
training – whether the model is in training mode
- Returns:
gated output tensor gate: sampled gate values
- Return type:
output
- regularization_loss()
Compute the expected L0 penalty
- Returns:
a tensor representing the expected L0 penalty (prob of being nonzero)
- Return type:
prob_nonzero
- sample_gate(training=True)
Sample gate values using Hard Concrete distribution
- Parameters:
training – bool, whether in training mode
- Returns:
torch.Tensor, sampled gate values
- Return type:
gate
- class scDRP.model.MSEDecoder(device, input_dim=3000, covariate_dim=1, layer_dims=[500, 100], latent_dim=20, dropout_rate=0.2, distribution='Normal')
Bases:
ModuleDecoder network
- __init__(device, input_dim=3000, covariate_dim=1, layer_dims=[500, 100], latent_dim=20, dropout_rate=0.2, distribution='Normal')
- Parameters:
input_dim – int, input dimension
covariate_dim – int, covariate dimension
layer_dims – list of int, hidden layer dimensions
latent_dim – int, latent dimension
dropout_rate – float, dropout rate in MLP
- forward(z, c, dispersion_strategy='gene')
- Parameters:
z – torch.Tensor, latent variable (batch_size x latent_dim)
c – torch.Tensor, covariate data (batch_size x covariate_dim)
dispersion_strategy – str, strategy to specify dispersion factor (we have two options: gene-wise and gene-cell wise but currently use gene-wise only)
- Returns:
torch.Tensor, mean of the negative binomial distribution dispersion: torch.Tensor, dispersion of the negative binomial distribution pi: torch.Tensor, zero-inflation parameter
- Return type:
rho
- class scDRP.model.NBDecoder(device, input_dim=3000, covariate_dim=1, layer_dims=[500, 100], latent_dim=20, dropout_rate=0.2, library_size_strategy='observed')
Bases:
ModuleDecoder network
- __init__(device, input_dim=3000, covariate_dim=1, layer_dims=[500, 100], latent_dim=20, dropout_rate=0.2, library_size_strategy='observed')
- Parameters:
input_dim – int, input dimension
covariate_dim – int, covariate dimension
layer_dims – list of int, hidden layer dimensions
latent_dim – int, latent dimension
dropout_rate – float, dropout rate in MLP
- forward(z, c, dispersion_strategy='gene')
- Parameters:
z – torch.Tensor, latent variables (batch_size x latent_dim)
c – torch.Tensor, covariate data (batch_size x covariate_dim)
dispersion_strategy – str, strategy to specify dispersion factor (we have two options: gene-wise and gene-cell wise but currently use gene-wise only)
- Returns:
torch.Tensor, mean of the negative binomial distribution dispersion: torch.Tensor, dispersion of the negative binomial distribution pi: torch.Tensor, zero-inflation parameter
- Return type:
rho
- class scDRP.model.PerturbNet(device, input_dim, covariate_dim=0, celltype_num=0, perturbation_num=2, layer_dims=[500, 100], latent_dep_dim=50, latent_ind_dim=50, dropout_rate=0.2, lambda_sparse=0, l0_latent=0.001, beta=1, lambda_hsic=0.2, distribution='ZINB', encoder_covariates=False, library_size_strategy='observed', eps=1e-10)
Bases:
ModulescPerturb model
- __init__(device, input_dim, covariate_dim=0, celltype_num=0, perturbation_num=2, layer_dims=[500, 100], latent_dep_dim=50, latent_ind_dim=50, dropout_rate=0.2, lambda_sparse=0, l0_latent=0.001, beta=1, lambda_hsic=0.2, distribution='ZINB', encoder_covariates=False, library_size_strategy='observed', eps=1e-10)
- Parameters:
input_dim – int, input dimension
covariate_dim – int, covariate dimension (default: 1)
celltype_num – int, number of cell types (default: 1)
perturbation_num – int, number of perturbations (default: 2)
layer_dims – list of int, hidden layer dimensions (default: [500,100])
latent_dep_dim – int, latent dimension for dependent variable (default: 20)
latent_ind_dim – int, latent dimension for independent variable (default: 10)
dropout_rate – float, dropout rate in MLP (default: 0.2)
lambda_sparse – float, sparsity penalty (default: 0)
beta – float, KL divergence weight (default: 10)
encoder_covariates – boolean, whether to include covariates in encoders (default: False)
eps – float, small value to prevent numerical instability (default: 1e-10)
- forward(x, a, t, c, train=True)
- Parameters:
x – torch.Tensor, input data (batch_size, input_dim)
a – torch.Tensor, perturbation data (batch_size, perturbation_num)
t – torch.Tensor, cell type data (batch_size, celltype_num)
c – torch.Tensor, covariate data (batch_size, covariate_dim)
train
- Returns:
torch.Tensor, latent representation that depends on perturbation, (batch_size, latent_dep_dim) z_u: torch.Tensor, latent representation that is independent from perturbation, (batch_size, latent_ind_dim) mu_d: torch.Tensor, mean of latent representation that depends on perturbation, (batch_size, latent_dep_dim) mu_u: torch.Tensor, mean of latent representation that is independent from perturbation, (batch_size, latent_ind_dim) rho: torch.Tensor, mean of the negative binomial distribution, (batch_size, input_dim) dispersion: torch.Tensor, dispersion of the negative binomial distribution, (input_dim,) when gene-wise pi: torch.Tensor, zero-inflation parameter, (batch_size, input_dim) s: torch.Tensor, library size, (batch_size, 1) loss: torch.Tensor, total loss loss_dict: dict, dictionary of losses
- Return type:
z_d
- reparameterize(mu, logvar)
Reparameterization trick
- Parameters:
mu – torch.Tensor, mean of the latent variable
logvar – torch.Tensor, log variance of the latent variable
- Returns:
torch.Tensor, latent variable
- Return type:
z
- sample_sequencing_depth(x, strategy='observed')
Sample sequencing depth
- Parameters:
x – torch.Tensor, observed data
strategy – str, strategy to sample sequencing depth. We have two options: batch_sample and observed, but will use observed only currently
- Returns:
torch.Tensor, library size
- Return type:
s
scDRP.module module
- class scDRP.module.DosageModel(models)
Bases:
object- __init__(models)
initialize the DosageModel with given models for mu and sigma.
- Parameters:
models – a dictionary containing ‘mu’ and ‘sigma’ models.
- predict(v_new_log)
Predict the mu and sigma for new dosage values.
- Parameters:
v_new_log – a numpy array of log-transformed dosage values.
- Returns:
a numpy array of predicted means. sigma_pred: a numpy array of predicted standard deviations.
- Return type:
mu_pred
- class scDRP.module.Perturb(adata, layer=None, perturbation_key='perturbation', celltype_key=None, batch_key=None, dose_key=None, distribution='ZINB')
Bases:
object- __init__(adata, layer=None, perturbation_key='perturbation', celltype_key=None, batch_key=None, dose_key=None, distribution='ZINB')
Initialize the Perturb object.
- Parameters:
adata – an AnnData object
layer – a string representing the layer name of count matrix in adata
perturbation_key – a string representing the key of perturbation in adata.obs
celltype_key – a string representing the key of cell type in adata.obs
batch_key – a string representing the key of batch in adata.obs
dose_key – a string representing the key of dose in adata.obs
distribution – a string representing the distribution of the data. Options (default: “ZINB”): - “ZINB”: Zero-Inflated Negative Binomial distribution - “NB”: Negative Binomial distribution - “Normal”: Gaussian distribution - “Normal_positive”: Gaussian distribution with positive output
- _fit_dose_to_dist(v, zd, method='gpr')
Fit models to map dosage values to latent distributions.
- Parameters:
v – a numpy array representing the dosage values.
zd – a numpy array representing the latent embeddings.
method – a string representing the method to fit dosage to latent distribution. Options are ‘linear’, ‘spline’, or ‘gpr’. Defaults to ‘gpr’.
- Returns:
a DosageModel object containing the fitted models.
- _fit_gpr(doses, means, stds)
Fit Gaussian Process Regressors for the given doses, means, and standard deviations.
- Parameters:
doses – a numpy array representing the unique dosage values.
means – a numpy array representing the mean latent embeddings.
stds – a numpy array representing the standard deviation of latent embeddings.
- Returns:
a dictionary containing the fitted mu and sigma models.
- _fit_interpolator(doses, means, stds, method)
Fit interpolators for the given doses, means, and standard deviations.
- Parameters:
doses – a numpy array representing the unique dosage values.
means – a numpy array representing the mean latent embeddings.
stds – a numpy array representing the standard deviation of latent embeddings.
method – a string representing the interpolation method. Options are ‘linear’ or ‘spline’.
- Returns:
a dictionary containing the fitted mu and sigma models.
- _generate_from_latent(latent, covariates=None, library_size=None, n_samples=1)
Generate samples from latent embeddings
- Parameters:
latent – torch Tensor containing latent factors, (sample_size, latent_dimensions)
covariates – torch Tensor containing one-hot encoded covariates, (sample_size, covariate_dimensions). Default: None
library_size – torch Tensor containing library sizes for new generate samples, (sample_size,1). Inferred library size of original adata will be used if None
n_samples – an integer representing the number of samples (Note: if latent is assigned a n*p matrix, then n_samples*n samples will generated in total!)
- Returns:
a numpy array representing the generated samples
- _generate_latent(perturbations, celltype, batch_size=32)
Generate latent embeddings from perturbations and cell types.
- Parameters:
perturbations – a numpy array representing one-hot encoded perturbations
celltype – a numpy array representing one-hot encoded cell types
batch_size – an integer representing the batch size for DataLoader
- Returns:
a numpy array representing the generated latent embeddings
- _representative_latent(mu, logvar)
Generate representative latent embeddings by sampling from the learned distribution.
- Parameters:
mu – a numpy array representing the mean of the latent distribution
logvar – a numpy array representing the log variance of the latent distribution
- Returns:
a numpy array representing the sampled latent embeddings
- _sample_from_parameter(n_samples, inference=False)
Generate samples from inferred ZINB parameters.
- Parameters:
n_samples – an integer representing the number of samples
inference – a boolean representing whether to reconduct the inference process each time or simply use inferred parameters stored in the object
- Returns:
a numpy array representing the samples
- counterfactual_samples(control, treatment, dose=None, strategy='ot', alpha=1, beta=1, projection_strategy='full', value='raw', method='sinkhorn', reg=0.01, reg_m=1.0)
Estimate the Individual Treatment Effect (ITE) for a pair of control and treatment.
- Parameters:
control – a string representing the name of control group (should be a value in adata.obs[perturbation_key]).
treatment – a string representing the name of treatment group (should be a value in adata.obs[perturbation_key]).
dose – a float representing the dose level (should be a value in adata.obs[dose_key]). If None, all doses will be considered.
strategy – a string representing the counterfactual generation strategy. Options: - “ot”: Optimal Transport - “average”: Average Effect Addition
alpha – a float representing the weight for zu embeddings in OT calculation. Default is 1.
beta – a float representing the weight for zd embeddings in OT calculation. Default is 1.
projection_strategy – a string representing the projection strategy for OT calculation. Options: - “full”: use full embeddings for OT calculation - “zd_only”: project zd embeddings while keep zu embeddings unchanged
value – a string representing the value type for OT cost calculation. Options: - “raw”: use raw embeddings for OT cost calculation - “quantile”: use quantile normalized embeddings for OT cost calculation
method (str, optional) – Optimal transport method. Options: - “emd”: Exact Optimal Transport - “sinkhorn”: Sinkhorn Regularized OT - “unbalanced_sinkhorn”: Unbalanced Sinkhorn Regularized OT Defaults to “emd”.
reg (float, optional) – Entropic regularization parameter for Sinkhorn. Default is 0.1. Only useful when specifying method as “sinkhorn” or “unbalanced_sinkhorn”.
reg_m (float, optional) – Marginal relaxation parameter (higher allows more mass deviation). Default is 1.0. Only useful when specifying method as “unbalanced_sinkhorn”.
- Returns:
a numpy array representing the ITE
- create_dataset(adata)
Create a dataset from the AnnData object.
- Parameters:
adata – an AnnData object.
- Returns:
a CombinedDataset object.
- dosage_extrapolate(control, treatment, dose_value=None, method='gpr')
Extrapolate latent embeddings for control group to new dosage levels based on treatment group data.
- Parameters:
control – a string representing the name of control group (should be a value in adata.obs[perturbation_key]).
treatment – a string representing the name of treatment group (should be a value in adata.obs[perturbation_key]).
dose_value – a scalar or array-like representing the dosage value(s) for extrapolation.
method – a string representing the method to fit dosage to latent distribution. Options are ‘linear’, ‘spline’, or ‘gpr’. Defaults to ‘gpr’.
- Returns:
a numpy array or a list of numpy arrays representing the counterfactual mean expressions at the specified dosage level(s).
- effect_estimate(control, treatment, dose=None, strategy='ot', alpha=1, beta=1, projection_strategy='full', value='raw', method='emd', reg=0.01, reg_m=1.0)
Estimate the Individual Treatment Effect (ITE) for a pair of control and treatment.
- Parameters:
control – a string representing the name of control group (should be a value in adata.obs[perturbation_key]).
treatment – a string representing the name of treatment group (should be a value in adata.obs[perturbation_key]).
method (str, optional) – Optimal transport method. Options: - “emd”: Exact Optimal Transport - “sinkhorn”: Sinkhorn Regularized OT - “unbalanced_sinkhorn”: Unbalanced Sinkhorn Regularized OT Defaults to “emd”.
reg (float, optional) – Entropic regularization parameter for Sinkhorn. Default is 0.1. Only useful when specifying method as “sinkhorn” or “unbalanced_sinkhorn”.
reg_m (float, optional) – Marginal relaxation parameter (higher allows more mass deviation). Default is 1.0. Only useful when specifying method as “unbalanced_sinkhorn”.
- Returns:
a numpy array representing the ITE
- get_adata()
Get the adata with latent variables and estimated generative parameters.
- Returns:
an AnnData object with latent embeddings and estimated ZINB generative parameters.
- get_counterfactual_adata(control, treatment, dose=None, covariates=None, strategy='ot', alpha=1, beta=1, projection_strategy='full', value='raw', method='sinkhorn', reg=0.01, reg_m=1.0)
Generate counterfactual AnnData for a pair of control and treatment.
- Parameters:
control – a string representing the name of control group (should be a value in adata.obs[perturbation_key]).
treatment – a string representing the name of treatment group (should be a value in adata.obs[perturbation_key]).
dose – a float representing the dose level (should be a value in adata.obs[dose_key]). If None, all doses will be considered.
covariates – a numpy array representing the one-hot encoded covariates for control group. If None, inferred covariates will be used.
strategy – a string representing the counterfactual generation strategy. Options: - “ot”: Optimal Transport - “average”: Average Effect Addition
alpha – a float representing the weight for zu embeddings in OT calculation. Default is 1.
beta – a float representing the weight for zd embeddings in OT calculation. Default is 1.
projection_strategy – a string representing the projection strategy for OT calculation. Options: - “full”: use full embeddings for OT calculation - “zd_only”: project zd embeddings while keep zu embeddings unchanged
value – a string representing the value type for OT cost calculation. Options: - “raw”: use raw embeddings for OT cost calculation - “quantile”: use quantile normalized embeddings for OT cost calculation
method (str, optional) – Optimal transport method. Options: - “emd”: Exact Optimal Transport - “sinkhorn”: Sinkhorn Regularized OT - “unbalanced_sinkhorn”: Unbalanced Sinkhorn Regularized OT Defaults to “emd”.
reg (float, optional) – Entropic regularization parameter for Sinkhorn. Default is 0.1. Only useful when specifying method as “sinkhorn” or “unbalanced_sinkhorn”.
reg_m (float, optional) – Marginal relaxation parameter (higher allows more mass deviation). Default is 1.0. Only useful when specifying method as “unbalanced_sinkhorn”.
- Returns:
an AnnData object representing the counterfactual samples
- get_counterfactual_latent(control, treatment, dose=None, strategy='ot', alpha=1, beta=1, projection_strategy='full', value='raw', method='emd', reg=0.01, reg_m=1.0)
Generate counterfactual latent embeddings for a pair of control and treatment.
- Parameters:
control – a string representing the name of control group (should be a value in adata.obs[perturbation_key]).
treatment – a string representing the name of treatment group (should be a value in adata.obs[perturbation_key]).
dose – a float representing the dose level (should be a value in adata.obs[dose_key]). If None, all doses will be considered.
strategy – a string representing the strategy to generate counterfactual latent embeddings. Options: - “ot”: use optimal transport to estimate the latent shift - “average”: use average effect to estimate the latent shift
alpha – a float number representing the weight for zu in optimal transport. Default is 1.
beta – a float number representing the weight for zd in optimal transport. Default is 1.
projection_strategy – a string representing the projection strategy in optimal transport. Options: - “full”: use full latent embeddings for optimal transport - “zd_only”: transport only on dependent latent embeddings with independent latent embeddings fixed
value – a string representing the value type in optimal transport. Options: - “raw”: use raw latent embeddings for optimal transport - “quantile”: use quantile transformed latent embeddings for optimal transport
method – a string representing the optimal transport method. Options: - “emd”: Exact Optimal Transport - “sinkhorn”: Sinkhorn Regularized OT - “unbalanced_sinkhorn”: Unbalanced Sinkhorn Regularized OT Defaults to “emd”.
reg – a float number representing the entropic regularization parameter for Sinkhorn. Default is 0.01. Only useful when specifying method as “sinkhorn” or “unbalanced_sinkhorn”.
reg_m – a float number representing the marginal relaxation parameter (higher allows more mass deviation). Default is 1.0. Only useful when specifying method as “unbalanced_sinkhorn”.
- Returns:
a tuple of numpy arrays representing the counterfactual latent embeddings (z,zd,zu)
- get_latent()
Return the disentangled latent embeddings.
- get_parameter_from_latent(latent, covariates=None, batch_size=32)
Generate generative model (ZINB) parameters from latent embeddings
- Parameters:
latent – torch Tensor containing latent factors, (batch_size, latent_dimensions)
covariates – torch Tensor containing one-hot encoded covariates, (batch_size, covariate_dimensions). Default: None
- Returns:
a tuple, ZINB paramters rho, dispersion, and pi in generative models
- inference(n_samples=1, dataset=None, batch_size=None, update=False, returns=False)
Perform inference.
- Parameters:
n_samples – an integer representing the number of samples repeated for inference process
dataset – a CombinedDataset object. By default we use the training dataset
batch_size – an integer representing the batch size
update – a boolean representing whether to update the adata
returns – a boolean representing whether to return the results
- Returns:
a tuple of numpy arrays representing the latent variables (z_d,z_u,mu_d,mu_u,rho,dispersion,dropout_rate,library_size)
- joint_effect_estimate(control, treatment1, treatment2, dose1=None, dose2=None, strategy='ot', alpha=1, beta=1, projection_strategy='full', value='raw', method='sinkhorn', reg=0.01, reg_m=1.0)
Estimate the joint effect for a pair of treatments compared to a control.
- Parameters:
control – a string representing the name of control group (should be a value in adata.obs[perturbation_key]).
treatment1 – a string representing the name of first treatment group (should be a value in adata.obs[perturbation_key]).
treatment2 – a string representing the name of second treatment group (should be a value in adata.obs[perturbation_key]).
dose1 – a float representing the dose level for first treatment (should be a value in adata.obs[dose_key]). If None, all doses will be considered.
dose2 – a float representing the dose level for second treatment (should be a value in adata.obs[dose_key]). If None, all doses will be considered.
strategy – a string representing the counterfactual generation strategy. Options: - “ot”: Optimal Transport - “average”: Average Effect Addition
alpha – a float representing the weight for zu embeddings in OT calculation. Default is 1.
beta – a float representing the weight for zd embeddings in OT calculation. Default is 1.
projection_strategy – a string representing the projection strategy for OT calculation. Options: - “full”: use full embeddings for OT calculation - “zd_only”: project zd embeddings while keep zu embeddings unchanged
value – a string representing the value type for OT cost calculation. Options: - “raw”: use raw embeddings for OT cost calculation - “quantile”: use quantile normalized embeddings for OT cost calculation
method (str, optional) – Optimal transport method. Options: - “emd”: Exact Optimal Transport - “sinkhorn”: Sinkhorn Regularized OT - “unbalanced_sinkhorn”: Unbalanced Sinkhorn Regularized OT Defaults to “emd”.
reg (float, optional) – Entropic regularization parameter for Sinkhorn. Default is 0.1. Only useful when specifying method as “sinkhorn” or “unbalanced_sinkhorn”.
reg_m (float, optional) – Marginal relaxation parameter (higher allows more mass deviation). Default is 1.0. Only useful when specifying method as “unbalanced_sinkhorn”.
- Returns:
a numpy array representing the joint effect
- latent_adaption(control_zd, treatment_zd, control_zu, treatment_zu, control_celltype, treatment_celltype, strategy='ot', alpha=1, beta=1, projection_strategy='full', value='raw', method='emd', reg=0.01, reg_m=1.0)
Generate counterfactual latent embeddings for a pair of control and treatment.
- Parameters:
control_zd – a numpy array representing the dependent latent embeddings of control group
treatment_zd – a numpy array representing the dependent latent embeddings of treatment group
control_zu – a numpy array representing the independent latent embeddings of control group
treatment_zu – a numpy array representing the independent latent embeddings of treatment group
control_celltype – a numpy array representing the one-hot encoded cell types of control group
treatment_celltype – a numpy array representing the one-hot encoded cell types of treatment group
strategy – a string representing the strategy to generate counterfactual latent embeddings. Options: - “ot”: use optimal transport to estimate the latent shift - “average”: use average effect to estimate the latent shift
alpha – a float number representing the weight for zu in optimal transport. Default is 1.
beta – a float number representing the weight for zd in optimal transport. Default is 1.
projection_strategy – a string representing the projection strategy in optimal transport. Options: - “full”: use full latent embeddings for optimal transport - “zd_only”: transport only on dependent latent embeddings with independent latent embeddings fixed
value – a string representing the value type in optimal transport. Options: - “raw”: use raw latent embeddings for optimal transport - “quantile”: use quantile transformed latent embeddings for optimal transport
method – a string representing the optimal transport method. Options: - “emd”: Exact Optimal Transport - “sinkhorn”: Sinkhorn Regularized OT - “unbalanced_sinkhorn”: Unbalanced Sinkhorn Regularized OT Defaults to “emd”.
reg – a float number representing the entropic regularization parameter for Sinkhorn. Default is 0.01. Only useful when specifying method as “sinkhorn” or “unbalanced_sinkhorn”.
reg_m – a float number representing the marginal relaxation parameter (higher allows more mass deviation). Default is 1.0. Only useful when specifying method as “unbalanced_sinkhorn”.
- Returns:
a numpy array representing the counterfactual latent embeddings counterfactual_zd: a numpy array representing the counterfactual dependent latent embeddings counterfactual_zu: a numpy array representing the counterfactual independent latent embeddings
- Return type:
counterfactual_z
- load(path, device=None)
Load the trained model.
- Parameters:
path – a string representing the path to the model checkpoint.
device – the device to load the model onto. Default is None, which uses self.device.
- sample_posterior(n_samples=1, mu=None, theta=None, pi=None)
Sample from the posterior distribution.
- Parameters:
n_samples – an integer representing the number of samples (Note: if mu is assigned a n*p matrix, then n_sample*n samples will generated in total!)
mu – a numpy array representing the mean of the negative binomial distribution.
theta – a numpy array representing the dispersion of the negative binomial distribution
pi – a numpy array representing the dropout rate
- Returns:
a numpy array representing the samples
- save(path)
Save the trained model.
- Parameters:
path – a string representing the path to the model checkpoint.
- setup(hidden_layers=[128, 128], latent_dependent=50, latent_independent=50, beta=1, sparse_coef=0, l0_latent=0.001, lambda_hsic=0.2, library_size_strategy='observed', device=None)
Setup the model for training.
- Parameters:
hidden_layers – a list of integers representing the number of neurons in each hidden layer
latent_dependent – an integer representing the dimensions of z_D
latent_independent – an integer representing the dimensions of z_I
beta – a float number representing the weight of the KL divergence term
sparse_coef – a float number representing the weight of the sparsity regularization on the Jacobian matrix
l0_latent – a float number representing the weight of the L0 regularization on the dimensions of latent variables
lambda_hsic – a float number representing the weight of the HSIC regularization
library_size_strategy – a string representing the library size normalization strategy. Options (default: “observed”): - “batch_sample”: sample from batch empirical distribution - “observed”: use the observed library size - “original”: set the library size as 1
device – the device to run the model on. Default is None, which uses ‘cuda’ if available, else ‘cpu’.
- train(epoch_num=200, batch_size=64, lr=1e-06, accumulation_steps=1, adaptlr=True, valid_prop=0, early_stopping=False, patience=10, tensorboard=False, savepath='./')
Train the model.
- Parameters:
epoch_num – an integer representing the number of epochs
batch_size – an integer representing the batch size
lr – a float number representing the learning rate
accumulation_steps – an integer representing the number of steps for gradient accumulation
adaptlr – a boolean representing whether to use adaptive learning rate
valid_prop – a float number representing the proportion of the dataset to use for validation
early_stopping – a boolean representing whether to use early stopping
patience – an integer representing the number of epochs to wait for improvement before stopping
tensorboard – a boolean representing whether to use tensorboard
savepath – a string representing the path to save the tensorboard logs
scDRP.train module
- class scDRP.train.EarlyStopping(patience=25, delta=0.0, verbose=False)
Bases:
objectEarly stops the training if validation loss doesn’t improve after a given patience.
- __init__(patience=25, delta=0.0, verbose=False)
Initialize the EarlyStopping object.
- Parameters:
patience – number of epochs with no improvement after which training will be stopped
delta – minimum change in the monitored quantity to qualify as an improvement
verbose – whether to print messages when stopping
- save_checkpoint(model, path='best_model.pt')
Saves model when validation loss decrease.
- scDRP.train.inference_model(device, infer_dataset, model, batch_size, count_data=False)
Do inference using the trained model
- Parameters:
device – device to run the model
infer_dataset – inference dataset
model – trained model to inference
batch_size – batch size
count_data – whether the data is count data
- Returns:
latent representation that depends on perturbation zu: latent representation that is independent from perturbation mu_d: mean of latent representation that depends on perturbation mu_u: mean of latent representation that is independent from perturbation rho: mean expression level dispersion: dispersion parameter (only for count data) pi: zero-inflation parameter (only for count data) library_size: size of the library (only for count data)
- Return type:
zd
- scDRP.train.train_model(device, writer, train_dataset, valid_dataset, model, epoch_num, batch_size, num_batch, lr=1e-06, accumulation_steps=1, adaptlr=True, count_data=False, early_stopping=True, patience=25)
Train the model
- Parameters:
device – device to run the model
writer – tensorboard writer
train_dataset – training dataset
model – model to train
epoch_num – number of epochs
batch_size – batch size
num_batch – number of batches
lr – learning rate
accumulation_steps – number of steps to accumulate gradients
adaptlr – whether to adapt learning rate
count_data – whether the data is count data
early_stopping – whether to use early stopping
patience – patience for early stopping
- scDRP.train.validate_model(device, validate_dataset, model, batch_size, count_data)
Validate the model
- Parameters:
device – device to run the model
validate_dataset – validation dataset
model – model to validate
batch_size – batch size
count_data – whether the data is count data
scDRP.utils module
- scDRP.utils.add_effect(control_zd, treatment_zd, control_label, treatment_label)
Computes counterfactual latent representations by adding the average treatment effect within each group. :param control_zd: Source matrix of shape (N, D_d). :type control_zd: np.ndarray :param treatment_zd: Target matrix of shape (M, D_d). :type treatment_zd: np.ndarray :param control_label: Group labels of shape (N,), indicating which group each row in X belongs to. :type control_label: np.ndarray :param treatment_label: Group labels of shape (M,), indicating which group each row in Y belongs to. :type treatment_label: np.ndarray
- scDRP.utils.condition_quantile(X, y=None)
Calculate the quantile of each column in X conditioned on y.
- Parameters:
X – a numpy array
y – a numpy array
- Returns:
a numpy array representing the quantile of each column in X conditioned on y
- scDRP.utils.conditional_OT_latent(control_zd, treatment_zd, control_zu, treatment_zu, control_label, treatment_label, alpha=1, beta=1, projection_strategy='full', value='raw', method='emd', reg=0.1, reg_m=1.0, eps=1e-08)
Computes a global optimal transport matching matrix (quantile matching of zd and value matching of zip), ensuring that matches occur only within groups.
- Parameters:
control_zd (np.ndarray) – Source matrix of shape (N, D_d).
control_zu (np.ndarray) – Source matrix of shape (N, D_i).
treatment_zd (np.ndarray) – Target matrix of shape (M, D_d).
treatment_zu (np.ndarray) – Target matrix of shape (M, D_i).
control_label (np.ndarray) – Group labels of shape (N,), indicating which group each row in X belongs to.
treatment_label (np.ndarray) – Group labels of shape (M,), indicating which group each row in Y belongs to.
alpha (float, optional) – Weight for the distance in zd space. Default is 1.
beta (float, optional) – Weight for the distance in zu space. Default is 1
projection_strategy (str, optional) – Strategy for constructing counterfactual latent representations. Options: - “full”: Use both zd and zu for counterfactual construction. - “zd_only”: Use only zd for counterfactual construction, keeping zu unchanged. Defaults to “full”.
method (str, optional) – Optimal transport method. Options: - “emd”: Exact Optimal Transport - “sinkhorn”: Sinkhorn Regularized OT - “unbalanced_sinkhorn”: Unbalanced Sinkhorn Regularized OT Defaults to “emd”.
reg (float, optional) – Entropic regularization parameter for Sinkhorn. Default is 0.1. Only useful when specifying method as “sinkhorn” or “unbalanced_sinkhorn”.
reg_m (float, optional) – Marginal relaxation parameter (higher allows more mass deviation). Default is 1.0. Only useful when specifying method as “unbalanced_sinkhorn”.
- Returns:
A global matching matrix (N, M), where matches are restricted within the same group.
- Return type:
np.ndarray
- scDRP.utils.to_dense_array(x)
Transform a potential sparse array to numpy array :param x: input array, can be sparse or numpy array
- Returns:
a numpy array
Module contents
A package to learn disentangled representations and estimate individual treatment effects in single-cell perturbation data.