diff --git a/api/test/api/conftest.py b/api/test/api/conftest.py index 488d9ce21..3aedda974 100644 --- a/api/test/api/conftest.py +++ b/api/test/api/conftest.py @@ -1,7 +1,10 @@ +import asyncio +import os +from contextlib import asynccontextmanager + import pytest +from fastapi import FastAPI from fastapi.testclient import TestClient -import os -import asyncio # Create test directories before setting environment variables os.makedirs("test/tmp/", exist_ok=True) @@ -24,12 +27,28 @@ os.environ["TRANSFORMERLAB_REFRESH_SECRET"] = "test-refresh-secret-for-testing-only" os.environ["EMAIL_METHOD"] = "dev" # Use dev mode for tests (no actual email sending) -# Use in-memory database for tests -os.environ["DATABASE_URL"] = "sqlite+aiosqlite:///:memory:" +# Use in-memory database with shared cache for tests (allows multiple connections to share the same DB) +os.environ["DATABASE_URL"] = "sqlite+aiosqlite:///:memory:?cache=shared" from api import app # noqa: E402 +@asynccontextmanager +async def _test_noop_lifespan(app: FastAPI): + """ + Replace the production lifespan for tests so we don't start + background tasks or subprocesses that keep pytest from exiting. + + We still manually init/seed/close the DB in the test fixtures. + """ + yield + + +# Override the app lifespan for tests to avoid spawning background +# tasks (run_over_and_over, migrations, fastchat controller, etc.). +app.router.lifespan_context = _test_noop_lifespan + + class AuthenticatedTestClient(TestClient): """TestClient that automatically adds admin authentication headers to all requests""" @@ -76,11 +95,9 @@ def request(self, method, url, **kwargs): @pytest.fixture(scope="session") def client(): # Initialize database tables for tests - from transformerlab.shared.models.user_model import create_db_and_tables # noqa: E402 + import transformerlab.db.session as db # noqa: E402 from transformerlab.services.experiment_init import seed_default_admin_user # noqa: E402 - asyncio.run(create_db_and_tables()) - asyncio.run(seed_default_admin_user()) controller_log_dir = os.path.join("test", "tmp", "workspace", "logs") os.makedirs(controller_log_dir, exist_ok=True) controller_log_path = os.path.join(controller_log_dir, "controller.log") @@ -88,5 +105,14 @@ def client(): with open(controller_log_path, "w") as f: f.write("") # Empty dummy file Empty dummy file + # Initialize database and run migrations (replaces create_db_and_tables) + asyncio.run(db.init()) + + # Seed admin user BEFORE creating the test client (client tries to login in __init__) + asyncio.run(seed_default_admin_user()) + with AuthenticatedTestClient(app) as c: yield c + + # Cleanup: close database connection + asyncio.run(db.close()) diff --git a/api/test/api/test_teams.py b/api/test/api/test_teams.py index 432a24ce1..c8b1bce7b 100644 --- a/api/test/api/test_teams.py +++ b/api/test/api/test_teams.py @@ -5,9 +5,9 @@ def verify_user_in_db(email: str): """Helper to mark a user as verified in the database (for testing)""" import asyncio from sqlalchemy import select - from transformerlab.shared.models.user_model import User + from transformerlab.shared.models.models import User from transformerlab.db.session import async_session - + async def _verify(): async with async_session() as session: stmt = select(User).where(User.email == email) @@ -18,7 +18,7 @@ async def _verify(): await session.commit() return True return False - + # Get or create event loop try: loop = asyncio.get_running_loop() @@ -31,7 +31,7 @@ async def _verify(): except RuntimeError: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - + if asyncio.iscoroutinefunction(_verify): return loop.run_until_complete(_verify()) @@ -52,10 +52,10 @@ def owner_user(client): } resp = client.post("/auth/register", json=user_data) assert resp.status_code in (200, 201, 400) # 400 if already exists - + # Verify user in database (for testing, bypass email verification) verify_user_in_db("owner@test.com") - + # Login login_data = { "username": "owner@test.com", @@ -64,7 +64,7 @@ def owner_user(client): resp = client.post("/auth/jwt/login", data=login_data) assert resp.status_code == 200 token = resp.json()["access_token"] - + return {"email": "owner@test.com", "token": token} @@ -78,10 +78,10 @@ def member_user(client): } resp = client.post("/auth/register", json=user_data) assert resp.status_code in (200, 201, 400) # 400 if already exists - + # Verify user in database (for testing, bypass email verification) verify_user_in_db("member@test.com") - + # Login login_data = { "username": "member@test.com", @@ -90,7 +90,7 @@ def member_user(client): resp = client.post("/auth/jwt/login", data=login_data) assert resp.status_code == 200 token = resp.json()["access_token"] - + return {"email": "member@test.com", "token": token} @@ -110,22 +110,19 @@ def test_team(client, owner_user): @pytest.fixture(scope="function") def member_in_test_team(client, owner_user, member_user, test_team): """Ensure member_user is added to test_team via invitation""" - headers = { - "Authorization": f"Bearer {owner_user['token']}", - "X-Team-Id": test_team["id"] - } - + headers = {"Authorization": f"Bearer {owner_user['token']}", "X-Team-Id": test_team["id"]} + # Check if member is already in the team resp = client.get(f"/teams/{test_team['id']}/members", headers=headers) if resp.status_code == 200: members = resp.json()["members"] if any(m["email"] == member_user["email"] for m in members): return True - + # If not, invite and accept invite_data = {"email": member_user["email"], "role": "member"} resp = client.post(f"/teams/{test_team['id']}/members", json=invite_data, headers=headers) - + if resp.status_code == 200: resp_json = resp.json() # Extract token from invitation_url (format: http://.../#/?invitation_token=TOKEN) @@ -137,7 +134,7 @@ def member_in_test_team(client, owner_user, member_user, test_team): return False else: return False - + headers_member = {"Authorization": f"Bearer {member_user['token']}"} accept_data = {"token": token} resp = client.post("/invitations/accept", json=accept_data, headers=headers_member) @@ -151,14 +148,15 @@ def member_in_test_team(client, owner_user, member_user, test_team): def fresh_owner_user(client): """Create a fresh owner user for invitation tests""" import time + email = f"fresh_owner_{int(time.time() * 1000)}@test.com" user_data = {"email": email, "password": "password123"} resp = client.post("/auth/register", json=user_data) assert resp.status_code in (200, 201) - + # Verify user in database (for testing, bypass email verification) verify_user_in_db(email) - + login_data = {"username": email, "password": "password123"} resp = client.post("/auth/jwt/login", data=login_data) assert resp.status_code == 200 @@ -181,7 +179,7 @@ def test_create_team(client, owner_user): headers = {"Authorization": f"Bearer {owner_user['token']}"} team_data = {"name": "New Team"} resp = client.post("/teams", json=team_data, headers=headers) - + assert resp.status_code == 200 team = resp.json() assert "id" in team @@ -192,16 +190,16 @@ def test_get_user_teams(client, owner_user, test_team): """Test getting user's teams""" headers = {"Authorization": f"Bearer {owner_user['token']}"} resp = client.get("/users/me/teams", headers=headers) - + assert resp.status_code == 200 data = resp.json() assert "teams" in data assert len(data["teams"]) > 0 - + # Check that test_team is in the list team_ids = [t["id"] for t in data["teams"]] assert test_team["id"] in team_ids - + # Check that the user has owner role for the test team test_team_data = next(t for t in data["teams"] if t["id"] == test_team["id"]) assert test_team_data["role"] == "owner" @@ -209,21 +207,18 @@ def test_get_user_teams(client, owner_user, test_team): def test_list_team_members(client, owner_user, test_team): """Test listing team members""" - headers = { - "Authorization": f"Bearer {owner_user['token']}", - "X-Team-Id": test_team["id"] - } + headers = {"Authorization": f"Bearer {owner_user['token']}", "X-Team-Id": test_team["id"]} resp = client.get(f"/teams/{test_team['id']}/members", headers=headers) - + assert resp.status_code == 200 data = resp.json() assert "members" in data assert len(data["members"]) >= 1 - + # Check owner is in the list emails = [m["email"] for m in data["members"]] assert owner_user["email"] in emails - + # Check owner has owner role owner_data = next(m for m in data["members"] if m["email"] == owner_user["email"]) assert owner_data["role"] == "owner" @@ -231,16 +226,10 @@ def test_list_team_members(client, owner_user, test_team): def test_invite_member(client, owner_user, member_user, test_team): """Test inviting a member to the team""" - headers = { - "Authorization": f"Bearer {owner_user['token']}", - "X-Team-Id": test_team["id"] - } - invite_data = { - "email": member_user["email"], - "role": "member" - } + headers = {"Authorization": f"Bearer {owner_user['token']}", "X-Team-Id": test_team["id"]} + invite_data = {"email": member_user["email"], "role": "member"} resp = client.post(f"/teams/{test_team['id']}/members", json=invite_data, headers=headers) - + assert resp.status_code == 200 data = resp.json() assert data["message"] == "Invitation created successfully" @@ -251,7 +240,7 @@ def test_invite_member(client, owner_user, member_user, test_team): # In dev mode, email_sent should be True (no actual sending, just logged) assert data["email_sent"] assert data["email_error"] is None - + # Accept the invitation token = data["invitation_url"].split("token=")[1] headers_member = {"Authorization": f"Bearer {member_user['token']}"} @@ -262,20 +251,14 @@ def test_invite_member(client, owner_user, member_user, test_team): def test_invite_duplicate_member(client, owner_user, test_team): """Test sending duplicate invitations to the same email""" - headers = { - "Authorization": f"Bearer {owner_user['token']}", - "X-Team-Id": test_team["id"] - } - invite_data = { - "email": "duplicate_test@test.com", - "role": "member" - } - + headers = {"Authorization": f"Bearer {owner_user['token']}", "X-Team-Id": test_team["id"]} + invite_data = {"email": "duplicate_test@test.com", "role": "member"} + # First invitation resp1 = client.post(f"/teams/{test_team['id']}/members", json=invite_data, headers=headers) assert resp1.status_code == 200 first_url = resp1.json()["invitation_url"] - + # Second invitation - should return same URL (idempotent) resp2 = client.post(f"/teams/{test_team['id']}/members", json=invite_data, headers=headers) assert resp2.status_code == 200 @@ -285,16 +268,10 @@ def test_invite_duplicate_member(client, owner_user, test_team): def test_invite_nonexistent_user(client, owner_user, test_team): """Test inviting a user that doesn't exist""" - headers = { - "Authorization": f"Bearer {owner_user['token']}", - "X-Team-Id": test_team["id"] - } - invite_data = { - "email": "nonexistent@test.com", - "role": "member" - } + headers = {"Authorization": f"Bearer {owner_user['token']}", "X-Team-Id": test_team["id"]} + invite_data = {"email": "nonexistent@test.com", "role": "member"} resp = client.post(f"/teams/{test_team['id']}/members", json=invite_data, headers=headers) - + # Invitation is created even if user doesn't exist yet assert resp.status_code == 200 data = resp.json() @@ -306,12 +283,9 @@ def test_invite_nonexistent_user(client, owner_user, test_team): def test_member_can_view_members(client, member_user, test_team, member_in_test_team): """Test that a member can view team members""" - headers = { - "Authorization": f"Bearer {member_user['token']}", - "X-Team-Id": test_team["id"] - } + headers = {"Authorization": f"Bearer {member_user['token']}", "X-Team-Id": test_team["id"]} resp = client.get(f"/teams/{test_team['id']}/members", headers=headers) - + assert resp.status_code == 200 data = resp.json() assert "members" in data @@ -320,27 +294,20 @@ def test_member_can_view_members(client, member_user, test_team, member_in_test_ def test_update_member_role_to_owner(client, owner_user, member_user, test_team, member_in_test_team): """Test promoting a member to owner""" - headers = { - "Authorization": f"Bearer {owner_user['token']}", - "X-Team-Id": test_team["id"] - } - + headers = {"Authorization": f"Bearer {owner_user['token']}", "X-Team-Id": test_team["id"]} + # Get the member's user_id resp = client.get(f"/teams/{test_team['id']}/members", headers=headers) members = resp.json()["members"] member_data = next((m for m in members if m["email"] == member_user["email"]), None) - + assert member_data is not None member_id = member_data["user_id"] - + # Promote to owner role_data = {"role": "owner"} - resp = client.put( - f"/teams/{test_team['id']}/members/{member_id}/role", - json=role_data, - headers=headers - ) - + resp = client.put(f"/teams/{test_team['id']}/members/{member_id}/role", json=role_data, headers=headers) + assert resp.status_code == 200 data = resp.json() assert data["message"] == "Role updated successfully" @@ -349,27 +316,20 @@ def test_update_member_role_to_owner(client, owner_user, member_user, test_team, def test_update_member_role_to_member(client, owner_user, member_user, test_team, member_in_test_team): """Test demoting an owner to member""" - headers = { - "Authorization": f"Bearer {owner_user['token']}", - "X-Team-Id": test_team["id"] - } - + headers = {"Authorization": f"Bearer {owner_user['token']}", "X-Team-Id": test_team["id"]} + # Get the member's user_id resp = client.get(f"/teams/{test_team['id']}/members", headers=headers) members = resp.json()["members"] member_data = next((m for m in members if m["email"] == member_user["email"]), None) - + assert member_data is not None member_id = member_data["user_id"] - + # Demote to member role_data = {"role": "member"} - resp = client.put( - f"/teams/{test_team['id']}/members/{member_id}/role", - json=role_data, - headers=headers - ) - + resp = client.put(f"/teams/{test_team['id']}/members/{member_id}/role", json=role_data, headers=headers) + assert resp.status_code == 200 data = resp.json() assert data["message"] == "Role updated successfully" @@ -379,39 +339,26 @@ def test_update_member_role_to_member(client, owner_user, member_user, test_team def test_cannot_demote_last_owner(client, owner_user, test_team): """Test that the last owner cannot be demoted""" # Get owner's user_id - headers = { - "Authorization": f"Bearer {owner_user['token']}", - "X-Team-Id": test_team["id"] - } + headers = {"Authorization": f"Bearer {owner_user['token']}", "X-Team-Id": test_team["id"]} resp = client.get(f"/teams/{test_team['id']}/members", headers=headers) members = resp.json()["members"] owner_data = next(m for m in members if m["email"] == owner_user["email"]) owner_id = owner_data["user_id"] - + # Try to demote role_data = {"role": "member"} - resp = client.put( - f"/teams/{test_team['id']}/members/{owner_id}/role", - json=role_data, - headers=headers - ) - + resp = client.put(f"/teams/{test_team['id']}/members/{owner_id}/role", json=role_data, headers=headers) + assert resp.status_code == 400 assert "last owner" in resp.json()["detail"].lower() def test_member_cannot_invite(client, member_user, test_team, member_in_test_team): """Test that a member cannot invite other users""" - headers = { - "Authorization": f"Bearer {member_user['token']}", - "X-Team-Id": test_team["id"] - } - invite_data = { - "email": "another@test.com", - "role": "member" - } + headers = {"Authorization": f"Bearer {member_user['token']}", "X-Team-Id": test_team["id"]} + invite_data = {"email": "another@test.com", "role": "member"} resp = client.post(f"/teams/{test_team['id']}/members", json=invite_data, headers=headers) - + assert resp.status_code == 403 assert "owner" in resp.json()["detail"].lower() @@ -419,27 +366,17 @@ def test_member_cannot_invite(client, member_user, test_team, member_in_test_tea def test_member_cannot_update_roles(client, owner_user, member_user, test_team, member_in_test_team): """Test that a member cannot change roles""" # Get owner's user_id - headers = { - "Authorization": f"Bearer {owner_user['token']}", - "X-Team-Id": test_team["id"] - } + headers = {"Authorization": f"Bearer {owner_user['token']}", "X-Team-Id": test_team["id"]} resp = client.get(f"/teams/{test_team['id']}/members", headers=headers) members = resp.json()["members"] owner_data = next(m for m in members if m["email"] == owner_user["email"]) owner_id = owner_data["user_id"] - + # Try to update role as member - headers_member = { - "Authorization": f"Bearer {member_user['token']}", - "X-Team-Id": test_team["id"] - } + headers_member = {"Authorization": f"Bearer {member_user['token']}", "X-Team-Id": test_team["id"]} role_data = {"role": "member"} - resp = client.put( - f"/teams/{test_team['id']}/members/{owner_id}/role", - json=role_data, - headers=headers_member - ) - + resp = client.put(f"/teams/{test_team['id']}/members/{owner_id}/role", json=role_data, headers=headers_member) + assert resp.status_code == 403 assert "owner" in resp.json()["detail"].lower() @@ -450,10 +387,7 @@ def test_remove_member(client, owner_user, member_user, test_team, member_in_tes assert member_in_test_team # Get member's user_id - headers = { - "Authorization": f"Bearer {owner_user['token']}", - "X-Team-Id": test_team["id"] - } + headers = {"Authorization": f"Bearer {owner_user['token']}", "X-Team-Id": test_team["id"]} resp = client.get(f"/teams/{test_team['id']}/members", headers=headers) members = resp.json()["members"] member_data = next((m for m in members if m["email"] == member_user["email"]), None) @@ -463,10 +397,7 @@ def test_remove_member(client, owner_user, member_user, test_team, member_in_tes member_id = member_data["user_id"] # Remove member - resp = client.delete( - f"/teams/{test_team['id']}/members/{member_id}", - headers=headers - ) + resp = client.delete(f"/teams/{test_team['id']}/members/{member_id}", headers=headers) assert resp.status_code == 200 assert resp.json()["message"] == "Member removed successfully" @@ -480,21 +411,15 @@ def test_remove_member(client, owner_user, member_user, test_team, member_in_tes def test_cannot_remove_last_owner(client, owner_user, test_team): """Test that the last owner cannot be removed""" # Get owner's user_id - headers = { - "Authorization": f"Bearer {owner_user['token']}", - "X-Team-Id": test_team["id"] - } + headers = {"Authorization": f"Bearer {owner_user['token']}", "X-Team-Id": test_team["id"]} resp = client.get(f"/teams/{test_team['id']}/members", headers=headers) members = resp.json()["members"] owner_data = next(m for m in members if m["email"] == owner_user["email"]) owner_id = owner_data["user_id"] - + # Try to remove - resp = client.delete( - f"/teams/{test_team['id']}/members/{owner_id}", - headers=headers - ) - + resp = client.delete(f"/teams/{test_team['id']}/members/{owner_id}", headers=headers) + assert resp.status_code == 400 assert "last owner" in resp.json()["detail"].lower() @@ -502,25 +427,16 @@ def test_cannot_remove_last_owner(client, owner_user, test_team): def test_member_cannot_remove_members(client, owner_user, member_user, test_team, member_in_test_team): """Test that a member cannot remove other members""" # Get owner's user_id - headers_owner = { - "Authorization": f"Bearer {owner_user['token']}", - "X-Team-Id": test_team["id"] - } + headers_owner = {"Authorization": f"Bearer {owner_user['token']}", "X-Team-Id": test_team["id"]} resp = client.get(f"/teams/{test_team['id']}/members", headers=headers_owner) members = resp.json()["members"] owner_data = next(m for m in members if m["email"] == owner_user["email"]) owner_id = owner_data["user_id"] - + # Try to remove as member - headers_member = { - "Authorization": f"Bearer {member_user['token']}", - "X-Team-Id": test_team["id"] - } - resp = client.delete( - f"/teams/{test_team['id']}/members/{owner_id}", - headers=headers_member - ) - + headers_member = {"Authorization": f"Bearer {member_user['token']}", "X-Team-Id": test_team["id"]} + resp = client.delete(f"/teams/{test_team['id']}/members/{owner_id}", headers=headers_member) + # Should fail - either because member doesn't have permission or isn't in the team assert resp.status_code in (403, 400) detail = resp.json()["detail"].lower() @@ -529,13 +445,10 @@ def test_member_cannot_remove_members(client, owner_user, member_user, test_team def test_update_team_name(client, owner_user, test_team): """Test updating team name as owner""" - headers = { - "Authorization": f"Bearer {owner_user['token']}", - "X-Team-Id": test_team["id"] - } + headers = {"Authorization": f"Bearer {owner_user['token']}", "X-Team-Id": test_team["id"]} update_data = {"name": "Updated Team Name"} resp = client.put(f"/teams/{test_team['id']}", json=update_data, headers=headers) - + assert resp.status_code == 200 team = resp.json() assert team["name"] == "Updated Team Name" @@ -543,13 +456,10 @@ def test_update_team_name(client, owner_user, test_team): def test_member_cannot_update_team_name(client, member_user, test_team, member_in_test_team): """Test that a member cannot update team name""" - headers = { - "Authorization": f"Bearer {member_user['token']}", - "X-Team-Id": test_team["id"] - } + headers = {"Authorization": f"Bearer {member_user['token']}", "X-Team-Id": test_team["id"]} update_data = {"name": "Hacked Name"} resp = client.put(f"/teams/{test_team['id']}", json=update_data, headers=headers) - + # Should fail - either because member doesn't have permission or isn't in the team assert resp.status_code in (403, 400) detail = resp.json()["detail"].lower() @@ -563,14 +473,11 @@ def test_delete_team(client, owner_user): team_data = {"name": "Team to Delete"} resp = client.post("/teams", json=team_data, headers=headers) team = resp.json() - + # Delete it - headers = { - "Authorization": f"Bearer {owner_user['token']}", - "X-Team-Id": team["id"] - } + headers = {"Authorization": f"Bearer {owner_user['token']}", "X-Team-Id": team["id"]} resp = client.delete(f"/teams/{team['id']}", headers=headers) - + assert resp.status_code == 200 assert resp.json()["message"] == "Team deleted" @@ -592,14 +499,11 @@ def test_cannot_delete_last_team(client, owner_user): if team["name"] == f"{username}'s Team": personal_team = team break - + assert personal_team is not None, "Personal team not found" # Try to delete the personal team - headers = { - "Authorization": f"Bearer {owner_user['token']}", - "X-Team-Id": personal_team["id"] - } + headers = {"Authorization": f"Bearer {owner_user['token']}", "X-Team-Id": personal_team["id"]} resp = client.delete(f"/teams/{personal_team['id']}", headers=headers) # Should fail because it's the personal team (last team) @@ -609,12 +513,9 @@ def test_cannot_delete_last_team(client, owner_user): def test_member_cannot_delete_team(client, member_user, test_team, member_in_test_team): """Test that a member cannot delete a team""" - headers = { - "Authorization": f"Bearer {member_user['token']}", - "X-Team-Id": test_team["id"] - } + headers = {"Authorization": f"Bearer {member_user['token']}", "X-Team-Id": test_team["id"]} resp = client.delete(f"/teams/{test_team['id']}", headers=headers) - + # Should fail - either because member doesn't have permission or isn't in the team assert resp.status_code in (403, 400) detail = resp.json()["detail"].lower() @@ -625,21 +526,18 @@ def test_cannot_delete_team_with_multiple_users(client, owner_user, member_user, """Test that a team with multiple users cannot be deleted""" # Ensure member was added successfully assert member_in_test_team - - headers = { - "Authorization": f"Bearer {owner_user['token']}", - "X-Team-Id": test_team["id"] - } - + + headers = {"Authorization": f"Bearer {owner_user['token']}", "X-Team-Id": test_team["id"]} + # Verify team has multiple members before attempting delete resp = client.get(f"/teams/{test_team['id']}/members", headers=headers) assert resp.status_code == 200 members = resp.json()["members"] assert len(members) >= 2, f"Team should have at least 2 members, but has {len(members)}" - + # Now try to delete - should fail resp = client.delete(f"/teams/{test_team['id']}", headers=headers) - + assert resp.status_code == 400 assert "multiple users" in resp.json()["detail"].lower() @@ -657,10 +555,10 @@ def invited_user(client): } resp = client.post("/auth/register", json=user_data) assert resp.status_code in (200, 201, 400) # 400 if already exists - + # Verify user in database (for testing, bypass email verification) verify_user_in_db("invited@test.com") - + # Login login_data = { "username": "invited@test.com", @@ -669,7 +567,7 @@ def invited_user(client): resp = client.post("/auth/jwt/login", data=login_data) assert resp.status_code == 200 token = resp.json()["access_token"] - + return {"email": "invited@test.com", "token": token} @@ -683,10 +581,10 @@ def reject_user(client): } resp = client.post("/auth/register", json=user_data) assert resp.status_code in (200, 201, 400) # 400 if already exists - + # Verify user in database (for testing, bypass email verification) verify_user_in_db("reject@test.com") - + # Login login_data = { "username": "reject@test.com", @@ -695,22 +593,16 @@ def reject_user(client): resp = client.post("/auth/jwt/login", data=login_data) assert resp.status_code == 200 token = resp.json()["access_token"] - + return {"email": "reject@test.com", "token": token} def test_create_invitation(client, fresh_owner_user, fresh_test_team): """Test creating a team invitation""" - headers = { - "Authorization": f"Bearer {fresh_owner_user['token']}", - "X-Team-Id": fresh_test_team["id"] - } - invitation_data = { - "email": "newinvite@test.com", - "role": "member" - } + headers = {"Authorization": f"Bearer {fresh_owner_user['token']}", "X-Team-Id": fresh_test_team["id"]} + invitation_data = {"email": "newinvite@test.com", "role": "member"} resp = client.post(f"/teams/{fresh_test_team['id']}/members", json=invitation_data, headers=headers) - + assert resp.status_code == 200 data = resp.json() assert data["message"] == "Invitation created successfully" @@ -724,27 +616,21 @@ def test_create_invitation(client, fresh_owner_user, fresh_test_team): def test_duplicate_invitation_returns_existing_url(client, fresh_owner_user, fresh_test_team): """Test that creating duplicate invitation returns existing URL""" - headers = { - "Authorization": f"Bearer {fresh_owner_user['token']}", - "X-Team-Id": fresh_test_team["id"] - } - invitation_data = { - "email": "duplicate@test.com", - "role": "member" - } - + headers = {"Authorization": f"Bearer {fresh_owner_user['token']}", "X-Team-Id": fresh_test_team["id"]} + invitation_data = {"email": "duplicate@test.com", "role": "member"} + # Create first invitation resp1 = client.post(f"/teams/{fresh_test_team['id']}/members", json=invitation_data, headers=headers) assert resp1.status_code == 200 first_url = resp1.json()["invitation_url"] first_id = resp1.json()["invitation_id"] - + # Create duplicate invitation resp2 = client.post(f"/teams/{fresh_test_team['id']}/members", json=invitation_data, headers=headers) assert resp2.status_code == 200 second_url = resp2.json()["invitation_url"] second_id = resp2.json()["invitation_id"] - + # Should return same invitation assert first_url == second_url assert first_id == second_id @@ -753,26 +639,20 @@ def test_duplicate_invitation_returns_existing_url(client, fresh_owner_user, fre def test_get_pending_invitations(client, fresh_owner_user, invited_user, fresh_test_team): """Test getting pending invitations for a user""" # Create invitation - headers = { - "Authorization": f"Bearer {fresh_owner_user['token']}", - "X-Team-Id": fresh_test_team["id"] - } - invitation_data = { - "email": invited_user["email"], - "role": "member" - } + headers = {"Authorization": f"Bearer {fresh_owner_user['token']}", "X-Team-Id": fresh_test_team["id"]} + invitation_data = {"email": invited_user["email"], "role": "member"} resp = client.post(f"/teams/{fresh_test_team['id']}/members", json=invitation_data, headers=headers) assert resp.status_code == 200 - + # Get invitations as invited user headers = {"Authorization": f"Bearer {invited_user['token']}"} resp = client.get("/invitations/me", headers=headers) - + assert resp.status_code == 200 data = resp.json() assert "invitations" in data assert len(data["invitations"]) > 0 - + # Check invitation details invitation = data["invitations"][0] assert invitation["email"] == invited_user["email"] @@ -784,29 +664,23 @@ def test_get_pending_invitations(client, fresh_owner_user, invited_user, fresh_t def test_accept_invitation(client, fresh_owner_user, invited_user, fresh_test_team): """Test accepting a team invitation""" # Create invitation - headers = { - "Authorization": f"Bearer {fresh_owner_user['token']}", - "X-Team-Id": fresh_test_team["id"] - } - invitation_data = { - "email": invited_user["email"], - "role": "member" - } + headers = {"Authorization": f"Bearer {fresh_owner_user['token']}", "X-Team-Id": fresh_test_team["id"]} + invitation_data = {"email": invited_user["email"], "role": "member"} resp = client.post(f"/teams/{fresh_test_team['id']}/members", json=invitation_data, headers=headers) assert resp.status_code == 200 token = resp.json()["invitation_url"].split("token=")[1] - + # Accept invitation headers = {"Authorization": f"Bearer {invited_user['token']}"} accept_data = {"token": token} resp = client.post("/invitations/accept", json=accept_data, headers=headers) - + assert resp.status_code == 200 data = resp.json() assert data["message"] == "Invitation accepted successfully" assert data["team_id"] == fresh_test_team["id"] assert data["role"] == "member" - + # Verify user is now in the team resp = client.get("/users/me/teams", headers=headers) assert resp.status_code == 200 @@ -818,26 +692,20 @@ def test_accept_invitation(client, fresh_owner_user, invited_user, fresh_test_te def test_reject_invitation(client, fresh_owner_user, reject_user, fresh_test_team): """Test rejecting a team invitation""" # Create invitation - headers = { - "Authorization": f"Bearer {fresh_owner_user['token']}", - "X-Team-Id": fresh_test_team["id"] - } - invitation_data = { - "email": reject_user["email"], - "role": "member" - } + headers = {"Authorization": f"Bearer {fresh_owner_user['token']}", "X-Team-Id": fresh_test_team["id"]} + invitation_data = {"email": reject_user["email"], "role": "member"} resp = client.post(f"/teams/{fresh_test_team['id']}/members", json=invitation_data, headers=headers) assert resp.status_code == 200 invitation_id = resp.json()["invitation_id"] - + # Reject invitation headers = {"Authorization": f"Bearer {reject_user['token']}"} resp = client.post(f"/invitations/{invitation_id}/reject", headers=headers) - + assert resp.status_code == 200 data = resp.json() assert data["message"] == "Invitation rejected successfully" - + # Verify user is not in the team resp = client.get("/users/me/teams", headers=headers) assert resp.status_code == 200 @@ -848,27 +716,24 @@ def test_reject_invitation(client, fresh_owner_user, reject_user, fresh_test_tea def test_get_team_invitations(client, fresh_owner_user, fresh_test_team): """Test getting all invitations for a team (owner only)""" - headers = { - "Authorization": f"Bearer {fresh_owner_user['token']}", - "X-Team-Id": fresh_test_team["id"] - } - + headers = {"Authorization": f"Bearer {fresh_owner_user['token']}", "X-Team-Id": fresh_test_team["id"]} + # Create an invitation first invitation_data = {"email": "testinvite@test.com", "role": "member"} resp = client.post(f"/teams/{fresh_test_team['id']}/members", json=invitation_data, headers=headers) assert resp.status_code == 200 - + # Now get all invitations for the team resp = client.get(f"/teams/{fresh_test_team['id']}/invitations", headers=headers) - + assert resp.status_code == 200 data = resp.json() assert "invitations" in data assert data["team_id"] == fresh_test_team["id"] - + # Should have at least the invitation we just created assert len(data["invitations"]) > 0 - + # Check that invitations have all required fields invitation = data["invitations"][0] assert "id" in invitation @@ -881,25 +746,19 @@ def test_get_team_invitations(client, fresh_owner_user, fresh_test_team): def test_cancel_invitation(client, fresh_owner_user, fresh_test_team): """Test cancelling a pending invitation""" # Create invitation - headers = { - "Authorization": f"Bearer {fresh_owner_user['token']}", - "X-Team-Id": fresh_test_team["id"] - } - invitation_data = { - "email": "cancel@test.com", - "role": "member" - } + headers = {"Authorization": f"Bearer {fresh_owner_user['token']}", "X-Team-Id": fresh_test_team["id"]} + invitation_data = {"email": "cancel@test.com", "role": "member"} resp = client.post(f"/teams/{fresh_test_team['id']}/members", json=invitation_data, headers=headers) assert resp.status_code == 200 invitation_id = resp.json()["invitation_id"] - + # Cancel invitation resp = client.delete(f"/teams/{fresh_test_team['id']}/invitations/{invitation_id}", headers=headers) - + assert resp.status_code == 200 data = resp.json() assert data["message"] == "Invitation cancelled successfully" - + # Verify invitation is marked as cancelled resp = client.get(f"/teams/{fresh_test_team['id']}/invitations", headers=headers) assert resp.status_code == 200 @@ -912,23 +771,17 @@ def test_cancel_invitation(client, fresh_owner_user, fresh_test_team): def test_cannot_accept_invitation_wrong_email(client, fresh_owner_user, member_user, fresh_test_team): """Test that a user cannot accept invitation meant for different email""" # Create invitation for one user - headers = { - "Authorization": f"Bearer {fresh_owner_user['token']}", - "X-Team-Id": fresh_test_team["id"] - } - invitation_data = { - "email": "someone@test.com", - "role": "member" - } + headers = {"Authorization": f"Bearer {fresh_owner_user['token']}", "X-Team-Id": fresh_test_team["id"]} + invitation_data = {"email": "someone@test.com", "role": "member"} resp = client.post(f"/teams/{fresh_test_team['id']}/members", json=invitation_data, headers=headers) assert resp.status_code == 200 token = resp.json()["invitation_url"].split("token=")[1] - + # Try to accept as different user headers = {"Authorization": f"Bearer {member_user['token']}"} accept_data = {"token": token} resp = client.post("/invitations/accept", json=accept_data, headers=headers) - + assert resp.status_code == 403 assert "not for your email" in resp.json()["detail"].lower() @@ -936,41 +789,32 @@ def test_cannot_accept_invitation_wrong_email(client, fresh_owner_user, member_u def test_cannot_cancel_non_pending_invitation(client, fresh_owner_user, invited_user, fresh_test_team): """Test that accepted/rejected invitations cannot be cancelled""" # Create and accept invitation - headers = { - "Authorization": f"Bearer {fresh_owner_user['token']}", - "X-Team-Id": fresh_test_team["id"] - } - invitation_data = { - "email": "accepted@test.com", - "role": "member" - } + headers = {"Authorization": f"Bearer {fresh_owner_user['token']}", "X-Team-Id": fresh_test_team["id"]} + invitation_data = {"email": "accepted@test.com", "role": "member"} resp = client.post(f"/teams/{fresh_test_team['id']}/members", json=invitation_data, headers=headers) assert resp.status_code == 200 invitation_id = resp.json()["invitation_id"] token = resp.json()["invitation_url"].split("token=")[1] - + # Register and accept as new user user_data = {"email": "accepted@test.com", "password": "password123"} client.post("/auth/register", json=user_data) - + # Verify user in database (for testing, bypass email verification) verify_user_in_db("accepted@test.com") - + login_data = {"username": "accepted@test.com", "password": "password123"} resp = client.post("/auth/jwt/login", data=login_data) new_token = resp.json()["access_token"] - + headers_new = {"Authorization": f"Bearer {new_token}"} accept_data = {"token": token} resp = client.post("/invitations/accept", json=accept_data, headers=headers_new) assert resp.status_code == 200 - + # Try to cancel the accepted invitation - headers = { - "Authorization": f"Bearer {fresh_owner_user['token']}", - "X-Team-Id": fresh_test_team["id"] - } + headers = {"Authorization": f"Bearer {fresh_owner_user['token']}", "X-Team-Id": fresh_test_team["id"]} resp = client.delete(f"/teams/{fresh_test_team['id']}/invitations/{invitation_id}", headers=headers) - + assert resp.status_code == 400 assert "cannot cancel" in resp.json()["detail"].lower() diff --git a/api/transformerlab/models/users.py b/api/transformerlab/models/users.py index 4b435d693..725fbd904 100644 --- a/api/transformerlab/models/users.py +++ b/api/transformerlab/models/users.py @@ -5,8 +5,9 @@ from fastapi_users import BaseUserManager, FastAPIUsers, UUIDIDMixin, schemas from fastapi_users.authentication import AuthenticationBackend, BearerTransport, JWTStrategy, Strategy from fastapi_users.db import SQLAlchemyUserDatabase -from transformerlab.shared.models.user_model import User, get_async_session, create_personal_team -from transformerlab.shared.models.models import UserTeam, TeamRole +from transformerlab.shared.models.models import User, UserTeam, TeamRole +from transformerlab.db.db import get_async_session +from transformerlab.shared.models.user_model import create_personal_team from transformerlab.utils.email import send_password_reset_email, send_email_verification_link from sqlalchemy.ext.asyncio import AsyncSession import os diff --git a/api/transformerlab/routers/auth.py b/api/transformerlab/routers/auth.py index ac45c0163..c43e41caa 100644 --- a/api/transformerlab/routers/auth.py +++ b/api/transformerlab/routers/auth.py @@ -1,6 +1,7 @@ from fastapi import APIRouter, Depends, HTTPException, Header -from transformerlab.shared.models.user_model import User, get_async_session, create_personal_team -from transformerlab.shared.models.models import Team, UserTeam, TeamRole +from transformerlab.shared.models.models import User, Team, UserTeam, TeamRole +from transformerlab.db.db import get_async_session +from transformerlab.shared.models.user_model import create_personal_team from transformerlab.models.users import ( fastapi_users, auth_backend, diff --git a/api/transformerlab/routers/teams.py b/api/transformerlab/routers/teams.py index 985e16299..4e0c4965f 100644 --- a/api/transformerlab/routers/teams.py +++ b/api/transformerlab/routers/teams.py @@ -1,7 +1,7 @@ from fastapi import APIRouter, Depends, HTTPException from sqlalchemy.ext.asyncio import AsyncSession -from transformerlab.shared.models.user_model import User, get_async_session -from transformerlab.shared.models.models import Team, UserTeam, TeamRole, TeamInvitation, InvitationStatus +from transformerlab.shared.models.models import User, Team, UserTeam, TeamRole, TeamInvitation, InvitationStatus +from transformerlab.db.db import get_async_session from transformerlab.models.users import current_active_user from transformerlab.routers.auth import require_team_owner, get_user_and_team from transformerlab.utils.email import send_team_invitation_email diff --git a/api/transformerlab/services/experiment_init.py b/api/transformerlab/services/experiment_init.py index 2fdd84916..597cc78c0 100644 --- a/api/transformerlab/services/experiment_init.py +++ b/api/transformerlab/services/experiment_init.py @@ -4,7 +4,9 @@ from lab import HOME_DIR from sqlalchemy import select -from transformerlab.shared.models.user_model import User, AsyncSessionLocal, create_personal_team +from transformerlab.shared.models.models import User +from transformerlab.db.session import async_session +from transformerlab.shared.models.user_model import create_personal_team from transformerlab.shared.models.models import UserTeam, TeamRole from transformerlab.models.users import UserManager, UserCreate from fastapi_users.db import SQLAlchemyUserDatabase @@ -16,7 +18,7 @@ async def seed_default_admin_user(): """Create a default admin user with credentials admin@example.com / admin123 if one doesn't exist.""" try: - async with AsyncSessionLocal() as session: + async with async_session() as session: # Check if admin user already exists stmt = select(User).where(User.email == "admin@example.com") result = await session.execute(stmt) diff --git a/api/transformerlab/shared/models/models.py b/api/transformerlab/shared/models/models.py index 8993a27e9..8db5bf3ef 100644 --- a/api/transformerlab/shared/models/models.py +++ b/api/transformerlab/shared/models/models.py @@ -1,6 +1,7 @@ from typing import Optional from sqlalchemy import String, JSON, DateTime, func, Integer, Index from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column +from fastapi_users.db import SQLAlchemyBaseUserTableUUID import uuid import enum @@ -66,6 +67,23 @@ class Team(Base): name: Mapped[str] = mapped_column(String, nullable=False) +class User(SQLAlchemyBaseUserTableUUID, Base): + """ + User database model. Inherits from SQLAlchemyBaseUserTableUUID which provides: + - id (UUID primary key) + - email (unique, indexed) + - hashed_password + - is_active (boolean) + - is_superuser (boolean) + - is_verified (boolean) + + We add custom fields below: + """ + + first_name: Mapped[Optional[str]] = mapped_column(String(100), nullable=True) + last_name: Mapped[Optional[str]] = mapped_column(String(100), nullable=True) + + class TeamRole(str, enum.Enum): """Enum for user roles within a team.""" diff --git a/api/transformerlab/shared/models/user_model.py b/api/transformerlab/shared/models/user_model.py index ce0bf9496..7f342fff4 100644 --- a/api/transformerlab/shared/models/user_model.py +++ b/api/transformerlab/shared/models/user_model.py @@ -1,47 +1,12 @@ # database.py -from typing import AsyncGenerator, Optional -from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine -from sqlalchemy.orm import sessionmaker, Mapped, mapped_column -from fastapi_users.db import SQLAlchemyBaseUserTableUUID -from sqlalchemy import String +from sqlalchemy.ext.asyncio import AsyncSession -# Replace with your actual database URL (e.g., PostgreSQL, SQLite) -from transformerlab.db.constants import DATABASE_URL -from .models import Base, Team +from .models import Team -# 1. Define your User Model (inherits from a FastAPI Users base class) -class User(SQLAlchemyBaseUserTableUUID, Base): - """ - User database model. Inherits from SQLAlchemyBaseUserTableUUID which provides: - - id (UUID primary key) - - email (unique, indexed) - - hashed_password - - is_active (boolean) - - is_superuser (boolean) - - is_verified (boolean) - - We add custom fields below: - """ - first_name: Mapped[Optional[str]] = mapped_column(String(100), nullable=True) - last_name: Mapped[Optional[str]] = mapped_column(String(100), nullable=True) - - -# 2. Setup the Async Engine and Session -engine = create_async_engine(DATABASE_URL) -AsyncSessionLocal = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) - - -# 3. Utility to create tables (run this on app startup) -async def create_db_and_tables(): - async with engine.begin() as conn: - await conn.run_sync(Base.metadata.create_all) - - -# 4. Database session dependency -async def get_async_session() -> AsyncGenerator[AsyncSession, None]: - async with AsyncSessionLocal() as session: - yield session +# # 2. Setup the Async Engine and Session +# engine = create_async_engine(DATABASE_URL) +# AsyncSessionLocal = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) # 5. Create personal team for user @@ -49,11 +14,11 @@ async def create_personal_team(session: AsyncSession, user) -> Team: """ Create a personal team for the user named "Username's Team". Each user gets their own team. - + Args: session: Database session user: User object with first_name, last_name, or email - + Returns: Team: The created personal team """ @@ -63,7 +28,7 @@ async def create_personal_team(session: AsyncSession, user) -> Team: else: # Fallback to email username if no first_name team_name = f"{user.email.split('@')[0]}'s Team" - + # Create new team for this user team = Team(name=team_name) session.add(team)