diff --git a/docs/content/docs/configuration/config.mdx b/docs/content/docs/configuration/config.mdx index 6d6655f2ae..cca1962dc2 100644 --- a/docs/content/docs/configuration/config.mdx +++ b/docs/content/docs/configuration/config.mdx @@ -490,7 +490,7 @@ algorithm: - `token_mean`: computes average loss over all valid tokens in the batch. Used in [DAPO](https://dapo-sia.github.io/). - `sequence_mean`: computes per-sequence avg token loss, then averages over the batch. - - `seq_mean_token_sum_norm`: computes the sum of token losses for each sequence, normalizes by `max_seq_len`, and then averages over the batch. This is used in [Dr. GRPO](https://arxiv.org/abs/2503.20783). If `algorithm.max_seq_len` is not explicitly set, it defaults to `generator.max_input_length + generator.sampling_params.max_generate_length`. + - `seq_mean_token_sum_norm`: computes the sum of token losses for each sequence, normalizes by `max_seq_len`, and then averages over the batch. This is used in [Dr. GRPO](https://arxiv.org/abs/2503.20783). `algorithm.max_seq_len` must be set explicitly for this mode because multi-turn/token budgets are workload-dependent. - `algorithm.grpo_norm_by_std`: Whether to normalize advantages by the standard deviation in GRPO. This is set to `false` in [Dr. GRPO](https://arxiv.org/abs/2503.20783). - `algorithm.zero_variance_filter`: Whether to loss mask out prompts with zero variance rewards. This is only applicable when rewards are response-level. @@ -505,7 +505,7 @@ algorithm: - `algorithm.dynamic_sampling.max_sample_batches`: Maximum number of batches to sample before stopping. Set to `-1` to sample forever. - `algorithm.dynamic_sampling.min_replace_ratio`: Minimum proportion of good samples with which to replace bad samples for `replace` strategy. - `algorithm.use_tis`: Whether to use Truncated Importance Sampling (TIS) as proposed in [this blog](https://fengyao.notion.site/off-policy-rl). This flag is to be deprecated, use `off_policy_correction.tis_ratio_type = "token"` instead. -- `max_seq_len`: Used for `seq_mean_token_sum_norm` `loss_reduction`. Users should set this value for multi-turn for that loss. If not set, will be calculated as generator.max_input_length + generator.sampling_params.max_generate_length, which is incorrect for multi-turn. +- `max_seq_len`: Used for `seq_mean_token_sum_norm` `loss_reduction`. Required when using that reduction mode. Set it to the total sequence-length normalization constant for your setup; this often matches the model context window / vLLM `max_model_len` when that is the intended budget. - `algorithm.tis_imp_ratio_cap`: Cap parameter for the importance ratio in TIS. This flag is to be deprecated, use `off_policy_correction.token_tis_ratio_clip_high` instead. - `algorithm.clip_cov`: Clip-Cov parameters (only used when `policy_loss_type` is `clip_cov`): diff --git a/skyrl/train/config/config.py b/skyrl/train/config/config.py index 4a5c4fbef1..28397a48e8 100644 --- a/skyrl/train/config/config.py +++ b/skyrl/train/config/config.py @@ -371,8 +371,8 @@ class AlgorithmConfig(BaseConfig): cispo: CISPOConfig = field(default_factory=CISPOConfig) """Only used when ``policy_loss_type="cispo"``.""" max_seq_len: Optional[int] = None - """Used for ``seq_mean_token_sum_norm`` loss reduction; set explicitly for multi-turn. - If ``None``, calculated as ``generator.max_input_length + generator.sampling_params.max_generate_length``.""" + """Used for ``seq_mean_token_sum_norm`` loss reduction. + Must be set explicitly for that reduction mode; otherwise can remain ``None``.""" # --------------------------------------------------------------------------- @@ -710,17 +710,6 @@ def __post_init__(self): if self.trainer.algorithm.temperature is None: self.trainer.algorithm.temperature = self.generator.sampling_params.temperature - if self.trainer.algorithm.max_seq_len is None: - # NOTE (erictang000): this is the max sequence length including the prompt, since max response length - # per batch can be variable based on the prompt length. This is used to normalize the loss for - # seq_mean_token_sum_norm loss reduction. - # TODO(Charlie): This calculation is not correct for multi-turn and users should use `max_seq_len` instead. - # Should we just force users to set max_seq_len if loss reduction is seq_mean_token_sum_norm, regardless of - # multi-turn or not? - self.trainer.algorithm.max_seq_len = ( - self.generator.max_input_length + self.generator.sampling_params.max_generate_length - ) - @classmethod def from_cli_overrides(cls, args: Union[List[str], dict]) -> "SkyRLTrainConfig": """Construct a SkyRLTrainConfig from CLI arguments or a dict of overrides. diff --git a/skyrl/train/config/ppo_base_config.yaml b/skyrl/train/config/ppo_base_config.yaml index c3f48297a8..d3ca423270 100644 --- a/skyrl/train/config/ppo_base_config.yaml +++ b/skyrl/train/config/ppo_base_config.yaml @@ -127,8 +127,9 @@ trainer: tis_imp_ratio_cap: -1.0 use_tis: false - # Used for seq_mean_token_sum_norm loss reduction. Users should set this value for multi-turn for that loss. - # If not set, will be calculated as generator.max_input_length + generator.sampling_params.max_generate_length. + # Used for seq_mean_token_sum_norm loss reduction. + # Must be set explicitly when trainer.algorithm.loss_reduction=seq_mean_token_sum_norm. + # Choose the total sequence-length normalization constant for your setup. max_seq_len: null # references diff --git a/skyrl/train/utils/utils.py b/skyrl/train/utils/utils.py index bea1d6bb0c..bd8c285a3b 100644 --- a/skyrl/train/utils/utils.py +++ b/skyrl/train/utils/utils.py @@ -276,6 +276,14 @@ def validate_cfg(cfg: SkyRLTrainConfig): f"invalid loss_reduction: {cfg.trainer.algorithm.loss_reduction}. " f"Must be one of `['token_mean', 'sequence_mean', 'seq_mean_token_sum_norm']`" ) + if cfg.trainer.algorithm.loss_reduction == "seq_mean_token_sum_norm": + if cfg.trainer.algorithm.max_seq_len is None: + raise ValueError( + "`trainer.algorithm.max_seq_len` must be set explicitly when " + "`trainer.algorithm.loss_reduction='seq_mean_token_sum_norm'`. " + "Choose the total sequence-length normalization constant for your setup; " + "this often matches the model context window / vLLM `max_model_len` when appropriate." + ) # TODO (erictang000): remove this after deprecation period if cfg.trainer.algorithm.use_tis: diff --git a/tests/train/test_config.py b/tests/train/test_config.py index b8813d2fa0..7c997fc725 100644 --- a/tests/train/test_config.py +++ b/tests/train/test_config.py @@ -17,6 +17,16 @@ _resolve_dataclass_type, ) from skyrl.train.config.utils import get_legacy_config +from skyrl.train.utils.utils import validate_cfg +from tests.train.util import example_dummy_config + + +def _make_validated_test_config(): + """Return a small config that passes validate_batch_sizes().""" + cfg = example_dummy_config() + cfg.trainer.policy_mini_batch_size = cfg.trainer.train_batch_size + cfg.trainer.critic_mini_batch_size = cfg.trainer.train_batch_size + return cfg # Helper dataclasses for testing @@ -196,20 +206,35 @@ def test_cross_field_defaults(): class TestMaxSeqLenValidation: - """Tests for the max_seq_len auto-calculation and explicit-override logic in __post_init__.""" + """Tests for max_seq_len defaults and validation behavior.""" - def test_max_seq_len_auto_calculated_when_none(self): - """When max_seq_len is None (default), __post_init__ should compute it as - max_input_length + max_generate_length.""" - # implicitly set max_seq_len to None + def test_max_seq_len_defaults_to_none_when_not_set(self): cfg = SkyRLTrainConfig.from_cli_overrides([]) - - expected = cfg.generator.max_input_length + cfg.generator.sampling_params.max_generate_length - assert cfg.trainer.algorithm.max_seq_len == expected + assert cfg.trainer.algorithm.max_seq_len is None def test_max_seq_len_preserved_when_explicitly_set(self): - """When max_seq_len is explicitly set by the user, __post_init__ should NOT overwrite it.""" - # explicitly set max_seq_len to 32768 cfg = SkyRLTrainConfig.from_cli_overrides(["trainer.algorithm.max_seq_len=32768"]) - assert cfg.trainer.algorithm.max_seq_len == 32768 + + def test_validate_cfg_requires_explicit_max_seq_len_for_seq_mean_token_sum_norm(self): + cfg = _make_validated_test_config() + cfg.trainer.algorithm.loss_reduction = "seq_mean_token_sum_norm" + cfg.trainer.algorithm.max_seq_len = None + + with pytest.raises(ValueError, match=r"trainer\.algorithm\.max_seq_len"): + validate_cfg(cfg) + + @pytest.mark.parametrize("loss_reduction", ["token_mean", "sequence_mean"]) + def test_validate_cfg_allows_missing_max_seq_len_for_other_reductions(self, loss_reduction): + cfg = _make_validated_test_config() + cfg.trainer.algorithm.loss_reduction = loss_reduction + cfg.trainer.algorithm.max_seq_len = None + + validate_cfg(cfg) + + def test_validate_cfg_allows_explicit_max_seq_len_for_seq_mean_token_sum_norm(self): + cfg = _make_validated_test_config() + cfg.trainer.algorithm.loss_reduction = "seq_mean_token_sum_norm" + cfg.trainer.algorithm.max_seq_len = 4096 + + validate_cfg(cfg)