Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option for synchronous evaluation #6

Merged
merged 8 commits into from
Jun 26, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add: optional async evaluation
soumik12345 committed Jun 22, 2024
commit 44519809a768ef996efd25329167011058fa1ce0
43 changes: 35 additions & 8 deletions hemm/eval_pipelines/eval_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from abc import ABC
import asyncio
from typing import Dict, List, Union

import wandb
import weave

from .hemm_evaluation import SyncHemmEvaluation
from .hemm_evaluation import AsyncHemmEvaluation, SyncHemmEvaluation
from .model import BaseDiffusionModel
from ..metrics.base import BaseMetric
from ..utils import base64_decode_image
@@ -78,6 +79,19 @@ def infer(self, prompt: str) -> Dict[str, str]:
)
return output

@weave.op()
async def infer_async(self, prompt: str) -> Dict[str, str]:
"""Async inference function to generate images for the given prompt.

Args:
prompt (str): Prompt to generate the image.

Returns:
Dict[str, str]: Dictionary containing base64 encoded image to be logged as
a Weave object.
"""
return self.infer(prompt)

def log_summary(self):
"""Log the evaluation summary to the Weights & Biases dashboard."""
config = wandb.config
@@ -91,18 +105,31 @@ def log_summary(self):
{f"Evalution/{self.model.diffusion_model_name_or_path}": self.wandb_table}
)

def __call__(self, dataset: Union[List[Dict], str]) -> None:
def __call__(
self, dataset: Union[List[Dict], str], async_evaluation: bool = False
) -> None:
"""Evaluate the Stable Diffusion model on the given dataset.

Args:
dataset (Union[List[Dict], str]): Dataset to evaluate the model on. If a string is
passed, it is assumed to be a Weave dataset reference.
async_evaluation (bool): Flag to enable asynchronous evaluation.
"""
dataset = weave.ref(dataset).get() if isinstance(dataset, str) else dataset
evaluation = SyncHemmEvaluation(
dataset=dataset,
scorers=[metric_fn.evaluate for metric_fn in self.metric_functions],
)
with weave.attributes(self.evaluation_configs):
evaluation.evaluate(self.infer)
if async_evaluation:
evaluation = AsyncHemmEvaluation(
dataset=dataset,
scorers=[
metric_fn.evaluate_async for metric_fn in self.metric_functions
],
)
with weave.attributes(self.evaluation_configs):
asyncio.run(evaluation.evaluate(self.infer_async))
else:
evaluation = SyncHemmEvaluation(
dataset=dataset,
scorers=[metric_fn.evaluate for metric_fn in self.metric_functions],
)
with weave.attributes(self.evaluation_configs):
evaluation.evaluate(self.infer)
self.log_summary()
4 changes: 4 additions & 0 deletions hemm/metrics/base.py
Original file line number Diff line number Diff line change
@@ -10,3 +10,7 @@ def __init__(self) -> None:
@abstractmethod
def evaluate(self) -> Dict[str, Any]:
pass

@abstractmethod
def evaluate_async(self) -> Dict[str, Any]:
pass
11 changes: 9 additions & 2 deletions hemm/metrics/image_quality/lpips.py
Original file line number Diff line number Diff line change
@@ -70,8 +70,15 @@ def compute_metric(
)

