Skip to content

Commit

Permalink
Fix code interpreter
Browse files Browse the repository at this point in the history
  • Loading branch information
vegito22 committed Nov 8, 2024
1 parent 5d278dc commit 9575417
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 28 deletions.
2 changes: 1 addition & 1 deletion llmstack/apps/apis.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,7 +844,7 @@ async def get_app_runner_async(self, session_id, source, request_user, input_dat
input_schema = json.loads(processor_cls.get_input_schema())
input_fields = []
for property in input_schema["properties"]:
input_fields.append({"name": property, "type": input_schema["properties"][property]["type"]})
input_fields.append({"name": property, "type": input_schema["properties"][property].get("type", "string")})

app_data = {
"name": f"Processor {provider_slug}_{processor_slug}",
Expand Down
2 changes: 1 addition & 1 deletion llmstack/play/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def _done_callback(future):
except CancelledError:
logger.info("Task cancelled")
except Exception as e:
logger.error(f"Task in loop {name} failed with error: {e}")
logger.exception(f"Task in loop {name} failed with error: {e}")
finally:
# Find and cancel all pending tasks before stopping the loop
for task in asyncio.all_tasks(loop):
Expand Down
68 changes: 42 additions & 26 deletions llmstack/processors/providers/promptly/code_interpreter.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import base64
import logging
import uuid
from typing import Dict, List, Optional
from typing import List, Optional

import grpc
from asgiref.sync import async_to_sync
from django.conf import settings
from langrocks.client.code_runner import (
CodeRunner,
CodeRunnerSession,
CodeRunnerState,
Content,
ContentMimeType,
)
from pydantic import Field

from llmstack.apps.schemas import OutputTemplate
Expand All @@ -14,7 +20,10 @@
ApiProcessorInterface,
ApiProcessorSchema,
)
from llmstack.processors.providers.promptly import Content, ContentMimeType
from llmstack.processors.providers.promptly import Content as PromptlyContent
from llmstack.processors.providers.promptly import (
ContentMimeType as PromptlyContentMimeType,
)

logger = logging.getLogger(__name__)

Expand All @@ -25,15 +34,21 @@ class CodeInterpreterLanguage(StrEnum):

class CodeInterpreterInput(ApiProcessorSchema):
code: str = Field(description="The code to run", json_schema_extra={"widget": "textarea"}, default="")
files: Optional[str] = Field(
description="Workspace files as a comma separated list",
default=None,
json_schema_extra={
"widget": "file",
},
)
language: CodeInterpreterLanguage = Field(
title="Language", description="The language of the code", default=CodeInterpreterLanguage.PYTHON
)


class CodeInterpreterOutput(ApiProcessorSchema):
stdout: List[Content] = Field(default=[], description="Standard output as a list of Content objects")
stdout: List[PromptlyContent] = Field(default=[], description="Standard output as a list of Content objects")
stderr: str = Field(default="", description="Standard error")
local_variables: Optional[Dict] = Field(description="Local variables as a JSON object")
exit_code: int = Field(default=0, description="Exit code of the process")


Expand Down Expand Up @@ -106,33 +121,34 @@ def convert_stdout_to_content(self, stdout) -> List[Content]:
return content

def process_session_data(self, session_data):
self._kernel_session_id = session_data.get("kernel_session_id", None)
self._interpreter_session_id = session_data.get("interpreter_session_id", str(uuid.uuid4()))
self._interpreter_session_data = session_data.get("interpreter_session_data", "")

def session_data_to_persist(self) -> dict:
return {
"kernel_session_id": self._kernel_session_id,
"interpreter_session_id": self._interpreter_session_id,
"interpreter_session_data": self._interpreter_session_data,
}

def process(self) -> dict:
from langrocks.common.models import runner_pb2, runner_pb2_grpc

kernel_session_id = self._kernel_session_id if self._kernel_session_id else str(uuid.uuid4())

channel = grpc.insecure_channel(f"{settings.RUNNER_HOST}:{settings.RUNNER_PORT}")
stub = runner_pb2_grpc.RunnerStub(channel)

request = runner_pb2.CodeRunnerRequest(source_code=self._input.code, timeout_secs=self._config.timeout)
response_iter = stub.GetCodeRunner(
iter([request]),
metadata=(("kernel_session_id", kernel_session_id),),
)
for response in response_iter:
if response.stdout:
stdout_result = self.convert_stdout_to_content(response.stdout)
async_to_sync(self._output_stream.write)(CodeInterpreterOutput(stdout=stdout_result))

if response.stderr:
async_to_sync(self._output_stream.write)(CodeInterpreterOutput(stderr=response.stderr))
with CodeRunner(
base_url=f"{settings.RUNNER_HOST}:{settings.RUNNER_PORT}",
session=CodeRunnerSession(
session_id=self._interpreter_session_id, session_data=self._interpreter_session_data
),
) as code_runner:
current_state = code_runner.get_state()
if current_state == CodeRunnerState.CODE_RUNNING:
respose_iter = code_runner.run_code(source_code=self._input.code)
for response in respose_iter:
async_to_sync(self._output_stream.write)(
CodeInterpreterOutput(
stdout=[PromptlyContent(mime_type=PromptlyContentMimeType.TEXT, data=response.decode())]
)
)

session_data = code_runner.get_session()
self._interpreter_session_data = session_data.session_data

output = self._output_stream.finalize()
return output

0 comments on commit 9575417

Please sign in to comment.