Skip to content

Commit

Permalink
Add probe data generation
Browse files Browse the repository at this point in the history
  • Loading branch information
sharanry committed Oct 29, 2024
1 parent 067cde9 commit a053838
Show file tree
Hide file tree
Showing 8 changed files with 2,434 additions and 175 deletions.
40 changes: 40 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,46 @@
### A framework to enable probing of language models.
![CI](https://github.com/sharanry/probe-lens/actions/workflows/ci.yaml/badge.svg)

## API
### Example Usage

Here is an example of how to use the ProbeLens framework to generate probe data and train a linear probe on a spelling task:

```python
from probe_lens.experiments.spelling import FirstLetterSpelling

words = ["example", "words", "to", "spell"]
spelling_task = FirstLetterSpelling(words)
```
```python
from sae_lens import HookedSAETransformer, SAE
DEVICE = "mps"
model = HookedSAETransformer.from_pretrained_no_processing("gpt2-small", device=DEVICE)
sae, cfg_dict, sparsity = SAE.from_pretrained(
release="gpt2-small-res-jb", # see other options in sae_lens/pretrained_saes.yaml
sae_id="blocks.8.hook_resid_pre", # won't always be a hook point
device=DEVICE,
)
```

```python
from torch.utils.data import DataLoader
dataset = spelling_task.generate_probe_data(model, sae, device=DEVICE)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
```

```python
from probe_lens.probes import LinearProbe
X, y = next(iter(dataloader))
probe = LinearProbe(X.shape[1], y.shape[1], class_names=spelling_task.get_classes())
```

```python
import torch.optim as optim
probe.train_probe(dataloader, optim.SGD(probe.parameters(), lr=0.01), val_dataloader=None, epochs=1000)
plot = probe.visualize_performance(dataloader)
```


## Roadmap
### Functionalities
Expand Down
2,445 changes: 2,274 additions & 171 deletions poetry.lock

Large diffs are not rendered by default.

42 changes: 42 additions & 0 deletions probe_lens/experiments/experiments.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from abc import ABC, abstractmethod

import torch
import torch.nn.functional as F
from sae_lens import SAE, HookedSAETransformer
from tqdm.autonotebook import tqdm
from transformer_lens import HookedTransformer

"""
Probe experiments are used to generate data for probing tasks.
"""
Expand All @@ -16,3 +22,39 @@ def __repr__(self) -> str:
@abstractmethod
def get_data(self) -> list[tuple[str, int]]:
pass

@abstractmethod
def get_classes(self) -> list[str]:
pass

def generate_probe_data(
self,
hooked_model: HookedSAETransformer | HookedTransformer,
sae: SAE | None = None,
device: str = "cpu", # consistent with sae_lens and transformer_lens
) -> torch.utils.data.TensorDataset:
sae_acts = []
answer_classes = []
for i, (prompt, answer_class) in enumerate(
tqdm(self.get_data(), desc="Generating probe data")
):
if sae is not None:
_, cache = hooked_model.run_with_cache_with_saes(
prompt, saes=[sae], stop_at_layer=sae.cfg.hook_layer + 1
)
sae_acts.append(
cache[sae.cfg.hook_name + ".hook_sae_acts_post"][0, -1, :]
)
else:
raise ValueError("Not implemented for non-SAE models.")

answer_classes.append(answer_class)

one_hot_answer_classes = F.one_hot(
torch.tensor(answer_classes, device=device),
num_classes=len(self.get_classes()),
).float()
dataset = torch.utils.data.TensorDataset(
torch.stack(sae_acts), one_hot_answer_classes
)
return dataset
10 changes: 8 additions & 2 deletions probe_lens/experiments/spelling.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Callable

import requests

from probe_lens.experiments.experiments import ProbeExperiment

LETTERS = "abcdefghijklmnopqrstuvwxyz"
Expand All @@ -17,20 +19,24 @@ def first_letter_index(word: str):
class FirstLetterSpelling(ProbeExperiment):
def __init__(
self,
words: list[str],
words: list[str] = requests.get(WORDS_DATASET).text.splitlines(),
prompt_fn: Callable[[str], str] = default_spelling_prompt_generator,
class_fn: Callable[[str], int] = first_letter_index,
):
super().__init__("First Letter Spelling Experiment")
self.words = words
self.prompt_fn = prompt_fn
self.class_fn = class_fn
self.class_names = list(LETTERS)
self.generate_data()

def generate_data(self):
self.classes = [self.class_fn(word) for word in self.words]
self.prompts = [self.prompt_fn(word) for word in self.words]
self.classes = [self.class_fn(word) for word in self.words]
self.data = list(zip(self.prompts, self.classes))

def get_data(self) -> list[tuple[str, int]]:
return self.data

def get_classes(self) -> list[str]:
return self.class_names
14 changes: 13 additions & 1 deletion probe_lens/probes.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,25 @@ def visualize_performance(
plt.show()
return plt

def accuracy(self, dataloader: torch.utils.data.DataLoader):
preds = []
gts = []
for X, y in dataloader:
pred = self(X)
gt = y.argmax(dim=1)
preds.append(pred.argmax(dim=1))
gts.append(gt)
preds = torch.cat(preds)
gts = torch.cat(gts)
return (preds == gts).sum().item() / len(preds)

def train_probe(
self,
dataloader: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer,
val_dataloader: torch.utils.data.DataLoader | None = None,
loss_fn: nn.Module = nn.BCEWithLogitsLoss(),
epochs: int = 10,
epochs: int = 1000,
verbose: bool = True,
):
tqdm_epochs = tqdm(range(epochs), desc="Training Probe", unit="epoch")
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ torch = "^2.5.0"
matplotlib = "^3.9.2"
scikit-learn = "^1.5.2"
seaborn = "^0.13.2"
sae-lens = "^4.0.9"
transformer-lens = "^2.8.1"

[tool.poetry.group.dev.dependencies]
ruff = "^0.5.1"
Expand Down
53 changes: 53 additions & 0 deletions tests/experiments/test_spelling.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,17 @@
import torch
from sae_lens import SAE, HookedSAETransformer
from torch.utils.data import DataLoader

from probe_lens.experiments.spelling import LETTERS, FirstLetterSpelling
from probe_lens.probes import LinearProbe

DEVICE = (
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)


def test_first_letter_spelling():
Expand All @@ -7,3 +20,43 @@ def test_first_letter_spelling():
data = spelling_task.data
classes = [c for _, c in data]
assert classes == [LETTERS.index(word.lower()[0]) for word in words]


def test_first_letter_spelling_default_dataset():
spelling_task = FirstLetterSpelling()
assert len(spelling_task.data) == 10000


def test_first_letter_spelling_probe_data():
words = ["example", "words", "to", "spell"]
spelling_task = FirstLetterSpelling(words)
model = HookedSAETransformer.from_pretrained_no_processing(
"gpt2-small", device=DEVICE
)
sae, cfg_dict, sparsity = SAE.from_pretrained(
release="gpt2-small-res-jb", # see other options in sae_lens/pretrained_saes.yaml
sae_id="blocks.8.hook_resid_pre", # won't always be a hook point
device=DEVICE,
)
dataset = spelling_task.generate_probe_data(model, sae, device=DEVICE)
assert len(dataset) == len(words)


def test_first_letter_spelling_probe_training():
words = ["example", "words", "to", "spell"]
spelling_task = FirstLetterSpelling(words)
model = HookedSAETransformer.from_pretrained_no_processing(
"gpt2-small", device=DEVICE
)
sae, cfg_dict, sparsity = SAE.from_pretrained(
release="gpt2-small-res-jb", # see other options in sae_lens/pretrained_saes.yaml
sae_id="blocks.8.hook_resid_pre", # won't always be a hook point
device=DEVICE,
)
dataset = spelling_task.generate_probe_data(model, sae, device=DEVICE)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
X, y = next(iter(dataloader))
probe = LinearProbe(X.shape[1], y.shape[1], device=DEVICE, class_names=LETTERS)
probe.train_probe(
dataloader, torch.optim.SGD(probe.parameters(), lr=0.01), val_dataloader=None
)
3 changes: 2 additions & 1 deletion tests/test_probes.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_linear_probe_training():
optimizer=torch.optim.SGD(model.parameters(), lr=0.01),
val_dataloader=dataloader,
loss_fn=nn.MSELoss(),
epochs=1000,
epochs=5000,
verbose=True,
)
print("Model weights: ", model.linear.weight)
Expand Down Expand Up @@ -84,3 +84,4 @@ def test_linear_probe_visualization():
plot.gca().get_title()
== f"Confusion Matrix (Accuracy: {accuracy:.4f}, F2 Score: {f2_score:.4f})"
)
assert model.accuracy(dataloader) == accuracy

0 comments on commit a053838

Please sign in to comment.