diff --git a/src/opengradient/client/llm.py b/src/opengradient/client/llm.py index 20e32fb..20a570e 100644 --- a/src/opengradient/client/llm.py +++ b/src/opengradient/client/llm.py @@ -53,6 +53,11 @@ class _ChatParams: x402_settlement_mode: x402SettlementMode +def _model_id(model: str) -> str: + """Extract model ID from a TEE_LLM value, handling plain strings without separator.""" + return model.split("/")[1] if "/" in model else model + + class LLM: """ LLM inference namespace. @@ -277,7 +282,7 @@ async def completion( Raises: RuntimeError: If the inference fails. """ - model_id = model.split("/")[1] + model_id = _model_id(model) payload: Dict = { "model": model_id, "prompt": prompt, @@ -360,7 +365,7 @@ async def chat( RuntimeError: If the inference fails. """ if response_format is not None and response_format.type == "json_object": - provider = model.split("/")[0] + provider = model.split("/")[0] if "/" in model else None if provider == "anthropic": raise ValueError( "Anthropic models do not support response_format type 'json_object'. " @@ -368,7 +373,7 @@ async def chat( ) params = _ChatParams( - model=model.split("/")[1], + model=_model_id(model), max_tokens=max_tokens, temperature=temperature, stop_sequence=stop_sequence,