@weave.op()
def __call__(
def evaluate(
self, prompt: str, ground_truth_image: str, model_output: Dict[str, Any]
) -> Union[float, Dict[str, float]]:
_ = "LPIPSMetric"
return super().__call__(prompt, ground_truth_image, model_output)
return super().evaluate(prompt, ground_truth_image, model_output)

@weave.op()
async def evaluate_async(
self, prompt: str, ground_truth_image: str, model_output: Dict[str, Any]
) -> Union[float, Dict[str, float]]:
_ = "LPIPSMetric"
return self.evaluate(prompt, ground_truth_image, model_output)
11 changes: 9 additions & 2 deletions hemm/metrics/image_quality/psnr.py
Original file line number Diff line number Diff line change
@@ -64,8 +64,15 @@ def compute_metric(
)

@weave.op()
def __call__(
def evaluate(
self, prompt: str, ground_truth_image: str, model_output: Dict[str, Any]
) -> Union[float, Dict[str, float]]:
_ = "PSNRMetric"
return super().__call__(prompt, ground_truth_image, model_output)
return super().evaluate(prompt, ground_truth_image, model_output)

@weave.op()
async def evaluate_async(
self, prompt: str, ground_truth_image: str, model_output: Dict[str, Any]
) -> Union[float, Dict[str, float]]:
_ = "PSNRMetric"
return self.evaluate(prompt, ground_truth_image, model_output)
11 changes: 9 additions & 2 deletions hemm/metrics/image_quality/ssim.py
Original file line number Diff line number Diff line change
@@ -91,8 +91,15 @@ def compute_metric(
)

@weave.op()
def __call__(
def evaluate(
self, prompt: str, ground_truth_image: str, model_output: Dict[str, Any]
) -> Union[float, Dict[str, float]]:
_ = "SSIMMetric"
return super().__call__(prompt, ground_truth_image, model_output)
return super().evaluate(prompt, ground_truth_image, model_output)

@weave.op()
async def evaluate_async(
self, prompt: str, ground_truth_image: str, model_output: Dict[str, Any]
) -> Union[float, Dict[str, float]]:
_ = "SSIMMetric"
return self.evaluate(prompt, ground_truth_image, model_output)
11 changes: 9 additions & 2 deletions hemm/metrics/prompt_alignment/blip_score.py
Original file line number Diff line number Diff line change
@@ -48,6 +48,13 @@ def compute_metric(
)

@weave.op()
def __call__(self, prompt: str, model_output: Dict[str, Any]) -> Dict[str, float]:
def evaluate(self, prompt: str, model_output: Dict[str, Any]) -> Dict[str, float]:
_ = "BLIPScoreMertric"
return super().__call__(prompt, model_output)
return super().evaluate(prompt, model_output)

@weave.op()
async def evaluate_async(
self, prompt: str, model_output: Dict[str, Any]
) -> Dict[str, float]:
_ = "BLIPScoreMertric"
return self.evaluate(prompt, model_output)
11 changes: 9 additions & 2 deletions hemm/metrics/prompt_alignment/clip_iqa_score.py
Original file line number Diff line number Diff line change
@@ -80,6 +80,13 @@ def compute_metric(
return score_dict

@weave.op()
def __call__(self, prompt: str, model_output: Dict[str, Any]) -> Dict[str, float]:
def evaluate(self, prompt: str, model_output: Dict[str, Any]) -> Dict[str, float]:
_ = "CLIPImageQualityScoreMetric"
return super().__call__(prompt, model_output)
return super().evaluate(prompt, model_output)

@weave.op()
async def evaluate_async(
self, prompt: str, model_output: Dict[str, Any]
) -> Dict[str, float]:
_ = "CLIPImageQualityScoreMetric"
return self.evaluate(prompt, model_output)
11 changes: 9 additions & 2 deletions hemm/metrics/prompt_alignment/clip_score.py
Original file line number Diff line number Diff line change
@@ -45,6 +45,13 @@ def compute_metric(
)

@weave.op()
def __call__(self, prompt: str, model_output: Dict[str, Any]) -> Dict[str, float]:
def evaluate(self, prompt: str, model_output: Dict[str, Any]) -> Dict[str, float]:
_ = "CLIPScoreMetric"
return super().__call__(prompt, model_output)
return super().evaluate(prompt, model_output)

@weave.op()
async def evaluate_async(
self, prompt: str, model_output: Dict[str, Any]
) -> Dict[str, float]:
_ = "CLIPScoreMetric"
return self.evaluate(prompt, model_output)
11 changes: 11 additions & 0 deletions hemm/metrics/spatial_relationship/spatial_relationship_2d.py
Original file line number Diff line number Diff line change
@@ -243,3 +243,14 @@ def evaluate(
prompt, image, entity_1, entity_2, relationship, boxes
)
return {self.name: judgement["score"]}

@weave.op()
async def evaluate_async(
self,
prompt: str,
entity_1: str,
entity_2: str,
relationship: str,
model_output: Dict[str, Any],
) -> Dict[str, Union[bool, float, int]]:
return self.evaluate(prompt, entity_1, entity_2, relationship, model_output)