diff --git a/README.md b/README.md index d7577ae..1d933d2 100644 --- a/README.md +++ b/README.md @@ -43,7 +43,7 @@ We initialize a linear probe with the same number of outputs as classes in the e ```python from probe_lens.probes import LinearProbe X, y = next(iter(dataloader)) -probe = LinearProbe(X.shape[1], y.shape[1], class_names=spelling_task.get_classes()) +probe = LinearProbe(X.shape[1], y.shape[1], class_names=spelling_task.get_classes(), device=DEVICE) ``` #### Train probe @@ -58,6 +58,7 @@ We use the `visualize_performance` method to visualize the performance of the pr ```python plot = probe.visualize_performance(dataloader) ``` +![Confusion Matrix](confusion_matrix.png) ## Roadmap @@ -81,3 +82,8 @@ plot = probe.visualize_performance(dataloader) - [ ] Add more visualization experiments - [ ] ... ? +### Documentation +- [ ] Add docstrings +- [ ] Add tutorials +- [ ] Reproduce experiments from major papers (SAE-Spelling, etc.) + diff --git a/confusion_matrix.png b/confusion_matrix.png new file mode 100644 index 0000000..c96f559 Binary files /dev/null and b/confusion_matrix.png differ diff --git a/probe_lens/probes.py b/probe_lens/probes.py index c644bf9..a46d347 100644 --- a/probe_lens/probes.py +++ b/probe_lens/probes.py @@ -30,14 +30,13 @@ def visualize_performance( accuracy = accuracy_score(gts.cpu(), preds.cpu()) f2_score = fbeta_score(gts.cpu(), preds.cpu(), beta=2, average="weighted") - cm = confusion_matrix(gts.cpu(), preds.cpu()) - plt.figure(figsize=(10, 7)) _class_names = ( self.class_names if self.class_names else [str(i) for i in range(cm.shape[0])] ) + plt.figure(figsize=(10, 7)) sns.heatmap( cm, annot=True,