Negative Binomial Modeling for scRNA-seq

The negative binomial (NB) distribution is the canonical statistical model for single-cell RNA-seq (scRNA-seq) gene expression integer counts. It’s favored over the Poisson distribution because real scRNA-seq data are overdispersed—their variance greatly exceeds their mean–due to both biological and technical factors (cell heterogeneity, bursty transcription, varied sequencing depth, dropouts, etc.).

The Poisson’s simple assumption (variance equals mean) fails in this context. In contrast, the NB augments the Poisson distribution with a dispersion parameter ($\theta$), so its variance is $\sigma^2 = \mu + \mu^2 / \theta$ (for mean $\mu$). This flexibility means two genes with equal mean expression but different variances (as commonly seen) can be modeled, simply by adjusting $\theta$.

From a generative perspective, the NB arises as a Poisson–Gamma mixture: each cell/gene’s unknown expression rate $\lambda$ is drawn from a Gamma distribution before Poisson sampling.

Interpretability is another reason for NB’s use. Its parameters have clear biological meaning:

  • $\mu$: expected transcript count
  • $\theta$: gene/cell-specific overdispersion

Parameterizations: Standard and Logit (Used in Deep Learning)

Standard NB:
Usually parameterized with shape ($\theta$, often called total_count) and success probability ($p$):

$$\mu = \frac{\theta (1-p)}{p} \quad \text{and} \quad \sigma^2 = \mu + \frac{\mu^2}{\theta}$$

Logit Parameterization:
For deep learning, the NB can be more robustly parameterized via logits instead of $p$:

$$\text{logits} = \log\left(\frac{p}{1-p}\right), \quad \mu = \theta \cdot e^{-\text{logits}}$$

Using logits stabilizes optimization, especially for auto-differentiation and large likelihoods.


NumPyro’s Negative Binomial Logits Implementation

Below is the core class definition for NegativeBinomialLogits within NumPyro, showcasing how this parameterization is handled internally.

class NegativeBinomialLogits(GammaPoisson):
    arg_constraints = {
        "total_count": constraints.positive,
        "logits": constraints.real,
    }
    support = constraints.nonnegative_integer

    def __init__(self, total_count, logits, _, validate_args=None):
        self.total_count, self.logits = promote_shapes(total_count, logits)
        concentration = total_count
        rate = jnp.exp(-logits)
        super().__init__(concentration, rate, validate_args=validate_args)

    @validate_sample
    def log_prob(self, value):
        return -(
            self.total_count * nn.softplus(self.logits)
            + value * nn.softplus(-self.logits)
            + _log_beta_1(self.total_count, value)
        )

This NegativeBinomialLogits class directly implements the logit parameterization, inheriting from GammaPoisson which underlines the Poisson-Gamma mixture interpretation.

  • __init__: The constructor takes total_count (our $\theta$) and logits as input. Crucially, it transforms these into concentration and rate parameters for the underlying Gamma distribution (jnp.exp(-logits) is used to derive the rate from logits). This reflects the Poisson-Gamma mixture where the mean of the Poisson is drawn from a Gamma distribution defined by these concentration/rate parameters.
  • log_prob: This is the main method that will participate in the loss trained by our VAE (alongside the KL term). It calculates the log-probability of observing a value (i.e., a specific count) given the total_count and logits. The formula used is a numerically stable computation of the negative binomial log-likelihood, leveraging nn.softplus for log(1+exp(x)) terms and _log_beta_1 for the beta function-related terms, ensuring robustness particularly for deep learning optimization.

Practical Implementation in VAEs (e.g., scVI)

Typical scVI-like models predict expected gene counts using neural networks and model observed counts as NB. Two common implementations differ mainly in how they handle normalization and numerical stability:

1. Standard (Softmax) Version used by scVI

rate = x_decoder(z)                       # Unconstrained decoder output
rate_softmax = jax.nn.softmax(rate)       # Gene proportions (sum to 1)
mu = jnp.log(l * rate_softmax + epsilon)  # Expected log-counts (scaled by library size l)
nb_logits = mu - jnp.log(theta + epsilon)
x_dist = dist.NegativeBinomialLogits(total_count=theta + epsilon, logits=nb_logits)
VariableDescription / Formula / Interpretation
zLatent representation from the VAE encoder.
x_decoder(z)Neural network output (logits for each gene, unnormalized).
rateUnnormalized decoder output (log-space) → interpretable as relative log-expression.
rate_softmaxSoftmax over genes → interpretable as cell-specific gene proportions.
lLibrary size (total counts per cell, either learned or fixed).
muLog of the expected counts for each gene: $\mu = \log(l \cdot \text{softmax(rate)})$ (This ensures $\sum \exp(\mu) \approx l$, matching the library size.)
thetaDispersion parameter (gene-specific).
nb_logitsComputed as $\log(\mu / \theta)$, the log-odds for the NB.
epsilonSmall constant for numerical stability.

2. Stable (Log-Space) Version

rate = x_decoder(z)
mu = jnp.log(l) + rate - jax.scipy.special.logsumexp(rate, axis=-1, keepdims=True)
nb_logits = mu - logtheta
x_dist = dist.NegativeBinomialLogits(total_count=jnp.exp(logtheta), logits=nb_logits)

With the following changes:

VariableDescription / Formula / Interpretation
logsumexp(rate)Log of the sum of exponentials (numerically stable softmax denominator).
muLog expected counts, computed as: $\mu = \log(l) + \text{rate} - \text{logsumexp(rate)}$. This is equivalent to the first implementation: $\mu = \log(l \cdot \text{softmax(rate)})$, but avoids explicit softmax.
lLibrary size (total counts per cell, either learned or fixed).
logthetaDirectly learned log of the dispersion parameter (more numerically stable).
nb_logitsComputed as $\log(\mu / \theta)$, where theta = exp(logtheta).

In that implementation, no need for epsilon!

Both approaches are mathematically equivalent. The second, log-space variant is regularly preferred in practice, especially as it avoids over/underflow errors in large or sparse data.

AspectSoftmax VersionLog-Space Version
NormalizationExplicit softmaxLogsumexp in log-space
StabilityGood, but softmax over/underflowExcellent, all log-domain
Dispersion HandlingDirect $\theta$Log-dispersion (logtheta)

Summary:

  • The negative binomial’s extra dispersion is essential for scRNA-seq modeling.
  • Deep learning models (e.g., scVI) implement NB via softmax, but here we propose logits parametrization for numerical stability and flexibility, employing log-space calculations for best results.
Avatar
Louis Faure
Postdoctoral Research Scholar