Skip to content

Commit

Permalink
Updates per pr feedback.
Browse files Browse the repository at this point in the history
  • Loading branch information
Ben Thomas committed Feb 11, 2025
1 parent 23fce9e commit 4fdd301
Show file tree
Hide file tree
Showing 7 changed files with 44 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

import aiohttp

from semantic_kernel.connectors.memory.astradb.utils import AsyncSession
from semantic_kernel.exceptions import ServiceResponseException
from semantic_kernel.utils.experimental_decorator import experimental_class
from semantic_kernel.utils.http.utils import AsyncSession
from semantic_kernel.utils.telemetry.user_agent import APP_INFO

ASTRA_CALLER_IDENTITY: str
Expand Down
17 changes: 0 additions & 17 deletions python/semantic_kernel/connectors/memory/astradb/utils.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,11 @@
# Copyright (c) Microsoft. All rights reserved.
from typing import Any

import aiohttp
import numpy

from semantic_kernel.memory.memory_record import MemoryRecord


class AsyncSession:
"""A wrapper around aiohttp.ClientSession that can be used as an async context manager."""

def __init__(self, session: aiohttp.ClientSession = None):
"""Initializes a new instance of the AsyncSession class."""
self._session = session if session else aiohttp.ClientSession()

async def __aenter__(self):
"""Enter the session."""
return await self._session.__aenter__()

async def __aexit__(self, *args, **kwargs):
"""Close the session."""
await self._session.close()


def build_payload(record: MemoryRecord) -> dict[str, Any]:
"""Builds a metadata payload to be sent to AstraDb from a MemoryRecord."""
payload: dict[str, Any] = {}
Expand Down
26 changes: 15 additions & 11 deletions python/semantic_kernel/core_plugins/crew_ai/crew_ai_enterprise.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,31 +135,35 @@ async def wait_for_crew_completion(self, kickoff_id: str) -> str:
FunctionExecutionException: If the task fails or an error occurs while waiting for completion.
"""
try:
status_response = None
status = CrewAIEnterpriseKickoffState.Pending
status_response: CrewAIStatusResponse | None = None
state: str = CrewAIEnterpriseKickoffState.Pending

async def poll_status():
nonlocal status, status_response
while status not in [
nonlocal state, status_response
while state not in [
CrewAIEnterpriseKickoffState.Failed,
CrewAIEnterpriseKickoffState.Failure,
CrewAIEnterpriseKickoffState.Success,
CrewAIEnterpriseKickoffState.Not_Found,
]:
logger.debug(
f"Waiting for CrewAI Crew with kickoff Id: {kickoff_id} to complete. Current state: {status}"
f"Waiting for CrewAI Crew with kickoff Id: {kickoff_id} to complete. Current state: {state}"
)

await asyncio.sleep(self.polling_interval)
status_response = await self.client.get_status(kickoff_id)
status = status_response.state
state = status_response.state

await asyncio.wait_for(poll_status(), timeout=self.polling_timeout)

logger.info(f"CrewAI Crew with kickoff Id: {kickoff_id} completed with status: {status_response.state}")
if status in ["Failed", "Failure"]:
raise FunctionResultError(f"CrewAI Crew failed with error: {status_response.result}")
return status_response.result or ""
logger.info(f"CrewAI Crew with kickoff Id: {kickoff_id} completed with status: {state}")
result = (
status_response.result if status_response is not None and status_response.result is not None else ""
)

if state in ["Failed", "Failure"]:
raise FunctionResultError(f"CrewAI Crew failed with error: {result}")
return result
except Exception as ex:
raise FunctionExecutionException(
f"Failed to wait for completion of CrewAI Crew with kickoff Id: {kickoff_id}."
Expand All @@ -173,7 +177,7 @@ def create_kernel_plugin(
task_webhook_url: str | None = None,
step_webhook_url: str | None = None,
crew_webhook_url: str | None = None,
) -> dict[str, Any]:
) -> KernelPlugin:
"""Creates a kernel plugin that can be used to invoke the CrewAI Crew.
Args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,18 @@

import aiohttp

from semantic_kernel.connectors.memory.astradb.utils import AsyncSession
from semantic_kernel.core_plugins.crew_ai.crew_ai_models import (
CrewAIKickoffResponse,
CrewAIRequiredInputs,
CrewAIStatusResponse,
)
from semantic_kernel.kernel_pydantic import KernelBaseModel
from semantic_kernel.utils.http.utils import AsyncSession
from semantic_kernel.utils.telemetry.user_agent import SEMANTIC_KERNEL_USER_AGENT


class CrewAIEnterpriseClient(KernelBaseModel):
class CrewAIEnterpriseClient:
"""Client to interact with the Crew AI Enterprise API."""

endpoint: str
auth_token: str
request_header: dict[str, str]
session: aiohttp.ClientSession | None

def __init__(
self,
endpoint: str,
Expand All @@ -35,23 +29,23 @@ def __init__(
auth_token (str): The authentication token.
session (Optional[aiohttp.ClientSession], optional): The HTTP client session. Defaults to None.
"""
request_header = {
self.endpoint = endpoint
self.auth_token = auth_token
self._session = session
self.request_header = {
"Authorization": f"Bearer {auth_token}",
"Content-Type": "application/json",
"user_agent": SEMANTIC_KERNEL_USER_AGENT,
}

session = session
super().__init__(endpoint=endpoint, auth_token=auth_token, request_header=request_header, session=session)

async def get_inputs(self) -> CrewAIRequiredInputs:
"""Get the required inputs for Crew AI.
Returns:
CrewAIRequiredInputs: The required inputs for Crew AI.
"""
async with (
AsyncSession(self.session) as session,
AsyncSession(self._session) as session,
session.get(f"{self.endpoint}/inputs", headers=self.request_header) as response,
):
response.raise_for_status()
Expand Down Expand Up @@ -82,7 +76,7 @@ async def kickoff(
"crewWebhookUrl": crew_webhook_url,
}
async with (
AsyncSession(self.session) as session,
AsyncSession(self._session) as session,
session.post(f"{self.endpoint}/kickoff", json=content, headers=self.request_header) as response,
):
response.raise_for_status()
Expand All @@ -99,7 +93,7 @@ async def get_status(self, task_id: str) -> CrewAIStatusResponse:
CrewAIStatusResponse: The status response of the task.
"""
async with (
AsyncSession(self.session) as session,
AsyncSession(self._session) as session,
session.get(f"{self.endpoint}/status/{task_id}", headers=self.request_header) as response,
):
response.raise_for_status()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,6 @@ class CrewAISettings(KernelBaseSettings):
env_prefix: ClassVar[str] = "CREW_AI_"

endpoint: str
auth_token: SecretStr | None = None
auth_token: SecretStr
polling_interval: float = 1.0
polling_timeout: float = 30.0
Empty file.
18 changes: 18 additions & 0 deletions python/semantic_kernel/utils/http/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) Microsoft. All rights reserved.
import aiohttp


class AsyncSession:
"""A wrapper around aiohttp.ClientSession that can be used as an async context manager."""

def __init__(self, session: aiohttp.ClientSession = None):
"""Initializes a new instance of the AsyncSession class."""
self._session = session if session else aiohttp.ClientSession()

async def __aenter__(self):
"""Enter the session."""
return await self._session.__aenter__()

async def __aexit__(self, *args, **kwargs):
"""Close the session."""
await self._session.close()

0 comments on commit 4fdd301

Please sign in to comment.