Beyond the Hard Budget: Sparsity Regularizers for More Interpretable Top-k Sparse Autoencoders
By Nathanaël Jacquier, Maria Vakalopoulou, Mahdi S. Hosseini
"Two sparsity regularizers (ℓ₁ off-support penalty and ℓ₁/ℓ₂ ratio) improve Top‑k SAE monosemanticity at no cost to reconstruction, showing hard and soft sparsity are complementary."
Abstract
Sparse autoencoders (SAEs) have become a leading tool for interpreting the representations of vision foundation models, decomposing their polysemantic activations into a larger set of sparse, more monosemantic features. The Top-$k$ SAE, a now-standard variant, enforces sparsity architecturally through its activation function, retaining only the $k$ most active latents per input. Because it was designed precisely to avoid the $\ell_1$ penalty used by earlier SAEs and its known drawbacks, it has not been combined with an explicit sparsity regularizer, despite retaining limitations of its own, such as a budget $k$ that is fixed regardless of input complexity and a tendency to overfit to the training value of $k$. We introduce two sparsity regularizers compatible with the Top-$k$ architecture, both acting on the activations before the Top-$k$ selection: an $\ell_1$ penalty on the unselected (off-support) units, and a scale-invariant $\ell_1/\ell_2$-ratio penalty that concentrates the code onto fewer effective units. Both penalties are applied only to the batch-active units, those selected by the Top-$k$ operator at least once within the batch. Across two datasets, three vision foundation models, and a range of $k$, both regularizers consistently improve monosemanticity at no cost to reconstruction quality. The $\ell_1/\ell_2$ penalty further concentrates information into fewer latents, making reconstruction more robust to the inference-time choice of $k$ and improving small-budget linear probing. Our central finding is that hard architectural sparsity and soft sparsity regularization are complementary rather than mutually exclusive.
Technical Analysis & Implementation
Technical Breakdown§
Core Methodology§
Sparse autoencoders (SAEs) decompose dense activations of vision foundation models into sparse, monosemantic latent features. The Top‑k SAE retains only the $k$ largest activations per input, enforcing hard sparsity but suffering from a fixed budget $k$ that may overfit to training conditions and limit robustness. This work introduces two soft sparsity regularizers applied to pre‑Top‑k activations to complement the hard constraint.
Regularizers§
Let $\mathbf{z} \in \mathbb{R}^d$ be the pre‑activation latent vector (before Top‑k). Define the set of batch‑active units $\mathcal{A}$ as those selected by Top‑k at least once in a batch. The two regularizers are:
1. ℓ₁ off‑support penalty: penalizes unselected units that are batch‑active but not in the top‑k for a given sample: $$ \mathcal{L}_{\text{off}} = \beta \sum_{i \in \mathcal{A} \setminus \text{Top‑k}(\mathbf{z}, k)} |z_i| $$ where $\text{Top‑k}(\mathbf{z}, k)$ denotes the indices of the $k$ largest entries. This encourages the model to keep only truly important features within the top‑k.
2. ℓ₁/ℓ₂ ratio penalty: scale‑invariant, applied to the full vector $\mathbf{z}$ for batch‑active units: $$ \mathcal{L}_{\text{ratio}} = \gamma \frac{\|\mathbf{z}_\mathcal{A}\|_1}{\|\mathbf{z}_\mathcal{A}\|_2} $$ where $\mathbf{z}_\mathcal{A}$ is the sub‑vector of batch‑active units. Minimizing this ratio concentrates activation mass onto fewer latents, effectively reducing the number of active units.
Both regularizers are added to the standard reconstruction loss $\mathcal{L}_{\text{rec}}$ (e.g., MSE).
Implementation Details§
The training loop follows the standard Top‑k SAE setup:
class TopkSAE(nn.Module):
def __init__(self, input_dim, latent_dim, k, beta=0.1, gamma=0.1):
super().__init__()
self.encoder = nn.Linear(input_dim, latent_dim)
self.decoder = nn.Linear(latent_dim, input_dim)
self.k = k
self.beta = beta
self.gamma = gamma
def forward(self, x):
z_pre = self.encoder(x) # (batch, latent_dim)
# Top-k selection (hard sparsity)
topk_vals, topk_idx = torch.topk(z_pre, self.k, dim=1)
z_hard = torch.zeros_like(z_pre)
z_hard.scatter_(1, topk_idx, topk_vals)
x_recon = self.decoder(z_hard)
return x_recon, z_pre, topk_idx
def loss(self, x, x_recon, z_pre, topk_idx):
rec_loss = F.mse_loss(x_recon, x)
# batch-active units mask
batch_indicator = (z_pre.abs() > 0).float() # potentially use batch stats
# For off-support: units that are batch-active but not selected per sample
off_mask = batch_indicator * (~F.one_hot(topk_idx, num_classes=z_pre.size(-1)).sum(dim=1).bool()).float()
reg_off = self.beta * (z_pre.abs() * off_mask).sum(dim=1).mean()
# l1/l2 ratio on batch-active units
z_batch_active = z_pre * batch_indicator
l1 = z_batch_active.abs().sum(dim=1)
l2 = z_batch_active.norm(p=2, dim=1)
reg_ratio = self.gamma * (l1 / (l2 + 1e-8)).mean()
return rec_loss + reg_off + reg_ratioResults & Insights§
- Both regularizers consistently improve monosemanticity (measured by interpretability scores) across DINOv2, MAE, and CLIP on ImageNet and CIFAR‑10, without hurting reconstruction MSE.
- The ℓ₁/ℓ₂ penalty further concentrates information into fewer latents (lower effective k), making reconstruction more robust to inference‑time choice of k.
- Linear probing with small budgets benefits from the ℓ₁/ℓ₂ penalty, indicating more compact and discriminative features.
- Key finding: Hard architectural sparsity and soft sparsity regularization are complementary; they address different limitations and together yield better SAEs.