Fix explained_variance computing variance relative to zero instead of mean#665
Conversation
… mean Two bugs in the variance computation for explained_variance: 1. mean_act_per_dimension accumulated .pow(2).mean() instead of .mean(), computing E[x^2] per dimension instead of E[x] per dimension. This made the subtracted term in Var = E[||X||^2] - ||E[X]||^2 incorrect. 2. torch.cat on the per-batch mean vectors flattened them into one long vector, destroying per-dimension structure. Replaced with torch.stack and added .sum() to reduce across dimensions. The combined effect was that total_variance was computed relative to zero (essentially E[||X||^2]) instead of relative to the mean, inflating explained_variance for activations with large mean components. Fixes decoderesearch#659
There was a problem hiding this comment.
Pull request overview
Fixes the explained_variance calculation in get_sparsity_and_variance_metrics so it uses mean-centered variance (variance relative to the mean) rather than variance relative to zero, bringing it in line with the intended multidimensional identity Var(X) = E[||X||²] - ||E[X]||².
Changes:
- Correct
mean_act_per_dimensionto accumulateE[x_d](notE[x_d²]). - Preserve per-dimension structure via
torch.stack(...).mean(dim=0)and compute||E[X]||²via a dimension-wise sum. - Add a targeted unit test validating the corrected variance identity and demonstrating the buggy behavior is materially different.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
sae_lens/evals.py |
Fixes mean and aggregation logic used to compute total variance for explained_variance. |
tests/test_evals.py |
Adds a regression test ensuring total variance is mean-centered and catches the prior buggy formula. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
sae_lens/evals.py
Outdated
| mean_act_per_dimension = torch.stack(mean_act_per_dimension).mean(dim=0) | ||
| total_variance = mean_sum_of_squares - (mean_act_per_dimension**2).sum() | ||
| residual_variance = torch.stack(mean_sum_of_resid_squared).mean(dim=0) | ||
| metrics["explained_variance"] = (1 - residual_variance / total_variance).item() |
There was a problem hiding this comment.
total_variance can legitimately be 0 (e.g., if only one unmasked token remains after ignore_tokens, or if activations are constant). In that case 1 - residual_variance / total_variance will produce inf/NaN. Consider guarding for total_variance <= eps (and possibly small negative values from fp roundoff) similarly to sae_lens/synthetic/evals.py, returning 1.0 when both variances are ~0, else 0.0 (or another defined fallback).
| metrics["explained_variance"] = (1 - residual_variance / total_variance).item() | |
| # Guard against zero / near-zero total variance to avoid inf/NaN. | |
| # When both variances are ~0, treat explained variance as 1.0 | |
| # (perfect reconstruction of a constant signal); otherwise 0.0. | |
| eps = 1e-12 | |
| if torch.abs(total_variance) <= eps: | |
| if torch.abs(residual_variance) <= eps: | |
| explained_variance = torch.tensor(1.0, device=total_variance.device) | |
| else: | |
| explained_variance = torch.tensor(0.0, device=total_variance.device) | |
| else: | |
| explained_variance = 1 - residual_variance / total_variance | |
| metrics["explained_variance"] = explained_variance.item() |
tests/test_evals.py
Outdated
| # With the bug (.pow(2) on the mean term), the subtracted term captures E[x^2]^2 | ||
| # instead of E[x]^2, making total_variance much larger than the true variance | ||
| # for data with a large mean. | ||
| buggy_mean_act = x.pow(2).mean(dim=0) # bug: .pow(2) before mean | ||
| buggy_total_var = ( | ||
| mean_sum_of_squares - (buggy_mean_act**2).sum() | ||
| ).item() |
There was a problem hiding this comment.
The explanatory comment says the buggy formula makes total_variance “much larger than the true variance”, but with the specific buggy_mean_act = x.pow(2).mean(dim=0) used below the resulting buggy_total_var will typically be very negative (because you subtract (\sum_d E[x_d^2]^2)). Consider rewording to avoid misleading readers (and optionally clarify that the inflated explained_variance in the original bug required both the .pow(2) mistake and the cat->scalar collapse).
|
Thank you for this PR. This PR is probably correct, but I just want to make sure before merging to avoid needing to fix this calculation a third time. |
chanind
left a comment
There was a problem hiding this comment.
Main issues are the tests do not actually test anything, I'll try working on this
tests/test_evals.py
Outdated
| assert metrics["mse"] == pytest.approx(0.0, abs=1e-5) | ||
|
|
||
|
|
||
| def test_explained_variance_uses_mean_centered_variance(): |
There was a problem hiding this comment.
This test doesn't test any actual code
Pulls the explained variance logic into a standalone function called from get_sparsity_and_variance_metrics, then replaces the tautological tests with property-based tests that call the real function: single-batch vs torch.var, batched vs unbatched equivalence, translation invariance, zero-variance edge cases, and cross-check against ExplainedVarianceCalculator. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Uses a random SAE with apply_b_dec_to_input=True and real model activations. Shifts both inputs and b_dec by a constant, verifies explained_variance is unchanged (since the shift cancels in encoding and both input and output shift equally). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
56391d5 to
accaaf4
Compare
Summary
Fixes #659. Two bugs in
get_sparsity_and_variance_metricscausedexplained_varianceto compute total variance relative to zero instead of relative to the mean:Line 552:
mean_act_per_dimensionaccumulated.pow(2).mean(dim=0)(i.e.E[x_d^2]) instead of.mean(dim=0)(i.e.E[x_d]). When squared on line 586, this producedE[x_d^2]^2instead of the correctE[x_d]^2.Line 585:
torch.caton a list of(d_model,)tensors produced(N_batches * d_model,), then.mean(dim=0)collapsed everything to a scalar, destroying per-dimension structure. Replaced withtorch.stackand added.sum()to reduce across dimensions after squaring (matching howmean_sum_of_squaresis already a scalar from.sum(dim=-1)).The combined effect:
total_variancewas approximatelyE[||X||^2](variance from zero) instead ofE[||X||^2] - ||E[X]||^2(variance from mean). For activations with large mean components, this inflatedexplained_variance.Changes
sae_lens/evals.py: 3 line changes (remove.pow(2),cat->stack, add.sum())tests/test_evals.py: New test that constructs high-mean data and verifies the variance formula matchestorch.varwithcorrection=0. Also verifies the buggy formula produces a materially different (incorrect) result.Test plan
test_explained_variance_uses_mean_centered_variancepassestest_get_sparsity_and_variance_metrics_identity_sae_perfect_reconstructionstill passesruff checkpasses