Skip to content

Commit

Permalink
Merge pull request #5 from wandb/update/eval
Browse files Browse the repository at this point in the history
Add HemmEvaluation
  • Loading branch information
soumik12345 authored Jun 22, 2024
2 parents 3f6a6e6 + 0371c03 commit 781348d
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 11 deletions.
16 changes: 5 additions & 11 deletions hemm/eval_pipelines/eval_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
import asyncio
import base64
from abc import ABC
from io import BytesIO
from typing import Callable, Dict, List, Union
from typing import Dict, List, Union

import wandb
import weave
from PIL import Image
from weave import Evaluation

from .hemm_evaluation import HemmEvaluation
from .model import BaseDiffusionModel
from ..metrics.base import BaseMetric
from ..utils import base64_decode_image


class EvaluationPipeline(ABC):
Expand Down Expand Up @@ -76,11 +74,7 @@ async def infer(self, prompt: str) -> Dict[str, str]:
[
self.model.diffusion_model_name_or_path,
prompt,
wandb.Image(
Image.open(
BytesIO(base64.b64decode(output["image"].split(";base64,")[-1]))
)
),
wandb.Image(base64_decode_image(output["image"])),
]
)
return output
Expand All @@ -106,7 +100,7 @@ def __call__(self, dataset: Union[List[Dict], str]) -> None:
passed, it is assumed to be a Weave dataset reference.
"""
dataset = weave.ref(dataset).get() if isinstance(dataset, str) else dataset
evaluation = Evaluation(
evaluation = HemmEvaluation(
dataset=dataset,
scorers=[metric_fn.evaluate for metric_fn in self.metric_functions],
)
Expand Down
72 changes: 72 additions & 0 deletions hemm/eval_pipelines/hemm_evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import traceback
from typing import Callable, Optional, Union, cast

import rich
import wandb
import weave
from weave.flow.dataset import Dataset
from weave.flow.model import Model
from weave.flow.util import async_foreach
from weave.flow.scorer import Scorer, get_scorer_attributes
from weave.trace.errors import OpCallError
from weave.trace.env import get_weave_parallelism
from weave.trace.op import Op


def replace_backslash_dot(d):
if isinstance(d, dict):
new_dict = {}
for k, v in d.items():
new_key = k.replace("\\.", ".")
new_dict[new_key] = replace_backslash_dot(v)
return new_dict
elif isinstance(d, list):
return [replace_backslash_dot(i) for i in d]
else:
return d


class HemmEvaluation(weave.Evaluation):
dataset: Union[Dataset, list]
scorers: Optional[list[Union[Callable, Op, Scorer]]] = None
preprocess_model_input: Optional[Callable] = None
trials: int = 1

@weave.op()
async def evaluate(self, model: Union[Callable, Model]) -> dict:
eval_rows = []

async def eval_example(example: dict) -> dict:
try:
eval_row = await self.predict_and_score(model, example)
except OpCallError as e:
raise e
except Exception as e:
rich.print("Predict and score failed")
traceback.print_exc()
return {"model_output": None, "scores": {}}
return eval_row

n_complete = 0
dataset = cast(Dataset, self.dataset)
_rows = dataset.rows
trial_rows = list(_rows) * self.trials
async for _, eval_row in async_foreach(
trial_rows, eval_example, get_weave_parallelism()
):
n_complete += 1
rich.print(f"Evaluated {n_complete} of {len(trial_rows)} examples")
if eval_row == None:
eval_row = {"model_output": None, "scores": {}}
if eval_row["scores"] == None:
eval_row["scores"] = {}
for scorer in self.scorers or []:
scorer_name, _, _ = get_scorer_attributes(scorer)
if scorer_name not in eval_row["scores"]:
eval_row["scores"][scorer_name] = {}
eval_rows.append(eval_row)

summary = await self.summarize(eval_rows)
wandb.log(replace_backslash_dot(summary))
rich.print("Evaluation summary", summary)
return summary

0 comments on commit 781348d

Please sign in to comment.