Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 5 additions & 10 deletions packages/prime/src/prime_cli/commands/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def generate_rl_config_template(environment: str | None = None) -> str:

return f'''\
model = "Qwen/Qwen3.5-0.8B"
loss = "rl" # "rl" | "sft"; OPD is not yet supported on hosted runtimes
loss = "rl" # "rl" | "sft" | "opd"
max_steps = 100

# env_files = ["secrets.env"] # optional file(s) for secrets
Expand All @@ -229,8 +229,8 @@ def generate_rl_config_template(environment: str | None = None) -> str:
# Optional: warm-start from an existing checkpoint
# checkpoint_id = "..."

# Optional: SFT distillation teacher
# To use SFT, change the top-level loss to "sft" and uncomment this block.
# Optional: distillation teacher
# To use SFT or OPD, change the top-level loss and uncomment this block.
# [teacher]
# model = "openai/gpt-oss-120b"
#
Expand Down Expand Up @@ -664,13 +664,8 @@ def validate_config_consistency(self) -> "RLConfig":
raise ValueError("max_inflight_rollouts must be at least rollouts_per_example")
if self.loss == "rl" and self.teacher is not None:
raise ValueError("teacher can only be set when loss is 'sft' or 'opd'")
if self.loss == "sft" and self.teacher is None:
raise ValueError("teacher is required when loss is 'sft'")
if self.loss == "opd":
raise ValueError(
"loss='opd' is not supported for hosted runs yet; OPD requires "
"teacher logprob scoring support in the hosted runtime"
)
if self.loss in {"sft", "opd"} and self.teacher is None:
raise ValueError(f"teacher is required when loss is '{self.loss}'")
Comment on lines +667 to +668

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Keep rejecting OPD until hosted scoring exists

In the hosted CLI path I inspected, train() still sends this config through RLClient.create_run() to /rft/runs, and this commit does not add the teacher-logprob runtime/API support that the previous guard said OPD requires. With this validator now accepting loss = "opd" whenever a teacher is present, users can create hosted OPD runs that pass local validation but fail or behave incorrectly once scheduled; keep the local rejection until the hosted runtime path is actually wired.

Useful? React with 👍 / 👎.

return self


Expand Down
17 changes: 17 additions & 0 deletions packages/prime/tests/test_rl_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,23 @@ def test_create_run_sends_sft_loss_and_teacher_config() -> None:
}


def test_create_run_sends_opd_loss_and_teacher_config() -> None:
api_client = FakeAPIClient()
client = RLClient(api_client) # type: ignore[arg-type]

client.create_run(
model_name="openai/gpt-oss-20b",
environments=[{"id": "primeintellect/reverse-text"}],
loss="opd",
teacher={"model": {"name": "openai/gpt-oss-120b"}},
)

assert api_client.posts[0][0] == "/rft/runs"
payload = api_client.posts[0][1]
assert payload["loss"] == "opd"
assert payload["teacher"] == {"model": {"name": "openai/gpt-oss-120b"}}


def test_create_run_omits_default_rl_loss() -> None:
api_client = FakeAPIClient()
client = RLClient(api_client) # type: ignore[arg-type]
Expand Down
24 changes: 16 additions & 8 deletions packages/prime/tests/test_rl_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,14 @@ def test_generate_rl_config_template_keeps_default_surface_minimal() -> None:
def test_generate_rl_config_template_sft_example_loads(tmp_path: Path) -> None:
template = generate_rl_config_template()
template = template.replace(
'loss = "rl" # "rl" | "sft"; OPD is not yet supported on hosted runtimes',
'loss = "sft" # "rl" | "sft"; OPD is not yet supported on hosted runtimes',
'loss = "rl" # "rl" | "sft" | "opd"',
'loss = "sft" # "rl" | "sft" | "opd"',
)

lines: list[str] = []
in_teacher_example = False
for line in template.splitlines():
if line == "# Optional: SFT distillation teacher":
if line == "# Optional: distillation teacher":
in_teacher_example = True
lines.append(line)
continue
Expand Down Expand Up @@ -226,15 +226,23 @@ def test_load_config_rejects_sft_without_teacher(tmp_path: Path) -> None:
load_config(str(config_path))


def test_load_config_rejects_opd_until_hosted_scoring_exists(tmp_path: Path) -> None:
def test_load_config_accepts_opd_teacher(tmp_path: Path) -> None:
config_path = tmp_path / "opd.toml"
config_path.write_text(
'model = "openai/gpt-oss-20b"\n'
'loss = "opd"\n'
"[teacher]\n"
'model = "openai/gpt-oss-120b"\n'
'model = "openai/gpt-oss-20b"\nloss = "opd"\n[teacher]\nmodel = "openai/gpt-oss-120b"\n'
)

cfg = load_config(str(config_path))

assert cfg.loss == "opd"
assert cfg.teacher is not None
assert cfg.teacher.model == "openai/gpt-oss-120b"


def test_load_config_rejects_opd_without_teacher(tmp_path: Path) -> None:
config_path = tmp_path / "opd.toml"
config_path.write_text('model = "openai/gpt-oss-20b"\nloss = "opd"\n')

with pytest.raises(typer.Exit):
load_config(str(config_path))

Expand Down
Loading