diff --git a/.github/workflows/claude-code-review.yml b/.github/workflows/claude-code-review.yml index b5e8cfd..4c1bf4e 100644 --- a/.github/workflows/claude-code-review.yml +++ b/.github/workflows/claude-code-review.yml @@ -33,6 +33,7 @@ jobs: - name: Run Claude Code Review id: claude-review + continue-on-error: true uses: anthropics/claude-code-action@v1 with: claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }} diff --git a/.github/workflows/claude.yml b/.github/workflows/claude.yml index 6b15fac..8b3fb60 100644 --- a/.github/workflows/claude.yml +++ b/.github/workflows/claude.yml @@ -32,6 +32,7 @@ jobs: - name: Run Claude Code id: claude + continue-on-error: true uses: anthropics/claude-code-action@v1 with: claude_code_oauth_token: ${{ secrets.CLAUDE_CODE_OAUTH_TOKEN }} diff --git a/backend/db/supabase_schema.sql b/backend/db/supabase_schema.sql index 8db0f34..b785d48 100644 --- a/backend/db/supabase_schema.sql +++ b/backend/db/supabase_schema.sql @@ -1,24 +1,54 @@ -- ============================================================ --- Sapling — Supabase Schema +-- Sapling — Supabase Schema (course_id migration) -- Run this in: Supabase Dashboard → SQL Editor → New query -- ============================================================ -- Users CREATE TABLE IF NOT EXISTS users ( - id TEXT PRIMARY KEY, - name TEXT NOT NULL, - email TEXT, - streak_count INTEGER DEFAULT 0, + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + email TEXT, + streak_count INTEGER DEFAULT 0, last_active_date TEXT, - room_id TEXT, - created_at TIMESTAMPTZ DEFAULT now(), - google_id TEXT UNIQUE, - avatar_url TEXT, - auth_provider TEXT DEFAULT 'google' + room_id TEXT, + created_at TIMESTAMPTZ DEFAULT now(), + google_id TEXT UNIQUE, + avatar_url TEXT, + auth_provider TEXT DEFAULT 'google' ); CREATE INDEX IF NOT EXISTS idx_users_google_id ON users(google_id); +-- Canonical course catalog (no user_id — shared across all students) +CREATE TABLE IF NOT EXISTS courses ( + id TEXT PRIMARY KEY, + course_code TEXT NOT NULL, + course_name TEXT NOT NULL, + department TEXT, + credits INTEGER, + semester TEXT DEFAULT 'Spring 2026', + instructor_name TEXT, + meeting_times TEXT, + location TEXT, + description TEXT, + syllabus_url TEXT, + school TEXT, + created_at TIMESTAMPTZ DEFAULT now() +); + +-- Enrollment join table (user ↔ canonical course) +CREATE TABLE IF NOT EXISTS user_courses ( + id TEXT PRIMARY KEY, + user_id TEXT NOT NULL REFERENCES users(id), + course_id TEXT NOT NULL REFERENCES courses(id), + color TEXT, + nickname TEXT, + enrolled_at TIMESTAMPTZ DEFAULT now(), + UNIQUE (user_id, course_id) +); + +CREATE INDEX IF NOT EXISTS idx_user_courses_user_id ON user_courses(user_id); + -- Knowledge graph nodes CREATE TABLE IF NOT EXISTS graph_nodes ( id TEXT PRIMARY KEY, @@ -29,10 +59,13 @@ CREATE TABLE IF NOT EXISTS graph_nodes ( times_studied INTEGER DEFAULT 0, last_studied_at TIMESTAMPTZ, subject TEXT, + course_id TEXT REFERENCES courses(id), created_at TIMESTAMPTZ DEFAULT now(), - mastery_events JSONB DEFAULT '[]' -- array of {ts, delta, reason, event_type} — last 20 events + mastery_events JSONB DEFAULT '[]' ); +CREATE INDEX IF NOT EXISTS idx_graph_nodes_user_course ON graph_nodes(user_id, course_id); + -- Knowledge graph edges CREATE TABLE IF NOT EXISTS graph_edges ( id TEXT PRIMARY KEY, @@ -41,42 +74,30 @@ CREATE TABLE IF NOT EXISTS graph_edges ( target_node_id TEXT NOT NULL REFERENCES graph_nodes(id), strength DOUBLE PRECISION DEFAULT 0.5, created_at TIMESTAMPTZ DEFAULT now(), - relationship_type TEXT DEFAULT 'related' -- 'prerequisite' | 'builds_on' | 'related' -); - --- Migrations (run these if the table already exists) --- ALTER TABLE graph_nodes ADD COLUMN IF NOT EXISTS mastery_events JSONB DEFAULT '[]'; --- ALTER TABLE graph_edges ADD COLUMN IF NOT EXISTS relationship_type TEXT DEFAULT 'related'; - --- Courses -CREATE TABLE IF NOT EXISTS courses ( - id TEXT PRIMARY KEY, - user_id TEXT NOT NULL REFERENCES users(id), - course_name TEXT NOT NULL, - color TEXT, - created_at TIMESTAMPTZ DEFAULT now(), - UNIQUE (user_id, course_name) + relationship_type TEXT DEFAULT 'related' ); -- Learning sessions CREATE TABLE IF NOT EXISTS sessions ( - id TEXT PRIMARY KEY, - user_id TEXT NOT NULL REFERENCES users(id), - mode TEXT NOT NULL, - topic TEXT NOT NULL, - started_at TIMESTAMPTZ DEFAULT now(), - ended_at TIMESTAMPTZ, - summary_json JSONB + id TEXT PRIMARY KEY, + user_id TEXT NOT NULL REFERENCES users(id), + mode TEXT NOT NULL, + topic TEXT NOT NULL, + course_id TEXT REFERENCES courses(id), + started_at TIMESTAMPTZ DEFAULT now(), + ended_at TIMESTAMPTZ, + summary_json JSONB, + name TEXT ); -- Chat messages within a session CREATE TABLE IF NOT EXISTS messages ( - id TEXT PRIMARY KEY, - session_id TEXT NOT NULL REFERENCES sessions(id), - role TEXT NOT NULL, - content TEXT NOT NULL, + id TEXT PRIMARY KEY, + session_id TEXT NOT NULL REFERENCES sessions(id), + role TEXT NOT NULL, + content TEXT NOT NULL, graph_update_json JSONB, - created_at TIMESTAMPTZ DEFAULT now() + created_at TIMESTAMPTZ DEFAULT now() ); -- Quiz attempts @@ -107,7 +128,7 @@ CREATE TABLE IF NOT EXISTS assignments ( id TEXT PRIMARY KEY, user_id TEXT NOT NULL REFERENCES users(id), title TEXT NOT NULL, - course_name TEXT, + course_id TEXT REFERENCES courses(id), due_date TEXT NOT NULL, assignment_type TEXT, notes TEXT, @@ -115,6 +136,64 @@ CREATE TABLE IF NOT EXISTS assignments ( created_at TIMESTAMPTZ DEFAULT now() ); +CREATE INDEX IF NOT EXISTS idx_assignments_user_due ON assignments(user_id, due_date); + +-- Documents (uploaded course materials) +CREATE TABLE IF NOT EXISTS documents ( + id TEXT PRIMARY KEY, + user_id TEXT NOT NULL REFERENCES users(id), + course_id TEXT NOT NULL REFERENCES courses(id), + file_name TEXT NOT NULL, + category TEXT NOT NULL, + summary TEXT, + key_takeaways JSONB, + flashcards JSONB, + created_at TIMESTAMPTZ DEFAULT now(), + processed_at TIMESTAMPTZ +); + +-- Study guides +CREATE TABLE IF NOT EXISTS study_guides ( + id TEXT PRIMARY KEY DEFAULT gen_random_uuid()::TEXT, + user_id TEXT NOT NULL REFERENCES users(id), + course_id TEXT NOT NULL REFERENCES courses(id), + exam_id TEXT NOT NULL, + generated_at TIMESTAMPTZ DEFAULT now(), + content JSONB NOT NULL +); + +-- Per-concept aggregated course stats (across all enrolled students) +CREATE TABLE IF NOT EXISTS course_concept_stats ( + id TEXT PRIMARY KEY DEFAULT gen_random_uuid()::TEXT, + course_id TEXT NOT NULL REFERENCES courses(id), + concept_name TEXT NOT NULL, + semester TEXT NOT NULL DEFAULT 'Spring 2026', + student_count INTEGER DEFAULT 0, + avg_mastery_score DOUBLE PRECISION DEFAULT 0.0, + pct_mastered DOUBLE PRECISION DEFAULT 0.0, + pct_struggling DOUBLE PRECISION DEFAULT 0.0, + pct_unexplored DOUBLE PRECISION DEFAULT 0.0, + common_misconceptions TEXT[] DEFAULT '{}', + effective_explanations TEXT[] DEFAULT '{}', + prerequisite_gaps TEXT[] DEFAULT '{}', + updated_at TIMESTAMPTZ DEFAULT now(), + UNIQUE (course_id, concept_name, semester) +); + +-- Course-wide summary (rolled up from course_concept_stats) +CREATE TABLE IF NOT EXISTS course_summary ( + course_id TEXT NOT NULL REFERENCES courses(id), + semester TEXT NOT NULL DEFAULT 'Spring 2026', + student_count INTEGER DEFAULT 0, + avg_class_mastery DOUBLE PRECISION DEFAULT 0.0, + top_struggling_concepts TEXT[] DEFAULT '{}', + top_mastered_concepts TEXT[] DEFAULT '{}', + summary_text TEXT, + summary_hash TEXT, + updated_at TIMESTAMPTZ DEFAULT now(), + PRIMARY KEY (course_id, semester) +); + -- Study rooms CREATE TABLE IF NOT EXISTS rooms ( id TEXT PRIMARY KEY, @@ -169,21 +248,16 @@ CREATE INDEX IF NOT EXISTS idx_room_messages_room_id ON room_messages(room_id, c -- Emoji reactions on room messages CREATE TABLE IF NOT EXISTS room_reactions ( - id UUID PRIMARY KEY DEFAULT gen_random_uuid(), - message_id UUID NOT NULL REFERENCES room_messages(id) ON DELETE CASCADE, - user_id TEXT NOT NULL REFERENCES users(id), - emoji TEXT NOT NULL, - created_at TIMESTAMPTZ DEFAULT now(), + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + message_id UUID NOT NULL REFERENCES room_messages(id) ON DELETE CASCADE, + user_id TEXT NOT NULL REFERENCES users(id), + emoji TEXT NOT NULL, + created_at TIMESTAMPTZ DEFAULT now(), UNIQUE(message_id, user_id, emoji) ); CREATE INDEX IF NOT EXISTS idx_room_reactions_message_id ON room_reactions(message_id); --- Migrations (run if tables already exist without these columns): --- ALTER TABLE room_messages ADD COLUMN IF NOT EXISTS reply_to_id UUID REFERENCES room_messages(id); --- ALTER TABLE room_messages ADD COLUMN IF NOT EXISTS is_deleted BOOLEAN NOT NULL DEFAULT FALSE; --- ALTER TABLE room_messages ADD COLUMN IF NOT EXISTS edited_at TIMESTAMPTZ; - -- Cached AI summaries for study rooms CREATE TABLE IF NOT EXISTS room_summaries ( room_id TEXT PRIMARY KEY REFERENCES rooms(id), @@ -192,14 +266,7 @@ CREATE TABLE IF NOT EXISTS room_summaries ( updated_at TIMESTAMPTZ DEFAULT now() ); --- Shared course-level learning context (aggregated from all students, no Gemini) -CREATE TABLE IF NOT EXISTS course_context ( - course_name TEXT PRIMARY KEY, - context_json JSONB NOT NULL, - student_count INTEGER DEFAULT 0, - updated_at TIMESTAMPTZ DEFAULT now() -); - +-- Flashcards CREATE TABLE IF NOT EXISTS flashcards ( id TEXT PRIMARY KEY, user_id TEXT NOT NULL REFERENCES users(id), @@ -214,18 +281,38 @@ CREATE TABLE IF NOT EXISTS flashcards ( CREATE INDEX IF NOT EXISTS idx_flashcards_user_topic ON flashcards(user_id, topic); -CREATE TABLE public.flashcards ( - id text NOT NULL, - user_id text NOT NULL, - topic text NOT NULL, - front text NOT NULL, - back text NOT NULL, - times_reviewed integer DEFAULT 0, - last_rating integer, - last_reviewed_at timestamp with time zone, - created_at timestamp with time zone DEFAULT now(), - CONSTRAINT flashcards_pkey PRIMARY KEY (id), - CONSTRAINT flashcards_user_id_fkey FOREIGN KEY (user_id) REFERENCES public.users(id) -); - -CREATE INDEX IF NOT EXISTS idx_flashcards_user_topic ON public.flashcards(user_id, topic); \ No newline at end of file +-- Feedback +CREATE TABLE IF NOT EXISTS feedback ( + id SERIAL PRIMARY KEY, + user_id TEXT NOT NULL, + type TEXT NOT NULL, + rating INTEGER NOT NULL, + selected_options JSONB DEFAULT '[]', + comment TEXT, + session_id TEXT, + topic TEXT, + created_at TIMESTAMPTZ DEFAULT now() +); + +-- Issue reports +CREATE TABLE IF NOT EXISTS issue_reports ( + id SERIAL PRIMARY KEY, + user_id TEXT NOT NULL, + topic TEXT NOT NULL, + description TEXT NOT NULL, + screenshot_urls JSONB DEFAULT '[]', + created_at TIMESTAMPTZ DEFAULT now() +); + +-- Job applications +CREATE TABLE IF NOT EXISTS job_applications ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + position TEXT NOT NULL, + full_name TEXT NOT NULL, + email TEXT NOT NULL, + phone TEXT, + linkedin_url TEXT NOT NULL, + resume TEXT, + portfolio_link TEXT, + submitted_at TIMESTAMPTZ DEFAULT now() +); diff --git a/backend/main.py b/backend/main.py index ee1dc05..a230ec9 100644 --- a/backend/main.py +++ b/backend/main.py @@ -7,8 +7,11 @@ from config import FRONTEND_URL, PORT from routes import graph, learn, quiz, calendar, social, extract, auth, documents, flashcards, study_guide, feedback, careers -from recost.frameworks.fastapi import RecostMiddleware +try: + from recost.frameworks.fastapi import RecostMiddleware +except ImportError: + RecostMiddleware = None # optional; tests/CI without recost package load_dotenv(Path(__file__).with_name(".env")) @@ -17,7 +20,7 @@ app = FastAPI(title="Sapling API", version="1.0.0") -if recost_api_key: +if recost_api_key and RecostMiddleware is not None: app.add_middleware( RecostMiddleware, api_key=recost_api_key, diff --git a/backend/models/__init__.py b/backend/models/__init__.py index e1495e4..4bf677e 100644 --- a/backend/models/__init__.py +++ b/backend/models/__init__.py @@ -9,6 +9,7 @@ class StartSessionBody(BaseModel): topic: str = "" mode: str = "socratic" use_shared_context: bool = True + course_id: Optional[str] = None # Direct course_id lookup instead of resolving from topic class ChatBody(BaseModel): @@ -56,7 +57,7 @@ class SubmitQuizBody(BaseModel): class AssignmentItem(BaseModel): title: str - course_name: str = "" + course_id: str = "" # Changed from course_name to course_id due_date: str assignment_type: str = "other" notes: Optional[str] = None @@ -84,6 +85,20 @@ class ImportSaveBody(BaseModel): assignments: list[AssignmentItem] +# ── Graph (Courses) ───────────────────────────────────────────────────────── + +class AddCourseBody(BaseModel): + """Body for enrolling a user in a course (creating a user_courses record).""" + course_id: str + color: Optional[str] = None + nickname: Optional[str] = None + + +class UpdateCourseColorBody(BaseModel): + """Body for updating a course enrollment's color.""" + color: str + + # ── Social ──────────────────────────────────────────────────────────────────── class CreateRoomBody(BaseModel): @@ -153,4 +168,12 @@ class SubmitIssueReportBody(BaseModel): user_id: str topic: str description: str - screenshot_urls: list[str] = [] \ No newline at end of file + screenshot_urls: list[str] = [] + + +# ── Documents ────────────────────────────────────────────────────────────────── + +class UploadDocumentBody(BaseModel): + """Body for document upload.""" + course_id: str + user_id: str diff --git a/backend/routes/calendar.py b/backend/routes/calendar.py index 8d24413..a6cf9b5 100644 --- a/backend/routes/calendar.py +++ b/backend/routes/calendar.py @@ -101,7 +101,7 @@ def save_assignments(body: SaveAssignmentsBody): payload = [ { "title": a.title, - "course_name": a.course_name, + "course_id": a.course_id, "due_date": a.due_date, "assignment_type": a.assignment_type, "notes": a.notes, @@ -116,43 +116,76 @@ def save_assignments(body: SaveAssignmentsBody): def get_upcoming(user_id: str): today = datetime.utcnow().strftime("%Y-%m-%d") rows = table("assignments").select( - "*", + "*,courses!left(course_code,course_name)", filters={"user_id": f"eq.{user_id}", "due_date": f"gte.{today}"}, order="due_date.asc", limit=20, ) - return {"assignments": rows} + assignments = [] + for r in rows: + course = r.get("courses", {}) if isinstance(r.get("courses"), dict) else {} + assignments.append({ + "id": r["id"], + "user_id": r["user_id"], + "title": r["title"], + "due_date": r["due_date"], + "assignment_type": r.get("assignment_type"), + "notes": r.get("notes"), + "google_event_id": r.get("google_event_id"), + "course_id": r.get("course_id"), + "course_code": course.get("course_code") or r.get("course_code") or "", + "course_name": course.get("course_name") or r.get("course_name") or "", + }) + return {"assignments": assignments} @router.get("/all/{user_id}") def get_all_assignments(user_id: str): """Return all assignments for a user (past and future) for the calendar view.""" rows = table("assignments").select( - "*", + "*,courses!left(course_code,course_name)", filters={"user_id": f"eq.{user_id}"}, order="due_date.asc", ) - return {"assignments": rows} + assignments = [] + for r in rows: + course = r.get("courses", {}) if isinstance(r.get("courses"), dict) else {} + assignments.append({ + "id": r["id"], + "user_id": r["user_id"], + "title": r["title"], + "due_date": r["due_date"], + "assignment_type": r.get("assignment_type"), + "notes": r.get("notes"), + "google_event_id": r.get("google_event_id"), + "course_id": r.get("course_id"), + "course_code": course.get("course_code") or r.get("course_code") or "", + "course_name": course.get("course_name") or r.get("course_name") or "", + }) + return {"assignments": assignments} @router.post("/suggest-study-blocks") def suggest_study_blocks(body: StudyBlockBody): today = datetime.utcnow().strftime("%Y-%m-%d") assignments = table("assignments").select( - "*", + "*,courses!left(course_code,course_name)", filters={"user_id": f"eq.{body.user_id}", "due_date": f"gte.{today}"}, order="due_date.asc", ) - blocks = [ - { - "topic": a["title"], + blocks = [] + for a in assignments: + course = a.get("courses", {}) if isinstance(a.get("courses"), dict) else {} + cc = course.get("course_code") or a.get("course_code") or "" + cn = course.get("course_name") or a.get("course_name") or "" + course_label = f"[{cc}] " if cc else (f"{cn}: " if cn else "") + blocks.append({ + "topic": f"{course_label}{a['title']}" if course_label else a["title"], "suggested_date": a["due_date"], "duration_minutes": 60, "reason": f"Due {a['due_date']}", "related_assignment_id": a["id"], - } - for a in assignments - ] + }) return {"study_blocks": blocks[:5]} @@ -223,7 +256,7 @@ def sync_to_google(body: SyncBody): service = build("calendar", "v3", credentials=creds) unsynced = table("assignments").select( - "*", + "*,courses!left(course_code,course_name)", filters={ "user_id": f"eq.{body.user_id}", "google_event_id": "is.null", @@ -231,7 +264,7 @@ def sync_to_google(body: SyncBody): ) # Also catch empty-string google_event_id unsynced += table("assignments").select( - "*", + "*,courses!left(course_code,course_name)", filters={ "user_id": f"eq.{body.user_id}", "google_event_id": "eq.", @@ -242,8 +275,14 @@ def sync_to_google(body: SyncBody): for a in unsynced: if not a.get("due_date"): continue + + course = a.get("courses", {}) if isinstance(a.get("courses"), dict) else {} + course_code = course.get("course_code") or a.get("course_code") or "" + course_name = course.get("course_name") or a.get("course_name") or "" + course_label = f"[{course_code}] " if course_code else (f"{course_name}: " if course_name else "") + event = { - "summary": f"[{a['course_name']}] {a['title']}" if a.get("course_name") else a["title"], + "summary": f"{course_label}{a['title']}" if course_label else a["title"], "description": a.get("notes") or "", "start": {"date": a["due_date"]}, "end": {"date": a["due_date"]}, @@ -268,7 +307,7 @@ def export_to_google(body: ExportBody): exported = 0 skipped = 0 for aid in body.assignment_ids: - rows = table("assignments").select("*", filters={"id": f"eq.{aid}"}) + rows = table("assignments").select("*,courses!left(course_code,course_name)", filters={"id": f"eq.{aid}"}) if not rows: continue a = rows[0] @@ -277,8 +316,13 @@ def export_to_google(body: ExportBody): skipped += 1 continue + course = a.get("courses", {}) if isinstance(a.get("courses"), dict) else {} + course_code = course.get("course_code") or a.get("course_code") or "" + course_name = course.get("course_name") or a.get("course_name") or "" + course_label = f"[{course_code}] " if course_code else (f"{course_name}: " if course_name else "") + event = { - "summary": f"[{a['course_name']}] {a['title']}" if a.get("course_name") else a["title"], + "summary": f"{course_label}{a['title']}" if course_label else a["title"], "description": a.get("notes") or "", "start": {"date": a["due_date"]}, "end": {"date": a["due_date"]}, diff --git a/backend/routes/documents.py b/backend/routes/documents.py index ffdd472..c30ef21 100644 --- a/backend/routes/documents.py +++ b/backend/routes/documents.py @@ -140,8 +140,11 @@ async def upload_document( if ai.get("category") == "syllabus": try: assignments = ai.get("assignments") or [] - if assignments: - save_assignments_to_db(user_id, assignments) + filtered = [a for a in assignments if isinstance(a, dict)] + for a in filtered: + a["course_id"] = course_id + if filtered: + save_assignments_to_db(user_id, filtered) except Exception: logger.exception("Assignment save failed for '%s' (best-effort)", filename) diff --git a/backend/routes/graph.py b/backend/routes/graph.py index 5a0bb0f..15284ec 100644 --- a/backend/routes/graph.py +++ b/backend/routes/graph.py @@ -1,4 +1,4 @@ -from fastapi import APIRouter +from fastapi import APIRouter, HTTPException from pydantic import BaseModel from typing import Optional @@ -23,8 +23,9 @@ def get_user_recommendations(user_id: str): # ── Course endpoints ────────────────────────────────────────────────────────── class AddCourseBody(BaseModel): - course_name: str + course_id: str color: Optional[str] = None + nickname: Optional[str] = None class UpdateCourseColorBody(BaseModel): @@ -38,14 +39,17 @@ def list_courses(user_id: str): @router.post("/{user_id}/courses") def create_course(user_id: str, body: AddCourseBody): - return add_course(user_id, body.course_name, body.color) + result = add_course(user_id, body.course_id, body.color, body.nickname) + if "error" in result: + raise HTTPException(status_code=404, detail=result["error"]) + return result -@router.patch("/{user_id}/courses/{course_name}/color") -def set_course_color(user_id: str, course_name: str, body: UpdateCourseColorBody): - return update_course_color(user_id, course_name, body.color) +@router.patch("/{user_id}/courses/{course_id}/color") +def set_course_color(user_id: str, course_id: str, body: UpdateCourseColorBody): + return update_course_color(user_id, course_id, body.color) -@router.delete("/{user_id}/courses/{course_name}") -def remove_course(user_id: str, course_name: str): - return delete_course(user_id, course_name) +@router.delete("/{user_id}/courses/{course_id}") +def remove_course(user_id: str, course_id: str): + return delete_course(user_id, course_id) diff --git a/backend/routes/learn.py b/backend/routes/learn.py index 8252098..c9fb745 100644 --- a/backend/routes/learn.py +++ b/backend/routes/learn.py @@ -39,58 +39,85 @@ def _load_prompt(name: str) -> str: } -def _resolve_course(topic: str, user_id: str) -> str: - """Return the subject/course the topic belongs to, or '' if unknown.""" +def _get_course_id_for_topic(topic: str, user_id: str) -> str: + """ + Find the course_id associated with a topic/concept for a user. + First checks if topic matches a course_code or course_name, + then falls back to finding via graph_nodes. + """ if not topic: return "" topic_trim = topic.strip() if not topic_trim: return "" - subject_match = table("graph_nodes").select( - "subject", filters={"user_id": f"eq.{user_id}", "subject": f"eq.{topic_trim}"}, limit=1 - ) - if subject_match: - return topic_trim - concept_match = table("graph_nodes").select( - "subject", filters={"user_id": f"eq.{user_id}", "concept_name": f"eq.{topic_trim}"}, limit=1 - ) - if concept_match: - return concept_match[0].get("subject") or "" - course_match = table("courses").select( - "course_name", filters={"user_id": f"eq.{user_id}", "course_name": f"eq.{topic_trim}"}, limit=1 - ) - if course_match: - return topic_trim + + # First, check if topic matches a course code or name in user's enrolled courses try: - course_rows = table("courses").select( - "course_name", + enrolled = table("user_courses").select( + "course_id,courses!inner(course_code,course_name)", filters={"user_id": f"eq.{user_id}"}, ) - except Exception: - course_rows = [] - for row in course_rows or []: - cn = row.get("course_name") or "" - if cn.lower() == topic_trim.lower(): - return cn + for row in enrolled: + course = row.get("courses", {}) if isinstance(row.get("courses"), dict) else {} + course_code = course.get("course_code", "") + course_name = course.get("course_name", "") + + # Match on course_code (exact or case-insensitive) + if topic_trim.upper() == course_code.upper(): + return row["course_id"] + # Match on course_name + if topic_trim.lower() == course_name.lower(): + return row["course_id"] + # Same label as graph subject roots (graph_service) + label = f"{course_code} - {course_name}" if course_code else course_name + if label and topic_trim == label: + return row["course_id"] + except Exception as e: + print(f"Failed to resolve course_id for topic={topic_trim!r} user_id={user_id!r}: {e}") + + # Fallback: find via graph_nodes - look for nodes matching topic + # that have a course_id + node_rows = table("graph_nodes").select( + "course_id", + filters={ + "user_id": f"eq.{user_id}", + "concept_name": f"eq.{topic_trim}", + }, + limit=10, + ) + for row in (node_rows or []): + if row.get("course_id"): + return row["course_id"] + + # Try matching on subject field (legacy support) + subject_rows = table("graph_nodes").select( + "course_id", + filters={ + "user_id": f"eq.{user_id}", + "subject": f"eq.{topic_trim}", + }, + limit=1, + ) + for row in (subject_rows or []): + if row.get("course_id"): + return row["course_id"] + return "" -def _get_session_topic(session_id: str) -> str: - rows = table("sessions").select("topic", filters={"id": f"eq.{session_id}"}, limit=1) - return rows[0]["topic"] if rows else "" +def _get_session_course_id(session_id: str) -> str: + """Get the course_id from a session if it exists.""" + rows = table("sessions").select("course_id", filters={"id": f"eq.{session_id}"}, limit=1) + if rows and rows[0].get("course_id"): + return rows[0]["course_id"] + return "" -def _get_course_documents(user_id: str, course_name: str) -> list: +def _get_course_documents(user_id: str, course_id: str) -> list: """Fetch uploaded document summaries and key takeaways for a user's course.""" - if not course_name: + if not course_id: return [] try: - course_rows = table("courses").select( - "id", filters={"user_id": f"eq.{user_id}", "course_name": f"eq.{course_name}"}, limit=1 - ) - if not course_rows: - return [] - course_id = course_rows[0]["id"] docs = table("documents").select( "file_name,category,summary,key_takeaways", filters={"user_id": f"eq.{user_id}", "course_id": f"eq.{course_id}"}, @@ -100,12 +127,32 @@ def _get_course_documents(user_id: str, course_name: str) -> list: return [] +def _get_course_info(course_id: str) -> dict: + """Get course info (code and name) for a course_id.""" + if not course_id: + return {"course_code": "", "course_name": ""} + try: + rows = table("courses").select( + "course_code,course_name", + filters={"id": f"eq.{course_id}"}, + limit=1, + ) + if rows: + return { + "course_code": rows[0].get("course_code", ""), + "course_name": rows[0].get("course_name", ""), + } + except Exception as e: + print(f"Failed to load course info for course_id={course_id!r}: {e}") + return {"course_code": "", "course_name": ""} + + def build_system_prompt( mode: str, student_name: str, graph_json: str, last_summary: str = "", - course_name: str = "", + course_id: str = "", use_shared_context: bool = True, documents: list | None = None, ) -> str: @@ -133,12 +180,14 @@ def build_system_prompt( + "\n\n---\n\n".join(doc_blocks) ) - if use_shared_context and course_name: - ctx = get_course_context(course_name) + if use_shared_context and course_id: + ctx = get_course_context(course_id) if ctx: + course_info = _get_course_info(course_id) + course_label = f"{course_info['course_code']} - {course_info['course_name']}" if course_info['course_code'] else course_info['course_name'] shared_block = ( SHARED_CONTEXT_TEMPLATE - .replace("{course_name}", course_name) + .replace("{course_name}", course_label) .replace("{shared_context_json}", json.dumps(ctx, indent=2)) ) parts.append(shared_block) @@ -179,12 +228,18 @@ def _consume_pending(session_id: str, user_id: str) -> None: pending = PENDING_SESSIONS.pop(session_id) if pending["user_id"] != user_id: raise HTTPException(status_code=403, detail="Session user mismatch") - table("sessions").insert({ + + # Include course_id in session creation + session_data = { "id": session_id, "user_id": user_id, "mode": pending["mode"], "topic": pending["topic"], - }) + } + if pending.get("course_id"): + session_data["course_id"] = pending["course_id"] + + table("sessions").insert(session_data) save_message(session_id, "assistant", pending["assistant_reply"], pending["graph_update"]) @@ -199,11 +254,14 @@ def start_session(body: StartSessionBody): student_name = get_user_name(body.user_id) graph_data = get_graph(body.user_id) - course_name = _resolve_course(body.topic, body.user_id) - documents = _get_course_documents(body.user_id, course_name) + + # Use course_id from body, or try to resolve from topic + course_id = body.course_id or _get_course_id_for_topic(body.topic, body.user_id) + documents = _get_course_documents(body.user_id, course_id) + system_prompt = build_system_prompt( body.mode, student_name, json.dumps(graph_data, indent=2), - course_name=course_name, use_shared_context=body.use_shared_context, + course_id=course_id, use_shared_context=body.use_shared_context, documents=documents, ) user_message = ( @@ -217,11 +275,13 @@ def start_session(body: StartSessionBody): raise HTTPException(status_code=502, detail=f"Gemini error: {e}") reply, graph_update = extract_graph_update(raw) - apply_graph_update(body.user_id, graph_update) + apply_graph_update(body.user_id, graph_update, course_id=course_id) + PENDING_SESSIONS[session_id] = { "user_id": body.user_id, "mode": body.mode, "topic": body.topic, + "course_id": course_id, "use_shared_context": body.use_shared_context, "assistant_reply": reply, "graph_update": graph_update, @@ -243,12 +303,14 @@ def chat(body: ChatBody): graph_data = get_graph(body.user_id) # Exclude the just-saved user message so history is prior turns only history = get_conversation_history(body.session_id)[:-1] - topic = _get_session_topic(body.session_id) - course_name = _resolve_course(topic, body.user_id) - documents = _get_course_documents(body.user_id, course_name) + + # Get course_id from session if available + course_id = _get_session_course_id(body.session_id) + documents = _get_course_documents(body.user_id, course_id) + system_prompt = build_system_prompt( body.mode, student_name, json.dumps(graph_data, indent=2), - course_name=course_name, use_shared_context=body.use_shared_context, + course_id=course_id, use_shared_context=body.use_shared_context, documents=documents, ) @@ -259,7 +321,7 @@ def chat(body: ChatBody): reply, graph_update = extract_graph_update(raw) save_message(body.session_id, "assistant", reply, graph_update) - mastery_changes = apply_graph_update(body.user_id, graph_update) + mastery_changes = apply_graph_update(body.user_id, graph_update, course_id=course_id) return {"reply": reply, "graph_update": graph_update, "mastery_changes": mastery_changes} @@ -349,6 +411,7 @@ def list_sessions(user_id: str, limit: int = 10): "id": s["id"], "topic": s["topic"], "mode": s["mode"], + "course_id": s.get("course_id"), "started_at": s["started_at"], "ended_at": s.get("ended_at"), "message_count": len(msgs), @@ -381,6 +444,7 @@ def resume_session(session_id: str): "user_id": p["user_id"], "topic": p["topic"], "mode": p["mode"], + "course_id": p.get("course_id"), "started_at": now, "ended_at": None, }, @@ -395,7 +459,7 @@ def resume_session(session_id: str): } session_rows = table("sessions").select( - "id,user_id,topic,mode,started_at,ended_at", + "id,user_id,topic,mode,started_at,ended_at,course_id", filters={"id": f"eq.{session_id}"}, ) if not session_rows: @@ -424,12 +488,14 @@ def action(body: ActionBody): student_name = get_user_name(body.user_id) graph_data = get_graph(body.user_id) history = get_conversation_history(body.session_id) - topic = _get_session_topic(body.session_id) - course_name = _resolve_course(topic, body.user_id) - documents = _get_course_documents(body.user_id, course_name) + + # Get course_id from session + course_id = _get_session_course_id(body.session_id) + documents = _get_course_documents(body.user_id, course_id) + system_prompt = build_system_prompt( body.mode, student_name, json.dumps(graph_data, indent=2), - course_name=course_name, use_shared_context=body.use_shared_context, + course_id=course_id, use_shared_context=body.use_shared_context, documents=documents, ) action_message = f"[ACTION: {action_prompts.get(body.action_type, '')}]" @@ -441,7 +507,7 @@ def action(body: ActionBody): reply, graph_update = extract_graph_update(raw) save_message(body.session_id, "assistant", reply, graph_update) - apply_graph_update(body.user_id, graph_update) + apply_graph_update(body.user_id, graph_update, course_id=course_id) return {"reply": reply, "graph_update": graph_update} @@ -449,7 +515,11 @@ def action(body: ActionBody): def mode_switch(body: ModeSwitchBody): _ensure_session_ready(body.session_id, body.user_id) student_name = get_user_name(body.user_id).split()[0] - topic = _get_session_topic(body.session_id) + session_rows = table("sessions").select( + "topic", filters={"id": f"eq.{body.session_id}"}, limit=1 + ) + topic = session_rows[0]["topic"] if session_rows else "this topic" + mode_label = MODE_DISPLAY_NAMES.get(body.new_mode, body.new_mode) reply = ( diff --git a/backend/routes/quiz.py b/backend/routes/quiz.py index ffa1899..12a5da9 100644 --- a/backend/routes/quiz.py +++ b/backend/routes/quiz.py @@ -44,13 +44,28 @@ def generate_quiz(body: GenerateQuizBody): ) # Append shared course-level context (misconceptions + weak areas) if available - subject = node.get("subject", "") - if body.use_shared_context and subject: + course_id = node.get("course_id", "") + if body.use_shared_context and course_id: from services.course_context_service import get_course_context - course_ctx = get_course_context(subject) + course_ctx = get_course_context(course_id) if course_ctx: - misconceptions = course_ctx.get("common_misconceptions", []) - weak_areas = course_ctx.get("weak_areas", []) + misconceptions: list[str] = [] + weak_areas: list[str] = [] + seen_m: set[str] = set() + seen_w: set[str] = set() + for row in course_ctx.get("concept_stats") or []: + if not isinstance(row, dict): + continue + for m in row.get("common_misconceptions") or []: + m = (m or "").strip() + if m and m.lower() not in seen_m: + seen_m.add(m.lower()) + misconceptions.append(m) + for w in row.get("prerequisite_gaps") or []: + w = (w or "").strip() + if w and w.lower() not in seen_w: + seen_w.add(w.lower()) + weak_areas.append(w) if misconceptions or weak_areas: addendum_parts = [] if misconceptions: diff --git a/backend/services/calendar_service.py b/backend/services/calendar_service.py index b37597f..1216503 100644 --- a/backend/services/calendar_service.py +++ b/backend/services/calendar_service.py @@ -41,7 +41,7 @@ def insert_new_assignments(user_id: str, assignments: list[dict]) -> int: "id": str(uuid.uuid4()), "user_id": user_id, "title": title, - "course_name": a.get("course_name") or "", + "course_id": a.get("course_id") or None, "due_date": key[1], "assignment_type": a.get("assignment_type") or "other", "notes": a.get("notes"), diff --git a/backend/services/course_context_service.py b/backend/services/course_context_service.py index 21e148c..6dcbf55 100644 --- a/backend/services/course_context_service.py +++ b/backend/services/course_context_service.py @@ -1,170 +1,333 @@ """ services/course_context_service.py -Builds and caches a shared course-level context from real DB data. -No Gemini calls — pure aggregation of graph_nodes and quiz_context. +Builds and caches shared course-level context from real DB data. +Aggregates graph_nodes mastery data and quiz_context across all students in a course. -The context is stored in the course_context table and refreshed every time -any student's graph is updated (via apply_graph_update in graph_service.py). +Stores data in: +- course_concept_stats: per-concept aggregated metrics +- course_summary: course-wide summary with Gemini-generated text """ +import json +import hashlib from datetime import datetime, timezone from db.connection import table +from services.gemini_service import call_gemini -def get_course_context(course_name: str) -> dict: - """Return the cached context_json for a course, or {} if not yet built.""" - if not course_name: +def _generate_data_hash(stats_rows: list) -> str: + """Generate a hash of the stats data to detect changes.""" + data_str = json.dumps(stats_rows, sort_keys=True, default=str) + return hashlib.sha256(data_str.encode()).hexdigest() + + +def _generate_summary_with_gemini( + course_code: str, + course_name: str, + avg_class_mastery: float, + top_struggling: list, + top_mastered: list, + student_count: int, +) -> str: + """Generate a natural language summary using Gemini.""" + prompt = f"""You are an expert education analyst summarizing a course for instructors. + +Course: {course_code} - {course_name} +Students enrolled: {student_count} +Average class mastery: {avg_class_mastery:.1%} + +Top struggling concepts (needs attention): +{chr(10).join(f"- {c}" for c in top_struggling) if top_struggling else "None identified"} + +Top mastered concepts (students doing well): +{chr(10).join(f"- {c}" for c in top_mastered) if top_mastered else "None identified"} + +Write a concise 2-3 paragraph summary that: +1. Describes the overall class performance +2. Highlights specific areas where students are struggling and may need intervention +3. Notes areas where students are excelling +4. Provides actionable recommendations for the instructor + +Write in a professional but approachable tone. Be specific and data-driven.""" + + try: + return call_gemini(prompt, retries=1) + except Exception: + # Fallback summary if Gemini fails + return ( + f"Class average mastery: {avg_class_mastery:.1%}. " + f"Students are struggling with: {', '.join(top_struggling[:3]) if top_struggling else 'No major areas identified'}. " + f"Students have mastered: {', '.join(top_mastered[:3]) if top_mastered else 'No areas identified yet'}." + ) + + +def get_course_context(course_id: str) -> dict: + """ + Return the cached course context including summary and concept stats. + Returns dict with course_summary + course_concept_stats, or {} if not found. + """ + if not course_id: return {} + try: - rows = table("course_context").select( - "context_json", - filters={"course_name": f"eq.{course_name}"}, + # Get course summary + summary_rows = table("course_summary").select( + "*", + filters={"course_id": f"eq.{course_id}"}, + ) + if not summary_rows: + return {} + + summary = summary_rows[0] + semester = summary["semester"] + + # Get concept stats for this course, scoped to the same semester + stats_rows = table("course_concept_stats").select( + "*", + filters={"course_id": f"eq.{course_id}", "semester": f"eq.{semester}"}, ) - return rows[0]["context_json"] if rows else {} + + return { + "course_summary": { + "course_id": summary["course_id"], + "semester": summary["semester"], + "student_count": summary["student_count"], + "avg_class_mastery": summary["avg_class_mastery"], + "top_struggling_concepts": summary.get("top_struggling_concepts", []), + "top_mastered_concepts": summary.get("top_mastered_concepts", []), + "summary_text": summary.get("summary_text", ""), + "updated_at": summary["updated_at"], + }, + "concept_stats": stats_rows or [], + } except Exception: return {} -def update_course_context(course_name: str) -> None: +def update_course_context(course_id: str) -> None: """ - Aggregate mastery + quiz data for all students in course_name and upsert - into the course_context table. Called automatically after any graph update. + Aggregate mastery + quiz data for all students enrolled in the course and upsert + into course_concept_stats and course_summary tables. + Called automatically after any graph update. """ - if not course_name: + if not course_id: return - # ── 1. All graph nodes for this course across every user ───────────────── + # ── 1. Get all students enrolled in this course via user_courses ─────────── + enrollment_rows = table("user_courses").select( + "user_id", + filters={"course_id": f"eq.{course_id}"}, + ) + if not enrollment_rows: + # No students enrolled — purge any stale aggregates + table("course_concept_stats").delete({"course_id": f"eq.{course_id}"}) + table("course_summary").delete({"course_id": f"eq.{course_id}"}) + return + + user_ids = [r["user_id"] for r in enrollment_rows] + student_count = len(user_ids) + + # ── 2. Get course info (including canonical semester) for the summary ────── + course_rows = table("courses").select( + "course_code,course_name,semester", + filters={"id": f"eq.{course_id}"}, + ) + course_info = course_rows[0] if course_rows else {"course_code": "", "course_name": "", "semester": "Spring 2026"} + semester = course_info.get("semester") or "Spring 2026" + + # ── 3. All graph nodes for this course across every enrolled student ────── + # Build user_id filter for PostgREST + user_filter = ",".join(user_ids) node_rows = table("graph_nodes").select( "id,concept_name,mastery_score,mastery_tier,user_id", - filters={"subject": f"eq.{course_name}"}, + filters={"course_id": f"eq.{course_id}", "user_id": f"in.({user_filter})"}, ) if not node_rows: - return + return # No graph data yet for this course - # ── 2. Group by concept_name, track per-user scores ────────────────────── + # ── 4. Group by concept_name, track per-user scores ─────────────────────── concept_data: dict = {} - user_ids: set = set() - node_id_set: set = set() for n in node_rows: - user_ids.add(n["user_id"]) - node_id_set.add(n["id"]) name = n["concept_name"] if name not in concept_data: - concept_data[name] = {"scores": [], "tiers": []} + concept_data[name] = {"scores": [], "tiers": [], "node_ids": []} concept_data[name]["scores"].append(float(n["mastery_score"] or 0.0)) concept_data[name]["tiers"].append(n["mastery_tier"] or "unexplored") + concept_data[name]["node_ids"].append(n["id"]) - student_count = len(user_ids) - - # ── 3. Compute per-concept metrics ──────────────────────────────────────── + # ── 5. Compute per-concept metrics ──────────────────────────────────────── concept_metrics: dict = {} + all_scores: list = [] + for name, data in concept_data.items(): scores = data["scores"] tiers = data["tiers"] n_s = len(scores) + + if n_s == 0: + continue + avg_mastery = sum(scores) / n_s - struggling_pct = sum(1 for t in tiers if t == "struggling") / n_s - mastered_pct = sum(1 for t in tiers if t == "mastered") / n_s + all_scores.extend(scores) + + struggling_count = sum(1 for t in tiers if t == "struggling") + mastered_count = sum(1 for t in tiers if t == "mastered") + unexplored_count = sum(1 for t in tiers if t == "unexplored") + concept_metrics[name] = { - "avg_mastery": round(avg_mastery, 3), - "struggling_pct": round(struggling_pct, 2), - "mastered_pct": round(mastered_pct, 2), + "avg_mastery_score": round(avg_mastery, 4), + "pct_struggling": round(struggling_count / n_s, 4), + "pct_mastered": round(mastered_count / n_s, 4), + "pct_unexplored": round(unexplored_count / n_s, 4), + "student_count": n_s, + "node_ids": data["node_ids"], } - struggling_concepts = sorted( - [ - { - "concept": name, - "avg_mastery": m["avg_mastery"], - "struggling_pct": m["struggling_pct"], - } - for name, m in concept_metrics.items() - if m["struggling_pct"] > 0.2 - ], - key=lambda x: x["avg_mastery"], - ) + # ── 6. Helpers: quiz_context rows for a set of graph node ids ──────────── + def _fetch_quiz_context_rows(node_ids: list) -> list: + if not node_ids: + return [] + chunk_size = 80 + out = [] + for i in range(0, len(node_ids), chunk_size): + chunk = node_ids[i : i + chunk_size] + node_filter = ",".join(chunk) + try: + rows = table("quiz_context").select( + "concept_node_id,context_json", + filters={"concept_node_id": f"in.({node_filter})"}, + ) + except Exception: + rows = [] + out.extend(rows or []) + return out + + def _parse_quiz_context_to_arrays(ctx_rows: list) -> tuple[list, list, list]: + common_misconceptions: list = [] + effective_explanations: list = [] + prerequisite_gaps: list = [] + seen_misconceptions: set = set() + seen_explanations: set = set() + seen_prereqs: set = set() + + for ctx in ctx_rows: + cj = ctx.get("context_json") or {} + if isinstance(cj, str): + try: + cj = json.loads(cj) + except Exception: + cj = {} + + for m in cj.get("common_mistakes", []): + m = (m or "").strip() + if m and m.lower() not in seen_misconceptions: + seen_misconceptions.add(m.lower()) + common_misconceptions.append(m) + + for w in cj.get("weak_areas", []): + w = (w or "").strip() + if w and w.lower() not in seen_prereqs: + seen_prereqs.add(w.lower()) + prerequisite_gaps.append(w) + + for exp in cj.get("effective_explanations", []): + exp = (exp or "").strip() + if exp and exp.lower() not in seen_explanations: + seen_explanations.add(exp.lower()) + effective_explanations.append(exp) + + return common_misconceptions[:20], effective_explanations[:20], prerequisite_gaps[:20] + + # ── 7. Upsert into course_concept_stats (quiz arrays per concept) ───────── + for name, metrics in concept_metrics.items(): + node_ids_for_concept = concept_data.get(name, {}).get("node_ids", []) + ctx_rows = _fetch_quiz_context_rows(node_ids_for_concept) + cm, ee, pg = _parse_quiz_context_to_arrays(ctx_rows) - mastered_concepts = sorted( - [ + table("course_concept_stats").upsert( { - "concept": name, - "avg_mastery": m["avg_mastery"], - "mastered_pct": m["mastered_pct"], - } - for name, m in concept_metrics.items() - if m["mastered_pct"] > 0.6 - ], - key=lambda x: x["avg_mastery"], + "course_id": course_id, + "concept_name": name, + "semester": semester, + "student_count": metrics["student_count"], + "avg_mastery_score": metrics["avg_mastery_score"], + "pct_mastered": metrics["pct_mastered"], + "pct_struggling": metrics["pct_struggling"], + "pct_unexplored": metrics["pct_unexplored"], + "common_misconceptions": cm, + "effective_explanations": ee, + "prerequisite_gaps": pg, + "updated_at": datetime.now(timezone.utc).isoformat(), + }, + on_conflict="course_id,concept_name,semester", + ) + + # ── 8. Compute course-wide summary metrics ──────────────────────────────── + avg_class_mastery = round(sum(all_scores) / len(all_scores), 4) if all_scores else 0.0 + + # Sort for top struggling (highest pct_struggling) and top mastered, excluding zeros + sorted_by_struggling = sorted( + [(name, m) for name, m in concept_metrics.items() if m["pct_struggling"] > 0.0], + key=lambda x: x[1]["pct_struggling"], reverse=True, ) + top_struggling_concepts = [name for name, _ in sorted_by_struggling[:5]] - concept_difficulty_ranking = sorted( - [ - {"concept": name, "avg_mastery": m["avg_mastery"]} - for name, m in concept_metrics.items() - ], - key=lambda x: x["avg_mastery"], # lowest mastery = hardest + sorted_by_mastered = sorted( + [(name, m) for name, m in concept_metrics.items() if m["pct_mastered"] > 0.0], + key=lambda x: x[1]["pct_mastered"], + reverse=True, ) + top_mastered_concepts = [name for name, _ in sorted_by_mastered[:5]] - # ── 4. Pull quiz_context for this course's students, filter by node ─────── - user_id_list = list(user_ids) - try: - ctx_rows_all = table("quiz_context").select( - "concept_node_id,context_json", - filters={"user_id": f"in.({','.join(user_id_list)})"}, - ) - except Exception: - ctx_rows_all = [] - - # Keep only contexts for concepts that belong to this course - ctx_rows = [r for r in ctx_rows_all if r["concept_node_id"] in node_id_set] + # Generate data hash to detect changes + stats_for_hash = [ + { + "concept": name, + "avg_mastery": m["avg_mastery_score"], + "pct_struggling": m["pct_struggling"], + "pct_mastered": m["pct_mastered"], + } + for name, m in concept_metrics.items() + ] + current_hash = _generate_data_hash(stats_for_hash) - # ── 5. Deduplicate misconceptions and weak areas (case-insensitive) ─────── - seen: set = set() - common_misconceptions: list = [] - seen2: set = set() - weak_areas: list = [] + # ── 9. Check if summary needs regeneration ───────────────────────────────── + existing_summary_rows = table("course_summary").select( + "summary_hash,summary_text", + filters={"course_id": f"eq.{course_id}", "semester": f"eq.{semester}"}, + ) + + existing_hash = existing_summary_rows[0]["summary_hash"] if existing_summary_rows else None + + # Regenerate summary only if data changed or no existing summary + if current_hash != existing_hash or not existing_summary_rows: + summary_text = _generate_summary_with_gemini( + course_info.get("course_code", ""), + course_info.get("course_name", ""), + avg_class_mastery, + top_struggling_concepts, + top_mastered_concepts, + student_count, + ) + else: + summary_text = existing_summary_rows[0].get("summary_text", "") - for ctx in ctx_rows: - cj = ctx.get("context_json") or {} - if isinstance(cj, str): - import json as _json - try: - cj = _json.loads(cj) - except Exception: - cj = {} - - for m in cj.get("common_mistakes", []): - m = (m or "").strip() - if m and m.lower() not in seen: - seen.add(m.lower()) - common_misconceptions.append(m) - - for w in cj.get("weak_areas", []): - w = (w or "").strip() - if w and w.lower() not in seen2: - seen2.add(w.lower()) - weak_areas.append(w) - - # ── 6. Upsert into course_context ───────────────────────────────────────── - context = { - "struggling_concepts": struggling_concepts, - "mastered_concepts": mastered_concepts, - "concept_difficulty_ranking": concept_difficulty_ranking, - "common_misconceptions": common_misconceptions, - "weak_areas": weak_areas, - "student_count": student_count, - } - - table("course_context").upsert( + # ── 10. Upsert into course_summary ──────────────────────────────────────── + table("course_summary").upsert( { - "course_name": course_name, - "context_json": context, + "course_id": course_id, + "semester": semester, "student_count": student_count, + "avg_class_mastery": avg_class_mastery, + "top_struggling_concepts": top_struggling_concepts, + "top_mastered_concepts": top_mastered_concepts, + "summary_text": summary_text, + "summary_hash": current_hash, "updated_at": datetime.now(timezone.utc).isoformat(), }, - on_conflict="course_name", + on_conflict="course_id,semester", ) diff --git a/backend/services/graph_service.py b/backend/services/graph_service.py index beb8446..3fc4515 100644 --- a/backend/services/graph_service.py +++ b/backend/services/graph_service.py @@ -5,27 +5,24 @@ from db.connection import table -def _user_course_titles(user_id: str) -> set[str]: +def _user_enrolled_courses(user_id: str) -> list[dict]: + """Get all courses a user is enrolled in via user_courses join.""" try: - rows = table("courses").select( - "course_name", + rows = table("user_courses").select( + "id,course_id,color,nickname,enrolled_at,courses!inner(course_code,course_name,department,school)", filters={"user_id": f"eq.{user_id}"}, ) except Exception: - return set() - return {r["course_name"] for r in (rows or []) if r.get("course_name")} + return [] + return rows or [] -def _filter_course_title_seed_nodes(nodes: list, course_titles: set[str]) -> list: - """Remove rows that only duplicated the course hub (concept_name == subject == registered course).""" - out = [] - for n in nodes: - subj = (n.get("subject") or "").strip() - concept = (n.get("concept_name") or "").strip() - if subj and concept == subj and subj in course_titles: - continue - out.append(n) - return out +def _get_course_nodes(user_id: str, course_id: str) -> list: + """Get graph nodes for a specific course.""" + return table("graph_nodes").select( + "*", + filters={"user_id": f"eq.{user_id}", "course_id": f"eq.{course_id}"}, + ) or [] def ensure_user_exists(user_id: str) -> None: @@ -89,9 +86,13 @@ def _compute_velocity(events: list) -> float: def get_graph(user_id: str) -> dict: ensure_user_exists(user_id) - course_titles = _user_course_titles(user_id) + + # Get all enrolled courses for this user + enrolled_courses = _user_enrolled_courses(user_id) + + # Get all graph nodes for this user nodes_raw = table("graph_nodes").select("*", filters={"user_id": f"eq.{user_id}"}) - nodes = _filter_course_title_seed_nodes(nodes_raw or [], course_titles) + nodes = nodes_raw or [] node_ids = {n["id"] for n in nodes} edges_raw = table("graph_edges").select("*", filters={"user_id": f"eq.{user_id}"}) @@ -134,33 +135,40 @@ def get_graph(user_id: str) -> dict: "avg_learning_velocity": avg_velocity, } - subject_map: dict = {} - for n in nodes: - subj = n.get("subject") or "General" - subject_map.setdefault(subj, []).append(n) - - for title in course_titles: - subject_map.setdefault(title, []) - + # Build subject root hubs from enrolled courses subject_nodes = [] subject_edges = [] - for subj, subj_nodes in subject_map.items(): - root_id = f"subject_root__{subj}" + + for enrollment in enrolled_courses: + course_id = enrollment["course_id"] + course = enrollment.get("courses", {}) if isinstance(enrollment.get("courses"), dict) else {} + course_code = course.get("course_code", "") + course_name = course.get("course_name", "") + + # Use "Course Code - Course Name" as the subject label + subject_label = f"{course_code} - {course_name}" if course_code else course_name + + # Find all nodes belonging to this course + subj_nodes = [n for n in nodes if n.get("course_id") == course_id] + + root_id = f"subject_root__{course_id}" if subj_nodes: avg_mastery = sum(n["mastery_score"] for n in subj_nodes) / len(subj_nodes) else: avg_mastery = 0.0 + subject_nodes.append({ "id": root_id, "user_id": user_id, - "concept_name": subj, + "concept_name": subject_label, "mastery_score": round(avg_mastery, 4), "mastery_tier": "subject_root", - "subject": subj, + "course_id": course_id, "times_studied": sum(n.get("times_studied", 0) for n in subj_nodes), "last_studied_at": None, "is_subject_root": True, }) + for n in subj_nodes: subject_edges.append({ "id": f"subject_edge__{root_id}__{n['id']}", @@ -176,91 +184,139 @@ def get_graph(user_id: str) -> dict: # ── Course management ────────────────────────────────────────────────────────── def get_courses(user_id: str) -> list: + """ + Return user's enrolled courses joined with canonical course data. + Returns list of dicts with: enrollment_id, course_id, course_code, course_name, + school, department, color, nickname, node_count, enrolled_at + """ try: - rows = table("courses").select( - "id,course_name,color,created_at", + rows = table("user_courses").select( + "id,course_id,color,nickname,enrolled_at,courses!inner(course_code,course_name,school,department)", filters={"user_id": f"eq.{user_id}"}, - order="created_at.asc", + order="enrolled_at.asc", ) except Exception: return [] + result = [] for r in rows: + course = r.get("courses", {}) if isinstance(r.get("courses"), dict) else {} + course_id = r["course_id"] + + # Count nodes for this course node_rows = table("graph_nodes").select( "id", - filters={"user_id": f"eq.{user_id}", "subject": f"eq.{r['course_name']}"}, + filters={"user_id": f"eq.{user_id}", "course_id": f"eq.{course_id}"}, ) + result.append({ - "id": r["id"], - "course_name": r["course_name"], - "color": r["color"], + "enrollment_id": r["id"], + "course_id": course_id, + "course_code": course.get("course_code", ""), + "course_name": course.get("course_name", ""), + "school": course.get("school", ""), + "department": course.get("department", ""), + "color": r.get("color"), + "nickname": r.get("nickname"), "node_count": len(node_rows), - "created_at": r["created_at"], + "enrolled_at": r["enrolled_at"], }) return result -def add_course(user_id: str, course_name: str, color: str | None = None) -> dict: - existing = table("courses").select( +def add_course(user_id: str, course_id: str, color: str | None = None, nickname: str | None = None) -> dict: + """ + Enroll a user in a course (insert into user_courses). + course_id refers to the canonical courses table. + """ + # Check if already enrolled + existing = table("user_courses").select( "id", - filters={"user_id": f"eq.{user_id}", "course_name": f"eq.{course_name}"}, + filters={"user_id": f"eq.{user_id}", "course_id": f"eq.{course_id}"}, ) if existing: - return {"course_name": course_name, "already_existed": True} - table("courses").insert({ + return {"course_id": course_id, "already_existed": True} + + # Verify the course exists in canonical courses + course_check = table("courses").select("id", filters={"id": f"eq.{course_id}"}) + if not course_check: + return {"course_id": course_id, "error": "Course not found in catalog"} + + table("user_courses").insert({ "id": str(uuid.uuid4()), "user_id": user_id, - "course_name": course_name, + "course_id": course_id, "color": color, + "nickname": nickname, }) - return {"course_name": course_name, "already_existed": False} + try: + from services.course_context_service import update_course_context + update_course_context(course_id) + except Exception: + pass + return {"course_id": course_id, "already_existed": False} -def update_course_color(user_id: str, course_name: str, color: str) -> dict: - table("courses").update( +def update_course_color(user_id: str, course_id: str, color: str) -> dict: + """Update the color for a user's course enrollment.""" + table("user_courses").update( {"color": color}, - filters={"user_id": f"eq.{user_id}", "course_name": f"eq.{course_name}"}, + filters={"user_id": f"eq.{user_id}", "course_id": f"eq.{course_id}"}, ) return {"updated": True} -def delete_course(user_id: str, course_name: str) -> dict: - node_rows = table("graph_nodes").select( - "id", - filters={"user_id": f"eq.{user_id}", "subject": f"eq.{course_name}"}, +def update_course_nickname(user_id: str, course_id: str, nickname: str) -> dict: + """Update the nickname for a user's course enrollment.""" + table("user_courses").update( + {"nickname": nickname}, + filters={"user_id": f"eq.{user_id}", "course_id": f"eq.{course_id}"}, ) - node_ids = [n["id"] for n in node_rows] - - if node_ids: - ids_str = ",".join(node_ids) - # Delete all tables that FK-reference graph_nodes before deleting nodes - table("quiz_context").delete({"concept_node_id": f"in.({ids_str})"}) - table("quiz_attempts").delete({"concept_node_id": f"in.({ids_str})"}) - table("graph_edges").delete({"source_node_id": f"in.({ids_str})"}) - table("graph_edges").delete({"target_node_id": f"in.({ids_str})"}) - table("graph_nodes").delete( - {"user_id": f"eq.{user_id}", "subject": f"eq.{course_name}"} - ) + return {"updated": True} - table("courses").delete( - {"user_id": f"eq.{user_id}", "course_name": f"eq.{course_name}"} + +def delete_course(user_id: str, course_id: str) -> dict: + """ + Unenroll a user from a course (delete from user_courses). + Note: We don't delete the graph nodes - they remain for potential re-enrollment. + """ + # Just delete the enrollment, not the nodes + table("user_courses").delete( + {"user_id": f"eq.{user_id}", "course_id": f"eq.{course_id}"} ) + try: + from services.course_context_service import update_course_context + update_course_context(course_id) + except Exception: + pass return {"deleted": True} -def apply_graph_update(user_id: str, graph_update: dict) -> list: - """Apply a graph_update dict to the DB. Returns mastery_changes list.""" +def _node_filters(user_id: str, concept_name: str, course_id: str | None) -> dict: + f = {"user_id": f"eq.{user_id}", "concept_name": f"eq.{concept_name}"} + if course_id: + f["course_id"] = f"eq.{course_id}" + return f + + +def apply_graph_update(user_id: str, graph_update: dict, course_id: str | None = None) -> list: + """ + Apply a graph_update dict to the DB. Returns mastery_changes list. + If course_id is provided, all new/updated nodes will be associated with that course. + """ mastery_changes = [] - touched_subjects: set = set() + touched_courses: set = set() for new_node in graph_update.get("new_nodes", []): name = new_node.get("concept_name", "") - subject = new_node.get("subject", "General") + node_course_id = course_id or new_node.get("course_id") init_m = float(new_node.get("initial_mastery", 0.0)) + existing = table("graph_nodes").select( "id", - filters={"user_id": f"eq.{user_id}", "concept_name": f"eq.{name}"}, + filters=_node_filters(user_id, name, node_course_id), ) + if not existing: table("graph_nodes").insert({ "id": str(uuid.uuid4()), @@ -268,18 +324,18 @@ def apply_graph_update(user_id: str, graph_update: dict) -> list: "concept_name": name, "mastery_score": init_m, "mastery_tier": get_mastery_tier(init_m), - "subject": subject, + "course_id": node_course_id, "mastery_events": [], }) - if subject and subject != "General": - touched_subjects.add(subject) + if node_course_id: + touched_courses.add(node_course_id) for upd in graph_update.get("updated_nodes", []): name = upd.get("concept_name", "") delta = float(upd.get("mastery_delta", 0.0)) rows = table("graph_nodes").select( - "id,mastery_score,times_studied,subject,mastery_events", - filters={"user_id": f"eq.{user_id}", "concept_name": f"eq.{name}"}, + "id,mastery_score,times_studied,course_id,mastery_events", + filters=_node_filters(user_id, name, course_id), ) if rows: row = rows[0] @@ -306,9 +362,11 @@ def apply_graph_update(user_id: str, graph_update: dict) -> list: filters={"id": f"eq.{row['id']}"}, ) mastery_changes.append({"concept": name, "before": before, "after": after}) - subj = row.get("subject", "") - if subj and subj != "General": - touched_subjects.add(subj) + + cid = row.get("course_id") + if cid: + touched_courses.add(cid) + if mastery_changes: update_streak(user_id) @@ -318,10 +376,10 @@ def apply_graph_update(user_id: str, graph_update: dict) -> list: strength = float(new_edge.get("strength", 0.5)) relationship_type = new_edge.get("relationship_type", "related") src_rows = table("graph_nodes").select( - "id", filters={"user_id": f"eq.{user_id}", "concept_name": f"eq.{src_name}"} + "id", filters=_node_filters(user_id, src_name, course_id) ) tgt_rows = table("graph_nodes").select( - "id", filters={"user_id": f"eq.{user_id}", "concept_name": f"eq.{tgt_name}"} + "id", filters=_node_filters(user_id, tgt_name, course_id) ) if src_rows and tgt_rows: src_id = src_rows[0]["id"] @@ -344,12 +402,12 @@ def apply_graph_update(user_id: str, graph_update: dict) -> list: "relationship_type": relationship_type, }) - # Refresh shared course context for every subject touched in this update - if touched_subjects: + # Refresh shared course context for every course touched in this update + if touched_courses: from services.course_context_service import update_course_context - for subj in touched_subjects: + for cid in touched_courses: try: - update_course_context(subj) + update_course_context(cid) except Exception: pass # never block the main response for a context refresh diff --git a/backend/tests/test_calendar_routes.py b/backend/tests/test_calendar_routes.py index a7ce3d7..98c8496 100644 --- a/backend/tests/test_calendar_routes.py +++ b/backend/tests/test_calendar_routes.py @@ -119,16 +119,24 @@ def test_save_skips_when_iso_datetime_matches_existing_date(self): class TestGetUpcoming: def test_returns_assignments_from_db(self): mock_rows = [ - {"id": "a1", "title": "HW1", "due_date": "2026-03-01"}, - {"id": "a2", "title": "Quiz", "due_date": "2026-03-10"}, + {"id": "a1", "user_id": "user_andres", "title": "HW1", + "due_date": "2026-03-01", "assignment_type": "homework", + "notes": None, "google_event_id": None, "course_id": None, "courses": None}, + {"id": "a2", "user_id": "user_andres", "title": "Quiz", + "due_date": "2026-03-10", "assignment_type": "quiz", + "notes": None, "google_event_id": None, "course_id": None, "courses": None}, ] with patch("routes.calendar.table") as t: t.return_value.select.return_value = mock_rows r = client.get("/api/calendar/upcoming/user_andres") assert r.status_code == 200 - assert len(r.json()["assignments"]) == 2 - assert r.json()["assignments"][0]["title"] == "HW1" + assignments = r.json()["assignments"] + assert len(assignments) == 2 + assert assignments[0]["title"] == "HW1" + assert assignments[0]["user_id"] == "user_andres" + assert assignments[0]["course_code"] == "" + assert assignments[0]["course_name"] == "" def test_returns_empty_list_when_none(self): with patch("routes.calendar.table") as t: diff --git a/backend/tests/test_graph_service.py b/backend/tests/test_graph_service.py index 24e470c..2f924e2 100644 --- a/backend/tests/test_graph_service.py +++ b/backend/tests/test_graph_service.py @@ -44,6 +44,17 @@ def _simple_mock(select_returns=None): return mock +def _enrollment_row(course_id: str, code: str = "", name: str = "Course"): + return { + "id": f"e-{course_id}", + "course_id": course_id, + "color": None, + "nickname": None, + "enrolled_at": "2026-01-01", + "courses": {"course_code": code, "course_name": name, "school": "", "department": ""}, + } + + # ── get_graph ───────────────────────────────────────────────────────────────── class TestGetGraph: @@ -52,7 +63,7 @@ def test_empty_graph_returns_zero_stats(self): "users": [{"streak_count": 5}], "graph_nodes": [], "graph_edges": [], - "courses": [], + "user_courses": [], }) with patch("services.graph_service.table", side_effect=factory): result = get_graph("u1") @@ -64,12 +75,17 @@ def test_empty_graph_returns_zero_stats(self): def test_counts_each_mastery_tier(self): nodes = [ - {"id": "n1", "concept_name": "A", "mastery_tier": "mastered", "mastery_score": 0.9, "subject": "Math", "times_studied": 1, "user_id": "u1"}, - {"id": "n2", "concept_name": "B", "mastery_tier": "learning", "mastery_score": 0.5, "subject": "Math", "times_studied": 1, "user_id": "u1"}, - {"id": "n3", "concept_name": "C", "mastery_tier": "struggling", "mastery_score": 0.2, "subject": "Math", "times_studied": 1, "user_id": "u1"}, - {"id": "n4", "concept_name": "D", "mastery_tier": "unexplored", "mastery_score": 0.0, "subject": "Math", "times_studied": 0, "user_id": "u1"}, + {"id": "n1", "concept_name": "A", "mastery_tier": "mastered", "mastery_score": 0.9, "subject": "Math", "times_studied": 1, "user_id": "u1", "course_id": "c1"}, + {"id": "n2", "concept_name": "B", "mastery_tier": "learning", "mastery_score": 0.5, "subject": "Math", "times_studied": 1, "user_id": "u1", "course_id": "c1"}, + {"id": "n3", "concept_name": "C", "mastery_tier": "struggling", "mastery_score": 0.2, "subject": "Math", "times_studied": 1, "user_id": "u1", "course_id": "c1"}, + {"id": "n4", "concept_name": "D", "mastery_tier": "unexplored", "mastery_score": 0.0, "subject": "Math", "times_studied": 0, "user_id": "u1", "course_id": "c1"}, ] - factory = _mock_table({"users": [{"streak_count": 0}], "graph_nodes": nodes, "graph_edges": [], "courses": []}) + factory = _mock_table({ + "users": [{"streak_count": 0}], + "graph_nodes": nodes, + "graph_edges": [], + "user_courses": [_enrollment_row("c1", "M", "Math")], + }) with patch("services.graph_service.table", side_effect=factory): stats = get_graph("u1")["stats"] @@ -81,31 +97,36 @@ def test_counts_each_mastery_tier(self): def test_adds_subject_root_node_per_subject(self): nodes = [ - {"id": "n1", "concept_name": "Loops", "mastery_tier": "learning", "mastery_score": 0.5, "subject": "CS101", "times_studied": 2, "user_id": "u1"}, - {"id": "n2", "concept_name": "Functions", "mastery_tier": "mastered", "mastery_score": 0.8, "subject": "CS101", "times_studied": 3, "user_id": "u1"}, + {"id": "n1", "concept_name": "Loops", "mastery_tier": "learning", "mastery_score": 0.5, + "subject": "CS101", "course_id": "c1", "times_studied": 2, "user_id": "u1"}, + {"id": "n2", "concept_name": "Functions", "mastery_tier": "mastered", "mastery_score": 0.8, + "subject": "CS101", "course_id": "c1", "times_studied": 3, "user_id": "u1"}, ] - factory = _mock_table({"users": [{"streak_count": 0}], "graph_nodes": nodes, "graph_edges": [], "courses": []}) + enrollment = [{"course_id": "c1", "courses": {"course_code": "CS101", "course_name": "Intro CS"}}] + factory = _mock_table({ + "users": [{"streak_count": 0}], + "user_courses": enrollment, + "graph_nodes": nodes, + "graph_edges": [], + }) with patch("services.graph_service.table", side_effect=factory): result = get_graph("u1") roots = [n for n in result["nodes"] if n.get("is_subject_root")] assert len(roots) == 1 - assert roots[0]["concept_name"] == "CS101" + assert "CS101" in roots[0]["concept_name"] assert roots[0]["mastery_tier"] == "subject_root" def test_legacy_seed_same_as_course_title_shows_only_subject_hub(self): - """Old seed rows (concept_name == course) are hidden; the course is only the big hub.""" - nodes = [ - { - "id": "n1", "concept_name": "EK 103: LINEAR ALGEBRA", "mastery_tier": "unexplored", - "mastery_score": 0.0, "subject": "EK 103: LINEAR ALGEBRA", "times_studied": 0, "user_id": "u1", - }, + """Course enrolled but no concept nodes — only the subject hub appears.""" + enrollment = [ + {"course_id": "c1", "courses": {"course_code": "", "course_name": "EK 103: LINEAR ALGEBRA"}} ] factory = _mock_table({ "users": [{"streak_count": 0}], - "graph_nodes": nodes, + "user_courses": enrollment, + "graph_nodes": [], "graph_edges": [], - "courses": [{"course_name": "EK 103: LINEAR ALGEBRA"}], }) with patch("services.graph_service.table", side_effect=factory): result = get_graph("u1") @@ -114,15 +135,14 @@ def test_legacy_seed_same_as_course_title_shows_only_subject_hub(self): assert len(roots) == 1 assert roots[0]["concept_name"] == "EK 103: LINEAR ALGEBRA" assert roots[0]["mastery_score"] == 0.0 - concepts = [n for n in result["nodes"] if not n.get("is_subject_root")] - assert len(concepts) == 0 def test_course_with_no_graph_nodes_still_shows_subject_hub(self): + enrollment = [{"course_id": "c1", "courses": {"course_code": "", "course_name": "Philosophy"}}] factory = _mock_table({ "users": [{"streak_count": 0}], + "user_courses": enrollment, "graph_nodes": [], "graph_edges": [], - "courses": [{"course_name": "Philosophy"}], }) with patch("services.graph_service.table", side_effect=factory): result = get_graph("u1") @@ -133,9 +153,14 @@ def test_course_with_no_graph_nodes_still_shows_subject_hub(self): assert roots[0]["mastery_score"] == 0.0 def test_edges_are_remapped(self): - nodes = [{"id": "n1", "concept_name": "A", "mastery_tier": "learning", "mastery_score": 0.5, "subject": "X", "times_studied": 0, "user_id": "u1"}] + nodes = [{"id": "n1", "concept_name": "A", "mastery_tier": "learning", "mastery_score": 0.5, "subject": "X", "times_studied": 0, "user_id": "u1", "course_id": "c-x"}] edges = [{"id": "e1", "source_node_id": "n1", "target_node_id": "n1", "strength": 0.9}] - factory = _mock_table({"users": [{"streak_count": 0}], "graph_nodes": nodes, "graph_edges": edges, "courses": []}) + factory = _mock_table({ + "users": [{"streak_count": 0}], + "graph_nodes": nodes, + "graph_edges": edges, + "user_courses": [_enrollment_row("c-x", "", "X")], + }) with patch("services.graph_service.table", side_effect=factory): result = get_graph("u1") @@ -145,7 +170,7 @@ def test_edges_are_remapped(self): assert graph_edges[0]["strength"] == 0.9 def test_streak_defaults_to_zero_when_no_user_row(self): - factory = _mock_table({"users": [], "graph_nodes": [], "graph_edges": [], "courses": []}) + factory = _mock_table({"users": [], "graph_nodes": [], "graph_edges": [], "user_courses": []}) with patch("services.graph_service.table", side_effect=factory): result = get_graph("u1") assert result["stats"]["streak"] == 0 @@ -157,10 +182,17 @@ class TestGetCourses: def test_returns_courses_with_node_count(self): def factory(name): mock = MagicMock() - if name == "courses": - mock.select.return_value = [{"id": "c1", "course_name": "Math", "color": "#fff", "created_at": "2026-01-01"}] + if name == "user_courses": + mock.select.return_value = [{ + "id": "e1", "course_id": "c1", "color": "#fff", + "nickname": None, "enrolled_at": "2026-01-01", + "courses": {"course_code": "MATH101", "course_name": "Math", + "school": "BU", "department": "Math"}, + }] elif name == "graph_nodes": mock.select.return_value = [{"id": "n1"}, {"id": "n2"}] + else: + mock.select.return_value = [] return mock with patch("services.graph_service.table", side_effect=factory): @@ -185,42 +217,50 @@ def factory(name): class TestAddCourse: def test_inserts_new_course(self): - mock = _simple_mock(select_returns=[]) - with patch("services.graph_service.table", return_value=mock): - result = add_course("u1", "Physics") + def factory(name): + mock = MagicMock() + if name == "user_courses": + # First call: check existing enrollment → not found + mock.select.return_value = [] + elif name == "courses": + # Check canonical course exists → found + mock.select.return_value = [{"id": "c1"}] + else: + mock.select.return_value = [] + mock.insert.return_value = [] + return mock - assert result["course_name"] == "Physics" + with patch("services.graph_service.table", side_effect=factory): + result = add_course("u1", "c1") + + assert result["course_id"] == "c1" assert result["already_existed"] is False - mock.insert.assert_called_once() def test_skips_insert_for_existing_course(self): mock = _simple_mock(select_returns=[{"id": "existing"}]) with patch("services.graph_service.table", return_value=mock): - result = add_course("u1", "Physics") + result = add_course("u1", "c1") assert result["already_existed"] is True - mock.insert.assert_not_called() # ── delete_course ───────────────────────────────────────────────────────────── class TestDeleteCourse: - def test_deletes_nodes_edges_and_course(self): - def factory(name): - mock = MagicMock() - mock.select.return_value = [{"id": "n1"}, {"id": "n2"}] if name == "graph_nodes" else [] - mock.delete.return_value = [] - return mock - - with patch("services.graph_service.table", side_effect=factory): - result = delete_course("u1", "Math") + def test_unenrolls_user_from_course(self): + mock = MagicMock() + mock.delete.return_value = [] + with patch("services.graph_service.table", return_value=mock): + result = delete_course("u1", "course-id-1") assert result == {"deleted": True} + mock.delete.assert_called_once() - def test_deletes_course_with_no_nodes(self): - mock = _simple_mock(select_returns=[]) + def test_unenroll_with_no_prior_nodes(self): + mock = MagicMock() + mock.delete.return_value = [] with patch("services.graph_service.table", return_value=mock): - result = delete_course("u1", "EmptyCourse") + result = delete_course("u1", "empty-course-id") assert result == {"deleted": True} diff --git a/backend/tests/test_learn_routes.py b/backend/tests/test_learn_routes.py index 706f7dd..db27904 100644 --- a/backend/tests/test_learn_routes.py +++ b/backend/tests/test_learn_routes.py @@ -8,58 +8,103 @@ from unittest.mock import MagicMock, patch from fastapi.testclient import TestClient -from routes.learn import _resolve_course from main import app client = TestClient(app) -# ── _resolve_course ─────────────────────────────────────────────────────────── +# ── _get_course_id_for_topic ────────────────────────────────────────────────── -class TestResolveCourse: +class TestGetCourseIdForTopic: def test_empty_topic_returns_empty(self): + from routes.learn import _get_course_id_for_topic with patch("routes.learn.table"): - assert _resolve_course("", "u1") == "" + assert _get_course_id_for_topic("", "u1") == "" + + def test_matches_enrolled_course_code(self): + from routes.learn import _get_course_id_for_topic + uc = MagicMock() + uc.select.return_value = [ + {"course_id": "cid-math", "courses": {"course_code": "MATH", "course_name": "Calculus"}}, + ] - def test_topic_matches_subject_name(self): def factory(name): - mock = MagicMock() - mock.select.return_value = [{"subject": "Math"}] - return mock + if name == "user_courses": + return uc + m = MagicMock() + m.select.return_value = [] + return m with patch("routes.learn.table", side_effect=factory): - assert _resolve_course("Math", "u1") == "Math" + assert _get_course_id_for_topic("math", "u1") == "cid-math" - def test_topic_matches_concept_name(self): - call_count = {"n": 0} + def test_matches_enrolled_course_name(self): + from routes.learn import _get_course_id_for_topic + uc = MagicMock() + uc.select.return_value = [ + {"course_id": "cid-bio", "courses": {"course_code": "", "course_name": "Biology 101"}}, + ] def factory(name): - mock = MagicMock() - call_count["n"] += 1 - # First call: subject lookup → no match - # Second call: concept lookup → match with subject - mock.select.return_value = [] if call_count["n"] == 1 else [{"subject": "CS101"}] - return mock + if name == "user_courses": + return uc + m = MagicMock() + m.select.return_value = [] + return m + + with patch("routes.learn.table", side_effect=factory): + assert _get_course_id_for_topic("biology 101", "u1") == "cid-bio" + + def test_matches_graph_subject_label(self): + from routes.learn import _get_course_id_for_topic + uc = MagicMock() + uc.select.return_value = [ + { + "course_id": "cid-x", + "courses": {"course_code": "CS", "course_name": "Intro"}, + }, + ] + + def factory(name): + if name == "user_courses": + return uc + if name == "graph_nodes": + m = MagicMock() + m.select.return_value = [] + return m + m = MagicMock() + m.select.return_value = [] + return m + + with patch("routes.learn.table", side_effect=factory): + assert _get_course_id_for_topic("CS - Intro", "u1") == "cid-x" + + def test_concept_node_with_course_id(self): + from routes.learn import _get_course_id_for_topic + uc = MagicMock() + uc.select.return_value = [] + + gn = MagicMock() + gn.select.return_value = [{"course_id": "cid-from-node"}] + + def factory(name): + if name == "user_courses": + return uc + if name == "graph_nodes": + return gn + m = MagicMock() + m.select.return_value = [] + return m with patch("routes.learn.table", side_effect=factory): - assert _resolve_course("Recursion", "u1") == "CS101" + assert _get_course_id_for_topic("Recursion", "u1") == "cid-from-node" def test_unknown_topic_returns_empty(self): + from routes.learn import _get_course_id_for_topic mock = MagicMock() mock.select.return_value = [] with patch("routes.learn.table", return_value=mock): - assert _resolve_course("UnknownXyzzy", "u1") == "" - - def test_topic_matches_course_case_insensitive(self): - mock = MagicMock() - mock.select.side_effect = [ - [], - [], - [], - [{"course_name": "Biology 101"}], - ] - with patch("routes.learn.table", return_value=mock): - assert _resolve_course("biology 101", "u1") == "Biology 101" + assert _get_course_id_for_topic("UnknownXyzzy", "u1") == "" # ── GET /api/learn/sessions/{user_id} ──────────────────────────────────────── diff --git a/backend/tests/test_ocr_pipeline.py b/backend/tests/test_ocr_pipeline.py index 0b75661..de70e94 100644 --- a/backend/tests/test_ocr_pipeline.py +++ b/backend/tests/test_ocr_pipeline.py @@ -10,6 +10,11 @@ import pytest +pytestmark = pytest.mark.skipif( + not os.getenv("GEMINI_API_KEY"), + reason="OCR/Gemini integration tests require GEMINI_API_KEY", +) + sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from services.calendar_service import save_assignments_to_db, process_and_save_syllabus diff --git a/backend/tests/test_quiz_routes.py b/backend/tests/test_quiz_routes.py index 8d99c70..6a5d2c2 100644 --- a/backend/tests/test_quiz_routes.py +++ b/backend/tests/test_quiz_routes.py @@ -7,6 +7,7 @@ - POST /api/quiz/submit — 404 when quiz not found """ import pytest +from contextlib import contextmanager from unittest.mock import MagicMock, patch from fastapi.testclient import TestClient @@ -98,51 +99,50 @@ def factory(name): return factory +@contextmanager +def _submit_quiz_mocks(questions=None): + with ( + patch("routes.quiz.table", side_effect=_make_table(questions)), + patch("routes.quiz.update_streak"), + patch("routes.quiz.get_quiz_context", return_value={}), + patch("routes.quiz.call_gemini_json", return_value={}), + ): + yield + + class TestSubmitQuiz: def test_all_correct_returns_full_score(self): - with patch("routes.quiz.table", side_effect=_make_table()): - with patch("routes.quiz.get_graph", return_value={"nodes": [], "edges": [], "stats": {}}): - with patch("routes.quiz.get_quiz_context", return_value={}): - with patch("routes.quiz.call_gemini_json", return_value={}): - r = client.post("/api/quiz/submit", json={ - "quiz_id": "quiz1", - "answers": [ - {"question_id": 1, "selected_label": "A"}, - {"question_id": 2, "selected_label": "D"}, - ], - }) - + with _submit_quiz_mocks(): + r = client.post("/api/quiz/submit", json={ + "quiz_id": "quiz1", + "answers": [ + {"question_id": 1, "selected_label": "A"}, + {"question_id": 2, "selected_label": "D"}, + ], + }) assert r.status_code == 200 data = r.json() assert data["score"] == 2 assert data["total"] == 2 def test_all_wrong_returns_zero_score(self): - with patch("routes.quiz.table", side_effect=_make_table()): - with patch("routes.quiz.get_graph", return_value={"nodes": [], "edges": [], "stats": {}}): - with patch("routes.quiz.get_quiz_context", return_value={}): - with patch("routes.quiz.call_gemini_json", return_value={}): - r = client.post("/api/quiz/submit", json={ - "quiz_id": "quiz1", - "answers": [ - {"question_id": 1, "selected_label": "B"}, - {"question_id": 2, "selected_label": "C"}, - ], - }) - + with _submit_quiz_mocks(): + r = client.post("/api/quiz/submit", json={ + "quiz_id": "quiz1", + "answers": [ + {"question_id": 1, "selected_label": "B"}, + {"question_id": 2, "selected_label": "C"}, + ], + }) assert r.status_code == 200 assert r.json()["score"] == 0 def test_result_shape_contains_required_fields(self): - with patch("routes.quiz.table", side_effect=_make_table()): - with patch("routes.quiz.get_graph", return_value={"nodes": [], "edges": [], "stats": {}}): - with patch("routes.quiz.get_quiz_context", return_value={}): - with patch("routes.quiz.call_gemini_json", return_value={}): - r = client.post("/api/quiz/submit", json={ - "quiz_id": "quiz1", - "answers": [{"question_id": 1, "selected_label": "A"}], - }) - + with _submit_quiz_mocks(): + r = client.post("/api/quiz/submit", json={ + "quiz_id": "quiz1", + "answers": [{"question_id": 1, "selected_label": "A"}], + }) data = r.json() assert "score" in data assert "total" in data @@ -151,52 +151,40 @@ def test_result_shape_contains_required_fields(self): assert "results" in data def test_each_result_has_correct_flag(self): - with patch("routes.quiz.table", side_effect=_make_table()): - with patch("routes.quiz.get_graph", return_value={"nodes": [], "edges": [], "stats": {}}): - with patch("routes.quiz.get_quiz_context", return_value={}): - with patch("routes.quiz.call_gemini_json", return_value={}): - r = client.post("/api/quiz/submit", json={ - "quiz_id": "quiz1", - "answers": [ - {"question_id": 1, "selected_label": "A"}, # correct - {"question_id": 2, "selected_label": "C"}, # wrong - ], - }) - + with _submit_quiz_mocks(): + r = client.post("/api/quiz/submit", json={ + "quiz_id": "quiz1", + "answers": [ + {"question_id": 1, "selected_label": "A"}, # correct + {"question_id": 2, "selected_label": "C"}, # wrong + ], + }) results = r.json()["results"] correct_flags = {str(res["question_id"]): res["correct"] for res in results} assert correct_flags["1"] is True assert correct_flags["2"] is False def test_mastery_after_is_higher_on_perfect_score(self): - with patch("routes.quiz.table", side_effect=_make_table()): - with patch("routes.quiz.get_graph", return_value={"nodes": [], "edges": [], "stats": {}}): - with patch("routes.quiz.get_quiz_context", return_value={}): - with patch("routes.quiz.call_gemini_json", return_value={}): - r = client.post("/api/quiz/submit", json={ - "quiz_id": "quiz1", - "answers": [ - {"question_id": 1, "selected_label": "A"}, - {"question_id": 2, "selected_label": "D"}, - ], - }) - + with _submit_quiz_mocks(): + r = client.post("/api/quiz/submit", json={ + "quiz_id": "quiz1", + "answers": [ + {"question_id": 1, "selected_label": "A"}, + {"question_id": 2, "selected_label": "D"}, + ], + }) data = r.json() assert data["mastery_after"] > data["mastery_before"] def test_mastery_after_is_lower_on_zero_score(self): - with patch("routes.quiz.table", side_effect=_make_table()): - with patch("routes.quiz.get_graph", return_value={"nodes": [], "edges": [], "stats": {}}): - with patch("routes.quiz.get_quiz_context", return_value={}): - with patch("routes.quiz.call_gemini_json", return_value={}): - r = client.post("/api/quiz/submit", json={ - "quiz_id": "quiz1", - "answers": [ - {"question_id": 1, "selected_label": "B"}, - {"question_id": 2, "selected_label": "C"}, - ], - }) - + with _submit_quiz_mocks(): + r = client.post("/api/quiz/submit", json={ + "quiz_id": "quiz1", + "answers": [ + {"question_id": 1, "selected_label": "B"}, + {"question_id": 2, "selected_label": "C"}, + ], + }) data = r.json() assert data["mastery_after"] < data["mastery_before"] @@ -233,7 +221,7 @@ def factory(name): return mock with patch("routes.quiz.table", side_effect=factory): - with patch("routes.quiz.get_graph", return_value={"nodes": [], "edges": [], "stats": {}}): + with patch("routes.quiz.update_streak"): with patch("routes.quiz.get_quiz_context", return_value={}): with patch("routes.quiz.call_gemini_json", return_value={}): r = client.post("/api/quiz/submit", json={ diff --git a/backend/tests/test_shared_course_context.py b/backend/tests/test_shared_course_context.py index db3f5a9..51ec1d9 100644 --- a/backend/tests/test_shared_course_context.py +++ b/backend/tests/test_shared_course_context.py @@ -2,8 +2,7 @@ Unit tests for the shared course context system. Tests: course_context_service, graph_service (apply_graph_update side-effects), - learn.py helpers (_resolve_course, _get_session_topic, build_system_prompt), - quiz.py (generate_quiz prompt augmentation). + learn.py (build_system_prompt), quiz.py (generate_quiz prompt augmentation). Run from backend/: python -m pytest tests/test_shared_course_context.py -v @@ -48,12 +47,45 @@ def test_empty_course_name_returns_empty_dict(self): @patch("services.course_context_service.table") def test_returns_context_json_when_found(self, mock_table): - ctx = {"struggling_concepts": [{"concept": "Pointers", "avg_mastery": 0.2}]} - mock_table.return_value.select.return_value = [{"context_json": ctx}] + summary_row = { + "course_id": "CS101", + "semester": "Spring 2026", + "student_count": 5, + "avg_class_mastery": 0.6, + "top_struggling_concepts": ["Pointers"], + "top_mastered_concepts": ["Variables"], + "summary_text": "Good progress.", + "updated_at": "2026-04-01T00:00:00+00:00", + } + stat_row = { + "course_id": "CS101", + "concept_name": "Pointers", + "semester": "Spring 2026", + "avg_mastery_score": 0.2, + "pct_struggling": 0.6, + "pct_mastered": 0.1, + "pct_unexplored": 0.3, + "student_count": 5, + "common_misconceptions": ["Dangling pointer"], + } + + def _tbl(name): + m = MagicMock() + if name == "course_summary": + m.select.return_value = [summary_row] + elif name == "course_concept_stats": + m.select.return_value = [stat_row] + else: + m.select.return_value = [] + return m + mock_table.side_effect = _tbl from services.course_context_service import get_course_context result = get_course_context("CS101") - self.assertEqual(result, ctx) + self.assertIn("course_summary", result) + self.assertIn("concept_stats", result) + self.assertEqual(result["course_summary"]["course_id"], "CS101") + self.assertEqual(result["concept_stats"][0]["concept_name"], "Pointers") @patch("services.course_context_service.table") def test_returns_empty_dict_when_not_found(self, mock_table): @@ -96,7 +128,9 @@ def test_no_op_when_no_nodes(self, mock_table): @patch("services.course_context_service.table") def test_aggregates_mastery_and_upserts(self, mock_table): - # Two users, same concept "Loops" — one struggling, one mastered + # Two students enrolled, same concept "Loops" — one struggling, one mastered + enrollment_rows = [{"user_id": "u1"}, {"user_id": "u2"}] + course_rows = [{"course_code": "CS101", "course_name": "Intro CS"}] node_rows = [ {"id": "n1", "concept_name": "Loops", "mastery_score": 0.2, "mastery_tier": "struggling", "user_id": "u1"}, @@ -104,45 +138,51 @@ def test_aggregates_mastery_and_upserts(self, mock_table): "mastery_tier": "mastered", "user_id": "u2"}, ] - def _select(*args, **kwargs): - return node_rows - - node_tbl = MagicMock() - node_tbl.select.side_effect = _select + stats_tbl = MagicMock() + stats_tbl.upsert.return_value = None - quiz_tbl = MagicMock() - quiz_tbl.select.return_value = [] - - ctx_tbl = MagicMock() + summary_tbl = MagicMock() + summary_tbl.select.return_value = [] # no existing summary + summary_tbl.upsert.return_value = None def _table(name): - if name == "graph_nodes": - return node_tbl - if name == "quiz_context": - return quiz_tbl - if name == "course_context": - return ctx_tbl - return MagicMock() + m = MagicMock() + if name == "user_courses": + m.select.return_value = enrollment_rows + elif name == "courses": + m.select.return_value = course_rows + elif name == "graph_nodes": + m.select.return_value = node_rows + elif name == "quiz_context": + m.select.return_value = [] + elif name == "course_concept_stats": + return stats_tbl + elif name == "course_summary": + return summary_tbl + else: + m.select.return_value = [] + return m mock_table.side_effect = _table - from services.course_context_service import update_course_context - update_course_context("CS101") + with patch("services.course_context_service._generate_summary_with_gemini", return_value="summary"): + from services.course_context_service import update_course_context + update_course_context("c-cs101") - ctx_tbl.upsert.assert_called_once() - upsert_payload = ctx_tbl.upsert.call_args[0][0] - self.assertEqual(upsert_payload["course_name"], "CS101") + # course_concept_stats should be upserted for "Loops" + stats_tbl.upsert.assert_called_once() + upsert_payload = stats_tbl.upsert.call_args[0][0] + self.assertEqual(upsert_payload["course_id"], "c-cs101") + self.assertEqual(upsert_payload["concept_name"], "Loops") self.assertEqual(upsert_payload["student_count"], 2) - - ctx_json = upsert_payload["context_json"] # avg mastery for Loops = (0.2 + 0.9) / 2 = 0.55 - self.assertAlmostEqual( - ctx_json["concept_difficulty_ranking"][0]["avg_mastery"], 0.55, places=2 - ) + self.assertAlmostEqual(upsert_payload["avg_mastery_score"], 0.55, places=2) @patch("services.course_context_service.table") def test_struggling_concepts_threshold(self, mock_table): - """struggling_pct > 0.2 should appear in struggling_concepts.""" + """Concepts with pct_struggling > 0 should appear in top_struggling_concepts.""" + enrollment_rows = [{"user_id": "u1"}, {"user_id": "u2"}] + course_rows = [{"course_code": "CS101", "course_name": "Intro CS"}] node_rows = [ {"id": "n1", "concept_name": "Recursion", "mastery_score": 0.1, "mastery_tier": "struggling", "user_id": "u1"}, @@ -152,36 +192,50 @@ def test_struggling_concepts_threshold(self, mock_table): "mastery_tier": "mastered", "user_id": "u1"}, ] - node_tbl = MagicMock() - node_tbl.select.return_value = node_rows - quiz_tbl = MagicMock() - quiz_tbl.select.return_value = [] - ctx_tbl = MagicMock() + stats_tbl = MagicMock() + summary_tbl = MagicMock() + summary_tbl.select.return_value = [] def _table(name): - if name == "graph_nodes": return node_tbl - if name == "quiz_context": return quiz_tbl - return ctx_tbl + m = MagicMock() + if name == "user_courses": + m.select.return_value = enrollment_rows + elif name == "courses": + m.select.return_value = course_rows + elif name == "graph_nodes": + m.select.return_value = node_rows + elif name == "quiz_context": + m.select.return_value = [] + elif name == "course_concept_stats": + return stats_tbl + elif name == "course_summary": + return summary_tbl + else: + m.select.return_value = [] + return m mock_table.side_effect = _table - from services.course_context_service import update_course_context - update_course_context("CS101") + with patch("services.course_context_service._generate_summary_with_gemini", return_value="summary"): + from services.course_context_service import update_course_context + update_course_context("c-cs101") - payload = ctx_tbl.upsert.call_args[0][0]["context_json"] - struggling_names = [c["concept"] for c in payload["struggling_concepts"]] - self.assertIn("Recursion", struggling_names) - self.assertNotIn("Loops", struggling_names) + # course_summary upsert should have Recursion in top_struggling_concepts + summary_tbl.upsert.assert_called_once() + summary_payload = summary_tbl.upsert.call_args[0][0] + self.assertIn("Recursion", summary_payload["top_struggling_concepts"]) + self.assertNotIn("Loops", summary_payload["top_struggling_concepts"]) @patch("services.course_context_service.table") def test_deduplicates_misconceptions_case_insensitive(self, mock_table): + enrollment_rows = [{"user_id": "u1"}, {"user_id": "u2"}] + course_rows = [{"course_code": "CS101", "course_name": "Intro CS"}] node_rows = [ {"id": "n1", "concept_name": "Loops", "mastery_score": 0.3, "mastery_tier": "learning", "user_id": "u1"}, {"id": "n2", "concept_name": "Loops", "mastery_score": 0.3, "mastery_tier": "learning", "user_id": "u2"}, ] - node_id_set = {"n1", "n2"} quiz_rows = [ {"concept_node_id": "n1", "context_json": {"common_mistakes": ["Off-by-one error", "off-by-one error"], "weak_areas": []}}, @@ -189,26 +243,39 @@ def test_deduplicates_misconceptions_case_insensitive(self, mock_table): "context_json": {"common_mistakes": ["OFF-BY-ONE ERROR"], "weak_areas": ["boundary conditions"]}}, ] - node_tbl = MagicMock() - node_tbl.select.return_value = node_rows - quiz_tbl = MagicMock() - quiz_tbl.select.return_value = quiz_rows - ctx_tbl = MagicMock() + stats_tbl = MagicMock() + summary_tbl = MagicMock() + summary_tbl.select.return_value = [] def _table(name): - if name == "graph_nodes": return node_tbl - if name == "quiz_context": return quiz_tbl - return ctx_tbl + m = MagicMock() + if name == "user_courses": + m.select.return_value = enrollment_rows + elif name == "courses": + m.select.return_value = course_rows + elif name == "graph_nodes": + m.select.return_value = node_rows + elif name == "quiz_context": + m.select.return_value = quiz_rows + elif name == "course_concept_stats": + return stats_tbl + elif name == "course_summary": + return summary_tbl + else: + m.select.return_value = [] + return m mock_table.side_effect = _table - from services.course_context_service import update_course_context - update_course_context("CS101") + with patch("services.course_context_service._generate_summary_with_gemini", return_value="summary"): + from services.course_context_service import update_course_context + update_course_context("c-cs101") - payload = ctx_tbl.upsert.call_args[0][0]["context_json"] # All three "off-by-one" variants are the same after .lower() — only one kept - self.assertEqual(len(payload["common_misconceptions"]), 1) - self.assertEqual(len(payload["weak_areas"]), 1) + stats_tbl.upsert.assert_called_once() + upsert_payload = stats_tbl.upsert.call_args[0][0] + self.assertEqual(len(upsert_payload["common_misconceptions"]), 1) + self.assertEqual(len(upsert_payload["prerequisite_gaps"]), 1) # ───────────────────────────────────────────────────────────────────────────── @@ -226,7 +293,8 @@ def test_update_course_context_called_for_touched_subjects( # patch it at the source module so the import resolves to our mock. node_tbl = MagicMock() node_tbl.select.return_value = [ - {"id": "n1", "mastery_score": 0.4, "times_studied": 2, "subject": "CS101"} + {"id": "n1", "mastery_score": 0.4, "times_studied": 2, + "subject": "CS101", "course_id": "course-1", "mastery_events": []} ] def _table(name): @@ -243,7 +311,7 @@ def _table(name): "new_edges": []} ) - mock_update_ctx.assert_called_once_with("CS101") + mock_update_ctx.assert_called_once_with("course-1") @patch("services.graph_service.table") @patch("services.course_context_service.update_course_context", @@ -254,7 +322,7 @@ def test_update_course_context_exception_does_not_raise( """A failure in update_course_context must never surface to the caller.""" node_tbl = MagicMock() node_tbl.select.return_value = [ - {"id": "n1", "mastery_score": 0.4, "times_studied": 2, "subject": "CS101"} + {"id": "n1", "mastery_score": 0.4, "times_studied": 2, "course_id": "c1", "subject": "CS101"} ] def _table(name): @@ -298,100 +366,124 @@ def _table(name): # ───────────────────────────────────────────────────────────────────────────── -# 4. learn.py — _resolve_course, _get_session_topic, build_system_prompt +# 4. learn.py — build_system_prompt # ───────────────────────────────────────────────────────────────────────────── class TestLearnHelpers(unittest.TestCase): @patch("routes.learn.table") - def test_resolve_course_when_topic_is_subject(self, mock_table): - mock_table.return_value.select.return_value = [{"subject": "CS101"}] + def test_resolve_course_when_topic_matches_course_code(self, mock_table): + enrolled_tbl = MagicMock() + enrolled_tbl.select.return_value = [ + {"course_id": "course-1", "courses": {"course_code": "CS101", "course_name": "Intro CS"}} + ] + node_tbl = MagicMock() + node_tbl.select.return_value = [] + + def _factory(name): + if name == "user_courses": return enrolled_tbl + return node_tbl - from routes.learn import _resolve_course - result = _resolve_course("CS101", "user1") - self.assertEqual(result, "CS101") + mock_table.side_effect = _factory + + from routes.learn import _get_course_id_for_topic + result = _get_course_id_for_topic("CS101", "user1") + self.assertEqual(result, "course-1") @patch("routes.learn.table") def test_resolve_course_when_topic_is_concept(self, mock_table): - tbl = MagicMock() - # First call (is topic a subject?) → not found - # Second call (is topic a concept?) → found with subject - tbl.select.side_effect = [[], [{"subject": "CS101"}]] - mock_table.return_value = tbl + enrolled_tbl = MagicMock() + enrolled_tbl.select.return_value = [] + node_tbl = MagicMock() + # First call (concept_name match) → found with course_id + node_tbl.select.side_effect = [[{"course_id": "course-1"}], []] + + def _factory(name): + if name == "user_courses": return enrolled_tbl + return node_tbl + + mock_table.side_effect = _factory - from routes.learn import _resolve_course - result = _resolve_course("Loops", "user1") - self.assertEqual(result, "CS101") + from routes.learn import _get_course_id_for_topic + result = _get_course_id_for_topic("Loops", "user1") + self.assertEqual(result, "course-1") @patch("routes.learn.table") def test_resolve_course_unknown_topic_returns_empty(self, mock_table): mock_table.return_value.select.return_value = [] - from routes.learn import _resolve_course - result = _resolve_course("RandomTopic", "user1") + from routes.learn import _get_course_id_for_topic + result = _get_course_id_for_topic("RandomTopic", "user1") self.assertEqual(result, "") def test_resolve_course_empty_topic_returns_empty(self): - from routes.learn import _resolve_course - result = _resolve_course("", "user1") + from routes.learn import _get_course_id_for_topic + result = _get_course_id_for_topic("", "user1") self.assertEqual(result, "") @patch("routes.learn.table") - def test_get_session_topic_found(self, mock_table): - mock_table.return_value.select.return_value = [{"topic": "Recursion"}] + def test_get_session_course_id_found(self, mock_table): + mock_table.return_value.select.return_value = [{"course_id": "course-1"}] - from routes.learn import _get_session_topic - result = _get_session_topic("session-abc") - self.assertEqual(result, "Recursion") + from routes.learn import _get_session_course_id + result = _get_session_course_id("session-abc") + self.assertEqual(result, "course-1") @patch("routes.learn.table") - def test_get_session_topic_not_found(self, mock_table): + def test_get_session_course_id_not_found(self, mock_table): mock_table.return_value.select.return_value = [] - from routes.learn import _get_session_topic - result = _get_session_topic("session-missing") + from routes.learn import _get_session_course_id + result = _get_session_course_id("session-missing") self.assertEqual(result, "") - # get_course_context is lazily imported inside build_system_prompt; - # patch it at the source module so the `from ... import` resolves to our mock. @patch("services.course_context_service.get_course_context", return_value={}) - def test_build_system_prompt_no_course_name(self, mock_ctx): + def test_build_system_prompt_no_course_id(self, mock_ctx): from routes.learn import build_system_prompt prompt = build_system_prompt("socratic", "Alice", "{}") self.assertNotIn("COURSE INTELLIGENCE", prompt) mock_ctx.assert_not_called() + @patch("routes.learn.table") @patch("services.course_context_service.get_course_context", return_value={}) - def test_build_system_prompt_course_name_but_empty_ctx(self, mock_ctx): + def test_build_system_prompt_course_id_but_empty_ctx(self, mock_ctx, mock_table): + mock_table.return_value.select.return_value = [] from routes.learn import build_system_prompt - prompt = build_system_prompt("socratic", "Alice", "{}", course_name="CS101") + prompt = build_system_prompt("socratic", "Alice", "{}", course_id="course-1") self.assertNotIn("COURSE INTELLIGENCE", prompt) - mock_ctx.assert_called_once_with("CS101") + mock_ctx.assert_called_once_with("course-1") + @patch("routes.learn.table") @patch("services.course_context_service.get_course_context") - def test_build_system_prompt_injects_shared_block(self, mock_ctx): + def test_build_system_prompt_injects_shared_block(self, mock_ctx, mock_table): mock_ctx.return_value = { - "struggling_concepts": [{"concept": "Pointers", "avg_mastery": 0.2}], - "mastered_concepts": [], - "concept_difficulty_ranking": [], - "common_misconceptions": ["Off-by-one errors"], - "weak_areas": [], - "student_count": 10, + "course_summary": {"avg_class_mastery": 0.6, "top_struggling_concepts": ["Pointers"]}, + "concept_stats": [], } + mock_table.return_value.select.return_value = [ + {"course_code": "CS101", "course_name": "Intro CS"} + ] from routes.learn import build_system_prompt - prompt = build_system_prompt("socratic", "Alice", "{}", course_name="CS101") + prompt = build_system_prompt("socratic", "Alice", "{}", course_id="course-1") self.assertIn("COURSE INTELLIGENCE", prompt) self.assertIn("CS101", prompt) - mock_ctx.assert_called_once_with("CS101") + mock_ctx.assert_called_once_with("course-1") + @patch("routes.learn.table") @patch("services.course_context_service.get_course_context") - def test_build_system_prompt_mode_appended_after_shared_block(self, mock_ctx): + def test_build_system_prompt_mode_appended_after_shared_block(self, mock_ctx, mock_table): """Mode prompt must always be the last section.""" - mock_ctx.return_value = {"struggling_concepts": [], "student_count": 5} + mock_ctx.return_value = { + "course_summary": {"avg_class_mastery": 0.5, "top_struggling_concepts": []}, + "concept_stats": [], + } + mock_table.return_value.select.return_value = [ + {"course_code": "CS101", "course_name": "Intro CS"} + ] from routes.learn import build_system_prompt, MODE_PROMPTS - prompt = build_system_prompt("expository", "Bob", "{}", course_name="CS101") + prompt = build_system_prompt("expository", "Bob", "{}", course_id="course-1") expository_text = MODE_PROMPTS["expository"] ctx_pos = prompt.find("COURSE INTELLIGENCE") mode_pos = prompt.find(expository_text[:40]) @@ -425,13 +517,19 @@ def test_misconceptions_appended_to_prompt( ): mock_table.return_value.select.return_value = [{ "id": "node-abc", "concept_name": "Pointers", - "mastery_score": 0.3, "subject": "CS101", + "mastery_score": 0.3, "course_id": "c1", }] mock_table.return_value.insert.return_value = None mock_graph.return_value = {"nodes": [], "edges": []} mock_ctx.return_value = { - "common_misconceptions": ["Dangling pointers", "Memory leaks"], - "weak_areas": ["Pointer arithmetic"], + "course_summary": {"avg_class_mastery": 0.4}, + "concept_stats": [ + { + "concept_name": "Pointers", + "common_misconceptions": ["Dangling pointers", "Memory leaks"], + "prerequisite_gaps": ["Pointer arithmetic"], + } + ], } mock_gemini.return_value = {"questions": []} @@ -453,7 +551,7 @@ def test_no_augmentation_when_ctx_empty( ): mock_table.return_value.select.return_value = [{ "id": "node-abc", "concept_name": "Loops", - "mastery_score": 0.5, "subject": "CS101", + "mastery_score": 0.5, "course_id": "c1", }] mock_table.return_value.insert.return_value = None mock_graph.return_value = {"nodes": [], "edges": []} @@ -463,8 +561,6 @@ def test_no_augmentation_when_ctx_empty( generate_quiz(self._make_generate_body()) actual_prompt = mock_gemini.call_args[0][0] - # The base quiz_generation.txt mentions "misconceptions" in its rules; - # assert the course-level addendum header is NOT present when ctx is empty. self.assertNotIn("Common misconceptions seen across the class", actual_prompt) @patch("services.course_context_service.get_course_context") @@ -477,13 +573,19 @@ def test_augmentation_capped_at_10_items( ): mock_table.return_value.select.return_value = [{ "id": "node-abc", "concept_name": "Pointers", - "mastery_score": 0.3, "subject": "CS101", + "mastery_score": 0.3, "course_id": "c1", }] mock_table.return_value.insert.return_value = None mock_graph.return_value = {"nodes": [], "edges": []} mock_ctx.return_value = { - "common_misconceptions": [f"mistake_{i}" for i in range(20)], - "weak_areas": [], + "course_summary": {"avg_class_mastery": 0.4}, + "concept_stats": [ + { + "concept_name": "Pointers", + "common_misconceptions": [f"mistake_{i}" for i in range(20)], + "prerequisite_gaps": [], + } + ], } mock_gemini.return_value = {"questions": []} @@ -498,12 +600,12 @@ def test_augmentation_capped_at_10_items( @patch("routes.quiz.get_quiz_context", return_value=None) @patch("routes.quiz.get_graph") @patch("routes.quiz.table") - def test_no_augmentation_when_node_has_no_subject( + def test_no_augmentation_when_node_has_no_course_id( self, mock_table, mock_graph, mock_quiz_ctx, mock_gemini ): mock_table.return_value.select.return_value = [{ "id": "node-abc", "concept_name": "GenericConcept", - "mastery_score": 0.5, "subject": "", + "mastery_score": 0.5, "course_id": "", }] mock_table.return_value.insert.return_value = None mock_graph.return_value = {"nodes": [], "edges": []} diff --git a/frontend/src/__tests__/api.test.ts b/frontend/src/__tests__/api.test.ts index fbbc5d4..00465aa 100644 --- a/frontend/src/__tests__/api.test.ts +++ b/frontend/src/__tests__/api.test.ts @@ -113,17 +113,17 @@ describe('getCourses', () => { }); describe('addCourse', () => { - it('POST /api/graph/:userId/courses with course_name', async () => { - mockFetch({ course_name: 'Math', already_existed: false }); + it('POST /api/graph/:userId/courses with course_id', async () => { + mockFetch({ course_id: 'Math', already_existed: false }); await addCourse('user_andres', 'Math'); const [url, opts] = lastCall(); expect(url).toBe('/api/graph/user_andres/courses'); expect(opts?.method).toBe('POST'); - expect(JSON.parse(opts?.body as string)).toMatchObject({ course_name: 'Math' }); + expect(JSON.parse(opts?.body as string)).toMatchObject({ course_id: 'Math' }); }); it('includes color when provided', async () => { - mockFetch({ course_name: 'Math', already_existed: false }); + mockFetch({ course_id: 'Math', already_existed: false }); await addCourse('user_andres', 'Math', '#ff0000'); const body = JSON.parse(lastCall()[1]?.body as string); expect(body.color).toBe('#ff0000'); @@ -292,7 +292,7 @@ describe('extractSyllabus', () => { describe('saveAssignments', () => { it('POST /api/calendar/save with user_id and assignments', async () => { mockFetch({ saved_count: 2 }); - const assignments = [{ title: 'HW1', due_date: '2026-03-01', assignment_type: 'homework', course_name: '' }]; + const assignments = [{ title: 'HW1', due_date: '2026-03-01', assignment_type: 'homework', course_id: 'c1' }]; await saveAssignments('user_andres', assignments); const [url, opts] = lastCall(); expect(url).toBe('/api/calendar/save'); diff --git a/frontend/src/app/calendar/page.tsx b/frontend/src/app/calendar/page.tsx index c80c923..00b3c9e 100644 --- a/frontend/src/app/calendar/page.tsx +++ b/frontend/src/app/calendar/page.tsx @@ -13,6 +13,9 @@ import { syncToGoogleCalendar, importGoogleEvents, disconnectGoogleCalendar, + getCourses, + type SaveAssignmentItem, + type EnrolledCourse, } from '@/lib/api'; import { useUser } from '@/context/UserContext'; @@ -53,7 +56,7 @@ const TYPE_COLORS: Record }; function AssignmentChip({ a, isMobile }: { a: Assignment; isMobile?: boolean }) { - const c = TYPE_COLORS[a.assignment_type] ?? TYPE_COLORS.other; + const c = TYPE_COLORS[a.assignment_type ?? 'other'] ?? TYPE_COLORS.other; return (
{dayAssignments.map(a => { - const c = TYPE_COLORS[a.assignment_type] ?? TYPE_COLORS.other; + const c = TYPE_COLORS[a.assignment_type ?? 'other'] ?? TYPE_COLORS.other; return (
{a.title} @@ -321,6 +324,51 @@ function CalendarGrid({ assignments }: { assignments: Assignment[] }) { ); } + +function normalizeAssignments(items: any[]): Assignment[] { + return (items ?? []).map((a: any, index: number) => ({ + id: a.id ?? `missing-id-${index}`, + title: a.title ?? '', + course_name: a.course_name ?? '', + course_code: a.course_code ?? '', + course_id: a.course_id ?? '', + due_date: a.due_date ?? '', + assignment_type: a.assignment_type ?? 'other', + notes: a.notes ?? null, + google_event_id: a.google_event_id ?? null, + })); +} + +function buildCourseNameToIdMap(courses: EnrolledCourse[]): Map { + const m = new Map(); + for (const c of courses) { + const label = c.course_code ? `${c.course_code} - ${c.course_name}` : c.course_name; + const keys = [c.course_name, label, c.nickname].filter(Boolean) as string[]; + for (const k of keys) { + const lower = k.toLowerCase(); + if (!m.has(lower)) m.set(lower, c.course_id); + } + } + return m; +} + +function toSaveItems(assignments: Assignment[], nameToId: Map): SaveAssignmentItem[] { + return assignments.map(a => { + let courseId = (a.course_id ?? '').trim(); + if (!courseId && a.course_name?.trim()) { + const key = a.course_name.trim().toLowerCase(); + courseId = nameToId.get(key) ?? ''; + } + return { + title: a.title, + course_id: courseId, + due_date: a.due_date, + assignment_type: a.assignment_type ?? 'other', + notes: a.notes ?? undefined, + }; + }); +} + function CalendarInner() { const { userId: USER_ID, userReady } = useUser(); const searchParams = useSearchParams(); @@ -340,6 +388,7 @@ function CalendarInner() { const [googleConnected, setGoogleConnected] = useState(false); const [googleEvents, setGoogleEvents] = useState([]); const [importingGoogle, setImportingGoogle] = useState(false); + const [courseNameToId, setCourseNameToId] = useState>(() => new Map()); /** Orange (notice) vs red (error) — matches syllabusInlineMessage fallback when kind is stale. */ const syllabusNoticeStyle = @@ -351,11 +400,14 @@ function CalendarInner() { useEffect(() => { if (!userReady) return; getAllAssignments(USER_ID) - .then(data => setAssignments(data.assignments ?? [])) + .then(data => setAssignments(normalizeAssignments(data.assignments ?? []))) .catch(console.error); getCalendarStatus(USER_ID) .then(res => setGoogleConnected(res.connected)) .catch(() => {}); + getCourses(USER_ID) + .then(data => setCourseNameToId(buildCourseNameToIdMap(data.courses ?? []))) + .catch(console.error); }, [USER_ID, userReady]); // Handle OAuth redirect (?connected=true) once on mount, independently @@ -390,6 +442,8 @@ function CalendarInner() { id: `extracted_${i}_${Date.now()}`, title: a.title ?? '', course_name: a.course_name ?? '', + course_code: a.course_code ?? '', + course_id: a.course_id ?? '', due_date: a.due_date ?? '', assignment_type: a.assignment_type ?? 'other', notes: a.notes ?? null, @@ -408,9 +462,9 @@ function CalendarInner() { const handleSaveDetected = async () => { setSaving(true); try { - await saveAssignments(USER_ID, extractedAssignments); + await saveAssignments(USER_ID, toSaveItems(extractedAssignments, courseNameToId)); const data = await getAllAssignments(USER_ID); - setAssignments(data.assignments ?? []); + setAssignments(normalizeAssignments(data.assignments ?? [])); setExtractedAssignments([]); setFileProcessed(false); setWarnings([]); @@ -443,7 +497,7 @@ function CalendarInner() { setSyncedCount(res.synced_count); // Refresh so google_event_id values are up to date getAllAssignments(USER_ID) - .then(data => setAssignments(data.assignments ?? [])) + .then(data => setAssignments(normalizeAssignments(data.assignments ?? []))) .catch(console.error); } catch (e: any) { alert(e.message); diff --git a/frontend/src/app/dashboard/page.tsx b/frontend/src/app/dashboard/page.tsx index 22005c6..2eed773 100644 --- a/frontend/src/app/dashboard/page.tsx +++ b/frontend/src/app/dashboard/page.tsx @@ -4,7 +4,16 @@ import { useEffect, useLayoutEffect, useState, useRef, useMemo, useCallback, Sus import { useRouter, useSearchParams } from 'next/navigation'; import KnowledgeGraph from '@/components/KnowledgeGraph'; import { GraphNode, GraphStats, Recommendation, Assignment } from '@/lib/types'; -import { getGraph, getRecommendations, getUpcomingAssignments, getCourses, addCourse, deleteCourse, updateCourseColor } from '@/lib/api'; +import { + getGraph, + getRecommendations, + getUpcomingAssignments, + getCourses, + addCourse, + deleteCourse, + updateCourseColor, + type EnrolledCourse, +} from '@/lib/api'; import { getMasteryColor, getMasteryLabel, formatDueDate, formatRelativeTime, getCourseColor, PRESET_COURSE_COLORS, RAINBOW_COLORS } from '@/lib/graphUtils'; import { useUser } from '@/context/UserContext'; import Link from 'next/link'; @@ -106,7 +115,7 @@ function DashboardInner() { // Courses panel state const [showCourses, setShowCourses] = useState(false); - const [courseList, setCourseList] = useState<{ id: string; course_name: string; color: string | null; node_count: number }[]>([]); + const [courseList, setCourseList] = useState([]); const [courseColorMap, setCourseColorMap] = useState>({}); const [newCourseName, setNewCourseName] = useState(''); const [courseAdding, setCourseAdding] = useState(false); @@ -312,17 +321,22 @@ function DashboardInner() { setEdges(graphData.edges); } } catch (e: any) { - setCourseError(e.message || 'Failed to add course.'); + let msg = e.message || 'Failed to add course.'; + try { + const j = JSON.parse(msg); + if (j.detail) msg = typeof j.detail === 'string' ? j.detail : JSON.stringify(j.detail); + } catch { /* keep msg */ } + setCourseError(msg); } finally { setCourseAdding(false); } }; - const handleColorChange = async (courseName: string, newHex: string) => { + const handleColorChange = async (courseId: string, courseName: string, newHex: string) => { if (!/^#[0-9a-fA-F]{6}$/.test(newHex)) return; try { - await updateCourseColor(userId, courseName, newHex); - setCourseList(prev => prev.map(c => c.course_name === courseName ? { ...c, color: newHex } : c)); + await updateCourseColor(userId, courseId, newHex); + setCourseList(prev => prev.map(c => (c.course_id === courseId ? { ...c, color: newHex } : c))); setCourseColorMap(prev => ({ ...prev, [courseName]: newHex })); setEditingColorFor(null); } catch (e) { @@ -330,11 +344,11 @@ function DashboardInner() { } }; - const handleDeleteCourse = async (courseName: string) => { - setCourseDeleting(courseName); + const handleDeleteCourse = async (courseId: string) => { + setCourseDeleting(courseId); try { - await deleteCourse(userId, courseName); - setCourseList(prev => prev.filter(c => c.course_name !== courseName)); + await deleteCourse(userId, courseId); + setCourseList(prev => prev.filter(c => c.course_id !== courseId)); // Refresh graph so the removed subject-root node disappears const graphData = await getGraph(userId); setNodes(graphData.nodes); @@ -892,7 +906,8 @@ function DashboardInner() { ) : (
{allAssignments.slice(0, 4).map(a => { - const c = getCourseColor(a.course_name, courseColorMap[a.course_name]); + const cn = a.course_name ?? ''; + const c = getCourseColor(cn, courseColorMap[cn]); return (
{ const color = getCourseColor(c.course_name, c.color); - const isDeleting = courseDeleting === c.course_name; - const isEditingColor = editingColorFor === c.course_name; + const isDeleting = courseDeleting === c.course_id; + const isEditingColor = editingColorFor === c.course_id; return ( -
+
{/* Main row */}
{/* Color swatch — click to open picker */}
- {confirmDeleteCourse === c.course_name ? ( + {confirmDeleteCourse === c.course_id ? (
) : (
diff --git a/frontend/src/app/library/page.tsx b/frontend/src/app/library/page.tsx index 7159cfc..2e13981 100644 --- a/frontend/src/app/library/page.tsx +++ b/frontend/src/app/library/page.tsx @@ -1,7 +1,15 @@ 'use client'; import { useEffect, useState, useMemo, useRef, DragEvent } from 'react'; -import { getCourses, addCourse, getDocuments, uploadDocument, deleteDocument, updateDocument } from '@/lib/api'; +import { + getCourses, + addCourse, + getDocuments, + uploadDocument, + deleteDocument, + updateDocument, + type EnrolledCourse, +} from '@/lib/api'; import CustomSelect from '@/components/CustomSelect'; import { getCourseColor, PRESET_COURSE_COLORS } from '@/lib/graphUtils'; import { useUser } from '@/context/UserContext'; @@ -83,7 +91,6 @@ function formatDate(iso: string): string { return new Date(iso).toLocaleDateString('en-US', { month: 'long', day: 'numeric', year: 'numeric' }); } -interface Course { id: string; course_name: string; color: string | null; node_count: number; } interface Flashcard { question: string; answer: string; } interface Doc { id: string; course_id: string; file_name: string; category: string; @@ -102,7 +109,7 @@ export default function LibraryPage() { const { userId, userReady } = useUser(); const isMobile = useIsMobile(); - const [courses, setCourses] = useState([]); + const [courses, setCourses] = useState([]); const [docs, setDocs] = useState([]); const [loading, setLoading] = useState(true); @@ -167,8 +174,10 @@ export default function LibraryPage() { }, [userId, userReady]); const courseById = useMemo(() => { - const m: Record = {}; - courses.forEach(c => { m[c.id] = c; }); + const m: Record = {}; + courses.forEach(c => { + m[c.course_id] = c; + }); return m; }, [courses]); @@ -351,13 +360,18 @@ export default function LibraryPage() { } else { const updated = await getCourses(userId); setCourses(updated.courses); - const created = updated.courses.find(c => c.course_name === name); - if (created) setSelectedCourseId(created.id); + const created = updated.courses.find(c => c.course_id === res.course_id); + if (created) setSelectedCourseId(created.course_id); setNewCourseName(''); setShowAddCourse(false); } } catch (e: any) { - setCourseAddError(e.message || 'Failed to add course.'); + let msg = e.message || 'Failed to add course.'; + try { + const j = JSON.parse(msg); + if (j.detail) msg = typeof j.detail === 'string' ? j.detail : JSON.stringify(j.detail); + } catch { /* keep msg */ } + setCourseAddError(msg); } finally { setCourseAdding(false); } @@ -407,8 +421,8 @@ export default function LibraryPage() { {courses.map(c => { const col = getCourseColor(c.course_name, c.color); return ( - ); @@ -675,7 +689,7 @@ export default function LibraryPage() { value={selectedCourseId} onChange={setSelectedCourseId} placeholder="Select a course…" - options={courses.map(c => ({ value: c.id, label: c.course_name }))} + options={courses.map(c => ({ value: c.course_id, label: c.course_name }))} style={{ width: '100%', display: 'block' }} /> diff --git a/frontend/src/components/AssignmentTable.tsx b/frontend/src/components/AssignmentTable.tsx index 175d861..c15d46c 100644 --- a/frontend/src/components/AssignmentTable.tsx +++ b/frontend/src/components/AssignmentTable.tsx @@ -74,6 +74,8 @@ export default function AssignmentTable({ assignments, onChange, selectedIds, on id: `temp_${Date.now()}`, title: '', course_name: '', + course_code: '', + course_id: '', due_date: '', assignment_type: 'homework', notes: null, @@ -256,7 +258,7 @@ export default function AssignmentTable({ assignments, onChange, selectedIds, on update(index, 'assignment_type', val)} options={TYPES.map(t => ({ value: t, label: t }))} compact diff --git a/frontend/src/lib/api.ts b/frontend/src/lib/api.ts index 4326f94..78f22be 100644 --- a/frontend/src/lib/api.ts +++ b/frontend/src/lib/api.ts @@ -27,35 +27,54 @@ export const getGraph = (userId: string) => export const getRecommendations = (userId: string) => fetchJSON<{ recommendations: any[] }>(`/api/graph/${userId}/recommendations`); +export interface EnrolledCourse { + enrollment_id: string; + course_id: string; + course_code: string; + course_name: string; + school: string; + department: string; + color: string | null; + nickname: string | null; + node_count: number; + enrolled_at: string; +} + export const getCourses = (userId: string) => - fetchJSON<{ courses: { id: string; course_name: string; color: string | null; node_count: number; created_at: string }[] }>( - `/api/graph/${userId}/courses` - ); + fetchJSON<{ courses: EnrolledCourse[] }>(`/api/graph/${userId}/courses`); -export const addCourse = (userId: string, courseName: string, color?: string) => - fetchJSON<{ course_name: string; already_existed: boolean }>(`/api/graph/${userId}/courses`, { +export const addCourse = (userId: string, courseId: string, color?: string, nickname?: string) => + fetchJSON<{ course_id: string; already_existed: boolean; error?: string }>(`/api/graph/${userId}/courses`, { method: 'POST', - body: JSON.stringify({ course_name: courseName, ...(color ? { color } : {}) }), + body: JSON.stringify({ course_id: courseId, ...(color ? { color } : {}), ...(nickname ? { nickname } : {}) }), }); -export const updateCourseColor = (userId: string, courseName: string, color: string) => +export const updateCourseColor = (userId: string, courseId: string, color: string) => fetchJSON<{ updated: boolean }>( - `/api/graph/${userId}/courses/${encodeURIComponent(courseName)}/color`, + `/api/graph/${userId}/courses/${encodeURIComponent(courseId)}/color`, { method: 'PATCH', body: JSON.stringify({ color }) } ); -export const deleteCourse = (userId: string, courseName: string) => +export const deleteCourse = (userId: string, courseId: string) => fetchJSON<{ deleted: boolean }>( - `/api/graph/${userId}/courses/${encodeURIComponent(courseName)}`, + `/api/graph/${userId}/courses/${encodeURIComponent(courseId)}`, { method: 'DELETE' } ); // ── Learn ───────────────────────────────────────────────────────────────────── -export const startSession = (userId: string, topic: string, mode: string, useSharedContext = true) => +export interface StartSessionRequest { + user_id: string; + topic: string; + mode: string; + use_shared_context?: boolean; + course_id?: string; +} + +export const startSession = (userId: string, topic: string, mode: string, courseId?: string, useSharedContext = true) => fetchJSON<{ session_id: string; initial_message: string; graph_state: any }>('/api/learn/start-session', { method: 'POST', - body: JSON.stringify({ user_id: userId, topic, mode, use_shared_context: useSharedContext }), + body: JSON.stringify({ user_id: userId, topic, mode, use_shared_context: useSharedContext, course_id: courseId }), }); export const sendChat = (sessionId: string, userId: string, message: string, mode: string, useSharedContext = true) => @@ -76,18 +95,19 @@ export const endSession = (sessionId: string, userId: string) => body: JSON.stringify({ session_id: sessionId, user_id: userId }), }); +export interface Session { + id: string; + topic: string; + mode: string; + course_id: string | null; + started_at: string; + ended_at: string | null; + message_count: number; + is_active: boolean; +} + export const getSessions = (userId: string, limit = 10) => - fetchJSON<{ - sessions: { - id: string; - topic: string; - mode: string; - started_at: string; - ended_at: string | null; - message_count: number; - is_active: boolean; - }[]; - }>(`/api/learn/sessions/${userId}?limit=${limit}`); + fetchJSON<{ sessions: Session[] }>(`/api/learn/sessions/${userId}?limit=${limit}`); export const switchMode = (sessionId: string, userId: string, newMode: string) => fetchJSON<{ reply: string }>('/api/learn/mode-switch', { @@ -103,7 +123,7 @@ export const deleteSession = (sessionId: string, userId: string) => export const resumeSession = (sessionId: string) => fetchJSON<{ - session: { id: string; topic: string; mode: string; started_at: string; ended_at: string | null }; + session: { id: string; user_id: string; topic: string; mode: string; course_id: string | null; started_at: string; ended_at: string | null }; messages: { id: string; role: string; content: string; created_at: string }[]; }>(`/api/learn/sessions/${sessionId}/resume`); @@ -152,13 +172,34 @@ export const extractSyllabus = (formData: FormData, userId?: string): Promise - fetchJSON<{ assignments: any[] }>(`/api/calendar/upcoming/${userId}`); + fetchJSON<{ assignments: Assignment[] }>(`/api/calendar/upcoming/${userId}`); export const getAllAssignments = (userId: string) => - fetchJSON<{ assignments: any[] }>(`/api/calendar/all/${userId}`); + fetchJSON<{ assignments: Assignment[] }>(`/api/calendar/all/${userId}`); + +export interface SaveAssignmentItem { + title: string; + course_id: string; + due_date: string; + assignment_type?: string; + notes?: string; +} -export const saveAssignments = (userId: string, assignments: any[]) => +export const saveAssignments = (userId: string, assignments: SaveAssignmentItem[]) => fetchJSON<{ saved_count: number }>('/api/calendar/save', { method: 'POST', body: JSON.stringify({ user_id: userId, assignments }), @@ -349,4 +390,4 @@ export const submitJobApplication = async (data: { throw new Error(err || `HTTP ${res.status}`); } return res.json(); -}; \ No newline at end of file +}; diff --git a/frontend/src/lib/types.ts b/frontend/src/lib/types.ts index 524fc00..5e110b7 100644 --- a/frontend/src/lib/types.ts +++ b/frontend/src/lib/types.ts @@ -6,6 +6,7 @@ export interface GraphNode { times_studied: number; last_studied_at: string | null; subject: string; + course_id?: string | null; is_subject_root?: boolean; x?: number; y?: number; @@ -96,11 +97,14 @@ export interface QuizContext { export interface Assignment { id: string; title: string; - course_name: string; + course_name?: string; + course_code?: string; + /** Canonical course FK when known (required for save API). */ + course_id?: string; due_date: string; - assignment_type: string; - notes: string | null; - google_event_id: string | null; + assignment_type?: string; + notes?: string | null; + google_event_id?: string | null; } export interface StudyBlockSuggestion {