Skip to content

[Question] Why SAETrainer do not normalize W_dec? #544

@jasonrichdarmawan

Description

@jasonrichdarmawan

Questions

Assumption from the code:

  1. mse_loss depends on W_enc and W_dec.
  2. l1_loss depends on W_enc (and W_dec norm).

If we don't normalize W_dec, then wouldn't it be possible to make W_enc very small to make l1_loss small and adjust W_dec accordingly to make mse_loss constant?

In other words, the model cheats to lower the l1_loss

No code indicate normalizing W_dec in _train_step function

https://github.com/jbloomAus/SAELens/blob/3432f0059bc1d90119fe48f88e17ac07aea3a620/sae_lens/training/sae_trainer.py#L230-L281

training_forward_pass function

https://github.com/jbloomAus/SAELens/blob/3432f0059bc1d90119fe48f88e17ac07aea3a620/sae_lens/saes/sae.py#L917-L963

calculate_aux_loss function

https://github.com/jbloomAus/SAELens/blob/3432f0059bc1d90119fe48f88e17ac07aea3a620/sae_lens/saes/gated_sae.py#L173-L202

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions