Negative Binomial Modeling for scRNA-seq
The negative binomial (NB) distribution is the canonical statistical model for single-cell RNA-seq (scRNA-seq) gene expression integer counts. It’s favored over the Poisson distribution because real scRNA-seq data are overdispersed—their variance greatly exceeds their mean–due to both biological and technical factors (cell heterogeneity, bursty transcription, varied sequencing depth, dropouts, etc.).
The Poisson’s simple assumption (variance equals mean) fails in this context. In contrast, the NB augments the Poisson distribution with a dispersion parameter ($\theta$), so its variance is $\sigma^2 = \mu + \mu^2 / \theta$ (for mean $\mu$). This flexibility means two genes with equal mean expression but different variances (as commonly seen) can be modeled, simply by adjusting $\theta$.
From a generative perspective, the NB arises as a Poisson–Gamma mixture: each cell/gene’s unknown expression rate $\lambda$ is drawn from a Gamma distribution before Poisson sampling.
Interpretability is another reason for NB’s use. Its parameters have clear biological meaning:
- $\mu$: expected transcript count
- $\theta$: gene/cell-specific overdispersion
Parameterizations: Standard and Logit (Used in Deep Learning)
Standard NB:
Usually parameterized with shape ($\theta$, often called total_count
) and success probability ($p$):
Logit Parameterization:
For deep learning, the NB can be more robustly parameterized via logits instead of $p$:
Using logits stabilizes optimization, especially for auto-differentiation and large likelihoods.
NumPyro’s Negative Binomial Logits Implementation
Below is the core class definition for NegativeBinomialLogits
within NumPyro, showcasing how this parameterization is handled internally.
class NegativeBinomialLogits(GammaPoisson):
arg_constraints = {
"total_count": constraints.positive,
"logits": constraints.real,
}
support = constraints.nonnegative_integer
def __init__(self, total_count, logits, _, validate_args=None):
self.total_count, self.logits = promote_shapes(total_count, logits)
concentration = total_count
rate = jnp.exp(-logits)
super().__init__(concentration, rate, validate_args=validate_args)
@validate_sample
def log_prob(self, value):
return -(
self.total_count * nn.softplus(self.logits)
+ value * nn.softplus(-self.logits)
+ _log_beta_1(self.total_count, value)
)
This NegativeBinomialLogits
class directly implements the logit parameterization, inheriting from GammaPoisson
which underlines the Poisson-Gamma mixture interpretation.
__init__
: The constructor takestotal_count
(our $\theta$) andlogits
as input. Crucially, it transforms these intoconcentration
andrate
parameters for the underlying Gamma distribution (jnp.exp(-logits)
is used to derive the rate from logits). This reflects the Poisson-Gamma mixture where the mean of the Poisson is drawn from a Gamma distribution defined by these concentration/rate parameters.log_prob
: This is the main method that will participate in the loss trained by our VAE (alongside the KL term). It calculates the log-probability of observing avalue
(i.e., a specific count) given thetotal_count
andlogits
. The formula used is a numerically stable computation of the negative binomial log-likelihood, leveragingnn.softplus
forlog(1+exp(x))
terms and_log_beta_1
for the beta function-related terms, ensuring robustness particularly for deep learning optimization.
Practical Implementation in VAEs (e.g., scVI)
Typical scVI-like models predict expected gene counts using neural networks and model observed counts as NB. Two common implementations differ mainly in how they handle normalization and numerical stability:
1. Standard (Softmax) Version used by scVI
rate = x_decoder(z) # Unconstrained decoder output
rate_softmax = jax.nn.softmax(rate) # Gene proportions (sum to 1)
mu = jnp.log(l * rate_softmax + epsilon) # Expected log-counts (scaled by library size l)
nb_logits = mu - jnp.log(theta + epsilon)
x_dist = dist.NegativeBinomialLogits(total_count=theta + epsilon, logits=nb_logits)
Variable | Description / Formula / Interpretation |
---|---|
z | Latent representation from the VAE encoder. |
x_decoder(z) | Neural network output (logits for each gene, unnormalized). |
rate | Unnormalized decoder output (log-space) → interpretable as relative log-expression. |
rate_softmax | Softmax over genes → interpretable as cell-specific gene proportions. |
l | Library size (total counts per cell, either learned or fixed). |
mu | Log of the expected counts for each gene: $\mu = \log(l \cdot \text{softmax(rate)})$ (This ensures $\sum \exp(\mu) \approx l$, matching the library size.) |
theta | Dispersion parameter (gene-specific). |
nb_logits | Computed as $\log(\mu / \theta)$, the log-odds for the NB. |
epsilon | Small constant for numerical stability. |
2. Stable (Log-Space) Version
rate = x_decoder(z)
mu = jnp.log(l) + rate - jax.scipy.special.logsumexp(rate, axis=-1, keepdims=True)
nb_logits = mu - logtheta
x_dist = dist.NegativeBinomialLogits(total_count=jnp.exp(logtheta), logits=nb_logits)
With the following changes:
Variable | Description / Formula / Interpretation |
---|---|
logsumexp(rate) | Log of the sum of exponentials (numerically stable softmax denominator). |
mu | Log expected counts, computed as: $\mu = \log(l) + \text{rate} - \text{logsumexp(rate)}$. This is equivalent to the first implementation: $\mu = \log(l \cdot \text{softmax(rate)})$, but avoids explicit softmax. |
l | Library size (total counts per cell, either learned or fixed). |
logtheta | Directly learned log of the dispersion parameter (more numerically stable). |
nb_logits | Computed as $\log(\mu / \theta)$, where theta = exp(logtheta) . |
In that implementation, no need for epsilon
!
Both approaches are mathematically equivalent. The second, log-space variant is regularly preferred in practice, especially as it avoids over/underflow errors in large or sparse data.
Aspect | Softmax Version | Log-Space Version |
---|---|---|
Normalization | Explicit softmax | Logsumexp in log-space |
Stability | Good, but softmax over/underflow | Excellent, all log-domain |
Dispersion Handling | Direct $\theta$ | Log-dispersion (logtheta ) |
Summary:
- The negative binomial’s extra dispersion is essential for scRNA-seq modeling.
- Deep learning models (e.g., scVI) implement NB via softmax, but here we propose logits parametrization for numerical stability and flexibility, employing log-space calculations for best results.