Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/content/docs/configuration/config.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ algorithm:

- `token_mean`: computes average loss over all valid tokens in the batch. Used in [DAPO](https://dapo-sia.github.io/).
- `sequence_mean`: computes per-sequence avg token loss, then averages over the batch.
- `seq_mean_token_sum_norm`: computes the sum of token losses for each sequence, normalizes by `max_seq_len`, and then averages over the batch. This is used in [Dr. GRPO](https://arxiv.org/abs/2503.20783). If `algorithm.max_seq_len` is not explicitly set, it defaults to `generator.max_input_length + generator.sampling_params.max_generate_length`.
- `seq_mean_token_sum_norm`: computes the sum of token losses for each sequence, normalizes by `max_seq_len`, and then averages over the batch. This is used in [Dr. GRPO](https://arxiv.org/abs/2503.20783). `algorithm.max_seq_len` must be set explicitly for this mode because multi-turn/token budgets are workload-dependent.

- `algorithm.grpo_norm_by_std`: Whether to normalize advantages by the standard deviation in GRPO. This is set to `false` in [Dr. GRPO](https://arxiv.org/abs/2503.20783).
- `algorithm.zero_variance_filter`: Whether to loss mask out prompts with zero variance rewards. This is only applicable when rewards are response-level.
Expand All @@ -505,7 +505,7 @@ algorithm:
- `algorithm.dynamic_sampling.max_sample_batches`: Maximum number of batches to sample before stopping. Set to `-1` to sample forever.
- `algorithm.dynamic_sampling.min_replace_ratio`: Minimum proportion of good samples with which to replace bad samples for `replace` strategy.
- `algorithm.use_tis`: Whether to use Truncated Importance Sampling (TIS) as proposed in [this blog](https://fengyao.notion.site/off-policy-rl). This flag is to be deprecated, use `off_policy_correction.tis_ratio_type = "token"` instead.
- `max_seq_len`: Used for `seq_mean_token_sum_norm` `loss_reduction`. Users should set this value for multi-turn for that loss. If not set, will be calculated as generator.max_input_length + generator.sampling_params.max_generate_length, which is incorrect for multi-turn.
- `max_seq_len`: Used for `seq_mean_token_sum_norm` `loss_reduction`. Required when using that reduction mode. Set it to the total sequence-length normalization constant for your setup; this often matches the model context window / vLLM `max_model_len` when that is the intended budget.
- `algorithm.tis_imp_ratio_cap`: Cap parameter for the importance ratio in TIS. This flag is to be deprecated, use `off_policy_correction.token_tis_ratio_clip_high` instead.
- `algorithm.clip_cov`: Clip-Cov parameters (only used when `policy_loss_type` is `clip_cov`):

Expand Down
15 changes: 2 additions & 13 deletions skyrl/train/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,8 +371,8 @@ class AlgorithmConfig(BaseConfig):
cispo: CISPOConfig = field(default_factory=CISPOConfig)
"""Only used when ``policy_loss_type="cispo"``."""
max_seq_len: Optional[int] = None
"""Used for ``seq_mean_token_sum_norm`` loss reduction; set explicitly for multi-turn.
If ``None``, calculated as ``generator.max_input_length + generator.sampling_params.max_generate_length``."""
"""Used for ``seq_mean_token_sum_norm`` loss reduction.
Must be set explicitly for that reduction mode; otherwise can remain ``None``."""
Comment on lines +374 to +375
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Stale comment in reduce_loss claims max_seq_len has a default fallback that no longer exists

At skyrl/backends/skyrl_train/utils/ppo_utils.py:999-1000, the comment says "NOTE: max_seq_len can be set explicitly via algorithm.max_seq_len, otherwise defaults to cfg.generator.max_input_length + cfg.generator.sampling_params.max_generate_length". This auto-calculation default was removed by this PR (deleted from skyrl/train/config/config.py:713-722), and max_seq_len must now always be set explicitly when using seq_mean_token_sum_norm. The stale comment will mislead developers into thinking a fallback still exists.

Prompt for agents
Update the stale comment in skyrl/backends/skyrl_train/utils/ppo_utils.py at lines 999-1000. The comment currently reads:
  # NOTE: max_seq_len can be set explicitly via algorithm.max_seq_len, otherwise defaults to
  # cfg.generator.max_input_length + cfg.generator.sampling_params.max_generate_length

It should be updated to something like:
  # NOTE: max_seq_len must be set explicitly via algorithm.max_seq_len when using seq_mean_token_sum_norm loss reduction.

This aligns with the new docstring at skyrl/train/config/config.py:374-375 and the validation at skyrl/train/utils/utils.py:279-285.
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.



# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -710,17 +710,6 @@ def __post_init__(self):
if self.trainer.algorithm.temperature is None:
self.trainer.algorithm.temperature = self.generator.sampling_params.temperature

if self.trainer.algorithm.max_seq_len is None:
# NOTE (erictang000): this is the max sequence length including the prompt, since max response length
# per batch can be variable based on the prompt length. This is used to normalize the loss for
# seq_mean_token_sum_norm loss reduction.
# TODO(Charlie): This calculation is not correct for multi-turn and users should use `max_seq_len` instead.
# Should we just force users to set max_seq_len if loss reduction is seq_mean_token_sum_norm, regardless of
# multi-turn or not?
self.trainer.algorithm.max_seq_len = (
self.generator.max_input_length + self.generator.sampling_params.max_generate_length
)

@classmethod
def from_cli_overrides(cls, args: Union[List[str], dict]) -> "SkyRLTrainConfig":
"""Construct a SkyRLTrainConfig from CLI arguments or a dict of overrides.
Expand Down
5 changes: 3 additions & 2 deletions skyrl/train/config/ppo_base_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,9 @@ trainer:
tis_imp_ratio_cap: -1.0
use_tis: false

# Used for seq_mean_token_sum_norm loss reduction. Users should set this value for multi-turn for that loss.
# If not set, will be calculated as generator.max_input_length + generator.sampling_params.max_generate_length.
# Used for seq_mean_token_sum_norm loss reduction.
# Must be set explicitly when trainer.algorithm.loss_reduction=seq_mean_token_sum_norm.
# Choose the total sequence-length normalization constant for your setup.
max_seq_len: null

# references
Expand Down
8 changes: 8 additions & 0 deletions skyrl/train/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,14 @@ def validate_cfg(cfg: SkyRLTrainConfig):
f"invalid loss_reduction: {cfg.trainer.algorithm.loss_reduction}. "
f"Must be one of `['token_mean', 'sequence_mean', 'seq_mean_token_sum_norm']`"
)
if cfg.trainer.algorithm.loss_reduction == "seq_mean_token_sum_norm":
if cfg.trainer.algorithm.max_seq_len is None:
raise ValueError(
"`trainer.algorithm.max_seq_len` must be set explicitly when "
"`trainer.algorithm.loss_reduction='seq_mean_token_sum_norm'`. "
"Choose the total sequence-length normalization constant for your setup; "
"this often matches the model context window / vLLM `max_model_len` when appropriate."
Comment on lines +279 to +285
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Breaking change: Dr. GRPO example script fails because auto-calculated max_seq_len fallback was removed

The PR removes the max_seq_len auto-calculation from SkyRLTrainConfig.__post_init__ (skyrl/train/config/config.py:713-722 on LEFT) and adds a hard assertion requiring it to be set explicitly when loss_reduction='seq_mean_token_sum_norm'. However, the official Dr. GRPO example script at examples/train/algorithms/drgrpo/run_drgrpo_gsm8k.sh:15,23 uses LOSS_REDUCTION="seq_mean_token_sum_norm" but never passes trainer.algorithm.max_seq_len. This script previously worked because __post_init__ auto-computed max_seq_len = max_input_length + max_generate_length. Now it will crash with an AssertionError at validation time.

Same issue in skyrl-agent example

skyrl-agent/examples/run_skyrl/run_skyrl_swe.sh:67 also sets trainer.algorithm.loss_reduction="seq_mean_token_sum_norm" without setting max_seq_len, so it will also fail.

Prompt for agents
Two example scripts need to be updated to explicitly pass trainer.algorithm.max_seq_len now that the auto-calculation fallback has been removed:

1. examples/train/algorithms/drgrpo/run_drgrpo_gsm8k.sh: Add a line like trainer.algorithm.max_seq_len=1536 (512 + 1024, matching max_prompt_length + max_generate_length from the script) to the uv run command.

2. skyrl-agent/examples/run_skyrl/run_skyrl_swe.sh: Add a line like trainer.algorithm.max_seq_len=40768 (8000 + 32768, matching max_prompt_length + max_generate_length from the script) to the uv run command.

Both scripts use loss_reduction=seq_mean_token_sum_norm and will now fail the new assertion at skyrl/train/utils/utils.py:279-285 without this fix.
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

Comment on lines +279 to +285
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Breaking change: Dr. GRPO example script fails because auto-calculated max_seq_len fallback was removed

The PR removes the max_seq_len auto-calculation from SkyRLTrainConfig.__post_init__ (skyrl/train/config/config.py:713-722 on LEFT) and adds a hard assertion requiring it to be set explicitly when loss_reduction='seq_mean_token_sum_norm'. However, the official Dr. GRPO example script at examples/train/algorithms/drgrpo/run_drgrpo_gsm8k.sh:15,23 uses LOSS_REDUCTION="seq_mean_token_sum_norm" but never passes trainer.algorithm.max_seq_len. This script previously worked because __post_init__ auto-computed max_seq_len = max_input_length + max_generate_length. Now it will crash with an AssertionError at validation time.

Same issue in skyrl-agent example

skyrl-agent/examples/run_skyrl/run_skyrl_swe.sh:67 also sets trainer.algorithm.loss_reduction="seq_mean_token_sum_norm" without setting max_seq_len, so it will also fail.

Prompt for agents
Two example scripts need to be updated to explicitly pass trainer.algorithm.max_seq_len now that the auto-calculation fallback has been removed:

1. examples/train/algorithms/drgrpo/run_drgrpo_gsm8k.sh: Add a line like trainer.algorithm.max_seq_len=1536 (512 + 1024, matching max_prompt_length + max_generate_length from the script) to the uv run command.

2. skyrl-agent/examples/run_skyrl/run_skyrl_swe.sh: Add a line like trainer.algorithm.max_seq_len=40768 (8000 + 32768, matching max_prompt_length + max_generate_length from the script) to the uv run command.

Both scripts use loss_reduction=seq_mean_token_sum_norm and will now fail the new assertion at skyrl/train/utils/utils.py:279-285 without this fix.
Open in Devin Review

Was this helpful? React with 👍 or 👎 to provide feedback.

)

# TODO (erictang000): remove this after deprecation period
if cfg.trainer.algorithm.use_tis:
Expand Down
47 changes: 36 additions & 11 deletions tests/train/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,16 @@
_resolve_dataclass_type,
)
from skyrl.train.config.utils import get_legacy_config
from skyrl.train.utils.utils import validate_cfg
from tests.train.util import example_dummy_config


def _make_validated_test_config():
"""Return a small config that passes validate_batch_sizes()."""
cfg = example_dummy_config()
cfg.trainer.policy_mini_batch_size = cfg.trainer.train_batch_size
cfg.trainer.critic_mini_batch_size = cfg.trainer.train_batch_size
return cfg


# Helper dataclasses for testing
Expand Down Expand Up @@ -196,20 +206,35 @@ def test_cross_field_defaults():


class TestMaxSeqLenValidation:
"""Tests for the max_seq_len auto-calculation and explicit-override logic in __post_init__."""
"""Tests for max_seq_len defaults and validation behavior."""

def test_max_seq_len_auto_calculated_when_none(self):
"""When max_seq_len is None (default), __post_init__ should compute it as
max_input_length + max_generate_length."""
# implicitly set max_seq_len to None
def test_max_seq_len_defaults_to_none_when_not_set(self):
cfg = SkyRLTrainConfig.from_cli_overrides([])

expected = cfg.generator.max_input_length + cfg.generator.sampling_params.max_generate_length
assert cfg.trainer.algorithm.max_seq_len == expected
assert cfg.trainer.algorithm.max_seq_len is None

def test_max_seq_len_preserved_when_explicitly_set(self):
"""When max_seq_len is explicitly set by the user, __post_init__ should NOT overwrite it."""
# explicitly set max_seq_len to 32768
cfg = SkyRLTrainConfig.from_cli_overrides(["trainer.algorithm.max_seq_len=32768"])

assert cfg.trainer.algorithm.max_seq_len == 32768

def test_validate_cfg_requires_explicit_max_seq_len_for_seq_mean_token_sum_norm(self):
cfg = _make_validated_test_config()
cfg.trainer.algorithm.loss_reduction = "seq_mean_token_sum_norm"
cfg.trainer.algorithm.max_seq_len = None

with pytest.raises(ValueError, match=r"trainer\.algorithm\.max_seq_len"):
validate_cfg(cfg)

@pytest.mark.parametrize("loss_reduction", ["token_mean", "sequence_mean"])
def test_validate_cfg_allows_missing_max_seq_len_for_other_reductions(self, loss_reduction):
cfg = _make_validated_test_config()
cfg.trainer.algorithm.loss_reduction = loss_reduction
cfg.trainer.algorithm.max_seq_len = None

validate_cfg(cfg)

def test_validate_cfg_allows_explicit_max_seq_len_for_seq_mean_token_sum_norm(self):
cfg = _make_validated_test_config()
cfg.trainer.algorithm.loss_reduction = "seq_mean_token_sum_norm"
cfg.trainer.algorithm.max_seq_len = 4096

validate_cfg(cfg)
Loading