Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions 05_src/assignment_chat/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
chroma_store/
a2-env/
Empty file.
Empty file.
61 changes: 61 additions & 0 deletions 05_src/assignment_chat/app/core/guardrails.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from __future__ import annotations

import re

# Restricted topics (must refuse)
RESTRICTED = [
r"\bcats?\b",
r"\bdogs?\b",
r"\bhoroscopes?\b",
r"\bzodiac\b",
r"\btaylor\s+swift\b",
]

# Prompt injection / system prompt exfiltration (must refuse)
# IMPORTANT: keep these specific. Avoid generic phrases like "tell me".
PROMPT_ATTACK = [
r"\bsystem\s+prompt\b",
r"\bdeveloper\s+message\b",
r"\binternal\s+instructions\b",
r"\breveal\b.*\b(system|developer)\b",
r"\bshow\b.*\b(system|developer)\b",
r"\bprint\b.*\b(system|developer)\b",
r"\bignore\b.*\b(instructions|previous)\b",
r"\boverride\b.*\b(instructions|system)\b",
r"\bjailbreak\b",
r"\bprompt\s+injection\b",
]

# Allow safe memory/recall phrasing to pass without false positives
MEMORY_OR_RECALL = re.compile(
r"^\s*(remember|note|save|store)\b|"
r"\bwhat\s+did\s+i\s+(tell|say)\s+you\b|"
r"\bwhat\s+did\s+i\s+mention\b|"
r"\bwhat\s+did\s+i\s+ask\b|"
r"\bdo\s+you\s+remember\b",
re.IGNORECASE,
)

def check_guardrails(user_text: str) -> tuple[bool, str]:
text = (user_text or "").strip()
low = text.lower()

# 1) Block prompt attacks (specific patterns only)
for pat in PROMPT_ATTACK:
if re.search(pat, low):
return (
False,
"I can’t share or modify system/developer instructions. "
"Tell me what you’re trying to do and I’ll help safely."
)

# 2) Allow memory/recall phrasing (prevents accidental blocks)
if MEMORY_OR_RECALL.search(text):
return (True, "")

# 3) Block restricted topics
for pat in RESTRICTED:
if re.search(pat, low):
return (False, "Sorry — I can’t help with that topic.")

return (True, "")
45 changes: 45 additions & 0 deletions 05_src/assignment_chat/app/core/memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Dict, List

@dataclass
class SessionState:
messages: List[dict] = field(default_factory=list) # [{"role": "...","content":"..."}]
rolling_summary: str = ""

class MemoryStore:
def __init__(self, max_turns: int = 12):
self.max_turns = max_turns
self.sessions: Dict[str, SessionState] = {}

def get(self, session_id: str) -> SessionState:
if session_id not in self.sessions:
self.sessions[session_id] = SessionState()
return self.sessions[session_id]

def append(self, session_id: str, role: str, content: str) -> None:
st = self.get(session_id)
st.messages.append({"role": role, "content": content})

# keep last N turns (2 msgs per turn approx)
if len(st.messages) > self.max_turns * 2:
st.messages = st.messages[-self.max_turns * 2 :]

def get_context_messages(self, session_id: str) -> list[dict]:
st = self.get(session_id)
msgs = []
if st.rolling_summary:
msgs.append({"role": "system", "content": f"Conversation summary so far: {st.rolling_summary}"})
msgs.extend(st.messages)
return msgs

def remember(self, session_id: str, fact: str) -> None:
st = self.get(session_id)
fact = fact.strip()
if not fact:
return
if st.rolling_summary:
st.rolling_summary = st.rolling_summary.rstrip() + f" | User note: {fact}"
else:
st.rolling_summary = f"User note: {fact}"
20 changes: 20 additions & 0 deletions 05_src/assignment_chat/app/core/openai_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from __future__ import annotations

import os
from dotenv import load_dotenv
from openai import OpenAI

load_dotenv(".secrets")

BASE_URL = "https://k7uffyg03f.execute-api.us-east-1.amazonaws.com/prod/openai/v1"

def get_client() -> OpenAI:
gateway_key = os.getenv("API_GATEWAY_KEY")
if not gateway_key:
raise RuntimeError("API_GATEWAY_KEY not set. Create .secrets with API_GATEWAY_KEY=XXX at assignment_chat root directory level.")

return OpenAI(
base_url=BASE_URL,
api_key="unused key (but required)",
default_headers={"x-api-key": gateway_key},
)
64 changes: 64 additions & 0 deletions 05_src/assignment_chat/app/gradio_ui.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from __future__ import annotations

import uuid
import requests
import gradio as gr

API_URL = "http://127.0.0.1:8000/chat"


def call_backend(session_id: str, message: str) -> str:
payload = {"session_id": session_id, "message": message}
r = requests.post(API_URL, json=payload, timeout=60)
r.raise_for_status()
return r.json()["reply"]


def chat_fn(message, history, session_id):
if history is None:
history = []

reply = call_backend(session_id, message)

# Append in new messages format
history.append({"role": "user", "content": message})
history.append({"role": "assistant", "content": reply})

return "", history, session_id


def new_session():
return str(uuid.uuid4()), []


with gr.Blocks(title="StudyMate 🤖📚") as demo:
gr.Markdown("## StudyMate 🤖📚")
gr.Markdown(
"Ask about embeddings, ChromaDB, RAG, or generate Mermaid diagrams."
)

session_id = gr.State(str(uuid.uuid4()))
chatbot = gr.Chatbot(height=450) # works across versions
msg = gr.Textbox(
placeholder="Ask something...",
show_label=False,
container=False,
)

clear_btn = gr.Button("New Session")

msg.submit(
chat_fn,
inputs=[msg, chatbot, session_id],
outputs=[msg, chatbot, session_id],
)

clear_btn.click(
fn=new_session,
inputs=None,
outputs=[session_id, chatbot],
)


if __name__ == "__main__":
demo.launch()
Loading