Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CodeNode #992

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 98 additions & 1 deletion apps/pipelines/nodes/nodes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import datetime
import json
import time
from typing import Literal

import tiktoken
Expand All @@ -14,13 +16,14 @@
from pydantic.config import ConfigDict
from pydantic_core import PydanticCustomError
from pydantic_core.core_schema import FieldValidationInfo
from RestrictedPython import compile_restricted_function, safe_builtins, safe_globals

from apps.assistants.models import OpenAiAssistant
from apps.channels.datamodels import Attachment
from apps.chat.conversation import compress_chat_history, compress_pipeline_chat_history
from apps.chat.models import ChatMessageType
from apps.experiments.models import ExperimentSession, ParticipantData
from apps.pipelines.exceptions import PipelineNodeBuildError
from apps.pipelines.exceptions import PipelineNodeBuildError, PipelineNodeRunError
from apps.pipelines.models import PipelineChatHistory, PipelineChatHistoryTypes
from apps.pipelines.nodes.base import NodeSchema, OptionsSource, PipelineNode, PipelineState, UiSchema, Widgets
from apps.pipelines.tasks import send_email_from_pipeline
Expand Down Expand Up @@ -622,3 +625,97 @@ def _get_assistant_runnable(self, assistant: OpenAiAssistant, session: Experimen
return AgentAssistantChat(adapter=adapter, history_manager=history_manager)
else:
return AssistantChat(adapter=adapter, history_manager=history_manager)


class CodeNode(PipelineNode):
"""Runs python"""

model_config = ConfigDict(json_schema_extra=NodeSchema(label="Python Node"))
code: str = Field(
description="The code to run",
json_schema_extra=UiSchema(widget=Widgets.expandable_text), # TODO: add a code widget
)

@field_validator("code")
def validate_code(cls, value, info: FieldValidationInfo):
if not value:
value = "return input"

byte_code = compile_restricted_function(
"input,shared_state",
value,
name="main",
filename="<inline code>",
)

if byte_code.errors:
raise PydanticCustomError("invalid_code", "{errors}", {"errors": "\n".join(byte_code.errors)})
return value

def _process(self, input: str, state: PipelineState, node_id: str) -> PipelineState:
function_name = "main"
function_args = "input"
byte_code = compile_restricted_function(
function_args,
self.code,
name=function_name,
filename="<inline code>",
)

custom_locals = {}
custom_globals = self._get_custom_globals()
exec(byte_code.code, custom_globals, custom_locals)

try:
result = str(custom_locals[function_name](input))
except Exception as exc:
raise PipelineNodeRunError(exc) from exc
return PipelineState.from_node_output(node_id=node_id, output=result)

def _get_custom_globals(self):
from RestrictedPython.Eval import (
default_guarded_getitem,
default_guarded_getiter,
)

custom_globals = safe_globals.copy()
custom_globals.update(
{
"__builtins__": self._get_custom_builtins(),
"json": json,
"datetime": datetime,
"time": time,
"_getitem_": default_guarded_getitem,
"_getiter_": default_guarded_getiter,
"_write_": lambda x: x,
}
)
return custom_globals

def _get_custom_builtins(self):
allowed_modules = {
"json",
"re",
"datetime",
"time",
}
custom_builtins = safe_builtins.copy()
custom_builtins.update(
{
"min": min,
"max": max,
"sum": sum,
"abs": abs,
"all": all,
"any": any,
"datetime": datetime,
}
)

def guarded_import(name, *args, **kwargs):
if name not in allowed_modules:
raise ImportError(f"Importing '{name}' is not allowed")
return __import__(name, *args, **kwargs)

custom_builtins["__import__"] = guarded_import
return custom_builtins
88 changes: 88 additions & 0 deletions apps/pipelines/tests/test_code_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import json
from unittest import mock

import pytest

from apps.pipelines.exceptions import PipelineNodeBuildError, PipelineNodeRunError
from apps.pipelines.nodes.base import PipelineState
from apps.pipelines.tests.utils import (
code_node,
create_runnable,
end_node,
start_node,
)
from apps.utils.factories.pipelines import PipelineFactory
from apps.utils.pytest import django_db_with_data


@pytest.fixture()
def pipeline():
return PipelineFactory()


EXTRA_FUNCTION = """
def other(foo):
return f"other {foo}"

return other(input)
"""

IMPORTS = """
import json
import datetime
import re
import time
return json.loads(input)
"""


@django_db_with_data(available_apps=("apps.service_providers",))
@mock.patch("apps.pipelines.nodes.base.PipelineNode.logger", mock.Mock())
@pytest.mark.parametrize(
("code", "input", "output"),
[
("return f'Hello, {input}!'", "World", "Hello, World!"),
("", "foo", "foo"), # No code just returns the input
(EXTRA_FUNCTION, "blah", "other blah"), # Calling a separate function is possible
("'foo'", "", "None"), # No return value will return "None"
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if we should instead raise an exception in this case, as it was probably not intentional to not return a string.

(IMPORTS, json.dumps({"a": "b"}), str(json.loads('{"a": "b"}'))), # Importing json will work
],
)
def test_code_node(pipeline, code, input, output):
nodes = [
start_node(),
code_node(code),
end_node(),
]
assert create_runnable(pipeline, nodes).invoke(PipelineState(messages=[input]))["messages"][-1] == output


@django_db_with_data(available_apps=("apps.service_providers",))
@mock.patch("apps.pipelines.nodes.base.PipelineNode.logger", mock.Mock())
def test_code_node_syntax_error(pipeline):
nodes = [
start_node(),
code_node("this{}"),
end_node(),
]
with pytest.raises(PipelineNodeBuildError, match="SyntaxError: invalid syntax at statement: 'this{}'"):
create_runnable(pipeline, nodes).invoke(PipelineState(messages=["World"]))["messages"][-1]


@django_db_with_data(available_apps=("apps.service_providers",))
@mock.patch("apps.pipelines.nodes.base.PipelineNode.logger", mock.Mock())
@pytest.mark.parametrize(
("code", "input", "error"),
[
("import collections", "", "Importing 'collections' is not allowed"),
("return f'Hello, {blah}!'", "", "name 'blah' is not defined"),
],
)
def test_code_node_runtime_errors(pipeline, code, input, error):
nodes = [
start_node(),
code_node(code),
end_node(),
]
with pytest.raises(PipelineNodeRunError, match=error):
create_runnable(pipeline, nodes).invoke(PipelineState(messages=[input]))["messages"][-1]
12 changes: 12 additions & 0 deletions apps/pipelines/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,3 +174,15 @@ def extract_structured_data_node(provider_id: str, provider_model_id: str, data_
"data_schema": data_schema,
},
}


def code_node(code: str | None = None):
if code is None:
code = "return f'Hello, {input}!'"
return {
"id": str(uuid4()),
"type": nodes.CodeNode.__name__,
"params": {
"code": code,
},
}
1 change: 1 addition & 0 deletions requirements/requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ psycopg[binary]
pyTelegramBotAPI==4.12.0
pydantic
pydub # Audio transcription
RestrictedPython
sentry-sdk
slack-bolt
taskbadger
Expand Down
2 changes: 2 additions & 0 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,8 @@ requests-oauthlib==1.3.1
# via django-allauth
requests-toolbelt==1.0.0
# via langsmith
restrictedpython==7.4
# via -r requirements.in
rich==13.6.0
# via typer
rpds-py==0.12.0
Expand Down
Loading