Questions
Assumption from the code:
mse_loss depends on W_enc and W_dec.
l1_loss depends on W_enc (and W_dec norm).
If we don't normalize W_dec, then wouldn't it be possible to make W_enc very small to make l1_loss small and adjust W_dec accordingly to make mse_loss constant?
In other words, the model cheats to lower the l1_loss
No code indicate normalizing W_dec in _train_step function
https://github.com/jbloomAus/SAELens/blob/3432f0059bc1d90119fe48f88e17ac07aea3a620/sae_lens/training/sae_trainer.py#L230-L281
training_forward_pass function
https://github.com/jbloomAus/SAELens/blob/3432f0059bc1d90119fe48f88e17ac07aea3a620/sae_lens/saes/sae.py#L917-L963
calculate_aux_loss function
https://github.com/jbloomAus/SAELens/blob/3432f0059bc1d90119fe48f88e17ac07aea3a620/sae_lens/saes/gated_sae.py#L173-L202
Questions
Assumption from the code:
mse_lossdepends onW_encandW_dec.l1_lossdepends onW_enc(andW_decnorm).If we don't normalize
W_dec, then wouldn't it be possible to makeW_encvery small to makel1_losssmall and adjustW_decaccordingly to makemse_lossconstant?In other words, the model cheats to lower the
l1_lossNo code indicate normalizing
W_decin_train_stepfunctionhttps://github.com/jbloomAus/SAELens/blob/3432f0059bc1d90119fe48f88e17ac07aea3a620/sae_lens/training/sae_trainer.py#L230-L281
training_forward_passfunctionhttps://github.com/jbloomAus/SAELens/blob/3432f0059bc1d90119fe48f88e17ac07aea3a620/sae_lens/saes/sae.py#L917-L963
calculate_aux_lossfunctionhttps://github.com/jbloomAus/SAELens/blob/3432f0059bc1d90119fe48f88e17ac07aea3a620/sae_lens/saes/gated_sae.py#L173-L202