diff --git a/AGENTS.md b/AGENTS.md index a9decbc7e2..55ee7d4eb5 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -50,11 +50,10 @@ Write tests as plain functions with pytest fixtures. Don't use class-based tests ## Git -- **Branch prefixes**: use the following prefixes for branches: `feat/`, `fix/`, `chore/` +- **Branch prefixes**: use `feat/`, `fix/`, `chore/`; use `exp/` for experiment branches (configs, run summaries, pins, notes). ## GitHub - **Draft PRs**: always create PRs as drafts (`gh pr create --draft`) to avoid triggering CI unnecessarily. - **Pull requests**: do not include a "test plan" section in PR descriptions unless you actually ran tests to verify the changes or the user explicitly asked for one. - **Keep PR descriptions in sync**: every time you push commits to a PR, also update the PR description (`gh pr edit --body-file ...`) so it reflects the current state of the branch — not just what was true when the PR was opened. Preserve any auto-generated blocks (e.g. ``). - diff --git a/configs/general_agent/rl_qwen3_0p6b.toml b/configs/general_agent/rl_qwen3_0p6b.toml new file mode 100644 index 0000000000..7f149181ed --- /dev/null +++ b/configs/general_agent/rl_qwen3_0p6b.toml @@ -0,0 +1,29 @@ +max_steps = 5 +seq_len = 8192 + +[wandb] +project = "general-agent-debug" +name = "qwen3-0p6b-rlm" + +[model] +name = "Qwen/Qwen3-0.6B" + +[orchestrator] +batch_size = 16 +rollouts_per_example = 4 + +[orchestrator.train.sampling] +max_completion_tokens = 4096 + +[[orchestrator.train.env]] +id = "general-agent-solver-rlm" + +[trainer] + +[inference] + +[inference.model] +max_model_len = 8192 + +[inference.parallel] +dp = 1 diff --git a/configs/general_agent/rl_qwen3_30b_a3b.toml b/configs/general_agent/rl_qwen3_30b_a3b.toml new file mode 100644 index 0000000000..49f4b3312c --- /dev/null +++ b/configs/general_agent/rl_qwen3_30b_a3b.toml @@ -0,0 +1,52 @@ +max_steps = 400 +seq_len = 32768 + +[slurm] +job_name = "general-agent-qwen3-30b-a3b-rlm" + +[deployment] +type = "multi_node" +num_train_nodes = 1 +num_infer_nodes = 1 + +[wandb] +project = "general-agent-debug" +name = "qwen3-30b-a3b-rlm" + +[ckpt] +interval = 50 +keep_last = 1 + +[model] +name = "Qwen/Qwen3-30B-A3B-Instruct-2507" + +[trainer] + +[trainer.model] +cp = 2 + +[trainer.model.ac] +freq = 1 + +[trainer.model.compile] + +[orchestrator] +batch_size = 512 +rollouts_per_example = 16 +max_off_policy_steps = 32 + +[[orchestrator.train.env]] +id = "general-agent-solver-rlm" + +[orchestrator.train.env.args] +min_tier = 1 + +[inference] +gpu_memory_utilization = 0.85 + +[inference.model] +max_model_len = 32768 + +[inference.parallel] +dp = 2 +tp = 4 diff --git a/configs/general_agent/rl_qwen3_4b.toml b/configs/general_agent/rl_qwen3_4b.toml new file mode 100644 index 0000000000..7cf5c105e7 --- /dev/null +++ b/configs/general_agent/rl_qwen3_4b.toml @@ -0,0 +1,44 @@ +max_steps = 200 +seq_len = 32768 + +[deployment] +num_train_gpus = 4 +num_infer_gpus = 4 + +[wandb] +project = "general-agent-debug" +name = "qwen3-4b-rlm" + +[ckpt] +interval = 100 +keep_last = 1 + +[model] +name = "Qwen/Qwen3-4B-Instruct-2507" + +[trainer] + +[trainer.model] +cp = 2 + +[trainer.model.ac] +freq = 1 + +[trainer.model.compile] + +[orchestrator] +batch_size = 512 +rollouts_per_example = 8 +max_off_policy_steps = 32 + +[[orchestrator.train.env]] +id = "general-agent-solver-rlm" + +[inference] +gpu_memory_utilization = 0.85 + +[inference.model] +max_model_len = 32768 + +[inference.parallel] +dp = 4 diff --git a/configs/private b/configs/private index 894caa0047..70c3503e1d 160000 --- a/configs/private +++ b/configs/private @@ -1 +1 @@ -Subproject commit 894caa00471d8dd8e0f911d59b0cab388089312e +Subproject commit 70c3503e1dc4ea499b09f0eee206b509169b79bd diff --git a/deps/research-environments b/deps/research-environments index d141472268..c752781984 160000 --- a/deps/research-environments +++ b/deps/research-environments @@ -1 +1 @@ -Subproject commit d141472268551411b6a9924a66b4426db3ce197d +Subproject commit c752781984c1b4fbb0a3d7f4aac1e7ed67cc749e diff --git a/pyproject.toml b/pyproject.toml index 1553034690..fdf35fb747 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,6 +66,7 @@ envs = [ "code-env", "color-codeword", "deepdive", + "general-agent", "gpqa", "hle", "ifeval", @@ -118,6 +119,7 @@ dev = [ "ipywidgets>=8.1.7", "pre-commit>=4.2.0", "pytest>=8.4.1", + "pytest-asyncio>=0.23", "ruff>=0.12.1", ] @@ -137,6 +139,7 @@ members = [ "deps/research-environments/environments/code_env", "deps/research-environments/environments/color_codeword", "deps/research-environments/environments/deepdive", + "deps/research-environments/environments/general_agent", "deps/research-environments/environments/gpqa", "deps/research-environments/environments/hle", "deps/research-environments/environments/ifeval", @@ -203,6 +206,7 @@ alphabet-sort = { workspace = true } code-env = { workspace = true } color-codeword = { workspace = true } deepdive = { workspace = true } +general-agent = { workspace = true } gpqa = { workspace = true } hle = { workspace = true } ifeval = { workspace = true } diff --git a/tests/unit/test_configs.py b/tests/unit/test_configs.py index 77e6c84bf1..75e6fce0aa 100644 --- a/tests/unit/test_configs.py +++ b/tests/unit/test_configs.py @@ -1,3 +1,4 @@ +import tomllib from pathlib import Path from typing import Annotated, Literal @@ -33,9 +34,19 @@ def get_config_files() -> list[Path]: return config_files + example_files +def is_eval_config(path: Path) -> bool: + """vf-eval TOMLs live under configs but are not prime-rl entrypoint configs.""" + with path.open("rb") as f: + data = tomllib.load(f) + return isinstance(data.get("eval"), list) + + @pytest.mark.parametrize("config_file", get_config_files(), ids=lambda x: x.as_posix()) def test_load_configs(config_file: Path): """Tests that all config files can be loaded by at least one config class.""" + if is_eval_config(config_file): + pytest.skip("vf-eval TOML files are not prime-rl entrypoint configs") + could_parse = [] for config_cls in CONFIG_CLASSES: try: diff --git a/uv.lock b/uv.lock index a3d9158657..26b1eb3b29 100644 --- a/uv.lock +++ b/uv.lock @@ -41,6 +41,7 @@ members = [ "code-env", "color-codeword", "deepdive", + "general-agent", "gpqa", "hle", "ifeval", @@ -968,7 +969,7 @@ wheels = [ [[package]] name = "deepdive" -version = "0.2.7" +version = "0.2.9" source = { editable = "deps/research-environments/environments/deepdive" } dependencies = [ { name = "aiohttp", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, @@ -1522,6 +1523,38 @@ http = [ { name = "aiohttp", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, ] +[[package]] +name = "general-agent" +version = "0.1.4" +source = { editable = "deps/research-environments/environments/general_agent" } +dependencies = [ + { name = "mcp", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "tyro", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "verifiers", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, +] + +[package.optional-dependencies] +dev = [ + { name = "ruff", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "ty", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, +] +test = [ + { name = "pytest", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "pytest-asyncio", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, +] + +[package.metadata] +requires-dist = [ + { name = "mcp", specifier = ">=1.0" }, + { name = "pytest", marker = "extra == 'test'" }, + { name = "pytest-asyncio", marker = "extra == 'test'" }, + { name = "ruff", marker = "extra == 'dev'" }, + { name = "ty", marker = "extra == 'dev'" }, + { name = "tyro", specifier = ">=0.9" }, + { name = "verifiers", specifier = ">=0.1.15.dev2" }, +] +provides-extras = ["dev", "test"] + [[package]] name = "gepa" version = "0.1.1" @@ -3306,7 +3339,7 @@ requires-dist = [ [[package]] name = "opencode-deepdive" -version = "0.1.15" +version = "0.1.16" source = { editable = "deps/research-environments/environments/opencode_deepdive" } dependencies = [ { name = "datasets", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, @@ -3951,6 +3984,7 @@ envs = [ { name = "code-env", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "color-codeword", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "deepdive", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "general-agent", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "gpqa", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "hle", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "ifeval", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, @@ -3995,6 +4029,7 @@ dev = [ { name = "ipywidgets", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "pre-commit", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "pytest", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "pytest-asyncio", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "ruff", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, ] mamba-ssm = [ @@ -4019,6 +4054,7 @@ requires-dist = [ { name = "flash-attn-3", marker = "extra == 'flash-attn-3'", index = "https://download.pytorch.org/whl/test/cu128" }, { name = "flash-attn-4", marker = "extra == 'flash-attn-cute'", git = "https://github.com/Dao-AILab/flash-attention.git?subdirectory=flash_attn%2Fcute&rev=96bd151" }, { name = "flash-linear-attention", git = "https://github.com/fla-org/flash-linear-attention" }, + { name = "general-agent", marker = "extra == 'envs'", editable = "deps/research-environments/environments/general_agent" }, { name = "gpqa", marker = "extra == 'envs'", editable = "deps/research-environments/environments/gpqa" }, { name = "hle", marker = "extra == 'envs'", editable = "deps/research-environments/environments/hle" }, { name = "ifeval", marker = "extra == 'envs'", editable = "deps/research-environments/environments/ifeval" }, @@ -4088,6 +4124,7 @@ dev = [ { name = "ipywidgets", specifier = ">=8.1.7" }, { name = "pre-commit", specifier = ">=4.2.0" }, { name = "pytest", specifier = ">=8.4.1" }, + { name = "pytest-asyncio", specifier = ">=0.23" }, { name = "ruff", specifier = ">=0.12.1" }, ] mamba-ssm = [{ name = "mamba-ssm", specifier = ">=2.3.0" }]