Skip to content

Commit b6f255c

Browse files
Finished functionality to add wrkspace system prompt
1 parent a7cd9de commit b6f255c

File tree

11 files changed

+137
-75
lines changed

11 files changed

+137
-75
lines changed
Binary file not shown.

migrations/versions/a692c8b52308_add_workspace_system_prompt.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,14 @@
55
Create Date: 2025-01-17 16:33:58.464223
66
77
"""
8+
89
from typing import Sequence, Union
910

1011
from alembic import op
11-
import sqlalchemy as sa
12-
1312

1413
# revision identifiers, used by Alembic.
15-
revision: str = 'a692c8b52308'
16-
down_revision: Union[str, None] = '5c2f3eee5f90'
14+
revision: str = "a692c8b52308"
15+
down_revision: Union[str, None] = "5c2f3eee5f90"
1716
branch_labels: Union[str, Sequence[str], None] = None
1817
depends_on: Union[str, Sequence[str], None] = None
1918

src/codegate/api/v1.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,12 @@ async def create_workspace(request: v1_models.CreateWorkspaceRequest):
6262
except AlreadyExistsError:
6363
raise HTTPException(status_code=409, detail="Workspace already exists")
6464
except ValidationError:
65-
raise HTTPException(status_code=400,
66-
detail=("Invalid workspace name. "
67-
"Please use only alphanumeric characters and dashes"))
65+
raise HTTPException(
66+
status_code=400,
67+
detail=(
68+
"Invalid workspace name. " "Please use only alphanumeric characters and dashes"
69+
),
70+
)
6871
except Exception:
6972
raise HTTPException(status_code=500, detail="Internal server error")
7073

src/codegate/db/connection.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,11 @@
3030
alert_queue = asyncio.Queue()
3131
fim_cache = FimCache()
3232

33+
3334
class AlreadyExistsError(Exception):
3435
pass
3536

37+
3638
class DbCodeGate:
3739
_instance = None
3840

@@ -266,7 +268,8 @@ async def add_workspace(self, workspace_name: str) -> Optional[Workspace]:
266268

267269
try:
268270
added_workspace = await self._execute_update_pydantic_model(
269-
workspace, sql, should_raise=True)
271+
workspace, sql, should_raise=True
272+
)
270273
except IntegrityError as e:
271274
logger.debug(f"Exception type: {type(e)}")
272275
raise AlreadyExistsError(f"Workspace {workspace_name} already exists.")
@@ -424,7 +427,7 @@ async def get_active_workspace(self) -> Optional[ActiveWorkspace]:
424427
sql = text(
425428
"""
426429
SELECT
427-
w.id, w.name, s.id as session_id, s.last_update
430+
w.id, w.name, w.system_prompt, s.id as session_id, s.last_update
428431
FROM sessions s
429432
INNER JOIN workspaces w ON w.id = s.active_workspace_id
430433
"""

src/codegate/db/models.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,6 @@ class Workspace(BaseModel):
5050
def validate_name(cls, value):
5151
if not re.match(r"^[a-zA-Z0-9_-]+$", value):
5252
raise ValueError("name must be alphanumeric and can only contain _ and -")
53-
# Avoid workspace names that are the same as commands that way we can do stuff like
54-
# `codegate workspace list` and
55-
# `codegate workspace my-ws system-prompt` without any conflicts
56-
elif value in ["list", "add", "activate", "system-prompt"]:
57-
raise ValueError("name cannot be the same as a command")
5853
return value
5954

6055

@@ -104,5 +99,6 @@ class WorkspaceActive(BaseModel):
10499
class ActiveWorkspace(BaseModel):
105100
id: str
106101
name: str
102+
system_prompt: Optional[str]
107103
session_id: str
108104
last_update: datetime.datetime

src/codegate/pipeline/cli/commands.py

+45-27
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@ async def run(self, args: List[str]) -> str:
3131
@property
3232
def help(self) -> str:
3333
return (
34-
"### CodeGate Version\n\n"
34+
"### CodeGate Version\n"
3535
"Prints the version of CodeGate.\n\n"
36+
"*args*: None\n\n"
3637
"**Usage**: `codegate version`\n\n"
37-
"*args*: None"
3838
)
3939

4040

@@ -46,6 +46,7 @@ def __init__(self):
4646
"list": self._list_workspaces,
4747
"add": self._add_workspace,
4848
"activate": self._activate_workspace,
49+
"system-prompt": self._add_system_prompt,
4950
}
5051

5152
async def _list_workspaces(self, *args: List[str]) -> str:
@@ -66,52 +67,63 @@ async def _add_workspace(self, args: List[str]) -> str:
6667
Add a workspace
6768
"""
6869
if args is None or len(args) == 0:
69-
return "Please provide a name. Use `codegate workspace add your_workspace_name`"
70+
return "Please provide a name. Use `codegate workspace add <workspace_name>`"
7071

