Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 33 additions & 7 deletions api/test/api/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import asyncio
import os
from contextlib import asynccontextmanager

import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
import os
import asyncio

# Create test directories before setting environment variables
os.makedirs("test/tmp/", exist_ok=True)
Expand All @@ -24,12 +27,28 @@
os.environ["TRANSFORMERLAB_REFRESH_SECRET"] = "test-refresh-secret-for-testing-only"
os.environ["EMAIL_METHOD"] = "dev" # Use dev mode for tests (no actual email sending)

# Use in-memory database for tests
os.environ["DATABASE_URL"] = "sqlite+aiosqlite:///:memory:"
# Use in-memory database with shared cache for tests (allows multiple connections to share the same DB)
os.environ["DATABASE_URL"] = "sqlite+aiosqlite:///:memory:?cache=shared"

from api import app # noqa: E402


@asynccontextmanager
async def _test_noop_lifespan(app: FastAPI):
"""
Replace the production lifespan for tests so we don't start
background tasks or subprocesses that keep pytest from exiting.

We still manually init/seed/close the DB in the test fixtures.
"""
yield


# Override the app lifespan for tests to avoid spawning background
# tasks (run_over_and_over, migrations, fastchat controller, etc.).
app.router.lifespan_context = _test_noop_lifespan


class AuthenticatedTestClient(TestClient):
"""TestClient that automatically adds admin authentication headers to all requests"""

Expand Down Expand Up @@ -76,17 +95,24 @@ def request(self, method, url, **kwargs):
@pytest.fixture(scope="session")
def client():
# Initialize database tables for tests
from transformerlab.shared.models.user_model import create_db_and_tables # noqa: E402
import transformerlab.db.session as db # noqa: E402
from transformerlab.services.experiment_init import seed_default_admin_user # noqa: E402

asyncio.run(create_db_and_tables())
asyncio.run(seed_default_admin_user())
controller_log_dir = os.path.join("test", "tmp", "workspace", "logs")
os.makedirs(controller_log_dir, exist_ok=True)
controller_log_path = os.path.join(controller_log_dir, "controller.log")
# Create the file (or truncate if it exists)
with open(controller_log_path, "w") as f:
f.write("") # Empty dummy file Empty dummy file

# Initialize database and run migrations (replaces create_db_and_tables)
asyncio.run(db.init())

# Seed admin user BEFORE creating the test client (client tries to login in __init__)
asyncio.run(seed_default_admin_user())

with AuthenticatedTestClient(app) as c:
yield c

# Cleanup: close database connection
asyncio.run(db.close())
Loading
Loading