Skip to content

Commit

Permalink
Walkthrough for the readme example
Browse files Browse the repository at this point in the history
  • Loading branch information
sharanry committed Oct 29, 2024
1 parent a053838 commit 3d04814
Showing 1 changed file with 17 additions and 0 deletions.
17 changes: 17 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,18 @@

Here is an example of how to use the ProbeLens framework to generate probe data and train a linear probe on a spelling task:


#### Configure probe experiment
We use the `FirstLetterSpelling` experiment as an example. With the `FirstLetterSpelling` class, we see how well the model's activations encode the first letter of the word.
```python
from probe_lens.experiments.spelling import FirstLetterSpelling

words = ["example", "words", "to", "spell"]
spelling_task = FirstLetterSpelling(words)
```

#### Configure hooked transformer model and SAE
We use the `HookedSAETransformer` class from `sae_lens` to hook the transformer model and the `SAE` class to get the SAE. This package is designed to be tightly integrated with `sae_lens` and `transformer_lens`.
```python
from sae_lens import HookedSAETransformer, SAE
DEVICE = "mps"
Expand All @@ -24,21 +30,32 @@ sae, cfg_dict, sparsity = SAE.from_pretrained(
)
```

#### Generate probe data
We use the `generate_probe_data` method to generate the probe data. This involves the querying the model with various prompts and capturing the activations.
```python
from torch.utils.data import DataLoader
dataset = spelling_task.generate_probe_data(model, sae, device=DEVICE)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
```

#### Initialize probe
We initialize a linear probe with the same number of outputs as classes in the experiment.
```python
from probe_lens.probes import LinearProbe
X, y = next(iter(dataloader))
probe = LinearProbe(X.shape[1], y.shape[1], class_names=spelling_task.get_classes())
```

#### Train probe
We use stochastic gradient descent to train the probe.
```python
import torch.optim as optim
probe.train_probe(dataloader, optim.SGD(probe.parameters(), lr=0.01), val_dataloader=None, epochs=1000)
```

#### Visualize performance
We use the `visualize_performance` method to visualize the performance of the probe.
```python
plot = probe.visualize_performance(dataloader)
```

Expand Down

0 comments on commit 3d04814

Please sign in to comment.