Skip to content

Commit 4ffc891

Browse files
committed
Tuned KL divergence
2 parents 88f1b35 + 6a6368c commit 4ffc891

File tree

2 files changed

+20
-13
lines changed

2 files changed

+20
-13
lines changed

delphi/log/result_analysis.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@ def plot_firing_vs_f1(
1515
out_dir.mkdir(parents=True, exist_ok=True)
1616
for module, module_df in latent_df.groupby("module"):
1717

18-
if 'firing_count' not in module_df.columns:
19-
print(f"""WARNING: 'firing_count' column not found for module {module}.
20-
Skipping plot.""")
18+
if "firing_count" not in module_df.columns:
19+
print(
20+
f"""WARNING: 'firing_count' column not found for module {module}.
21+
Skipping plot."""
22+
)
2123
continue
2224

2325
module_df = module_df.copy()
@@ -175,9 +177,11 @@ def parse_score_file(path: Path) -> pd.DataFrame:
175177
return pd.DataFrame()
176178

177179
if not isinstance(data, list):
178-
print(f"""Warning: Expected a list of results in {path},
179-
but found {type(data)}.
180-
Skipping file.""")
180+
print(
181+
f"""Warning: Expected a list of results in {path},
182+
but found {type(data)}.
183+
Skipping file."""
184+
)
181185
return pd.DataFrame()
182186

183187
latent_idx = int(path.stem.split("latent")[-1])
@@ -327,9 +331,11 @@ def log_results(
327331
print(f"Class-Balanced Accuracy: {score_type_summary['accuracy']:.3f}")
328332
print(f"F1 Score: {score_type_summary['f1_score']:.3f}")
329333

330-
if counts and score_type_summary['weighted_f1'] is not None:
331-
print(f"""Frequency-Weighted F1 Score:
332-
{score_type_summary['weighted_f1']:.3f}""")
334+
if counts and score_type_summary["weighted_f1"] is not None:
335+
print(
336+
f"""Frequency-Weighted F1 Score:
337+
{score_type_summary['weighted_f1']:.3f}"""
338+
)
333339

334340
print(f"Precision: {score_type_summary['precision']:.3f}")
335341
print(f"Recall: {score_type_summary['recall']:.3f}")

delphi/scorers/intervention/surprisal_intervention_scorer.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,10 +138,11 @@ def _resolve_hookpoint(self, model: Any, hookpoint_str: str) -> Any:
138138
try:
139139
return self._find_layer(model, full_path)
140140
except AttributeError as e:
141-
raise AttributeError(f"""Could not resolve path '{full_path}'.
142-
Model structure might be unexpected.
143-
Original error: {e}""")
144-
141+
raise AttributeError(
142+
f"""Could not resolve path '{full_path}'.
143+
Model structure might be unexpected.
144+
Original error: {e}"""
145+
)
145146

146147
def _sanitize_examples(self, examples: List[Any]) -> List[Dict[str, Any]]:
147148
"""

0 commit comments

Comments
 (0)