Skip to content
18 changes: 18 additions & 0 deletions src/art/dev/openai_server.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
import warnings
from typing import Literal

from typing_extensions import TypedDict

from .engine import EngineArgs

ENGINE_INIT_ONLY_ARGS = {
"max_logprobs",
"gpu_memory_utilization",
"tensor_parallel_size",
"max_model_len",
}


def get_openai_server_config(
model_name: str,
Expand Down Expand Up @@ -35,6 +43,16 @@ def get_openai_server_config(
generation_config="vllm",
)
engine_args.update(config.get("engine_args", {}))
user_engine_args = config.get("engine_args", {})
ignored_args = set(user_engine_args.keys()) & ENGINE_INIT_ONLY_ARGS
if ignored_args:
warnings.warn(
f"OpenAIServerConfig.engine_args contains {ignored_args} which will be "
f"ignored. The vLLM engine is initialized by Unsloth before this config "
f"is applied. Use TrainableModel._internal_config.engine_args instead.",
UserWarning,
stacklevel=2,
)
return OpenAIServerConfig(
log_file=log_file, server_args=server_args, engine_args=engine_args
)
Expand Down
21 changes: 19 additions & 2 deletions src/art/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ class Loss(BaseModel):
mean_kl: torch.Tensor
mean_entropy: torch.Tensor | None
probs_corr: torch.Tensor
frac_old_logprobs_valid: float
mean_importance_ratio: torch.Tensor
clip_fraction: torch.Tensor


def loss_fn(
Expand All @@ -32,6 +35,9 @@ def loss_fn(
)
weights = shift_tensor(inputs["weights"], 0.0)
old_logprobs_mask = ~torch.isnan(old_logprobs)
frac_old_logprobs_valid = (
old_logprobs_mask.float().sum() / (old_logprobs.numel() + 1e-6)
).item()
probs_corr = torch.corrcoef(
torch.stack(
[
Expand Down Expand Up @@ -77,15 +83,23 @@ def loss_fn(
)
if tau := experimental_config.get("kimi_k2_tau", None):
advantages -= tau * logprob_diff.detach()
clipped_ratio = torch.clip(prob_ratio, 1 - epsilon, 1 + epsilon_high)
is_clipped = (prob_ratio < 1 - epsilon) | (prob_ratio > 1 + epsilon_high)
clip_fraction = (is_clipped.float() * assistant_mask).sum() / (
assistant_mask.sum() + 1e-6
)
mean_importance_ratio = (prob_ratio * assistant_mask).sum() / (
assistant_mask.sum() + 1e-6
)
if experimental_config.get("ppo", True):
policy_loss = -torch.min(
prob_ratio * advantages,
torch.clip(prob_ratio, 1 - epsilon, 1 + epsilon_high) * advantages,
clipped_ratio * advantages,
)
else:
# Modified REINFORCE or Clipped IS-weight Policy Optimization (CISPO)
policy_loss = -(
torch.clip(prob_ratio.detach(), 1 - epsilon, 1 + epsilon_high)
clipped_ratio.detach()
* advantages
* new_logprobs
)
Expand Down Expand Up @@ -123,6 +137,9 @@ def loss_fn(
mean_kl=mean_kl,
mean_entropy=mean_entropy,
probs_corr=probs_corr,
frac_old_logprobs_valid=frac_old_logprobs_valid,
mean_importance_ratio=mean_importance_ratio,
clip_fraction=clip_fraction,
)


Expand Down
26 changes: 24 additions & 2 deletions src/art/model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
from typing import TYPE_CHECKING, Generic, Iterable, Optional, TypeVar, cast, overload
from typing import (
TYPE_CHECKING,
Any,
Generic,
Iterable,
Optional,
TypeVar,
cast,
overload,
)

import httpx
from openai import AsyncOpenAI, DefaultAsyncHttpxClient
from pydantic import BaseModel
from pydantic import BaseModel, model_validator
from typing_extensions import Never

from . import dev
Expand Down Expand Up @@ -279,6 +288,19 @@ def __init__(
# Bypass BaseModel __setattr__ to allow setting private attr
object.__setattr__(self, "_internal_config", _internal_config)

@model_validator(mode="wrap")
@classmethod
def _preserve_internal_config(
cls, data: Any, handler: Any
) -> "TrainableModel[ModelConfig]":
internal_config = None
if isinstance(data, dict) and "_internal_config" in data:
internal_config = data.pop("_internal_config")
model = handler(data)
if internal_config is not None:
object.__setattr__(model, "_internal_config", internal_config)
return model

@overload
def __new__(
cls,
Expand Down
86 changes: 62 additions & 24 deletions src/art/preprocessing/tokenize.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,23 +159,28 @@ def tokenize_trajectory(
if history.tools is not None
else None
)
chat = cast(
str,
tokenizer.apply_chat_template(
cast(list[dict], messages),
tools=tools, # type: ignore
continue_final_message=True,
tokenize=False,
),
)
original_token_ids = cast(
list[int],
tokenizer.apply_chat_template(
cast(list[dict], messages),
tools=tools, # type: ignore
continue_final_message=True,
),
)
try:
chat = cast(
str,
tokenizer.apply_chat_template(
cast(list[dict], messages),
tools=tools, # type: ignore
continue_final_message=True,
tokenize=False,
),
)
original_token_ids = cast(
list[int],
tokenizer.apply_chat_template(
cast(list[dict], messages),
tools=tools, # type: ignore
continue_final_message=True,
),
)
except ValueError as e:
if "continue_final_message" in str(e):
return None
raise
sentinal_token_id = max(
set(range(cast(int, tokenizer.vocab_size))) - set(original_token_ids)
)
Expand Down Expand Up @@ -216,13 +221,46 @@ def tokenize_trajectory(
if isinstance(message, dict):
content = message.get("content")
assert isinstance(content, str)
content_token_ids = tokenizer.encode(
content,
add_special_tokens=False,
)
token_ids[start:end] = content_token_ids
logprobs[start:end] = [float("nan")] * len(content_token_ids)
assistant_mask[start:end] = [1] * len(content_token_ids)
msg_token_ids = message.get("token_ids")
dict_logprobs = message.get("logprobs")
if (
msg_token_ids is not None
and dict_logprobs
and "values" in dict_logprobs
):
token_ids[start:end] = msg_token_ids
logprobs[start:end] = dict_logprobs["values"]
assistant_mask[start:end] = [1] * len(msg_token_ids)
elif (
dict_logprobs
and "content" in dict_logprobs
and dict_logprobs["content"]
):
token_logprobs = dict_logprobs["content"]
try:
token_ids[start:end] = [
int(lp["token"].split(":")[1]) for lp in token_logprobs
]
except (IndexError, ValueError, KeyError):
token_ids[start:end] = [
token_id if token_id is not None else tokenizer.eos_token_id
for token_id in tokenizer.convert_tokens_to_ids(
[
lp.get("token") or tokenizer.eos_token
for lp in token_logprobs
]
)
]
logprobs[start:end] = [lp["logprob"] for lp in token_logprobs]
assistant_mask[start:end] = [1] * len(token_logprobs)
else:
content_token_ids = tokenizer.encode(
content,
add_special_tokens=False,
)
token_ids[start:end] = content_token_ids
logprobs[start:end] = [float("nan")] * len(content_token_ids)
assistant_mask[start:end] = [1] * len(content_token_ids)
else:
choice = message
assert choice.logprobs or allow_training_without_logprobs, (
Expand Down
4 changes: 3 additions & 1 deletion src/art/rewards/ruler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from rich import print

import art
from art.utils.strip_logprobs import strip_logprobs


class TrajectoryScore(BaseModel):
Expand Down Expand Up @@ -287,9 +288,10 @@ async def ruler_score_group(
new_trajectories.append(new_traj)

# Extract message lists and preserve original rewards for comparison
# Strip logprobs to avoid sending huge token probability data to the judge
message_lists: list[list[ChatCompletionMessageParam]] = []
for traj in new_trajectories:
message_lists.append(traj.messages())
message_lists.append(strip_logprobs(traj.messages()))
traj.metrics["independent_reward"] = traj.reward

try:
Expand Down
5 changes: 4 additions & 1 deletion src/art/unsloth/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,10 @@ def compute_loss(
trainer._metrics["train"]["learning_rate"].append(config.learning_rate)
trainer._metrics["train"]["policy_loss"].append(loss.mean_policy_loss.item())
if loss.mean_entropy is not None:
trainer._metrics["train"]["entropy"].append(loss.mean_entropy.item()) # type: ignore
trainer._metrics["train"]["entropy"].append(loss.mean_entropy.item())
trainer._metrics["train"]["frac_old_logprobs_valid"].append(loss.frac_old_logprobs_valid)
trainer._metrics["train"]["mean_importance_ratio"].append(loss.mean_importance_ratio.item())
trainer._metrics["train"]["clip_fraction"].append(loss.clip_fraction.item())
if config.beta > 0.0:
trainer._metrics["train"]["kl_div"].append(loss.mean_kl.item())
return loss.mean_policy_loss + config.beta * loss.mean_kl # type: ignore
Expand Down
Loading