diff --git a/docs/metrics/spatial_relationship.md b/docs/metrics/spatial_relationship.md index 489cb8f..97b5ef8 100644 --- a/docs/metrics/spatial_relationship.md +++ b/docs/metrics/spatial_relationship.md @@ -41,7 +41,7 @@ This module aims to implement the Spatial relationship metric described in secti model_address=detr_model_address, revision=detr_revision ) - # Add PSNR Metric to the evaluation pipeline + # Add 2d spatial relationship Metric to the evaluation pipeline metric = SpatialRelationshipMetric2D(judge=judge, name="2d_spatial_relationship_score") evaluation_pipeline.add_metric(metric) diff --git a/hemm/metrics/spatial_relationship/spatial_relationship_2d.py b/hemm/metrics/spatial_relationship/spatial_relationship_2d.py index d39c116..de600bd 100644 --- a/hemm/metrics/spatial_relationship/spatial_relationship_2d.py +++ b/hemm/metrics/spatial_relationship/spatial_relationship_2d.py @@ -14,6 +14,37 @@ class SpatialRelationshipMetric2D: """Spatial relationship metric for 2D images as proposed by Section 4.2 from the paper [T2I-CompBench: A Comprehensive Benchmark for Open-world Compositional Text-to-image Generation](https://arxiv.org/pdf/2307.06350). + ??? example "Sample usage" + ```python + import wandb + import weave + + from hemm.eval_pipelines import BaseDiffusionModel, EvaluationPipeline + from hemm.metrics.image_quality import LPIPSMetric, PSNRMetric, SSIMMetric + + # Initialize Weave and WandB + wandb.init(project="image-quality-leaderboard", job_type="evaluation") + weave.init(project_name="image-quality-leaderboard") + + # Initialize the diffusion model to be evaluated as a `weave.Model` using `BaseWeaveModel` + model = BaseDiffusionModel(diffusion_model_name_or_path="CompVis/stable-diffusion-v1-4") + + # Add the model to the evaluation pipeline + evaluation_pipeline = EvaluationPipeline(model=model) + + # Define the judge model for 2d spatial relationship metric + judge = DETRSpatialRelationShipJudge( + model_address=detr_model_address, revision=detr_revision + ) + + # Add 2d spatial relationship Metric to the evaluation pipeline + metric = SpatialRelationshipMetric2D(judge=judge, name="2d_spatial_relationship_score") + evaluation_pipeline.add_metric(metric) + + # Evaluate! + evaluation_pipeline(dataset="t2i_compbench_spatial_prompts:v0") + ``` + Args: judge (Union[weave.Model, DETRSpatialRelationShipJudge]): The judge model to predict the bounding boxes from the generated image.