Skip to content

Fix explained_variance computing variance relative to zero instead of mean#665

Open
amaljithkuttamath wants to merge 7 commits intodecoderesearch:mainfrom
amaljithkuttamath:fix/explained-variance-computation
Open

Fix explained_variance computing variance relative to zero instead of mean#665
amaljithkuttamath wants to merge 7 commits intodecoderesearch:mainfrom
amaljithkuttamath:fix/explained-variance-computation

Conversation

@amaljithkuttamath
Copy link
Copy Markdown

Summary

Fixes #659. Two bugs in get_sparsity_and_variance_metrics caused explained_variance to compute total variance relative to zero instead of relative to the mean:

  1. Line 552: mean_act_per_dimension accumulated .pow(2).mean(dim=0) (i.e. E[x_d^2]) instead of .mean(dim=0) (i.e. E[x_d]). When squared on line 586, this produced E[x_d^2]^2 instead of the correct E[x_d]^2.

  2. Line 585: torch.cat on a list of (d_model,) tensors produced (N_batches * d_model,), then .mean(dim=0) collapsed everything to a scalar, destroying per-dimension structure. Replaced with torch.stack and added .sum() to reduce across dimensions after squaring (matching how mean_sum_of_squares is already a scalar from .sum(dim=-1)).

The combined effect: total_variance was approximately E[||X||^2] (variance from zero) instead of E[||X||^2] - ||E[X]||^2 (variance from mean). For activations with large mean components, this inflated explained_variance.

Changes

  • sae_lens/evals.py: 3 line changes (remove .pow(2), cat -> stack, add .sum())
  • tests/test_evals.py: New test that constructs high-mean data and verifies the variance formula matches torch.var with correction=0. Also verifies the buggy formula produces a materially different (incorrect) result.

Test plan

  • New test test_explained_variance_uses_mean_centered_variance passes
  • Existing test test_get_sparsity_and_variance_metrics_identity_sae_perfect_reconstruction still passes
  • ruff check passes

… mean

Two bugs in the variance computation for explained_variance:

1. mean_act_per_dimension accumulated .pow(2).mean() instead of .mean(),
   computing E[x^2] per dimension instead of E[x] per dimension. This
   made the subtracted term in Var = E[||X||^2] - ||E[X]||^2 incorrect.

2. torch.cat on the per-batch mean vectors flattened them into one long
   vector, destroying per-dimension structure. Replaced with torch.stack
   and added .sum() to reduce across dimensions.

The combined effect was that total_variance was computed relative to zero
(essentially E[||X||^2]) instead of relative to the mean, inflating
explained_variance for activations with large mean components.

Fixes decoderesearch#659
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Fixes the explained_variance calculation in get_sparsity_and_variance_metrics so it uses mean-centered variance (variance relative to the mean) rather than variance relative to zero, bringing it in line with the intended multidimensional identity Var(X) = E[||X||²] - ||E[X]||².

Changes:

  • Correct mean_act_per_dimension to accumulate E[x_d] (not E[x_d²]).
  • Preserve per-dimension structure via torch.stack(...).mean(dim=0) and compute ||E[X]||² via a dimension-wise sum.
  • Add a targeted unit test validating the corrected variance identity and demonstrating the buggy behavior is materially different.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
sae_lens/evals.py Fixes mean and aggregation logic used to compute total variance for explained_variance.
tests/test_evals.py Adds a regression test ensuring total variance is mean-centered and catches the prior buggy formula.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

mean_act_per_dimension = torch.stack(mean_act_per_dimension).mean(dim=0)
total_variance = mean_sum_of_squares - (mean_act_per_dimension**2).sum()
residual_variance = torch.stack(mean_sum_of_resid_squared).mean(dim=0)
metrics["explained_variance"] = (1 - residual_variance / total_variance).item()
Copy link

Copilot AI Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

total_variance can legitimately be 0 (e.g., if only one unmasked token remains after ignore_tokens, or if activations are constant). In that case 1 - residual_variance / total_variance will produce inf/NaN. Consider guarding for total_variance <= eps (and possibly small negative values from fp roundoff) similarly to sae_lens/synthetic/evals.py, returning 1.0 when both variances are ~0, else 0.0 (or another defined fallback).

Suggested change
metrics["explained_variance"] = (1 - residual_variance / total_variance).item()
# Guard against zero / near-zero total variance to avoid inf/NaN.
# When both variances are ~0, treat explained variance as 1.0
# (perfect reconstruction of a constant signal); otherwise 0.0.
eps = 1e-12
if torch.abs(total_variance) <= eps:
if torch.abs(residual_variance) <= eps:
explained_variance = torch.tensor(1.0, device=total_variance.device)
else:
explained_variance = torch.tensor(0.0, device=total_variance.device)
else:
explained_variance = 1 - residual_variance / total_variance
metrics["explained_variance"] = explained_variance.item()

Copilot uses AI. Check for mistakes.
Comment on lines +669 to +675
# With the bug (.pow(2) on the mean term), the subtracted term captures E[x^2]^2
# instead of E[x]^2, making total_variance much larger than the true variance
# for data with a large mean.
buggy_mean_act = x.pow(2).mean(dim=0) # bug: .pow(2) before mean
buggy_total_var = (
mean_sum_of_squares - (buggy_mean_act**2).sum()
).item()
Copy link

Copilot AI Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The explanatory comment says the buggy formula makes total_variance “much larger than the true variance”, but with the specific buggy_mean_act = x.pow(2).mean(dim=0) used below the resulting buggy_total_var will typically be very negative (because you subtract (\sum_d E[x_d^2]^2)). Consider rewording to avoid misleading readers (and optionally clarify that the inflated explained_variance in the original bug required both the .pow(2) mistake and the cat->scalar collapse).

Copilot uses AI. Check for mistakes.
@chanind
Copy link
Copy Markdown
Collaborator

chanind commented Mar 30, 2026

Thank you for this PR. This PR is probably correct, but I just want to make sure before merging to avoid needing to fix this calculation a third time.

Copy link
Copy Markdown
Collaborator

@chanind chanind left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Main issues are the tests do not actually test anything, I'll try working on this

assert metrics["mse"] == pytest.approx(0.0, abs=1e-5)


def test_explained_variance_uses_mean_centered_variance():
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test doesn't test any actual code

chanind and others added 2 commits March 31, 2026 15:30
Pulls the explained variance logic into a standalone function called
from get_sparsity_and_variance_metrics, then replaces the tautological
tests with property-based tests that call the real function: single-batch
vs torch.var, batched vs unbatched equivalence, translation invariance,
zero-variance edge cases, and cross-check against ExplainedVarianceCalculator.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Uses a random SAE with apply_b_dec_to_input=True and real model
activations. Shifts both inputs and b_dec by a constant, verifies
explained_variance is unchanged (since the shift cancels in encoding
and both input and output shift equally).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@chanind chanind force-pushed the fix/explained-variance-computation branch from 56391d5 to accaaf4 Compare March 31, 2026 14:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Bug in explained_variance computation from PR #443

3 participants