7172
new_workspace_name = args[0]
7273
if not new_workspace_name:
73-
return "Please provide a name. Use `codegate workspace add your_workspace_name`"
74+
return "Please provide a name. Use `codegate workspace add <workspace_name>`"
7475

7576
try:
7677
_ = await self.workspace_crud.add_workspace(new_workspace_name)
7778
except ValidationError:
7879
return "Invalid workspace name: It should be alphanumeric and dashes"
7980
except AlreadyExistsError:
80-
return f"Workspace **{new_workspace_name}** already exists"
81+
return f"Workspace `{new_workspace_name}` already exists"
8182
except Exception:
8283
return "An error occurred while adding the workspace"
8384

84-
return f"Workspace **{new_workspace_name}** has been added"
85+
return f"Workspace `{new_workspace_name}` has been added"
8586

8687
async def _activate_workspace(self, args: List[str]) -> str:
8788
"""
8889
Activate a workspace
8990
"""
9091
if args is None or len(args) == 0:
91-
return "Please provide a name. Use `codegate workspace activate workspace_name`"
92+
return "Please provide a name. Use `codegate workspace activate <workspace_name>`"
9293

9394
workspace_name = args[0]
9495
if not workspace_name:
95-
return "Please provide a name. Use `codegate workspace activate workspace_name`"
96+
return "Please provide a name. Use `codegate workspace activate <workspace_name>`"
9697

9798
was_activated = await self.workspace_crud.activate_workspace(workspace_name)
9899
if not was_activated:
99100
return (
100-
f"Workspace **{workspace_name}** does not exist or was already active. "
101+
f"Workspace `{workspace_name}` does not exist or was already active. "
101102
f"Use `codegate workspace add {workspace_name}` to add it"
102103
)
103-
return f"Workspace **{workspace_name}** has been activated"
104+
return f"Workspace `{workspace_name}` has been activated"
104105

105-
async def _add_system_prompt(self, workspace_name: str, sys_prompt_lst: List[str]):
106-
updated_worksapce = await self.workspace_crud.update_workspace_system_prompt(workspace_name, sys_prompt_lst)
106+
async def _add_system_prompt(self, args: List[str]):
107+
if len(args) < 2:
108+
return (
109+
"Please provide a workspace name and a system prompt. "
110+
"Use `codegate workspace system-prompt <workspace_name> <system_prompt>`"
111+
)
112+
113+
workspace_name = args[0]
114+
sys_prompt_lst = args[1:]
115+
116+
updated_worksapce = await self.workspace_crud.update_workspace_system_prompt(
117+
workspace_name, sys_prompt_lst
118+
)
107119
if not updated_worksapce:
108120
return (
109121
f"Workspace system prompt not updated. "
110-
f"Check if the workspace **{workspace_name}** exists"
122+
f"Check if the workspace `{workspace_name}` exists"
111123
)
112124
return (
113-
f"Workspace **{updated_worksapce.name}** system prompt "
114-
f"updated to:\n\n```{updated_worksapce.system_prompt}```"
125+
f"Workspace `{updated_worksapce.name}` system prompt "
126+
f"updated to:\n```\n{updated_worksapce.system_prompt}\n```"
115127
)
116128

