Skip to content

Commit

Permalink
update: OpenAIJudge
Browse files Browse the repository at this point in the history
  • Loading branch information
soumik12345 committed Sep 5, 2024
1 parent d6a4e27 commit 006dfa3
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 8 deletions.
14 changes: 8 additions & 6 deletions examples/multimodal_llm_eval/evaluate_mllm_metric_complex.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Tuple
from typing import Optional

import fire
import wandb
Expand All @@ -16,23 +16,25 @@ def main(
dataset_limit: Optional[int] = None,
diffusion_model_address: str = "stabilityai/stable-diffusion-2-1",
diffusion_model_enable_cpu_offfload: bool = False,
image_size: Tuple[int, int] = (512, 512),
image_height: int = 1024,
image_width: int = 1024,
):
wandb.init(project=project, entity=entity, job_type="evaluation")
weave.init(project_name=project)
weave.init(project_name=f"{entity}/{project}")

dataset = weave.ref(dataset_ref).get()
dataset = dataset.rows[:dataset_limit] if dataset_limit else dataset

diffusion_model = BaseDiffusionModel(
diffusion_model_name_or_path=diffusion_model_address,
enable_cpu_offfload=diffusion_model_enable_cpu_offfload,
image_height=image_size[0],
image_width=image_size[1],
image_height=image_height,
image_width=image_width,
)
diffusion_model._pipeline.set_progress_bar_config(disable=True)
evaluation_pipeline = EvaluationPipeline(model=diffusion_model)

judge = OpenAIJudge(prompt_property=PromptCategory.complex)
judge = OpenAIJudge(prompt_property=PromptCategory.action)
metric = MultiModalLLMEvaluationMetric(judge=judge)
evaluation_pipeline.add_metric(metric)

Expand Down
2 changes: 2 additions & 0 deletions hemm/metrics/vqa/judges/mmllm_judges/openai_judge.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import subprocess
from typing import List

import instructor
Expand Down Expand Up @@ -65,6 +66,7 @@ def __init__(
max_retries=max_retries,
seed=seed,
)
subprocess.run(["spacy", "download", "en_core_web_sm"])
self._nlp_pipeline = spacy.load(self.prompt_pipeline)
self._openai_client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
self._instructor_openai_client = instructor.from_openai(
Expand Down
66 changes: 64 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ spacy = "^3.7.5"
instructor = "^1.3.4"
torchmetrics = { extras = ["multimodal"], version = "^1.4.1" }
mkdocstrings = {version = "^0.25.2", extras = ["python"]}
sentencepiece = "^0.2.0"

[tool.poetry.extras]
core = [
Expand All @@ -48,6 +49,7 @@ core = [
"spacy",
"instructor",
"torchmetrics",
"sentencepiece",
]
docs = [
"mkdocs",
Expand Down

0 comments on commit 006dfa3

Please sign in to comment.