Skip to content

Commit

Permalink
update: docs + docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
soumik12345 committed Jun 21, 2024
1 parent 712f11b commit 6e46579
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 1 deletion.
2 changes: 1 addition & 1 deletion docs/metrics/spatial_relationship.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ This module aims to implement the Spatial relationship metric described in secti
model_address=detr_model_address, revision=detr_revision
)

# Add PSNR Metric to the evaluation pipeline
# Add 2d spatial relationship Metric to the evaluation pipeline
metric = SpatialRelationshipMetric2D(judge=judge, name="2d_spatial_relationship_score")
evaluation_pipeline.add_metric(metric)

Expand Down
31 changes: 31 additions & 0 deletions hemm/metrics/spatial_relationship/spatial_relationship_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,37 @@ class SpatialRelationshipMetric2D:
"""Spatial relationship metric for 2D images as proposed by Section 4.2 from the paper
[T2I-CompBench: A Comprehensive Benchmark for Open-world Compositional Text-to-image Generation](https://arxiv.org/pdf/2307.06350).
??? example "Sample usage"
```python
import wandb
import weave
from hemm.eval_pipelines import BaseDiffusionModel, EvaluationPipeline
from hemm.metrics.image_quality import LPIPSMetric, PSNRMetric, SSIMMetric
# Initialize Weave and WandB
wandb.init(project="image-quality-leaderboard", job_type="evaluation")
weave.init(project_name="image-quality-leaderboard")
# Initialize the diffusion model to be evaluated as a `weave.Model` using `BaseWeaveModel`
model = BaseDiffusionModel(diffusion_model_name_or_path="CompVis/stable-diffusion-v1-4")
# Add the model to the evaluation pipeline
evaluation_pipeline = EvaluationPipeline(model=model)
# Define the judge model for 2d spatial relationship metric
judge = DETRSpatialRelationShipJudge(
model_address=detr_model_address, revision=detr_revision
)
# Add 2d spatial relationship Metric to the evaluation pipeline
metric = SpatialRelationshipMetric2D(judge=judge, name="2d_spatial_relationship_score")
evaluation_pipeline.add_metric(metric)
# Evaluate!
evaluation_pipeline(dataset="t2i_compbench_spatial_prompts:v0")
```
Args:
judge (Union[weave.Model, DETRSpatialRelationShipJudge]): The judge model to predict
the bounding boxes from the generated image.
Expand Down

0 comments on commit 6e46579

Please sign in to comment.