117129
async def run(self, args: List[str]) -> str:
@@ -122,23 +134,29 @@ async def run(self, args: List[str]) -> str:
122134
if command_to_execute is not None:
123135
return await command_to_execute(args[1:])
124136
else:
125-
if len(args) >= 2 and args[1] == "system-prompt":
126-
return await self._add_system_prompt(args[0], args[2:])
127137
return "Command not found. Use `codegate workspace -h` to see available commands"
128138

129139
@property
130140
def help(self) -> str:
131141
return (
132-
"### CodeGate Workspace\n\n"
142+
"### CodeGate Workspace\n"
133143
"Manage workspaces.\n\n"
134144
"**Usage**: `codegate workspace <command> [args]`\n\n"
135-
"Available commands:\n\n"
136-
"- `list`: List all workspaces\n\n"
137-
" - *args*: None\n\n"
138-
"- `add`: Add a workspace\n\n"
139-
" - *args*:\n\n"
140-
" - `workspace_name`\n\n"
141-
"- `activate`: Activate a workspace\n\n"
142-
" - *args*:\n\n"
143-
" - `workspace_name`"
145+
"Available commands:\n"
146+
"- `list`: List all workspaces\n"
147+
" - *args*: None\n"
148+
" - **Usage**: `codegate workspace list`\n"
149+
"- `add`: Add a workspace\n"
150+
" - *args*:\n"
151+
" - `workspace_name`\n"
152+
" - **Usage**: `codegate workspace add <workspace_name>`\n"
153+
"- `activate`: Activate a workspace\n"
154+
" - *args*:\n"
155+
" - `workspace_name`\n"
156+
" - **Usage**: `codegate workspace activate <workspace_name>`\n"
157+
"- `system-prompt`: Modify the system-prompt of a workspace\n"
158+
" - *args*:\n"
159+
" - `workspace_name`\n"
160+
" - `system_prompt`\n"
161+
" - **Usage**: `codegate workspace system-prompt <workspace_name> <system_prompt>`\n"
144162
)
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import json
1+
from typing import Optional
22

33
from litellm import ChatCompletionRequest, ChatCompletionSystemMessage
44

@@ -7,6 +7,7 @@
77
PipelineResult,
88
PipelineStep,
99
)
10+
from codegate.workspaces.crud import WorkspaceCrud
1011

1112

1213
class SystemPrompt(PipelineStep):
@@ -16,7 +17,7 @@ class SystemPrompt(PipelineStep):
1617
"""
1718

1819
def __init__(self, system_prompt: str):
19-
self._system_message = ChatCompletionSystemMessage(content=system_prompt, role="system")
20+
self.codegate_system_prompt = system_prompt
2021

2122
@property
2223
def name(self) -> str:
@@ -25,6 +26,44 @@ def name(self) -> str:
2526
"""
2627
return "system-prompt"
2728

