Skip to content

feat: Subgraph scores and notebook#42

Open
hijohnnylin wants to merge 14 commits intodecoderesearch:mainfrom
hijohnnylin:subgraph_scores_and_notebook
Open

feat: Subgraph scores and notebook#42
hijohnnylin wants to merge 14 commits intodecoderesearch:mainfrom
hijohnnylin:subgraph_scores_and_notebook

Conversation

@hijohnnylin
Copy link
Copy Markdown
Collaborator

This PR adds subgraph replacement and completeness score calculation, which treats pruned features as errors. It merges the pruned edge weights with their error nodes.

The notebook demonstrates:

  • Graph Scores
  • Pruned Graph Scores (zeros out the pruned edges/nodes)
  • Subgraph Scores (merges the non-selected edges/nodes into error nodes)

The user should pass the unpruned graph for the most accurate subgraph score, but in cases where we do not have that (eg on Neuronpedia where we only have the pruned graph), the notebook shows that the different is not huge if we pass the pruned graph.

Important to review:

  • Check the logic in compute_subgraph_scores that I'm not borking something important. This was Claude (4.5) assisted!
  • Check that my reasoning makes sense for pruned graph scores and subgraph scores.

I did a "sanity test" by checking 1 pinned node vs 22 pinned nodes, in the Gemma Dallas Austin graph

1 Pinned Node (node '21_5943_10')
Graph: replacement_score=0.7022231817245483, completeness_score=0.9187235832214355
Pruned graph: replacement_score=0.5541447997093201, completeness_score=0.8545156121253967
Subgraph: replacement_score=0.2105160355567932, completeness_score=0.6135198473930359

22 Pinned Nodes (the default subgraph in the demo)
Graph: replacement_score=0.7022683620452881, completeness_score=0.9187378287315369
Pruned graph: replacement_score=0.5542337894439697, completeness_score=0.8545501828193665
Subgraph: replacement_score=0.3058098256587982, completeness_score=0.7032272815704346

The results seem... fine? Is the replacement score expected to be so low?

Credit to @mntss, @hannamw, and @neverix for entertaining my various questions.

selected_features=graph.selected_features,
activation_values=graph.activation_values,
scan=graph.scan,
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that if we implement this functionality, it should probably go into another file, like graph.py - this file is used for the .json outputs for visualization.

I also think that it would make sense to simply remove pruned features from the rows / columns of the adjacency matrix, as well as selected_features.

That said, I'm not sure if we need to explicitly represent the pruned graph at all, though if people have been wanting the ability to do so, I'm happy with adding it.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah thanks! I agree that the "pruned graph scores" doesn't really fit in very well in the circuit-tracer repo.

I think some of the weirdness in my implementation of this is due to restrictions on Neuronpedia. On Neuronpedia we only have access to the pruned json, but we'd like to still be able to show replacement/completeness scores for both the pruned graph and subgraph (which is also based on the pruned graph). Even though we don't use the Python score calculation implementation directly, the original intent was that the Python library will be able to output scores that we see on Neuronpedia, so that people can validate/repro.

But, I'm understanding now that I should just consider that separately and remove this to simplify and cause less potential confusion. It could make sense to just put an asterix+note on Neuronpedia that the scores will differ from the circuit-tracer library since we're basing it on pruned graph and not the whole graph there.

For now I've moved it to graph.py as per your request, but I'm also open to removing this entirely - let me know if the above makes sense (or if you have some other proposal that satisfies both circuit-tracer and Neuronpedia needs) and I'll proceed.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense to me! I agree that it'd be nice if people can relatively easily reproduce scores from Neuronpedia, and don't mind keeping this function. I also think that this function is a pretty reasonable thing to want independently of Neuronpedia anyway.

That said, I think it would make sense to remove pruned features from the rows / columns of the adjacency matrix, as well as selected_features; otherwise, the Graph will have incorrect (zero) values for the outgoing edges / direct effects of the pruned features. If we remove pruned features entirely, pruning features is essentially the same thing as having not attributed to them.

In the future, I'd like to record the direct effects of nodes that we don't attribute to, which will be aggregated into a "unattributed / pruned" node, a second type of error node. After that, the behavior of pruning will change a little - we will merge the direct effects of pruned features into this "unattributed / pruned node".

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

create_pruned_graph has been updated to remove the pruned features entirely. Please review the logic when convenient.
The same example notebook runs successfully, with a slight change in the pruned graph scores.

Comment thread circuit_tracer/graph.py Outdated
Format: "layer_featidx_pos" (e.g., "5_123_10" for layer 5, feature 123,
position 10). Features not in this list are treated as pruned/errors.
node_mask: Boolean tensor from prune_graph indicating which nodes survived pruning.
edge_mask: Boolean tensor from prune_graph indicating which edges survived pruning.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This implementation treats pruned nodes as nothing / zeros them out, but the original paper treats them as error nodes. I would probably follow their lead and do the same. If you do so, there's no need for the node / edge mask, since pruned nodes will naturally be excluded from your subgraph anyway.

As I mentioned over Slack, this leads to some weird behavior - nodes that were never attributed to don't get counted as error nodes, whereas nodes that are pruned do. This means that attributing to many nodes and then pruning could give you different results from attributing from fewer nodes, even if both processes left you with the same graph in the end. I think that the solution is to track unattributed nodes, rather than to zero out pruned nodes.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've removed the edge and node masks, which I believe means that if we pass in the unpruned graph, it should correctly track unattributed nodes.

Comment thread circuit_tracer/graph.py Outdated
Comment thread circuit_tracer/graph.py Outdated
Comment thread circuit_tracer/graph.py
# First apply masks, zeroing pruned nodes/edges
masked_adjacency = graph.adjacency_matrix.clone()
masked_adjacency[~node_mask] = 0
masked_adjacency[:, ~node_mask] = 0
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we don't need to zero these if we're going to remove them from the matrix and selected_features anyway.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants