Skip to content

Commit

Permalink
fix: update SK model adapter constructor (#5150)
Browse files Browse the repository at this point in the history
* update constructor

* fix typing error

* revert/fix doc changes

* add unsaved changes

---------

Co-authored-by: Leonardo Pinheiro <[email protected]>
  • Loading branch information
lspinheiro and lpinheiroms authored Jan 23, 2025
1 parent 5e9b24c commit 3fe1066
Showing 1 changed file with 35 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,15 @@ class SKChatCompletionAdapter(ChatCompletionClient):
Args:
sk_client (ChatCompletionClientBase):
The Semantic Kernel client to wrap (e.g., AzureChatCompletion, GoogleAIChatCompletion, OllamaChatCompletion).
kernel (Optional[Kernel]):
The Semantic Kernel instance to use for executing requests. If not provided, one must be passed
in the extra_create_args for each request.
prompt_settings (Optional[PromptExecutionSettings]):
Default prompt execution settings to use. Can be overridden per request.
model_info (Optional[ModelInfo]):
Information about the model's capabilities.
service_id (Optional[str]):
Optional service identifier.
Example usage:
Expand Down Expand Up @@ -100,8 +109,8 @@ async def main():
api_key = "<AZURE_OPENAI_API_KEY>"
azure_client = AzureChatCompletion(deployment_name=deployment_name, endpoint=endpoint, api_key=api_key)
azure_request_settings = AzureChatPromptExecutionSettings(temperature=0.8)
azure_adapter = SKChatCompletionAdapter(sk_client=azure_client, default_prompt_settings=azure_request_settings)
azure_settings = AzureChatPromptExecutionSettings(temperature=0.8)
azure_adapter = SKChatCompletionAdapter(sk_client=azure_client, kernel=kernel, prompt_settings=azure_settings)
# ----------------------------------------------------------------
# Example B: Google Gemini
Expand All @@ -127,7 +136,7 @@ async def main():
"temperature": 0.8,
},
)
ollama_adapter = SKChatCompletionAdapter(sk_client=ollama_client, default_prompt_settings=request_settings)
ollama_adapter = SKChatCompletionAdapter(sk_client=ollama_client, prompt_settings=request_settings)
# 3) Create a tool and register it with the kernel
calc_tool = CalculatorTool()
Expand All @@ -143,23 +152,20 @@ async def main():
azure_result = await azure_adapter.create(
messages=messages,
tools=[calc_tool],
extra_create_args={"kernel": kernel, "prompt_execution_settings": azure_request_settings},
)
print("Azure result:", azure_result.content)
# Google example
google_result = await google_adapter.create(
messages=messages,
tools=[calc_tool],
extra_create_args={"kernel": kernel},
)
print("Google result:", google_result.content)
# Ollama example
ollama_result = await ollama_adapter.create(
messages=messages,
tools=[calc_tool],
extra_create_args={"kernel": kernel, "prompt_execution_settings": request_settings},
)
print("Ollama result:", ollama_result.content)
Expand All @@ -171,12 +177,14 @@ async def main():
def __init__(
self,
sk_client: ChatCompletionClientBase,
kernel: Optional[Kernel] = None,
prompt_settings: Optional[PromptExecutionSettings] = None,
model_info: Optional[ModelInfo] = None,
service_id: Optional[str] = None,
default_prompt_settings: Optional[PromptExecutionSettings] = None,
):
self._service_id = service_id
self._default_prompt_settings = default_prompt_settings
self._kernel = kernel
self._prompt_settings = prompt_settings
self._sk_client = sk_client
self._model_info = model_info or ModelInfo(
vision=False, function_calling=False, json_output=False, family=ModelFamily.UNKNOWN
Expand Down Expand Up @@ -287,6 +295,17 @@ def _process_tool_calls(self, result: ChatMessageContent) -> list[FunctionCall]:
function_calls.append(FunctionCall(id=item.id, name=full_name, arguments=arguments))
return function_calls

def _get_kernel(self, extra_create_args: Mapping[str, Any]) -> Kernel:
kernel = extra_create_args.get("kernel", self._kernel)
if not kernel:
raise ValueError("kernel must be provided either in constructor or extra_create_args")
if not isinstance(kernel, Kernel):
raise ValueError("kernel must be an instance of semantic_kernel.kernel.Kernel")
return kernel

def _get_prompt_settings(self, extra_create_args: Mapping[str, Any]) -> Optional[PromptExecutionSettings]:
return extra_create_args.get("prompt_execution_settings", None) or self._prompt_settings

async def create(
self,
messages: Sequence[LLMMessage],
Expand All @@ -300,9 +319,9 @@ async def create(
The `extra_create_args` dictionary can include two special keys:
1) `"kernel"` (required):
1) `"kernel"` (optional):
An instance of :class:`semantic_kernel.Kernel` used to execute the request.
If not provided, a ValueError is raised.
If not provided either in constructor or extra_create_args, a ValueError is raised.
2) `"prompt_execution_settings"` (optional):
An instance of a :class:`PromptExecutionSettings` subclass corresponding to the
Expand All @@ -320,19 +339,9 @@ async def create(
Returns:
CreateResult: The result of the chat completion.
"""
if "kernel" not in extra_create_args:
raise ValueError("kernel is required in extra_create_args")

kernel = extra_create_args["kernel"]
if not isinstance(kernel, Kernel):
raise ValueError("kernel must be an instance of semantic_kernel.kernel.Kernel")

kernel = self._get_kernel(extra_create_args)
chat_history = self._convert_to_chat_history(messages)

# Build execution settings from extra args and tools
user_settings = extra_create_args.get("prompt_execution_settings", None)
if user_settings is None:
user_settings = self._default_prompt_settings
user_settings = self._get_prompt_settings(extra_create_args)
settings = self._build_execution_settings(user_settings, tools)

# Sync tools with kernel
Expand Down Expand Up @@ -380,9 +389,9 @@ async def create_stream(
The `extra_create_args` dictionary can include two special keys:
1) `"kernel"` (required):
1) `"kernel"` (optional):
An instance of :class:`semantic_kernel.Kernel` used to execute the request.
If not provided, a ValueError is raised.
If not provided either in constructor or extra_create_args, a ValueError is raised.
2) `"prompt_execution_settings"` (optional):
An instance of a :class:`PromptExecutionSettings` subclass corresponding to the
Expand All @@ -400,17 +409,9 @@ async def create_stream(
Yields:
Union[str, CreateResult]: Either a string chunk of the response or a CreateResult containing function calls.
"""
if "kernel" not in extra_create_args:
raise ValueError("kernel is required in extra_create_args")

kernel = extra_create_args["kernel"]
if not isinstance(kernel, Kernel):
raise ValueError("kernel must be an instance of semantic_kernel.kernel.Kernel")

kernel = self._get_kernel(extra_create_args)
chat_history = self._convert_to_chat_history(messages)
user_settings = extra_create_args.get("prompt_execution_settings", None)
if user_settings is None:
user_settings = self._default_prompt_settings
user_settings = self._get_prompt_settings(extra_create_args)
settings = self._build_execution_settings(user_settings, tools)
self._sync_tools_with_kernel(kernel, tools)

Expand Down

0 comments on commit 3fe1066

Please sign in to comment.