29+
async def _get_workspace_system_prompt(self) -> str:
30+
wksp_crud = WorkspaceCrud()
31+
workspace = await wksp_crud.get_active_workspace()
32+
if not workspace:
33+
return ""
34+
35+
return workspace.system_prompt
36+
37+
async def _construct_system_prompt(
38+
self,
39+
wrksp_sys_prompt: str,
40+
req_sys_prompt: Optional[str],
41+
should_add_codegate_sys_prompt: bool,
42+
) -> ChatCompletionSystemMessage:
43+
44+
def _start_or_append(existing_prompt: str, new_prompt: str) -> str:
45+
if existing_prompt:
46+
return existing_prompt + "\n\nHere are additional instructions:\n\n" + new_prompt
47+
return new_prompt
48+
49+
system_prompt = ""
50+
# Add codegate system prompt if secrets or bad packages are found at the beginning
51+
if should_add_codegate_sys_prompt:
52+
system_prompt = _start_or_append(system_prompt, self.codegate_system_prompt)
53+
54+
# Add workspace system prompt if present
55+
if wrksp_sys_prompt:
56+
system_prompt = _start_or_append(system_prompt, wrksp_sys_prompt)
57+
58+
# Add request system prompt if present
59+
if req_sys_prompt and "codegate" not in req_sys_prompt.lower():
60+
system_prompt = _start_or_append(system_prompt, req_sys_prompt)
61+
62+
return system_prompt
63+
64+
async def _should_add_codegate_system_prompt(self, context: PipelineContext) -> bool:
65+
return context.secrets_found or context.bad_packages_found
66+
2867
async def process(
2968
self, request: ChatCompletionRequest, context: PipelineContext
3069
) -> PipelineResult:
@@ -33,32 +72,35 @@ async def process(
3372
to the existing system prompt
3473
"""
3574

36-
# Nothing to do if no secrets or bad_packages are found
37-
if not (context.secrets_found or context.bad_packages_found):
75+
wrksp_sys_prompt = await self._get_workspace_system_prompt()
76+
should_add_codegate_sys_prompt = await self._should_add_codegate_system_prompt(context)
77+
78+
# Nothing to do if no secrets or bad_packages are found and we don't have a workspace
79+
# system prompt
80+
if not should_add_codegate_sys_prompt and not wrksp_sys_prompt:
3881
return PipelineResult(request=request, context=context)
3982

4083
new_request = request.copy()
4184

4285
if "messages" not in new_request:
4386
new_request["messages"] = []
4487

45-
request_system_message = None
88+
request_system_message = {}
4689
for message in new_request["messages"]:
4790
if message["role"] == "system":
4891
request_system_message = message
92+
req_sys_prompt = request_system_message.get("content")
4993

50-
if request_system_message is None:
51-
# Add system message
52-
context.add_alert(self.name, trigger_string=json.dumps(self._system_message))
53-
new_request["messages"].insert(0, self._system_message)
54-
elif "codegate" not in request_system_message["content"].lower():
55-
# Prepend to the system message
56-
prepended_message = (
57-
self._system_message["content"]
58-
+ "\n Here are additional instructions. \n "
59-
+ request_system_message["content"]
60-
)
61-
context.add_alert(self.name, trigger_string=prepended_message)
62-
request_system_message["content"] = prepended_message
94+
system_prompt = await self._construct_system_prompt(
95+
wrksp_sys_prompt, req_sys_prompt, should_add_codegate_sys_prompt
96+
)
97+
context.add_alert(self.name, trigger_string=system_prompt)
98+
if not request_system_message:
99+
# Insert the system prompt at the beginning of the messages
100+
sytem_message = ChatCompletionSystemMessage(content=system_prompt, role="system")
101+
new_request["messages"].insert(0, sytem_message)
102+
else:
103+
# Update the existing system prompt
104+
request_system_message["content"] = system_prompt
63105

64106
return PipelineResult(request=new_request, context=context)

src/codegate/workspaces/crud.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
class WorkspaceCrudError(Exception):
99
pass
1010

11+
1112
class WorkspaceCrud:
1213

1314
def __init__(self):
@@ -24,7 +25,7 @@ async def add_workspace(self, new_workspace_name: str) -> Workspace:
2425
workspace_created = await db_recorder.add_workspace(new_workspace_name)
2526
return workspace_created
2627

27-
async def get_workspaces(self)-> List[WorkspaceActive]:
28+
async def get_workspaces(self) -> List[WorkspaceActive]:
2829
"""
2930
Get all workspaces
3031
"""
@@ -79,8 +80,8 @@ async def activate_workspace(self, workspace_name: str) -> bool:
7980
return True
8081

8182
async def update_workspace_system_prompt(
82-
self, workspace_name: str, sys_prompt_lst: List[str]
83-
) -> Optional[Workspace]:
83+
self, workspace_name: str, sys_prompt_lst: List[str]
84+
) -> Optional[Workspace]:
8485
selected_workspace = await self._db_reader.get_workspace_by_name(workspace_name)
8586
if not selected_workspace:
8687
return None

0 commit comments

Comments
 (0)