Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/rating_and_ranking.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,9 @@ Converts a rank to a percentile based on the distribution of ranks.
*Parameters:*
- `rank (int)`: The rank to convert.
- `method (Literal["original", "hazen", "weibull"])`: The method to use for conversion. Defaults to `"weibull"`.
- `"original"`: $\text{percentile} = 100.0 \times \frac{\text{rank}}{\text{\#active\_users}}$, capped at 100% when $\text{rank} = \text{\#active\_users} + 1$
- `"hazen"`: $\text{percentile} = 100.0 \times \frac{(\text{rank} - 0.5)}{(\text{\#active\_users} + 1)}$
- `"weibull"`: $\text{percentile} = 100.0 \times \frac{\text{rank}}{(\text{\#active\_users} + 2)}$
- `"original"`: $\text{percentile} = 100.0 \times \frac{\text{rank}}{\text{num\_active\_users}}$, capped at 100% when $\text{rank} = \text{num\_active\_users} + 1$
- `"hazen"`: $\text{percentile} = 100.0 \times \frac{(\text{rank} - 0.5)}{(\text{num\_active\_users} + 1)}$
- `"weibull"`: $\text{percentile} = 100.0 \times \frac{\text{rank}}{(\text{num\_active\_users} + 2)}$

> *Note:* The `"weibull"` method is recommended because it avoids 0%/100% endpoints (exclusive percentiles) and is widely used in the literature. We selected `"weibull"` as default rather than `"hazen"` because it provides a slightly more aligned to the original percentile calculation when the rank is higher. The original paper uses the `"original"` method, but it does not align well with statistical properties. All methods are acceptable as long as the method is documented.

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ dependencies = [
"pillow",
"polars>=1",
"pydantic>=2",
"typing_extensions",
]

[project.urls]
Expand Down
62 changes: 43 additions & 19 deletions src/ale_bench/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,52 @@
from typing import Sequence

from PIL import Image
from pydantic import Field, field_serializer, field_validator
from pydantic import Field
from pydantic.functional_serializers import PlainSerializer
from pydantic.functional_validators import BeforeValidator
from pydantic.json_schema import WithJsonSchema
from typing_extensions import Annotated

from ale_bench.data import Problem
from ale_bench.result import CaseResult, Result
from ale_bench.utils import base64_to_pil, pil_to_base64

SerializableImage = Annotated[
Image.Image,
# Deserialize from base64 string
BeforeValidator(lambda v: base64_to_pil(v) if isinstance(v, str) else v),
# Serialize to base64 string
PlainSerializer(lambda img: pil_to_base64(img), return_type=str), # NOTE: when_used="json" may be helpful
# JSON Schema representation
WithJsonSchema(
{
"type": "string",
"format": "byte",
"contentEncoding": "base64",
"contentMediaType": "image/*",
"description": "Base64-encoded image data (png/jpeg/webp)",
}
),
]


class ProblemSerializable(Problem):
"""Serializable version of Problem for JSON serialization.

This class extends Problem to include serialization and deserialization of the `statement_images` field.
This class is especially useful for hosting APIs that need to serialize images in a format suitable for JSON.
"""

statement_images: dict[str, SerializableImage | list[SerializableImage]] = Field(
description="Problem statement images in PIL Image format (key: image name, value: image object)",
default_factory=dict,
)

@classmethod
def from_problem(cls, problem: Problem) -> "ProblemSerializable":
"""Create a ProblemSerializable from an existing Problem."""
return cls.model_validate(problem.model_dump())


class CaseResultSerializable(CaseResult):
"""Serializable version of CaseResult for JSON serialization.
Expand All @@ -16,24 +57,7 @@ class CaseResultSerializable(CaseResult):
This class is especially useful for hosting APIs that need to serialize images in a format suitable for JSON.
"""

@field_serializer("local_visualization")
def serialize_local_visualization(self, value: Image.Image | None) -> str | None:
"""Serialize the local visualization image to a base64 string."""
if value is None:
return None
return pil_to_base64(value)

@field_validator("local_visualization", mode="before")
def deserialize_local_visualization(cls, value: Image.Image | str | None) -> Image.Image | None:
"""Deserialize the local visualization from a base64 string to an Image."""
if isinstance(value, str):
try:
return base64_to_pil(value)
except Exception as e:
raise ValueError(f"Invalid base64 image data: {e}")
elif isinstance(value, Image.Image):
return value
return None
local_visualization: SerializableImage | None = Field(default=None, description="The final state of the submission")

@classmethod
def from_case_result(cls, case_result: CaseResult) -> "CaseResultSerializable":
Expand Down
72 changes: 71 additions & 1 deletion tests/test_schemas.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,83 @@
from __future__ import annotations

import datetime
from typing import Any

import pytest
from PIL import Image

from ale_bench.data import ProblemConstraints, ProblemMetaData, ProblemType, ScoreType
from ale_bench.result import CaseResult, JudgeResult, ResourceUsage, Result
from ale_bench.schemas import CaseResultSerializable, ResultSerializable
from ale_bench.schemas import CaseResultSerializable, ProblemSerializable, ResultSerializable
from ale_bench.utils import pil_to_base64


@pytest.mark.parametrize(
"problem,serialized",
[
pytest.param(
ProblemSerializable(
metadata=ProblemMetaData(
problem_id="test",
start_at=datetime.datetime(2025, 1, 1, 0, 0, 0),
end_at=datetime.datetime(2025, 7, 7, 0, 0, 0),
contest_url="https://example.com/test",
title="Test Problem",
problem_type=ProblemType.BATCH,
score_type=ScoreType.MAXIMIZE,
),
constraints=ProblemConstraints(
time_limit=2.0,
memory_limit=1073741824,
),
statement="Test problem statement. Image:\nimage1\nVideo:\nvideo1",
statement_ja="テスト問題文。画像:\nimage1\n映像:\nvideo1",
statement_images={
"image1": Image.new("RGBA", (100, 100)),
"video1": [Image.new("RGBA", (100, 100), (64 * i,) * 4) for i in range(3)],
},
example_input="Test input",
example_output="Test output",
tool_readme="Test tool README",
),
{
"metadata": {
"problem_id": "test",
"start_at": "2025-01-01T00:00:00",
"end_at": "2025-07-07T00:00:00",
"contest_url": "https://example.com/test",
"title": "Test Problem",
"problem_type": "batch",
"score_type": "maximize",
},
"constraints": {
"time_limit": 2.0,
"memory_limit": 1073741824,
},
"statement": "Test problem statement. Image:\nimage1\nVideo:\nvideo1",
"statement_ja": "テスト問題文。画像:\nimage1\n映像:\nvideo1",
"statement_images": {
"image1": pil_to_base64(Image.new("RGBA", (100, 100))),
"video1": [pil_to_base64(Image.new("RGBA", (100, 100), (64 * i,) * 4)) for i in range(3)],
},
"example_input": "Test input",
"example_output": "Test output",
"tool_readme": "Test tool README",
},
id="problem_serializable_with_image",
),
],
)
def test_problem_serializable(problem: ProblemSerializable, serialized: dict[str, Any]) -> None:
"""Test serialization and deserialization of ProblemSerializable."""
# Test serialization to dict
problem_serialized = problem.model_dump()
assert problem_serialized == serialized
# Test deserialization from dict
problem_restored = ProblemSerializable.model_validate(serialized)
assert problem_restored == problem


@pytest.mark.parametrize(
"case_result,serialized",
[
Expand Down
Loading