From 3d048142016bafc7fe0b9d13556c0374e4e22776 Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Tue, 29 Oct 2024 15:29:19 +0000 Subject: [PATCH] Walkthrough for the readme example --- README.md | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/README.md b/README.md index c738fa7..d7577ae 100644 --- a/README.md +++ b/README.md @@ -7,12 +7,18 @@ Here is an example of how to use the ProbeLens framework to generate probe data and train a linear probe on a spelling task: + +#### Configure probe experiment +We use the `FirstLetterSpelling` experiment as an example. With the `FirstLetterSpelling` class, we see how well the model's activations encode the first letter of the word. ```python from probe_lens.experiments.spelling import FirstLetterSpelling words = ["example", "words", "to", "spell"] spelling_task = FirstLetterSpelling(words) ``` + +#### Configure hooked transformer model and SAE +We use the `HookedSAETransformer` class from `sae_lens` to hook the transformer model and the `SAE` class to get the SAE. This package is designed to be tightly integrated with `sae_lens` and `transformer_lens`. ```python from sae_lens import HookedSAETransformer, SAE DEVICE = "mps" @@ -24,21 +30,32 @@ sae, cfg_dict, sparsity = SAE.from_pretrained( ) ``` +#### Generate probe data +We use the `generate_probe_data` method to generate the probe data. This involves the querying the model with various prompts and capturing the activations. ```python from torch.utils.data import DataLoader dataset = spelling_task.generate_probe_data(model, sae, device=DEVICE) dataloader = DataLoader(dataset, batch_size=2, shuffle=True) ``` +#### Initialize probe +We initialize a linear probe with the same number of outputs as classes in the experiment. ```python from probe_lens.probes import LinearProbe X, y = next(iter(dataloader)) probe = LinearProbe(X.shape[1], y.shape[1], class_names=spelling_task.get_classes()) ``` +#### Train probe +We use stochastic gradient descent to train the probe. ```python import torch.optim as optim probe.train_probe(dataloader, optim.SGD(probe.parameters(), lr=0.01), val_dataloader=None, epochs=1000) +``` + +#### Visualize performance +We use the `visualize_performance` method to visualize the performance of the probe. +```python plot = probe.visualize_performance(dataloader) ```