4 Variational Sparse Coding
4.1 Motivation
We would want a model that has an interpretable latent space (by introducing sparcity) with more general feature disentenglement that \(\beta\)-VAE, meaning that different combinations of features can be present in different data points.
What is Posterior Collapse?
- Problem: In VAEs, some latent dimensions become “useless” – they encode no meaningful information.
- Why it happens:
- The KL divergence term in ELBO forces latent variables to match the prior.
- If the decoder is too powerful, it ignores latent variables, leading to dimensions being permanently inactive.
- Result: Model uses only a few dimensions, losing sparsity and disentanglement.
How VSC Fixes It:
- Spike-and-Slab Warm-Up
- Phase 1 (\(\lambda=0\)):
- Forces latent variables to behave like binary codes (spike dominates).
- Model must “choose” which dimensions to activate (no continuous refinement).
- Forces latent variables to behave like binary codes (spike dominates).
- Phase 2 (\(\lambda \rightarrow 1\)):
- Gradually introduces continuous slab parameters (\(\mu_{i,j}, \sigma_{i,j}\)).
- Prevents early over-reliance on a few dimensions.
- Gradually introduces continuous slab parameters (\(\mu_{i,j}, \sigma_{i,j}\)).
- Phase 1 (\(\lambda=0\)):
- Sparsity Enforcement
- KL Sparsity Term: Penalizes average activation rate \(\bar{\gamma}_u\) if it deviates from \(\alpha\) (e.g., \(\alpha=0.01\)).
- Forces the model to use only essential dimensions, avoiding redundancy.
- KL Sparsity Term: Penalizes average activation rate \(\bar{\gamma}_u\) if it deviates from \(\alpha\) (e.g., \(\alpha=0.01\)).
- Dynamic Prior
- Prior \(p_s(z)\) adapts via pseudo-inputs \(x_u\) and classifier \(C_\omega(x_i)\).
- Prevents trivial alignment with a fixed prior (e.g., \(\mathcal{N}(0,1)\)).
- Prior \(p_s(z)\) adapts via pseudo-inputs \(x_u\) and classifier \(C_\omega(x_i)\).
Result:
- Latent dimensions stay sparse and interpretable.
- No single dimension dominates; features are distributed across active variables. Variational Sparse Coding (VSC) extends VAEs by inducing sparsity in the latent space using a Spike-and-Slab prior, enabling feature disentanglement and controlled generation when the number of latent factors is unknown.
4.2 Recognition Model
Spike-and-Slab Encoder Distribution
\[ q_\phi(z|x_i) = \prod_{j=1}^J \left[ \gamma_{i,j} \mathcal{N}(z_{i,j}; \mu_{i,j}, \sigma_{i,j}^2) + (1 - \gamma_{i,j}) \delta(z_{i,j}) \right] \]
Parameters: All outputs of a neural network (encoder).
- \(\gamma_{i,j}\): Probability that latent dimension \(j\) is active for input \(x_i\).
- \(\mu_{i,j}, \sigma_{i,j}\): Mean and variance of the Gaussian (slab) for active dimensions.
- \(\delta(z_{i,j})\): Dirac delta function (spike) forces inactive dimensions to exactly 0.
Pytorch implementation of the reparameterization logic/model/vsc.py:
def reparameterize(self,
mu: torch.Tensor,
logvar: torch.Tensor,
gamma: torch.Tensor
) -> torch.Tensor:
lamb = self.lambda_val # warm-up factor
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
# Interpolate between a fixed (zero-mean, unit variance) slab
# and the learned slab.
slab = lam * mu + eps * (lam * std + (1 - lam))
# Sample binary spike; note: torch.bernoulli is not differentiable.
spike = torch.bernoulli(gamma)
return spike * slab4.3 Training Procedure
4.3.1 Prior Distribution & Objective
Prior
\[ p_s(z) = q_\phi(z|x_{u^*}), \quad u^* = C_\omega(x_i) \]
- Pseudo-inputs: Learnable templates \(\{x_u\}\) represent common feature combinations.
- Classifier: \(C_\omega(x_i)\) selects the best-matching template \(x_{u^*}\) for input \(x_i\).
Objective (ELBO with Sparsity)
\[ \mathcal{L} = \sum_i \left[ -\text{KL}(q_\phi \| p_s) + \mathbb{E}_{q_\phi}[\log p_\theta(x_i|z)] \right] - J \cdot \text{KL}(\bar{\gamma}_u \| \alpha) \]
- KL Term:
- Aligns encoder (\(\mu_{i,j}, \sigma_{i,j}, \gamma_{i,j}\)) with prior (\(\mu_{u^*,j}, \sigma_{u^*,j}, \gamma_{u^*,j}\)).
- Closed-form formula ensures fast computation.
- Aligns encoder (\(\mu_{i,j}, \sigma_{i,j}, \gamma_{i,j}\)) with prior (\(\mu_{u^*,j}, \sigma_{u^*,j}, \gamma_{u^*,j}\)).
- Sparsity Term:
- Penalizes deviation from target sparsity \(\alpha\) (e.g., 90% dimensions inactive).
Pytorch implementation in logic/model/vsc.py:
def compute_sparsity_loss(self, gamma: torch.Tensor) -> torch.Tensor:
target = torch.full_like(gamma, self.prior_sparsity)
return nn.functional.binary_cross_entropy(gamma, target)4.3.2 Warm-Up Strategy
\[ q_{\phi,\lambda}(z|x_i) = \prod_{j=1}^J \left[ \gamma_{i,j} \mathcal{N}(z_{i,j}; \lambda \mu_{i,j}, \lambda \sigma_{i,j}^2 + (1-\lambda)) + (1 - \gamma_{i,j}) \delta(z_{i,j}) \right] \]
- Phase 1 (\(\lambda=0\)):
- Slab fixed to \(\mathcal{N}(0,1)\) (binary-like latent codes).
- Focus: Which features to activate.
- Slab fixed to \(\mathcal{N}(0,1)\) (binary-like latent codes).
- Phase 2 (\(\lambda \rightarrow 1\)):
- Gradually learn \(\mu_{i,j}, \sigma_{i,j}\) (refine how to represent features).
- Gradually learn \(\mu_{i,j}, \sigma_{i,j}\) (refine how to represent features).
- Avoids collapse: Prevents premature “freezing” of latent dimensions.