Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add HemmEvaluation #5

Merged
merged 3 commits into from
Jun 22, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 5 additions & 11 deletions hemm/eval_pipelines/eval_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
import asyncio
import base64
from abc import ABC
from io import BytesIO
from typing import Callable, Dict, List, Union
from typing import Dict, List, Union

import wandb
import weave
from PIL import Image
from weave import Evaluation

from .hemm_evaluation import HemmEvaluation
from .model import BaseDiffusionModel
from ..metrics.base import BaseMetric
from ..utils import base64_decode_image


class EvaluationPipeline(ABC):
@@ -76,11 +74,7 @@ async def infer(self, prompt: str) -> Dict[str, str]:
[
self.model.diffusion_model_name_or_path,
prompt,
wandb.Image(
Image.open(
BytesIO(base64.b64decode(output["image"].split(";base64,")[-1]))
)
),
wandb.Image(base64_decode_image(output["image"])),
]
)
return output
@@ -106,7 +100,7 @@ def __call__(self, dataset: Union[List[Dict], str]) -> None:
passed, it is assumed to be a Weave dataset reference.
"""
dataset = weave.ref(dataset).get() if isinstance(dataset, str) else dataset
evaluation = Evaluation(
evaluation = HemmEvaluation(
dataset=dataset,
scorers=[metric_fn.evaluate for metric_fn in self.metric_functions],
)
72 changes: 72 additions & 0 deletions hemm/eval_pipelines/hemm_evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import traceback
from typing import Callable, Optional, Union, cast

import rich
import wandb
import weave
from weave.flow.dataset import Dataset
from weave.flow.model import Model
from weave.flow.util import async_foreach
from weave.flow.scorer import Scorer, get_scorer_attributes
from weave.trace.errors import OpCallError
from weave.trace.env import get_weave_parallelism
from weave.trace.op import Op


def replace_backslash_dot(d):
if isinstance(d, dict):
new_dict = {}
for k, v in d.items():
new_key = k.replace("\\.", ".")
new_dict[new_key] = replace_backslash_dot(v)
return new_dict
elif isinstance(d, list):
return [replace_backslash_dot(i) for i in d]
else:
return d


class HemmEvaluation(weave.Evaluation):
dataset: Union[Dataset, list]
scorers: Optional[list[Union[Callable, Op, Scorer]]] = None
preprocess_model_input: Optional[Callable] = None
trials: int = 1

@weave.op()
async def evaluate(self, model: Union[Callable, Model]) -> dict:
eval_rows = []

async def eval_example(example: dict) -> dict:
try:
eval_row = await self.predict_and_score(model, example)
except OpCallError as e:
raise e
except Exception as e:
rich.print("Predict and score failed")
traceback.print_exc()
return {"model_output": None, "scores": {}}
return eval_row

n_complete = 0
dataset = cast(Dataset, self.dataset)
_rows = dataset.rows
trial_rows = list(_rows) * self.trials
async for _, eval_row in async_foreach(
trial_rows, eval_example, get_weave_parallelism()
):
n_complete += 1
rich.print(f"Evaluated {n_complete} of {len(trial_rows)} examples")
if eval_row == None:
eval_row = {"model_output": None, "scores": {}}
if eval_row["scores"] == None:
eval_row["scores"] = {}
for scorer in self.scorers or []:
scorer_name, _, _ = get_scorer_attributes(scorer)
if scorer_name not in eval_row["scores"]:
eval_row["scores"][scorer_name] = {}
eval_rows.append(eval_row)

summary = await self.summarize(eval_rows)
wandb.log(replace_backslash_dot(summary))
rich.print("Evaluation summary", summary)
return summary