From 48953164d0287e2fd9c87c57f8379613edfe2dae Mon Sep 17 00:00:00 2001 From: S1ro1 Date: Wed, 13 May 2026 08:54:45 +0530 Subject: [PATCH 01/32] feat: wire r3 v3 routed experts stack --- .../qwen30b_kv_offload_routed_experts.toml | 32 ++ configs/r3_v3/qwen30b_wordle_r3.toml | 108 ++++ pyproject.toml | 12 +- uv.lock | 462 +++++++++++------- 4 files changed, 442 insertions(+), 172 deletions(-) create mode 100644 configs/r3_v3/qwen30b_kv_offload_routed_experts.toml create mode 100644 configs/r3_v3/qwen30b_wordle_r3.toml diff --git a/configs/r3_v3/qwen30b_kv_offload_routed_experts.toml b/configs/r3_v3/qwen30b_kv_offload_routed_experts.toml new file mode 100644 index 0000000000..347c73bfac --- /dev/null +++ b/configs/r3_v3/qwen30b_kv_offload_routed_experts.toml @@ -0,0 +1,32 @@ +output_dir = "outputs/qwen30b-kv-offload-r3-v3" +enable_return_routed_experts = true +enable_prefix_caching = true +gpu_memory_utilization = 0.9 + +[model] +name = "Qwen/Qwen3-30B-A3B-Thinking-2507" +max_model_len = 1024 +enforce_eager = true + +[parallel] +tp = 8 +dp = 1 + +[deployment] +type = "single_node" +gpus_per_node = 8 + +[slurm] +job_name = "qwen30b-kv-offload-r3-v3" +partition = "preempt" +time = "02:00:00" +pre_run_command = "uv sync --all-extras --reinstall-package vllm --reinstall-package nvidia-cudnn-cu12 --reinstall-package nvidia-nccl-cu12 --reinstall-package nvidia-cusparselt-cu12 --reinstall-package nvidia-nvshmem-cu12" + +[kv_cache_offload] +cpu_bytes = 17179869184 + +[vllm_extra] +async_scheduling = false +kv_cache_memory_bytes = 536870912 +max_num_batched_tokens = 1024 +max_num_seqs = 16 diff --git a/configs/r3_v3/qwen30b_wordle_r3.toml b/configs/r3_v3/qwen30b_wordle_r3.toml new file mode 100644 index 0000000000..244ae07445 --- /dev/null +++ b/configs/r3_v3/qwen30b_wordle_r3.toml @@ -0,0 +1,108 @@ +output_dir = "outputs/qwen30b-wordle-r3-v3" +clean_output_dir = true +max_steps = 20 +seq_len = 4096 + +[log] +level = "debug" + +[model] +name = "Qwen/Qwen3-30B-A3B-Thinking-2507" + +[deployment] +type = "multi_node" +num_train_nodes = 1 +num_infer_nodes = 1 + +[slurm] +job_name = "qwen30b-wordle-r3-v3" +partition = "preempt" +time = "06:00:00" +pre_run_command = "uv sync --all-extras --reinstall-package nvidia-cudnn-cu12 --reinstall-package nvidia-nccl-cu12 --reinstall-package nvidia-cusparselt-cu12 --reinstall-package nvidia-nvshmem-cu12" + +[wandb] +project = "qwen30b-wordle" +name = "qwen30b-wordle-r3-v3" +group = "qwen30b-wordle-r3-v3" + +[weight_broadcast] +type = "nccl" +timeout = 3600 + +[trainer] +enable_router_replay = true +max_concurrent_runs = 1 +dist_timeout_seconds = 3600 + +[trainer.model] +impl = "custom" +attn = "flash_attention_3" +ep = 8 +optimization_dtype = "float32" +reduce_dtype = "float32" + +[trainer.model.ac] +mode = "full" +freq = 1 +targets = ["norm"] + +[trainer.model.ac_offloading] +max_inflight_activations = 5 + +[trainer.optim] +type = "adamw" +lr = 1e-6 + +[inference] +enable_return_routed_experts = true + +[inference.model] +max_model_len = 4096 + +[inference.parallel] +tp = 8 +dp = 1 + +[inference.vllm_extra] +async_scheduling = false + +[orchestrator] +filters = [] +batch_size = 64 +max_inflight_rollouts = 64 +rollouts_per_example = 8 +max_off_policy_steps = 8 +use_token_client = true + +[[orchestrator.train.env]] +id = "primeintellect/wordle" +name = "wordle" +num_workers = 1 +max_retries = 0 +max_total_completion_tokens = -1 + +[orchestrator.train.env.extra_env_kwargs] +max_total_completion_tokens = -1 +max_seq_len = 4096 + +[orchestrator.train.sampling] +temperature = 1.0 +repetition_penalty = 1.0 +max_completion_tokens = 1024 +min_tokens = 0 + +[orchestrator.train.sampling.extra_body] +top_k = -1 +min_p = 0.0 +return_token_ids = true + +[orchestrator.client] +timeout = 1200 +wait_for_ready_timeout = 1800 + +[orchestrator.client.extra_headers_from_state] +X-Session-ID = "example_id" + +[orchestrator.buffer] +easy_threshold = 1.0 +hard_threshold = 0.0 diff --git a/pyproject.toml b/pyproject.toml index d9b5468fa0..9e1806908c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ dependencies = [ "torchaudio", "torchdata>=0.11.0", "transformers", - "vllm>=0.20.2", + "vllm", "wandb>=0.26.1", "ring-flash-attn>=0.1.8", "prime>=0.6.4", @@ -31,7 +31,8 @@ dependencies = [ "uvloop>=0.21.0", "torchtitan", "verifiers", - "renderers==0.1.6", + "renderers>=0.1.8.dev0", + "wordle", "dion", "tilelang>=0.1.8", "flash-linear-attention", @@ -166,15 +167,16 @@ prime-rl-configs = { workspace = true } torch = { index = "pytorch-cu128" } torchvision = { index = "pytorch-cu128" } torchaudio = { index = "pytorch-cu128" } -verifiers = { git = "https://github.com/PrimeIntellect-ai/verifiers.git", rev = "aa428f3" } +verifiers = { path = "third_party/verifiers", editable = true } +wordle = { path = "third_party/verifiers/environments/wordle" } torchtitan = { git = "https://github.com/pytorch/torchtitan", rev = "a1fdd7e" } dion = { git = "https://github.com/samsja/dion.git", rev = "d891eeb" } transformers = { git = "https://github.com/huggingface/transformers.git", rev = "c1c3424" } flash-attn-4 = { git = "https://github.com/Dao-AILab/flash-attention.git", subdirectory = "flash_attn/cute", rev = "96bd151" } pydantic-config = { git = "https://github.com/samsja/pydantic_config.git", branch = "main" } -vllm-router = { url = "https://github.com/PrimeIntellect-ai/router/releases/download/v0.1.22/vllm_router-0.1.22-cp38-abi3-manylinux_2_28_x86_64.whl" } +vllm-router = { path = "third_party/router/dist-r3-v3-router-cache/vllm_router-0.1.23-cp38-abi3-linux_x86_64.whl" } vllm = [ - { url = "https://github.com/vllm-project/vllm/releases/download/v0.20.2/vllm-0.20.2+cu129-cp38-abi3-manylinux_2_31_x86_64.whl", marker = "platform_machine == 'x86_64'" }, + { path = "third_party/vllm/dist-r3-v3-external/vllm-0.20.1rc1.dev97+g6acff036e.d20260513.precompiled-cp312-cp312-linux_x86_64.whl", marker = "platform_machine == 'x86_64'" }, { url = "https://github.com/vllm-project/vllm/releases/download/v0.20.2/vllm-0.20.2+cu129-cp38-abi3-manylinux_2_31_aarch64.whl", marker = "platform_machine == 'aarch64'" }, ] reverse-text = { index = "primeintellect" } diff --git a/uv.lock b/uv.lock index e5c35957d3..4eac4b601d 100644 --- a/uv.lock +++ b/uv.lock @@ -11,7 +11,7 @@ supported-markers = [ ] [options] -exclude-newer = "0001-01-01T00:00:00Z" # This has no effect and is included for backwards compatibility when using relative exclude-newer values. +exclude-newer = "2026-05-06T12:10:58.853628905Z" exclude-newer-span = "P7D" [options.exclude-newer-package] @@ -25,12 +25,12 @@ color-codeword = false nixl-cu12 = false flash-attn-3 = false prime-tunnel = false -prime = false +prime-sandboxes = false deep-gemm = false aime2024 = false prime-evals = false deepdive = false -prime-sandboxes = false +prime = false reverse-text = false code-env = false mini-swe-agent-plus = false @@ -412,6 +412,12 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8a/1f/f041989e93b001bc4e44bb1669ccdcf54d3f00e628229a85b08d330615c5/charset_normalizer-3.4.3-py3-none-any.whl", hash = "sha256:ce571ab16d890d23b5c278547ba694193a45011ff86a9162a71307ed9f86759a", size = 53175, upload-time = "2025-08-09T07:57:26.864Z" }, ] +[[package]] +name = "chess" +version = "1.11.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/93/09/7d04d7581ae3bb8b598017941781bceb7959dd1b13e3ebf7b6a2cd843bc9/chess-1.11.2.tar.gz", hash = "sha256:a8b43e5678fdb3000695bdaa573117ad683761e5ca38e591c4826eba6d25bb39", size = 6131385, upload-time = "2025-02-25T19:10:27.328Z" } + [[package]] name = "chromadb" version = "1.5.4" @@ -1508,6 +1514,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/31/b4/b9b800c45527aadd64d5b442f9b932b00648617eb5d63d2c7a6587b7cafc/jmespath-1.0.1-py3-none-any.whl", hash = "sha256:02e2e4cc71b5bcab88332eebf907519190dd9e6e82107fa7f83b1003a6252980", size = 20256, upload-time = "2022-06-17T18:00:10.251Z" }, ] +[[package]] +name = "joblib" +version = "1.5.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/41/f2/d34e8b3a08a9cc79a50b2208a93dce981fe615b64d5a4d4abee421d898df/joblib-1.5.3.tar.gz", hash = "sha256:8561a3269e6801106863fd0d6d84bb737be9e7631e33aaed3fb9ce5953688da3", size = 331603, upload-time = "2025-12-15T08:41:46.427Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7b/91/984aca2ec129e2757d1e4e3c81c3fcda9d0f85b74670a094cc443d9ee949/joblib-1.5.3-py3-none-any.whl", hash = "sha256:5fc3c5039fc5ca8c0276333a188bbd59d6b7ab37fe6632daa76bc7f9ec18e713", size = 309071, upload-time = "2025-12-15T08:41:44.973Z" }, +] + [[package]] name = "jsonschema" version = "4.25.1" @@ -2118,6 +2133,21 @@ requires-dist = [ { name = "torch" }, ] +[[package]] +name = "nltk" +version = "3.9.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "joblib", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "regex", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "tqdm", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/74/a1/b3b4adf15585a5bc4c357adde150c01ebeeb642173ded4d871e89468767c/nltk-3.9.4.tar.gz", hash = "sha256:ed03bc098a40481310320808b2db712d95d13ca65b27372f8a403949c8b523d0", size = 2946864, upload-time = "2026-03-24T06:13:40.641Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9d/91/04e965f8e717ba0ab4bdca5c112deeab11c9e750d94c4d4602f050295d39/nltk-3.9.4-py3-none-any.whl", hash = "sha256:f2fa301c3a12718ce4a0e9305c5675299da5ad9e26068218b69d692fda84828f", size = 1552087, upload-time = "2026-03-24T06:13:38.47Z" }, +] + [[package]] name = "nodeenv" version = "1.9.1" @@ -2782,9 +2812,10 @@ dependencies = [ { name = "transformers", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "uvloop", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "verifiers", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "vllm", version = "0.20.1rc1.dev97+g6acff036e.d20260513.precompiled", source = { path = "third_party/vllm/dist-r3-v3-external/vllm-0.20.1rc1.dev97+g6acff036e.d20260513.precompiled-cp312-cp312-linux_x86_64.whl" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "vllm", version = "0.20.2+cu129", source = { url = "https://github.com/vllm-project/vllm/releases/download/v0.20.2/vllm-0.20.2+cu129-cp38-abi3-manylinux_2_31_aarch64.whl" }, marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "vllm", version = "0.20.2+cu129", source = { url = "https://github.com/vllm-project/vllm/releases/download/v0.20.2/vllm-0.20.2+cu129-cp38-abi3-manylinux_2_31_x86_64.whl" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "wandb", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "wordle", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] [package.optional-dependencies] @@ -2892,7 +2923,7 @@ requires-dist = [ { name = "pyarrow", specifier = ">=21.0.0" }, { name = "pyzmq", specifier = ">=27.1.0" }, { name = "quack-kernels", marker = "extra == 'quack'", specifier = ">=0.3.3" }, - { name = "renderers", specifier = "==0.1.6" }, + { name = "renderers", specifier = ">=0.1.8.dev0" }, { name = "reverse-text", marker = "extra == 'envs'", index = "https://hub.primeintellect.ai/primeintellect/simple/" }, { name = "rich", specifier = ">=14.0.0" }, { name = "ring-flash-attn", specifier = ">=0.1.8" }, @@ -2907,13 +2938,14 @@ requires-dist = [ { name = "torchvision", index = "https://download.pytorch.org/whl/cu128" }, { name = "transformers", git = "https://github.com/huggingface/transformers.git?rev=c1c3424" }, { name = "uvloop", specifier = ">=0.21.0" }, - { name = "verifiers", git = "https://github.com/PrimeIntellect-ai/verifiers.git?rev=aa428f3" }, - { name = "vllm", marker = "platform_machine != 'aarch64' and platform_machine != 'x86_64'", specifier = ">=0.20.2" }, + { name = "verifiers", editable = "third_party/verifiers" }, + { name = "vllm", marker = "platform_machine != 'aarch64' and platform_machine != 'x86_64'" }, { name = "vllm", marker = "platform_machine == 'aarch64'", url = "https://github.com/vllm-project/vllm/releases/download/v0.20.2/vllm-0.20.2+cu129-cp38-abi3-manylinux_2_31_aarch64.whl" }, - { name = "vllm", marker = "platform_machine == 'x86_64'", url = "https://github.com/vllm-project/vllm/releases/download/v0.20.2/vllm-0.20.2+cu129-cp38-abi3-manylinux_2_31_x86_64.whl" }, - { name = "vllm-router", marker = "platform_machine == 'x86_64' and extra == 'disagg'", url = "https://github.com/PrimeIntellect-ai/router/releases/download/v0.1.22/vllm_router-0.1.22-cp38-abi3-manylinux_2_28_x86_64.whl" }, + { name = "vllm", marker = "platform_machine == 'x86_64'", path = "third_party/vllm/dist-r3-v3-external/vllm-0.20.1rc1.dev97+g6acff036e.d20260513.precompiled-cp312-cp312-linux_x86_64.whl" }, + { name = "vllm-router", marker = "platform_machine == 'x86_64' and extra == 'disagg'", path = "third_party/router/dist-r3-v3-router-cache/vllm_router-0.1.23-cp38-abi3-linux_x86_64.whl" }, { name = "wandb", specifier = ">=0.26.1" }, { name = "wiki-search", marker = "extra == 'envs'", index = "https://hub.primeintellect.ai/primeintellect/simple/" }, + { name = "wordle", directory = "third_party/verifiers/environments/wordle" }, ] provides-extras = ["flash-attn", "flash-attn-3", "flash-attn-cute", "envs", "disagg", "gpt-oss", "quack", "all"] @@ -3378,7 +3410,7 @@ wheels = [ [[package]] name = "renderers" -version = "0.1.6" +version = "0.1.8.dev0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jinja2", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, @@ -3388,9 +3420,9 @@ dependencies = [ { name = "tiktoken", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "transformers", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/7c/a7/26162494dab2d7740ff02191cb87c30b68450fb154363c7f0a434e7f3ea9/renderers-0.1.6.tar.gz", hash = "sha256:b74bc3dc870bea3c37ff5b47826ace9b8dd608a4c1f56554c39be1b20b2c63dc", size = 163768, upload-time = "2026-05-07T14:12:36.634Z" } +sdist = { url = "https://files.pythonhosted.org/packages/50/de/a445036157af3367c6a962c13333427c83c08926934c541886eb87f9dcdf/renderers-0.1.8.dev0.tar.gz", hash = "sha256:71eef7bfa3d3f5849ba070d38cd89a1f6387ca7710824f2e50d8c05c9b1048b9", size = 210667, upload-time = "2026-05-12T17:48:45.352Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/5e/ad/2cf218b9fafe2333fb3e80e123e3e2022d4923d9a61fa73ee6d79f39b563/renderers-0.1.6-py3-none-any.whl", hash = "sha256:90c626713239ec108716b7c9d194ba81ffcebe94dc003324f14fbd70e6793e89", size = 83348, upload-time = "2026-05-07T14:12:35.218Z" }, + { url = "https://files.pythonhosted.org/packages/e7/33/936a38c7f20fbe096b751842ffc6ef254c9eb2223153aa860a122ce9a834/renderers-0.1.8.dev0-py3-none-any.whl", hash = "sha256:09bb35233f67599519c0ff6edfad469f0836a55a6b78e039cd8e7b5e527bdcb3", size = 98617, upload-time = "2026-05-12T17:48:44.222Z" }, ] [[package]] @@ -3762,6 +3794,24 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/73/c6/825dab04195756cf8ff2e12698f22513b3db2f64925bdd41671bfb33aaa5/tensorboard_data_server-0.7.2-py3-none-manylinux_2_31_x86_64.whl", hash = "sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530", size = 6590363, upload-time = "2023-10-23T21:23:35.583Z" }, ] +[[package]] +name = "textarena" +version = "0.7.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "chess", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "nltk", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "openai", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "python-dotenv", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "requests", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "rich", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "websockets", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ba/04/4a3ca42093d0be2a9c377ae3335a6c6baac1d278ae932562ec69f339d172/textarena-0.7.4.tar.gz", hash = "sha256:28bb9170d7718f2ae05e4515bea82262422731e563fc7318a9e7983de0cadd4f", size = 954969, upload-time = "2025-10-16T14:41:55.981Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/26/b4/9a9ba65154aff853c75b3d7324319d168ad9c69c6097f4aa3c16da7d9ef3/textarena-0.7.4-py3-none-any.whl", hash = "sha256:684784e78278e518066f67557ee93b47c238d16cbbd15d3abdaa3147562d3024", size = 1073570, upload-time = "2025-10-16T14:41:53.965Z" }, +] + [[package]] name = "textual" version = "8.2.5" @@ -4197,8 +4247,7 @@ wheels = [ [[package]] name = "verifiers" -version = "0.1.14" -source = { git = "https://github.com/PrimeIntellect-ai/verifiers.git?rev=aa428f3#aa428f3941ae35a7cf7c0dad7e60c7eca525bac6" } +source = { editable = "third_party/verifiers" } dependencies = [ { name = "aiolimiter", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "anthropic", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, @@ -4223,7 +4272,75 @@ dependencies = [ { name = "setproctitle", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "tenacity", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "textual", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "wget", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, +] + +[package.metadata] +requires-dist = [ + { name = "accelerate", marker = "extra == 'rl'", specifier = ">=1.4.0" }, + { name = "aiohttp", marker = "extra == 'browser'", specifier = ">=3.9.0" }, + { name = "aiolimiter", specifier = ">=1.2.1" }, + { name = "anthropic", specifier = ">=0.78.0" }, + { name = "datasets", specifier = ">=3.0.0,<4.7.0" }, + { name = "deepspeed", marker = "extra == 'rl'", specifier = ">=0.17.6" }, + { name = "flash-attn", marker = "extra == 'rl'", specifier = ">=2.8.3" }, + { name = "gepa" }, + { name = "httpx", specifier = ">=0.27.0" }, + { name = "jinja2", specifier = ">=3.1.6" }, + { name = "liger-kernel", marker = "extra == 'rl'", specifier = ">=0.5.10" }, + { name = "math-verify", specifier = ">=0.8.0" }, + { name = "mcp", specifier = ">=1.14.1" }, + { name = "msgpack", specifier = ">=1.1.2" }, + { name = "nest-asyncio", specifier = ">=1.6.0" }, + { name = "nltk", marker = "extra == 'ta'" }, + { name = "numpy" }, + { name = "openai", specifier = ">=1.108.1" }, + { name = "openai-agents", specifier = ">=0.0.7" }, + { name = "openenv-core", extras = ["core"], marker = "extra == 'openenv'", specifier = "==0.2.1" }, + { name = "peft", marker = "extra == 'rl'" }, + { name = "prime-sandboxes", specifier = ">=0.2.25" }, + { name = "prime-tunnel", specifier = ">=0.1.6" }, + { name = "pydantic", specifier = ">=2.11.9" }, + { name = "python-dotenv", marker = "extra == 'browser'", specifier = ">=1.0.0" }, + { name = "pyzmq", specifier = ">=27.1.0" }, + { name = "reasoning-gym", marker = "extra == 'rg'" }, + { name = "regex", specifier = "<2026.4.4" }, + { name = "renderers", marker = "extra == 'renderers'", specifier = ">=0.1.8.dev0" }, + { name = "requests" }, + { name = "requests", marker = "extra == 'rl'" }, + { name = "rich" }, + { name = "setproctitle", specifier = ">=1.3.0" }, + { name = "stagehand", marker = "extra == 'browser'", specifier = ">=3.0.0" }, + { name = "tenacity", specifier = ">=8.5.0" }, + { name = "textarena", marker = "extra == 'ta'" }, + { name = "textual" }, + { name = "tomli", marker = "python_full_version < '3.11'" }, + { name = "torch", marker = "extra == 'rl'", specifier = ">=2.8.0,<2.9.0" }, + { name = "transformers", marker = "extra == 'rl'", specifier = ">=4.56.2" }, + { name = "typing-extensions", marker = "python_full_version < '3.12'" }, + { name = "vllm", marker = "extra == 'rl'", specifier = ">=0.10.0,<0.11.0" }, + { name = "wandb", marker = "extra == 'rl'" }, +] +provides-extras = ["browser", "openenv", "renderers", "rg", "rl", "ta"] + +[package.metadata.requires-dev] +dev = [ + { name = "aiohttp", specifier = ">=3.9.0" }, + { name = "ipykernel" }, + { name = "ipywidgets" }, + { name = "nltk" }, + { name = "openenv-core", extras = ["core"], specifier = "==0.2.1" }, + { name = "pre-commit" }, + { name = "pytest", specifier = ">=7.0.0" }, + { name = "pytest-asyncio", specifier = ">=0.21.0" }, + { name = "pytest-cov", specifier = ">=4.0.0" }, + { name = "pytest-xdist", specifier = ">=3.8.0" }, + { name = "python-dotenv", specifier = ">=1.0.0" }, + { name = "reasoning-gym" }, + { name = "renderers", specifier = ">=0.1.8.dev0" }, + { name = "ruff" }, + { name = "stagehand", specifier = ">=3.0.0" }, + { name = "textarena" }, + { name = "ty", specifier = ">=0.0.1a29,<0.0.22" }, ] [[package]] @@ -4242,83 +4359,83 @@ wheels = [ [[package]] name = "vllm" -version = "0.20.2+cu129" -source = { url = "https://github.com/vllm-project/vllm/releases/download/v0.20.2/vllm-0.20.2+cu129-cp38-abi3-manylinux_2_31_aarch64.whl" } +version = "0.20.1rc1.dev97+g6acff036e.d20260513.precompiled" +source = { path = "third_party/vllm/dist-r3-v3-external/vllm-0.20.1rc1.dev97+g6acff036e.d20260513.precompiled-cp312-cp312-linux_x86_64.whl" } resolution-markers = [ - "platform_machine == 'aarch64' and sys_platform == 'linux'", + "platform_machine == 'x86_64' and sys_platform == 'linux'", ] dependencies = [ - { name = "aiohttp", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "anthropic", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "apache-tvm-ffi", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "blake3", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "cachetools", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "cbor2", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "cloudpickle", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "compressed-tensors", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "depyf", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "diskcache", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "einops", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "fastapi", extra = ["standard"], marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "fastsafetensors", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "filelock", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "flashinfer-cubin", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "flashinfer-python", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "gguf", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "ijson", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "lark", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "llguidance", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "lm-format-enforcer", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "mcp", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "mistral-common", extra = ["image"], marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "model-hosting-container-standards", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "msgspec", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "ninja", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "numba", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "numpy", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "nvidia-cudnn-frontend", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "nvidia-cutlass-dsl", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "openai", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "openai-harmony", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "opencv-python-headless", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "opentelemetry-api", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "opentelemetry-exporter-otlp", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "opentelemetry-sdk", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "opentelemetry-semantic-conventions-ai", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "outlines-core", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "partial-json-parser", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "pillow", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "prometheus-client", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "prometheus-fastapi-instrumentator", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "protobuf", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "psutil", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "py-cpuinfo", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "pybase64", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "pydantic", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "python-json-logger", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "pyyaml", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "pyzmq", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "quack-kernels", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "regex", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "requests", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "sentencepiece", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "setproctitle", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "setuptools", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "six", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "tiktoken", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "tilelang", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "tokenizers", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "torch", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "torchaudio", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "torchvision", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "tqdm", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "transformers", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "typing-extensions", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "watchfiles", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "xgrammar", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "aiohttp", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "anthropic", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "apache-tvm-ffi", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "blake3", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "cachetools", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "cbor2", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "cloudpickle", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "compressed-tensors", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "depyf", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "diskcache", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "einops", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "fastapi", extra = ["standard"], marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "fastsafetensors", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "filelock", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "flashinfer-cubin", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "flashinfer-python", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "gguf", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "ijson", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "lark", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "llguidance", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "lm-format-enforcer", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "mcp", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "mistral-common", extra = ["image"], marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "model-hosting-container-standards", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "msgspec", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "ninja", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "numba", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "numpy", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cudnn-frontend", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cutlass-dsl", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "openai", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "openai-harmony", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "opencv-python-headless", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "opentelemetry-api", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "opentelemetry-exporter-otlp", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "opentelemetry-sdk", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "opentelemetry-semantic-conventions-ai", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "outlines-core", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "partial-json-parser", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "pillow", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "prometheus-client", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "prometheus-fastapi-instrumentator", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "protobuf", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "psutil", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "py-cpuinfo", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "pybase64", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "pydantic", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "python-json-logger", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "pyyaml", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "pyzmq", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "quack-kernels", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "regex", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "requests", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "sentencepiece", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "setproctitle", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "setuptools", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "six", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "tiktoken", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "tilelang", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "tokenizers", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "torch", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "torchaudio", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "torchvision", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "tqdm", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "transformers", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "typing-extensions", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "watchfiles", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "xgrammar", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, ] wheels = [ - { url = "https://github.com/vllm-project/vllm/releases/download/v0.20.2/vllm-0.20.2+cu129-cp38-abi3-manylinux_2_31_aarch64.whl", hash = "sha256:8a58a086c5c4ed2883eee36aaaf6b79c83463d02da3015454acf92afcc8e150e" }, + { filename = "vllm-0.20.1rc1.dev97+g6acff036e.d20260513.precompiled-cp312-cp312-linux_x86_64.whl", hash = "sha256:8db1e80a5f4dd97237d7c5702b33f37a65910db9976b42db4f58937ddd0ffd48" }, ] [package.metadata] @@ -4418,82 +4535,82 @@ provides-extras = ["zen", "bench", "tensorizer", "fastsafetensors", "instanttens [[package]] name = "vllm" version = "0.20.2+cu129" -source = { url = "https://github.com/vllm-project/vllm/releases/download/v0.20.2/vllm-0.20.2+cu129-cp38-abi3-manylinux_2_31_x86_64.whl" } +source = { url = "https://github.com/vllm-project/vllm/releases/download/v0.20.2/vllm-0.20.2+cu129-cp38-abi3-manylinux_2_31_aarch64.whl" } resolution-markers = [ - "platform_machine == 'x86_64' and sys_platform == 'linux'", + "platform_machine == 'aarch64' and sys_platform == 'linux'", ] dependencies = [ - { name = "aiohttp", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "anthropic", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "apache-tvm-ffi", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "blake3", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "cachetools", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "cbor2", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "cloudpickle", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "compressed-tensors", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "depyf", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "diskcache", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "einops", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "fastapi", extra = ["standard"], marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "fastsafetensors", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "filelock", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "flashinfer-cubin", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "flashinfer-python", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "gguf", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "ijson", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "lark", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "llguidance", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "lm-format-enforcer", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "mcp", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "mistral-common", extra = ["image"], marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "model-hosting-container-standards", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "msgspec", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "ninja", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "numba", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "numpy", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cudnn-frontend", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "nvidia-cutlass-dsl", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "openai", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "openai-harmony", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "opencv-python-headless", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "opentelemetry-api", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "opentelemetry-exporter-otlp", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "opentelemetry-sdk", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "opentelemetry-semantic-conventions-ai", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "outlines-core", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "partial-json-parser", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "pillow", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "prometheus-client", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "prometheus-fastapi-instrumentator", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "protobuf", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "psutil", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "py-cpuinfo", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "pybase64", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "pydantic", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "python-json-logger", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "pyyaml", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "pyzmq", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "quack-kernels", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "regex", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "requests", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "sentencepiece", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "setproctitle", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "setuptools", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "six", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "tiktoken", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "tilelang", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "tokenizers", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "torch", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "torchaudio", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "torchvision", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "tqdm", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "transformers", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "typing-extensions", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "watchfiles", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "xgrammar", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "aiohttp", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "anthropic", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "apache-tvm-ffi", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "blake3", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "cachetools", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "cbor2", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "cloudpickle", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "compressed-tensors", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "depyf", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "diskcache", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "einops", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "fastapi", extra = ["standard"], marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "fastsafetensors", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "filelock", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "flashinfer-cubin", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "flashinfer-python", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "gguf", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "ijson", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "lark", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "llguidance", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "lm-format-enforcer", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "mcp", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "mistral-common", extra = ["image"], marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "model-hosting-container-standards", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "msgspec", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "ninja", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "numba", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "numpy", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "nvidia-cudnn-frontend", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "nvidia-cutlass-dsl", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "openai", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "openai-harmony", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "opencv-python-headless", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "opentelemetry-api", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "opentelemetry-exporter-otlp", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "opentelemetry-sdk", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "opentelemetry-semantic-conventions-ai", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "outlines-core", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "partial-json-parser", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "pillow", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "prometheus-client", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "prometheus-fastapi-instrumentator", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "protobuf", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "psutil", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "py-cpuinfo", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "pybase64", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "pydantic", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "python-json-logger", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "pyyaml", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "pyzmq", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "quack-kernels", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "regex", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "requests", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "sentencepiece", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "setproctitle", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "setuptools", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "six", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "tiktoken", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "tilelang", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "tokenizers", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "torch", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "torchaudio", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "torchvision", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "tqdm", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "transformers", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "typing-extensions", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "watchfiles", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "xgrammar", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, ] wheels = [ - { url = "https://github.com/vllm-project/vllm/releases/download/v0.20.2/vllm-0.20.2+cu129-cp38-abi3-manylinux_2_31_x86_64.whl", hash = "sha256:2f8c2bf2ac6d3d16f930535e66822abd71065468521884eb5b910225b2abef4b" }, + { url = "https://github.com/vllm-project/vllm/releases/download/v0.20.2/vllm-0.20.2+cu129-cp38-abi3-manylinux_2_31_aarch64.whl", hash = "sha256:8a58a086c5c4ed2883eee36aaaf6b79c83463d02da3015454acf92afcc8e150e" }, ] [package.metadata] @@ -4592,8 +4709,8 @@ provides-extras = ["zen", "bench", "tensorizer", "fastsafetensors", "instanttens [[package]] name = "vllm-router" -version = "0.1.22" -source = { url = "https://github.com/PrimeIntellect-ai/router/releases/download/v0.1.22/vllm_router-0.1.22-cp38-abi3-manylinux_2_28_x86_64.whl" } +version = "0.1.23" +source = { path = "third_party/router/dist-r3-v3-router-cache/vllm_router-0.1.23-cp38-abi3-linux_x86_64.whl" } dependencies = [ { name = "aiohttp", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "fastapi", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, @@ -4603,7 +4720,7 @@ dependencies = [ { name = "uvicorn", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, ] wheels = [ - { url = "https://github.com/PrimeIntellect-ai/router/releases/download/v0.1.22/vllm_router-0.1.22-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:6361a0387241e56932f3ba2e51af27f58d11a462e3187e58286b2f96056e4d15" }, + { filename = "vllm_router-0.1.23-cp38-abi3-linux_x86_64.whl", hash = "sha256:545aee83e3b10901a50d8426b32105bf3eeb9fa8af8c71cdc7f4cfc4303da42f" }, ] [package.metadata] @@ -4711,12 +4828,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/52/24/ab44c871b0f07f491e5d2ad12c9bd7358e527510618cb1b803a88e986db1/werkzeug-3.1.3-py3-none-any.whl", hash = "sha256:54b78bf3716d19a65be4fceccc0d1d7b89e608834989dfae50ea87564639213e", size = 224498, upload-time = "2024-11-08T15:52:16.132Z" }, ] -[[package]] -name = "wget" -version = "3.2" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/47/6a/62e288da7bcda82b935ff0c6cfe542970f04e29c756b0e147251b2fb251f/wget-3.2.zip", hash = "sha256:35e630eca2aa50ce998b9b1a127bb26b30dfee573702782aa982f875e3f16061", size = 10857, upload-time = "2015-10-22T15:26:37.51Z" } - [[package]] name = "widgetsnbextension" version = "4.0.14" @@ -4740,6 +4851,23 @@ wheels = [ { url = "https://hub.primeintellect.ai/primeintellect/wiki-search/@10d58ffe/wiki_search-0.1.23-py3-none-any.whl", hash = "sha256:ffeff890f2d14d7b2910baf57c27f6939da0f669ae0c4545916762f3f4edd75b" }, ] +[[package]] +name = "wordle" +version = "0.1.7" +source = { directory = "third_party/verifiers/environments/wordle" } +dependencies = [ + { name = "nltk", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "textarena", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "verifiers", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, +] + +[package.metadata] +requires-dist = [ + { name = "nltk", specifier = ">=3.9.2" }, + { name = "textarena", specifier = "==0.7.4" }, + { name = "verifiers", specifier = ">=0.1.9.post3" }, +] + [[package]] name = "xgrammar" version = "0.1.33" From 721a87434b6e7db04a93733f4eac5f61f008edae Mon Sep 17 00:00:00 2001 From: S1ro1 Date: Wed, 13 May 2026 17:27:35 +0530 Subject: [PATCH 02/32] feat: reset routing caches on policy update --- .../src/prime_rl/configs/orchestrator.py | 10 +++ skills/config/SKILL.md | 4 + src/prime_rl/entrypoints/rl.py | 11 +++ src/prime_rl/inference/vllm/server.py | 8 +- src/prime_rl/orchestrator/orchestrator.py | 10 ++- src/prime_rl/orchestrator/scheduler.py | 11 ++- src/prime_rl/utils/client.py | 78 ++++++++++++++++--- src/prime_rl/utils/elastic.py | 8 +- 8 files changed, 123 insertions(+), 17 deletions(-) diff --git a/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py b/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py index 5d04d3369f..42111daf2e 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py @@ -1090,6 +1090,16 @@ class OrchestratorConfig(BaseConfig): ), ] = None + reset_prefix_cache_on_policy_update: Annotated[ + bool, + Field( + description=( + "Reset vLLM prefix caches when updating inference weights. This prevents stale KV cache reuse " + "across policy steps, at the cost of losing cross-policy prefix-cache hits." + ), + ), + ] = False + max_async_level: Annotated[ int, Field( diff --git a/skills/config/SKILL.md b/skills/config/SKILL.md index e8dc13216c..a2ecd68dd9 100644 --- a/skills/config/SKILL.md +++ b/skills/config/SKILL.md @@ -157,6 +157,10 @@ If you wish to configure values of the default variant, you don't need to set th For hosted multi-tenant runs where the trainer image's `trainer.loss.type` is fixed, the orchestrator exposes a per-run override that forces SFT loss on every micro-batch without rebuilding the trainer. Set `orchestrator.use_sft_loss = true` alongside `orchestrator.teacher_rollout_model`; both must be configured together (the orchestrator validator enforces this). The orchestrator stamps each `TrainingSample.sft_loss = True`, which the trainer's `compute_loss` honors by dispatching to `sft_loss_fn` per batch — independent of the trainer's configured default loss. +### Router replay with KV offload + +When `trainer.enable_router_replay = true` and inference CPU KV offload is configured, RL config auto-sets `orchestrator.reset_prefix_cache_on_policy_update = true`. This makes policy updates pause vLLM in `clear` mode instead of `keep` mode, so old-policy prefix-cache entries are not reused after new weights are loaded. If the rollout client points at a router, the orchestrator also calls the router's `clear_routing_cache` route after backend weight updates resume. + ### Model fields For `BaseModel | None` fields (like `[ckpt]`, `[wandb]`, `[compile]`), a bare flag enables them with defaults: diff --git a/src/prime_rl/entrypoints/rl.py b/src/prime_rl/entrypoints/rl.py index b740f58ba2..fcf98f02d9 100644 --- a/src/prime_rl/entrypoints/rl.py +++ b/src/prime_rl/entrypoints/rl.py @@ -543,7 +543,18 @@ def rl_slurm(config: RLConfig): logger.success(f"{result.stdout.strip()}\n\n{log_message}") +def finalize_policy_update_cache_reset(config: RLConfig) -> None: + if ( + config.trainer.enable_router_replay + and config.inference is not None + and config.inference.kv_cache_offload is not None + ): + config.orchestrator.reset_prefix_cache_on_policy_update = True + + def rl(config: RLConfig): + finalize_policy_update_cache_reset(config) + resuming = config.ckpt is not None and config.ckpt.resume_step is not None clean = config.clean_output_dir and not os.environ.get("NEVER_CLEAN_OUTPUT_DIR") ckpt_output_dir = config.ckpt.output_dir if config.ckpt else None diff --git a/src/prime_rl/inference/vllm/server.py b/src/prime_rl/inference/vllm/server.py index 53ae22c104..d38ae1b0de 100644 --- a/src/prime_rl/inference/vllm/server.py +++ b/src/prime_rl/inference/vllm/server.py @@ -1,7 +1,7 @@ import asyncio from argparse import Namespace from http import HTTPStatus -from typing import Any +from typing import Any, Literal import uvloop from fastapi import APIRouter, Depends, Request @@ -210,9 +210,9 @@ async def _chat_with_tokens(request: ChatCompletionRequestWithTokens, raw_reques @router.post("/pause") -async def pause(request: Request): - await engine_client(request).pause_generation(mode="keep", clear_cache=False) - return {"status": "paused"} +async def pause(request: Request, mode: Literal["keep", "clear"] = "keep"): + await engine_client(request).pause_generation(mode="keep", clear_cache=mode == "clear") + return {"status": "paused", "mode": mode} @router.post("/resume") diff --git a/src/prime_rl/orchestrator/orchestrator.py b/src/prime_rl/orchestrator/orchestrator.py index bc1128ebc7..e4feb8e09c 100644 --- a/src/prime_rl/orchestrator/orchestrator.py +++ b/src/prime_rl/orchestrator/orchestrator.py @@ -55,6 +55,7 @@ ) from prime_rl.trainer.model import setup_tokenizer from prime_rl.utils.client import ( + clear_routing_cache, init_nccl_broadcast, setup_inference_pool, ) @@ -316,7 +317,14 @@ async def orchestrate(config: OrchestratorConfig): config.output_dir, scheduler.ckpt_step, check_exists=check_exists, wait_timeout=wait_timeout ) lora_name = config.model.lora.name if config.model.lora else None - await inference_pool.update_weights(weights_path, lora_name=lora_name, step=scheduler.ckpt_step) + await inference_pool.update_weights( + weights_path, + lora_name=lora_name, + step=scheduler.ckpt_step, + reset_prefix_cache=config.reset_prefix_cache_on_policy_update, + ) + if config.reset_prefix_cache_on_policy_update: + await clear_routing_cache(config.client) else: logger.info("Training from scratch") diff --git a/src/prime_rl/orchestrator/scheduler.py b/src/prime_rl/orchestrator/scheduler.py index c266757c1c..ab92c91a24 100644 --- a/src/prime_rl/orchestrator/scheduler.py +++ b/src/prime_rl/orchestrator/scheduler.py @@ -13,7 +13,7 @@ from prime_rl.orchestrator.envs import TrainEnvs from prime_rl.orchestrator.vf_utils import get_seq_len from prime_rl.utils.async_utils import safe_cancel, safe_cancel_all -from prime_rl.utils.client import InferencePool +from prime_rl.utils.client import InferencePool, clear_routing_cache from prime_rl.utils.logger import ProgressTracker, get_logger from prime_rl.utils.utils import ( get_broadcast_dir, @@ -320,7 +320,14 @@ async def _apply_policy_update(self, next_ckpt_step: int) -> None: update_weights_start_time = time.perf_counter() weights_path = get_step_path(get_broadcast_dir(self.config.output_dir), next_ckpt_step) - await self.inference_pool.update_weights(weights_path, lora_name=self.lora_name, step=next_ckpt_step) + await self.inference_pool.update_weights( + weights_path, + lora_name=self.lora_name, + step=next_ckpt_step, + reset_prefix_cache=self.config.reset_prefix_cache_on_policy_update, + ) + if self.config.reset_prefix_cache_on_policy_update: + await clear_routing_cache(self.config.client) self.update_weights_time = time.perf_counter() - update_weights_start_time self.logger.debug(f"Updated weights to step {next_ckpt_step} in {self.update_weights_time:.2f}s") diff --git a/src/prime_rl/utils/client.py b/src/prime_rl/utils/client.py index 21659dfc46..876503a494 100644 --- a/src/prime_rl/utils/client.py +++ b/src/prime_rl/utils/client.py @@ -42,7 +42,13 @@ async def wait_for_ready(self, model_name: str, timeout: int | None = None) -> N """Wait for inference pool to be ready.""" ... - async def update_weights(self, weight_dir: Path | None, lora_name: str | None = None, step: int = 0) -> None: + async def update_weights( + self, + weight_dir: Path | None, + lora_name: str | None = None, + step: int = 0, + reset_prefix_cache: bool = False, + ) -> None: """Update weights on all inference servers.""" ... @@ -110,8 +116,20 @@ async def wait_for_ready(self, model_name: str, timeout: int | None = None) -> N ) await maybe_check_has_model(self._admin_clients, model_name, skip_model_check=self._skip_model_check) - async def update_weights(self, weight_dir: Path | None, lora_name: str | None = None, step: int = 0) -> None: - await update_weights(self._admin_clients, weight_dir, lora_name=lora_name, step=step) + async def update_weights( + self, + weight_dir: Path | None, + lora_name: str | None = None, + step: int = 0, + reset_prefix_cache: bool = False, + ) -> None: + await update_weights( + self._admin_clients, + weight_dir, + lora_name=lora_name, + step=step, + reset_prefix_cache=reset_prefix_cache, + ) def get_metrics(self) -> dict[str, float]: return {} @@ -287,13 +305,14 @@ async def _check_health(admin_client: AsyncClient) -> None: NCCL_READY_MARKER = "NCCL_READY" -async def _pause_engines(admin_clients: list[AsyncClient]) -> None: +async def _pause_engines(admin_clients: list[AsyncClient], reset_prefix_cache: bool = False) -> None: """Pause all inference engines, waiting for in-flight requests to drain.""" logger = get_logger() - logger.info("Pausing inference engines for weight update") + mode = "clear" if reset_prefix_cache else "keep" + logger.info(f"Pausing inference engines for weight update (mode={mode})") async def _pause(client: AsyncClient) -> None: - response = await client.post("/pause", params={"mode": "keep", "clear_cache": "false"}) + response = await client.post("/pause", params={"mode": mode}) response.raise_for_status() await asyncio.gather(*[_pause(client) for client in admin_clients]) @@ -312,11 +331,52 @@ async def _resume(client: AsyncClient) -> None: logger.info("All inference engines resumed") +async def clear_routing_cache(client_config: ClientConfig) -> None: + """Clear router-local routed-experts cache when a policy update resets prefix cache.""" + logger = get_logger() + if client_config.router_url is not None: + router_urls = [client_config.router_url] + elif client_config.admin_base_url is not None: + router_urls = client_config.base_url + else: + router_urls = [] + + def _setup_router_client(base_url: str) -> AsyncClient: + headers = client_config.headers.copy() + api_key = os.getenv(client_config.api_key_var, "EMPTY") + if api_key and api_key != "EMPTY": + headers["Authorization"] = f"Bearer {api_key}" + + return AsyncClient( + base_url=base_url.rstrip("/").removesuffix("/v1"), + headers=headers, + limits=httpx.Limits(max_connections=4, max_keepalive_connections=1), + timeout=httpx.Timeout(None), + ) + + router_clients = [_setup_router_client(url) for url in router_urls] + if not router_clients: + logger.info("Skipping routing cache clear: no router admin endpoint configured") + return + + async def _clear(client: AsyncClient) -> None: + response = await client.post("/clear_routing_cache") + response.raise_for_status() + + try: + logger.info(f"Clearing router routing cache on {', '.join(str(client.base_url) for client in router_clients)}") + await asyncio.gather(*[_clear(client) for client in router_clients]) + logger.info("Router routing cache cleared") + finally: + await asyncio.gather(*[client.aclose() for client in router_clients]) + + async def update_weights( admin_clients: list[AsyncClient], weight_dir: Path | None, lora_name: str | None = None, step: int = 0, + reset_prefix_cache: bool = False, ) -> None: """Update weights on static inference servers. @@ -324,8 +384,8 @@ async def update_weights( weight update, then resumes. This ensures all DP workers are idle and can participate in the collective weight transfer. - Note: The server-side /update_weights endpoint automatically resets the prefix cache - to invalidate any cached KV states computed with the old weights. + When reset_prefix_cache is enabled, engines are paused in clear mode so vLLM + drops prefix-cache state before loading the new weights. """ logger = get_logger() @@ -340,7 +400,7 @@ async def _update_weights(admin_client: AsyncClient, weight_dir: str | None) -> response.raise_for_status() # Pause engines so all DP workers drain in-flight work and can join the NCCL broadcast - await _pause_engines(admin_clients) + await _pause_engines(admin_clients, reset_prefix_cache=reset_prefix_cache) try: # Create ready marker before servers enter receive path (used by NCCL broadcast) diff --git a/src/prime_rl/utils/elastic.py b/src/prime_rl/utils/elastic.py index 902f873903..5de73497f4 100644 --- a/src/prime_rl/utils/elastic.py +++ b/src/prime_rl/utils/elastic.py @@ -499,7 +499,13 @@ async def wait_for_ready(self, model_name: str = "", timeout: int | None = None, raise TimeoutError(f"Timed out waiting for {min_servers} ready servers (got {self.num_ready_servers})") - async def update_weights(self, weight_dir: Path | None, lora_name: str | None = None, step: int = 0) -> None: + async def update_weights( + self, + weight_dir: Path | None, + lora_name: str | None = None, + step: int = 0, + reset_prefix_cache: bool = False, + ) -> None: if lora_name is None: raise ValueError("Elastic inference pool requires LoRA training (lora_name must be set)") await self.sync_weights(weight_dir, lora_name, step) From baa6935f64c356de5a6f4fe6b2740714a49a4d01 Mon Sep 17 00:00:00 2001 From: S1ro1 Date: Wed, 13 May 2026 18:02:13 +0530 Subject: [PATCH 03/32] fix: rely on native vllm routed experts --- src/prime_rl/inference/vllm/server.py | 8 +- .../vllm/serving_chat_with_tokens.py | 57 +-------- src/prime_rl/inference/vllm/serving_tokens.py | 114 +----------------- uv.lock | 4 +- 4 files changed, 13 insertions(+), 170 deletions(-) diff --git a/src/prime_rl/inference/vllm/server.py b/src/prime_rl/inference/vllm/server.py index d38ae1b0de..d9740a6882 100644 --- a/src/prime_rl/inference/vllm/server.py +++ b/src/prime_rl/inference/vllm/server.py @@ -281,8 +281,8 @@ async def custom_init_app_state( so the ``/v1/chat/completions/tokens`` (TITO) endpoint can stream token IDs alongside the rendered chat completion. 3. Replace ``serving_tokens`` with ``PrimeRlServingTokens`` so DP-rank - routing and ``routed_experts`` export survive the migration off the - legacy ``/v1/generate`` endpoint. + routing and server-side ``max_tokens`` defaulting are available on + ``/inference/v1/generate``. """ await init_app_state(engine_client, state, args, supported_tasks) @@ -300,8 +300,8 @@ async def custom_init_app_state( state.openai_serving_chat_with_tokens = None # Swap in our ServingTokens subclass for /inference/v1/generate so the - # X-data-parallel-rank header and routed_experts response field — both - # used by prime-RL's renderer / router-replay paths — keep working. + # X-data-parallel-rank header and server-side max_tokens defaulting keep + # working. if "generate" in supported_tasks and state.serving_tokens is not None: from prime_rl.inference.vllm.serving_tokens import PrimeRlServingTokens diff --git a/src/prime_rl/inference/vllm/serving_chat_with_tokens.py b/src/prime_rl/inference/vllm/serving_chat_with_tokens.py index fae9465fbe..e044a70664 100644 --- a/src/prime_rl/inference/vllm/serving_chat_with_tokens.py +++ b/src/prime_rl/inference/vllm/serving_chat_with_tokens.py @@ -1,4 +1,4 @@ -from collections.abc import AsyncGenerator, AsyncIterator +from collections.abc import AsyncGenerator from typing import ClassVar, Optional, Union from fastapi import Request @@ -10,72 +10,19 @@ from vllm.entrypoints.utils import get_max_tokens from vllm.exceptions import VLLMValidationError from vllm.logger import init_logger -from vllm.outputs import RequestOutput from vllm.reasoning import ReasoningParser from vllm.sampling_params import BeamSearchParams, SamplingParams -from prime_rl.inference.vllm.serving_tokens import _RoutedExpertsCaptureBase - logger = init_logger(__name__) -class _RoutedExpertsCapture(_RoutedExpertsCaptureBase): - """Chat-endpoint variant: mutates choices in-place because - ``ChatCompletionResponseChoice`` is ``extra='allow'``, so an extra - ``routed_experts`` attribute survives serialization.""" - - def post_process(self, response: ChatCompletionResponse) -> None: - for choice in response.choices: - if choice.index in self.routed_experts: - choice.routed_experts = self.routed_experts[choice.index] - - class ChatCompletionRequestWithTokens(ChatCompletionRequest): field_names: ClassVar[Optional[set[str]]] = None tokens: list[int] = Field(description=("Prompt tokens to use for the request.")) class OpenAIServingChatWithTokens(OpenAIServingChat): - """OpenAI-compatible generate API that allows token-in and routed experts capture.""" - - async def chat_completion_full_generator( - self, - request: ChatCompletionRequest, - result_generator: AsyncIterator[RequestOutput], - request_id: str, - model_name: str, - conversation, - tokenizer, - request_metadata: RequestResponseMetadata, - reasoning_parser: ReasoningParser | None = None, - ) -> ErrorResponse | ChatCompletionResponse: - # We need to override the full_generator to be able to capture the routed experts - # By default, VLLM does not save the routed experts into ChatCompletionResponse.choices, so we need to capture them manually - # How this works: - # 1. We create a custom generator that encapsulates the original result_generator in self._generator - # 2. We override it's __aiter__ method to also capture the routed experts as an extra field in ChatCompletionResponse.choices - # 3. We override the full_generator method to use the custom generator instead of the original one if expert routing is enabled - if self.model_config.enable_return_routed_experts: - capture = _RoutedExpertsCapture(result_generator) - result_generator = capture - else: - capture = None - - response = await super().chat_completion_full_generator( - request, - result_generator, - request_id, - model_name, - conversation, - tokenizer, - request_metadata, - reasoning_parser, - ) - - if capture and isinstance(response, ChatCompletionResponse): - capture.post_process(response) - - return response + """OpenAI-compatible chat API that allows token-in requests.""" async def create_chat_completion_with_tokens( self, diff --git a/src/prime_rl/inference/vllm/serving_tokens.py b/src/prime_rl/inference/vllm/serving_tokens.py index 9fe0833591..2e1b72e9b7 100644 --- a/src/prime_rl/inference/vllm/serving_tokens.py +++ b/src/prime_rl/inference/vllm/serving_tokens.py @@ -3,18 +3,14 @@ vLLM 0.20 ships a generic tokens-in / tokens-out handler at ``vllm.entrypoints.serve.disagg.serving.ServingTokens`` that already covers prefix-cache salting, lora dispatch, multimodal features, prompt logprobs and -priority. Three prime-RL features are not in the upstream protocol though, so +priority. Two prime-RL features are not in the upstream protocol though, so we subclass it to add them back: 1. ``data_parallel_rank`` routing — read from the ``X-data-parallel-rank`` header and forwarded to ``engine_client.generate``. The DP-replicated inference servers prime-RL runs need this to target a specific replica. -2. ``routed_experts`` per-token export — when the engine emits routing - decisions (``enable_return_routed_experts``), surface them on each choice. - This is what the trainer's router-replay path consumes. - -3. Server-side ``max_tokens`` defaulting — ``ServingTokens`` hands the +2. Server-side ``max_tokens`` defaulting — ``ServingTokens`` hands the client-supplied ``SamplingParams`` to the engine verbatim, and ``SamplingParams.max_tokens`` defaults to ``16`` (a dataclass-level default that predates the OpenAI-compat layer). Every other vLLM @@ -30,87 +26,20 @@ from __future__ import annotations -import base64 from collections.abc import AsyncGenerator from functools import cached_property -import numpy as np from fastapi import Request -from pydantic import Field from vllm.entrypoints.openai.engine.protocol import ErrorResponse, RequestResponseMetadata from vllm.entrypoints.serve.disagg.protocol import ( GenerateRequest, GenerateResponse, - GenerateResponseChoice, ) from vllm.entrypoints.serve.disagg.serving import ServingTokens from vllm.entrypoints.utils import get_max_tokens -from vllm.outputs import RequestOutput from vllm.sampling_params import RequestOutputKind, SamplingParams -class PrimeRlGenerateResponseChoice(GenerateResponseChoice): - routed_experts: dict | None = Field( - default=None, - description=( - "Per-token expert routing decisions (base85-encoded int32 array + shape). " - "Populated only when the engine was launched with " - "``enable_return_routed_experts=True``; otherwise ``None``." - ), - ) - - -class PrimeRlGenerateResponse(GenerateResponse): - choices: list[PrimeRlGenerateResponseChoice] - - -def encode_routed_experts(arr: np.ndarray) -> dict: - return { - "data": base64.b85encode(arr.tobytes()).decode("ascii"), - "shape": list(arr.shape), - } - - -class _RoutedExpertsCaptureBase: - """Wraps the engine result generator and accumulates a - ``{output_index: encoded_experts}`` map as outputs stream. Subclasses - implement ``post_process`` to fold the captured map into the response - in whatever shape the endpoint returns (in-place vs rebuilt).""" - - def __init__(self, generator: AsyncGenerator[RequestOutput, None]): - self._generator = generator - self.routed_experts: dict[int, dict] = {} - - async def __aiter__(self): - async for request_output in self._generator: - for output in request_output.outputs: - if output.routed_experts is not None: - self.routed_experts[output.index] = encode_routed_experts(output.routed_experts) - yield request_output - - -class _RoutedExpertsCapture(_RoutedExpertsCaptureBase): - """Generate-endpoint variant: rebuilds the response with - ``PrimeRlGenerateResponseChoice`` because upstream's - ``GenerateResponseChoice`` isn't ``extra='allow'``, so an attribute - set after construction wouldn't survive serialization.""" - - def post_process(self, response: GenerateResponse) -> PrimeRlGenerateResponse: - new_choices = [ - PrimeRlGenerateResponseChoice( - **choice.model_dump(), - routed_experts=self.routed_experts.get(choice.index), - ) - for choice in response.choices - ] - return PrimeRlGenerateResponse( - request_id=response.request_id, - choices=new_choices, - prompt_logprobs=response.prompt_logprobs, - kv_transfer_params=response.kv_transfer_params, - ) - - async def _client_set_max_tokens(raw_request: Request | None) -> bool: """Whether the inbound JSON body carried ``sampling_params.max_tokens``. @@ -135,7 +64,7 @@ async def _client_set_max_tokens(raw_request: Request | None) -> bool: class PrimeRlServingTokens(ServingTokens): - """ServingTokens + DP-rank routing + routed_experts export + max_tokens defaulting.""" + """ServingTokens + DP-rank routing + max_tokens defaulting.""" @cached_property def _max_tokens_defaults(self) -> tuple[dict, int | None]: @@ -162,13 +91,11 @@ async def serve_tokens( self, request: GenerateRequest, raw_request: Request | None = None, - ) -> PrimeRlGenerateResponse | ErrorResponse | AsyncGenerator[str, None]: + ) -> GenerateResponse | ErrorResponse | AsyncGenerator[str, None]: # Mirrors upstream ``ServingTokens.serve_tokens`` (vllm 0.20). Diffs: # (a) inject ``data_parallel_rank`` from the inbound header into # ``engine_client.generate``; (b) default ``sampling_params.max_tokens`` - # to ``max_model_len - prompt_len`` when the caller didn't set it; and - # (c) dispatch to our overridden response builder so ``routed_experts`` - # makes it into the JSON. + # to ``max_model_len - prompt_len`` when the caller didn't set it. error_check_ret = await self._check_model(request) if error_check_ret is not None: return error_check_ret @@ -261,9 +188,6 @@ async def serve_tokens( ) if request.stream: - # Streaming path: defer to upstream — prime-RL's renderer client - # only consumes the full response, so adding routed_experts to the - # streaming choice schema is unnecessary churn. return self.serve_tokens_stream_generator( request, result_generator, @@ -275,31 +199,3 @@ async def serve_tokens( return await self.serve_tokens_full_generator( request, result_generator, request_id, model_name, request_metadata ) - - async def serve_tokens_full_generator( # type: ignore[override] - self, - request: GenerateRequest, - result_generator: AsyncGenerator[RequestOutput, None], - request_id: str, - model_name: str, - request_metadata: RequestResponseMetadata, - ) -> ErrorResponse | GenerateResponse: - # Mirror serving_chat_with_tokens: wrap the result generator to capture - # routed_experts as it streams, defer the rest to upstream, then post- - # process the response into our PrimeRlGenerateResponse subclass so the - # encoded experts surface in the JSON. Skipping the wrapper when the - # engine isn't producing routed experts keeps us a no-op subclass on - # the common path. - capture: _RoutedExpertsCapture | None = None - if self.model_config.enable_return_routed_experts: - capture = _RoutedExpertsCapture(result_generator) - result_generator = capture # type: ignore[assignment] - - response = await super().serve_tokens_full_generator( - request, result_generator, request_id, model_name, request_metadata - ) - - if capture is not None and isinstance(response, GenerateResponse): - response = capture.post_process(response) - - return response diff --git a/uv.lock b/uv.lock index 4eac4b601d..e0270e7836 100644 --- a/uv.lock +++ b/uv.lock @@ -11,7 +11,7 @@ supported-markers = [ ] [options] -exclude-newer = "2026-05-06T12:10:58.853628905Z" +exclude-newer = "2026-05-06T12:31:19.143031393Z" exclude-newer-span = "P7D" [options.exclude-newer-package] @@ -4720,7 +4720,7 @@ dependencies = [ { name = "uvicorn", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, ] wheels = [ - { filename = "vllm_router-0.1.23-cp38-abi3-linux_x86_64.whl", hash = "sha256:545aee83e3b10901a50d8426b32105bf3eeb9fa8af8c71cdc7f4cfc4303da42f" }, + { filename = "vllm_router-0.1.23-cp38-abi3-linux_x86_64.whl", hash = "sha256:306525b002ec3d8652fcbfaf37cfa46f1fde48180e01ffa0efa0e55d952bbfc2" }, ] [package.metadata] From 18e9a7a73bb765d0b14f4330c584586c97b12ca8 Mon Sep 17 00:00:00 2001 From: S1ro1 Date: Wed, 13 May 2026 18:14:54 +0530 Subject: [PATCH 04/32] fix: pin routed experts dependencies for ci --- pyproject.toml | 8 +- tests/unit/inference/test_serving_tokens.py | 62 ++---------- uv.lock | 101 +++----------------- 3 files changed, 24 insertions(+), 147 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9e1806908c..931bf9ab38 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -167,16 +167,16 @@ prime-rl-configs = { workspace = true } torch = { index = "pytorch-cu128" } torchvision = { index = "pytorch-cu128" } torchaudio = { index = "pytorch-cu128" } -verifiers = { path = "third_party/verifiers", editable = true } -wordle = { path = "third_party/verifiers/environments/wordle" } +verifiers = { git = "https://github.com/PrimeIntellect-ai/verifiers", rev = "bb54a3e" } +wordle = { git = "https://github.com/PrimeIntellect-ai/verifiers", rev = "bb54a3e", subdirectory = "environments/wordle" } torchtitan = { git = "https://github.com/pytorch/torchtitan", rev = "a1fdd7e" } dion = { git = "https://github.com/samsja/dion.git", rev = "d891eeb" } transformers = { git = "https://github.com/huggingface/transformers.git", rev = "c1c3424" } flash-attn-4 = { git = "https://github.com/Dao-AILab/flash-attention.git", subdirectory = "flash_attn/cute", rev = "96bd151" } pydantic-config = { git = "https://github.com/samsja/pydantic_config.git", branch = "main" } -vllm-router = { path = "third_party/router/dist-r3-v3-router-cache/vllm_router-0.1.23-cp38-abi3-linux_x86_64.whl" } +vllm-router = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm_router-0.1.23-cp38-abi3-linux_x86_64.whl" } vllm = [ - { path = "third_party/vllm/dist-r3-v3-external/vllm-0.20.1rc1.dev97+g6acff036e.d20260513.precompiled-cp312-cp312-linux_x86_64.whl", marker = "platform_machine == 'x86_64'" }, + { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.20.1rc1.dev97+g6acff036e.d20260513.precompiled-cp312-cp312-linux_x86_64.whl", marker = "platform_machine == 'x86_64'" }, { url = "https://github.com/vllm-project/vllm/releases/download/v0.20.2/vllm-0.20.2+cu129-cp38-abi3-manylinux_2_31_aarch64.whl", marker = "platform_machine == 'aarch64'" }, ] reverse-text = { index = "primeintellect" } diff --git a/tests/unit/inference/test_serving_tokens.py b/tests/unit/inference/test_serving_tokens.py index ac5b52b3d4..978d791333 100644 --- a/tests/unit/inference/test_serving_tokens.py +++ b/tests/unit/inference/test_serving_tokens.py @@ -3,25 +3,18 @@ The full happy-path is owned upstream by vLLM 0.20's ``vllm/entrypoints/serve/disagg`` test suite. We only cover the prime-RL deltas here: - * ``encode_routed_experts`` round-trips a numpy array as expected. - * ``PrimeRlGenerateResponseChoice`` accepts the optional field. - * The subclass attaches its overrides without monkey-patching the parent. + * The subclass only overrides ``serve_tokens`` for DP-rank routing and + server-side max-tokens defaulting. * ``_client_set_max_tokens`` distinguishes raw-body shapes correctly. """ from __future__ import annotations import asyncio -import base64 - -import numpy as np from prime_rl.inference.vllm.serving_tokens import ( - PrimeRlGenerateResponse, - PrimeRlGenerateResponseChoice, PrimeRlServingTokens, _client_set_max_tokens, - encode_routed_experts, ) @@ -36,50 +29,9 @@ async def json(self): return self._body -def test_encode_routed_experts_roundtrip(): - arr = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32) - encoded = encode_routed_experts(arr) - - assert encoded["shape"] == [2, 3] - decoded = np.frombuffer(base64.b85decode(encoded["data"]), dtype=np.int32).reshape(encoded["shape"]) - np.testing.assert_array_equal(decoded, arr) - - -def test_routed_experts_choice_accepts_none_and_dict(): - no_re = PrimeRlGenerateResponseChoice(index=0, finish_reason="stop", token_ids=[1, 2]) - assert no_re.routed_experts is None - - encoded = encode_routed_experts(np.zeros((1, 1), dtype=np.int32)) - with_re = PrimeRlGenerateResponseChoice(index=0, finish_reason="stop", token_ids=[1], routed_experts=encoded) - assert with_re.routed_experts == encoded - - -def test_response_only_serializes_declared_fields(): - # Upstream silently drops id=/created=/model=/usage= because they're not - # declared on GenerateResponse. Our subclass adds nothing to that surface - # — it only widens the choices type — so the JSON shape stays slim. - resp = PrimeRlGenerateResponse( - request_id="gen-x", - choices=[PrimeRlGenerateResponseChoice(index=0, finish_reason="stop", token_ids=[7])], - ) - dumped = resp.model_dump() - assert set(dumped.keys()) == { - "request_id", - "choices", - "prompt_logprobs", - "kv_transfer_params", - } - assert dumped["choices"][0]["routed_experts"] is None - - -def test_subclass_inherits_serve_tokens_full_generator(): - # The subclass adds an override; make sure we didn't accidentally rebind - # ``serve_tokens`` to a parent attribute via __dict__-update tricks later. - assert ( - PrimeRlServingTokens.serve_tokens_full_generator - is not PrimeRlServingTokens.__mro__[1].serve_tokens_full_generator - ) +def test_subclass_only_overrides_serve_tokens(): assert PrimeRlServingTokens.serve_tokens is not PrimeRlServingTokens.__mro__[1].serve_tokens + assert "serve_tokens_full_generator" not in PrimeRlServingTokens.__dict__ def test_client_set_max_tokens_recognizes_explicit_value(): @@ -96,12 +48,12 @@ def test_client_set_max_tokens_detects_unset(): def test_client_set_max_tokens_assumes_set_when_body_unreadable(): - # No raw_request → can't tell, don't override. + # No raw_request: can't tell, don't override. assert asyncio.run(_client_set_max_tokens(None)) is True - # body read raises → can't tell, don't override. + # body read raises: can't tell, don't override. err = ValueError("bad json") assert asyncio.run(_client_set_max_tokens(_FakeRawRequest(err))) is True - # non-dict body → can't tell, don't override. + # non-dict body: can't tell, don't override. assert asyncio.run(_client_set_max_tokens(_FakeRawRequest([1, 2, 3]))) is True diff --git a/uv.lock b/uv.lock index e0270e7836..07251a5534 100644 --- a/uv.lock +++ b/uv.lock @@ -11,7 +11,7 @@ supported-markers = [ ] [options] -exclude-newer = "2026-05-06T12:31:19.143031393Z" +exclude-newer = "2026-05-06T12:37:29.025733799Z" exclude-newer-span = "P7D" [options.exclude-newer-package] @@ -2812,7 +2812,7 @@ dependencies = [ { name = "transformers", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "uvloop", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "verifiers", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "vllm", version = "0.20.1rc1.dev97+g6acff036e.d20260513.precompiled", source = { path = "third_party/vllm/dist-r3-v3-external/vllm-0.20.1rc1.dev97+g6acff036e.d20260513.precompiled-cp312-cp312-linux_x86_64.whl" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "vllm", version = "0.20.1rc1.dev97+g6acff036e.d20260513.precompiled", source = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.20.1rc1.dev97+g6acff036e.d20260513.precompiled-cp312-cp312-linux_x86_64.whl" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "vllm", version = "0.20.2+cu129", source = { url = "https://github.com/vllm-project/vllm/releases/download/v0.20.2/vllm-0.20.2+cu129-cp38-abi3-manylinux_2_31_aarch64.whl" }, marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, { name = "wandb", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "wordle", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, @@ -2938,14 +2938,14 @@ requires-dist = [ { name = "torchvision", index = "https://download.pytorch.org/whl/cu128" }, { name = "transformers", git = "https://github.com/huggingface/transformers.git?rev=c1c3424" }, { name = "uvloop", specifier = ">=0.21.0" }, - { name = "verifiers", editable = "third_party/verifiers" }, + { name = "verifiers", git = "https://github.com/PrimeIntellect-ai/verifiers?rev=bb54a3e" }, { name = "vllm", marker = "platform_machine != 'aarch64' and platform_machine != 'x86_64'" }, { name = "vllm", marker = "platform_machine == 'aarch64'", url = "https://github.com/vllm-project/vllm/releases/download/v0.20.2/vllm-0.20.2+cu129-cp38-abi3-manylinux_2_31_aarch64.whl" }, - { name = "vllm", marker = "platform_machine == 'x86_64'", path = "third_party/vllm/dist-r3-v3-external/vllm-0.20.1rc1.dev97+g6acff036e.d20260513.precompiled-cp312-cp312-linux_x86_64.whl" }, - { name = "vllm-router", marker = "platform_machine == 'x86_64' and extra == 'disagg'", path = "third_party/router/dist-r3-v3-router-cache/vllm_router-0.1.23-cp38-abi3-linux_x86_64.whl" }, + { name = "vllm", marker = "platform_machine == 'x86_64'", url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.20.1rc1.dev97+g6acff036e.d20260513.precompiled-cp312-cp312-linux_x86_64.whl" }, + { name = "vllm-router", marker = "platform_machine == 'x86_64' and extra == 'disagg'", url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm_router-0.1.23-cp38-abi3-linux_x86_64.whl" }, { name = "wandb", specifier = ">=0.26.1" }, { name = "wiki-search", marker = "extra == 'envs'", index = "https://hub.primeintellect.ai/primeintellect/simple/" }, - { name = "wordle", directory = "third_party/verifiers/environments/wordle" }, + { name = "wordle", git = "https://github.com/PrimeIntellect-ai/verifiers?subdirectory=environments%2Fwordle&rev=bb54a3e" }, ] provides-extras = ["flash-attn", "flash-attn-3", "flash-attn-cute", "envs", "disagg", "gpt-oss", "quack", "all"] @@ -4247,7 +4247,8 @@ wheels = [ [[package]] name = "verifiers" -source = { editable = "third_party/verifiers" } +version = "0.1.15.dev2" +source = { git = "https://github.com/PrimeIntellect-ai/verifiers?rev=bb54a3e#bb54a3efc5ee1d48b2328fc2d07c690fb50cd7e0" } dependencies = [ { name = "aiolimiter", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "anthropic", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, @@ -4274,75 +4275,6 @@ dependencies = [ { name = "textual", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] -[package.metadata] -requires-dist = [ - { name = "accelerate", marker = "extra == 'rl'", specifier = ">=1.4.0" }, - { name = "aiohttp", marker = "extra == 'browser'", specifier = ">=3.9.0" }, - { name = "aiolimiter", specifier = ">=1.2.1" }, - { name = "anthropic", specifier = ">=0.78.0" }, - { name = "datasets", specifier = ">=3.0.0,<4.7.0" }, - { name = "deepspeed", marker = "extra == 'rl'", specifier = ">=0.17.6" }, - { name = "flash-attn", marker = "extra == 'rl'", specifier = ">=2.8.3" }, - { name = "gepa" }, - { name = "httpx", specifier = ">=0.27.0" }, - { name = "jinja2", specifier = ">=3.1.6" }, - { name = "liger-kernel", marker = "extra == 'rl'", specifier = ">=0.5.10" }, - { name = "math-verify", specifier = ">=0.8.0" }, - { name = "mcp", specifier = ">=1.14.1" }, - { name = "msgpack", specifier = ">=1.1.2" }, - { name = "nest-asyncio", specifier = ">=1.6.0" }, - { name = "nltk", marker = "extra == 'ta'" }, - { name = "numpy" }, - { name = "openai", specifier = ">=1.108.1" }, - { name = "openai-agents", specifier = ">=0.0.7" }, - { name = "openenv-core", extras = ["core"], marker = "extra == 'openenv'", specifier = "==0.2.1" }, - { name = "peft", marker = "extra == 'rl'" }, - { name = "prime-sandboxes", specifier = ">=0.2.25" }, - { name = "prime-tunnel", specifier = ">=0.1.6" }, - { name = "pydantic", specifier = ">=2.11.9" }, - { name = "python-dotenv", marker = "extra == 'browser'", specifier = ">=1.0.0" }, - { name = "pyzmq", specifier = ">=27.1.0" }, - { name = "reasoning-gym", marker = "extra == 'rg'" }, - { name = "regex", specifier = "<2026.4.4" }, - { name = "renderers", marker = "extra == 'renderers'", specifier = ">=0.1.8.dev0" }, - { name = "requests" }, - { name = "requests", marker = "extra == 'rl'" }, - { name = "rich" }, - { name = "setproctitle", specifier = ">=1.3.0" }, - { name = "stagehand", marker = "extra == 'browser'", specifier = ">=3.0.0" }, - { name = "tenacity", specifier = ">=8.5.0" }, - { name = "textarena", marker = "extra == 'ta'" }, - { name = "textual" }, - { name = "tomli", marker = "python_full_version < '3.11'" }, - { name = "torch", marker = "extra == 'rl'", specifier = ">=2.8.0,<2.9.0" }, - { name = "transformers", marker = "extra == 'rl'", specifier = ">=4.56.2" }, - { name = "typing-extensions", marker = "python_full_version < '3.12'" }, - { name = "vllm", marker = "extra == 'rl'", specifier = ">=0.10.0,<0.11.0" }, - { name = "wandb", marker = "extra == 'rl'" }, -] -provides-extras = ["browser", "openenv", "renderers", "rg", "rl", "ta"] - -[package.metadata.requires-dev] -dev = [ - { name = "aiohttp", specifier = ">=3.9.0" }, - { name = "ipykernel" }, - { name = "ipywidgets" }, - { name = "nltk" }, - { name = "openenv-core", extras = ["core"], specifier = "==0.2.1" }, - { name = "pre-commit" }, - { name = "pytest", specifier = ">=7.0.0" }, - { name = "pytest-asyncio", specifier = ">=0.21.0" }, - { name = "pytest-cov", specifier = ">=4.0.0" }, - { name = "pytest-xdist", specifier = ">=3.8.0" }, - { name = "python-dotenv", specifier = ">=1.0.0" }, - { name = "reasoning-gym" }, - { name = "renderers", specifier = ">=0.1.8.dev0" }, - { name = "ruff" }, - { name = "stagehand", specifier = ">=3.0.0" }, - { name = "textarena" }, - { name = "ty", specifier = ">=0.0.1a29,<0.0.22" }, -] - [[package]] name = "virtualenv" version = "20.34.0" @@ -4360,7 +4292,7 @@ wheels = [ [[package]] name = "vllm" version = "0.20.1rc1.dev97+g6acff036e.d20260513.precompiled" -source = { path = "third_party/vllm/dist-r3-v3-external/vllm-0.20.1rc1.dev97+g6acff036e.d20260513.precompiled-cp312-cp312-linux_x86_64.whl" } +source = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.20.1rc1.dev97+g6acff036e.d20260513.precompiled-cp312-cp312-linux_x86_64.whl" } resolution-markers = [ "platform_machine == 'x86_64' and sys_platform == 'linux'", ] @@ -4435,7 +4367,7 @@ dependencies = [ { name = "xgrammar", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, ] wheels = [ - { filename = "vllm-0.20.1rc1.dev97+g6acff036e.d20260513.precompiled-cp312-cp312-linux_x86_64.whl", hash = "sha256:8db1e80a5f4dd97237d7c5702b33f37a65910db9976b42db4f58937ddd0ffd48" }, + { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.20.1rc1.dev97+g6acff036e.d20260513.precompiled-cp312-cp312-linux_x86_64.whl", hash = "sha256:8db1e80a5f4dd97237d7c5702b33f37a65910db9976b42db4f58937ddd0ffd48" }, ] [package.metadata] @@ -4710,7 +4642,7 @@ provides-extras = ["zen", "bench", "tensorizer", "fastsafetensors", "instanttens [[package]] name = "vllm-router" version = "0.1.23" -source = { path = "third_party/router/dist-r3-v3-router-cache/vllm_router-0.1.23-cp38-abi3-linux_x86_64.whl" } +source = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm_router-0.1.23-cp38-abi3-linux_x86_64.whl" } dependencies = [ { name = "aiohttp", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "fastapi", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, @@ -4720,7 +4652,7 @@ dependencies = [ { name = "uvicorn", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, ] wheels = [ - { filename = "vllm_router-0.1.23-cp38-abi3-linux_x86_64.whl", hash = "sha256:306525b002ec3d8652fcbfaf37cfa46f1fde48180e01ffa0efa0e55d952bbfc2" }, + { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm_router-0.1.23-cp38-abi3-linux_x86_64.whl", hash = "sha256:306525b002ec3d8652fcbfaf37cfa46f1fde48180e01ffa0efa0e55d952bbfc2" }, ] [package.metadata] @@ -4854,20 +4786,13 @@ wheels = [ [[package]] name = "wordle" version = "0.1.7" -source = { directory = "third_party/verifiers/environments/wordle" } +source = { git = "https://github.com/PrimeIntellect-ai/verifiers?subdirectory=environments%2Fwordle&rev=bb54a3e#bb54a3efc5ee1d48b2328fc2d07c690fb50cd7e0" } dependencies = [ { name = "nltk", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "textarena", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "verifiers", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] -[package.metadata] -requires-dist = [ - { name = "nltk", specifier = ">=3.9.2" }, - { name = "textarena", specifier = "==0.7.4" }, - { name = "verifiers", specifier = ">=0.1.9.post3" }, -] - [[package]] name = "xgrammar" version = "0.1.33" From 90d2e3adfde1d3909ef2190fba9a1bd15136100f Mon Sep 17 00:00:00 2001 From: S1ro1 Date: Wed, 13 May 2026 18:34:00 +0530 Subject: [PATCH 05/32] fix: update scheduler tests for prefix reset --- tests/unit/orchestrator/test_scheduler.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/unit/orchestrator/test_scheduler.py b/tests/unit/orchestrator/test_scheduler.py index 9e73b5207d..ae3415c29f 100644 --- a/tests/unit/orchestrator/test_scheduler.py +++ b/tests/unit/orchestrator/test_scheduler.py @@ -15,7 +15,11 @@ def make_scheduler() -> Scheduler: scheduler.strict_async_level = False scheduler.step = 9 scheduler.ckpt_step = 7 - scheduler.config = SimpleNamespace(output_dir=Path("/tmp/prime-rl-test")) + scheduler.config = SimpleNamespace( + output_dir=Path("/tmp/prime-rl-test"), + reset_prefix_cache_on_policy_update=False, + client=None, + ) scheduler.logger = MagicMock() scheduler.checkpoint_ready = asyncio.Event() scheduler.checkpoint_ready.set() @@ -97,7 +101,7 @@ async def run() -> None: release = asyncio.Event() applied_steps: list[int] = [] - async def update_weights(weight_dir, lora_name=None, step=0) -> None: + async def update_weights(weight_dir, lora_name=None, step=0, reset_prefix_cache=False) -> None: applied_steps.append(step) started.set() await release.wait() @@ -135,7 +139,7 @@ async def run() -> None: started = asyncio.Event() cancelled = asyncio.Event() - async def update_weights(weight_dir, lora_name=None, step=0) -> None: + async def update_weights(weight_dir, lora_name=None, step=0, reset_prefix_cache=False) -> None: started.set() try: await asyncio.Future() From 1fea38ec427fff78c08a5efa46177ad54f1d9bad Mon Sep 17 00:00:00 2001 From: S1ro1 Date: Thu, 14 May 2026 19:35:07 +0530 Subject: [PATCH 06/32] fix: clean routed experts replay integration --- .../qwen30b_kv_offload_routed_experts.toml | 32 ------ configs/r3_v3/qwen30b_wordle_r3.toml | 108 ------------------ .../src/prime_rl/configs/orchestrator.py | 10 -- .../src/prime_rl/configs/rl.py | 13 +++ pyproject.toml | 11 +- skills/config/SKILL.md | 4 - src/prime_rl/entrypoints/rl.py | 11 -- src/prime_rl/inference/vllm/routed_experts.py | 48 ++++++++ src/prime_rl/inference/vllm/server.py | 16 +-- .../vllm/serving_chat_with_tokens.py | 39 ++++++- src/prime_rl/inference/vllm/serving_tokens.py | 63 +++++++++- src/prime_rl/orchestrator/orchestrator.py | 10 +- src/prime_rl/orchestrator/scheduler.py | 11 +- src/prime_rl/orchestrator/trajectories.py | 95 ++++++++++----- src/prime_rl/trainer/batch.py | 66 ++++++++++- src/prime_rl/trainer/rl/data.py | 25 +++- src/prime_rl/transport/types.py | 8 +- src/prime_rl/utils/client.py | 78 ++----------- src/prime_rl/utils/elastic.py | 8 +- tests/unit/inference/test_serving_tokens.py | 54 +++++++-- tests/unit/orchestrator/test_batch.py | 27 ++++- tests/unit/orchestrator/test_scheduler.py | 10 +- tests/unit/orchestrator/test_trajectories.py | 64 +++++++---- uv.lock | 94 +++------------ 24 files changed, 465 insertions(+), 440 deletions(-) delete mode 100644 configs/r3_v3/qwen30b_kv_offload_routed_experts.toml delete mode 100644 configs/r3_v3/qwen30b_wordle_r3.toml create mode 100644 src/prime_rl/inference/vllm/routed_experts.py diff --git a/configs/r3_v3/qwen30b_kv_offload_routed_experts.toml b/configs/r3_v3/qwen30b_kv_offload_routed_experts.toml deleted file mode 100644 index 347c73bfac..0000000000 --- a/configs/r3_v3/qwen30b_kv_offload_routed_experts.toml +++ /dev/null @@ -1,32 +0,0 @@ -output_dir = "outputs/qwen30b-kv-offload-r3-v3" -enable_return_routed_experts = true -enable_prefix_caching = true -gpu_memory_utilization = 0.9 - -[model] -name = "Qwen/Qwen3-30B-A3B-Thinking-2507" -max_model_len = 1024 -enforce_eager = true - -[parallel] -tp = 8 -dp = 1 - -[deployment] -type = "single_node" -gpus_per_node = 8 - -[slurm] -job_name = "qwen30b-kv-offload-r3-v3" -partition = "preempt" -time = "02:00:00" -pre_run_command = "uv sync --all-extras --reinstall-package vllm --reinstall-package nvidia-cudnn-cu12 --reinstall-package nvidia-nccl-cu12 --reinstall-package nvidia-cusparselt-cu12 --reinstall-package nvidia-nvshmem-cu12" - -[kv_cache_offload] -cpu_bytes = 17179869184 - -[vllm_extra] -async_scheduling = false -kv_cache_memory_bytes = 536870912 -max_num_batched_tokens = 1024 -max_num_seqs = 16 diff --git a/configs/r3_v3/qwen30b_wordle_r3.toml b/configs/r3_v3/qwen30b_wordle_r3.toml deleted file mode 100644 index 244ae07445..0000000000 --- a/configs/r3_v3/qwen30b_wordle_r3.toml +++ /dev/null @@ -1,108 +0,0 @@ -output_dir = "outputs/qwen30b-wordle-r3-v3" -clean_output_dir = true -max_steps = 20 -seq_len = 4096 - -[log] -level = "debug" - -[model] -name = "Qwen/Qwen3-30B-A3B-Thinking-2507" - -[deployment] -type = "multi_node" -num_train_nodes = 1 -num_infer_nodes = 1 - -[slurm] -job_name = "qwen30b-wordle-r3-v3" -partition = "preempt" -time = "06:00:00" -pre_run_command = "uv sync --all-extras --reinstall-package nvidia-cudnn-cu12 --reinstall-package nvidia-nccl-cu12 --reinstall-package nvidia-cusparselt-cu12 --reinstall-package nvidia-nvshmem-cu12" - -[wandb] -project = "qwen30b-wordle" -name = "qwen30b-wordle-r3-v3" -group = "qwen30b-wordle-r3-v3" - -[weight_broadcast] -type = "nccl" -timeout = 3600 - -[trainer] -enable_router_replay = true -max_concurrent_runs = 1 -dist_timeout_seconds = 3600 - -[trainer.model] -impl = "custom" -attn = "flash_attention_3" -ep = 8 -optimization_dtype = "float32" -reduce_dtype = "float32" - -[trainer.model.ac] -mode = "full" -freq = 1 -targets = ["norm"] - -[trainer.model.ac_offloading] -max_inflight_activations = 5 - -[trainer.optim] -type = "adamw" -lr = 1e-6 - -[inference] -enable_return_routed_experts = true - -[inference.model] -max_model_len = 4096 - -[inference.parallel] -tp = 8 -dp = 1 - -[inference.vllm_extra] -async_scheduling = false - -[orchestrator] -filters = [] -batch_size = 64 -max_inflight_rollouts = 64 -rollouts_per_example = 8 -max_off_policy_steps = 8 -use_token_client = true - -[[orchestrator.train.env]] -id = "primeintellect/wordle" -name = "wordle" -num_workers = 1 -max_retries = 0 -max_total_completion_tokens = -1 - -[orchestrator.train.env.extra_env_kwargs] -max_total_completion_tokens = -1 -max_seq_len = 4096 - -[orchestrator.train.sampling] -temperature = 1.0 -repetition_penalty = 1.0 -max_completion_tokens = 1024 -min_tokens = 0 - -[orchestrator.train.sampling.extra_body] -top_k = -1 -min_p = 0.0 -return_token_ids = true - -[orchestrator.client] -timeout = 1200 -wait_for_ready_timeout = 1800 - -[orchestrator.client.extra_headers_from_state] -X-Session-ID = "example_id" - -[orchestrator.buffer] -easy_threshold = 1.0 -hard_threshold = 0.0 diff --git a/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py b/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py index 42111daf2e..5d04d3369f 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py @@ -1090,16 +1090,6 @@ class OrchestratorConfig(BaseConfig): ), ] = None - reset_prefix_cache_on_policy_update: Annotated[ - bool, - Field( - description=( - "Reset vLLM prefix caches when updating inference weights. This prevents stale KV cache reuse " - "across policy steps, at the cost of losing cross-policy prefix-cache hits." - ), - ), - ] = False - max_async_level: Annotated[ int, Field( diff --git a/packages/prime-rl-configs/src/prime_rl/configs/rl.py b/packages/prime-rl-configs/src/prime_rl/configs/rl.py index a160af2c9f..8d5f087da7 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/rl.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/rl.py @@ -791,6 +791,19 @@ def auto_setup_router_replay(self): ) return self + @model_validator(mode="after") + def validate_router_replay_without_kv_offload(self): + if ( + self.trainer.enable_router_replay + and self.inference is not None + and self.inference.kv_cache_offload is not None + ): + raise ValueError( + "Router replay with inference.kv_cache_offload is not supported. " + "External KV cache hits do not carry routed-expert decisions." + ) + return self + @model_validator(mode="after") def auto_setup_deployment(self): if self.deployment.type == "single_node": # single-node diff --git a/pyproject.toml b/pyproject.toml index 931bf9ab38..dc99465e41 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,8 +31,7 @@ dependencies = [ "uvloop>=0.21.0", "torchtitan", "verifiers", - "renderers>=0.1.8.dev0", - "wordle", + "renderers==0.1.6", "dion", "tilelang>=0.1.8", "flash-linear-attention", @@ -167,16 +166,16 @@ prime-rl-configs = { workspace = true } torch = { index = "pytorch-cu128" } torchvision = { index = "pytorch-cu128" } torchaudio = { index = "pytorch-cu128" } -verifiers = { git = "https://github.com/PrimeIntellect-ai/verifiers", rev = "bb54a3e" } -wordle = { git = "https://github.com/PrimeIntellect-ai/verifiers", rev = "bb54a3e", subdirectory = "environments/wordle" } +verifiers = { git = "https://github.com/PrimeIntellect-ai/verifiers", rev = "11dbe34" } torchtitan = { git = "https://github.com/pytorch/torchtitan", rev = "a1fdd7e" } dion = { git = "https://github.com/samsja/dion.git", rev = "d891eeb" } transformers = { git = "https://github.com/huggingface/transformers.git", rev = "c1c3424" } flash-attn-4 = { git = "https://github.com/Dao-AILab/flash-attention.git", subdirectory = "flash_attn/cute", rev = "96bd151" } pydantic-config = { git = "https://github.com/samsja/pydantic_config.git", branch = "main" } -vllm-router = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm_router-0.1.23-cp38-abi3-linux_x86_64.whl" } +# TODO: update router wheel when the routed-experts P/D stitching release is ready. +vllm-router = { url = "https://github.com/PrimeIntellect-ai/router/releases/download/v0.1.22/vllm_router-0.1.22-cp38-abi3-manylinux_2_28_x86_64.whl" } vllm = [ - { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.20.1rc1.dev97+g6acff036e.d20260513.precompiled-cp312-cp312-linux_x86_64.whl", marker = "platform_machine == 'x86_64'" }, + { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.20.1rc1.dev99+g77adbf599.precompiled-cp312-cp312-linux_x86_64.whl", marker = "platform_machine == 'x86_64'" }, { url = "https://github.com/vllm-project/vllm/releases/download/v0.20.2/vllm-0.20.2+cu129-cp38-abi3-manylinux_2_31_aarch64.whl", marker = "platform_machine == 'aarch64'" }, ] reverse-text = { index = "primeintellect" } diff --git a/skills/config/SKILL.md b/skills/config/SKILL.md index a2ecd68dd9..e8dc13216c 100644 --- a/skills/config/SKILL.md +++ b/skills/config/SKILL.md @@ -157,10 +157,6 @@ If you wish to configure values of the default variant, you don't need to set th For hosted multi-tenant runs where the trainer image's `trainer.loss.type` is fixed, the orchestrator exposes a per-run override that forces SFT loss on every micro-batch without rebuilding the trainer. Set `orchestrator.use_sft_loss = true` alongside `orchestrator.teacher_rollout_model`; both must be configured together (the orchestrator validator enforces this). The orchestrator stamps each `TrainingSample.sft_loss = True`, which the trainer's `compute_loss` honors by dispatching to `sft_loss_fn` per batch — independent of the trainer's configured default loss. -### Router replay with KV offload - -When `trainer.enable_router_replay = true` and inference CPU KV offload is configured, RL config auto-sets `orchestrator.reset_prefix_cache_on_policy_update = true`. This makes policy updates pause vLLM in `clear` mode instead of `keep` mode, so old-policy prefix-cache entries are not reused after new weights are loaded. If the rollout client points at a router, the orchestrator also calls the router's `clear_routing_cache` route after backend weight updates resume. - ### Model fields For `BaseModel | None` fields (like `[ckpt]`, `[wandb]`, `[compile]`), a bare flag enables them with defaults: diff --git a/src/prime_rl/entrypoints/rl.py b/src/prime_rl/entrypoints/rl.py index fcf98f02d9..b740f58ba2 100644 --- a/src/prime_rl/entrypoints/rl.py +++ b/src/prime_rl/entrypoints/rl.py @@ -543,18 +543,7 @@ def rl_slurm(config: RLConfig): logger.success(f"{result.stdout.strip()}\n\n{log_message}") -def finalize_policy_update_cache_reset(config: RLConfig) -> None: - if ( - config.trainer.enable_router_replay - and config.inference is not None - and config.inference.kv_cache_offload is not None - ): - config.orchestrator.reset_prefix_cache_on_policy_update = True - - def rl(config: RLConfig): - finalize_policy_update_cache_reset(config) - resuming = config.ckpt is not None and config.ckpt.resume_step is not None clean = config.clean_output_dir and not os.environ.get("NEVER_CLEAN_OUTPUT_DIR") ckpt_output_dir = config.ckpt.output_dir if config.ckpt else None diff --git a/src/prime_rl/inference/vllm/routed_experts.py b/src/prime_rl/inference/vllm/routed_experts.py new file mode 100644 index 0000000000..d2a6bf7f78 --- /dev/null +++ b/src/prime_rl/inference/vllm/routed_experts.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +import base64 +from collections.abc import AsyncIterator +from io import BytesIO +from typing import Any + +import numpy as np +from vllm.outputs import RequestOutput + + +def serialize_routed_experts(routed_experts: Any) -> str | None: + if routed_experts is None: + return None + + array = np.asarray(routed_experts) + assert array.ndim == 3 + assert np.issubdtype(array.dtype, np.integer) + + if array.size == 0: + compact = array.astype(np.uint8, copy=False) + else: + min_value = array.min() + max_value = array.max() + if min_value >= 0 and max_value <= np.iinfo(np.uint8).max: + compact = array.astype(np.uint8, copy=False) + elif min_value >= np.iinfo(np.int16).min and max_value <= np.iinfo(np.int16).max: + compact = array.astype(np.int16, copy=False) + else: + compact = array.astype(np.int32, copy=False) + + buffer = BytesIO() + np.save(buffer, np.ascontiguousarray(compact), allow_pickle=False) + return base64.b64encode(buffer.getvalue()).decode("ascii") + + +class RoutedExpertsCapture: + def __init__(self, generator: AsyncIterator[RequestOutput]): + self._generator = generator + self.routed_experts: dict[int, str] = {} + + async def __aiter__(self): + async for request_output in self._generator: + for output in request_output.outputs: + encoded = serialize_routed_experts(getattr(output, "routed_experts", None)) + if encoded is not None: + self.routed_experts[output.index] = encoded + yield request_output diff --git a/src/prime_rl/inference/vllm/server.py b/src/prime_rl/inference/vllm/server.py index d9740a6882..53ae22c104 100644 --- a/src/prime_rl/inference/vllm/server.py +++ b/src/prime_rl/inference/vllm/server.py @@ -1,7 +1,7 @@ import asyncio from argparse import Namespace from http import HTTPStatus -from typing import Any, Literal +from typing import Any import uvloop from fastapi import APIRouter, Depends, Request @@ -210,9 +210,9 @@ async def _chat_with_tokens(request: ChatCompletionRequestWithTokens, raw_reques @router.post("/pause") -async def pause(request: Request, mode: Literal["keep", "clear"] = "keep"): - await engine_client(request).pause_generation(mode="keep", clear_cache=mode == "clear") - return {"status": "paused", "mode": mode} +async def pause(request: Request): + await engine_client(request).pause_generation(mode="keep", clear_cache=False) + return {"status": "paused"} @router.post("/resume") @@ -281,8 +281,8 @@ async def custom_init_app_state( so the ``/v1/chat/completions/tokens`` (TITO) endpoint can stream token IDs alongside the rendered chat completion. 3. Replace ``serving_tokens`` with ``PrimeRlServingTokens`` so DP-rank - routing and server-side ``max_tokens`` defaulting are available on - ``/inference/v1/generate``. + routing and ``routed_experts`` export survive the migration off the + legacy ``/v1/generate`` endpoint. """ await init_app_state(engine_client, state, args, supported_tasks) @@ -300,8 +300,8 @@ async def custom_init_app_state( state.openai_serving_chat_with_tokens = None # Swap in our ServingTokens subclass for /inference/v1/generate so the - # X-data-parallel-rank header and server-side max_tokens defaulting keep - # working. + # X-data-parallel-rank header and routed_experts response field — both + # used by prime-RL's renderer / router-replay paths — keep working. if "generate" in supported_tasks and state.serving_tokens is not None: from prime_rl.inference.vllm.serving_tokens import PrimeRlServingTokens diff --git a/src/prime_rl/inference/vllm/serving_chat_with_tokens.py b/src/prime_rl/inference/vllm/serving_chat_with_tokens.py index e044a70664..4895d333f6 100644 --- a/src/prime_rl/inference/vllm/serving_chat_with_tokens.py +++ b/src/prime_rl/inference/vllm/serving_chat_with_tokens.py @@ -1,4 +1,4 @@ -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, AsyncIterator from typing import ClassVar, Optional, Union from fastapi import Request @@ -10,9 +10,12 @@ from vllm.entrypoints.utils import get_max_tokens from vllm.exceptions import VLLMValidationError from vllm.logger import init_logger +from vllm.outputs import RequestOutput from vllm.reasoning import ReasoningParser from vllm.sampling_params import BeamSearchParams, SamplingParams +from prime_rl.inference.vllm.routed_experts import RoutedExpertsCapture + logger = init_logger(__name__) @@ -24,6 +27,40 @@ class ChatCompletionRequestWithTokens(ChatCompletionRequest): class OpenAIServingChatWithTokens(OpenAIServingChat): """OpenAI-compatible chat API that allows token-in requests.""" + async def chat_completion_full_generator( + self, + request: ChatCompletionRequest, + result_generator: AsyncIterator[RequestOutput], + request_id: str, + model_name: str, + conversation, + tokenizer, + request_metadata: RequestResponseMetadata, + reasoning_parser: ReasoningParser | None = None, + ) -> ErrorResponse | ChatCompletionResponse: + capture = None + if self.model_config.enable_return_routed_experts: + capture = RoutedExpertsCapture(result_generator) + result_generator = capture + + response = await super().chat_completion_full_generator( + request, + result_generator, + request_id, + model_name, + conversation, + tokenizer, + request_metadata, + reasoning_parser, + ) + + if capture is not None and isinstance(response, ChatCompletionResponse): + for choice in response.choices: + if choice.index in capture.routed_experts: + choice.routed_experts = capture.routed_experts[choice.index] + + return response + async def create_chat_completion_with_tokens( self, request: ChatCompletionRequestWithTokens, diff --git a/src/prime_rl/inference/vllm/serving_tokens.py b/src/prime_rl/inference/vllm/serving_tokens.py index 2e1b72e9b7..19a43ca33a 100644 --- a/src/prime_rl/inference/vllm/serving_tokens.py +++ b/src/prime_rl/inference/vllm/serving_tokens.py @@ -3,14 +3,18 @@ vLLM 0.20 ships a generic tokens-in / tokens-out handler at ``vllm.entrypoints.serve.disagg.serving.ServingTokens`` that already covers prefix-cache salting, lora dispatch, multimodal features, prompt logprobs and -priority. Two prime-RL features are not in the upstream protocol though, so +priority. Three prime-RL features are not in the upstream protocol though, so we subclass it to add them back: 1. ``data_parallel_rank`` routing — read from the ``X-data-parallel-rank`` header and forwarded to ``engine_client.generate``. The DP-replicated inference servers prime-RL runs need this to target a specific replica. -2. Server-side ``max_tokens`` defaulting — ``ServingTokens`` hands the +2. Compact ``routed_experts`` export — when the engine emits routing + decisions, surface them as base64 NumPy payloads without requiring a vLLM + source fork. + +3. Server-side ``max_tokens`` defaulting — ``ServingTokens`` hands the client-supplied ``SamplingParams`` to the engine verbatim, and ``SamplingParams.max_tokens`` defaults to ``16`` (a dataclass-level default that predates the OpenAI-compat layer). Every other vLLM @@ -34,11 +38,42 @@ from vllm.entrypoints.serve.disagg.protocol import ( GenerateRequest, GenerateResponse, + GenerateResponseChoice, ) from vllm.entrypoints.serve.disagg.serving import ServingTokens from vllm.entrypoints.utils import get_max_tokens +from vllm.outputs import RequestOutput from vllm.sampling_params import RequestOutputKind, SamplingParams +from prime_rl.inference.vllm.routed_experts import RoutedExpertsCapture + + +class PrimeRlGenerateResponseChoice(GenerateResponseChoice): + routed_experts: str | None = None + + +class PrimeRlGenerateResponse(GenerateResponse): + choices: list[PrimeRlGenerateResponseChoice] + prompt_token_ids: list[int] + + +class _GenerateRoutedExpertsCapture(RoutedExpertsCapture): + def post_process(self, response: GenerateResponse, prompt_token_ids: list[int]) -> PrimeRlGenerateResponse: + choices = [ + PrimeRlGenerateResponseChoice( + **choice.model_dump(), + routed_experts=self.routed_experts.get(choice.index), + ) + for choice in response.choices + ] + return PrimeRlGenerateResponse( + request_id=response.request_id, + choices=choices, + prompt_token_ids=prompt_token_ids, + prompt_logprobs=response.prompt_logprobs, + kv_transfer_params=response.kv_transfer_params, + ) + async def _client_set_max_tokens(raw_request: Request | None) -> bool: """Whether the inbound JSON body carried ``sampling_params.max_tokens``. @@ -64,7 +99,7 @@ async def _client_set_max_tokens(raw_request: Request | None) -> bool: class PrimeRlServingTokens(ServingTokens): - """ServingTokens + DP-rank routing + max_tokens defaulting.""" + """ServingTokens + DP-rank routing + compact routed experts + max_tokens defaulting.""" @cached_property def _max_tokens_defaults(self) -> tuple[dict, int | None]: @@ -199,3 +234,25 @@ async def serve_tokens( return await self.serve_tokens_full_generator( request, result_generator, request_id, model_name, request_metadata ) + + async def serve_tokens_full_generator( # type: ignore[override] + self, + request: GenerateRequest, + result_generator: AsyncGenerator[RequestOutput, None], + request_id: str, + model_name: str, + request_metadata: RequestResponseMetadata, + ) -> ErrorResponse | GenerateResponse: + capture = None + if self.model_config.enable_return_routed_experts: + capture = _GenerateRoutedExpertsCapture(result_generator) + result_generator = capture + + response = await super().serve_tokens_full_generator( + request, result_generator, request_id, model_name, request_metadata + ) + + if capture is not None and isinstance(response, GenerateResponse): + response = capture.post_process(response, request.token_ids) + + return response diff --git a/src/prime_rl/orchestrator/orchestrator.py b/src/prime_rl/orchestrator/orchestrator.py index e4feb8e09c..bc1128ebc7 100644 --- a/src/prime_rl/orchestrator/orchestrator.py +++ b/src/prime_rl/orchestrator/orchestrator.py @@ -55,7 +55,6 @@ ) from prime_rl.trainer.model import setup_tokenizer from prime_rl.utils.client import ( - clear_routing_cache, init_nccl_broadcast, setup_inference_pool, ) @@ -317,14 +316,7 @@ async def orchestrate(config: OrchestratorConfig): config.output_dir, scheduler.ckpt_step, check_exists=check_exists, wait_timeout=wait_timeout ) lora_name = config.model.lora.name if config.model.lora else None - await inference_pool.update_weights( - weights_path, - lora_name=lora_name, - step=scheduler.ckpt_step, - reset_prefix_cache=config.reset_prefix_cache_on_policy_update, - ) - if config.reset_prefix_cache_on_policy_update: - await clear_routing_cache(config.client) + await inference_pool.update_weights(weights_path, lora_name=lora_name, step=scheduler.ckpt_step) else: logger.info("Training from scratch") diff --git a/src/prime_rl/orchestrator/scheduler.py b/src/prime_rl/orchestrator/scheduler.py index ab92c91a24..c266757c1c 100644 --- a/src/prime_rl/orchestrator/scheduler.py +++ b/src/prime_rl/orchestrator/scheduler.py @@ -13,7 +13,7 @@ from prime_rl.orchestrator.envs import TrainEnvs from prime_rl.orchestrator.vf_utils import get_seq_len from prime_rl.utils.async_utils import safe_cancel, safe_cancel_all -from prime_rl.utils.client import InferencePool, clear_routing_cache +from prime_rl.utils.client import InferencePool from prime_rl.utils.logger import ProgressTracker, get_logger from prime_rl.utils.utils import ( get_broadcast_dir, @@ -320,14 +320,7 @@ async def _apply_policy_update(self, next_ckpt_step: int) -> None: update_weights_start_time = time.perf_counter() weights_path = get_step_path(get_broadcast_dir(self.config.output_dir), next_ckpt_step) - await self.inference_pool.update_weights( - weights_path, - lora_name=self.lora_name, - step=next_ckpt_step, - reset_prefix_cache=self.config.reset_prefix_cache_on_policy_update, - ) - if self.config.reset_prefix_cache_on_policy_update: - await clear_routing_cache(self.config.client) + await self.inference_pool.update_weights(weights_path, lora_name=self.lora_name, step=next_ckpt_step) self.update_weights_time = time.perf_counter() - update_weights_start_time self.logger.debug(f"Updated weights to step {next_ckpt_step} in {self.update_weights_time:.2f}s") diff --git a/src/prime_rl/orchestrator/trajectories.py b/src/prime_rl/orchestrator/trajectories.py index 3a45ee9ada..26c86c5b55 100644 --- a/src/prime_rl/orchestrator/trajectories.py +++ b/src/prime_rl/orchestrator/trajectories.py @@ -6,6 +6,7 @@ from pathlib import Path from typing import Any +import numpy as np import torch import verifiers as vf from PIL import Image @@ -25,25 +26,46 @@ # primitives are immutable. pixel_values/image_grid_thw are not mutated after creation. +def _decode_routed_experts(payload: str | None) -> np.ndarray | None: + if payload is None: + return None + routed_experts = np.load(BytesIO(base64.b64decode(payload)), allow_pickle=False) + assert routed_experts.ndim == 3 + return np.ascontiguousarray(routed_experts) + + def _align_routed_experts( - routed_experts: list[list[list[int]]] | None, + routed_experts: np.ndarray | None, expected_len: int, -) -> list[list[list[int]]] | None: +) -> np.ndarray | None: """Align routed_experts length with the expected token count. VLLM's capturer uses `num_tokens - 1` slot mappings because the final generated token was never fed as input to a forward pass and has no routing decision. Append zero-filled entries for the missing positions. """ - if routed_experts is None or not routed_experts: + if routed_experts is None: return routed_experts - deficit = expected_len - len(routed_experts) + assert routed_experts.ndim == 3 + if routed_experts.shape[0] > expected_len: + return np.ascontiguousarray(routed_experts[:expected_len]) + deficit = expected_len - routed_experts.shape[0] if deficit <= 0: return routed_experts - num_layers = len(routed_experts[0]) - topk = len(routed_experts[0][0]) - zero_entry = [[0] * topk for _ in range(num_layers)] - return routed_experts + [zero_entry for _ in range(deficit)] + padding = np.zeros((deficit, routed_experts.shape[1], routed_experts.shape[2]), dtype=routed_experts.dtype) + return np.concatenate((routed_experts, padding), axis=0) + + +def _set_sample_routed_experts(sample: TrainingSample, routed_experts: np.ndarray | None) -> None: + if routed_experts is None: + sample.routed_experts = None + sample.routed_experts_shape = None + sample.routed_experts_dtype = None + return + routed_experts = np.ascontiguousarray(routed_experts) + sample.routed_experts = routed_experts.tobytes() + sample.routed_experts_shape = list(routed_experts.shape) + sample.routed_experts_dtype = str(routed_experts.dtype) def _common_prefix_len(a: list[int], b: list[int]) -> int: @@ -302,7 +324,7 @@ def prepare_step_tokens(step: vf.TrajectoryStep, step_idx: int) -> dict[str, Any "completion_ids": list(tokens["completion_ids"]), "completion_mask": [bool(i) for i in tokens["completion_mask"]], "completion_logprobs": list(tokens["completion_logprobs"]), - "routed_experts": tokens.get("routed_experts"), + "routed_experts": _decode_routed_experts(tokens.get("routed_experts")), } logger.warning(f"Missing rollout tokens for example {output['example_id']} step {step_idx}.") @@ -315,7 +337,7 @@ def prepare_step_tokens(step: vf.TrajectoryStep, step_idx: int) -> dict[str, Any return None prepared_steps.append(prepared) - def make_sample(tokens: dict[str, Any]) -> TrainingSample: + def make_sample(tokens: dict[str, Any]) -> tuple[TrainingSample, np.ndarray | None]: """Create a new TrainingSample from a trajectory step.""" if has_error: completion_mask = [False] * len(tokens["completion_mask"]) @@ -328,7 +350,7 @@ def make_sample(tokens: dict[str, Any]) -> TrainingSample: len(tokens["prompt_ids"]) + len(tokens["completion_ids"]), ) prompt_ids = list(tokens["prompt_ids"]) - return TrainingSample( + sample = TrainingSample( prompt_ids=prompt_ids, prompt_mask=[bool(i) for i in tokens["prompt_mask"]], completion_ids=completion_ids, @@ -337,11 +359,17 @@ def make_sample(tokens: dict[str, Any]) -> TrainingSample: completion_temperatures=[temperature] * len(completion_ids), teacher_logprobs=None, advantage=None, - routed_experts=routed_experts, mm_token_type_ids=None, ) - - def extend_sample(sample: TrainingSample, prefix_len: int, step_idx: int) -> None: + _set_sample_routed_experts(sample, routed_experts) + return sample, routed_experts + + def extend_sample( + sample: TrainingSample, + sample_routed_experts: np.ndarray | None, + prefix_len: int, + step_idx: int, + ) -> np.ndarray | None: """Extend an existing sample with a new trajectory step (extension property holds).""" tokens = prepared_steps[step_idx] @@ -362,24 +390,27 @@ def extend_sample(sample: TrainingSample, prefix_len: int, step_idx: int) -> Non sample.completion_logprobs.extend(tokens["completion_logprobs"]) sample.completion_temperatures.extend([temperature] * len(completion_ids)) - if tokens.get("routed_experts") is not None and sample.routed_experts is not None: + if tokens.get("routed_experts") is not None and sample_routed_experts is not None: step_routed = tokens["routed_experts"] # The previous step's last routing entry was zero-padded by _align_routed_experts # (vLLM only captures num_tokens-1 routings per request). This step actually # processed that boundary token as part of its prompt, so replace the zero-fill # with the real routing decision before appending new entries. - if prefix_len > 0 and prefix_len <= len(step_routed): - sample.routed_experts[prefix_len - 1] = step_routed[prefix_len - 1] - sample.routed_experts.extend(step_routed[prefix_len:]) + if prefix_len > 0 and prefix_len <= step_routed.shape[0]: + sample_routed_experts[prefix_len - 1] = step_routed[prefix_len - 1] + sample_routed_experts = np.concatenate((sample_routed_experts, step_routed[prefix_len:]), axis=0) expected_len = len(sample.prompt_ids) + len(sample.completion_ids) - sample.routed_experts = _align_routed_experts(sample.routed_experts, expected_len) + sample_routed_experts = _align_routed_experts(sample_routed_experts, expected_len) + _set_sample_routed_experts(sample, sample_routed_experts) + return sample_routed_experts - # Track [prefix_tokens, sample, last_step_idx] per active sample - active_samples: list[tuple[list[int], TrainingSample, int]] = [] + # Track [prefix_tokens, sample, last_step_idx, routed_experts] per active sample + active_samples: list[tuple[list[int], TrainingSample, int, np.ndarray | None]] = [] first_tokens = prepared_steps[0] first_prefix = first_tokens["prompt_ids"] + first_tokens["completion_ids"] - active_samples.append((first_prefix, make_sample(first_tokens), 0)) + first_sample, first_routed_experts = make_sample(first_tokens) + active_samples.append((first_prefix, first_sample, 0, first_routed_experts)) for step_idx, _step in enumerate(trajectory[1:], start=1): tokens = prepared_steps[step_idx] @@ -387,16 +418,21 @@ def extend_sample(sample: TrainingSample, prefix_len: int, step_idx: int) -> Non # Check if this step extends ANY active prefix matched_idx = None - for idx, (prefix_tokens, _, _) in enumerate(active_samples): + for idx, (prefix_tokens, _, _, _) in enumerate(active_samples): if step_prompt_ids[: len(prefix_tokens)] == prefix_tokens: matched_idx = idx break if matched_idx is not None: # Extension holds - merge into matched sample - prefix_tokens, sample, _ = active_samples[matched_idx] - extend_sample(sample, len(prefix_tokens), step_idx=step_idx) - active_samples[matched_idx] = (tokens["prompt_ids"] + tokens["completion_ids"], sample, step_idx) + prefix_tokens, sample, _, sample_routed_experts = active_samples[matched_idx] + sample_routed_experts = extend_sample(sample, sample_routed_experts, len(prefix_tokens), step_idx=step_idx) + active_samples[matched_idx] = ( + tokens["prompt_ids"] + tokens["completion_ids"], + sample, + step_idx, + sample_routed_experts, + ) else: # No prefix matches - start a new sample logger.debug( @@ -404,7 +440,8 @@ def extend_sample(sample: TrainingSample, prefix_len: int, step_idx: int) -> Non f"Starting new sample (active_prefixes={len(active_samples)}, step_prompt_len={len(step_prompt_ids)})." ) new_prefix = tokens["prompt_ids"] + tokens["completion_ids"] - active_samples.append((new_prefix, make_sample(tokens), step_idx)) + sample, routed_experts = make_sample(tokens) + active_samples.append((new_prefix, sample, step_idx, routed_experts)) # Attach images once per sample using only the last merged step. Prompt # tokens already contain fully expanded <|image_pad|> placeholders because @@ -413,7 +450,7 @@ def extend_sample(sample: TrainingSample, prefix_len: int, step_idx: int) -> Non # fallback path so features and tokens stay 1:1. if vlm_cache is not None: key = output["example_id"] if cache_key is None else cache_key - for _, sample, last_step_idx in active_samples: + for _, sample, last_step_idx, _ in active_samples: pv, shape, grids = vlm_cache.get_for_step(key, last_step_idx) sample.pixel_values = pv sample.pixel_values_shape = shape @@ -423,7 +460,7 @@ def extend_sample(sample: TrainingSample, prefix_len: int, step_idx: int) -> Non mm_token_type_ids_mapping.get(token_id, 0) for token_id in sample.prompt_ids + sample.completion_ids ] - return [sample for _, sample, _ in active_samples] + return [sample for _, sample, _, _ in active_samples] # ============================================================================= diff --git a/src/prime_rl/trainer/batch.py b/src/prime_rl/trainer/batch.py index 662df36a80..b19f702b6e 100644 --- a/src/prime_rl/trainer/batch.py +++ b/src/prime_rl/trainer/batch.py @@ -2,6 +2,43 @@ from prime_rl.transport.types import MicroBatch, TrainingSample +ROUTED_EXPERTS_DTYPE_ITEMSIZE = { + "uint8": 1, + "int16": 2, + "int32": 4, +} + + +def _routed_experts_row_size(shape: list[int], dtype: str) -> int: + return shape[1] * shape[2] * ROUTED_EXPERTS_DTYPE_ITEMSIZE[dtype] + + +def _slice_routed_experts(data: bytes, shape: list[int], dtype: str, seq_len: int) -> tuple[bytes, list[int]]: + row_size = _routed_experts_row_size(shape, dtype) + return data[: seq_len * row_size], [seq_len, shape[1], shape[2]] + + +def _append_routed_experts(dst: MicroBatch, src: MicroBatch) -> None: + assert dst.routed_experts is not None + assert dst.routed_experts_shape is not None + assert dst.routed_experts_dtype is not None + assert src.routed_experts is not None + assert src.routed_experts_shape is not None + assert src.routed_experts_dtype is not None + assert dst.routed_experts_dtype == src.routed_experts_dtype + assert dst.routed_experts_shape[1:] == src.routed_experts_shape[1:] + dst.routed_experts += src.routed_experts + dst.routed_experts_shape[0] += src.routed_experts_shape[0] + + +def _pad_routed_experts(micro_batch: MicroBatch, padding_size: int) -> None: + assert micro_batch.routed_experts is not None + assert micro_batch.routed_experts_shape is not None + assert micro_batch.routed_experts_dtype is not None + row_size = _routed_experts_row_size(micro_batch.routed_experts_shape, micro_batch.routed_experts_dtype) + micro_batch.routed_experts += b"\0" * (padding_size * row_size) + micro_batch.routed_experts_shape[0] += padding_size + def prepare_sample(training_example: TrainingSample, seq_len: int) -> MicroBatch: """ @@ -24,6 +61,8 @@ def prepare_sample(training_example: TrainingSample, seq_len: int) -> MicroBatch # computed via prefill in the orchestrator when a teacher model is configured teacher_logprobs = training_example.teacher_logprobs routed_experts = training_example.routed_experts + routed_experts_shape = training_example.routed_experts_shape + routed_experts_dtype = training_example.routed_experts_dtype if len(input_ids) > seq_len: input_ids = input_ids[:seq_len] @@ -35,7 +74,11 @@ def prepare_sample(training_example: TrainingSample, seq_len: int) -> MicroBatch if teacher_logprobs is not None: teacher_logprobs = teacher_logprobs[:seq_len] if routed_experts is not None: - routed_experts = routed_experts[:seq_len] + assert routed_experts_shape is not None + assert routed_experts_dtype is not None + routed_experts, routed_experts_shape = _slice_routed_experts( + routed_experts, routed_experts_shape, routed_experts_dtype, seq_len + ) if mm_token_type_ids is not None: mm_token_type_ids = mm_token_type_ids[:seq_len] @@ -53,8 +96,13 @@ def prepare_sample(training_example: TrainingSample, seq_len: int) -> MicroBatch assert len(teacher_logprobs) == len(input_ids), f"teacher_logprobs: {len(teacher_logprobs)}" if routed_experts is not None: - assert len(routed_experts) == len(input_ids), ( - f"routed_experts: {len(routed_experts)}, input_ids: {len(input_ids)}" + assert routed_experts_shape is not None + assert routed_experts_dtype is not None + assert routed_experts_shape[0] == len(input_ids), ( + f"routed_experts: {routed_experts_shape}, input_ids: {len(input_ids)}" + ) + assert len(routed_experts) == len(input_ids) * _routed_experts_row_size( + routed_experts_shape, routed_experts_dtype ) if mm_token_type_ids is not None: @@ -71,6 +119,8 @@ def prepare_sample(training_example: TrainingSample, seq_len: int) -> MicroBatch teacher_logprobs=teacher_logprobs, temperatures=temperatures, routed_experts=routed_experts, + routed_experts_shape=routed_experts_shape, + routed_experts_dtype=routed_experts_dtype, mm_token_type_ids=mm_token_type_ids, # Multimodal fields (Qwen3-VL) - passed through without modification pixel_values=training_example.pixel_values, @@ -129,10 +179,14 @@ def packed_samples_into_micro_bs( if bin_content.teacher_logprobs is None: bin_content.teacher_logprobs = [] bin_content.teacher_logprobs.extend(sample.teacher_logprobs) + assert (bin_content.routed_experts is None) == (sample.routed_experts is None) if sample.routed_experts is not None: if bin_content.routed_experts is None: - bin_content.routed_experts = [] - bin_content.routed_experts.extend(sample.routed_experts) + bin_content.routed_experts = sample.routed_experts + bin_content.routed_experts_shape = list(sample.routed_experts_shape) + bin_content.routed_experts_dtype = sample.routed_experts_dtype + else: + _append_routed_experts(bin_content, sample) if sample.mm_token_type_ids is not None: if bin_content.mm_token_type_ids is None: bin_content.mm_token_type_ids = [] @@ -178,6 +232,8 @@ def pad_micro_batch(micro_batch: MicroBatch, pad_to_multiple_of: int) -> MicroBa ) if micro_batch.mm_token_type_ids is not None: micro_batch.mm_token_type_ids.extend([0] * padding_size) + if micro_batch.routed_experts is not None: + _pad_routed_experts(micro_batch, padding_size) return micro_batch diff --git a/src/prime_rl/trainer/rl/data.py b/src/prime_rl/trainer/rl/data.py index ffc4bc627f..43ef1bf0ce 100644 --- a/src/prime_rl/trainer/rl/data.py +++ b/src/prime_rl/trainer/rl/data.py @@ -12,6 +12,12 @@ from prime_rl.trainer.world import get_world from prime_rl.transport import MicroBatch, MicroBatchReceiver, TransportConfig, setup_micro_batch_receiver +ROUTED_EXPERTS_TORCH_DTYPES = { + "uint8": torch.uint8, + "int16": torch.int16, + "int32": torch.int32, +} + class TensorMicroBatch(TypedDict): """A micro batch of data for training.""" @@ -195,6 +201,19 @@ def _micro_batch_to_tensor(self, micro_batch: MicroBatch) -> TensorMicroBatch: if micro_batch.lora_num_tokens is None: micro_batch.lora_num_tokens = [0] * self.multi_run_manager.max_runs micro_batch.lora_num_tokens[0] = len(micro_batch.input_ids) + routed_experts = None + if micro_batch.routed_experts is not None: + assert micro_batch.routed_experts_shape is not None + assert micro_batch.routed_experts_dtype is not None + routed_experts = ( + torch.frombuffer( + micro_batch.routed_experts, + dtype=ROUTED_EXPERTS_TORCH_DTYPES[micro_batch.routed_experts_dtype], + ) + .reshape(micro_batch.routed_experts_shape) + .to(torch.int32) + .unsqueeze(0) + ) return TensorMicroBatch( input_ids=torch.tensor(micro_batch.input_ids, dtype=torch.long).unsqueeze(0), position_ids=torch.tensor(micro_batch.position_ids, dtype=torch.long).unsqueeze(0), @@ -218,10 +237,6 @@ def _micro_batch_to_tensor(self, micro_batch: MicroBatch) -> TensorMicroBatch: mm_token_type_ids=torch.tensor(micro_batch.mm_token_type_ids, dtype=torch.long).unsqueeze(0) if micro_batch.mm_token_type_ids is not None else None, - routed_experts=torch.tensor(micro_batch.routed_experts, dtype=torch.int32).unsqueeze( - 0 - ) # [1, seq_len, layers, topk] - if micro_batch.routed_experts is not None - else None, + routed_experts=routed_experts, sft_loss=micro_batch.sft_loss, ) diff --git a/src/prime_rl/transport/types.py b/src/prime_rl/transport/types.py index 4bc594f06d..ff60e59978 100644 --- a/src/prime_rl/transport/types.py +++ b/src/prime_rl/transport/types.py @@ -21,7 +21,9 @@ class TrainingSample(msgspec.Struct, array_like=True, gc=False, omit_defaults=Tr # image_grid_thw: grid dimensions [num_images, 3] where each entry is [temporal, height, width] image_grid_thw: list[list[int]] | None = None - routed_experts: list[list[list[int]]] | None = None # [seq_len, layers, topk] + routed_experts: bytes | None = None + routed_experts_shape: list[int] | None = None # [seq_len, layers, topk] + routed_experts_dtype: str | None = None # mm_token_type_ids: token type ids per token [batch seq], int64 (0=text, 1=image, 2=video) mm_token_type_ids: list[int] | None = None @@ -49,7 +51,9 @@ class MicroBatch(msgspec.Struct, array_like=True, gc=False, omit_defaults=True): temperatures: list[float] # Per-token temperatures used during generation teacher_logprobs: list[float] | None = None lora_num_tokens: list[int] | None = None - routed_experts: list[list[list[int]]] | None = None + routed_experts: bytes | None = None + routed_experts_shape: list[int] | None = None # [seq_len, layers, topk] + routed_experts_dtype: str | None = None # Multimodal fields (Qwen3-VL) — pixel_values stored as raw float32 bytes for efficient serialization pixel_values: bytes | None = None diff --git a/src/prime_rl/utils/client.py b/src/prime_rl/utils/client.py index 876503a494..21659dfc46 100644 --- a/src/prime_rl/utils/client.py +++ b/src/prime_rl/utils/client.py @@ -42,13 +42,7 @@ async def wait_for_ready(self, model_name: str, timeout: int | None = None) -> N """Wait for inference pool to be ready.""" ... - async def update_weights( - self, - weight_dir: Path | None, - lora_name: str | None = None, - step: int = 0, - reset_prefix_cache: bool = False, - ) -> None: + async def update_weights(self, weight_dir: Path | None, lora_name: str | None = None, step: int = 0) -> None: """Update weights on all inference servers.""" ... @@ -116,20 +110,8 @@ async def wait_for_ready(self, model_name: str, timeout: int | None = None) -> N ) await maybe_check_has_model(self._admin_clients, model_name, skip_model_check=self._skip_model_check) - async def update_weights( - self, - weight_dir: Path | None, - lora_name: str | None = None, - step: int = 0, - reset_prefix_cache: bool = False, - ) -> None: - await update_weights( - self._admin_clients, - weight_dir, - lora_name=lora_name, - step=step, - reset_prefix_cache=reset_prefix_cache, - ) + async def update_weights(self, weight_dir: Path | None, lora_name: str | None = None, step: int = 0) -> None: + await update_weights(self._admin_clients, weight_dir, lora_name=lora_name, step=step) def get_metrics(self) -> dict[str, float]: return {} @@ -305,14 +287,13 @@ async def _check_health(admin_client: AsyncClient) -> None: NCCL_READY_MARKER = "NCCL_READY" -async def _pause_engines(admin_clients: list[AsyncClient], reset_prefix_cache: bool = False) -> None: +async def _pause_engines(admin_clients: list[AsyncClient]) -> None: """Pause all inference engines, waiting for in-flight requests to drain.""" logger = get_logger() - mode = "clear" if reset_prefix_cache else "keep" - logger.info(f"Pausing inference engines for weight update (mode={mode})") + logger.info("Pausing inference engines for weight update") async def _pause(client: AsyncClient) -> None: - response = await client.post("/pause", params={"mode": mode}) + response = await client.post("/pause", params={"mode": "keep", "clear_cache": "false"}) response.raise_for_status() await asyncio.gather(*[_pause(client) for client in admin_clients]) @@ -331,52 +312,11 @@ async def _resume(client: AsyncClient) -> None: logger.info("All inference engines resumed") -async def clear_routing_cache(client_config: ClientConfig) -> None: - """Clear router-local routed-experts cache when a policy update resets prefix cache.""" - logger = get_logger() - if client_config.router_url is not None: - router_urls = [client_config.router_url] - elif client_config.admin_base_url is not None: - router_urls = client_config.base_url - else: - router_urls = [] - - def _setup_router_client(base_url: str) -> AsyncClient: - headers = client_config.headers.copy() - api_key = os.getenv(client_config.api_key_var, "EMPTY") - if api_key and api_key != "EMPTY": - headers["Authorization"] = f"Bearer {api_key}" - - return AsyncClient( - base_url=base_url.rstrip("/").removesuffix("/v1"), - headers=headers, - limits=httpx.Limits(max_connections=4, max_keepalive_connections=1), - timeout=httpx.Timeout(None), - ) - - router_clients = [_setup_router_client(url) for url in router_urls] - if not router_clients: - logger.info("Skipping routing cache clear: no router admin endpoint configured") - return - - async def _clear(client: AsyncClient) -> None: - response = await client.post("/clear_routing_cache") - response.raise_for_status() - - try: - logger.info(f"Clearing router routing cache on {', '.join(str(client.base_url) for client in router_clients)}") - await asyncio.gather(*[_clear(client) for client in router_clients]) - logger.info("Router routing cache cleared") - finally: - await asyncio.gather(*[client.aclose() for client in router_clients]) - - async def update_weights( admin_clients: list[AsyncClient], weight_dir: Path | None, lora_name: str | None = None, step: int = 0, - reset_prefix_cache: bool = False, ) -> None: """Update weights on static inference servers. @@ -384,8 +324,8 @@ async def update_weights( weight update, then resumes. This ensures all DP workers are idle and can participate in the collective weight transfer. - When reset_prefix_cache is enabled, engines are paused in clear mode so vLLM - drops prefix-cache state before loading the new weights. + Note: The server-side /update_weights endpoint automatically resets the prefix cache + to invalidate any cached KV states computed with the old weights. """ logger = get_logger() @@ -400,7 +340,7 @@ async def _update_weights(admin_client: AsyncClient, weight_dir: str | None) -> response.raise_for_status() # Pause engines so all DP workers drain in-flight work and can join the NCCL broadcast - await _pause_engines(admin_clients, reset_prefix_cache=reset_prefix_cache) + await _pause_engines(admin_clients) try: # Create ready marker before servers enter receive path (used by NCCL broadcast) diff --git a/src/prime_rl/utils/elastic.py b/src/prime_rl/utils/elastic.py index 5de73497f4..902f873903 100644 --- a/src/prime_rl/utils/elastic.py +++ b/src/prime_rl/utils/elastic.py @@ -499,13 +499,7 @@ async def wait_for_ready(self, model_name: str = "", timeout: int | None = None, raise TimeoutError(f"Timed out waiting for {min_servers} ready servers (got {self.num_ready_servers})") - async def update_weights( - self, - weight_dir: Path | None, - lora_name: str | None = None, - step: int = 0, - reset_prefix_cache: bool = False, - ) -> None: + async def update_weights(self, weight_dir: Path | None, lora_name: str | None = None, step: int = 0) -> None: if lora_name is None: raise ValueError("Elastic inference pool requires LoRA training (lora_name must be set)") await self.sync_weights(weight_dir, lora_name, step) diff --git a/tests/unit/inference/test_serving_tokens.py b/tests/unit/inference/test_serving_tokens.py index 978d791333..fbee225543 100644 --- a/tests/unit/inference/test_serving_tokens.py +++ b/tests/unit/inference/test_serving_tokens.py @@ -1,17 +1,14 @@ -"""Sanity tests for the prime-RL ``ServingTokens`` subclass. - -The full happy-path is owned upstream by vLLM 0.20's -``vllm/entrypoints/serve/disagg`` test suite. We only cover the prime-RL -deltas here: - * The subclass only overrides ``serve_tokens`` for DP-rank routing and - server-side max-tokens defaulting. - * ``_client_set_max_tokens`` distinguishes raw-body shapes correctly. -""" +"""Sanity tests for the prime-RL ``ServingTokens`` subclass.""" from __future__ import annotations import asyncio +import base64 +from io import BytesIO + +import numpy as np +from prime_rl.inference.vllm.routed_experts import serialize_routed_experts from prime_rl.inference.vllm.serving_tokens import ( PrimeRlServingTokens, _client_set_max_tokens, @@ -31,7 +28,44 @@ async def json(self): def test_subclass_only_overrides_serve_tokens(): assert PrimeRlServingTokens.serve_tokens is not PrimeRlServingTokens.__mro__[1].serve_tokens - assert "serve_tokens_full_generator" not in PrimeRlServingTokens.__dict__ + assert ( + PrimeRlServingTokens.serve_tokens_full_generator + is not PrimeRlServingTokens.__mro__[1].serve_tokens_full_generator + ) + + +def test_serialize_routed_experts_uses_compact_numpy_payload(): + routed_experts = np.array( + [ + [[1, 2], [3, 4]], + [[5, 6], [7, 8]], + ], + dtype=np.int64, + ) + + encoded = serialize_routed_experts(routed_experts) + assert encoded is not None + + decoded = np.load(BytesIO(base64.b64decode(encoded)), allow_pickle=False) + assert decoded.dtype == np.uint8 + np.testing.assert_array_equal(decoded, routed_experts) + + +def test_serialize_routed_experts_uses_int16_for_large_expert_ids(): + routed_experts = np.array( + [ + [[256, 257], [300, 301]], + [[302, 303], [304, 305]], + ], + dtype=np.int64, + ) + + encoded = serialize_routed_experts(routed_experts) + assert encoded is not None + + decoded = np.load(BytesIO(base64.b64decode(encoded)), allow_pickle=False) + assert decoded.dtype == np.int16 + np.testing.assert_array_equal(decoded, routed_experts) def test_client_set_max_tokens_recognizes_explicit_value(): diff --git a/tests/unit/orchestrator/test_batch.py b/tests/unit/orchestrator/test_batch.py index a2e2e50079..202fee7e92 100644 --- a/tests/unit/orchestrator/test_batch.py +++ b/tests/unit/orchestrator/test_batch.py @@ -1,9 +1,15 @@ +import numpy as np import pytest from prime_rl.trainer.batch import prepare_batch, prepare_sample from prime_rl.transport.types import TrainingSample +def _routed_experts(data, dtype=np.uint8): + routed_experts = np.asarray(data, dtype=dtype) + return routed_experts.tobytes(), list(routed_experts.shape), str(routed_experts.dtype) + + @pytest.fixture def make_training_example(): def _make_training_example(temperature: float = 1.0, sft_loss: bool = False) -> TrainingSample: @@ -109,6 +115,7 @@ def test_prepare_sample_with_routed_experts(): """Routed experts are passed through prepare_sample and match input_ids length.""" # 2 prompt + 2 completion = 4 tokens, 2 layers, topk=2 routed_experts = [[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[0, 2], [1, 3]], [[1, 0], [3, 2]]] + routed_bytes, routed_shape, routed_dtype = _routed_experts(routed_experts) sample = TrainingSample( prompt_ids=[1, 2], prompt_mask=[False, False], @@ -117,18 +124,23 @@ def test_prepare_sample_with_routed_experts(): completion_logprobs=[-0.1, -0.2], completion_temperatures=[1.0, 1.0], advantage=1.0, - routed_experts=routed_experts, + routed_experts=routed_bytes, + routed_experts_shape=routed_shape, + routed_experts_dtype=routed_dtype, ) micro_batch = prepare_sample(sample, seq_len=8) assert micro_batch.routed_experts is not None - assert len(micro_batch.routed_experts) == 4 - assert micro_batch.routed_experts == routed_experts + assert micro_batch.routed_experts == routed_bytes + assert micro_batch.routed_experts_shape == routed_shape + assert micro_batch.routed_experts_dtype == routed_dtype def test_prepare_sample_truncates_routed_experts(): """Routed experts are truncated to seq_len when input exceeds it.""" routed_experts = [[[0, 1]], [[2, 3]], [[4, 5]], [[6, 7]]] + routed_bytes, routed_shape, routed_dtype = _routed_experts(routed_experts) + expected_bytes, expected_shape, _ = _routed_experts(routed_experts[:3]) sample = TrainingSample( prompt_ids=[1, 2], prompt_mask=[False, False], @@ -137,13 +149,16 @@ def test_prepare_sample_truncates_routed_experts(): completion_logprobs=[-0.1, -0.2], completion_temperatures=[1.0, 1.0], advantage=1.0, - routed_experts=routed_experts, + routed_experts=routed_bytes, + routed_experts_shape=routed_shape, + routed_experts_dtype=routed_dtype, ) micro_batch = prepare_sample(sample, seq_len=3) assert micro_batch.routed_experts is not None - assert len(micro_batch.routed_experts) == 3 - assert micro_batch.routed_experts == routed_experts[:3] + assert micro_batch.routed_experts == expected_bytes + assert micro_batch.routed_experts_shape == expected_shape + assert micro_batch.routed_experts_dtype == routed_dtype def test_prepare_sample_none_routed_experts(): diff --git a/tests/unit/orchestrator/test_scheduler.py b/tests/unit/orchestrator/test_scheduler.py index ae3415c29f..9e73b5207d 100644 --- a/tests/unit/orchestrator/test_scheduler.py +++ b/tests/unit/orchestrator/test_scheduler.py @@ -15,11 +15,7 @@ def make_scheduler() -> Scheduler: scheduler.strict_async_level = False scheduler.step = 9 scheduler.ckpt_step = 7 - scheduler.config = SimpleNamespace( - output_dir=Path("/tmp/prime-rl-test"), - reset_prefix_cache_on_policy_update=False, - client=None, - ) + scheduler.config = SimpleNamespace(output_dir=Path("/tmp/prime-rl-test")) scheduler.logger = MagicMock() scheduler.checkpoint_ready = asyncio.Event() scheduler.checkpoint_ready.set() @@ -101,7 +97,7 @@ async def run() -> None: release = asyncio.Event() applied_steps: list[int] = [] - async def update_weights(weight_dir, lora_name=None, step=0, reset_prefix_cache=False) -> None: + async def update_weights(weight_dir, lora_name=None, step=0) -> None: applied_steps.append(step) started.set() await release.wait() @@ -139,7 +135,7 @@ async def run() -> None: started = asyncio.Event() cancelled = asyncio.Event() - async def update_weights(weight_dir, lora_name=None, step=0, reset_prefix_cache=False) -> None: + async def update_weights(weight_dir, lora_name=None, step=0) -> None: started.set() try: await asyncio.Future() diff --git a/tests/unit/orchestrator/test_trajectories.py b/tests/unit/orchestrator/test_trajectories.py index 6fa169760c..ff77b73e4f 100644 --- a/tests/unit/orchestrator/test_trajectories.py +++ b/tests/unit/orchestrator/test_trajectories.py @@ -30,6 +30,22 @@ def _decode_pixels(pixel_bytes: bytes, shape: list[int]) -> list[list[float]]: return np.frombuffer(pixel_bytes, dtype=np.float32).reshape(shape).tolist() +def _routed_experts_payload(data, dtype=np.uint8) -> str: + arr = np.asarray(data, dtype=dtype) + buffer = BytesIO() + np.save(buffer, arr, allow_pickle=False) + return base64.b64encode(buffer.getvalue()).decode("ascii") + + +def _sample_routed_experts(sample) -> np.ndarray: + assert sample.routed_experts is not None + assert sample.routed_experts_shape is not None + assert sample.routed_experts_dtype is not None + return np.frombuffer(sample.routed_experts, dtype=np.dtype(sample.routed_experts_dtype)).reshape( + sample.routed_experts_shape + ) + + def test_deserialize_tool_calls_does_not_inject_missing_key(): messages = [{"role": "assistant", "content": "hello"}] @@ -1857,40 +1873,43 @@ def test_align_routed_experts_none(): def test_align_routed_experts_empty(): - result = _align_routed_experts([], 10) - assert result == [] + experts = np.empty((0, 2, 2), dtype=np.uint8) + result = _align_routed_experts(experts, 10) + assert result is not None + assert result.shape == (10, 2, 2) + assert np.all(result == 0) def test_align_routed_experts_no_deficit(): # 3 tokens, 2 layers, topk=2 - experts = [[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[0, 2], [1, 3]]] + experts = np.asarray([[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[0, 2], [1, 3]]], dtype=np.uint8) result = _align_routed_experts(experts, expected_len=3) - assert result == experts + np.testing.assert_array_equal(result, experts) def test_align_routed_experts_with_deficit(): # 2 tokens but expected 4 (deficit of 2) - experts = [[[1, 2], [3, 4]], [[5, 6], [7, 0]]] + experts = np.asarray([[[1, 2], [3, 4]], [[5, 6], [7, 0]]], dtype=np.uint8) result = _align_routed_experts(experts, expected_len=4) - assert len(result) == 4 - assert result[:2] == experts + assert result is not None + assert result.shape == (4, 2, 2) + np.testing.assert_array_equal(result[:2], experts) # Padded entries should be zero-filled with same shape [layers=2, topk=2] - assert result[2] == [[0, 0], [0, 0]] - assert result[3] == [[0, 0], [0, 0]] + np.testing.assert_array_equal(result[2], [[0, 0], [0, 0]]) + np.testing.assert_array_equal(result[3], [[0, 0], [0, 0]]) def test_align_routed_experts_excess_length(): - experts = [[[1, 2]], [[3, 4]], [[5, 6]]] + experts = np.asarray([[[1, 2]], [[3, 4]], [[5, 6]]], dtype=np.uint8) result = _align_routed_experts(experts, expected_len=2) - # No truncation, just returns as-is - assert result == experts + np.testing.assert_array_equal(result, experts[:2]) def test_interleave_rollout_single_step_with_routed_experts(): """Routed experts are aligned and passed through for a single-step trajectory.""" # prompt_ids=[1,2], completion_ids=[3,4] -> total 4 tokens # vLLM returns num_tokens-1 = 3 routed expert entries - routed_experts_from_vllm = [[[0, 1]], [[2, 3]], [[4, 5]]] # 3 entries, 1 layer, topk=2 + routed_experts_from_vllm = np.asarray([[[0, 1]], [[2, 3]], [[4, 5]]], dtype=np.uint8) output = vf.RolloutOutput( example_id=0, trajectory=[ @@ -1906,7 +1925,7 @@ def test_interleave_rollout_single_step_with_routed_experts(): completion_logprobs=[-0.1, -0.2], overlong_prompt=False, is_truncated=False, - routed_experts=routed_experts_from_vllm, + routed_experts=_routed_experts_payload(routed_experts_from_vllm), ), reward=None, advantage=None, @@ -1926,18 +1945,19 @@ def test_interleave_rollout_single_step_with_routed_experts(): # Should be aligned to 4 tokens (2 prompt + 2 completion) assert sample.routed_experts is not None - assert len(sample.routed_experts) == 4 + routed_experts = _sample_routed_experts(sample) + assert routed_experts.shape == (4, 1, 2) # First 3 are original, last one is zero-padded - assert sample.routed_experts[:3] == routed_experts_from_vllm - assert sample.routed_experts[3] == [[0, 0]] + np.testing.assert_array_equal(routed_experts[:3], routed_experts_from_vllm) + np.testing.assert_array_equal(routed_experts[3], [[0, 0]]) def test_interleave_rollout_multi_step_with_routed_experts(): """Routed experts are extended and aligned across multi-step trajectories.""" # Step 1: prompt=[1,2], completion=[3,4] -> 4 tokens, vLLM returns 3 - step1_experts = [[[1, 2]], [[3, 4]], [[5, 6]]] + step1_experts = np.asarray([[[1, 2]], [[3, 4]], [[5, 6]]], dtype=np.uint8) # Step 2: prompt=[1,2,3,4,5,6], completion=[7,8] -> 8 tokens, vLLM returns 7 - step2_experts = [[[1, 0]], [[2, 0]], [[3, 0]], [[4, 0]], [[5, 0]], [[6, 0]], [[7, 0]]] + step2_experts = np.asarray([[[1, 0]], [[2, 0]], [[3, 0]], [[4, 0]], [[5, 0]], [[6, 0]], [[7, 0]]], dtype=np.uint8) output = vf.RolloutOutput( example_id=0, @@ -1954,7 +1974,7 @@ def test_interleave_rollout_multi_step_with_routed_experts(): completion_logprobs=[-0.1, -0.2], overlong_prompt=False, is_truncated=False, - routed_experts=step1_experts, + routed_experts=_routed_experts_payload(step1_experts), ), reward=None, advantage=None, @@ -1978,7 +1998,7 @@ def test_interleave_rollout_multi_step_with_routed_experts(): completion_logprobs=[-0.3, -0.4], overlong_prompt=False, is_truncated=False, - routed_experts=step2_experts, + routed_experts=_routed_experts_payload(step2_experts), ), reward=None, advantage=None, @@ -1999,7 +2019,7 @@ def test_interleave_rollout_multi_step_with_routed_experts(): # Merged sample: prompt=[1,2], completion=[3,4,5,6,7,8] -> 8 tokens total assert len(sample.prompt_ids) + len(sample.completion_ids) == 8 assert sample.routed_experts is not None - assert len(sample.routed_experts) == 8 + assert _sample_routed_experts(sample).shape == (8, 1, 2) def test_interleave_rollout_none_routed_experts_stays_none(): diff --git a/uv.lock b/uv.lock index 07251a5534..ea168dd95c 100644 --- a/uv.lock +++ b/uv.lock @@ -11,7 +11,7 @@ supported-markers = [ ] [options] -exclude-newer = "2026-05-06T12:37:29.025733799Z" +exclude-newer = "2026-05-07T14:12:41.778927491Z" exclude-newer-span = "P7D" [options.exclude-newer-package] @@ -412,12 +412,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8a/1f/f041989e93b001bc4e44bb1669ccdcf54d3f00e628229a85b08d330615c5/charset_normalizer-3.4.3-py3-none-any.whl", hash = "sha256:ce571ab16d890d23b5c278547ba694193a45011ff86a9162a71307ed9f86759a", size = 53175, upload-time = "2025-08-09T07:57:26.864Z" }, ] -[[package]] -name = "chess" -version = "1.11.2" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/93/09/7d04d7581ae3bb8b598017941781bceb7959dd1b13e3ebf7b6a2cd843bc9/chess-1.11.2.tar.gz", hash = "sha256:a8b43e5678fdb3000695bdaa573117ad683761e5ca38e591c4826eba6d25bb39", size = 6131385, upload-time = "2025-02-25T19:10:27.328Z" } - [[package]] name = "chromadb" version = "1.5.4" @@ -1514,15 +1508,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/31/b4/b9b800c45527aadd64d5b442f9b932b00648617eb5d63d2c7a6587b7cafc/jmespath-1.0.1-py3-none-any.whl", hash = "sha256:02e2e4cc71b5bcab88332eebf907519190dd9e6e82107fa7f83b1003a6252980", size = 20256, upload-time = "2022-06-17T18:00:10.251Z" }, ] -[[package]] -name = "joblib" -version = "1.5.3" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/41/f2/d34e8b3a08a9cc79a50b2208a93dce981fe615b64d5a4d4abee421d898df/joblib-1.5.3.tar.gz", hash = "sha256:8561a3269e6801106863fd0d6d84bb737be9e7631e33aaed3fb9ce5953688da3", size = 331603, upload-time = "2025-12-15T08:41:46.427Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/7b/91/984aca2ec129e2757d1e4e3c81c3fcda9d0f85b74670a094cc443d9ee949/joblib-1.5.3-py3-none-any.whl", hash = "sha256:5fc3c5039fc5ca8c0276333a188bbd59d6b7ab37fe6632daa76bc7f9ec18e713", size = 309071, upload-time = "2025-12-15T08:41:44.973Z" }, -] - [[package]] name = "jsonschema" version = "4.25.1" @@ -2133,21 +2118,6 @@ requires-dist = [ { name = "torch" }, ] -[[package]] -name = "nltk" -version = "3.9.4" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "click", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "joblib", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "regex", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "tqdm", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/74/a1/b3b4adf15585a5bc4c357adde150c01ebeeb642173ded4d871e89468767c/nltk-3.9.4.tar.gz", hash = "sha256:ed03bc098a40481310320808b2db712d95d13ca65b27372f8a403949c8b523d0", size = 2946864, upload-time = "2026-03-24T06:13:40.641Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/9d/91/04e965f8e717ba0ab4bdca5c112deeab11c9e750d94c4d4602f050295d39/nltk-3.9.4-py3-none-any.whl", hash = "sha256:f2fa301c3a12718ce4a0e9305c5675299da5ad9e26068218b69d692fda84828f", size = 1552087, upload-time = "2026-03-24T06:13:38.47Z" }, -] - [[package]] name = "nodeenv" version = "1.9.1" @@ -2812,10 +2782,9 @@ dependencies = [ { name = "transformers", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "uvloop", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "verifiers", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "vllm", version = "0.20.1rc1.dev97+g6acff036e.d20260513.precompiled", source = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.20.1rc1.dev97+g6acff036e.d20260513.precompiled-cp312-cp312-linux_x86_64.whl" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "vllm", version = "0.20.1rc1.dev99+g77adbf599.precompiled", source = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.20.1rc1.dev99+g77adbf599.precompiled-cp312-cp312-linux_x86_64.whl" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "vllm", version = "0.20.2+cu129", source = { url = "https://github.com/vllm-project/vllm/releases/download/v0.20.2/vllm-0.20.2+cu129-cp38-abi3-manylinux_2_31_aarch64.whl" }, marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, { name = "wandb", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "wordle", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] [package.optional-dependencies] @@ -2923,7 +2892,7 @@ requires-dist = [ { name = "pyarrow", specifier = ">=21.0.0" }, { name = "pyzmq", specifier = ">=27.1.0" }, { name = "quack-kernels", marker = "extra == 'quack'", specifier = ">=0.3.3" }, - { name = "renderers", specifier = ">=0.1.8.dev0" }, + { name = "renderers", specifier = "==0.1.6" }, { name = "reverse-text", marker = "extra == 'envs'", index = "https://hub.primeintellect.ai/primeintellect/simple/" }, { name = "rich", specifier = ">=14.0.0" }, { name = "ring-flash-attn", specifier = ">=0.1.8" }, @@ -2938,14 +2907,13 @@ requires-dist = [ { name = "torchvision", index = "https://download.pytorch.org/whl/cu128" }, { name = "transformers", git = "https://github.com/huggingface/transformers.git?rev=c1c3424" }, { name = "uvloop", specifier = ">=0.21.0" }, - { name = "verifiers", git = "https://github.com/PrimeIntellect-ai/verifiers?rev=bb54a3e" }, + { name = "verifiers", git = "https://github.com/PrimeIntellect-ai/verifiers?rev=11dbe34" }, { name = "vllm", marker = "platform_machine != 'aarch64' and platform_machine != 'x86_64'" }, { name = "vllm", marker = "platform_machine == 'aarch64'", url = "https://github.com/vllm-project/vllm/releases/download/v0.20.2/vllm-0.20.2+cu129-cp38-abi3-manylinux_2_31_aarch64.whl" }, - { name = "vllm", marker = "platform_machine == 'x86_64'", url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.20.1rc1.dev97+g6acff036e.d20260513.precompiled-cp312-cp312-linux_x86_64.whl" }, - { name = "vllm-router", marker = "platform_machine == 'x86_64' and extra == 'disagg'", url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm_router-0.1.23-cp38-abi3-linux_x86_64.whl" }, + { name = "vllm", marker = "platform_machine == 'x86_64'", url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.20.1rc1.dev99+g77adbf599.precompiled-cp312-cp312-linux_x86_64.whl" }, + { name = "vllm-router", marker = "platform_machine == 'x86_64' and extra == 'disagg'", url = "https://github.com/PrimeIntellect-ai/router/releases/download/v0.1.22/vllm_router-0.1.22-cp38-abi3-manylinux_2_28_x86_64.whl" }, { name = "wandb", specifier = ">=0.26.1" }, { name = "wiki-search", marker = "extra == 'envs'", index = "https://hub.primeintellect.ai/primeintellect/simple/" }, - { name = "wordle", git = "https://github.com/PrimeIntellect-ai/verifiers?subdirectory=environments%2Fwordle&rev=bb54a3e" }, ] provides-extras = ["flash-attn", "flash-attn-3", "flash-attn-cute", "envs", "disagg", "gpt-oss", "quack", "all"] @@ -3410,7 +3378,7 @@ wheels = [ [[package]] name = "renderers" -version = "0.1.8.dev0" +version = "0.1.6" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jinja2", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, @@ -3420,9 +3388,9 @@ dependencies = [ { name = "tiktoken", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "transformers", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/50/de/a445036157af3367c6a962c13333427c83c08926934c541886eb87f9dcdf/renderers-0.1.8.dev0.tar.gz", hash = "sha256:71eef7bfa3d3f5849ba070d38cd89a1f6387ca7710824f2e50d8c05c9b1048b9", size = 210667, upload-time = "2026-05-12T17:48:45.352Z" } +sdist = { url = "https://files.pythonhosted.org/packages/7c/a7/26162494dab2d7740ff02191cb87c30b68450fb154363c7f0a434e7f3ea9/renderers-0.1.6.tar.gz", hash = "sha256:b74bc3dc870bea3c37ff5b47826ace9b8dd608a4c1f56554c39be1b20b2c63dc", size = 163768, upload-time = "2026-05-07T14:12:36.634Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/e7/33/936a38c7f20fbe096b751842ffc6ef254c9eb2223153aa860a122ce9a834/renderers-0.1.8.dev0-py3-none-any.whl", hash = "sha256:09bb35233f67599519c0ff6edfad469f0836a55a6b78e039cd8e7b5e527bdcb3", size = 98617, upload-time = "2026-05-12T17:48:44.222Z" }, + { url = "https://files.pythonhosted.org/packages/5e/ad/2cf218b9fafe2333fb3e80e123e3e2022d4923d9a61fa73ee6d79f39b563/renderers-0.1.6-py3-none-any.whl", hash = "sha256:90c626713239ec108716b7c9d194ba81ffcebe94dc003324f14fbd70e6793e89", size = 83348, upload-time = "2026-05-07T14:12:35.218Z" }, ] [[package]] @@ -3794,24 +3762,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/73/c6/825dab04195756cf8ff2e12698f22513b3db2f64925bdd41671bfb33aaa5/tensorboard_data_server-0.7.2-py3-none-manylinux_2_31_x86_64.whl", hash = "sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530", size = 6590363, upload-time = "2023-10-23T21:23:35.583Z" }, ] -[[package]] -name = "textarena" -version = "0.7.4" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "chess", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "nltk", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "openai", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "python-dotenv", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "requests", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "rich", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "websockets", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/ba/04/4a3ca42093d0be2a9c377ae3335a6c6baac1d278ae932562ec69f339d172/textarena-0.7.4.tar.gz", hash = "sha256:28bb9170d7718f2ae05e4515bea82262422731e563fc7318a9e7983de0cadd4f", size = 954969, upload-time = "2025-10-16T14:41:55.981Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/26/b4/9a9ba65154aff853c75b3d7324319d168ad9c69c6097f4aa3c16da7d9ef3/textarena-0.7.4-py3-none-any.whl", hash = "sha256:684784e78278e518066f67557ee93b47c238d16cbbd15d3abdaa3147562d3024", size = 1073570, upload-time = "2025-10-16T14:41:53.965Z" }, -] - [[package]] name = "textual" version = "8.2.5" @@ -4247,8 +4197,8 @@ wheels = [ [[package]] name = "verifiers" -version = "0.1.15.dev2" -source = { git = "https://github.com/PrimeIntellect-ai/verifiers?rev=bb54a3e#bb54a3efc5ee1d48b2328fc2d07c690fb50cd7e0" } +version = "0.1.15.dev4" +source = { git = "https://github.com/PrimeIntellect-ai/verifiers?rev=11dbe34#11dbe340f017d604f880b8467784cb4353ec1233" } dependencies = [ { name = "aiolimiter", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "anthropic", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, @@ -4291,8 +4241,8 @@ wheels = [ [[package]] name = "vllm" -version = "0.20.1rc1.dev97+g6acff036e.d20260513.precompiled" -source = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.20.1rc1.dev97+g6acff036e.d20260513.precompiled-cp312-cp312-linux_x86_64.whl" } +version = "0.20.1rc1.dev99+g77adbf599.precompiled" +source = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.20.1rc1.dev99+g77adbf599.precompiled-cp312-cp312-linux_x86_64.whl" } resolution-markers = [ "platform_machine == 'x86_64' and sys_platform == 'linux'", ] @@ -4367,7 +4317,7 @@ dependencies = [ { name = "xgrammar", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, ] wheels = [ - { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.20.1rc1.dev97+g6acff036e.d20260513.precompiled-cp312-cp312-linux_x86_64.whl", hash = "sha256:8db1e80a5f4dd97237d7c5702b33f37a65910db9976b42db4f58937ddd0ffd48" }, + { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.20.1rc1.dev99+g77adbf599.precompiled-cp312-cp312-linux_x86_64.whl", hash = "sha256:f0dbd42c86463f2952b1d5ff637c1e3cb1b8338686680a1bc4517d2e83d2fdd3" }, ] [package.metadata] @@ -4641,8 +4591,8 @@ provides-extras = ["zen", "bench", "tensorizer", "fastsafetensors", "instanttens [[package]] name = "vllm-router" -version = "0.1.23" -source = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm_router-0.1.23-cp38-abi3-linux_x86_64.whl" } +version = "0.1.22" +source = { url = "https://github.com/PrimeIntellect-ai/router/releases/download/v0.1.22/vllm_router-0.1.22-cp38-abi3-manylinux_2_28_x86_64.whl" } dependencies = [ { name = "aiohttp", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "fastapi", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, @@ -4652,7 +4602,7 @@ dependencies = [ { name = "uvicorn", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, ] wheels = [ - { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm_router-0.1.23-cp38-abi3-linux_x86_64.whl", hash = "sha256:306525b002ec3d8652fcbfaf37cfa46f1fde48180e01ffa0efa0e55d952bbfc2" }, + { url = "https://github.com/PrimeIntellect-ai/router/releases/download/v0.1.22/vllm_router-0.1.22-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:6361a0387241e56932f3ba2e51af27f58d11a462e3187e58286b2f96056e4d15" }, ] [package.metadata] @@ -4783,16 +4733,6 @@ wheels = [ { url = "https://hub.primeintellect.ai/primeintellect/wiki-search/@10d58ffe/wiki_search-0.1.23-py3-none-any.whl", hash = "sha256:ffeff890f2d14d7b2910baf57c27f6939da0f669ae0c4545916762f3f4edd75b" }, ] -[[package]] -name = "wordle" -version = "0.1.7" -source = { git = "https://github.com/PrimeIntellect-ai/verifiers?subdirectory=environments%2Fwordle&rev=bb54a3e#bb54a3efc5ee1d48b2328fc2d07c690fb50cd7e0" } -dependencies = [ - { name = "nltk", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "textarena", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "verifiers", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, -] - [[package]] name = "xgrammar" version = "0.1.33" From 2c019e1a85840c6e5f9fce3da543fc7ea4aa4c1e Mon Sep 17 00:00:00 2001 From: S1ro1 Date: Thu, 14 May 2026 20:39:44 +0530 Subject: [PATCH 07/32] fix: keep routed experts transport first class --- .../vllm/serving_chat_with_tokens.py | 8 +- src/prime_rl/inference/vllm/serving_tokens.py | 23 +++-- src/prime_rl/orchestrator/trajectories.py | 21 +++-- src/prime_rl/trainer/batch.py | 84 +++++++++---------- src/prime_rl/trainer/rl/data.py | 11 ++- src/prime_rl/transport/__init__.py | 3 +- src/prime_rl/transport/types.py | 16 ++-- tests/unit/inference/test_serving_tokens.py | 16 +++- tests/unit/orchestrator/test_batch.py | 34 ++++---- tests/unit/orchestrator/test_trajectories.py | 6 +- 10 files changed, 120 insertions(+), 102 deletions(-) diff --git a/src/prime_rl/inference/vllm/serving_chat_with_tokens.py b/src/prime_rl/inference/vllm/serving_chat_with_tokens.py index 4895d333f6..c78a76bde8 100644 --- a/src/prime_rl/inference/vllm/serving_chat_with_tokens.py +++ b/src/prime_rl/inference/vllm/serving_chat_with_tokens.py @@ -25,7 +25,7 @@ class ChatCompletionRequestWithTokens(ChatCompletionRequest): class OpenAIServingChatWithTokens(OpenAIServingChat): - """OpenAI-compatible chat API that allows token-in requests.""" + """OpenAI-compatible generate API that allows token-in and routed experts capture.""" async def chat_completion_full_generator( self, @@ -38,6 +38,12 @@ async def chat_completion_full_generator( request_metadata: RequestResponseMetadata, reasoning_parser: ReasoningParser | None = None, ) -> ErrorResponse | ChatCompletionResponse: + # We need to override the full_generator to be able to capture the routed experts + # By default, VLLM does not save the routed experts into ChatCompletionResponse.choices, so we need to capture them manually + # How this works: + # 1. We create a custom generator that encapsulates the original result_generator in self._generator + # 2. We override it's __aiter__ method to also capture the routed experts as an extra field in ChatCompletionResponse.choices + # 3. We override the full_generator method to use the custom generator instead of the original one if expert routing is enabled capture = None if self.model_config.enable_return_routed_experts: capture = RoutedExpertsCapture(result_generator) diff --git a/src/prime_rl/inference/vllm/serving_tokens.py b/src/prime_rl/inference/vllm/serving_tokens.py index 19a43ca33a..9d8a138d1b 100644 --- a/src/prime_rl/inference/vllm/serving_tokens.py +++ b/src/prime_rl/inference/vllm/serving_tokens.py @@ -54,11 +54,10 @@ class PrimeRlGenerateResponseChoice(GenerateResponseChoice): class PrimeRlGenerateResponse(GenerateResponse): choices: list[PrimeRlGenerateResponseChoice] - prompt_token_ids: list[int] class _GenerateRoutedExpertsCapture(RoutedExpertsCapture): - def post_process(self, response: GenerateResponse, prompt_token_ids: list[int]) -> PrimeRlGenerateResponse: + def post_process(self, response: GenerateResponse) -> PrimeRlGenerateResponse: choices = [ PrimeRlGenerateResponseChoice( **choice.model_dump(), @@ -69,7 +68,6 @@ def post_process(self, response: GenerateResponse, prompt_token_ids: list[int]) return PrimeRlGenerateResponse( request_id=response.request_id, choices=choices, - prompt_token_ids=prompt_token_ids, prompt_logprobs=response.prompt_logprobs, kv_transfer_params=response.kv_transfer_params, ) @@ -126,11 +124,13 @@ async def serve_tokens( self, request: GenerateRequest, raw_request: Request | None = None, - ) -> GenerateResponse | ErrorResponse | AsyncGenerator[str, None]: + ) -> PrimeRlGenerateResponse | ErrorResponse | AsyncGenerator[str, None]: # Mirrors upstream ``ServingTokens.serve_tokens`` (vllm 0.20). Diffs: # (a) inject ``data_parallel_rank`` from the inbound header into # ``engine_client.generate``; (b) default ``sampling_params.max_tokens`` - # to ``max_model_len - prompt_len`` when the caller didn't set it. + # to ``max_model_len - prompt_len`` when the caller didn't set it; and + # (c) dispatch to our overridden response builder so ``routed_experts`` + # makes it into the JSON. error_check_ret = await self._check_model(request) if error_check_ret is not None: return error_check_ret @@ -223,6 +223,9 @@ async def serve_tokens( ) if request.stream: + # Streaming path: defer to upstream — prime-RL's renderer client + # only consumes the full response, so adding routed_experts to the + # streaming choice schema is unnecessary churn. return self.serve_tokens_stream_generator( request, result_generator, @@ -243,7 +246,13 @@ async def serve_tokens_full_generator( # type: ignore[override] model_name: str, request_metadata: RequestResponseMetadata, ) -> ErrorResponse | GenerateResponse: - capture = None + # Mirror serving_chat_with_tokens: wrap the result generator to capture + # routed_experts as it streams, defer the rest to upstream, then post- + # process the response into our PrimeRlGenerateResponse subclass so the + # encoded experts surface in the JSON. Skipping the wrapper when the + # engine isn't producing routed experts keeps us a no-op subclass on + # the common path. + capture: _GenerateRoutedExpertsCapture | None = None if self.model_config.enable_return_routed_experts: capture = _GenerateRoutedExpertsCapture(result_generator) result_generator = capture @@ -253,6 +262,6 @@ async def serve_tokens_full_generator( # type: ignore[override] ) if capture is not None and isinstance(response, GenerateResponse): - response = capture.post_process(response, request.token_ids) + response = capture.post_process(response) return response diff --git a/src/prime_rl/orchestrator/trajectories.py b/src/prime_rl/orchestrator/trajectories.py index 26c86c5b55..18238c2c9d 100644 --- a/src/prime_rl/orchestrator/trajectories.py +++ b/src/prime_rl/orchestrator/trajectories.py @@ -12,7 +12,7 @@ from PIL import Image from transformers.tokenization_utils import PreTrainedTokenizer -from prime_rl.transport import TrainingSample +from prime_rl.transport import RoutedExperts, TrainingSample from prime_rl.utils.chat_template import ( common_prefix_len, deserialize_tool_calls, @@ -56,16 +56,15 @@ def _align_routed_experts( return np.concatenate((routed_experts, padding), axis=0) -def _set_sample_routed_experts(sample: TrainingSample, routed_experts: np.ndarray | None) -> None: +def _pack_routed_experts(routed_experts: np.ndarray | None) -> RoutedExperts | None: if routed_experts is None: - sample.routed_experts = None - sample.routed_experts_shape = None - sample.routed_experts_dtype = None - return + return None routed_experts = np.ascontiguousarray(routed_experts) - sample.routed_experts = routed_experts.tobytes() - sample.routed_experts_shape = list(routed_experts.shape) - sample.routed_experts_dtype = str(routed_experts.dtype) + return RoutedExperts( + data=routed_experts.tobytes(), + shape=list(routed_experts.shape), + dtype=str(routed_experts.dtype), + ) def _common_prefix_len(a: list[int], b: list[int]) -> int: @@ -359,9 +358,9 @@ def make_sample(tokens: dict[str, Any]) -> tuple[TrainingSample, np.ndarray | No completion_temperatures=[temperature] * len(completion_ids), teacher_logprobs=None, advantage=None, + routed_experts=_pack_routed_experts(routed_experts), mm_token_type_ids=None, ) - _set_sample_routed_experts(sample, routed_experts) return sample, routed_experts def extend_sample( @@ -401,7 +400,7 @@ def extend_sample( sample_routed_experts = np.concatenate((sample_routed_experts, step_routed[prefix_len:]), axis=0) expected_len = len(sample.prompt_ids) + len(sample.completion_ids) sample_routed_experts = _align_routed_experts(sample_routed_experts, expected_len) - _set_sample_routed_experts(sample, sample_routed_experts) + sample.routed_experts = _pack_routed_experts(sample_routed_experts) return sample_routed_experts # Track [prefix_tokens, sample, last_step_idx, routed_experts] per active sample diff --git a/src/prime_rl/trainer/batch.py b/src/prime_rl/trainer/batch.py index b19f702b6e..ca248a43d4 100644 --- a/src/prime_rl/trainer/batch.py +++ b/src/prime_rl/trainer/batch.py @@ -1,6 +1,6 @@ import copy -from prime_rl.transport.types import MicroBatch, TrainingSample +from prime_rl.transport.types import MicroBatch, RoutedExperts, TrainingSample ROUTED_EXPERTS_DTYPE_ITEMSIZE = { "uint8": 1, @@ -9,35 +9,44 @@ } -def _routed_experts_row_size(shape: list[int], dtype: str) -> int: - return shape[1] * shape[2] * ROUTED_EXPERTS_DTYPE_ITEMSIZE[dtype] +def _copy_routed_experts(routed_experts: RoutedExperts) -> RoutedExperts: + return RoutedExperts( + data=routed_experts.data, + shape=list(routed_experts.shape), + dtype=routed_experts.dtype, + ) + + +def _routed_experts_row_size(routed_experts: RoutedExperts) -> int: + return routed_experts.shape[1] * routed_experts.shape[2] * ROUTED_EXPERTS_DTYPE_ITEMSIZE[routed_experts.dtype] -def _slice_routed_experts(data: bytes, shape: list[int], dtype: str, seq_len: int) -> tuple[bytes, list[int]]: - row_size = _routed_experts_row_size(shape, dtype) - return data[: seq_len * row_size], [seq_len, shape[1], shape[2]] +def _slice_routed_experts(routed_experts: RoutedExperts, seq_len: int) -> RoutedExperts: + row_size = _routed_experts_row_size(routed_experts) + return RoutedExperts( + data=routed_experts.data[: seq_len * row_size], + shape=[seq_len, routed_experts.shape[1], routed_experts.shape[2]], + dtype=routed_experts.dtype, + ) def _append_routed_experts(dst: MicroBatch, src: MicroBatch) -> None: - assert dst.routed_experts is not None - assert dst.routed_experts_shape is not None - assert dst.routed_experts_dtype is not None - assert src.routed_experts is not None - assert src.routed_experts_shape is not None - assert src.routed_experts_dtype is not None - assert dst.routed_experts_dtype == src.routed_experts_dtype - assert dst.routed_experts_shape[1:] == src.routed_experts_shape[1:] - dst.routed_experts += src.routed_experts - dst.routed_experts_shape[0] += src.routed_experts_shape[0] + dst_routed = dst.routed_experts + src_routed = src.routed_experts + assert dst_routed is not None + assert src_routed is not None + assert dst_routed.dtype == src_routed.dtype + assert dst_routed.shape[1:] == src_routed.shape[1:] + dst_routed.data += src_routed.data + dst_routed.shape[0] += src_routed.shape[0] def _pad_routed_experts(micro_batch: MicroBatch, padding_size: int) -> None: - assert micro_batch.routed_experts is not None - assert micro_batch.routed_experts_shape is not None - assert micro_batch.routed_experts_dtype is not None - row_size = _routed_experts_row_size(micro_batch.routed_experts_shape, micro_batch.routed_experts_dtype) - micro_batch.routed_experts += b"\0" * (padding_size * row_size) - micro_batch.routed_experts_shape[0] += padding_size + routed_experts = micro_batch.routed_experts + assert routed_experts is not None + row_size = _routed_experts_row_size(routed_experts) + routed_experts.data += b"\0" * (padding_size * row_size) + routed_experts.shape[0] += padding_size def prepare_sample(training_example: TrainingSample, seq_len: int) -> MicroBatch: @@ -60,9 +69,9 @@ def prepare_sample(training_example: TrainingSample, seq_len: int) -> MicroBatch # Teacher logprobs already cover the full sequence (prompt + completion), # computed via prefill in the orchestrator when a teacher model is configured teacher_logprobs = training_example.teacher_logprobs - routed_experts = training_example.routed_experts - routed_experts_shape = training_example.routed_experts_shape - routed_experts_dtype = training_example.routed_experts_dtype + routed_experts = ( + _copy_routed_experts(training_example.routed_experts) if training_example.routed_experts is not None else None + ) if len(input_ids) > seq_len: input_ids = input_ids[:seq_len] @@ -74,11 +83,7 @@ def prepare_sample(training_example: TrainingSample, seq_len: int) -> MicroBatch if teacher_logprobs is not None: teacher_logprobs = teacher_logprobs[:seq_len] if routed_experts is not None: - assert routed_experts_shape is not None - assert routed_experts_dtype is not None - routed_experts, routed_experts_shape = _slice_routed_experts( - routed_experts, routed_experts_shape, routed_experts_dtype, seq_len - ) + routed_experts = _slice_routed_experts(routed_experts, seq_len) if mm_token_type_ids is not None: mm_token_type_ids = mm_token_type_ids[:seq_len] @@ -96,14 +101,10 @@ def prepare_sample(training_example: TrainingSample, seq_len: int) -> MicroBatch assert len(teacher_logprobs) == len(input_ids), f"teacher_logprobs: {len(teacher_logprobs)}" if routed_experts is not None: - assert routed_experts_shape is not None - assert routed_experts_dtype is not None - assert routed_experts_shape[0] == len(input_ids), ( - f"routed_experts: {routed_experts_shape}, input_ids: {len(input_ids)}" - ) - assert len(routed_experts) == len(input_ids) * _routed_experts_row_size( - routed_experts_shape, routed_experts_dtype + assert routed_experts.shape[0] == len(input_ids), ( + f"routed_experts: {routed_experts.shape}, input_ids: {len(input_ids)}" ) + assert len(routed_experts.data) == len(input_ids) * _routed_experts_row_size(routed_experts) if mm_token_type_ids is not None: assert len(mm_token_type_ids) == len(input_ids), ( @@ -119,8 +120,6 @@ def prepare_sample(training_example: TrainingSample, seq_len: int) -> MicroBatch teacher_logprobs=teacher_logprobs, temperatures=temperatures, routed_experts=routed_experts, - routed_experts_shape=routed_experts_shape, - routed_experts_dtype=routed_experts_dtype, mm_token_type_ids=mm_token_type_ids, # Multimodal fields (Qwen3-VL) - passed through without modification pixel_values=training_example.pixel_values, @@ -181,12 +180,7 @@ def packed_samples_into_micro_bs( bin_content.teacher_logprobs.extend(sample.teacher_logprobs) assert (bin_content.routed_experts is None) == (sample.routed_experts is None) if sample.routed_experts is not None: - if bin_content.routed_experts is None: - bin_content.routed_experts = sample.routed_experts - bin_content.routed_experts_shape = list(sample.routed_experts_shape) - bin_content.routed_experts_dtype = sample.routed_experts_dtype - else: - _append_routed_experts(bin_content, sample) + _append_routed_experts(bin_content, sample) if sample.mm_token_type_ids is not None: if bin_content.mm_token_type_ids is None: bin_content.mm_token_type_ids = [] diff --git a/src/prime_rl/trainer/rl/data.py b/src/prime_rl/trainer/rl/data.py index 43ef1bf0ce..cabd126f59 100644 --- a/src/prime_rl/trainer/rl/data.py +++ b/src/prime_rl/trainer/rl/data.py @@ -202,15 +202,14 @@ def _micro_batch_to_tensor(self, micro_batch: MicroBatch) -> TensorMicroBatch: micro_batch.lora_num_tokens = [0] * self.multi_run_manager.max_runs micro_batch.lora_num_tokens[0] = len(micro_batch.input_ids) routed_experts = None - if micro_batch.routed_experts is not None: - assert micro_batch.routed_experts_shape is not None - assert micro_batch.routed_experts_dtype is not None + packed_routed_experts = micro_batch.routed_experts + if packed_routed_experts is not None: routed_experts = ( torch.frombuffer( - micro_batch.routed_experts, - dtype=ROUTED_EXPERTS_TORCH_DTYPES[micro_batch.routed_experts_dtype], + packed_routed_experts.data, + dtype=ROUTED_EXPERTS_TORCH_DTYPES[packed_routed_experts.dtype], ) - .reshape(micro_batch.routed_experts_shape) + .reshape(packed_routed_experts.shape) .to(torch.int32) .unsqueeze(0) ) diff --git a/src/prime_rl/transport/__init__.py b/src/prime_rl/transport/__init__.py index e4c3153dc7..5108077920 100644 --- a/src/prime_rl/transport/__init__.py +++ b/src/prime_rl/transport/__init__.py @@ -8,7 +8,7 @@ FileSystemTrainingBatchReceiver, FileSystemTrainingBatchSender, ) -from prime_rl.transport.types import MicroBatch, TrainingBatch, TrainingSample +from prime_rl.transport.types import MicroBatch, RoutedExperts, TrainingBatch, TrainingSample from prime_rl.transport.zmq import ( ZMQMicroBatchReceiver, ZMQMicroBatchSender, @@ -64,6 +64,7 @@ def setup_micro_batch_receiver( "FileSystemMicroBatchReceiver", "MicroBatchReceiver", "MicroBatchSender", + "RoutedExperts", "TrainingSample", "TrainingBatch", "MicroBatch", diff --git a/src/prime_rl/transport/types.py b/src/prime_rl/transport/types.py index ff60e59978..cc943e9b76 100644 --- a/src/prime_rl/transport/types.py +++ b/src/prime_rl/transport/types.py @@ -1,6 +1,14 @@ import msgspec +# Routed experts are large per-token arrays. tolist() is too expensive, so we +# send raw bytes through msgpack and carry the shape/dtype needed to rebuild. +class RoutedExperts(msgspec.Struct, array_like=True, gc=False, omit_defaults=True): + data: bytes + shape: list[int] # [seq_len, layers, topk] + dtype: str + + # Orchestrator -> Packer class TrainingSample(msgspec.Struct, array_like=True, gc=False, omit_defaults=True): """A single training example.""" @@ -21,9 +29,7 @@ class TrainingSample(msgspec.Struct, array_like=True, gc=False, omit_defaults=Tr # image_grid_thw: grid dimensions [num_images, 3] where each entry is [temporal, height, width] image_grid_thw: list[list[int]] | None = None - routed_experts: bytes | None = None - routed_experts_shape: list[int] | None = None # [seq_len, layers, topk] - routed_experts_dtype: str | None = None + routed_experts: RoutedExperts | None = None # mm_token_type_ids: token type ids per token [batch seq], int64 (0=text, 1=image, 2=video) mm_token_type_ids: list[int] | None = None @@ -51,9 +57,7 @@ class MicroBatch(msgspec.Struct, array_like=True, gc=False, omit_defaults=True): temperatures: list[float] # Per-token temperatures used during generation teacher_logprobs: list[float] | None = None lora_num_tokens: list[int] | None = None - routed_experts: bytes | None = None - routed_experts_shape: list[int] | None = None # [seq_len, layers, topk] - routed_experts_dtype: str | None = None + routed_experts: RoutedExperts | None = None # Multimodal fields (Qwen3-VL) — pixel_values stored as raw float32 bytes for efficient serialization pixel_values: bytes | None = None diff --git a/tests/unit/inference/test_serving_tokens.py b/tests/unit/inference/test_serving_tokens.py index fbee225543..d88d8dff70 100644 --- a/tests/unit/inference/test_serving_tokens.py +++ b/tests/unit/inference/test_serving_tokens.py @@ -1,4 +1,12 @@ -"""Sanity tests for the prime-RL ``ServingTokens`` subclass.""" +"""Sanity tests for the prime-RL ``ServingTokens`` subclass. + +The full happy-path is owned upstream by vLLM 0.20's +``vllm/entrypoints/serve/disagg`` test suite. We only cover the prime-RL +deltas here: + * ``serialize_routed_experts`` round-trips a numpy array as expected. + * The subclass attaches its overrides without monkey-patching the parent. + * ``_client_set_max_tokens`` distinguishes raw-body shapes correctly. +""" from __future__ import annotations @@ -82,12 +90,12 @@ def test_client_set_max_tokens_detects_unset(): def test_client_set_max_tokens_assumes_set_when_body_unreadable(): - # No raw_request: can't tell, don't override. + # No raw_request → can't tell, don't override. assert asyncio.run(_client_set_max_tokens(None)) is True - # body read raises: can't tell, don't override. + # body read raises → can't tell, don't override. err = ValueError("bad json") assert asyncio.run(_client_set_max_tokens(_FakeRawRequest(err))) is True - # non-dict body: can't tell, don't override. + # non-dict body → can't tell, don't override. assert asyncio.run(_client_set_max_tokens(_FakeRawRequest([1, 2, 3]))) is True diff --git a/tests/unit/orchestrator/test_batch.py b/tests/unit/orchestrator/test_batch.py index 202fee7e92..fc95de4e2f 100644 --- a/tests/unit/orchestrator/test_batch.py +++ b/tests/unit/orchestrator/test_batch.py @@ -2,12 +2,16 @@ import pytest from prime_rl.trainer.batch import prepare_batch, prepare_sample -from prime_rl.transport.types import TrainingSample +from prime_rl.transport.types import RoutedExperts, TrainingSample def _routed_experts(data, dtype=np.uint8): routed_experts = np.asarray(data, dtype=dtype) - return routed_experts.tobytes(), list(routed_experts.shape), str(routed_experts.dtype) + return RoutedExperts( + data=routed_experts.tobytes(), + shape=list(routed_experts.shape), + dtype=str(routed_experts.dtype), + ) @pytest.fixture @@ -115,7 +119,7 @@ def test_prepare_sample_with_routed_experts(): """Routed experts are passed through prepare_sample and match input_ids length.""" # 2 prompt + 2 completion = 4 tokens, 2 layers, topk=2 routed_experts = [[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[0, 2], [1, 3]], [[1, 0], [3, 2]]] - routed_bytes, routed_shape, routed_dtype = _routed_experts(routed_experts) + routed = _routed_experts(routed_experts) sample = TrainingSample( prompt_ids=[1, 2], prompt_mask=[False, False], @@ -124,23 +128,21 @@ def test_prepare_sample_with_routed_experts(): completion_logprobs=[-0.1, -0.2], completion_temperatures=[1.0, 1.0], advantage=1.0, - routed_experts=routed_bytes, - routed_experts_shape=routed_shape, - routed_experts_dtype=routed_dtype, + routed_experts=routed, ) micro_batch = prepare_sample(sample, seq_len=8) assert micro_batch.routed_experts is not None - assert micro_batch.routed_experts == routed_bytes - assert micro_batch.routed_experts_shape == routed_shape - assert micro_batch.routed_experts_dtype == routed_dtype + assert micro_batch.routed_experts.data == routed.data + assert micro_batch.routed_experts.shape == routed.shape + assert micro_batch.routed_experts.dtype == routed.dtype def test_prepare_sample_truncates_routed_experts(): """Routed experts are truncated to seq_len when input exceeds it.""" routed_experts = [[[0, 1]], [[2, 3]], [[4, 5]], [[6, 7]]] - routed_bytes, routed_shape, routed_dtype = _routed_experts(routed_experts) - expected_bytes, expected_shape, _ = _routed_experts(routed_experts[:3]) + routed = _routed_experts(routed_experts) + expected = _routed_experts(routed_experts[:3]) sample = TrainingSample( prompt_ids=[1, 2], prompt_mask=[False, False], @@ -149,16 +151,14 @@ def test_prepare_sample_truncates_routed_experts(): completion_logprobs=[-0.1, -0.2], completion_temperatures=[1.0, 1.0], advantage=1.0, - routed_experts=routed_bytes, - routed_experts_shape=routed_shape, - routed_experts_dtype=routed_dtype, + routed_experts=routed, ) micro_batch = prepare_sample(sample, seq_len=3) assert micro_batch.routed_experts is not None - assert micro_batch.routed_experts == expected_bytes - assert micro_batch.routed_experts_shape == expected_shape - assert micro_batch.routed_experts_dtype == routed_dtype + assert micro_batch.routed_experts.data == expected.data + assert micro_batch.routed_experts.shape == expected.shape + assert micro_batch.routed_experts.dtype == expected.dtype def test_prepare_sample_none_routed_experts(): diff --git a/tests/unit/orchestrator/test_trajectories.py b/tests/unit/orchestrator/test_trajectories.py index ff77b73e4f..7bcb971556 100644 --- a/tests/unit/orchestrator/test_trajectories.py +++ b/tests/unit/orchestrator/test_trajectories.py @@ -39,10 +39,8 @@ def _routed_experts_payload(data, dtype=np.uint8) -> str: def _sample_routed_experts(sample) -> np.ndarray: assert sample.routed_experts is not None - assert sample.routed_experts_shape is not None - assert sample.routed_experts_dtype is not None - return np.frombuffer(sample.routed_experts, dtype=np.dtype(sample.routed_experts_dtype)).reshape( - sample.routed_experts_shape + return np.frombuffer(sample.routed_experts.data, dtype=np.dtype(sample.routed_experts.dtype)).reshape( + sample.routed_experts.shape ) From 803b4ae2dd9c406a62c0d17e88e134370429576a Mon Sep 17 00:00:00 2001 From: S1ro1 Date: Thu, 14 May 2026 20:46:33 +0530 Subject: [PATCH 08/32] fix: keep routed experts on samples --- src/prime_rl/inference/patches.py | 31 ++++++++++++++ src/prime_rl/orchestrator/trajectories.py | 51 ++++++++++------------- src/prime_rl/transport/__init__.py | 3 +- 3 files changed, 54 insertions(+), 31 deletions(-) diff --git a/src/prime_rl/inference/patches.py b/src/prime_rl/inference/patches.py index c33f3ccf42..974aed5f82 100644 --- a/src/prime_rl/inference/patches.py +++ b/src/prime_rl/inference/patches.py @@ -18,6 +18,37 @@ def transformers_v5_compat(): monkey_patch_deep_gemm_ep_scatter() monkey_patch_deep_gemm_silu_mul_quant_int64() monkey_patch_dp_engine_core_pause_resume_deadlock() + monkey_patch_vllm_layerwise_reload_alias_buffers() + + +def monkey_patch_vllm_layerwise_reload_alias_buffers(): + # vLLM's layerwise reload materializes each buffer as an independent tensor + # and then copies it back into the original kernel storage. When a buffer + # aliases a parameter (e.g. NemotronH Mamba's mixer.conv_weights, a view of + # mixer.conv1d.weight), the buffer copy stamps garbage into the parameter's + # storage *after* the parameter has been correctly reloaded. Skip the copy + # for any buffer that shares storage with a parameter; _place_kernel_tensors + # re-registers the original view, which trivially reflects the parameter. + from vllm.logger import init_logger + from vllm.model_executor.model_loader.reload import layerwise as reload_layerwise + + logger = init_logger(__name__) + + def _copy_and_restore_kernel_tensors(layer: torch.nn.Module, info: reload_layerwise.LayerReloadingInfo): + assert info.kernel_tensors is not None + parameters, buffers = info.kernel_tensors + param_storage_ptrs = {p.untyped_storage().data_ptr() for p in layer.parameters(recurse=True)} + for name, param in parameters.items(): + param.data.copy_(getattr(layer, name)) + for name, buffer in buffers.items(): + if buffer.untyped_storage().data_ptr() in param_storage_ptrs: + continue + buffer.data.copy_(getattr(layer, name)) + + reload_layerwise._place_kernel_tensors(layer, info) + + reload_layerwise._copy_and_restore_kernel_tensors = _copy_and_restore_kernel_tensors + logger.warning("Enabled vLLM layerwise reload alias-buffer patch.") @triton.jit diff --git a/src/prime_rl/orchestrator/trajectories.py b/src/prime_rl/orchestrator/trajectories.py index 18238c2c9d..029e957f94 100644 --- a/src/prime_rl/orchestrator/trajectories.py +++ b/src/prime_rl/orchestrator/trajectories.py @@ -12,7 +12,8 @@ from PIL import Image from transformers.tokenization_utils import PreTrainedTokenizer -from prime_rl.transport import RoutedExperts, TrainingSample +from prime_rl.transport import TrainingSample +from prime_rl.transport.types import RoutedExperts from prime_rl.utils.chat_template import ( common_prefix_len, deserialize_tool_calls, @@ -67,6 +68,10 @@ def _pack_routed_experts(routed_experts: np.ndarray | None) -> RoutedExperts | N ) +def _unpack_routed_experts(routed_experts: RoutedExperts) -> np.ndarray: + return np.frombuffer(routed_experts.data, dtype=np.dtype(routed_experts.dtype)).reshape(routed_experts.shape).copy() + + def _common_prefix_len(a: list[int], b: list[int]) -> int: return common_prefix_len(a, b) @@ -336,7 +341,7 @@ def prepare_step_tokens(step: vf.TrajectoryStep, step_idx: int) -> dict[str, Any return None prepared_steps.append(prepared) - def make_sample(tokens: dict[str, Any]) -> tuple[TrainingSample, np.ndarray | None]: + def make_sample(tokens: dict[str, Any]) -> TrainingSample: """Create a new TrainingSample from a trajectory step.""" if has_error: completion_mask = [False] * len(tokens["completion_mask"]) @@ -361,14 +366,9 @@ def make_sample(tokens: dict[str, Any]) -> tuple[TrainingSample, np.ndarray | No routed_experts=_pack_routed_experts(routed_experts), mm_token_type_ids=None, ) - return sample, routed_experts - - def extend_sample( - sample: TrainingSample, - sample_routed_experts: np.ndarray | None, - prefix_len: int, - step_idx: int, - ) -> np.ndarray | None: + return sample + + def extend_sample(sample: TrainingSample, prefix_len: int, step_idx: int) -> None: """Extend an existing sample with a new trajectory step (extension property holds).""" tokens = prepared_steps[step_idx] @@ -389,8 +389,9 @@ def extend_sample( sample.completion_logprobs.extend(tokens["completion_logprobs"]) sample.completion_temperatures.extend([temperature] * len(completion_ids)) - if tokens.get("routed_experts") is not None and sample_routed_experts is not None: + if tokens.get("routed_experts") is not None and sample.routed_experts is not None: step_routed = tokens["routed_experts"] + sample_routed_experts = _unpack_routed_experts(sample.routed_experts) # The previous step's last routing entry was zero-padded by _align_routed_experts # (vLLM only captures num_tokens-1 routings per request). This step actually # processed that boundary token as part of its prompt, so replace the zero-fill @@ -401,15 +402,13 @@ def extend_sample( expected_len = len(sample.prompt_ids) + len(sample.completion_ids) sample_routed_experts = _align_routed_experts(sample_routed_experts, expected_len) sample.routed_experts = _pack_routed_experts(sample_routed_experts) - return sample_routed_experts - # Track [prefix_tokens, sample, last_step_idx, routed_experts] per active sample - active_samples: list[tuple[list[int], TrainingSample, int, np.ndarray | None]] = [] + # Track [prefix_tokens, sample, last_step_idx] per active sample + active_samples: list[tuple[list[int], TrainingSample, int]] = [] first_tokens = prepared_steps[0] first_prefix = first_tokens["prompt_ids"] + first_tokens["completion_ids"] - first_sample, first_routed_experts = make_sample(first_tokens) - active_samples.append((first_prefix, first_sample, 0, first_routed_experts)) + active_samples.append((first_prefix, make_sample(first_tokens), 0)) for step_idx, _step in enumerate(trajectory[1:], start=1): tokens = prepared_steps[step_idx] @@ -417,21 +416,16 @@ def extend_sample( # Check if this step extends ANY active prefix matched_idx = None - for idx, (prefix_tokens, _, _, _) in enumerate(active_samples): + for idx, (prefix_tokens, _, _) in enumerate(active_samples): if step_prompt_ids[: len(prefix_tokens)] == prefix_tokens: matched_idx = idx break if matched_idx is not None: # Extension holds - merge into matched sample - prefix_tokens, sample, _, sample_routed_experts = active_samples[matched_idx] - sample_routed_experts = extend_sample(sample, sample_routed_experts, len(prefix_tokens), step_idx=step_idx) - active_samples[matched_idx] = ( - tokens["prompt_ids"] + tokens["completion_ids"], - sample, - step_idx, - sample_routed_experts, - ) + prefix_tokens, sample, _ = active_samples[matched_idx] + extend_sample(sample, len(prefix_tokens), step_idx=step_idx) + active_samples[matched_idx] = (tokens["prompt_ids"] + tokens["completion_ids"], sample, step_idx) else: # No prefix matches - start a new sample logger.debug( @@ -439,8 +433,7 @@ def extend_sample( f"Starting new sample (active_prefixes={len(active_samples)}, step_prompt_len={len(step_prompt_ids)})." ) new_prefix = tokens["prompt_ids"] + tokens["completion_ids"] - sample, routed_experts = make_sample(tokens) - active_samples.append((new_prefix, sample, step_idx, routed_experts)) + active_samples.append((new_prefix, make_sample(tokens), step_idx)) # Attach images once per sample using only the last merged step. Prompt # tokens already contain fully expanded <|image_pad|> placeholders because @@ -449,7 +442,7 @@ def extend_sample( # fallback path so features and tokens stay 1:1. if vlm_cache is not None: key = output["example_id"] if cache_key is None else cache_key - for _, sample, last_step_idx, _ in active_samples: + for _, sample, last_step_idx in active_samples: pv, shape, grids = vlm_cache.get_for_step(key, last_step_idx) sample.pixel_values = pv sample.pixel_values_shape = shape @@ -459,7 +452,7 @@ def extend_sample( mm_token_type_ids_mapping.get(token_id, 0) for token_id in sample.prompt_ids + sample.completion_ids ] - return [sample for _, sample, _, _ in active_samples] + return [sample for _, sample, _ in active_samples] # ============================================================================= diff --git a/src/prime_rl/transport/__init__.py b/src/prime_rl/transport/__init__.py index 5108077920..e4c3153dc7 100644 --- a/src/prime_rl/transport/__init__.py +++ b/src/prime_rl/transport/__init__.py @@ -8,7 +8,7 @@ FileSystemTrainingBatchReceiver, FileSystemTrainingBatchSender, ) -from prime_rl.transport.types import MicroBatch, RoutedExperts, TrainingBatch, TrainingSample +from prime_rl.transport.types import MicroBatch, TrainingBatch, TrainingSample from prime_rl.transport.zmq import ( ZMQMicroBatchReceiver, ZMQMicroBatchSender, @@ -64,7 +64,6 @@ def setup_micro_batch_receiver( "FileSystemMicroBatchReceiver", "MicroBatchReceiver", "MicroBatchSender", - "RoutedExperts", "TrainingSample", "TrainingBatch", "MicroBatch", From 9092eca6edeefa2e825970bb930eb4b31553db96 Mon Sep 17 00:00:00 2001 From: S1ro1 Date: Thu, 14 May 2026 21:11:00 +0530 Subject: [PATCH 09/32] fix: use upstream vllm nightly wheel --- pyproject.toml | 3 +- src/prime_rl/inference/patches.py | 17 ++- uv.lock | 203 ++++++++++++++++++++++-------- 3 files changed, 167 insertions(+), 56 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index dc99465e41..b4dad25184 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -130,6 +130,7 @@ override-dependencies = [ [tool.uv.exclude-newer-package] # we want latest vllm, remove next patch vllm = false +tokenspeed-mla = false flash_attn_3 = false # Self-vendored packages on our primeintellect index reverse-text = false @@ -175,7 +176,7 @@ pydantic-config = { git = "https://github.com/samsja/pydantic_config.git", branc # TODO: update router wheel when the routed-experts P/D stitching release is ready. vllm-router = { url = "https://github.com/PrimeIntellect-ai/router/releases/download/v0.1.22/vllm_router-0.1.22-cp38-abi3-manylinux_2_28_x86_64.whl" } vllm = [ - { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.20.1rc1.dev99+g77adbf599.precompiled-cp312-cp312-linux_x86_64.whl", marker = "platform_machine == 'x86_64'" }, + { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.20.2rc1.dev354+g24337fb86.cu129-cp38-abi3-manylinux_2_34_x86_64.whl", marker = "platform_machine == 'x86_64'" }, { url = "https://github.com/vllm-project/vllm/releases/download/v0.20.2/vllm-0.20.2+cu129-cp38-abi3-manylinux_2_31_aarch64.whl", marker = "platform_machine == 'aarch64'" }, ] reverse-text = { index = "primeintellect" } diff --git a/src/prime_rl/inference/patches.py b/src/prime_rl/inference/patches.py index 974aed5f82..b0b0146edc 100644 --- a/src/prime_rl/inference/patches.py +++ b/src/prime_rl/inference/patches.py @@ -897,9 +897,9 @@ def monkey_patch_dp_engine_core_pause_resume_deadlock(): - on resume, wake every DP rank and force an immediate global unfinished sync instead of waiting for the normal 32-step cadence - This keeps the upstream pause-side fix from - https://github.com/vllm-project/vllm/pull/37024 and extends it with the - resume-side wave-state fix. + This also bypasses vLLM's two-phase DP pause implementation + (https://github.com/vllm-project/vllm/pull/39366), which makes resume + reject states that our weight-update flow can validly hit. """ from vllm.config import ParallelConfig from vllm.v1.core.sched.interface import PauseState @@ -909,7 +909,8 @@ def monkey_patch_dp_engine_core_pause_resume_deadlock(): _base_add_request = EngineCore.add_request _base_handle_client_request = EngineCoreProc._handle_client_request - _base_resume_scheduler = DPEngineCoreProc.resume_scheduler + _base_pause_complete = EngineCoreProc._pause_complete + _base_resume_scheduler = EngineCoreProc.resume_scheduler def _patched_add_request(self, request: Request, request_wave: int = 0): _base_add_request(self, request, request_wave) @@ -930,8 +931,15 @@ def _patched_handle_client_request(self, request_type, request): else: _base_handle_client_request(self, request_type, request) + def _patched_pause_complete(self) -> bool: + self.pending_pause = False + self.ignore_start_dp_wave = False + return _base_pause_complete(self) + def _patched_resume_scheduler(self): was_paused = self.scheduler.pause_state != PauseState.UNPAUSED + self.pending_pause = False + self.ignore_start_dp_wave = False _base_resume_scheduler(self) if was_paused: self.engines_running = True @@ -948,6 +956,7 @@ def _patched_has_global_unfinished_reqs(self, local_unfinished: bool) -> bool: DPEngineCoreProc.add_request = _patched_add_request DPEngineCoreProc._handle_client_request = _patched_handle_client_request + DPEngineCoreProc._pause_complete = _patched_pause_complete DPEngineCoreProc.resume_scheduler = _patched_resume_scheduler DPEngineCoreProc._has_global_unfinished_reqs = _patched_has_global_unfinished_reqs diff --git a/uv.lock b/uv.lock index ea168dd95c..e7cbfd3448 100644 --- a/uv.lock +++ b/uv.lock @@ -11,38 +11,39 @@ supported-markers = [ ] [options] -exclude-newer = "2026-05-07T14:12:41.778927491Z" +exclude-newer = "2026-05-07T15:38:18.75074342Z" exclude-newer-span = "P7D" [options.exclude-newer-package] -vllm = false verifiers = false -vllm-router = false dion = false alphabet-sort = false science-env = false -color-codeword = false -nixl-cu12 = false -flash-attn-3 = false -prime-tunnel = false -prime-sandboxes = false -deep-gemm = false -aime2024 = false prime-evals = false deepdive = false -prime = false reverse-text = false code-env = false mini-swe-agent-plus = false deep-ep = false pydantic-config = false renderers = false -math-env = false -logic-env = false wiki-search = false math-python = false math500 = false aime2025 = false +vllm = false +vllm-router = false +color-codeword = false +nixl-cu12 = false +flash-attn-3 = false +prime-tunnel = false +deep-gemm = false +aime2024 = false +tokenspeed-mla = false +math-env = false +logic-env = false +prime-sandboxes = false +prime = false [manifest] members = [ @@ -1890,15 +1891,18 @@ wheels = [ name = "mistral-common" version = "1.11.0" source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "platform_machine == 'aarch64' and sys_platform == 'linux'", +] dependencies = [ - { name = "jsonschema", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "numpy", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "pillow", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "pydantic", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "pydantic-extra-types", extra = ["pycountry"], marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "requests", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "tiktoken", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "typing-extensions", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "jsonschema", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "numpy", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "pillow", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "pydantic", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "pydantic-extra-types", extra = ["pycountry"], marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "requests", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "tiktoken", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "typing-extensions", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/61/97/753c85b5c0a19f4331ac99e0300ac8da06d4b29b629c9cb03064b38561bd/mistral_common-1.11.0.tar.gz", hash = "sha256:439b7fa38f9c3f020154af51bdf30eb81def507643017d8ce9f798384ec47ec3", size = 6355512, upload-time = "2026-04-01T13:54:12.36Z" } wheels = [ @@ -1907,7 +1911,34 @@ wheels = [ [package.optional-dependencies] image = [ - { name = "opencv-python-headless", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "opencv-python-headless", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, +] + +[[package]] +name = "mistral-common" +version = "1.11.2" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "platform_machine == 'x86_64' and sys_platform == 'linux'", +] +dependencies = [ + { name = "jsonschema", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "numpy", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "pillow", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "pydantic", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "pydantic-extra-types", extra = ["pycountry"], marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "requests", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "tiktoken", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "typing-extensions", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c2/eb/12167a1bea9714582e5b4f539f9c019323363e314a499c72855ff0e5ad43/mistral_common-1.11.2.tar.gz", hash = "sha256:79f68fc2d1190f28637f40e053f919c8c2697e00b2aa679ddee562a95183f4ad", size = 6357845, upload-time = "2026-05-04T19:47:40.413Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/47/f0/6a5d604b972e442b9d36c117d01788feddad099e4965699e3516ee6fefc3/mistral_common-1.11.2-py3-none-any.whl", hash = "sha256:ebb42062cd705a0aa2bc69b4cde2b83d446ae58150b7e29322c90cb08fcfca6c", size = 6531968, upload-time = "2026-05-04T19:47:37.718Z" }, +] + +[package.optional-dependencies] +image = [ + { name = "opencv-python-headless", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, ] [[package]] @@ -1964,20 +1995,44 @@ wheels = [ name = "model-hosting-container-standards" version = "0.1.13" source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "platform_machine == 'aarch64' and sys_platform == 'linux'", +] dependencies = [ - { name = "fastapi", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "httpx", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "jmespath", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "pydantic", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "setuptools", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "starlette", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "supervisor", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "fastapi", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "httpx", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "jmespath", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "pydantic", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "setuptools", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "starlette", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "supervisor", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/d7/b7/a6a31b4dfd30d14b1019dc358f09c9d88ca38e555ba7c976e7d3e6b593fe/model_hosting_container_standards-0.1.13.tar.gz", hash = "sha256:27a1333410dde2719286a300a2803e24fdde407baa91894eb845c0f268aa194d", size = 79116, upload-time = "2026-01-09T21:45:20.683Z" } wheels = [ { url = "https://files.pythonhosted.org/packages/8c/37/6dc61971ba31450bbed460b5f40543f0915e352680534e3bcaf57116d8d7/model_hosting_container_standards-0.1.13-py3-none-any.whl", hash = "sha256:be307d4a988cc660df4e6bd8bdedb7917844bac940e332f9fd001cb385d7994c", size = 105738, upload-time = "2026-01-09T21:45:18.959Z" }, ] +[[package]] +name = "model-hosting-container-standards" +version = "0.1.15" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "platform_machine == 'x86_64' and sys_platform == 'linux'", +] +dependencies = [ + { name = "fastapi", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "httpx", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "jmespath", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "pydantic", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "setuptools", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "starlette", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "supervisor", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/03/5a/d669bdeb5ba96db42c6ef010835a25119b05f8c35ee5f1c3f715626625fe/model_hosting_container_standards-0.1.15.tar.gz", hash = "sha256:ae8dd74d3250545c14f0a7068186c7b0f0ab6563d31e7137f556b6b660c8a6a9", size = 93994, upload-time = "2026-05-05T18:22:29.357Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/61/26/c7aea197f1719f31d0dd686eb4475982fe9efd7668ce259cb52b62c676b6/model_hosting_container_standards-0.1.15-py3-none-any.whl", hash = "sha256:849e08c4732203ee861c8c24966b4e916ea4420fa324b430f7f74a1e1fe8811a", size = 125418, upload-time = "2026-05-05T18:22:27.819Z" }, +] + [[package]] name = "mpmath" version = "1.3.0" @@ -2782,7 +2837,7 @@ dependencies = [ { name = "transformers", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "uvloop", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "verifiers", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "vllm", version = "0.20.1rc1.dev99+g77adbf599.precompiled", source = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.20.1rc1.dev99+g77adbf599.precompiled-cp312-cp312-linux_x86_64.whl" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "vllm", version = "0.20.2rc1.dev354+g24337fb86.cu129", source = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.20.2rc1.dev354+g24337fb86.cu129-cp38-abi3-manylinux_2_34_x86_64.whl" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "vllm", version = "0.20.2+cu129", source = { url = "https://github.com/vllm-project/vllm/releases/download/v0.20.2/vllm-0.20.2+cu129-cp38-abi3-manylinux_2_31_aarch64.whl" }, marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, { name = "wandb", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] @@ -2910,7 +2965,7 @@ requires-dist = [ { name = "verifiers", git = "https://github.com/PrimeIntellect-ai/verifiers?rev=11dbe34" }, { name = "vllm", marker = "platform_machine != 'aarch64' and platform_machine != 'x86_64'" }, { name = "vllm", marker = "platform_machine == 'aarch64'", url = "https://github.com/vllm-project/vllm/releases/download/v0.20.2/vllm-0.20.2+cu129-cp38-abi3-manylinux_2_31_aarch64.whl" }, - { name = "vllm", marker = "platform_machine == 'x86_64'", url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.20.1rc1.dev99+g77adbf599.precompiled-cp312-cp312-linux_x86_64.whl" }, + { name = "vllm", marker = "platform_machine == 'x86_64'", url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.20.2rc1.dev354+g24337fb86.cu129-cp38-abi3-manylinux_2_34_x86_64.whl" }, { name = "vllm-router", marker = "platform_machine == 'x86_64' and extra == 'disagg'", url = "https://github.com/PrimeIntellect-ai/router/releases/download/v0.1.22/vllm_router-0.1.22-cp38-abi3-manylinux_2_28_x86_64.whl" }, { name = "wandb", specifier = ">=0.26.1" }, { name = "wiki-search", marker = "extra == 'envs'", index = "https://hub.primeintellect.ai/primeintellect/simple/" }, @@ -3858,6 +3913,28 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/05/a1/d62dfe7376beaaf1394917e0f8e93ee5f67fea8fcf4107501db35996586b/tokenizers-0.22.2-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:38337540fbbddff8e999d59970f3c6f35a82de10053206a7562f1ea02d046fa5", size = 10033429, upload-time = "2026-01-05T10:45:14.333Z" }, ] +[[package]] +name = "tokenspeed-mla" +version = "0.1.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "apache-tvm-ffi", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cutlass-dsl", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "tokenspeed-triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "torch", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/84/01/4bf8b74ead3e8e7c1c809435396254c067a33fde48acc20f602aae622d97/tokenspeed_mla-0.1.2-py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:c9466a351fe039792e56cf49f3e79744c1dc28c7af10306a02e62b8e92fa5985", size = 748681, upload-time = "2026-05-13T03:30:56.718Z" }, +] + +[[package]] +name = "tokenspeed-triton" +version = "3.7.10.post20260505" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/af/c3/4808d86016368fed9495c3a3408cc7f912e7863ff3432937404bd0a551a6/tokenspeed_triton-3.7.10.post20260505-cp312-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:19618c7db01a9bd33885f7acbf8945adb2f5534668aa97629b56d481753cbcad", size = 89127692, upload-time = "2026-05-05T07:49:04.22Z" }, +] + [[package]] name = "toml" version = "0.10.2" @@ -4241,8 +4318,8 @@ wheels = [ [[package]] name = "vllm" -version = "0.20.1rc1.dev99+g77adbf599.precompiled" -source = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.20.1rc1.dev99+g77adbf599.precompiled-cp312-cp312-linux_x86_64.whl" } +version = "0.20.2rc1.dev354+g24337fb86.cu129" +source = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.20.2rc1.dev354+g24337fb86.cu129-cp38-abi3-manylinux_2_34_x86_64.whl" } resolution-markers = [ "platform_machine == 'x86_64' and sys_platform == 'linux'", ] @@ -4269,8 +4346,8 @@ dependencies = [ { name = "llguidance", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "lm-format-enforcer", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "mcp", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "mistral-common", extra = ["image"], marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "model-hosting-container-standards", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "mistral-common", version = "1.11.2", source = { registry = "https://pypi.org/simple" }, extra = ["image"], marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "model-hosting-container-standards", version = "0.1.15", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "msgspec", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "ninja", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "numba", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, @@ -4307,6 +4384,7 @@ dependencies = [ { name = "tiktoken", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "tilelang", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "tokenizers", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "tokenspeed-mla", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "torch", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "torchaudio", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "torchvision", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, @@ -4314,10 +4392,10 @@ dependencies = [ { name = "transformers", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "typing-extensions", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "watchfiles", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "xgrammar", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "xgrammar", version = "0.2.0", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, ] wheels = [ - { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.20.1rc1.dev99+g77adbf599.precompiled-cp312-cp312-linux_x86_64.whl", hash = "sha256:f0dbd42c86463f2952b1d5ff637c1e3cb1b8338686680a1bc4517d2e83d2fdd3" }, + { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.20.2rc1.dev354+g24337fb86.cu129-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:a16f4fd2d468f0bb0afd84e3e96f4016654e8525892879909f7a095e33101668" }, ] [package.metadata] @@ -4351,14 +4429,14 @@ requires-dist = [ { name = "matplotlib", marker = "extra == 'bench'" }, { name = "mcp" }, { name = "mistral-common", extras = ["audio"], marker = "extra == 'audio'" }, - { name = "mistral-common", extras = ["image"], specifier = ">=1.11.0" }, - { name = "model-hosting-container-standards", specifier = ">=0.1.13,<1.0.0" }, + { name = "mistral-common", extras = ["image"], specifier = ">=1.11.2" }, + { name = "model-hosting-container-standards", specifier = ">=0.1.14,<1.0.0" }, { name = "msgspec" }, { name = "ninja" }, { name = "numba", specifier = "==0.65.0" }, { name = "numpy" }, { name = "nvidia-cudnn-frontend", specifier = ">=1.13.0,<1.19.0" }, - { name = "nvidia-cutlass-dsl", specifier = ">=4.4.2" }, + { name = "nvidia-cutlass-dsl", specifier = "==4.5.0" }, { name = "openai", specifier = ">=2.0.0" }, { name = "openai-harmony", specifier = ">=0.0.3" }, { name = "opencv-python-headless", specifier = ">=4.13.0" }, @@ -4402,6 +4480,7 @@ requires-dist = [ { name = "tiktoken", specifier = ">=0.6.0" }, { name = "tilelang", specifier = "==0.1.9" }, { name = "tokenizers", specifier = ">=0.21.1" }, + { name = "tokenspeed-mla", specifier = "==0.1.2" }, { name = "torch", specifier = "==2.11.0" }, { name = "torchaudio", specifier = "==2.11.0" }, { name = "torchvision", specifier = "==0.26.0" }, @@ -4409,7 +4488,7 @@ requires-dist = [ { name = "transformers", specifier = ">=4.56.0,!=5.0.*,!=5.1.*,!=5.2.*,!=5.3.*,!=5.4.*,!=5.5.0" }, { name = "typing-extensions", specifier = ">=4.10" }, { name = "watchfiles" }, - { name = "xgrammar", marker = "platform_machine == 'aarch64' or platform_machine == 'arm64' or platform_machine == 'ppc64le' or platform_machine == 's390x' or platform_machine == 'x86_64'", specifier = ">=0.1.32,<1.0.0" }, + { name = "xgrammar", marker = "platform_machine == 'aarch64' or platform_machine == 'arm64' or platform_machine == 'ppc64le' or platform_machine == 's390x' or platform_machine == 'x86_64'", specifier = ">=0.2.0,<1.0.0" }, { name = "zentorch-weekly", marker = "extra == 'zen'", specifier = "==5.2.1.dev20260408" }, ] provides-extras = ["zen", "bench", "tensorizer", "fastsafetensors", "instanttensor", "runai", "audio", "video", "flashinfer", "helion", "grpc", "otel"] @@ -4444,8 +4523,8 @@ dependencies = [ { name = "llguidance", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, { name = "lm-format-enforcer", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, { name = "mcp", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "mistral-common", extra = ["image"], marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "model-hosting-container-standards", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "mistral-common", version = "1.11.0", source = { registry = "https://pypi.org/simple" }, extra = ["image"], marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "model-hosting-container-standards", version = "0.1.13", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, { name = "msgspec", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, { name = "ninja", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, { name = "numba", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, @@ -4489,7 +4568,7 @@ dependencies = [ { name = "transformers", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, { name = "typing-extensions", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, { name = "watchfiles", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, - { name = "xgrammar", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "xgrammar", version = "0.1.33", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, ] wheels = [ { url = "https://github.com/vllm-project/vllm/releases/download/v0.20.2/vllm-0.20.2+cu129-cp38-abi3-manylinux_2_31_aarch64.whl", hash = "sha256:8a58a086c5c4ed2883eee36aaaf6b79c83463d02da3015454acf92afcc8e150e" }, @@ -4737,18 +4816,40 @@ wheels = [ name = "xgrammar" version = "0.1.33" source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "platform_machine == 'aarch64' and sys_platform == 'linux'", +] dependencies = [ - { name = "numpy", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "pydantic", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "torch", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "transformers", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, - { name = "typing-extensions", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "numpy", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "pydantic", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "torch", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "transformers", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, + { name = "typing-extensions", marker = "platform_machine == 'aarch64' and sys_platform == 'linux'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/db/43/e5dfddb1d2a4fccf3e3a88f103e88698cdefc3182f4e169a359ffe1c1794/xgrammar-0.1.33.tar.gz", hash = "sha256:8dbe5fc3d76651ab1fac7a68fc2a118b885fa0ec7189927fb6e0dce0081aea99", size = 2398956, upload-time = "2026-03-27T10:16:36.582Z" } wheels = [ { url = "https://files.pythonhosted.org/packages/4e/04/43d4baca876f5ae1b45897ec30a59801a2da37f16da1fcd85f9555e4c125/xgrammar-0.1.33-cp312-cp312-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9c803e60d791854c5d1f271ece7e1f34d73c82dd4a8b2a06b7af5331482a78ac", size = 42133168, upload-time = "2026-03-27T10:15:16.994Z" }, - { url = "https://files.pythonhosted.org/packages/f0/a8/672833a3cff027253793aa999401d8364896ebf396967e475c7a878b895f/xgrammar-0.1.33-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:52b8eaa533282a0efb0835db6998ae72e7b3c7875d7a52e360ffebff9b78c30a", size = 42205803, upload-time = "2026-03-27T10:15:21.599Z" }, +] + +[[package]] +name = "xgrammar" +version = "0.2.0" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "platform_machine == 'x86_64' and sys_platform == 'linux'", +] +dependencies = [ + { name = "apache-tvm-ffi", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "numpy", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "pydantic", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "torch", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "transformers", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "typing-extensions", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a0/54/7e593fc41ffcaf5ac7c0379e0aec0cf03e53a742d1a91f64c6c7e79a6ac1/xgrammar-0.2.0.tar.gz", hash = "sha256:c4f0238a89869343171d43d069b8c5da874f3c2c25f408f20cd5987219a6adef", size = 2421093, upload-time = "2026-05-01T18:33:54.474Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7e/30/99f4e83821db16d58dd41249ba46038ed47bce274c57ad5567030775fc62/xgrammar-0.2.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a36c744d24d93e178c138486aa02b390a80326b64ff11e222e063a028dd65849", size = 44616361, upload-time = "2026-05-01T18:32:42.536Z" }, ] [[package]] From f49caec2311a5ab0de90f90fec33966422a33bdb Mon Sep 17 00:00:00 2001 From: S1ro1 Date: Thu, 14 May 2026 21:23:38 +0530 Subject: [PATCH 10/32] fix: pin latest routed experts verifiers --- pyproject.toml | 2 +- uv.lock | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b4dad25184..3614e0c08b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -167,7 +167,7 @@ prime-rl-configs = { workspace = true } torch = { index = "pytorch-cu128" } torchvision = { index = "pytorch-cu128" } torchaudio = { index = "pytorch-cu128" } -verifiers = { git = "https://github.com/PrimeIntellect-ai/verifiers", rev = "11dbe34" } +verifiers = { git = "https://github.com/PrimeIntellect-ai/verifiers", rev = "48a203f" } torchtitan = { git = "https://github.com/pytorch/torchtitan", rev = "a1fdd7e" } dion = { git = "https://github.com/samsja/dion.git", rev = "d891eeb" } transformers = { git = "https://github.com/huggingface/transformers.git", rev = "c1c3424" } diff --git a/uv.lock b/uv.lock index e7cbfd3448..18ad2a47d8 100644 --- a/uv.lock +++ b/uv.lock @@ -11,7 +11,7 @@ supported-markers = [ ] [options] -exclude-newer = "2026-05-07T15:38:18.75074342Z" +exclude-newer = "2026-05-07T15:52:53.627120476Z" exclude-newer-span = "P7D" [options.exclude-newer-package] @@ -2962,7 +2962,7 @@ requires-dist = [ { name = "torchvision", index = "https://download.pytorch.org/whl/cu128" }, { name = "transformers", git = "https://github.com/huggingface/transformers.git?rev=c1c3424" }, { name = "uvloop", specifier = ">=0.21.0" }, - { name = "verifiers", git = "https://github.com/PrimeIntellect-ai/verifiers?rev=11dbe34" }, + { name = "verifiers", git = "https://github.com/PrimeIntellect-ai/verifiers?rev=48a203f" }, { name = "vllm", marker = "platform_machine != 'aarch64' and platform_machine != 'x86_64'" }, { name = "vllm", marker = "platform_machine == 'aarch64'", url = "https://github.com/vllm-project/vllm/releases/download/v0.20.2/vllm-0.20.2+cu129-cp38-abi3-manylinux_2_31_aarch64.whl" }, { name = "vllm", marker = "platform_machine == 'x86_64'", url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.20.2rc1.dev354+g24337fb86.cu129-cp38-abi3-manylinux_2_34_x86_64.whl" }, @@ -4275,7 +4275,7 @@ wheels = [ [[package]] name = "verifiers" version = "0.1.15.dev4" -source = { git = "https://github.com/PrimeIntellect-ai/verifiers?rev=11dbe34#11dbe340f017d604f880b8467784cb4353ec1233" } +source = { git = "https://github.com/PrimeIntellect-ai/verifiers?rev=48a203f#48a203feb82eeadce37ead5fc6d24d1cd7014575" } dependencies = [ { name = "aiolimiter", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "anthropic", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, From 61a03882b1a7c3deb2655495bdb9219e702a9c4f Mon Sep 17 00:00:00 2001 From: S1ro1 Date: Fri, 15 May 2026 00:27:54 +0530 Subject: [PATCH 11/32] fix: pin routed experts dependencies --- pyproject.toml | 5 ++--- uv.lock | 16 ++++++++-------- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3614e0c08b..0afe2049fe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -167,14 +167,13 @@ prime-rl-configs = { workspace = true } torch = { index = "pytorch-cu128" } torchvision = { index = "pytorch-cu128" } torchaudio = { index = "pytorch-cu128" } -verifiers = { git = "https://github.com/PrimeIntellect-ai/verifiers", rev = "48a203f" } +verifiers = { git = "https://github.com/PrimeIntellect-ai/verifiers", rev = "7fdf522" } torchtitan = { git = "https://github.com/pytorch/torchtitan", rev = "a1fdd7e" } dion = { git = "https://github.com/samsja/dion.git", rev = "d891eeb" } transformers = { git = "https://github.com/huggingface/transformers.git", rev = "c1c3424" } flash-attn-4 = { git = "https://github.com/Dao-AILab/flash-attention.git", subdirectory = "flash_attn/cute", rev = "96bd151" } pydantic-config = { git = "https://github.com/samsja/pydantic_config.git", branch = "main" } -# TODO: update router wheel when the routed-experts P/D stitching release is ready. -vllm-router = { url = "https://github.com/PrimeIntellect-ai/router/releases/download/v0.1.22/vllm_router-0.1.22-cp38-abi3-manylinux_2_28_x86_64.whl" } +vllm-router = { url = "https://github.com/PrimeIntellect-ai/router/releases/download/v0.1.24/vllm_router-0.1.24-cp38-abi3-manylinux_2_28_x86_64.whl" } vllm = [ { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.20.2rc1.dev354+g24337fb86.cu129-cp38-abi3-manylinux_2_34_x86_64.whl", marker = "platform_machine == 'x86_64'" }, { url = "https://github.com/vllm-project/vllm/releases/download/v0.20.2/vllm-0.20.2+cu129-cp38-abi3-manylinux_2_31_aarch64.whl", marker = "platform_machine == 'aarch64'" }, diff --git a/uv.lock b/uv.lock index 18ad2a47d8..1ccac8633d 100644 --- a/uv.lock +++ b/uv.lock @@ -11,7 +11,7 @@ supported-markers = [ ] [options] -exclude-newer = "2026-05-07T15:52:53.627120476Z" +exclude-newer = "2026-05-07T18:56:25.00241089Z" exclude-newer-span = "P7D" [options.exclude-newer-package] @@ -2962,11 +2962,11 @@ requires-dist = [ { name = "torchvision", index = "https://download.pytorch.org/whl/cu128" }, { name = "transformers", git = "https://github.com/huggingface/transformers.git?rev=c1c3424" }, { name = "uvloop", specifier = ">=0.21.0" }, - { name = "verifiers", git = "https://github.com/PrimeIntellect-ai/verifiers?rev=48a203f" }, + { name = "verifiers", git = "https://github.com/PrimeIntellect-ai/verifiers?rev=7fdf522" }, { name = "vllm", marker = "platform_machine != 'aarch64' and platform_machine != 'x86_64'" }, { name = "vllm", marker = "platform_machine == 'aarch64'", url = "https://github.com/vllm-project/vllm/releases/download/v0.20.2/vllm-0.20.2+cu129-cp38-abi3-manylinux_2_31_aarch64.whl" }, { name = "vllm", marker = "platform_machine == 'x86_64'", url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.20.2rc1.dev354+g24337fb86.cu129-cp38-abi3-manylinux_2_34_x86_64.whl" }, - { name = "vllm-router", marker = "platform_machine == 'x86_64' and extra == 'disagg'", url = "https://github.com/PrimeIntellect-ai/router/releases/download/v0.1.22/vllm_router-0.1.22-cp38-abi3-manylinux_2_28_x86_64.whl" }, + { name = "vllm-router", marker = "platform_machine == 'x86_64' and extra == 'disagg'", url = "https://github.com/PrimeIntellect-ai/router/releases/download/v0.1.24/vllm_router-0.1.24-cp38-abi3-manylinux_2_28_x86_64.whl" }, { name = "wandb", specifier = ">=0.26.1" }, { name = "wiki-search", marker = "extra == 'envs'", index = "https://hub.primeintellect.ai/primeintellect/simple/" }, ] @@ -4274,8 +4274,8 @@ wheels = [ [[package]] name = "verifiers" -version = "0.1.15.dev4" -source = { git = "https://github.com/PrimeIntellect-ai/verifiers?rev=48a203f#48a203feb82eeadce37ead5fc6d24d1cd7014575" } +version = "0.1.15.dev5" +source = { git = "https://github.com/PrimeIntellect-ai/verifiers?rev=7fdf522#7fdf52219347086c76f00b37659a3626562a58ec" } dependencies = [ { name = "aiolimiter", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "anthropic", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, @@ -4670,8 +4670,8 @@ provides-extras = ["zen", "bench", "tensorizer", "fastsafetensors", "instanttens [[package]] name = "vllm-router" -version = "0.1.22" -source = { url = "https://github.com/PrimeIntellect-ai/router/releases/download/v0.1.22/vllm_router-0.1.22-cp38-abi3-manylinux_2_28_x86_64.whl" } +version = "0.1.24" +source = { url = "https://github.com/PrimeIntellect-ai/router/releases/download/v0.1.24/vllm_router-0.1.24-cp38-abi3-manylinux_2_28_x86_64.whl" } dependencies = [ { name = "aiohttp", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "fastapi", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, @@ -4681,7 +4681,7 @@ dependencies = [ { name = "uvicorn", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, ] wheels = [ - { url = "https://github.com/PrimeIntellect-ai/router/releases/download/v0.1.22/vllm_router-0.1.22-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:6361a0387241e56932f3ba2e51af27f58d11a462e3187e58286b2f96056e4d15" }, + { url = "https://github.com/PrimeIntellect-ai/router/releases/download/v0.1.24/vllm_router-0.1.24-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:5b45d59871b73357ab1c3d48790de233e190055b6ace810aefb3cb1416fe0d00" }, ] [package.metadata] From 094d233a2e7ddd4633ef167463b6714e0ff2b285 Mon Sep 17 00:00:00 2001 From: S1ro1 Date: Fri, 15 May 2026 00:28:59 +0530 Subject: [PATCH 12/32] fix: allow routed experts with nixl --- src/prime_rl/inference/patches.py | 48 +++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/src/prime_rl/inference/patches.py b/src/prime_rl/inference/patches.py index b0b0146edc..a9d3bc6f6d 100644 --- a/src/prime_rl/inference/patches.py +++ b/src/prime_rl/inference/patches.py @@ -19,6 +19,54 @@ def transformers_v5_compat(): monkey_patch_deep_gemm_silu_mul_quant_int64() monkey_patch_dp_engine_core_pause_resume_deadlock() monkey_patch_vllm_layerwise_reload_alias_buffers() + monkey_patch_return_routed_experts_with_nixl_connector() + + +def monkey_patch_return_routed_experts_with_nixl_connector(): + from vllm import envs + from vllm.config.vllm import VllmConfig + from vllm.logger import init_logger + + logger = init_logger(__name__) + original_post_init = VllmConfig.__post_init__ + + if getattr(original_post_init, "_prime_rl_allows_nixl_routed_experts", False): + return + + def _is_nixl_routed_experts_pd_config(config: VllmConfig) -> bool: + kv_transfer_config = config.kv_transfer_config + return ( + config.model_config is not None + and config.model_config.enable_return_routed_experts + and kv_transfer_config is not None + and kv_transfer_config.kv_connector == "NixlConnector" + and kv_transfer_config.is_kv_transfer_instance + ) + + def _post_init(config: VllmConfig): + if not _is_nixl_routed_experts_pd_config(config): + return original_post_init(config) + + if config.parallel_config.pipeline_parallel_size > 1: + raise ValueError( + "--enable-return-routed-experts is incompatible with " + "pipeline parallelism (PP > 1)." + ) + if envs.VLLM_USE_V2_MODEL_RUNNER: + raise ValueError("VLLM_USE_V2_MODEL_RUNNER does not yet support: routed experts capture") + + # vLLM rejects every KV connector, but our P/D path uses NIXL and + # stitches prefill/decode routed experts in the router. CPU KV offload + # remains rejected by prime-rl config validation. + config.model_config.enable_return_routed_experts = False + try: + return original_post_init(config) + finally: + config.model_config.enable_return_routed_experts = True + + _post_init._prime_rl_allows_nixl_routed_experts = True + VllmConfig.__post_init__ = _post_init + logger.warning("Enabled vLLM routed-experts capture with NIXL connector patch.") def monkey_patch_vllm_layerwise_reload_alias_buffers(): From d6d06b44922e29966f6828baee5e33389c284a81 Mon Sep 17 00:00:00 2001 From: S1ro1 Date: Fri, 15 May 2026 00:31:00 +0530 Subject: [PATCH 13/32] style: format nixl patch --- src/prime_rl/inference/patches.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/prime_rl/inference/patches.py b/src/prime_rl/inference/patches.py index a9d3bc6f6d..780086be08 100644 --- a/src/prime_rl/inference/patches.py +++ b/src/prime_rl/inference/patches.py @@ -48,10 +48,7 @@ def _post_init(config: VllmConfig): return original_post_init(config) if config.parallel_config.pipeline_parallel_size > 1: - raise ValueError( - "--enable-return-routed-experts is incompatible with " - "pipeline parallelism (PP > 1)." - ) + raise ValueError("--enable-return-routed-experts is incompatible with pipeline parallelism (PP > 1).") if envs.VLLM_USE_V2_MODEL_RUNNER: raise ValueError("VLLM_USE_V2_MODEL_RUNNER does not yet support: routed experts capture") From 9317cef0066651a7680e20d660d6eda2de9768e8 Mon Sep 17 00:00:00 2001 From: S1ro1 Date: Sat, 16 May 2026 02:50:24 +0530 Subject: [PATCH 14/32] Use raw uint8 routed experts payloads --- pyproject.toml | 7 +- src/prime_rl/inference/vllm/routed_experts.py | 30 ++- src/prime_rl/inference/vllm/serving_tokens.py | 5 +- src/prime_rl/orchestrator/trajectories.py | 14 +- tests/unit/inference/test_serving_tokens.py | 33 ++-- tests/unit/orchestrator/test_trajectories.py | 12 +- uv.lock | 183 +++++++++++++++--- 7 files changed, 207 insertions(+), 77 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0afe2049fe..6ad14d78f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ dependencies = [ "tilelang>=0.1.8", "flash-linear-attention", "nvidia-ml-py>=12.575.51", + "pybase64>=1.4.2", ] [project.scripts] @@ -72,6 +73,7 @@ envs = [ "aime2024", "aime2025", "mini_swe_agent_plus", + "rlm-swe", "deepdive" ] disagg = [ @@ -167,13 +169,13 @@ prime-rl-configs = { workspace = true } torch = { index = "pytorch-cu128" } torchvision = { index = "pytorch-cu128" } torchaudio = { index = "pytorch-cu128" } -verifiers = { git = "https://github.com/PrimeIntellect-ai/verifiers", rev = "7fdf522" } +verifiers = { git = "https://github.com/PrimeIntellect-ai/verifiers", rev = "461a730" } torchtitan = { git = "https://github.com/pytorch/torchtitan", rev = "a1fdd7e" } dion = { git = "https://github.com/samsja/dion.git", rev = "d891eeb" } transformers = { git = "https://github.com/huggingface/transformers.git", rev = "c1c3424" } flash-attn-4 = { git = "https://github.com/Dao-AILab/flash-attention.git", subdirectory = "flash_attn/cute", rev = "96bd151" } pydantic-config = { git = "https://github.com/samsja/pydantic_config.git", branch = "main" } -vllm-router = { url = "https://github.com/PrimeIntellect-ai/router/releases/download/v0.1.24/vllm_router-0.1.24-cp38-abi3-manylinux_2_28_x86_64.whl" } +vllm-router = { git = "https://github.com/PrimeIntellect-ai/router", rev = "510092f" } vllm = [ { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.20.2rc1.dev354+g24337fb86.cu129-cp38-abi3-manylinux_2_34_x86_64.whl", marker = "platform_machine == 'x86_64'" }, { url = "https://github.com/vllm-project/vllm/releases/download/v0.20.2/vllm-0.20.2+cu129-cp38-abi3-manylinux_2_31_aarch64.whl", marker = "platform_machine == 'aarch64'" }, @@ -192,6 +194,7 @@ aime2024 = { index = "primeintellect" } aime2025 = { index = "primeintellect" } deepdive = { index = "primeintellect" } mini_swe_agent_plus = { index = "primeintellect" } +rlm-swe = { path = "/shared/research-prod/research-environments/environments/rlm_swe", editable = true } deep-ep = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/deep_ep-1.2.1+29d31c0-cp312-cp312-linux_x86_64.whl" } deep-gemm = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/deep_gemm-2.5.0+891d57b-cp312-cp312-linux_x86_64.whl" } nixl-cu12 = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/nixl_cu12-0.10.1-cp312-cp312-linux_x86_64.whl" } diff --git a/src/prime_rl/inference/vllm/routed_experts.py b/src/prime_rl/inference/vllm/routed_experts.py index d2a6bf7f78..cad97e8574 100644 --- a/src/prime_rl/inference/vllm/routed_experts.py +++ b/src/prime_rl/inference/vllm/routed_experts.py @@ -1,43 +1,35 @@ from __future__ import annotations -import base64 from collections.abc import AsyncIterator -from io import BytesIO from typing import Any import numpy as np +import pybase64 from vllm.outputs import RequestOutput -def serialize_routed_experts(routed_experts: Any) -> str | None: +def serialize_routed_experts(routed_experts: Any) -> dict[str, Any] | None: if routed_experts is None: return None array = np.asarray(routed_experts) assert array.ndim == 3 assert np.issubdtype(array.dtype, np.integer) + if array.size: + assert array.min() >= 0 + assert array.max() <= np.iinfo(np.uint8).max - if array.size == 0: - compact = array.astype(np.uint8, copy=False) - else: - min_value = array.min() - max_value = array.max() - if min_value >= 0 and max_value <= np.iinfo(np.uint8).max: - compact = array.astype(np.uint8, copy=False) - elif min_value >= np.iinfo(np.int16).min and max_value <= np.iinfo(np.int16).max: - compact = array.astype(np.int16, copy=False) - else: - compact = array.astype(np.int32, copy=False) - - buffer = BytesIO() - np.save(buffer, np.ascontiguousarray(compact), allow_pickle=False) - return base64.b64encode(buffer.getvalue()).decode("ascii") + compact = np.ascontiguousarray(array.astype(np.uint8, copy=False)) + return { + "data": pybase64.b64encode(memoryview(compact)).decode("ascii"), + "shape": list(compact.shape), + } class RoutedExpertsCapture: def __init__(self, generator: AsyncIterator[RequestOutput]): self._generator = generator - self.routed_experts: dict[int, str] = {} + self.routed_experts: dict[int, dict[str, Any]] = {} async def __aiter__(self): async for request_output in self._generator: diff --git a/src/prime_rl/inference/vllm/serving_tokens.py b/src/prime_rl/inference/vllm/serving_tokens.py index 40d92263db..229e78cc93 100644 --- a/src/prime_rl/inference/vllm/serving_tokens.py +++ b/src/prime_rl/inference/vllm/serving_tokens.py @@ -11,7 +11,7 @@ inference servers prime-RL runs need this to target a specific replica. 2. Compact ``routed_experts`` export — when the engine emits routing - decisions, surface them as base64 NumPy payloads without requiring a vLLM + decisions, surface them as base64 raw-byte payloads without requiring a vLLM source fork. 3. Server-side ``max_tokens`` defaulting — ``ServingTokens`` hands the @@ -32,6 +32,7 @@ from collections.abc import AsyncGenerator from functools import cached_property +from typing import Any from fastapi import Request from vllm.entrypoints.openai.engine.protocol import ErrorResponse, RequestResponseMetadata @@ -49,7 +50,7 @@ class PrimeRlGenerateResponseChoice(GenerateResponseChoice): - routed_experts: str | None = None + routed_experts: dict[str, Any] | None = None class PrimeRlGenerateResponse(GenerateResponse): diff --git a/src/prime_rl/orchestrator/trajectories.py b/src/prime_rl/orchestrator/trajectories.py index 029e957f94..4cd6f5643c 100644 --- a/src/prime_rl/orchestrator/trajectories.py +++ b/src/prime_rl/orchestrator/trajectories.py @@ -7,6 +7,7 @@ from typing import Any import numpy as np +import pybase64 import torch import verifiers as vf from PIL import Image @@ -27,12 +28,16 @@ # primitives are immutable. pixel_values/image_grid_thw are not mutated after creation. -def _decode_routed_experts(payload: str | None) -> np.ndarray | None: +def _decode_routed_experts(payload: dict[str, Any] | None) -> np.ndarray | None: if payload is None: return None - routed_experts = np.load(BytesIO(base64.b64decode(payload)), allow_pickle=False) + shape = [int(dim) for dim in payload["shape"]] + decoded = pybase64.b64decode_as_bytearray(payload["data"]) + expected_size = int(np.prod(shape, dtype=np.int64)) + assert len(decoded) == expected_size, (len(decoded), expected_size, shape) + routed_experts = np.frombuffer(decoded, dtype=np.uint8).reshape(shape) assert routed_experts.ndim == 3 - return np.ascontiguousarray(routed_experts) + return routed_experts def _align_routed_experts( @@ -322,13 +327,14 @@ def interleave_rollout( def prepare_step_tokens(step: vf.TrajectoryStep, step_idx: int) -> dict[str, Any] | None: tokens = step["tokens"] if tokens is not None: + routed_experts = _decode_routed_experts(tokens.get("routed_experts")) return { "prompt_ids": list(tokens["prompt_ids"]), "prompt_mask": [bool(i) for i in tokens["prompt_mask"]], "completion_ids": list(tokens["completion_ids"]), "completion_mask": [bool(i) for i in tokens["completion_mask"]], "completion_logprobs": list(tokens["completion_logprobs"]), - "routed_experts": _decode_routed_experts(tokens.get("routed_experts")), + "routed_experts": routed_experts, } logger.warning(f"Missing rollout tokens for example {output['example_id']} step {step_idx}.") diff --git a/tests/unit/inference/test_serving_tokens.py b/tests/unit/inference/test_serving_tokens.py index d88d8dff70..bdda7fc485 100644 --- a/tests/unit/inference/test_serving_tokens.py +++ b/tests/unit/inference/test_serving_tokens.py @@ -3,7 +3,7 @@ The full happy-path is owned upstream by vLLM 0.20's ``vllm/entrypoints/serve/disagg`` test suite. We only cover the prime-RL deltas here: - * ``serialize_routed_experts`` round-trips a numpy array as expected. + * ``serialize_routed_experts`` round-trips a compact raw-byte payload. * The subclass attaches its overrides without monkey-patching the parent. * ``_client_set_max_tokens`` distinguishes raw-body shapes correctly. """ @@ -11,10 +11,9 @@ from __future__ import annotations import asyncio -import base64 -from io import BytesIO import numpy as np +import pybase64 from prime_rl.inference.vllm.routed_experts import serialize_routed_experts from prime_rl.inference.vllm.serving_tokens import ( @@ -23,6 +22,13 @@ ) +def _decode_routed_experts(encoded: dict) -> np.ndarray: + return np.frombuffer( + pybase64.b64decode_as_bytearray(encoded["data"]), + dtype=np.uint8, + ).reshape(encoded["shape"]) + + class _FakeRawRequest: def __init__(self, body): self._body = body @@ -42,7 +48,7 @@ def test_subclass_only_overrides_serve_tokens(): ) -def test_serialize_routed_experts_uses_compact_numpy_payload(): +def test_serialize_routed_experts_uses_compact_raw_payload(): routed_experts = np.array( [ [[1, 2], [3, 4]], @@ -54,28 +60,11 @@ def test_serialize_routed_experts_uses_compact_numpy_payload(): encoded = serialize_routed_experts(routed_experts) assert encoded is not None - decoded = np.load(BytesIO(base64.b64decode(encoded)), allow_pickle=False) + decoded = _decode_routed_experts(encoded) assert decoded.dtype == np.uint8 np.testing.assert_array_equal(decoded, routed_experts) -def test_serialize_routed_experts_uses_int16_for_large_expert_ids(): - routed_experts = np.array( - [ - [[256, 257], [300, 301]], - [[302, 303], [304, 305]], - ], - dtype=np.int64, - ) - - encoded = serialize_routed_experts(routed_experts) - assert encoded is not None - - decoded = np.load(BytesIO(base64.b64decode(encoded)), allow_pickle=False) - assert decoded.dtype == np.int16 - np.testing.assert_array_equal(decoded, routed_experts) - - def test_client_set_max_tokens_recognizes_explicit_value(): body = {"token_ids": [1, 2, 3], "sampling_params": {"max_tokens": 256}} assert asyncio.run(_client_set_max_tokens(_FakeRawRequest(body))) is True diff --git a/tests/unit/orchestrator/test_trajectories.py b/tests/unit/orchestrator/test_trajectories.py index 7bcb971556..28e2f15156 100644 --- a/tests/unit/orchestrator/test_trajectories.py +++ b/tests/unit/orchestrator/test_trajectories.py @@ -4,6 +4,7 @@ import numpy as np import pytest +import pybase64 import verifiers as vf from PIL import Image @@ -30,11 +31,12 @@ def _decode_pixels(pixel_bytes: bytes, shape: list[int]) -> list[list[float]]: return np.frombuffer(pixel_bytes, dtype=np.float32).reshape(shape).tolist() -def _routed_experts_payload(data, dtype=np.uint8) -> str: - arr = np.asarray(data, dtype=dtype) - buffer = BytesIO() - np.save(buffer, arr, allow_pickle=False) - return base64.b64encode(buffer.getvalue()).decode("ascii") +def _routed_experts_payload(data) -> dict: + arr = np.asarray(data, dtype=np.uint8) + return { + "data": pybase64.b64encode(memoryview(np.ascontiguousarray(arr))).decode("ascii"), + "shape": list(arr.shape), + } def _sample_routed_experts(sample) -> np.ndarray: diff --git a/uv.lock b/uv.lock index 1ccac8633d..0a60624026 100644 --- a/uv.lock +++ b/uv.lock @@ -11,7 +11,7 @@ supported-markers = [ ] [options] -exclude-newer = "2026-05-07T18:56:25.00241089Z" +exclude-newer = "2026-05-08T21:17:02.773057077Z" exclude-newer-span = "P7D" [options.exclude-newer-package] @@ -260,6 +260,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/77/06/bb80f5f86020c4551da315d78b3ab75e8228f89f0162f2c3a819e407941a/attrs-25.3.0-py3-none-any.whl", hash = "sha256:427318ce031701fea540783410126f03899a97ffc6f61596ad581ac2e40e3bc3", size = 63815, upload-time = "2025-03-13T11:10:21.14Z" }, ] +[[package]] +name = "bashlex" +version = "0.18" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/76/60/aae0bb54f9af5e0128ba90eb83d8d0d506ee8f0475c4fdda3deeda20b1d2/bashlex-0.18.tar.gz", hash = "sha256:5bb03a01c6d5676338c36fd1028009c8ad07e7d61d8a1ce3f513b7fff52796ee", size = 68742, upload-time = "2023-01-18T15:21:26.402Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f4/be/6985abb1011fda8a523cfe21ed9629e397d6e06fb5bae99750402b25c95b/bashlex-0.18-py2.py3-none-any.whl", hash = "sha256:91d73a23a3e51711919c1c899083890cdecffc91d8c088942725ac13e9dcfffa", size = 69539, upload-time = "2023-01-18T15:21:24.167Z" }, +] + [[package]] name = "bcrypt" version = "5.0.0" @@ -653,6 +662,19 @@ nvtx = [ { name = "nvidia-nvtx-cu12", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] +[[package]] +name = "dataclasses-json" +version = "0.6.7" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "marshmallow", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "typing-inspect", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/64/a4/f71d9cf3a5ac257c993b5ca3f93df5f7fb395c725e7f1e6479d2514173c3/dataclasses_json-0.6.7.tar.gz", hash = "sha256:b6b3e528266ea45b9535223bc53ca645f5208833c29229e847b3f26a1cc55fc0", size = 32227, upload-time = "2024-06-09T16:20:19.103Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c3/be/d0d44e092656fe7a06b55e6103cbce807cdbdee17884a5367c68c9860853/dataclasses_json-0.6.7-py3-none-any.whl", hash = "sha256:0dbf33f26c8d5305befd61b39d2b3414e8a407bedc2834dea9b8d642666fb40a", size = 28686, upload-time = "2024-06-09T16:20:16.715Z" }, +] + [[package]] name = "datasets" version = "4.0.0" @@ -1766,6 +1788,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a2/82/8be4c96ffee03c5b4a034e60a31294daf481e12c7c43ab8e34a1453ee48b/MarkupSafe-3.0.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:ad10d3ded218f1039f11a75f8091880239651b52e9bb592ca27de44eed242a48", size = 23352, upload-time = "2024-10-18T15:21:20.971Z" }, ] +[[package]] +name = "marshmallow" +version = "3.26.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "packaging", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/55/79/de6c16cc902f4fc372236926b0ce2ab7845268dcc30fb2fbb7f71b418631/marshmallow-3.26.2.tar.gz", hash = "sha256:bbe2adb5a03e6e3571b573f42527c6fe926e17467833660bebd11593ab8dfd57", size = 222095, upload-time = "2025-12-22T06:53:53.309Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/be/2f/5108cb3ee4ba6501748c4908b908e55f42a5b66245b4cfe0c99326e1ef6e/marshmallow-3.26.2-py3-none-any.whl", hash = "sha256:013fa8a3c4c276c24d26d84ce934dc964e2aa794345a0f8c7e5a7191482c8a73", size = 50964, upload-time = "2025-12-22T06:53:51.801Z" }, +] + [[package]] name = "math-env" version = "0.1.2" @@ -2066,6 +2100,26 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9d/e6/8ad51bdc806aac1dc501e8fe43f759f9ed7284043d722b53323ea421c360/msgspec-0.19.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:5f0f65f29b45e2816d8bded36e6b837a4bf5fb60ec4bc3c625fa2c6da4124537", size = 219081, upload-time = "2024-12-27T17:39:55.142Z" }, ] +[[package]] +name = "multi-swe-bench" +version = "1.1.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "dataclasses-json", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "docker", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "gitpython", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "pygithub", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "pyyaml", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "swe-rex", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "toml", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "tqdm", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "unidiff", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/48/ad/6b7cda600a50392c790b14ee420b9a3bb318a982a298c05f2d1c066a434f/multi_swe_bench-1.1.2.tar.gz", hash = "sha256:44944bc6608d7d9b8d4390f3ce0a3b2c69122ea6be6e35766c6fde2328f50392", size = 1267660, upload-time = "2025-12-18T07:16:09.584Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/08/a8/060eb46096742944d8d37c34094d4e0fb34b28c6291a877388543ea65660/multi_swe_bench-1.1.2-py3-none-any.whl", hash = "sha256:09a5770096d6a035383c5240762ffa8c87b1e8df7d374110de8fb781b4e5a9f9", size = 4942355, upload-time = "2025-12-18T07:16:07.468Z" }, +] + [[package]] name = "multidict" version = "6.6.4" @@ -2095,6 +2149,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/da/d9/f7f9379981e39b8c2511c9e0326d212accacb82f12fbfdc1aa2ce2a7b2b6/multiprocess-0.70.16-py39-none-any.whl", hash = "sha256:a0bafd3ae1b732eac64be2e72038231c1ba97724b60b09400d68f229fcc2fbf3", size = 133351, upload-time = "2024-01-28T18:52:31.981Z" }, ] +[[package]] +name = "mypy-extensions" +version = "1.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/a2/6e/371856a3fb9d31ca8dac321cda606860fa4548858c0cc45d9d1d4ca2628b/mypy_extensions-1.1.0.tar.gz", hash = "sha256:52e68efc3284861e772bbcd66823fde5ae21fd2fdb51c62a211403730b916558", size = 6343, upload-time = "2025-04-22T14:54:24.164Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/79/7b/2c79738432f5c924bef5071f933bcc9efd0473bac3b4aa584a6f7c1c8df8/mypy_extensions-1.1.0-py3-none-any.whl", hash = "sha256:1be4cccdb0f2482337c4743e60421de3a356cd97508abadd57d47403e94f5505", size = 4963, upload-time = "2025-04-22T14:54:22.983Z" }, +] + [[package]] name = "nest-asyncio" version = "1.6.0" @@ -2822,6 +2885,7 @@ dependencies = [ { name = "prime", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "prime-rl-configs", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "pyarrow", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "pybase64", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "pyzmq", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "renderers", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "rich", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, @@ -2874,6 +2938,7 @@ envs = [ { name = "math500", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "mini-swe-agent-plus", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "reverse-text", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "rlm-swe", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "science-env", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "wiki-search", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] @@ -2945,12 +3010,14 @@ requires-dist = [ { name = "prime-rl", extras = ["quack"], marker = "extra == 'all'" }, { name = "prime-rl-configs", editable = "packages/prime-rl-configs" }, { name = "pyarrow", specifier = ">=21.0.0" }, + { name = "pybase64", specifier = ">=1.4.2" }, { name = "pyzmq", specifier = ">=27.1.0" }, { name = "quack-kernels", marker = "extra == 'quack'", specifier = ">=0.3.3" }, { name = "renderers", specifier = "==0.1.6" }, { name = "reverse-text", marker = "extra == 'envs'", index = "https://hub.primeintellect.ai/primeintellect/simple/" }, { name = "rich", specifier = ">=14.0.0" }, { name = "ring-flash-attn", specifier = ">=0.1.8" }, + { name = "rlm-swe", marker = "extra == 'envs'", editable = "/shared/research-prod/research-environments/environments/rlm_swe" }, { name = "science-env", marker = "extra == 'envs'", index = "https://hub.primeintellect.ai/primeintellect/simple/" }, { name = "setproctitle", specifier = ">=1.3.0" }, { name = "tenacity", specifier = ">=8.2.0" }, @@ -2962,11 +3029,11 @@ requires-dist = [ { name = "torchvision", index = "https://download.pytorch.org/whl/cu128" }, { name = "transformers", git = "https://github.com/huggingface/transformers.git?rev=c1c3424" }, { name = "uvloop", specifier = ">=0.21.0" }, - { name = "verifiers", git = "https://github.com/PrimeIntellect-ai/verifiers?rev=7fdf522" }, + { name = "verifiers", git = "https://github.com/PrimeIntellect-ai/verifiers?rev=461a730" }, { name = "vllm", marker = "platform_machine != 'aarch64' and platform_machine != 'x86_64'" }, { name = "vllm", marker = "platform_machine == 'aarch64'", url = "https://github.com/vllm-project/vllm/releases/download/v0.20.2/vllm-0.20.2+cu129-cp38-abi3-manylinux_2_31_aarch64.whl" }, { name = "vllm", marker = "platform_machine == 'x86_64'", url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.20.2rc1.dev354+g24337fb86.cu129-cp38-abi3-manylinux_2_34_x86_64.whl" }, - { name = "vllm-router", marker = "platform_machine == 'x86_64' and extra == 'disagg'", url = "https://github.com/PrimeIntellect-ai/router/releases/download/v0.1.24/vllm_router-0.1.24-cp38-abi3-manylinux_2_28_x86_64.whl" }, + { name = "vllm-router", marker = "platform_machine == 'x86_64' and extra == 'disagg'", git = "https://github.com/PrimeIntellect-ai/router?rev=510092f" }, { name = "wandb", specifier = ">=0.26.1" }, { name = "wiki-search", marker = "extra == 'envs'", index = "https://hub.primeintellect.ai/primeintellect/simple/" }, ] @@ -3247,6 +3314,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/00/4b/ccc026168948fec4f7555b9164c724cf4125eac006e176541483d2c959be/pydantic_settings-2.13.1-py3-none-any.whl", hash = "sha256:d56fd801823dbeae7f0975e1f8c8e25c258eb75d278ea7abb5d9cebb01b56237", size = 58929, upload-time = "2026-02-19T13:45:06.034Z" }, ] +[[package]] +name = "pygithub" +version = "2.9.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pyjwt", extra = ["crypto"], marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "pynacl", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "requests", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "typing-extensions", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "urllib3", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ab/c3/8465a311197e16cf5ab68789fe689535e90f6b61ab524cc32a39e67237ae/pygithub-2.9.1.tar.gz", hash = "sha256:59771d7ff63d54d427be2e7d0dad2208dfffc2b0a045fec959263787739b611c", size = 2594989, upload-time = "2026-04-14T07:26:13.622Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/77/aa/81a5506f089a26338bff17535e4339b3b22049ebd1bcdeff756c4d7a7559/pygithub-2.9.1-py3-none-any.whl", hash = "sha256:2ec78fca30092d51a42d76f4ddb02131b6f0c666a35dfdf364cf302cdda115b9", size = 449710, upload-time = "2026-04-14T07:26:12.382Z" }, +] + [[package]] name = "pygments" version = "2.19.2" @@ -3270,6 +3353,25 @@ crypto = [ { name = "cryptography", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] +[[package]] +name = "pynacl" +version = "1.6.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cffi", marker = "(platform_machine == 'aarch64' and platform_python_implementation != 'PyPy' and sys_platform == 'linux') or (platform_machine == 'x86_64' and platform_python_implementation != 'PyPy' and sys_platform == 'linux')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d9/9a/4019b524b03a13438637b11538c82781a5eda427394380381af8f04f467a/pynacl-1.6.2.tar.gz", hash = "sha256:018494d6d696ae03c7e656e5e74cdfd8ea1326962cc401bcf018f1ed8436811c", size = 3511692, upload-time = "2026-01-01T17:48:10.851Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1e/b4/e927e0653ba63b02a4ca5b4d852a8d1d678afbf69b3dbf9c4d0785ac905c/pynacl-1.6.2-cp38-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8845c0631c0be43abdd865511c41eab235e0be69c81dc66a50911594198679b0", size = 800020, upload-time = "2026-01-01T17:32:18.34Z" }, + { url = "https://files.pythonhosted.org/packages/7f/81/d60984052df5c97b1d24365bc1e30024379b42c4edcd79d2436b1b9806f2/pynacl-1.6.2-cp38-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:22de65bb9010a725b0dac248f353bb072969c94fa8d6b1f34b87d7953cf7bbe4", size = 1399174, upload-time = "2026-01-01T17:32:20.239Z" }, + { url = "https://files.pythonhosted.org/packages/68/f7/322f2f9915c4ef27d140101dd0ed26b479f7e6f5f183590fd32dfc48c4d3/pynacl-1.6.2-cp38-abi3-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:46065496ab748469cdd999246d17e301b2c24ae2fdf739132e580a0e94c94a87", size = 835085, upload-time = "2026-01-01T17:32:22.24Z" }, + { url = "https://files.pythonhosted.org/packages/3e/d0/f301f83ac8dbe53442c5a43f6a39016f94f754d7a9815a875b65e218a307/pynacl-1.6.2-cp38-abi3-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8a66d6fb6ae7661c58995f9c6435bda2b1e68b54b598a6a10247bfcdadac996c", size = 1437614, upload-time = "2026-01-01T17:32:23.766Z" }, + { url = "https://files.pythonhosted.org/packages/c4/58/fc6e649762b029315325ace1a8c6be66125e42f67416d3dbd47b69563d61/pynacl-1.6.2-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:26bfcd00dcf2cf160f122186af731ae30ab120c18e8375684ec2670dccd28130", size = 818251, upload-time = "2026-01-01T17:32:25.69Z" }, + { url = "https://files.pythonhosted.org/packages/c9/a8/b917096b1accc9acd878819a49d3d84875731a41eb665f6ebc826b1af99e/pynacl-1.6.2-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:c8a231e36ec2cab018c4ad4358c386e36eede0319a0c41fed24f840b1dac59f6", size = 1402859, upload-time = "2026-01-01T17:32:27.215Z" }, + { url = "https://files.pythonhosted.org/packages/85/42/fe60b5f4473e12c72f977548e4028156f4d340b884c635ec6b063fe7e9a5/pynacl-1.6.2-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:68be3a09455743ff9505491220b64440ced8973fe930f270c8e07ccfa25b1f9e", size = 791926, upload-time = "2026-01-01T17:32:29.314Z" }, + { url = "https://files.pythonhosted.org/packages/fa/f9/e40e318c604259301cc091a2a63f237d9e7b424c4851cafaea4ea7c4834e/pynacl-1.6.2-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:8b097553b380236d51ed11356c953bf8ce36a29a3e596e934ecabe76c985a577", size = 1363101, upload-time = "2026-01-01T17:32:31.263Z" }, +] + [[package]] name = "pypika" version = "0.51.1" @@ -3535,6 +3637,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c1/02/18ba0727a1c755c528d6a52b363d62c0b7a8e64cf961b3030c046107db4d/ring_flash_attn-0.1.8-py3-none-any.whl", hash = "sha256:296c929516c3b21f7bcdaeca44a99bb541779a7b63979eb0f67837dcb18a2bb9", size = 25437, upload-time = "2025-09-10T11:53:07.565Z" }, ] +[[package]] +name = "rlm-swe" +version = "0.3.3" +source = { editable = "/shared/research-prod/research-environments/environments/rlm_swe" } +dependencies = [ + { name = "multi-swe-bench", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "prime-sandboxes", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "swebench", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "verifiers", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, +] + +[package.metadata] +requires-dist = [ + { name = "multi-swe-bench", specifier = ">=1.1.2" }, + { name = "prime-sandboxes", specifier = ">=0.2.19" }, + { name = "swebench", specifier = "==4.1.0" }, + { name = "verifiers", specifier = ">=0.1.13.dev8" }, +] + [[package]] name = "rpds-py" version = "0.27.1" @@ -3721,6 +3842,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0e/65/5e726c372da8a5e35022a94388b12252710aad0c2351699c3d76ae8dba78/supervisor-4.3.0-py2.py3-none-any.whl", hash = "sha256:0bcb763fddafba410f35cbde226aa7f8514b9fb82eb05a0c85f6588d1c13f8db", size = 320736, upload-time = "2025-08-23T18:25:00.767Z" }, ] +[[package]] +name = "swe-rex" +version = "1.4.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "bashlex", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "fastapi", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "pexpect", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "pydantic", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "python-multipart", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "requests", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "rich", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "uvicorn", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/94/86/a069f93ec866151a4d476d546e60220e66b3788878b6e248b2df3ab2c5f1/swe_rex-1.4.0.tar.gz", hash = "sha256:14f8a24c49a63f9e251340b1109ac75a4aacbaece410f8599209de9bfca843c0", size = 41755, upload-time = "2025-08-14T01:19:20.22Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/98/0d/d06ab2aa78138055c297490762cd7b4d8ac58a544783f874c869cdb7b534/swe_rex-1.4.0-py3-none-any.whl", hash = "sha256:61261ad03eb23b717b5901cd5d229f24f6e1be2e120aad5c2e5ea3384a1d15ad", size = 47756, upload-time = "2025-08-14T01:19:18.93Z" }, +] + [[package]] name = "swebench" version = "4.1.0" @@ -4175,6 +4315,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/18/67/36e9267722cc04a6b9f15c7f3441c2363321a3ea07da7ae0c0707beb2a9c/typing_extensions-4.15.0-py3-none-any.whl", hash = "sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548", size = 44614, upload-time = "2025-08-25T13:49:24.86Z" }, ] +[[package]] +name = "typing-inspect" +version = "0.9.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mypy-extensions", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "typing-extensions", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/dc/74/1789779d91f1961fa9438e9a8710cdae6bd138c80d7303996933d117264a/typing_inspect-0.9.0.tar.gz", hash = "sha256:b23fc42ff6f6ef6954e4852c1fb512cdd18dbea03134f91f856a95ccc9461f78", size = 13825, upload-time = "2023-05-24T20:25:47.612Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/65/f3/107a22063bf27bdccf2024833d3445f4eea42b2e598abfbd46f6a63b6cb0/typing_inspect-0.9.0-py3-none-any.whl", hash = "sha256:9ee6fc59062311ef8547596ab6b955e1b8aa46242d854bfc78f4f6b0eff35f9f", size = 8827, upload-time = "2023-05-24T20:25:45.287Z" }, +] + [[package]] name = "typing-inspection" version = "0.4.2" @@ -4275,7 +4428,7 @@ wheels = [ [[package]] name = "verifiers" version = "0.1.15.dev5" -source = { git = "https://github.com/PrimeIntellect-ai/verifiers?rev=7fdf522#7fdf52219347086c76f00b37659a3626562a58ec" } +source = { git = "https://github.com/PrimeIntellect-ai/verifiers?rev=461a730#461a730024710c1d3ed9a513c9b0ff85339e7db4" } dependencies = [ { name = "aiolimiter", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "anthropic", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, @@ -4292,6 +4445,7 @@ dependencies = [ { name = "openai-agents", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "prime-sandboxes", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "prime-tunnel", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, + { name = "pybase64", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "pydantic", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "pyzmq", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "regex", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, @@ -4670,8 +4824,8 @@ provides-extras = ["zen", "bench", "tensorizer", "fastsafetensors", "instanttens [[package]] name = "vllm-router" -version = "0.1.24" -source = { url = "https://github.com/PrimeIntellect-ai/router/releases/download/v0.1.24/vllm_router-0.1.24-cp38-abi3-manylinux_2_28_x86_64.whl" } +version = "0.1.25" +source = { git = "https://github.com/PrimeIntellect-ai/router?rev=510092f#510092fdf456ee6d45657f425b99d4c35508664f" } dependencies = [ { name = "aiohttp", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "fastapi", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, @@ -4680,23 +4834,6 @@ dependencies = [ { name = "setproctitle", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "uvicorn", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, ] -wheels = [ - { url = "https://github.com/PrimeIntellect-ai/router/releases/download/v0.1.24/vllm_router-0.1.24-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:5b45d59871b73357ab1c3d48790de233e190055b6ace810aefb3cb1416fe0d00" }, -] - -[package.metadata] -requires-dist = [ - { name = "aiohttp" }, - { name = "fastapi" }, - { name = "orjson" }, - { name = "pytest", marker = "extra == 'dev'", specifier = ">=7.0.0" }, - { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.21.0" }, - { name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=4.0.0" }, - { name = "requests", specifier = ">=2.25.0" }, - { name = "setproctitle" }, - { name = "uvicorn" }, -] -provides-extras = ["dev"] [[package]] name = "wadler-lindig" From 3cb834594315aad5494f5309cd4aeadd15a766bf Mon Sep 17 00:00:00 2001 From: S1ro1 Date: Sat, 16 May 2026 03:08:01 +0530 Subject: [PATCH 15/32] Remove unrelated rlm-swe dependency --- pyproject.toml | 2 - tests/unit/orchestrator/test_trajectories.py | 2 +- uv.lock | 153 +------------------ 3 files changed, 2 insertions(+), 155 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6ad14d78f3..e6fd981798 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,7 +73,6 @@ envs = [ "aime2024", "aime2025", "mini_swe_agent_plus", - "rlm-swe", "deepdive" ] disagg = [ @@ -194,7 +193,6 @@ aime2024 = { index = "primeintellect" } aime2025 = { index = "primeintellect" } deepdive = { index = "primeintellect" } mini_swe_agent_plus = { index = "primeintellect" } -rlm-swe = { path = "/shared/research-prod/research-environments/environments/rlm_swe", editable = true } deep-ep = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/deep_ep-1.2.1+29d31c0-cp312-cp312-linux_x86_64.whl" } deep-gemm = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/deep_gemm-2.5.0+891d57b-cp312-cp312-linux_x86_64.whl" } nixl-cu12 = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/nixl_cu12-0.10.1-cp312-cp312-linux_x86_64.whl" } diff --git a/tests/unit/orchestrator/test_trajectories.py b/tests/unit/orchestrator/test_trajectories.py index 28e2f15156..303a02fd11 100644 --- a/tests/unit/orchestrator/test_trajectories.py +++ b/tests/unit/orchestrator/test_trajectories.py @@ -3,8 +3,8 @@ from unittest.mock import MagicMock import numpy as np -import pytest import pybase64 +import pytest import verifiers as vf from PIL import Image diff --git a/uv.lock b/uv.lock index 0a60624026..328e6fb7f4 100644 --- a/uv.lock +++ b/uv.lock @@ -11,7 +11,7 @@ supported-markers = [ ] [options] -exclude-newer = "2026-05-08T21:17:02.773057077Z" +exclude-newer = "2026-05-08T21:37:36.737039338Z" exclude-newer-span = "P7D" [options.exclude-newer-package] @@ -260,15 +260,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/77/06/bb80f5f86020c4551da315d78b3ab75e8228f89f0162f2c3a819e407941a/attrs-25.3.0-py3-none-any.whl", hash = "sha256:427318ce031701fea540783410126f03899a97ffc6f61596ad581ac2e40e3bc3", size = 63815, upload-time = "2025-03-13T11:10:21.14Z" }, ] -[[package]] -name = "bashlex" -version = "0.18" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/76/60/aae0bb54f9af5e0128ba90eb83d8d0d506ee8f0475c4fdda3deeda20b1d2/bashlex-0.18.tar.gz", hash = "sha256:5bb03a01c6d5676338c36fd1028009c8ad07e7d61d8a1ce3f513b7fff52796ee", size = 68742, upload-time = "2023-01-18T15:21:26.402Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/f4/be/6985abb1011fda8a523cfe21ed9629e397d6e06fb5bae99750402b25c95b/bashlex-0.18-py2.py3-none-any.whl", hash = "sha256:91d73a23a3e51711919c1c899083890cdecffc91d8c088942725ac13e9dcfffa", size = 69539, upload-time = "2023-01-18T15:21:24.167Z" }, -] - [[package]] name = "bcrypt" version = "5.0.0" @@ -662,19 +653,6 @@ nvtx = [ { name = "nvidia-nvtx-cu12", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] -[[package]] -name = "dataclasses-json" -version = "0.6.7" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "marshmallow", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "typing-inspect", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/64/a4/f71d9cf3a5ac257c993b5ca3f93df5f7fb395c725e7f1e6479d2514173c3/dataclasses_json-0.6.7.tar.gz", hash = "sha256:b6b3e528266ea45b9535223bc53ca645f5208833c29229e847b3f26a1cc55fc0", size = 32227, upload-time = "2024-06-09T16:20:19.103Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/c3/be/d0d44e092656fe7a06b55e6103cbce807cdbdee17884a5367c68c9860853/dataclasses_json-0.6.7-py3-none-any.whl", hash = "sha256:0dbf33f26c8d5305befd61b39d2b3414e8a407bedc2834dea9b8d642666fb40a", size = 28686, upload-time = "2024-06-09T16:20:16.715Z" }, -] - [[package]] name = "datasets" version = "4.0.0" @@ -1788,18 +1766,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a2/82/8be4c96ffee03c5b4a034e60a31294daf481e12c7c43ab8e34a1453ee48b/MarkupSafe-3.0.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:ad10d3ded218f1039f11a75f8091880239651b52e9bb592ca27de44eed242a48", size = 23352, upload-time = "2024-10-18T15:21:20.971Z" }, ] -[[package]] -name = "marshmallow" -version = "3.26.2" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "packaging", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/55/79/de6c16cc902f4fc372236926b0ce2ab7845268dcc30fb2fbb7f71b418631/marshmallow-3.26.2.tar.gz", hash = "sha256:bbe2adb5a03e6e3571b573f42527c6fe926e17467833660bebd11593ab8dfd57", size = 222095, upload-time = "2025-12-22T06:53:53.309Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/be/2f/5108cb3ee4ba6501748c4908b908e55f42a5b66245b4cfe0c99326e1ef6e/marshmallow-3.26.2-py3-none-any.whl", hash = "sha256:013fa8a3c4c276c24d26d84ce934dc964e2aa794345a0f8c7e5a7191482c8a73", size = 50964, upload-time = "2025-12-22T06:53:51.801Z" }, -] - [[package]] name = "math-env" version = "0.1.2" @@ -2100,26 +2066,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9d/e6/8ad51bdc806aac1dc501e8fe43f759f9ed7284043d722b53323ea421c360/msgspec-0.19.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:5f0f65f29b45e2816d8bded36e6b837a4bf5fb60ec4bc3c625fa2c6da4124537", size = 219081, upload-time = "2024-12-27T17:39:55.142Z" }, ] -[[package]] -name = "multi-swe-bench" -version = "1.1.2" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "dataclasses-json", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "docker", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "gitpython", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "pygithub", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "pyyaml", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "swe-rex", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "toml", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "tqdm", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "unidiff", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/48/ad/6b7cda600a50392c790b14ee420b9a3bb318a982a298c05f2d1c066a434f/multi_swe_bench-1.1.2.tar.gz", hash = "sha256:44944bc6608d7d9b8d4390f3ce0a3b2c69122ea6be6e35766c6fde2328f50392", size = 1267660, upload-time = "2025-12-18T07:16:09.584Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/08/a8/060eb46096742944d8d37c34094d4e0fb34b28c6291a877388543ea65660/multi_swe_bench-1.1.2-py3-none-any.whl", hash = "sha256:09a5770096d6a035383c5240762ffa8c87b1e8df7d374110de8fb781b4e5a9f9", size = 4942355, upload-time = "2025-12-18T07:16:07.468Z" }, -] - [[package]] name = "multidict" version = "6.6.4" @@ -2149,15 +2095,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/da/d9/f7f9379981e39b8c2511c9e0326d212accacb82f12fbfdc1aa2ce2a7b2b6/multiprocess-0.70.16-py39-none-any.whl", hash = "sha256:a0bafd3ae1b732eac64be2e72038231c1ba97724b60b09400d68f229fcc2fbf3", size = 133351, upload-time = "2024-01-28T18:52:31.981Z" }, ] -[[package]] -name = "mypy-extensions" -version = "1.1.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/a2/6e/371856a3fb9d31ca8dac321cda606860fa4548858c0cc45d9d1d4ca2628b/mypy_extensions-1.1.0.tar.gz", hash = "sha256:52e68efc3284861e772bbcd66823fde5ae21fd2fdb51c62a211403730b916558", size = 6343, upload-time = "2025-04-22T14:54:24.164Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/79/7b/2c79738432f5c924bef5071f933bcc9efd0473bac3b4aa584a6f7c1c8df8/mypy_extensions-1.1.0-py3-none-any.whl", hash = "sha256:1be4cccdb0f2482337c4743e60421de3a356cd97508abadd57d47403e94f5505", size = 4963, upload-time = "2025-04-22T14:54:22.983Z" }, -] - [[package]] name = "nest-asyncio" version = "1.6.0" @@ -2938,7 +2875,6 @@ envs = [ { name = "math500", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "mini-swe-agent-plus", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "reverse-text", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "rlm-swe", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "science-env", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "wiki-search", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] @@ -3017,7 +2953,6 @@ requires-dist = [ { name = "reverse-text", marker = "extra == 'envs'", index = "https://hub.primeintellect.ai/primeintellect/simple/" }, { name = "rich", specifier = ">=14.0.0" }, { name = "ring-flash-attn", specifier = ">=0.1.8" }, - { name = "rlm-swe", marker = "extra == 'envs'", editable = "/shared/research-prod/research-environments/environments/rlm_swe" }, { name = "science-env", marker = "extra == 'envs'", index = "https://hub.primeintellect.ai/primeintellect/simple/" }, { name = "setproctitle", specifier = ">=1.3.0" }, { name = "tenacity", specifier = ">=8.2.0" }, @@ -3314,22 +3249,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/00/4b/ccc026168948fec4f7555b9164c724cf4125eac006e176541483d2c959be/pydantic_settings-2.13.1-py3-none-any.whl", hash = "sha256:d56fd801823dbeae7f0975e1f8c8e25c258eb75d278ea7abb5d9cebb01b56237", size = 58929, upload-time = "2026-02-19T13:45:06.034Z" }, ] -[[package]] -name = "pygithub" -version = "2.9.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "pyjwt", extra = ["crypto"], marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "pynacl", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "requests", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "typing-extensions", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "urllib3", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/ab/c3/8465a311197e16cf5ab68789fe689535e90f6b61ab524cc32a39e67237ae/pygithub-2.9.1.tar.gz", hash = "sha256:59771d7ff63d54d427be2e7d0dad2208dfffc2b0a045fec959263787739b611c", size = 2594989, upload-time = "2026-04-14T07:26:13.622Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/77/aa/81a5506f089a26338bff17535e4339b3b22049ebd1bcdeff756c4d7a7559/pygithub-2.9.1-py3-none-any.whl", hash = "sha256:2ec78fca30092d51a42d76f4ddb02131b6f0c666a35dfdf364cf302cdda115b9", size = 449710, upload-time = "2026-04-14T07:26:12.382Z" }, -] - [[package]] name = "pygments" version = "2.19.2" @@ -3353,25 +3272,6 @@ crypto = [ { name = "cryptography", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, ] -[[package]] -name = "pynacl" -version = "1.6.2" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "cffi", marker = "(platform_machine == 'aarch64' and platform_python_implementation != 'PyPy' and sys_platform == 'linux') or (platform_machine == 'x86_64' and platform_python_implementation != 'PyPy' and sys_platform == 'linux')" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/d9/9a/4019b524b03a13438637b11538c82781a5eda427394380381af8f04f467a/pynacl-1.6.2.tar.gz", hash = "sha256:018494d6d696ae03c7e656e5e74cdfd8ea1326962cc401bcf018f1ed8436811c", size = 3511692, upload-time = "2026-01-01T17:48:10.851Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/1e/b4/e927e0653ba63b02a4ca5b4d852a8d1d678afbf69b3dbf9c4d0785ac905c/pynacl-1.6.2-cp38-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8845c0631c0be43abdd865511c41eab235e0be69c81dc66a50911594198679b0", size = 800020, upload-time = "2026-01-01T17:32:18.34Z" }, - { url = "https://files.pythonhosted.org/packages/7f/81/d60984052df5c97b1d24365bc1e30024379b42c4edcd79d2436b1b9806f2/pynacl-1.6.2-cp38-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:22de65bb9010a725b0dac248f353bb072969c94fa8d6b1f34b87d7953cf7bbe4", size = 1399174, upload-time = "2026-01-01T17:32:20.239Z" }, - { url = "https://files.pythonhosted.org/packages/68/f7/322f2f9915c4ef27d140101dd0ed26b479f7e6f5f183590fd32dfc48c4d3/pynacl-1.6.2-cp38-abi3-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:46065496ab748469cdd999246d17e301b2c24ae2fdf739132e580a0e94c94a87", size = 835085, upload-time = "2026-01-01T17:32:22.24Z" }, - { url = "https://files.pythonhosted.org/packages/3e/d0/f301f83ac8dbe53442c5a43f6a39016f94f754d7a9815a875b65e218a307/pynacl-1.6.2-cp38-abi3-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8a66d6fb6ae7661c58995f9c6435bda2b1e68b54b598a6a10247bfcdadac996c", size = 1437614, upload-time = "2026-01-01T17:32:23.766Z" }, - { url = "https://files.pythonhosted.org/packages/c4/58/fc6e649762b029315325ace1a8c6be66125e42f67416d3dbd47b69563d61/pynacl-1.6.2-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:26bfcd00dcf2cf160f122186af731ae30ab120c18e8375684ec2670dccd28130", size = 818251, upload-time = "2026-01-01T17:32:25.69Z" }, - { url = "https://files.pythonhosted.org/packages/c9/a8/b917096b1accc9acd878819a49d3d84875731a41eb665f6ebc826b1af99e/pynacl-1.6.2-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:c8a231e36ec2cab018c4ad4358c386e36eede0319a0c41fed24f840b1dac59f6", size = 1402859, upload-time = "2026-01-01T17:32:27.215Z" }, - { url = "https://files.pythonhosted.org/packages/85/42/fe60b5f4473e12c72f977548e4028156f4d340b884c635ec6b063fe7e9a5/pynacl-1.6.2-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:68be3a09455743ff9505491220b64440ced8973fe930f270c8e07ccfa25b1f9e", size = 791926, upload-time = "2026-01-01T17:32:29.314Z" }, - { url = "https://files.pythonhosted.org/packages/fa/f9/e40e318c604259301cc091a2a63f237d9e7b424c4851cafaea4ea7c4834e/pynacl-1.6.2-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:8b097553b380236d51ed11356c953bf8ce36a29a3e596e934ecabe76c985a577", size = 1363101, upload-time = "2026-01-01T17:32:31.263Z" }, -] - [[package]] name = "pypika" version = "0.51.1" @@ -3637,25 +3537,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c1/02/18ba0727a1c755c528d6a52b363d62c0b7a8e64cf961b3030c046107db4d/ring_flash_attn-0.1.8-py3-none-any.whl", hash = "sha256:296c929516c3b21f7bcdaeca44a99bb541779a7b63979eb0f67837dcb18a2bb9", size = 25437, upload-time = "2025-09-10T11:53:07.565Z" }, ] -[[package]] -name = "rlm-swe" -version = "0.3.3" -source = { editable = "/shared/research-prod/research-environments/environments/rlm_swe" } -dependencies = [ - { name = "multi-swe-bench", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "prime-sandboxes", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "swebench", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "verifiers", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, -] - -[package.metadata] -requires-dist = [ - { name = "multi-swe-bench", specifier = ">=1.1.2" }, - { name = "prime-sandboxes", specifier = ">=0.2.19" }, - { name = "swebench", specifier = "==4.1.0" }, - { name = "verifiers", specifier = ">=0.1.13.dev8" }, -] - [[package]] name = "rpds-py" version = "0.27.1" @@ -3842,25 +3723,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0e/65/5e726c372da8a5e35022a94388b12252710aad0c2351699c3d76ae8dba78/supervisor-4.3.0-py2.py3-none-any.whl", hash = "sha256:0bcb763fddafba410f35cbde226aa7f8514b9fb82eb05a0c85f6588d1c13f8db", size = 320736, upload-time = "2025-08-23T18:25:00.767Z" }, ] -[[package]] -name = "swe-rex" -version = "1.4.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "bashlex", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "fastapi", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "pexpect", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "pydantic", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "python-multipart", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "requests", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "rich", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "uvicorn", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/94/86/a069f93ec866151a4d476d546e60220e66b3788878b6e248b2df3ab2c5f1/swe_rex-1.4.0.tar.gz", hash = "sha256:14f8a24c49a63f9e251340b1109ac75a4aacbaece410f8599209de9bfca843c0", size = 41755, upload-time = "2025-08-14T01:19:20.22Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/98/0d/d06ab2aa78138055c297490762cd7b4d8ac58a544783f874c869cdb7b534/swe_rex-1.4.0-py3-none-any.whl", hash = "sha256:61261ad03eb23b717b5901cd5d229f24f6e1be2e120aad5c2e5ea3384a1d15ad", size = 47756, upload-time = "2025-08-14T01:19:18.93Z" }, -] - [[package]] name = "swebench" version = "4.1.0" @@ -4315,19 +4177,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/18/67/36e9267722cc04a6b9f15c7f3441c2363321a3ea07da7ae0c0707beb2a9c/typing_extensions-4.15.0-py3-none-any.whl", hash = "sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548", size = 44614, upload-time = "2025-08-25T13:49:24.86Z" }, ] -[[package]] -name = "typing-inspect" -version = "0.9.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "mypy-extensions", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "typing-extensions", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/dc/74/1789779d91f1961fa9438e9a8710cdae6bd138c80d7303996933d117264a/typing_inspect-0.9.0.tar.gz", hash = "sha256:b23fc42ff6f6ef6954e4852c1fb512cdd18dbea03134f91f856a95ccc9461f78", size = 13825, upload-time = "2023-05-24T20:25:47.612Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/65/f3/107a22063bf27bdccf2024833d3445f4eea42b2e598abfbd46f6a63b6cb0/typing_inspect-0.9.0-py3-none-any.whl", hash = "sha256:9ee6fc59062311ef8547596ab6b955e1b8aa46242d854bfc78f4f6b0eff35f9f", size = 8827, upload-time = "2023-05-24T20:25:45.287Z" }, -] - [[package]] name = "typing-inspection" version = "0.4.2" From 66a298462991f5e651d53f22a5883bbd9e849507 Mon Sep 17 00:00:00 2001 From: S1ro1 Date: Sat, 16 May 2026 03:17:01 +0530 Subject: [PATCH 16/32] Pin vllm-router 0.1.25 wheel --- pyproject.toml | 2 +- uv.lock | 23 ++++++++++++++++++++--- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e6fd981798..2ddab51903 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -174,7 +174,7 @@ dion = { git = "https://github.com/samsja/dion.git", rev = "d891eeb" } transformers = { git = "https://github.com/huggingface/transformers.git", rev = "c1c3424" } flash-attn-4 = { git = "https://github.com/Dao-AILab/flash-attention.git", subdirectory = "flash_attn/cute", rev = "96bd151" } pydantic-config = { git = "https://github.com/samsja/pydantic_config.git", branch = "main" } -vllm-router = { git = "https://github.com/PrimeIntellect-ai/router", rev = "510092f" } +vllm-router = { url = "https://github.com/PrimeIntellect-ai/router/releases/download/v0.1.25/vllm_router-0.1.25-cp38-abi3-manylinux_2_28_x86_64.whl" } vllm = [ { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.20.2rc1.dev354+g24337fb86.cu129-cp38-abi3-manylinux_2_34_x86_64.whl", marker = "platform_machine == 'x86_64'" }, { url = "https://github.com/vllm-project/vllm/releases/download/v0.20.2/vllm-0.20.2+cu129-cp38-abi3-manylinux_2_31_aarch64.whl", marker = "platform_machine == 'aarch64'" }, diff --git a/uv.lock b/uv.lock index 328e6fb7f4..952e403ed2 100644 --- a/uv.lock +++ b/uv.lock @@ -11,7 +11,7 @@ supported-markers = [ ] [options] -exclude-newer = "2026-05-08T21:37:36.737039338Z" +exclude-newer = "2026-05-08T21:46:44.310954962Z" exclude-newer-span = "P7D" [options.exclude-newer-package] @@ -2968,7 +2968,7 @@ requires-dist = [ { name = "vllm", marker = "platform_machine != 'aarch64' and platform_machine != 'x86_64'" }, { name = "vllm", marker = "platform_machine == 'aarch64'", url = "https://github.com/vllm-project/vllm/releases/download/v0.20.2/vllm-0.20.2+cu129-cp38-abi3-manylinux_2_31_aarch64.whl" }, { name = "vllm", marker = "platform_machine == 'x86_64'", url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.20.2rc1.dev354+g24337fb86.cu129-cp38-abi3-manylinux_2_34_x86_64.whl" }, - { name = "vllm-router", marker = "platform_machine == 'x86_64' and extra == 'disagg'", git = "https://github.com/PrimeIntellect-ai/router?rev=510092f" }, + { name = "vllm-router", marker = "platform_machine == 'x86_64' and extra == 'disagg'", url = "https://github.com/PrimeIntellect-ai/router/releases/download/v0.1.25/vllm_router-0.1.25-cp38-abi3-manylinux_2_28_x86_64.whl" }, { name = "wandb", specifier = ">=0.26.1" }, { name = "wiki-search", marker = "extra == 'envs'", index = "https://hub.primeintellect.ai/primeintellect/simple/" }, ] @@ -4674,7 +4674,7 @@ provides-extras = ["zen", "bench", "tensorizer", "fastsafetensors", "instanttens [[package]] name = "vllm-router" version = "0.1.25" -source = { git = "https://github.com/PrimeIntellect-ai/router?rev=510092f#510092fdf456ee6d45657f425b99d4c35508664f" } +source = { url = "https://github.com/PrimeIntellect-ai/router/releases/download/v0.1.25/vllm_router-0.1.25-cp38-abi3-manylinux_2_28_x86_64.whl" } dependencies = [ { name = "aiohttp", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "fastapi", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, @@ -4683,6 +4683,23 @@ dependencies = [ { name = "setproctitle", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "uvicorn", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, ] +wheels = [ + { url = "https://github.com/PrimeIntellect-ai/router/releases/download/v0.1.25/vllm_router-0.1.25-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:e84e731a0779f820bfe3cf4ce78cea2d09993c0a6501c63bcda93826bcd21fd0" }, +] + +[package.metadata] +requires-dist = [ + { name = "aiohttp" }, + { name = "fastapi" }, + { name = "orjson" }, + { name = "pytest", marker = "extra == 'dev'", specifier = ">=7.0.0" }, + { name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.21.0" }, + { name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=4.0.0" }, + { name = "requests", specifier = ">=2.25.0" }, + { name = "setproctitle" }, + { name = "uvicorn" }, +] +provides-extras = ["dev"] [[package]] name = "wadler-lindig" From 777aae7153597850eb08ffa9071f0f8fd464dd88 Mon Sep 17 00:00:00 2001 From: S1ro1 Date: Sat, 16 May 2026 07:37:42 +0530 Subject: [PATCH 17/32] Keep verifiers routed experts opaque --- pyproject.toml | 2 +- uv.lock | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2ddab51903..a54daaa387 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -168,7 +168,7 @@ prime-rl-configs = { workspace = true } torch = { index = "pytorch-cu128" } torchvision = { index = "pytorch-cu128" } torchaudio = { index = "pytorch-cu128" } -verifiers = { git = "https://github.com/PrimeIntellect-ai/verifiers", rev = "461a730" } +verifiers = { git = "https://github.com/PrimeIntellect-ai/verifiers", rev = "3708ede" } torchtitan = { git = "https://github.com/pytorch/torchtitan", rev = "a1fdd7e" } dion = { git = "https://github.com/samsja/dion.git", rev = "d891eeb" } transformers = { git = "https://github.com/huggingface/transformers.git", rev = "c1c3424" } diff --git a/uv.lock b/uv.lock index 952e403ed2..d3d6d567fe 100644 --- a/uv.lock +++ b/uv.lock @@ -11,7 +11,7 @@ supported-markers = [ ] [options] -exclude-newer = "2026-05-08T21:46:44.310954962Z" +exclude-newer = "2026-05-09T02:04:57.89664956Z" exclude-newer-span = "P7D" [options.exclude-newer-package] @@ -2964,7 +2964,7 @@ requires-dist = [ { name = "torchvision", index = "https://download.pytorch.org/whl/cu128" }, { name = "transformers", git = "https://github.com/huggingface/transformers.git?rev=c1c3424" }, { name = "uvloop", specifier = ">=0.21.0" }, - { name = "verifiers", git = "https://github.com/PrimeIntellect-ai/verifiers?rev=461a730" }, + { name = "verifiers", git = "https://github.com/PrimeIntellect-ai/verifiers?rev=3708ede" }, { name = "vllm", marker = "platform_machine != 'aarch64' and platform_machine != 'x86_64'" }, { name = "vllm", marker = "platform_machine == 'aarch64'", url = "https://github.com/vllm-project/vllm/releases/download/v0.20.2/vllm-0.20.2+cu129-cp38-abi3-manylinux_2_31_aarch64.whl" }, { name = "vllm", marker = "platform_machine == 'x86_64'", url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.20.2rc1.dev354+g24337fb86.cu129-cp38-abi3-manylinux_2_34_x86_64.whl" }, @@ -4277,7 +4277,7 @@ wheels = [ [[package]] name = "verifiers" version = "0.1.15.dev5" -source = { git = "https://github.com/PrimeIntellect-ai/verifiers?rev=461a730#461a730024710c1d3ed9a513c9b0ff85339e7db4" } +source = { git = "https://github.com/PrimeIntellect-ai/verifiers?rev=3708ede#3708ede31d16b77866befa3c7a97cf94b5062cd3" } dependencies = [ { name = "aiolimiter", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "anthropic", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, From a74e7f56b8c80e9d060caf32b388006aef678a68 Mon Sep 17 00:00:00 2001 From: Sami Date: Sat, 16 May 2026 12:01:32 +0530 Subject: [PATCH 18/32] Forward renderer thinking preservation config --- .../src/prime_rl/configs/orchestrator.py | 4 ++++ .../src/prime_rl/configs/shared.py | 20 +++++++++++++++++++ src/prime_rl/orchestrator/orchestrator.py | 4 ++++ src/prime_rl/utils/client.py | 14 +++++++++++++ src/prime_rl/utils/elastic.py | 10 ++++++++++ .../orchestrator/test_orchestrator_setup.py | 6 ++++++ tests/unit/utils/test_client.py | 3 +++ 7 files changed, 61 insertions(+) diff --git a/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py b/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py index 5d04d3369f..8feba60128 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py @@ -1269,6 +1269,10 @@ def validate_renderer_args(self): renderer_args_set.append(f"renderer.reasoning_parser={self.renderer.reasoning_parser!r}") if self.renderer.pool_size is not None: renderer_args_set.append(f"renderer.pool_size={self.renderer.pool_size!r}") + if self.renderer.preserve_all_thinking: + renderer_args_set.append("renderer.preserve_all_thinking=true") + if self.renderer.preserve_thinking_between_tool_calls: + renderer_args_set.append("renderer.preserve_thinking_between_tool_calls=true") if renderer_args_set: raise ValueError( diff --git a/packages/prime-rl-configs/src/prime_rl/configs/shared.py b/packages/prime-rl-configs/src/prime_rl/configs/shared.py index d26c33d9a9..1651e17b9f 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/shared.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/shared.py @@ -186,6 +186,26 @@ class RendererConfig(BaseConfig): ), ] = None + preserve_all_thinking: Annotated[ + bool, + Field( + description=( + "Forward preserve_all_thinking to the renderer client. When true, " + "past-assistant reasoning_content is re-emitted on subsequent renders." + ), + ), + ] = False + + preserve_thinking_between_tool_calls: Annotated[ + bool, + Field( + description=( + "Forward preserve_thinking_between_tool_calls to the renderer client. " + "This preserves thinking only inside the active assistant/tool block." + ), + ), + ] = False + class ElasticConfig(BaseConfig): """Configures elastic inference pool with DNS-based service discovery. diff --git a/src/prime_rl/orchestrator/orchestrator.py b/src/prime_rl/orchestrator/orchestrator.py index bc1128ebc7..67ef7bfa1d 100644 --- a/src/prime_rl/orchestrator/orchestrator.py +++ b/src/prime_rl/orchestrator/orchestrator.py @@ -926,6 +926,8 @@ async def setup_rollout_inference_pool( renderer=config.renderer.name, tool_parser=config.renderer.tool_parser, reasoning_parser=config.renderer.reasoning_parser, + preserve_all_thinking=config.renderer.preserve_all_thinking, + preserve_thinking_between_tool_calls=config.renderer.preserve_thinking_between_tool_calls, ) logger.info(f"Initialized {type(renderer).__name__} for {config.model.name}") inference_pool = await setup_inference_pool( @@ -937,6 +939,8 @@ async def setup_rollout_inference_pool( tool_parser=config.renderer.tool_parser, reasoning_parser=config.renderer.reasoning_parser, renderer_pool_size=config.renderer.pool_size, + preserve_all_thinking=config.renderer.preserve_all_thinking, + preserve_thinking_between_tool_calls=config.renderer.preserve_thinking_between_tool_calls, ) logger.info("Using direct renderer rollout client") return renderer, inference_pool diff --git a/src/prime_rl/utils/client.py b/src/prime_rl/utils/client.py index 21659dfc46..fedbdddb8e 100644 --- a/src/prime_rl/utils/client.py +++ b/src/prime_rl/utils/client.py @@ -68,6 +68,8 @@ def __init__( tool_parser: str | None = None, reasoning_parser: str | None = None, renderer_pool_size: int | None = None, + preserve_all_thinking: bool = False, + preserve_thinking_between_tool_calls: bool = False, ): renderer_model_name = model_name if train_client_type == "renderer" else None self._train_clients = setup_clients( @@ -78,6 +80,8 @@ def __init__( tool_parser=tool_parser, reasoning_parser=reasoning_parser, renderer_pool_size=renderer_pool_size, + preserve_all_thinking=preserve_all_thinking, + preserve_thinking_between_tool_calls=preserve_thinking_between_tool_calls, ) self._eval_clients = setup_clients(client_config, client_type=eval_client_type) self._admin_clients = setup_admin_clients(client_config) @@ -129,6 +133,8 @@ async def setup_inference_pool( tool_parser: str | None = None, reasoning_parser: str | None = None, renderer_pool_size: int | None = None, + preserve_all_thinking: bool = False, + preserve_thinking_between_tool_calls: bool = False, ) -> InferencePool: """Create an inference pool from config (static or elastic).""" logger = get_logger() @@ -152,6 +158,8 @@ async def setup_inference_pool( tool_parser=tool_parser, reasoning_parser=reasoning_parser, renderer_pool_size=renderer_pool_size, + preserve_all_thinking=preserve_all_thinking, + preserve_thinking_between_tool_calls=preserve_thinking_between_tool_calls, ) logger.info( @@ -168,6 +176,8 @@ async def setup_inference_pool( tool_parser=tool_parser, reasoning_parser=reasoning_parser, renderer_pool_size=renderer_pool_size, + preserve_all_thinking=preserve_all_thinking, + preserve_thinking_between_tool_calls=preserve_thinking_between_tool_calls, ) @@ -179,6 +189,8 @@ def setup_clients( tool_parser: str | None = None, reasoning_parser: str | None = None, renderer_pool_size: int | None = None, + preserve_all_thinking: bool = False, + preserve_thinking_between_tool_calls: bool = False, ) -> list[vf.ClientConfig]: clients = [] client_idx = 0 @@ -196,6 +208,8 @@ def setup_clients( renderer_pool_size=renderer_pool_size, tool_parser=tool_parser, reasoning_parser=reasoning_parser, + preserve_all_thinking=preserve_all_thinking, + preserve_thinking_between_tool_calls=preserve_thinking_between_tool_calls, api_base_url=base_url, api_key_var=client_config.api_key_var, timeout=client_config.timeout, diff --git a/src/prime_rl/utils/elastic.py b/src/prime_rl/utils/elastic.py index 902f873903..c59f81e27f 100644 --- a/src/prime_rl/utils/elastic.py +++ b/src/prime_rl/utils/elastic.py @@ -110,6 +110,8 @@ def __init__( tool_parser: str | None = None, reasoning_parser: str | None = None, renderer_pool_size: int | None = None, + preserve_all_thinking: bool = False, + preserve_thinking_between_tool_calls: bool = False, ): self.logger = get_logger() self.client_config = client_config @@ -125,6 +127,8 @@ def __init__( self.tool_parser = tool_parser self.reasoning_parser = reasoning_parser self.renderer_pool_size = renderer_pool_size + self.preserve_all_thinking = preserve_all_thinking + self.preserve_thinking_between_tool_calls = preserve_thinking_between_tool_calls self.router_url = client_config.router_url self._servers: dict[str, ServerState] = {} @@ -152,6 +156,8 @@ async def from_config( tool_parser: str | None = None, reasoning_parser: str | None = None, renderer_pool_size: int | None = None, + preserve_all_thinking: bool = False, + preserve_thinking_between_tool_calls: bool = False, ) -> ElasticInferencePool: if client_config.elastic is None: raise ValueError("Elastic inference pool requires elastic config") @@ -164,6 +170,8 @@ async def from_config( tool_parser=tool_parser, reasoning_parser=reasoning_parser, renderer_pool_size=renderer_pool_size, + preserve_all_thinking=preserve_all_thinking, + preserve_thinking_between_tool_calls=preserve_thinking_between_tool_calls, ) await pool.start() return pool @@ -214,6 +222,8 @@ def _rebuild_clients(self) -> None: tool_parser=self.tool_parser, reasoning_parser=self.reasoning_parser, renderer_pool_size=self.renderer_pool_size, + preserve_all_thinking=self.preserve_all_thinking, + preserve_thinking_between_tool_calls=self.preserve_thinking_between_tool_calls, ) if urls else [] diff --git a/tests/unit/orchestrator/test_orchestrator_setup.py b/tests/unit/orchestrator/test_orchestrator_setup.py index ff9bb5b79f..5c5b420fc5 100644 --- a/tests/unit/orchestrator/test_orchestrator_setup.py +++ b/tests/unit/orchestrator/test_orchestrator_setup.py @@ -50,6 +50,8 @@ async def run() -> None: tool_parser=None, reasoning_parser=None, pool_size=None, + preserve_all_thinking=True, + preserve_thinking_between_tool_calls=False, ), ) rollout_client_config = SimpleNamespace(base_url=["http://localhost:8000/v1"]) @@ -79,6 +81,8 @@ async def run() -> None: renderer="qwen3_vl", tool_parser=None, reasoning_parser=None, + preserve_all_thinking=True, + preserve_thinking_between_tool_calls=False, ) setup_pool_mock.assert_awaited_once_with( rollout_client_config, @@ -89,6 +93,8 @@ async def run() -> None: tool_parser=None, reasoning_parser=None, renderer_pool_size=None, + preserve_all_thinking=True, + preserve_thinking_between_tool_calls=False, ) asyncio.run(run()) diff --git a/tests/unit/utils/test_client.py b/tests/unit/utils/test_client.py index 6b48790ef3..3b13e30bd3 100644 --- a/tests/unit/utils/test_client.py +++ b/tests/unit/utils/test_client.py @@ -62,10 +62,13 @@ def test_setup_clients_assigns_renderer_and_dp_rank_headers(): client_config, client_type="renderer", renderer_name="qwen3_vl", + preserve_all_thinking=True, ) assert [client.client_type for client in clients] == ["renderer", "renderer"] assert [client.renderer for client in clients] == ["qwen3_vl", "qwen3_vl"] + assert [client.preserve_all_thinking for client in clients] == [True, True] + assert [client.preserve_thinking_between_tool_calls for client in clients] == [False, False] assert [client.renderer_model_name for client in clients] == [None, None] assert [client.api_base_url for client in clients] == ["http://worker-a:8000/v1"] * 2 assert [client.extra_headers["X-data-parallel-rank"] for client in clients] == ["0", "1"] From a723ac03ff13a3002ada3f75e18ff8118df915df Mon Sep 17 00:00:00 2001 From: Sami Date: Sat, 16 May 2026 12:45:31 +0530 Subject: [PATCH 19/32] Avoid duplicate routed experts in token responses --- src/prime_rl/inference/vllm/serving_tokens.py | 2 +- tests/unit/inference/test_serving_tokens.py | 27 +++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/src/prime_rl/inference/vllm/serving_tokens.py b/src/prime_rl/inference/vllm/serving_tokens.py index 229e78cc93..789b361c19 100644 --- a/src/prime_rl/inference/vllm/serving_tokens.py +++ b/src/prime_rl/inference/vllm/serving_tokens.py @@ -61,7 +61,7 @@ class _GenerateRoutedExpertsCapture(RoutedExpertsCapture): def post_process(self, response: GenerateResponse) -> PrimeRlGenerateResponse: choices = [ PrimeRlGenerateResponseChoice( - **choice.model_dump(), + **choice.model_dump(exclude={"routed_experts"}), routed_experts=self.routed_experts.get(choice.index), ) for choice in response.choices diff --git a/tests/unit/inference/test_serving_tokens.py b/tests/unit/inference/test_serving_tokens.py index bdda7fc485..1882e57e55 100644 --- a/tests/unit/inference/test_serving_tokens.py +++ b/tests/unit/inference/test_serving_tokens.py @@ -14,11 +14,13 @@ import numpy as np import pybase64 +from vllm.entrypoints.serve.disagg.protocol import GenerateResponse, GenerateResponseChoice from prime_rl.inference.vllm.routed_experts import serialize_routed_experts from prime_rl.inference.vllm.serving_tokens import ( PrimeRlServingTokens, _client_set_max_tokens, + _GenerateRoutedExpertsCapture, ) @@ -40,6 +42,11 @@ async def json(self): return self._body +async def _empty_request_outputs(): + if False: + yield + + def test_subclass_only_overrides_serve_tokens(): assert PrimeRlServingTokens.serve_tokens is not PrimeRlServingTokens.__mro__[1].serve_tokens assert ( @@ -65,6 +72,26 @@ def test_serialize_routed_experts_uses_compact_raw_payload(): np.testing.assert_array_equal(decoded, routed_experts) +def test_generate_response_post_process_replaces_upstream_routed_experts(): + compact_routed_experts = {"data": "AQID", "shape": [1, 1, 3]} + capture = _GenerateRoutedExpertsCapture(_empty_request_outputs()) + capture.routed_experts[0] = compact_routed_experts + response = GenerateResponse( + request_id="request-id", + choices=[ + GenerateResponseChoice( + index=0, + token_ids=[1, 2, 3], + routed_experts="upstream-npy-payload", + ) + ], + ) + + processed = capture.post_process(response) + + assert processed.choices[0].routed_experts == compact_routed_experts + + def test_client_set_max_tokens_recognizes_explicit_value(): body = {"token_ids": [1, 2, 3], "sampling_params": {"max_tokens": 256}} assert asyncio.run(_client_set_max_tokens(_FakeRawRequest(body))) is True From b9d09bd90aa7ff38d0cd0b57c974dcd898dd7734 Mon Sep 17 00:00:00 2001 From: samsja <55492238+samsja@users.noreply.github.com> Date: Mon, 18 May 2026 15:19:47 -0700 Subject: [PATCH 20/32] [codex] Guard checkpoint disk metrics mkdir (#2523) * Guard checkpoint disk metrics mkdir * Remove test_trainer_utils.py per review feedback Co-Authored-By: Claude Opus 4.7 (1M context) * Simplify ckpt disk metrics guard Drop the rank-0 gate and the disk_usage path fallback per review feedback. Catching FileExistsError on mkdir is sufficient: every rank that races on mkdir either wins or harmlessly catches the BeegFS race, and shutil.disk_usage can then operate on the now-existing ckpt_dir. Co-Authored-By: Claude Opus 4.7 (1M context) --------- Co-authored-by: Claude Opus 4.7 (1M context) --- src/prime_rl/trainer/utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/prime_rl/trainer/utils.py b/src/prime_rl/trainer/utils.py index 72c4b4999e..406527acb0 100644 --- a/src/prime_rl/trainer/utils.py +++ b/src/prime_rl/trainer/utils.py @@ -119,7 +119,11 @@ def get_ckpt_disk_metrics(output_dir: Path) -> dict[str, float]: monitor.log(...) call (once per step). """ ckpt_dir = get_ckpt_dir(output_dir) - ckpt_dir.mkdir(parents=True, exist_ok=True) + try: + ckpt_dir.mkdir(parents=True, exist_ok=True) + except FileExistsError: + # BeegFS can surface FileExistsError from exist_ok=True when another rank wins the mkdir race. + pass usage = shutil.disk_usage(str(ckpt_dir)) total = float(usage.total) if usage.total else 0.0 return { From e2cffa18ff89b1dbd1766fdebfba82290a62a377 Mon Sep 17 00:00:00 2001 From: S1ro1 Date: Tue, 19 May 2026 19:46:20 +0530 Subject: [PATCH 21/32] Pin verifiers routed experts sidecar --- pyproject.toml | 2 +- uv.lock | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a54daaa387..e3ebdcac3f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -168,7 +168,7 @@ prime-rl-configs = { workspace = true } torch = { index = "pytorch-cu128" } torchvision = { index = "pytorch-cu128" } torchaudio = { index = "pytorch-cu128" } -verifiers = { git = "https://github.com/PrimeIntellect-ai/verifiers", rev = "3708ede" } +verifiers = { git = "https://github.com/PrimeIntellect-ai/verifiers", rev = "044f28c" } torchtitan = { git = "https://github.com/pytorch/torchtitan", rev = "a1fdd7e" } dion = { git = "https://github.com/samsja/dion.git", rev = "d891eeb" } transformers = { git = "https://github.com/huggingface/transformers.git", rev = "c1c3424" } diff --git a/uv.lock b/uv.lock index d3d6d567fe..dcdc3b3461 100644 --- a/uv.lock +++ b/uv.lock @@ -2964,7 +2964,7 @@ requires-dist = [ { name = "torchvision", index = "https://download.pytorch.org/whl/cu128" }, { name = "transformers", git = "https://github.com/huggingface/transformers.git?rev=c1c3424" }, { name = "uvloop", specifier = ">=0.21.0" }, - { name = "verifiers", git = "https://github.com/PrimeIntellect-ai/verifiers?rev=3708ede" }, + { name = "verifiers", git = "https://github.com/PrimeIntellect-ai/verifiers?rev=044f28c" }, { name = "vllm", marker = "platform_machine != 'aarch64' and platform_machine != 'x86_64'" }, { name = "vllm", marker = "platform_machine == 'aarch64'", url = "https://github.com/vllm-project/vllm/releases/download/v0.20.2/vllm-0.20.2+cu129-cp38-abi3-manylinux_2_31_aarch64.whl" }, { name = "vllm", marker = "platform_machine == 'x86_64'", url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.20.2rc1.dev354+g24337fb86.cu129-cp38-abi3-manylinux_2_34_x86_64.whl" }, @@ -4277,7 +4277,7 @@ wheels = [ [[package]] name = "verifiers" version = "0.1.15.dev5" -source = { git = "https://github.com/PrimeIntellect-ai/verifiers?rev=3708ede#3708ede31d16b77866befa3c7a97cf94b5062cd3" } +source = { git = "https://github.com/PrimeIntellect-ai/verifiers?rev=044f28c#044f28c102ead58169c38202de0e3e009c25d5a9" } dependencies = [ { name = "aiolimiter", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "anthropic", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, From 0edc0c598a3f495c33cbc28b2e632e5af72ee5e6 Mon Sep 17 00:00:00 2001 From: S1ro1 Date: Tue, 19 May 2026 20:13:33 +0530 Subject: [PATCH 22/32] Pin cleaned verifiers routed experts handling --- pyproject.toml | 2 +- uv.lock | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e3ebdcac3f..8806bb66e4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -168,7 +168,7 @@ prime-rl-configs = { workspace = true } torch = { index = "pytorch-cu128" } torchvision = { index = "pytorch-cu128" } torchaudio = { index = "pytorch-cu128" } -verifiers = { git = "https://github.com/PrimeIntellect-ai/verifiers", rev = "044f28c" } +verifiers = { git = "https://github.com/PrimeIntellect-ai/verifiers", rev = "3821b17" } torchtitan = { git = "https://github.com/pytorch/torchtitan", rev = "a1fdd7e" } dion = { git = "https://github.com/samsja/dion.git", rev = "d891eeb" } transformers = { git = "https://github.com/huggingface/transformers.git", rev = "c1c3424" } diff --git a/uv.lock b/uv.lock index dcdc3b3461..7ca7c6bf44 100644 --- a/uv.lock +++ b/uv.lock @@ -2964,7 +2964,7 @@ requires-dist = [ { name = "torchvision", index = "https://download.pytorch.org/whl/cu128" }, { name = "transformers", git = "https://github.com/huggingface/transformers.git?rev=c1c3424" }, { name = "uvloop", specifier = ">=0.21.0" }, - { name = "verifiers", git = "https://github.com/PrimeIntellect-ai/verifiers?rev=044f28c" }, + { name = "verifiers", git = "https://github.com/PrimeIntellect-ai/verifiers?rev=3821b17" }, { name = "vllm", marker = "platform_machine != 'aarch64' and platform_machine != 'x86_64'" }, { name = "vllm", marker = "platform_machine == 'aarch64'", url = "https://github.com/vllm-project/vllm/releases/download/v0.20.2/vllm-0.20.2+cu129-cp38-abi3-manylinux_2_31_aarch64.whl" }, { name = "vllm", marker = "platform_machine == 'x86_64'", url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.20.2rc1.dev354+g24337fb86.cu129-cp38-abi3-manylinux_2_34_x86_64.whl" }, @@ -4277,7 +4277,7 @@ wheels = [ [[package]] name = "verifiers" version = "0.1.15.dev5" -source = { git = "https://github.com/PrimeIntellect-ai/verifiers?rev=044f28c#044f28c102ead58169c38202de0e3e009c25d5a9" } +source = { git = "https://github.com/PrimeIntellect-ai/verifiers?rev=3821b17#3821b1762318181c5b114b036c084f28c37f1a9d" } dependencies = [ { name = "aiolimiter", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "anthropic", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, From 4402d7e20e62278b1fd615f49ccea712afd0de4d Mon Sep 17 00:00:00 2001 From: S1ro1 Date: Tue, 19 May 2026 20:19:21 +0530 Subject: [PATCH 23/32] Pin rebased verifiers routed experts handling --- pyproject.toml | 2 +- uv.lock | 7 +++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8806bb66e4..fcf274ca71 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -168,7 +168,7 @@ prime-rl-configs = { workspace = true } torch = { index = "pytorch-cu128" } torchvision = { index = "pytorch-cu128" } torchaudio = { index = "pytorch-cu128" } -verifiers = { git = "https://github.com/PrimeIntellect-ai/verifiers", rev = "3821b17" } +verifiers = { git = "https://github.com/PrimeIntellect-ai/verifiers", rev = "0640852" } torchtitan = { git = "https://github.com/pytorch/torchtitan", rev = "a1fdd7e" } dion = { git = "https://github.com/samsja/dion.git", rev = "d891eeb" } transformers = { git = "https://github.com/huggingface/transformers.git", rev = "c1c3424" } diff --git a/uv.lock b/uv.lock index 7ca7c6bf44..ef20a73dd8 100644 --- a/uv.lock +++ b/uv.lock @@ -2964,7 +2964,7 @@ requires-dist = [ { name = "torchvision", index = "https://download.pytorch.org/whl/cu128" }, { name = "transformers", git = "https://github.com/huggingface/transformers.git?rev=c1c3424" }, { name = "uvloop", specifier = ">=0.21.0" }, - { name = "verifiers", git = "https://github.com/PrimeIntellect-ai/verifiers?rev=3821b17" }, + { name = "verifiers", git = "https://github.com/PrimeIntellect-ai/verifiers?rev=0640852" }, { name = "vllm", marker = "platform_machine != 'aarch64' and platform_machine != 'x86_64'" }, { name = "vllm", marker = "platform_machine == 'aarch64'", url = "https://github.com/vllm-project/vllm/releases/download/v0.20.2/vllm-0.20.2+cu129-cp38-abi3-manylinux_2_31_aarch64.whl" }, { name = "vllm", marker = "platform_machine == 'x86_64'", url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.20.2rc1.dev354+g24337fb86.cu129-cp38-abi3-manylinux_2_34_x86_64.whl" }, @@ -4276,8 +4276,8 @@ wheels = [ [[package]] name = "verifiers" -version = "0.1.15.dev5" -source = { git = "https://github.com/PrimeIntellect-ai/verifiers?rev=3821b17#3821b1762318181c5b114b036c084f28c37f1a9d" } +version = "0.1.15.dev7" +source = { git = "https://github.com/PrimeIntellect-ai/verifiers?rev=0640852#0640852f194667b36f3625c5006eec2d953bd3f2" } dependencies = [ { name = "aiolimiter", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "anthropic", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, @@ -4294,7 +4294,6 @@ dependencies = [ { name = "openai-agents", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "prime-sandboxes", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "prime-tunnel", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, - { name = "pybase64", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "pydantic", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "pyzmq", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, { name = "regex", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux')" }, From 7076bb1d37e0d22fd64d0eabbf927aad68136505 Mon Sep 17 00:00:00 2001 From: S1ro1 Date: Thu, 21 May 2026 20:45:10 +0530 Subject: [PATCH 24/32] fix: remove unrelated prime-rl changes --- .../src/prime_rl/configs/orchestrator.py | 4 ---- .../src/prime_rl/configs/shared.py | 20 ------------------- src/prime_rl/orchestrator/orchestrator.py | 4 ---- src/prime_rl/trainer/utils.py | 6 +----- src/prime_rl/utils/client.py | 14 ------------- src/prime_rl/utils/elastic.py | 10 ---------- .../orchestrator/test_orchestrator_setup.py | 6 ------ tests/unit/utils/test_client.py | 3 --- 8 files changed, 1 insertion(+), 66 deletions(-) diff --git a/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py b/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py index 8feba60128..5d04d3369f 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/orchestrator.py @@ -1269,10 +1269,6 @@ def validate_renderer_args(self): renderer_args_set.append(f"renderer.reasoning_parser={self.renderer.reasoning_parser!r}") if self.renderer.pool_size is not None: renderer_args_set.append(f"renderer.pool_size={self.renderer.pool_size!r}") - if self.renderer.preserve_all_thinking: - renderer_args_set.append("renderer.preserve_all_thinking=true") - if self.renderer.preserve_thinking_between_tool_calls: - renderer_args_set.append("renderer.preserve_thinking_between_tool_calls=true") if renderer_args_set: raise ValueError( diff --git a/packages/prime-rl-configs/src/prime_rl/configs/shared.py b/packages/prime-rl-configs/src/prime_rl/configs/shared.py index 1651e17b9f..d26c33d9a9 100644 --- a/packages/prime-rl-configs/src/prime_rl/configs/shared.py +++ b/packages/prime-rl-configs/src/prime_rl/configs/shared.py @@ -186,26 +186,6 @@ class RendererConfig(BaseConfig): ), ] = None - preserve_all_thinking: Annotated[ - bool, - Field( - description=( - "Forward preserve_all_thinking to the renderer client. When true, " - "past-assistant reasoning_content is re-emitted on subsequent renders." - ), - ), - ] = False - - preserve_thinking_between_tool_calls: Annotated[ - bool, - Field( - description=( - "Forward preserve_thinking_between_tool_calls to the renderer client. " - "This preserves thinking only inside the active assistant/tool block." - ), - ), - ] = False - class ElasticConfig(BaseConfig): """Configures elastic inference pool with DNS-based service discovery. diff --git a/src/prime_rl/orchestrator/orchestrator.py b/src/prime_rl/orchestrator/orchestrator.py index 67ef7bfa1d..bc1128ebc7 100644 --- a/src/prime_rl/orchestrator/orchestrator.py +++ b/src/prime_rl/orchestrator/orchestrator.py @@ -926,8 +926,6 @@ async def setup_rollout_inference_pool( renderer=config.renderer.name, tool_parser=config.renderer.tool_parser, reasoning_parser=config.renderer.reasoning_parser, - preserve_all_thinking=config.renderer.preserve_all_thinking, - preserve_thinking_between_tool_calls=config.renderer.preserve_thinking_between_tool_calls, ) logger.info(f"Initialized {type(renderer).__name__} for {config.model.name}") inference_pool = await setup_inference_pool( @@ -939,8 +937,6 @@ async def setup_rollout_inference_pool( tool_parser=config.renderer.tool_parser, reasoning_parser=config.renderer.reasoning_parser, renderer_pool_size=config.renderer.pool_size, - preserve_all_thinking=config.renderer.preserve_all_thinking, - preserve_thinking_between_tool_calls=config.renderer.preserve_thinking_between_tool_calls, ) logger.info("Using direct renderer rollout client") return renderer, inference_pool diff --git a/src/prime_rl/trainer/utils.py b/src/prime_rl/trainer/utils.py index 406527acb0..72c4b4999e 100644 --- a/src/prime_rl/trainer/utils.py +++ b/src/prime_rl/trainer/utils.py @@ -119,11 +119,7 @@ def get_ckpt_disk_metrics(output_dir: Path) -> dict[str, float]: monitor.log(...) call (once per step). """ ckpt_dir = get_ckpt_dir(output_dir) - try: - ckpt_dir.mkdir(parents=True, exist_ok=True) - except FileExistsError: - # BeegFS can surface FileExistsError from exist_ok=True when another rank wins the mkdir race. - pass + ckpt_dir.mkdir(parents=True, exist_ok=True) usage = shutil.disk_usage(str(ckpt_dir)) total = float(usage.total) if usage.total else 0.0 return { diff --git a/src/prime_rl/utils/client.py b/src/prime_rl/utils/client.py index fedbdddb8e..21659dfc46 100644 --- a/src/prime_rl/utils/client.py +++ b/src/prime_rl/utils/client.py @@ -68,8 +68,6 @@ def __init__( tool_parser: str | None = None, reasoning_parser: str | None = None, renderer_pool_size: int | None = None, - preserve_all_thinking: bool = False, - preserve_thinking_between_tool_calls: bool = False, ): renderer_model_name = model_name if train_client_type == "renderer" else None self._train_clients = setup_clients( @@ -80,8 +78,6 @@ def __init__( tool_parser=tool_parser, reasoning_parser=reasoning_parser, renderer_pool_size=renderer_pool_size, - preserve_all_thinking=preserve_all_thinking, - preserve_thinking_between_tool_calls=preserve_thinking_between_tool_calls, ) self._eval_clients = setup_clients(client_config, client_type=eval_client_type) self._admin_clients = setup_admin_clients(client_config) @@ -133,8 +129,6 @@ async def setup_inference_pool( tool_parser: str | None = None, reasoning_parser: str | None = None, renderer_pool_size: int | None = None, - preserve_all_thinking: bool = False, - preserve_thinking_between_tool_calls: bool = False, ) -> InferencePool: """Create an inference pool from config (static or elastic).""" logger = get_logger() @@ -158,8 +152,6 @@ async def setup_inference_pool( tool_parser=tool_parser, reasoning_parser=reasoning_parser, renderer_pool_size=renderer_pool_size, - preserve_all_thinking=preserve_all_thinking, - preserve_thinking_between_tool_calls=preserve_thinking_between_tool_calls, ) logger.info( @@ -176,8 +168,6 @@ async def setup_inference_pool( tool_parser=tool_parser, reasoning_parser=reasoning_parser, renderer_pool_size=renderer_pool_size, - preserve_all_thinking=preserve_all_thinking, - preserve_thinking_between_tool_calls=preserve_thinking_between_tool_calls, ) @@ -189,8 +179,6 @@ def setup_clients( tool_parser: str | None = None, reasoning_parser: str | None = None, renderer_pool_size: int | None = None, - preserve_all_thinking: bool = False, - preserve_thinking_between_tool_calls: bool = False, ) -> list[vf.ClientConfig]: clients = [] client_idx = 0 @@ -208,8 +196,6 @@ def setup_clients( renderer_pool_size=renderer_pool_size, tool_parser=tool_parser, reasoning_parser=reasoning_parser, - preserve_all_thinking=preserve_all_thinking, - preserve_thinking_between_tool_calls=preserve_thinking_between_tool_calls, api_base_url=base_url, api_key_var=client_config.api_key_var, timeout=client_config.timeout, diff --git a/src/prime_rl/utils/elastic.py b/src/prime_rl/utils/elastic.py index c59f81e27f..902f873903 100644 --- a/src/prime_rl/utils/elastic.py +++ b/src/prime_rl/utils/elastic.py @@ -110,8 +110,6 @@ def __init__( tool_parser: str | None = None, reasoning_parser: str | None = None, renderer_pool_size: int | None = None, - preserve_all_thinking: bool = False, - preserve_thinking_between_tool_calls: bool = False, ): self.logger = get_logger() self.client_config = client_config @@ -127,8 +125,6 @@ def __init__( self.tool_parser = tool_parser self.reasoning_parser = reasoning_parser self.renderer_pool_size = renderer_pool_size - self.preserve_all_thinking = preserve_all_thinking - self.preserve_thinking_between_tool_calls = preserve_thinking_between_tool_calls self.router_url = client_config.router_url self._servers: dict[str, ServerState] = {} @@ -156,8 +152,6 @@ async def from_config( tool_parser: str | None = None, reasoning_parser: str | None = None, renderer_pool_size: int | None = None, - preserve_all_thinking: bool = False, - preserve_thinking_between_tool_calls: bool = False, ) -> ElasticInferencePool: if client_config.elastic is None: raise ValueError("Elastic inference pool requires elastic config") @@ -170,8 +164,6 @@ async def from_config( tool_parser=tool_parser, reasoning_parser=reasoning_parser, renderer_pool_size=renderer_pool_size, - preserve_all_thinking=preserve_all_thinking, - preserve_thinking_between_tool_calls=preserve_thinking_between_tool_calls, ) await pool.start() return pool @@ -222,8 +214,6 @@ def _rebuild_clients(self) -> None: tool_parser=self.tool_parser, reasoning_parser=self.reasoning_parser, renderer_pool_size=self.renderer_pool_size, - preserve_all_thinking=self.preserve_all_thinking, - preserve_thinking_between_tool_calls=self.preserve_thinking_between_tool_calls, ) if urls else [] diff --git a/tests/unit/orchestrator/test_orchestrator_setup.py b/tests/unit/orchestrator/test_orchestrator_setup.py index 5c5b420fc5..ff9bb5b79f 100644 --- a/tests/unit/orchestrator/test_orchestrator_setup.py +++ b/tests/unit/orchestrator/test_orchestrator_setup.py @@ -50,8 +50,6 @@ async def run() -> None: tool_parser=None, reasoning_parser=None, pool_size=None, - preserve_all_thinking=True, - preserve_thinking_between_tool_calls=False, ), ) rollout_client_config = SimpleNamespace(base_url=["http://localhost:8000/v1"]) @@ -81,8 +79,6 @@ async def run() -> None: renderer="qwen3_vl", tool_parser=None, reasoning_parser=None, - preserve_all_thinking=True, - preserve_thinking_between_tool_calls=False, ) setup_pool_mock.assert_awaited_once_with( rollout_client_config, @@ -93,8 +89,6 @@ async def run() -> None: tool_parser=None, reasoning_parser=None, renderer_pool_size=None, - preserve_all_thinking=True, - preserve_thinking_between_tool_calls=False, ) asyncio.run(run()) diff --git a/tests/unit/utils/test_client.py b/tests/unit/utils/test_client.py index 3b13e30bd3..6b48790ef3 100644 --- a/tests/unit/utils/test_client.py +++ b/tests/unit/utils/test_client.py @@ -62,13 +62,10 @@ def test_setup_clients_assigns_renderer_and_dp_rank_headers(): client_config, client_type="renderer", renderer_name="qwen3_vl", - preserve_all_thinking=True, ) assert [client.client_type for client in clients] == ["renderer", "renderer"] assert [client.renderer for client in clients] == ["qwen3_vl", "qwen3_vl"] - assert [client.preserve_all_thinking for client in clients] == [True, True] - assert [client.preserve_thinking_between_tool_calls for client in clients] == [False, False] assert [client.renderer_model_name for client in clients] == [None, None] assert [client.api_base_url for client in clients] == ["http://worker-a:8000/v1"] * 2 assert [client.extra_headers["X-data-parallel-rank"] for client in clients] == ["0", "1"] From 62cc96bc7e75d037315bb8ab19988f7b6a0f8124 Mon Sep 17 00:00:00 2001 From: S1ro1 Date: Fri, 22 May 2026 05:07:25 +0530 Subject: [PATCH 25/32] fix: pack routed experts as typed payloads --- src/prime_rl/orchestrator/trajectories.py | 10 +-- src/prime_rl/trainer/batch.py | 74 ++++++++++++++--------- src/prime_rl/trainer/rl/data.py | 10 +-- src/prime_rl/transport/__init__.py | 3 +- tests/unit/orchestrator/test_batch.py | 26 ++++---- 5 files changed, 71 insertions(+), 52 deletions(-) diff --git a/src/prime_rl/orchestrator/trajectories.py b/src/prime_rl/orchestrator/trajectories.py index 692c06d54d..fd36148fce 100644 --- a/src/prime_rl/orchestrator/trajectories.py +++ b/src/prime_rl/orchestrator/trajectories.py @@ -9,7 +9,7 @@ import verifiers as vf from transformers.tokenization_utils import PreTrainedTokenizer -from prime_rl.transport import TrainingSample +from prime_rl.transport import RoutedExperts, TrainingSample from prime_rl.utils.chat_template import ( common_prefix_len, deserialize_tool_calls, @@ -60,11 +60,13 @@ def _align_routed_experts( def _set_sample_routed_experts(sample: TrainingSample, routed_experts: np.ndarray | None) -> None: if routed_experts is None: sample.routed_experts = None - sample.routed_experts_shape = None return routed_experts = np.ascontiguousarray(routed_experts) - sample.routed_experts = routed_experts.tobytes() - sample.routed_experts_shape = list(routed_experts.shape) + sample.routed_experts = RoutedExperts( + data=routed_experts.tobytes(), + shape=list(routed_experts.shape), + dtype=str(routed_experts.dtype), + ) def _common_prefix_len(a: list[int], b: list[int]) -> int: diff --git a/src/prime_rl/trainer/batch.py b/src/prime_rl/trainer/batch.py index 6510df7cc4..3c323457f3 100644 --- a/src/prime_rl/trainer/batch.py +++ b/src/prime_rl/trainer/batch.py @@ -1,33 +1,52 @@ import copy -from prime_rl.transport.types import MicroBatch, TrainingSample +from prime_rl.transport.types import MicroBatch, RoutedExperts, TrainingSample +ROUTED_EXPERTS_DTYPE_ITEMSIZE = { + "uint8": 1, + "int16": 2, + "int32": 4, +} -def _routed_experts_row_size(shape: list[int]) -> int: - return shape[1] * shape[2] +def _copy_routed_experts(routed_experts: RoutedExperts) -> RoutedExperts: + return RoutedExperts( + data=routed_experts.data, + shape=list(routed_experts.shape), + dtype=routed_experts.dtype, + ) + + +def _routed_experts_row_size(routed_experts: RoutedExperts) -> int: + return routed_experts.shape[1] * routed_experts.shape[2] * ROUTED_EXPERTS_DTYPE_ITEMSIZE[routed_experts.dtype] -def _slice_routed_experts(data: bytes, shape: list[int], seq_len: int) -> tuple[bytes, list[int]]: - row_size = _routed_experts_row_size(shape) - return data[: seq_len * row_size], [seq_len, shape[1], shape[2]] + +def _slice_routed_experts(routed_experts: RoutedExperts, seq_len: int) -> RoutedExperts: + row_size = _routed_experts_row_size(routed_experts) + return RoutedExperts( + data=routed_experts.data[: seq_len * row_size], + shape=[seq_len, routed_experts.shape[1], routed_experts.shape[2]], + dtype=routed_experts.dtype, + ) def _append_routed_experts(dst: MicroBatch, src: MicroBatch) -> None: - assert dst.routed_experts is not None - assert dst.routed_experts_shape is not None - assert src.routed_experts is not None - assert src.routed_experts_shape is not None - assert dst.routed_experts_shape[1:] == src.routed_experts_shape[1:] - dst.routed_experts += src.routed_experts - dst.routed_experts_shape[0] += src.routed_experts_shape[0] + dst_routed = dst.routed_experts + src_routed = src.routed_experts + assert dst_routed is not None + assert src_routed is not None + assert dst_routed.dtype == src_routed.dtype + assert dst_routed.shape[1:] == src_routed.shape[1:] + dst_routed.data += src_routed.data + dst_routed.shape[0] += src_routed.shape[0] def _pad_routed_experts(micro_batch: MicroBatch, padding_size: int) -> None: - assert micro_batch.routed_experts is not None - assert micro_batch.routed_experts_shape is not None - row_size = _routed_experts_row_size(micro_batch.routed_experts_shape) - micro_batch.routed_experts += b"\0" * (padding_size * row_size) - micro_batch.routed_experts_shape[0] += padding_size + routed_experts = micro_batch.routed_experts + assert routed_experts is not None + row_size = _routed_experts_row_size(routed_experts) + routed_experts.data += b"\0" * (padding_size * row_size) + routed_experts.shape[0] += padding_size def prepare_sample(training_example: TrainingSample, seq_len: int) -> MicroBatch: @@ -52,8 +71,9 @@ def prepare_sample(training_example: TrainingSample, seq_len: int) -> MicroBatch # Teacher logprobs already cover the full sequence (prompt + completion), # computed via prefill in the orchestrator when a teacher model is configured teacher_logprobs = training_example.teacher_logprobs - routed_experts = training_example.routed_experts - routed_experts_shape = training_example.routed_experts_shape + routed_experts = ( + _copy_routed_experts(training_example.routed_experts) if training_example.routed_experts is not None else None + ) if len(input_ids) > seq_len: input_ids = input_ids[:seq_len] @@ -65,8 +85,7 @@ def prepare_sample(training_example: TrainingSample, seq_len: int) -> MicroBatch if teacher_logprobs is not None: teacher_logprobs = teacher_logprobs[:seq_len] if routed_experts is not None: - assert routed_experts_shape is not None - routed_experts, routed_experts_shape = _slice_routed_experts(routed_experts, routed_experts_shape, seq_len) + routed_experts = _slice_routed_experts(routed_experts, seq_len) if mm_token_type_ids is not None: mm_token_type_ids = mm_token_type_ids[:seq_len] env_names = env_names[:seq_len] @@ -85,11 +104,10 @@ def prepare_sample(training_example: TrainingSample, seq_len: int) -> MicroBatch assert len(teacher_logprobs) == len(input_ids), f"teacher_logprobs: {len(teacher_logprobs)}" if routed_experts is not None: - assert routed_experts_shape is not None - assert routed_experts_shape[0] == len(input_ids), ( - f"routed_experts: {routed_experts_shape}, input_ids: {len(input_ids)}" + assert routed_experts.shape[0] == len(input_ids), ( + f"routed_experts: {routed_experts.shape}, input_ids: {len(input_ids)}" ) - assert len(routed_experts) == len(input_ids) * _routed_experts_row_size(routed_experts_shape) + assert len(routed_experts.data) == len(input_ids) * _routed_experts_row_size(routed_experts) if mm_token_type_ids is not None: assert len(mm_token_type_ids) == len(input_ids), ( @@ -106,7 +124,6 @@ def prepare_sample(training_example: TrainingSample, seq_len: int) -> MicroBatch teacher_logprobs=teacher_logprobs, temperatures=temperatures, routed_experts=routed_experts, - routed_experts_shape=routed_experts_shape, mm_token_type_ids=mm_token_type_ids, env_names=env_names, mm_kwargs=training_example.mm_kwargs, @@ -166,8 +183,7 @@ def packed_samples_into_micro_bs( assert (bin_content.routed_experts is None) == (sample.routed_experts is None) if sample.routed_experts is not None: if bin_content.routed_experts is None: - bin_content.routed_experts = sample.routed_experts - bin_content.routed_experts_shape = list(sample.routed_experts_shape) + bin_content.routed_experts = _copy_routed_experts(sample.routed_experts) else: _append_routed_experts(bin_content, sample) if sample.mm_token_type_ids is not None: diff --git a/src/prime_rl/trainer/rl/data.py b/src/prime_rl/trainer/rl/data.py index 2ee3918d6e..b08fcf2666 100644 --- a/src/prime_rl/trainer/rl/data.py +++ b/src/prime_rl/trainer/rl/data.py @@ -208,14 +208,14 @@ def _micro_batch_to_tensor(self, micro_batch: MicroBatch) -> TensorMicroBatch: for key, payload in micro_batch.mm_kwargs.items() } routed_experts = None - if micro_batch.routed_experts is not None: - assert micro_batch.routed_experts_shape is not None + packed_routed_experts = micro_batch.routed_experts + if packed_routed_experts is not None: routed_experts = ( torch.frombuffer( - micro_batch.routed_experts, - dtype=torch.uint8, + packed_routed_experts.data, + dtype=_torch_dtype(packed_routed_experts.dtype), ) - .reshape(micro_batch.routed_experts_shape) + .reshape(packed_routed_experts.shape) .to(torch.int32) .unsqueeze(0) ) diff --git a/src/prime_rl/transport/__init__.py b/src/prime_rl/transport/__init__.py index e4c3153dc7..bad9d6c806 100644 --- a/src/prime_rl/transport/__init__.py +++ b/src/prime_rl/transport/__init__.py @@ -8,7 +8,7 @@ FileSystemTrainingBatchReceiver, FileSystemTrainingBatchSender, ) -from prime_rl.transport.types import MicroBatch, TrainingBatch, TrainingSample +from prime_rl.transport.types import MicroBatch, RoutedExperts, TrainingBatch, TrainingSample from prime_rl.transport.zmq import ( ZMQMicroBatchReceiver, ZMQMicroBatchSender, @@ -67,6 +67,7 @@ def setup_micro_batch_receiver( "TrainingSample", "TrainingBatch", "MicroBatch", + "RoutedExperts", "setup_training_batch_sender", "setup_training_batch_receiver", "setup_micro_batch_sender", diff --git a/tests/unit/orchestrator/test_batch.py b/tests/unit/orchestrator/test_batch.py index 62a2363649..7531423c72 100644 --- a/tests/unit/orchestrator/test_batch.py +++ b/tests/unit/orchestrator/test_batch.py @@ -2,12 +2,16 @@ import pytest from prime_rl.trainer.batch import prepare_batch, prepare_sample -from prime_rl.transport.types import TrainingSample +from prime_rl.transport.types import RoutedExperts, TrainingSample def _routed_experts(data, dtype=np.uint8): routed_experts = np.asarray(data, dtype=dtype) - return routed_experts.tobytes(), list(routed_experts.shape) + return RoutedExperts( + data=routed_experts.tobytes(), + shape=list(routed_experts.shape), + dtype=str(routed_experts.dtype), + ) @pytest.fixture @@ -134,7 +138,7 @@ def test_prepare_sample_with_routed_experts(): """Routed experts are passed through prepare_sample and match input_ids length.""" # 2 prompt + 2 completion = 4 tokens, 2 layers, topk=2 routed_experts = [[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[0, 2], [1, 3]], [[1, 0], [3, 2]]] - routed_bytes, routed_shape = _routed_experts(routed_experts) + routed_payload = _routed_experts(routed_experts) sample = TrainingSample( prompt_ids=[1, 2], prompt_mask=[False, False], @@ -144,21 +148,19 @@ def test_prepare_sample_with_routed_experts(): completion_temperatures=[1.0, 1.0], advantage=1.0, env_name="test-env", - routed_experts=routed_bytes, - routed_experts_shape=routed_shape, + routed_experts=routed_payload, ) micro_batch = prepare_sample(sample, seq_len=8) assert micro_batch.routed_experts is not None - assert micro_batch.routed_experts == routed_bytes - assert micro_batch.routed_experts_shape == routed_shape + assert micro_batch.routed_experts == routed_payload def test_prepare_sample_truncates_routed_experts(): """Routed experts are truncated to seq_len when input exceeds it.""" routed_experts = [[[0, 1]], [[2, 3]], [[4, 5]], [[6, 7]]] - routed_bytes, routed_shape = _routed_experts(routed_experts) - expected_bytes, expected_shape = _routed_experts(routed_experts[:3]) + routed_payload = _routed_experts(routed_experts) + expected_payload = _routed_experts(routed_experts[:3]) sample = TrainingSample( prompt_ids=[1, 2], prompt_mask=[False, False], @@ -168,14 +170,12 @@ def test_prepare_sample_truncates_routed_experts(): completion_temperatures=[1.0, 1.0], advantage=1.0, env_name="test-env", - routed_experts=routed_bytes, - routed_experts_shape=routed_shape, + routed_experts=routed_payload, ) micro_batch = prepare_sample(sample, seq_len=3) assert micro_batch.routed_experts is not None - assert micro_batch.routed_experts == expected_bytes - assert micro_batch.routed_experts_shape == expected_shape + assert micro_batch.routed_experts == expected_payload assert micro_batch.env_names == ["test-env"] * 3 From ae6b8b3204049271eb4f0a5e058982aca08418d4 Mon Sep 17 00:00:00 2001 From: S1ro1 Date: Fri, 22 May 2026 05:25:40 +0530 Subject: [PATCH 26/32] refactor: inline routed experts trajectory packing --- src/prime_rl/orchestrator/trajectories.py | 114 +++++++------------ tests/unit/orchestrator/test_trajectories.py | 18 +-- 2 files changed, 51 insertions(+), 81 deletions(-) diff --git a/src/prime_rl/orchestrator/trajectories.py b/src/prime_rl/orchestrator/trajectories.py index fd36148fce..f1b75ed058 100644 --- a/src/prime_rl/orchestrator/trajectories.py +++ b/src/prime_rl/orchestrator/trajectories.py @@ -23,19 +23,7 @@ # primitives are immutable. mm_kwargs payloads are not mutated after creation. -def _decode_routed_experts(payload: dict[str, Any] | None) -> np.ndarray | None: - if payload is None: - return None - shape = [int(dim) for dim in payload["shape"]] - decoded = pybase64.b64decode_as_bytearray(payload["data"]) - expected_size = int(np.prod(shape, dtype=np.int64)) - assert len(decoded) == expected_size, (len(decoded), expected_size, shape) - routed_experts = np.frombuffer(decoded, dtype=np.uint8).reshape(shape) - assert routed_experts.ndim == 3 - return routed_experts - - -def _align_routed_experts( +def align_routed_experts( routed_experts: np.ndarray | None, expected_len: int, ) -> np.ndarray | None: @@ -57,58 +45,16 @@ def _align_routed_experts( return np.concatenate((routed_experts, padding), axis=0) -def _set_sample_routed_experts(sample: TrainingSample, routed_experts: np.ndarray | None) -> None: - if routed_experts is None: - sample.routed_experts = None - return - routed_experts = np.ascontiguousarray(routed_experts) - sample.routed_experts = RoutedExperts( - data=routed_experts.tobytes(), - shape=list(routed_experts.shape), - dtype=str(routed_experts.dtype), - ) - - -def _common_prefix_len(a: list[int], b: list[int]) -> int: - return common_prefix_len(a, b) - - -def _normalize_messages(messages: Any, default_role: str) -> list[dict[str, Any]]: - return normalize_messages(messages, default_role) - - -def _deserialize_tool_calls(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: - return deserialize_tool_calls(messages) - - -def _strip_message_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: - return strip_message_content(messages) - - -def _render_messages( - tokenizer: PreTrainedTokenizer, - messages: list[dict[str, Any]], - add_generation_prompt: bool = False, - tools: list[dict[str, Any]] | None = None, -) -> list[int]: - return render_messages( - tokenizer, - messages, - add_generation_prompt=add_generation_prompt, - tools=tools, - ) - - def _tokenize_step_from_messages( step: vf.TrajectoryStep, tokenizer: PreTrainedTokenizer, tools: list[dict[str, Any]] | None = None, ) -> dict[str, Any]: - prompt = _normalize_messages(step.get("prompt"), default_role="user") - completion = _normalize_messages(step.get("completion"), default_role="assistant") + prompt = normalize_messages(step.get("prompt"), default_role="user") + completion = normalize_messages(step.get("completion"), default_role="assistant") - prompt = _strip_message_content(_deserialize_tool_calls(prompt)) - completion = _strip_message_content(_deserialize_tool_calls(completion)) + prompt = strip_message_content(deserialize_tool_calls(prompt)) + completion = strip_message_content(deserialize_tool_calls(completion)) assert all(m.get("role") == "assistant" for m in completion), ( "Expected all completion messages to be assistant role for SFT distillation, " @@ -117,19 +63,19 @@ def _tokenize_step_from_messages( all_messages = prompt + completion prompt_has_assistant_completion = len(completion) > 0 and completion[0].get("role") == "assistant" - prompt_ids = _render_messages( + prompt_ids = render_messages( tokenizer, prompt, add_generation_prompt=prompt_has_assistant_completion, tools=tools, ) - full_ids = _render_messages( + full_ids = render_messages( tokenizer, all_messages, tools=tools, ) - split_idx = _common_prefix_len(prompt_ids, full_ids) + split_idx = common_prefix_len(prompt_ids, full_ids) original_prompt_len = len(prompt_ids) prompt_ids = full_ids[:split_idx] @@ -181,10 +127,10 @@ def _tokenize_step_with_renderer( """Tokenize a trajectory step using a Renderer.""" from renderers.base import build_trajectory_step - prompt = _normalize_messages(step.get("prompt"), default_role="user") - completion = _normalize_messages(step.get("completion"), default_role="assistant") - prompt = _strip_message_content(_deserialize_tool_calls(prompt)) - completion = _strip_message_content(_deserialize_tool_calls(completion)) + prompt = normalize_messages(step.get("prompt"), default_role="user") + completion = normalize_messages(step.get("completion"), default_role="assistant") + prompt = strip_message_content(deserialize_tool_calls(prompt)) + completion = strip_message_content(deserialize_tool_calls(completion)) return build_trajectory_step(renderer, prompt, completion, tools=tools) @@ -263,7 +209,20 @@ def interleave_rollout( def prepare_step_tokens(step: vf.TrajectoryStep, step_idx: int) -> dict[str, Any] | None: tokens = step["tokens"] if tokens is not None: - routed_experts = _decode_routed_experts(tokens.get("routed_experts")) + routed_experts_payload = tokens.get("routed_experts") + routed_experts = None + if routed_experts_payload is not None: + shape = [int(dim) for dim in routed_experts_payload["shape"]] + decoded_routed_experts = pybase64.b64decode_as_bytearray(routed_experts_payload["data"]) + expected_size = int(np.prod(shape, dtype=np.int64)) + assert len(decoded_routed_experts) == expected_size, ( + len(decoded_routed_experts), + expected_size, + shape, + ) + routed_experts = np.frombuffer(decoded_routed_experts, dtype=np.uint8).reshape(shape) + assert routed_experts.ndim == 3 + return { "prompt_ids": list(tokens["prompt_ids"]), "prompt_mask": [bool(i) for i in tokens["prompt_mask"]], @@ -296,7 +255,7 @@ def make_sample(tokens: dict[str, Any]) -> TrainingSample: completion_mask = [bool(i) for i in tokens["completion_mask"]] completion_ids = list(tokens["completion_ids"]) - routed_experts = _align_routed_experts( + routed_experts = align_routed_experts( tokens.get("routed_experts"), len(tokens["prompt_ids"]) + len(tokens["completion_ids"]), ) @@ -313,7 +272,13 @@ def make_sample(tokens: dict[str, Any]) -> TrainingSample: env_name=output["env_name"], mm_token_type_ids=None, ) - _set_sample_routed_experts(sample, routed_experts) + if routed_experts is not None: + routed_experts = np.ascontiguousarray(routed_experts) + sample.routed_experts = RoutedExperts( + data=routed_experts.tobytes(), + shape=list(routed_experts.shape), + dtype=str(routed_experts.dtype), + ) return sample, routed_experts def extend_sample( @@ -344,7 +309,7 @@ def extend_sample( if tokens.get("routed_experts") is not None and sample_routed_experts is not None: step_routed = tokens["routed_experts"] - # The previous step's last routing entry was zero-padded by _align_routed_experts + # The previous step's last routing entry was zero-padded by align_routed_experts # (vLLM only captures num_tokens-1 routings per request). This step actually # processed that boundary token as part of its prompt, so replace the zero-fill # with the real routing decision before appending new entries. @@ -352,8 +317,13 @@ def extend_sample( sample_routed_experts[prefix_len - 1] = step_routed[prefix_len - 1] sample_routed_experts = np.concatenate((sample_routed_experts, step_routed[prefix_len:]), axis=0) expected_len = len(sample.prompt_ids) + len(sample.completion_ids) - sample_routed_experts = _align_routed_experts(sample_routed_experts, expected_len) - _set_sample_routed_experts(sample, sample_routed_experts) + sample_routed_experts = align_routed_experts(sample_routed_experts, expected_len) + sample_routed_experts = np.ascontiguousarray(sample_routed_experts) + sample.routed_experts = RoutedExperts( + data=sample_routed_experts.tobytes(), + shape=list(sample_routed_experts.shape), + dtype=str(sample_routed_experts.dtype), + ) return sample_routed_experts # Track (prefix_tokens, sample, step_indices) per active sample. step_indices diff --git a/tests/unit/orchestrator/test_trajectories.py b/tests/unit/orchestrator/test_trajectories.py index 0a788c0282..b37de4abb6 100644 --- a/tests/unit/orchestrator/test_trajectories.py +++ b/tests/unit/orchestrator/test_trajectories.py @@ -6,10 +6,10 @@ import verifiers as vf from prime_rl.orchestrator.trajectories import ( - _align_routed_experts, - _deserialize_tool_calls, + align_routed_experts, interleave_rollout, ) +from prime_rl.utils.chat_template import deserialize_tool_calls _interleave_rollout = interleave_rollout @@ -49,7 +49,7 @@ def _sample_routed_experts(sample) -> np.ndarray: def test_deserialize_tool_calls_does_not_inject_missing_key(): messages = [{"role": "assistant", "content": "hello"}] - deserialized = _deserialize_tool_calls(messages) + deserialized = deserialize_tool_calls(messages) assert "tool_calls" not in deserialized[0] @@ -68,7 +68,7 @@ def test_deserialize_tool_calls_parses_arguments_when_present(): } ] - deserialized = _deserialize_tool_calls(messages) + deserialized = deserialize_tool_calls(messages) assert deserialized[0]["tool_calls"][0]["function"]["arguments"] == {"x": 1} @@ -823,12 +823,12 @@ def test_interleave_rollout_error_masks_all_false(): def test_align_routed_experts_none(): - assert _align_routed_experts(None, 10) is None + assert align_routed_experts(None, 10) is None def test_align_routed_experts_empty(): experts = np.empty((0, 2, 2), dtype=np.uint8) - result = _align_routed_experts(experts, 10) + result = align_routed_experts(experts, 10) assert result is not None assert result.shape == (10, 2, 2) assert np.all(result == 0) @@ -837,14 +837,14 @@ def test_align_routed_experts_empty(): def test_align_routed_experts_no_deficit(): # 3 tokens, 2 layers, topk=2 experts = np.asarray([[[0, 1], [2, 3]], [[4, 5], [6, 7]], [[0, 2], [1, 3]]], dtype=np.uint8) - result = _align_routed_experts(experts, expected_len=3) + result = align_routed_experts(experts, expected_len=3) np.testing.assert_array_equal(result, experts) def test_align_routed_experts_with_deficit(): # 2 tokens but expected 4 (deficit of 2) experts = np.asarray([[[1, 2], [3, 4]], [[5, 6], [7, 0]]], dtype=np.uint8) - result = _align_routed_experts(experts, expected_len=4) + result = align_routed_experts(experts, expected_len=4) assert result is not None assert result.shape == (4, 2, 2) np.testing.assert_array_equal(result[:2], experts) @@ -855,7 +855,7 @@ def test_align_routed_experts_with_deficit(): def test_align_routed_experts_excess_length(): experts = np.asarray([[[1, 2]], [[3, 4]], [[5, 6]]], dtype=np.uint8) - result = _align_routed_experts(experts, expected_len=2) + result = align_routed_experts(experts, expected_len=2) np.testing.assert_array_equal(result, experts[:2]) From 6de7fd123ea197874b3902ee60839fc4af103c07 Mon Sep 17 00:00:00 2001 From: S1ro1 Date: Fri, 22 May 2026 05:28:03 +0530 Subject: [PATCH 27/32] fix: restore trajectory tokenization helpers --- src/prime_rl/orchestrator/trajectories.py | 52 +++++++++++++++----- tests/unit/orchestrator/test_trajectories.py | 6 +-- 2 files changed, 44 insertions(+), 14 deletions(-) diff --git a/src/prime_rl/orchestrator/trajectories.py b/src/prime_rl/orchestrator/trajectories.py index f1b75ed058..bd827e5374 100644 --- a/src/prime_rl/orchestrator/trajectories.py +++ b/src/prime_rl/orchestrator/trajectories.py @@ -45,16 +45,46 @@ def align_routed_experts( return np.concatenate((routed_experts, padding), axis=0) +def _common_prefix_len(a: list[int], b: list[int]) -> int: + return common_prefix_len(a, b) + + +def _normalize_messages(messages: Any, default_role: str) -> list[dict[str, Any]]: + return normalize_messages(messages, default_role) + + +def _deserialize_tool_calls(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: + return deserialize_tool_calls(messages) + + +def _strip_message_content(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: + return strip_message_content(messages) + + +def _render_messages( + tokenizer: PreTrainedTokenizer, + messages: list[dict[str, Any]], + add_generation_prompt: bool = False, + tools: list[dict[str, Any]] | None = None, +) -> list[int]: + return render_messages( + tokenizer, + messages, + add_generation_prompt=add_generation_prompt, + tools=tools, + ) + + def _tokenize_step_from_messages( step: vf.TrajectoryStep, tokenizer: PreTrainedTokenizer, tools: list[dict[str, Any]] | None = None, ) -> dict[str, Any]: - prompt = normalize_messages(step.get("prompt"), default_role="user") - completion = normalize_messages(step.get("completion"), default_role="assistant") + prompt = _normalize_messages(step.get("prompt"), default_role="user") + completion = _normalize_messages(step.get("completion"), default_role="assistant") - prompt = strip_message_content(deserialize_tool_calls(prompt)) - completion = strip_message_content(deserialize_tool_calls(completion)) + prompt = _strip_message_content(_deserialize_tool_calls(prompt)) + completion = _strip_message_content(_deserialize_tool_calls(completion)) assert all(m.get("role") == "assistant" for m in completion), ( "Expected all completion messages to be assistant role for SFT distillation, " @@ -63,19 +93,19 @@ def _tokenize_step_from_messages( all_messages = prompt + completion prompt_has_assistant_completion = len(completion) > 0 and completion[0].get("role") == "assistant" - prompt_ids = render_messages( + prompt_ids = _render_messages( tokenizer, prompt, add_generation_prompt=prompt_has_assistant_completion, tools=tools, ) - full_ids = render_messages( + full_ids = _render_messages( tokenizer, all_messages, tools=tools, ) - split_idx = common_prefix_len(prompt_ids, full_ids) + split_idx = _common_prefix_len(prompt_ids, full_ids) original_prompt_len = len(prompt_ids) prompt_ids = full_ids[:split_idx] @@ -127,10 +157,10 @@ def _tokenize_step_with_renderer( """Tokenize a trajectory step using a Renderer.""" from renderers.base import build_trajectory_step - prompt = normalize_messages(step.get("prompt"), default_role="user") - completion = normalize_messages(step.get("completion"), default_role="assistant") - prompt = strip_message_content(deserialize_tool_calls(prompt)) - completion = strip_message_content(deserialize_tool_calls(completion)) + prompt = _normalize_messages(step.get("prompt"), default_role="user") + completion = _normalize_messages(step.get("completion"), default_role="assistant") + prompt = _strip_message_content(_deserialize_tool_calls(prompt)) + completion = _strip_message_content(_deserialize_tool_calls(completion)) return build_trajectory_step(renderer, prompt, completion, tools=tools) diff --git a/tests/unit/orchestrator/test_trajectories.py b/tests/unit/orchestrator/test_trajectories.py index b37de4abb6..36c9ef1008 100644 --- a/tests/unit/orchestrator/test_trajectories.py +++ b/tests/unit/orchestrator/test_trajectories.py @@ -6,10 +6,10 @@ import verifiers as vf from prime_rl.orchestrator.trajectories import ( + _deserialize_tool_calls, align_routed_experts, interleave_rollout, ) -from prime_rl.utils.chat_template import deserialize_tool_calls _interleave_rollout = interleave_rollout @@ -49,7 +49,7 @@ def _sample_routed_experts(sample) -> np.ndarray: def test_deserialize_tool_calls_does_not_inject_missing_key(): messages = [{"role": "assistant", "content": "hello"}] - deserialized = deserialize_tool_calls(messages) + deserialized = _deserialize_tool_calls(messages) assert "tool_calls" not in deserialized[0] @@ -68,7 +68,7 @@ def test_deserialize_tool_calls_parses_arguments_when_present(): } ] - deserialized = deserialize_tool_calls(messages) + deserialized = _deserialize_tool_calls(messages) assert deserialized[0]["tool_calls"][0]["function"]["arguments"] == {"x": 1} From f50ad90cceced1b7329c15535b9302625cde926a Mon Sep 17 00:00:00 2001 From: S1ro1 Date: Fri, 22 May 2026 05:32:13 +0530 Subject: [PATCH 28/32] refactor: simplify routed experts packing --- src/prime_rl/orchestrator/trajectories.py | 30 +++++++++++------------ 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/src/prime_rl/orchestrator/trajectories.py b/src/prime_rl/orchestrator/trajectories.py index bd827e5374..5e693b0a76 100644 --- a/src/prime_rl/orchestrator/trajectories.py +++ b/src/prime_rl/orchestrator/trajectories.py @@ -242,16 +242,10 @@ def prepare_step_tokens(step: vf.TrajectoryStep, step_idx: int) -> dict[str, Any routed_experts_payload = tokens.get("routed_experts") routed_experts = None if routed_experts_payload is not None: - shape = [int(dim) for dim in routed_experts_payload["shape"]] decoded_routed_experts = pybase64.b64decode_as_bytearray(routed_experts_payload["data"]) - expected_size = int(np.prod(shape, dtype=np.int64)) - assert len(decoded_routed_experts) == expected_size, ( - len(decoded_routed_experts), - expected_size, - shape, + routed_experts = np.frombuffer(decoded_routed_experts, dtype=np.uint8).reshape( + routed_experts_payload["shape"] ) - routed_experts = np.frombuffer(decoded_routed_experts, dtype=np.uint8).reshape(shape) - assert routed_experts.ndim == 3 return { "prompt_ids": list(tokens["prompt_ids"]), @@ -289,6 +283,15 @@ def make_sample(tokens: dict[str, Any]) -> TrainingSample: tokens.get("routed_experts"), len(tokens["prompt_ids"]) + len(tokens["completion_ids"]), ) + packed_routed_experts = None + if routed_experts is not None: + routed_experts = np.ascontiguousarray(routed_experts) + packed_routed_experts = RoutedExperts( + data=routed_experts.tobytes(), + shape=list(routed_experts.shape), + dtype=str(routed_experts.dtype), + ) + prompt_ids = list(tokens["prompt_ids"]) sample = TrainingSample( prompt_ids=prompt_ids, @@ -301,14 +304,8 @@ def make_sample(tokens: dict[str, Any]) -> TrainingSample: advantage=None, env_name=output["env_name"], mm_token_type_ids=None, + routed_experts=packed_routed_experts, ) - if routed_experts is not None: - routed_experts = np.ascontiguousarray(routed_experts) - sample.routed_experts = RoutedExperts( - data=routed_experts.tobytes(), - shape=list(routed_experts.shape), - dtype=str(routed_experts.dtype), - ) return sample, routed_experts def extend_sample( @@ -349,11 +346,12 @@ def extend_sample( expected_len = len(sample.prompt_ids) + len(sample.completion_ids) sample_routed_experts = align_routed_experts(sample_routed_experts, expected_len) sample_routed_experts = np.ascontiguousarray(sample_routed_experts) - sample.routed_experts = RoutedExperts( + packed_routed_experts = RoutedExperts( data=sample_routed_experts.tobytes(), shape=list(sample_routed_experts.shape), dtype=str(sample_routed_experts.dtype), ) + sample.routed_experts = packed_routed_experts return sample_routed_experts # Track (prefix_tokens, sample, step_indices) per active sample. step_indices From 97c65a21cd7e1957ce7c67b8f8e17197d8414c63 Mon Sep 17 00:00:00 2001 From: S1ro1 Date: Sat, 23 May 2026 02:24:18 +0530 Subject: [PATCH 29/32] chore: pin vllm router wheel --- pyproject.toml | 2 +- uv.lock | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index bc934b2a38..7c1d94e495 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -233,7 +233,7 @@ flash-attn-4 = { git = "https://github.com/Dao-AILab/flash-attention.git", subdi prime-pydantic-config = { workspace = true } vllm-router = { url = "https://github.com/PrimeIntellect-ai/router/releases/download/v0.1.25/vllm_router-0.1.25-cp38-abi3-manylinux_2_28_x86_64.whl" } vllm = [ - { url = "https://github.com/vllm-project/vllm/releases/download/v0.21.0/vllm-0.21.0+cu129-cp38-abi3-manylinux_2_34_x86_64.whl", marker = "platform_machine == 'x86_64'" }, + { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.21.0+cu129.router.f96fddf-cp38-abi3-manylinux_2_34_x86_64.whl", marker = "platform_machine == 'x86_64'" }, { url = "https://github.com/vllm-project/vllm/releases/download/v0.21.0/vllm-0.21.0+cu129-cp38-abi3-manylinux_2_34_aarch64.whl", marker = "platform_machine == 'aarch64'" }, ] deep-ep = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/deep_ep-1.2.1+29d31c0-cp312-cp312-linux_x86_64.whl" } diff --git a/uv.lock b/uv.lock index c4d4ae7c60..0cc9c208f8 100644 --- a/uv.lock +++ b/uv.lock @@ -3920,7 +3920,7 @@ dependencies = [ { name = "uvloop", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "verifiers", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "vllm", version = "0.21.0+cu129", source = { url = "https://github.com/vllm-project/vllm/releases/download/v0.21.0/vllm-0.21.0+cu129-cp38-abi3-manylinux_2_34_aarch64.whl" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "vllm", version = "0.21.0+cu129", source = { url = "https://github.com/vllm-project/vllm/releases/download/v0.21.0/vllm-0.21.0+cu129-cp38-abi3-manylinux_2_34_x86_64.whl" }, marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "vllm", version = "0.21.0+cu129.router.f96fddf", source = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.21.0+cu129.router.f96fddf-cp38-abi3-manylinux_2_34_x86_64.whl" }, marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "wandb", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, ] @@ -4074,7 +4074,7 @@ requires-dist = [ { name = "verifiers", editable = "deps/verifiers" }, { name = "vllm", marker = "platform_machine != 'aarch64' and platform_machine != 'x86_64'", specifier = ">=0.21.0" }, { name = "vllm", marker = "platform_machine == 'aarch64'", url = "https://github.com/vllm-project/vllm/releases/download/v0.21.0/vllm-0.21.0+cu129-cp38-abi3-manylinux_2_34_aarch64.whl" }, - { name = "vllm", marker = "platform_machine == 'x86_64'", url = "https://github.com/vllm-project/vllm/releases/download/v0.21.0/vllm-0.21.0+cu129-cp38-abi3-manylinux_2_34_x86_64.whl" }, + { name = "vllm", marker = "platform_machine == 'x86_64'", url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.21.0+cu129.router.f96fddf-cp38-abi3-manylinux_2_34_x86_64.whl" }, { name = "vllm-router", marker = "platform_machine == 'x86_64' and extra == 'disagg'", url = "https://github.com/PrimeIntellect-ai/router/releases/download/v0.1.25/vllm_router-0.1.25-cp38-abi3-manylinux_2_28_x86_64.whl" }, { name = "wandb", specifier = ">=0.26.1" }, { name = "wiki-search", marker = "extra == 'envs'", editable = "deps/verifiers/environments/wiki_search" }, @@ -5933,7 +5933,7 @@ rl = [ { name = "torch", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "transformers", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "vllm", version = "0.21.0+cu129", source = { url = "https://github.com/vllm-project/vllm/releases/download/v0.21.0/vllm-0.21.0+cu129-cp38-abi3-manylinux_2_34_aarch64.whl" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "vllm", version = "0.21.0+cu129", source = { url = "https://github.com/vllm-project/vllm/releases/download/v0.21.0/vllm-0.21.0+cu129-cp38-abi3-manylinux_2_34_x86_64.whl" }, marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "vllm", version = "0.21.0+cu129.router.f96fddf", source = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.21.0+cu129.router.f96fddf-cp38-abi3-manylinux_2_34_x86_64.whl" }, marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "wandb", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, ] ta = [ @@ -6009,7 +6009,7 @@ requires-dist = [ { name = "typing-extensions", marker = "python_full_version < '3.12'" }, { name = "vllm", marker = "platform_machine == 'aarch64' and extra == 'rl'", url = "https://github.com/vllm-project/vllm/releases/download/v0.21.0/vllm-0.21.0+cu129-cp38-abi3-manylinux_2_34_aarch64.whl" }, { name = "vllm", marker = "platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'rl'", specifier = ">=0.10.0,<0.11.0" }, - { name = "vllm", marker = "platform_machine == 'x86_64' and extra == 'rl'", url = "https://github.com/vllm-project/vllm/releases/download/v0.21.0/vllm-0.21.0+cu129-cp38-abi3-manylinux_2_34_x86_64.whl" }, + { name = "vllm", marker = "platform_machine == 'x86_64' and extra == 'rl'", url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.21.0+cu129.router.f96fddf-cp38-abi3-manylinux_2_34_x86_64.whl" }, { name = "wandb", marker = "extra == 'rl'" }, ] provides-extras = ["browser", "openenv", "renderers", "rg", "rl", "ta"] @@ -6229,8 +6229,8 @@ provides-extras = ["zen", "bench", "tensorizer", "fastsafetensors", "instanttens [[package]] name = "vllm" -version = "0.21.0+cu129" -source = { url = "https://github.com/vllm-project/vllm/releases/download/v0.21.0/vllm-0.21.0+cu129-cp38-abi3-manylinux_2_34_x86_64.whl" } +version = "0.21.0+cu129.router.f96fddf" +source = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.21.0+cu129.router.f96fddf-cp38-abi3-manylinux_2_34_x86_64.whl" } resolution-markers = [ "platform_machine == 'x86_64' and sys_platform == 'linux'", ] @@ -6306,7 +6306,7 @@ dependencies = [ { name = "xgrammar", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, ] wheels = [ - { url = "https://github.com/vllm-project/vllm/releases/download/v0.21.0/vllm-0.21.0+cu129-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:920777691e340df7a8328adfb1e57b9996dbb537edfb654dd32f70844f5f423d" }, + { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.21.0+cu129.router.f96fddf-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:f08dc9bd405fbd77c582e78dd002cc0df9904d991379213f0836ad768090713c" }, ] [package.metadata] From fbf16de18c70822ccfd3786b2ca5180c46247f24 Mon Sep 17 00:00:00 2001 From: S1ro1 Date: Sat, 23 May 2026 02:42:00 +0530 Subject: [PATCH 30/32] chore: update renderers and verifiers --- deps/renderers | 2 +- deps/verifiers | 2 +- pyproject.toml | 1 + tests/unit/orchestrator/test_qwen3_vl_e2e.py | 5 ++++- uv.lock | 19 +++++++++++-------- 5 files changed, 18 insertions(+), 11 deletions(-) diff --git a/deps/renderers b/deps/renderers index 8704f9d502..3ae276c446 160000 --- a/deps/renderers +++ b/deps/renderers @@ -1 +1 @@ -Subproject commit 8704f9d50252692a4a677177eb98d274f8d3ac5d +Subproject commit 3ae276c44683f8b11115b0c9f365abbb4beb850c diff --git a/deps/verifiers b/deps/verifiers index 58b119fa1b..521d436c55 160000 --- a/deps/verifiers +++ b/deps/verifiers @@ -1 +1 @@ -Subproject commit 58b119fa1b24eff85b74a75ccf3e132523b3c6c3 +Subproject commit 521d436c551b9a706cd3bcebd7200ae7e8907abc diff --git a/pyproject.toml b/pyproject.toml index 7c1d94e495..bf56e23067 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -179,6 +179,7 @@ override-dependencies = [ # we want latest vllm, remove next patch vllm = false tokenspeed-mla = false +fastokens = false flash_attn_3 = false # PrimeIntellect-published on PyPI (trusted publisher) prime = false diff --git a/tests/unit/orchestrator/test_qwen3_vl_e2e.py b/tests/unit/orchestrator/test_qwen3_vl_e2e.py index a08fa30fae..ffbdc45457 100644 --- a/tests/unit/orchestrator/test_qwen3_vl_e2e.py +++ b/tests/unit/orchestrator/test_qwen3_vl_e2e.py @@ -14,10 +14,12 @@ from __future__ import annotations import asyncio +import json from pathlib import Path from typing import Any from unittest.mock import MagicMock +import httpx import pytest _HF_CACHE = Path("~/.cache/huggingface/hub").expanduser() @@ -54,7 +56,7 @@ async def post(self, path, *, cast_to=dict, body=None, options=None): self.calls.append({"path": path, "body": body, "options": options}) # Reply with two sampled tokens + <|im_end|>. The renderer's # parse_response slices the content tokens. - return { + payload = { "request_id": "qwen-vl-e2e", "choices": [ { @@ -71,6 +73,7 @@ async def post(self, path, *, cast_to=dict, body=None, options=None): }, ], } + return httpx.Response(200, content=json.dumps(payload).encode()) def test_renderer_client_qwen3_vl_e2e_features_payload_roundtrips_through_vllm(): diff --git a/uv.lock b/uv.lock index 0cc9c208f8..3d6252df51 100644 --- a/uv.lock +++ b/uv.lock @@ -23,6 +23,7 @@ vllm = false vllm-router = false dion = false tokenspeed-mla = false +prime = false nixl-cu12 = false deep-ep = false flash-attn-3 = false @@ -30,7 +31,7 @@ prime-sandboxes = false prime-tunnel = false deep-gemm = false prime-evals = false -prime = false +fastokens = false [manifest] members = [ @@ -1309,14 +1310,14 @@ wheels = [ [[package]] name = "fastokens" -version = "0.1.2" +version = "0.2.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f2/bd/e65b2989eb045863e1d4b1d161d122f69c8d3b8e23fa287e2a8f1eb4c8ab/fastokens-0.1.2.tar.gz", hash = "sha256:71da0dd9b198d3a00c1cdfae06aff7a616513bced4ba6b2ab0da63b688302c0d", size = 675220, upload-time = "2026-05-07T14:34:31.372Z" } +sdist = { url = "https://files.pythonhosted.org/packages/14/8e/7e88ec1d48db5a6e8d8d44318ce285e38c04b81508bdc2a60e17045a116f/fastokens-0.2.0.tar.gz", hash = "sha256:ef0e175de5c8cb1b616b3210d75dce1fab78e35fc02f77f03f7847d4678be686", size = 675822, upload-time = "2026-05-17T10:32:55.642Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c7/79/fd6f087929423289df4cf11c5c05d0c13f5274b6f1ff187d322b15ee35bc/fastokens-0.1.2-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:07737f126ea0c6b92123f13c6aef9fada45923d37efdcf3d6bb23e677ec782a6", size = 3304086, upload-time = "2026-05-07T14:34:13.13Z" }, - { url = "https://files.pythonhosted.org/packages/3c/9d/393fa72d1d9a4e251221e077e42bdccce736f86636563b785d8460d655d1/fastokens-0.1.2-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:cd5c630f190a29492d86da7fcc53dd642f59dec4bf3eb204a378a32131003970", size = 3252648, upload-time = "2026-05-07T14:34:01.934Z" }, - { url = "https://files.pythonhosted.org/packages/70/8e/728c46b32fd6c10088a6a1b268732f50c03b0efb035bbea7d3f22b8de47e/fastokens-0.1.2-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:62b3bbbeb4e0ec72ff895b15e4b0ad04a791c87caa0edb1f63bd9d7c9896c86e", size = 3335929, upload-time = "2026-05-07T14:34:21.548Z" }, - { url = "https://files.pythonhosted.org/packages/04/3d/4ccb53de21bfb87ec13f1dfddc7567cb01732f5c755a083a7cc9e6eebfec/fastokens-0.1.2-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:8e1cb2331e4ac377a636411d722e7a0a2a12c00a46c2d64ab6522004a38c8918", size = 3598235, upload-time = "2026-05-07T14:34:30.102Z" }, + { url = "https://files.pythonhosted.org/packages/b4/54/e0e4318ee1ad0b5196df72cf93615bba0b81f7869d659a44ccc475969151/fastokens-0.2.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:160253f8d30747cf66e7ed895c513e16f7b173dd9e644fa641e2eecbd43a616a", size = 3303534, upload-time = "2026-05-17T10:32:37.462Z" }, + { url = "https://files.pythonhosted.org/packages/64/44/bfff90e4b1a43c17edf7305dafbd56dc992bbe832cc08da78f1f50104c2d/fastokens-0.2.0-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:b61b9fe5b41e0bb36ad86e7551dc53293c9833909ef07b1cdbaa2055b06c3b3e", size = 3254096, upload-time = "2026-05-17T10:32:28.489Z" }, + { url = "https://files.pythonhosted.org/packages/05/bf/1cad7f0e8d03f5f5b2b417cda8859e4d968d2eebdca0cd336b23d7dbbdbb/fastokens-0.2.0-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:01b9bdba818d7b2c67d57d9917faf7a1dad32ece0734440130de94ad768b819f", size = 3336689, upload-time = "2026-05-17T10:32:46.21Z" }, + { url = "https://files.pythonhosted.org/packages/97/d7/f5fb2564e16b1f5733e05c41b090f95a3fe767f6b888ba7d864193bc5447/fastokens-0.2.0-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:d068bc50082ad67d5d542847075f1f7b8d10f703274e56e241312f18b4d9e772", size = 3598064, upload-time = "2026-05-17T10:32:54.109Z" }, ] [[package]] @@ -4758,7 +4759,7 @@ dev = [ [package.metadata] requires-dist = [ - { name = "fastokens", specifier = ">=0.1.1" }, + { name = "fastokens", specifier = ">=0.2.0" }, { name = "jinja2" }, { name = "numpy" }, { name = "openai", specifier = ">=1.108.1" }, @@ -5895,6 +5896,7 @@ dependencies = [ { name = "numpy", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "openai", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "openai-agents", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "prime-pydantic-config", extra = ["toml"], marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "prime-sandboxes", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "prime-tunnel", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "pydantic", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, @@ -5987,6 +5989,7 @@ requires-dist = [ { name = "openai-agents", specifier = ">=0.0.7" }, { name = "openenv-core", marker = "extra == 'openenv'", specifier = ">=0.3.0" }, { name = "peft", marker = "extra == 'rl'" }, + { name = "prime-pydantic-config", extras = ["toml"], editable = "deps/pydantic-config" }, { name = "prime-sandboxes", specifier = ">=0.2.25" }, { name = "prime-tunnel", specifier = ">=0.1.6" }, { name = "pydantic", specifier = ">=2.11.9" }, From 2630ce90d6e983a8292b38631af87a1ca50074d4 Mon Sep 17 00:00:00 2001 From: S1ro1 Date: Sat, 23 May 2026 04:19:15 +0530 Subject: [PATCH 31/32] Pin vLLM PR39568 backport wheel --- pyproject.toml | 2 +- uv.lock | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index bf56e23067..a3d6fa8e87 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -234,7 +234,7 @@ flash-attn-4 = { git = "https://github.com/Dao-AILab/flash-attention.git", subdi prime-pydantic-config = { workspace = true } vllm-router = { url = "https://github.com/PrimeIntellect-ai/router/releases/download/v0.1.25/vllm_router-0.1.25-cp38-abi3-manylinux_2_28_x86_64.whl" } vllm = [ - { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.21.0+cu129.router.f96fddf-cp38-abi3-manylinux_2_34_x86_64.whl", marker = "platform_machine == 'x86_64'" }, + { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.21.0+cu129.pr39568.efa1cac-cp38-abi3-manylinux_2_24_x86_64.whl", marker = "platform_machine == 'x86_64'" }, { url = "https://github.com/vllm-project/vllm/releases/download/v0.21.0/vllm-0.21.0+cu129-cp38-abi3-manylinux_2_34_aarch64.whl", marker = "platform_machine == 'aarch64'" }, ] deep-ep = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/deep_ep-1.2.1+29d31c0-cp312-cp312-linux_x86_64.whl" } diff --git a/uv.lock b/uv.lock index 3d6252df51..bbfbc51dbf 100644 --- a/uv.lock +++ b/uv.lock @@ -3921,7 +3921,7 @@ dependencies = [ { name = "uvloop", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "verifiers", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "vllm", version = "0.21.0+cu129", source = { url = "https://github.com/vllm-project/vllm/releases/download/v0.21.0/vllm-0.21.0+cu129-cp38-abi3-manylinux_2_34_aarch64.whl" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "vllm", version = "0.21.0+cu129.router.f96fddf", source = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.21.0+cu129.router.f96fddf-cp38-abi3-manylinux_2_34_x86_64.whl" }, marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "vllm", version = "0.21.0+cu129.pr39568.efa1cac", source = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.21.0+cu129.pr39568.efa1cac-cp38-abi3-manylinux_2_24_x86_64.whl" }, marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "wandb", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, ] @@ -4075,7 +4075,7 @@ requires-dist = [ { name = "verifiers", editable = "deps/verifiers" }, { name = "vllm", marker = "platform_machine != 'aarch64' and platform_machine != 'x86_64'", specifier = ">=0.21.0" }, { name = "vllm", marker = "platform_machine == 'aarch64'", url = "https://github.com/vllm-project/vllm/releases/download/v0.21.0/vllm-0.21.0+cu129-cp38-abi3-manylinux_2_34_aarch64.whl" }, - { name = "vllm", marker = "platform_machine == 'x86_64'", url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.21.0+cu129.router.f96fddf-cp38-abi3-manylinux_2_34_x86_64.whl" }, + { name = "vllm", marker = "platform_machine == 'x86_64'", url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.21.0+cu129.pr39568.efa1cac-cp38-abi3-manylinux_2_24_x86_64.whl" }, { name = "vllm-router", marker = "platform_machine == 'x86_64' and extra == 'disagg'", url = "https://github.com/PrimeIntellect-ai/router/releases/download/v0.1.25/vllm_router-0.1.25-cp38-abi3-manylinux_2_28_x86_64.whl" }, { name = "wandb", specifier = ">=0.26.1" }, { name = "wiki-search", marker = "extra == 'envs'", editable = "deps/verifiers/environments/wiki_search" }, @@ -5935,7 +5935,7 @@ rl = [ { name = "torch", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "transformers", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "vllm", version = "0.21.0+cu129", source = { url = "https://github.com/vllm-project/vllm/releases/download/v0.21.0/vllm-0.21.0+cu129-cp38-abi3-manylinux_2_34_aarch64.whl" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "vllm", version = "0.21.0+cu129.router.f96fddf", source = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.21.0+cu129.router.f96fddf-cp38-abi3-manylinux_2_34_x86_64.whl" }, marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "vllm", version = "0.21.0+cu129.pr39568.efa1cac", source = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.21.0+cu129.pr39568.efa1cac-cp38-abi3-manylinux_2_24_x86_64.whl" }, marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "wandb", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, ] ta = [ @@ -6012,7 +6012,7 @@ requires-dist = [ { name = "typing-extensions", marker = "python_full_version < '3.12'" }, { name = "vllm", marker = "platform_machine == 'aarch64' and extra == 'rl'", url = "https://github.com/vllm-project/vllm/releases/download/v0.21.0/vllm-0.21.0+cu129-cp38-abi3-manylinux_2_34_aarch64.whl" }, { name = "vllm", marker = "platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'rl'", specifier = ">=0.10.0,<0.11.0" }, - { name = "vllm", marker = "platform_machine == 'x86_64' and extra == 'rl'", url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.21.0+cu129.router.f96fddf-cp38-abi3-manylinux_2_34_x86_64.whl" }, + { name = "vllm", marker = "platform_machine == 'x86_64' and extra == 'rl'", url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.21.0+cu129.pr39568.efa1cac-cp38-abi3-manylinux_2_24_x86_64.whl" }, { name = "wandb", marker = "extra == 'rl'" }, ] provides-extras = ["browser", "openenv", "renderers", "rg", "rl", "ta"] @@ -6232,8 +6232,8 @@ provides-extras = ["zen", "bench", "tensorizer", "fastsafetensors", "instanttens [[package]] name = "vllm" -version = "0.21.0+cu129.router.f96fddf" -source = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.21.0+cu129.router.f96fddf-cp38-abi3-manylinux_2_34_x86_64.whl" } +version = "0.21.0+cu129.pr39568.efa1cac" +source = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.21.0+cu129.pr39568.efa1cac-cp38-abi3-manylinux_2_24_x86_64.whl" } resolution-markers = [ "platform_machine == 'x86_64' and sys_platform == 'linux'", ] @@ -6309,7 +6309,7 @@ dependencies = [ { name = "xgrammar", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, ] wheels = [ - { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.21.0+cu129.router.f96fddf-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:f08dc9bd405fbd77c582e78dd002cc0df9904d991379213f0836ad768090713c" }, + { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.21.0+cu129.pr39568.efa1cac-cp38-abi3-manylinux_2_24_x86_64.whl", hash = "sha256:726223381676b552cbc66c5d7e67f5cf8a01e8018af003ed2afd918a743474c9" }, ] [package.metadata] From c3ffa15e333c064ffa06d0fa91591129bfe77e62 Mon Sep 17 00:00:00 2001 From: S1ro1 Date: Sat, 23 May 2026 04:47:57 +0530 Subject: [PATCH 32/32] Pin vLLM revert42434 PR39568 wheel --- pyproject.toml | 2 +- uv.lock | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a3d6fa8e87..6a4bae2629 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -234,7 +234,7 @@ flash-attn-4 = { git = "https://github.com/Dao-AILab/flash-attention.git", subdi prime-pydantic-config = { workspace = true } vllm-router = { url = "https://github.com/PrimeIntellect-ai/router/releases/download/v0.1.25/vllm_router-0.1.25-cp38-abi3-manylinux_2_28_x86_64.whl" } vllm = [ - { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.21.0+cu129.pr39568.efa1cac-cp38-abi3-manylinux_2_24_x86_64.whl", marker = "platform_machine == 'x86_64'" }, + { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.21.0+cu129.r42434.pr39568.a106aa6-cp38-abi3-manylinux_2_24_x86_64.whl", marker = "platform_machine == 'x86_64'" }, { url = "https://github.com/vllm-project/vllm/releases/download/v0.21.0/vllm-0.21.0+cu129-cp38-abi3-manylinux_2_34_aarch64.whl", marker = "platform_machine == 'aarch64'" }, ] deep-ep = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/deep_ep-1.2.1+29d31c0-cp312-cp312-linux_x86_64.whl" } diff --git a/uv.lock b/uv.lock index bbfbc51dbf..0aab78a784 100644 --- a/uv.lock +++ b/uv.lock @@ -3921,7 +3921,7 @@ dependencies = [ { name = "uvloop", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "verifiers", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "vllm", version = "0.21.0+cu129", source = { url = "https://github.com/vllm-project/vllm/releases/download/v0.21.0/vllm-0.21.0+cu129-cp38-abi3-manylinux_2_34_aarch64.whl" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "vllm", version = "0.21.0+cu129.pr39568.efa1cac", source = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.21.0+cu129.pr39568.efa1cac-cp38-abi3-manylinux_2_24_x86_64.whl" }, marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "vllm", version = "0.21.0+cu129.r42434.pr39568.a106aa6", source = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.21.0+cu129.r42434.pr39568.a106aa6-cp38-abi3-manylinux_2_24_x86_64.whl" }, marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "wandb", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, ] @@ -4075,7 +4075,7 @@ requires-dist = [ { name = "verifiers", editable = "deps/verifiers" }, { name = "vllm", marker = "platform_machine != 'aarch64' and platform_machine != 'x86_64'", specifier = ">=0.21.0" }, { name = "vllm", marker = "platform_machine == 'aarch64'", url = "https://github.com/vllm-project/vllm/releases/download/v0.21.0/vllm-0.21.0+cu129-cp38-abi3-manylinux_2_34_aarch64.whl" }, - { name = "vllm", marker = "platform_machine == 'x86_64'", url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.21.0+cu129.pr39568.efa1cac-cp38-abi3-manylinux_2_24_x86_64.whl" }, + { name = "vllm", marker = "platform_machine == 'x86_64'", url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.21.0+cu129.r42434.pr39568.a106aa6-cp38-abi3-manylinux_2_24_x86_64.whl" }, { name = "vllm-router", marker = "platform_machine == 'x86_64' and extra == 'disagg'", url = "https://github.com/PrimeIntellect-ai/router/releases/download/v0.1.25/vllm_router-0.1.25-cp38-abi3-manylinux_2_28_x86_64.whl" }, { name = "wandb", specifier = ">=0.26.1" }, { name = "wiki-search", marker = "extra == 'envs'", editable = "deps/verifiers/environments/wiki_search" }, @@ -5935,7 +5935,7 @@ rl = [ { name = "torch", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "transformers", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "vllm", version = "0.21.0+cu129", source = { url = "https://github.com/vllm-project/vllm/releases/download/v0.21.0/vllm-0.21.0+cu129-cp38-abi3-manylinux_2_34_aarch64.whl" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine != 'aarch64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, - { name = "vllm", version = "0.21.0+cu129.pr39568.efa1cac", source = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.21.0+cu129.pr39568.efa1cac-cp38-abi3-manylinux_2_24_x86_64.whl" }, marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, + { name = "vllm", version = "0.21.0+cu129.r42434.pr39568.a106aa6", source = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.21.0+cu129.r42434.pr39568.a106aa6-cp38-abi3-manylinux_2_24_x86_64.whl" }, marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, { name = "wandb", marker = "(platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, ] ta = [ @@ -6012,7 +6012,7 @@ requires-dist = [ { name = "typing-extensions", marker = "python_full_version < '3.12'" }, { name = "vllm", marker = "platform_machine == 'aarch64' and extra == 'rl'", url = "https://github.com/vllm-project/vllm/releases/download/v0.21.0/vllm-0.21.0+cu129-cp38-abi3-manylinux_2_34_aarch64.whl" }, { name = "vllm", marker = "platform_machine != 'aarch64' and platform_machine != 'x86_64' and extra == 'rl'", specifier = ">=0.10.0,<0.11.0" }, - { name = "vllm", marker = "platform_machine == 'x86_64' and extra == 'rl'", url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.21.0+cu129.pr39568.efa1cac-cp38-abi3-manylinux_2_24_x86_64.whl" }, + { name = "vllm", marker = "platform_machine == 'x86_64' and extra == 'rl'", url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.21.0+cu129.r42434.pr39568.a106aa6-cp38-abi3-manylinux_2_24_x86_64.whl" }, { name = "wandb", marker = "extra == 'rl'" }, ] provides-extras = ["browser", "openenv", "renderers", "rg", "rl", "ta"] @@ -6232,8 +6232,8 @@ provides-extras = ["zen", "bench", "tensorizer", "fastsafetensors", "instanttens [[package]] name = "vllm" -version = "0.21.0+cu129.pr39568.efa1cac" -source = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.21.0+cu129.pr39568.efa1cac-cp38-abi3-manylinux_2_24_x86_64.whl" } +version = "0.21.0+cu129.r42434.pr39568.a106aa6" +source = { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.21.0+cu129.r42434.pr39568.a106aa6-cp38-abi3-manylinux_2_24_x86_64.whl" } resolution-markers = [ "platform_machine == 'x86_64' and sys_platform == 'linux'", ] @@ -6309,7 +6309,7 @@ dependencies = [ { name = "xgrammar", marker = "(platform_machine == 'x86_64' and sys_platform == 'linux') or (platform_machine != 'x86_64' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy') or (sys_platform != 'linux' and extra == 'extra-9-verifiers-openenv' and extra == 'group-9-verifiers-policy')" }, ] wheels = [ - { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.21.0+cu129.pr39568.efa1cac-cp38-abi3-manylinux_2_24_x86_64.whl", hash = "sha256:726223381676b552cbc66c5d7e67f5cf8a01e8018af003ed2afd918a743474c9" }, + { url = "https://github.com/PrimeIntellect-ai/prime-rl/releases/download/v0.5.0/vllm-0.21.0+cu129.r42434.pr39568.a106aa6-cp38-abi3-manylinux_2_24_x86_64.whl", hash = "sha256:80dbe20d6df474df0d9f87c4b82a68de2c96c36d9a1a5e55620e69d3f306fd4b" }, ] [package.metadata]