diff --git a/src/art/dev/openai_server.py b/src/art/dev/openai_server.py index f6639bda..1da49661 100644 --- a/src/art/dev/openai_server.py +++ b/src/art/dev/openai_server.py @@ -1,9 +1,17 @@ +import warnings from typing import Literal from typing_extensions import TypedDict from .engine import EngineArgs +ENGINE_INIT_ONLY_ARGS = { + "max_logprobs", + "gpu_memory_utilization", + "tensor_parallel_size", + "max_model_len", +} + def get_openai_server_config( model_name: str, @@ -35,6 +43,16 @@ def get_openai_server_config( generation_config="vllm", ) engine_args.update(config.get("engine_args", {})) + user_engine_args = config.get("engine_args", {}) + ignored_args = set(user_engine_args.keys()) & ENGINE_INIT_ONLY_ARGS + if ignored_args: + warnings.warn( + f"OpenAIServerConfig.engine_args contains {ignored_args} which will be " + f"ignored. The vLLM engine is initialized by Unsloth before this config " + f"is applied. Use TrainableModel._internal_config.engine_args instead.", + UserWarning, + stacklevel=2, + ) return OpenAIServerConfig( log_file=log_file, server_args=server_args, engine_args=engine_args ) diff --git a/src/art/loss.py b/src/art/loss.py index f209cc6e..246f0356 100644 --- a/src/art/loss.py +++ b/src/art/loss.py @@ -16,6 +16,9 @@ class Loss(BaseModel): mean_kl: torch.Tensor mean_entropy: torch.Tensor | None probs_corr: torch.Tensor + frac_old_logprobs_valid: float + mean_importance_ratio: torch.Tensor + clip_fraction: torch.Tensor def loss_fn( @@ -32,6 +35,9 @@ def loss_fn( ) weights = shift_tensor(inputs["weights"], 0.0) old_logprobs_mask = ~torch.isnan(old_logprobs) + frac_old_logprobs_valid = ( + old_logprobs_mask.float().sum() / (old_logprobs.numel() + 1e-6) + ).item() probs_corr = torch.corrcoef( torch.stack( [ @@ -77,15 +83,23 @@ def loss_fn( ) if tau := experimental_config.get("kimi_k2_tau", None): advantages -= tau * logprob_diff.detach() + clipped_ratio = torch.clip(prob_ratio, 1 - epsilon, 1 + epsilon_high) + is_clipped = (prob_ratio < 1 - epsilon) | (prob_ratio > 1 + epsilon_high) + clip_fraction = (is_clipped.float() * assistant_mask).sum() / ( + assistant_mask.sum() + 1e-6 + ) + mean_importance_ratio = (prob_ratio * assistant_mask).sum() / ( + assistant_mask.sum() + 1e-6 + ) if experimental_config.get("ppo", True): policy_loss = -torch.min( prob_ratio * advantages, - torch.clip(prob_ratio, 1 - epsilon, 1 + epsilon_high) * advantages, + clipped_ratio * advantages, ) else: # Modified REINFORCE or Clipped IS-weight Policy Optimization (CISPO) policy_loss = -( - torch.clip(prob_ratio.detach(), 1 - epsilon, 1 + epsilon_high) + clipped_ratio.detach() * advantages * new_logprobs ) @@ -123,6 +137,9 @@ def loss_fn( mean_kl=mean_kl, mean_entropy=mean_entropy, probs_corr=probs_corr, + frac_old_logprobs_valid=frac_old_logprobs_valid, + mean_importance_ratio=mean_importance_ratio, + clip_fraction=clip_fraction, ) diff --git a/src/art/model.py b/src/art/model.py index 43c519b2..040f157d 100644 --- a/src/art/model.py +++ b/src/art/model.py @@ -1,8 +1,17 @@ -from typing import TYPE_CHECKING, Generic, Iterable, Optional, TypeVar, cast, overload +from typing import ( + TYPE_CHECKING, + Any, + Generic, + Iterable, + Optional, + TypeVar, + cast, + overload, +) import httpx from openai import AsyncOpenAI, DefaultAsyncHttpxClient -from pydantic import BaseModel +from pydantic import BaseModel, model_validator from typing_extensions import Never from . import dev @@ -279,6 +288,19 @@ def __init__( # Bypass BaseModel __setattr__ to allow setting private attr object.__setattr__(self, "_internal_config", _internal_config) + @model_validator(mode="wrap") + @classmethod + def _preserve_internal_config( + cls, data: Any, handler: Any + ) -> "TrainableModel[ModelConfig]": + internal_config = None + if isinstance(data, dict) and "_internal_config" in data: + internal_config = data.pop("_internal_config") + model = handler(data) + if internal_config is not None: + object.__setattr__(model, "_internal_config", internal_config) + return model + @overload def __new__( cls, diff --git a/src/art/preprocessing/tokenize.py b/src/art/preprocessing/tokenize.py index 6ffc462f..0e8c7d14 100644 --- a/src/art/preprocessing/tokenize.py +++ b/src/art/preprocessing/tokenize.py @@ -159,23 +159,28 @@ def tokenize_trajectory( if history.tools is not None else None ) - chat = cast( - str, - tokenizer.apply_chat_template( - cast(list[dict], messages), - tools=tools, # type: ignore - continue_final_message=True, - tokenize=False, - ), - ) - original_token_ids = cast( - list[int], - tokenizer.apply_chat_template( - cast(list[dict], messages), - tools=tools, # type: ignore - continue_final_message=True, - ), - ) + try: + chat = cast( + str, + tokenizer.apply_chat_template( + cast(list[dict], messages), + tools=tools, # type: ignore + continue_final_message=True, + tokenize=False, + ), + ) + original_token_ids = cast( + list[int], + tokenizer.apply_chat_template( + cast(list[dict], messages), + tools=tools, # type: ignore + continue_final_message=True, + ), + ) + except ValueError as e: + if "continue_final_message" in str(e): + return None + raise sentinal_token_id = max( set(range(cast(int, tokenizer.vocab_size))) - set(original_token_ids) ) @@ -216,13 +221,46 @@ def tokenize_trajectory( if isinstance(message, dict): content = message.get("content") assert isinstance(content, str) - content_token_ids = tokenizer.encode( - content, - add_special_tokens=False, - ) - token_ids[start:end] = content_token_ids - logprobs[start:end] = [float("nan")] * len(content_token_ids) - assistant_mask[start:end] = [1] * len(content_token_ids) + msg_token_ids = message.get("token_ids") + dict_logprobs = message.get("logprobs") + if ( + msg_token_ids is not None + and dict_logprobs + and "values" in dict_logprobs + ): + token_ids[start:end] = msg_token_ids + logprobs[start:end] = dict_logprobs["values"] + assistant_mask[start:end] = [1] * len(msg_token_ids) + elif ( + dict_logprobs + and "content" in dict_logprobs + and dict_logprobs["content"] + ): + token_logprobs = dict_logprobs["content"] + try: + token_ids[start:end] = [ + int(lp["token"].split(":")[1]) for lp in token_logprobs + ] + except (IndexError, ValueError, KeyError): + token_ids[start:end] = [ + token_id if token_id is not None else tokenizer.eos_token_id + for token_id in tokenizer.convert_tokens_to_ids( + [ + lp.get("token") or tokenizer.eos_token + for lp in token_logprobs + ] + ) + ] + logprobs[start:end] = [lp["logprob"] for lp in token_logprobs] + assistant_mask[start:end] = [1] * len(token_logprobs) + else: + content_token_ids = tokenizer.encode( + content, + add_special_tokens=False, + ) + token_ids[start:end] = content_token_ids + logprobs[start:end] = [float("nan")] * len(content_token_ids) + assistant_mask[start:end] = [1] * len(content_token_ids) else: choice = message assert choice.logprobs or allow_training_without_logprobs, ( diff --git a/src/art/rewards/ruler.py b/src/art/rewards/ruler.py index 2ea33312..41054470 100644 --- a/src/art/rewards/ruler.py +++ b/src/art/rewards/ruler.py @@ -20,6 +20,7 @@ from rich import print import art +from art.utils.strip_logprobs import strip_logprobs class TrajectoryScore(BaseModel): @@ -287,9 +288,10 @@ async def ruler_score_group( new_trajectories.append(new_traj) # Extract message lists and preserve original rewards for comparison + # Strip logprobs to avoid sending huge token probability data to the judge message_lists: list[list[ChatCompletionMessageParam]] = [] for traj in new_trajectories: - message_lists.append(traj.messages()) + message_lists.append(strip_logprobs(traj.messages())) traj.metrics["independent_reward"] = traj.reward try: diff --git a/src/art/unsloth/train.py b/src/art/unsloth/train.py index ddacaafd..320b407a 100644 --- a/src/art/unsloth/train.py +++ b/src/art/unsloth/train.py @@ -167,7 +167,10 @@ def compute_loss( trainer._metrics["train"]["learning_rate"].append(config.learning_rate) trainer._metrics["train"]["policy_loss"].append(loss.mean_policy_loss.item()) if loss.mean_entropy is not None: - trainer._metrics["train"]["entropy"].append(loss.mean_entropy.item()) # type: ignore + trainer._metrics["train"]["entropy"].append(loss.mean_entropy.item()) + trainer._metrics["train"]["frac_old_logprobs_valid"].append(loss.frac_old_logprobs_valid) + trainer._metrics["train"]["mean_importance_ratio"].append(loss.mean_importance_ratio.item()) + trainer._metrics["train"]["clip_fraction"].append(loss.clip_fraction.item()) if config.beta > 0.0: trainer._metrics["train"]["kl_div"].append(loss.mean_kl.item()) return loss.mean_policy_loss + config.beta * loss.mean_kl # type: ignore