4  Variational Sparse Coding

4.1 Motivation

We would want a model that has an interpretable latent space (by introducing sparcity) with more general feature disentenglement that \(\beta\)-VAE, meaning that different combinations of features can be present in different data points.

What is Posterior Collapse?

  • Problem: In VAEs, some latent dimensions become “useless” – they encode no meaningful information.
  • Why it happens:
    • The KL divergence term in ELBO forces latent variables to match the prior.
    • If the decoder is too powerful, it ignores latent variables, leading to dimensions being permanently inactive.
    • Result: Model uses only a few dimensions, losing sparsity and disentanglement.

How VSC Fixes It:

  1. Spike-and-Slab Warm-Up
    • Phase 1 (\(\lambda=0\)):
      • Forces latent variables to behave like binary codes (spike dominates).
      • Model must “choose” which dimensions to activate (no continuous refinement).
    • Phase 2 (\(\lambda \rightarrow 1\)):
      • Gradually introduces continuous slab parameters (\(\mu_{i,j}, \sigma_{i,j}\)).
      • Prevents early over-reliance on a few dimensions.
  2. Sparsity Enforcement
    • KL Sparsity Term: Penalizes average activation rate \(\bar{\gamma}_u\) if it deviates from \(\alpha\) (e.g., \(\alpha=0.01\)).
    • Forces the model to use only essential dimensions, avoiding redundancy.
  3. Dynamic Prior
    • Prior \(p_s(z)\) adapts via pseudo-inputs \(x_u\) and classifier \(C_\omega(x_i)\).
    • Prevents trivial alignment with a fixed prior (e.g., \(\mathcal{N}(0,1)\)).

Result:

  • Latent dimensions stay sparse and interpretable.
  • No single dimension dominates; features are distributed across active variables. Variational Sparse Coding (VSC) extends VAEs by inducing sparsity in the latent space using a Spike-and-Slab prior, enabling feature disentanglement and controlled generation when the number of latent factors is unknown.

4.2 Recognition Model

Spike-and-Slab Encoder Distribution
\[ q_\phi(z|x_i) = \prod_{j=1}^J \left[ \gamma_{i,j} \mathcal{N}(z_{i,j}; \mu_{i,j}, \sigma_{i,j}^2) + (1 - \gamma_{i,j}) \delta(z_{i,j}) \right] \]

Parameters: All outputs of a neural network (encoder).

  • \(\gamma_{i,j}\): Probability that latent dimension \(j\) is active for input \(x_i\).
  • \(\mu_{i,j}, \sigma_{i,j}\): Mean and variance of the Gaussian (slab) for active dimensions.
  • \(\delta(z_{i,j})\): Dirac delta function (spike) forces inactive dimensions to exactly 0.

Pytorch implementation of the reparameterization logic/model/vsc.py:

def reparameterize(self, 
    mu: torch.Tensor, 
    logvar: torch.Tensor, 
    gamma: torch.Tensor
    ) -> torch.Tensor:
    
    lamb = self.lambda_val  # warm-up factor
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)

    # Interpolate between a fixed (zero-mean, unit variance) slab 
    # and the learned slab.
    slab = lam * mu + eps * (lam * std + (1 - lam))
    
    # Sample binary spike; note: torch.bernoulli is not differentiable.
    spike = torch.bernoulli(gamma)
    
    return spike * slab

4.3 Training Procedure

4.3.1 Prior Distribution & Objective

Prior
\[ p_s(z) = q_\phi(z|x_{u^*}), \quad u^* = C_\omega(x_i) \]

  • Pseudo-inputs: Learnable templates \(\{x_u\}\) represent common feature combinations.
  • Classifier: \(C_\omega(x_i)\) selects the best-matching template \(x_{u^*}\) for input \(x_i\).

Objective (ELBO with Sparsity)
\[ \mathcal{L} = \sum_i \left[ -\text{KL}(q_\phi \| p_s) + \mathbb{E}_{q_\phi}[\log p_\theta(x_i|z)] \right] - J \cdot \text{KL}(\bar{\gamma}_u \| \alpha) \]

  • KL Term:
    • Aligns encoder (\(\mu_{i,j}, \sigma_{i,j}, \gamma_{i,j}\)) with prior (\(\mu_{u^*,j}, \sigma_{u^*,j}, \gamma_{u^*,j}\)).
    • Closed-form formula ensures fast computation.
  • Sparsity Term:
    • Penalizes deviation from target sparsity \(\alpha\) (e.g., 90% dimensions inactive).

Pytorch implementation in logic/model/vsc.py:

def compute_sparsity_loss(self, gamma: torch.Tensor) -> torch.Tensor:
    target = torch.full_like(gamma, self.prior_sparsity)
    return nn.functional.binary_cross_entropy(gamma, target)

4.3.2 Warm-Up Strategy

\[ q_{\phi,\lambda}(z|x_i) = \prod_{j=1}^J \left[ \gamma_{i,j} \mathcal{N}(z_{i,j}; \lambda \mu_{i,j}, \lambda \sigma_{i,j}^2 + (1-\lambda)) + (1 - \gamma_{i,j}) \delta(z_{i,j}) \right] \]

  • Phase 1 (\(\lambda=0\)):
    • Slab fixed to \(\mathcal{N}(0,1)\) (binary-like latent codes).
    • Focus: Which features to activate.
  • Phase 2 (\(\lambda \rightarrow 1\)):
    • Gradually learn \(\mu_{i,j}, \sigma_{i,j}\) (refine how to represent features).
  • Avoids collapse: Prevents premature “freezing” of latent dimensions.