Skip to content

Commit

Permalink
add: wandb logging of evaluation summary in HemmEvaluation
Browse files Browse the repository at this point in the history
soumik12345 committed Jun 22, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent a6e42ad commit 0371c03
Showing 2 changed files with 15 additions and 7 deletions.
1 change: 0 additions & 1 deletion hemm/eval_pipelines/eval_pipeline.py
Original file line number Diff line number Diff line change
@@ -103,7 +103,6 @@ def __call__(self, dataset: Union[List[Dict], str]) -> None:
evaluation = HemmEvaluation(
dataset=dataset,
scorers=[metric_fn.evaluate for metric_fn in self.metric_functions],
wandb_summary_table_name=f"Evalution/summary/{self.model.diffusion_model_name_or_path}",
)
with weave.attributes(self.evaluation_configs):
asyncio.run(evaluation.evaluate(self.infer))
21 changes: 15 additions & 6 deletions hemm/eval_pipelines/hemm_evaluation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import time
import traceback
from typing import Callable, Optional, Union, cast

@@ -14,20 +13,29 @@
from weave.trace.op import Op


def replace_backslash_dot(d):
if isinstance(d, dict):
new_dict = {}
for k, v in d.items():
new_key = k.replace("\\.", ".")
new_dict[new_key] = replace_backslash_dot(v)
return new_dict
elif isinstance(d, list):
return [replace_backslash_dot(i) for i in d]
else:
return d


class HemmEvaluation(weave.Evaluation):
dataset: Union[Dataset, list]
scorers: Optional[list[Union[Callable, Op, Scorer]]] = None
preprocess_model_input: Optional[Callable] = None
trials: int = 1
wandb_summary_table_name: str = None
wandb_summary_table: wandb.Table = wandb.Table(columns=["summary"])

@weave.op()
async def evaluate(self, model: Union[Callable, Model]) -> dict:
eval_rows = []

start_time = time.time()

async def eval_example(example: dict) -> dict:
try:
eval_row = await self.predict_and_score(model, example)
@@ -43,7 +51,7 @@ async def eval_example(example: dict) -> dict:
dataset = cast(Dataset, self.dataset)
_rows = dataset.rows
trial_rows = list(_rows) * self.trials
async for example, eval_row in async_foreach(
async for _, eval_row in async_foreach(
trial_rows, eval_example, get_weave_parallelism()
):
n_complete += 1
@@ -59,5 +67,6 @@ async def eval_example(example: dict) -> dict:
eval_rows.append(eval_row)

summary = await self.summarize(eval_rows)
wandb.log(replace_backslash_dot(summary))
rich.print("Evaluation summary", summary)
return summary

0 comments on commit 0371c03

Please sign in to comment.