Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions crates/api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,10 @@ pub async fn init_domain_services_with_pool(
let models_repo = Arc::new(database::repositories::ModelRepository::new(
database.pool().clone(),
));
let api_key_model_affinity_repo = Arc::new(
database::repositories::PgApiKeyModelAffinityRepository::new(database.pool().clone()),
)
as Arc<dyn services::completions::ports::ApiKeyModelAffinityRepository>;

// Note: inference_url models and external providers are loaded in init_inference_providers.
// Periodic refresh is also started there.
Expand Down Expand Up @@ -376,6 +380,7 @@ pub async fn init_domain_services_with_pool(
usage_service.clone(),
metrics_service.clone(),
models_repo.clone() as Arc<dyn services::models::ModelsRepository>,
api_key_model_affinity_repo,
org_limit_repository,
));

Expand Down
141 changes: 140 additions & 1 deletion crates/api/tests/e2e_repositories.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,39 @@
// These tests directly test repository behavior with the database
mod common;

use async_trait::async_trait;
use chrono::{Duration, Utc};
use database::OAuthStateRepository;
use database::{OAuthStateRepository, PgApiKeyModelAffinityRepository};
use services::completions::ports::ApiKeyModelAffinityRepository;
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
};
use std::time::Duration as StdDuration;
use tokio::sync::Barrier;
use tokio::time::timeout;
use uuid::Uuid;

async fn get_test_pool() -> database::pool::DbPool {
let (_server, _inference_provider_pool, _mock_provider, database) =
common::setup_test_server_with_pool().await;
database.pool().clone()
}

struct TestProviderUrlSelector {
provider_url: String,
call_count: Arc<AtomicUsize>,
}

#[async_trait]
impl services::completions::ports::ProviderUrlSelector for TestProviderUrlSelector {
async fn select_provider_url(&self) -> Result<Option<String>, anyhow::Error> {
self.call_count.fetch_add(1, Ordering::SeqCst);
tokio::time::sleep(StdDuration::from_millis(50)).await;
Ok(Some(self.provider_url.clone()))
}
}

// ============================================
// OAuth State Repository Tests
// ============================================
Expand Down Expand Up @@ -114,3 +138,118 @@ async fn test_state_replay_protection() {
let second = repo.get_and_delete(&state).await.unwrap();
assert!(second.is_none());
}

// ============================================
// API Key Model Affinity Repository Tests
// ============================================

#[tokio::test]
async fn test_affinity_get_or_create_uses_advisory_lock_for_concurrent_miss() {
let pool = get_test_pool().await;
let repo = Arc::new(PgApiKeyModelAffinityRepository::new(pool.clone()));

let api_key_id = Uuid::new_v4();
let model_name = format!("test-model-{}", Uuid::new_v4());
let provider_url = "http://10.0.0.7:8000".to_string();
let selector_calls = Arc::new(AtomicUsize::new(0));
let barrier = Arc::new(Barrier::new(4));

let mut handles: Vec<tokio::task::JoinHandle<anyhow::Result<Option<String>>>> = Vec::new();
for _ in 0..4 {
let repo = repo.clone();
let barrier = barrier.clone();
let selector_calls = selector_calls.clone();
let provider_url = provider_url.clone();
let model_name = model_name.clone();

handles.push(tokio::spawn(async move {
let selector = TestProviderUrlSelector {
provider_url,
call_count: selector_calls,
};

barrier.wait().await;

repo.get_or_create_active_provider_url(
api_key_id,
&model_name,
StdDuration::from_secs(300),
&selector,
)
.await
}));
}

let mut results = Vec::new();
for handle in handles {
results.push(handle.await.unwrap().unwrap());
}

assert_eq!(selector_calls.load(Ordering::SeqCst), 1);
assert_eq!(results.len(), 4);
assert!(results
.iter()
.all(|value: &Option<String>| value.as_deref() == Some(provider_url.as_str())));

let stored = repo
.get_active_provider_url(api_key_id, &model_name)
.await
.unwrap();
assert_eq!(stored.as_deref(), Some(provider_url.as_str()));
}

#[tokio::test]
async fn test_affinity_hit_does_not_wait_on_advisory_lock() {
let pool = get_test_pool().await;
let repo = Arc::new(PgApiKeyModelAffinityRepository::new(pool.clone()));

let api_key_id = Uuid::new_v4();
let model_name = format!("test-model-{}", Uuid::new_v4());
let provider_url = "http://10.0.0.9:8000".to_string();
let selector_calls = Arc::new(AtomicUsize::new(0));

repo.upsert_provider_url(
api_key_id,
&model_name,
&provider_url,
StdDuration::from_secs(300),
)
.await
.unwrap();

let mut client = pool.get().await.unwrap();
let transaction = client.transaction().await.unwrap();
transaction
.execute(
"SELECT pg_advisory_xact_lock($1)",
&[&PgApiKeyModelAffinityRepository::advisory_lock_key(
api_key_id,
&model_name,
)],
)
.await
.unwrap();

let selector = TestProviderUrlSelector {
provider_url: "http://10.0.0.10:8000".to_string(),
call_count: selector_calls.clone(),
};

let result = timeout(
StdDuration::from_millis(200),
repo.get_or_create_active_provider_url(
api_key_id,
&model_name,
StdDuration::from_secs(300),
&selector,
),
)
.await
.expect("cache hit should not wait on advisory lock")
.unwrap();

assert_eq!(result.as_deref(), Some(provider_url.as_str()));
assert_eq!(selector_calls.load(Ordering::SeqCst), 0);

drop(transaction);
}
8 changes: 4 additions & 4 deletions crates/database/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ pub use constants::*;
pub use models::*;
pub use pool::DbPool;
pub use repositories::{
ApiKeyRepository, McpConnectorRepository, OAuthStateRepository, PgAttestationRepository,
PgConversationRepository, PgOrganizationInvitationRepository, PgOrganizationRepository,
PgResponseItemsRepository, PgResponseRepository, PostgresNearNonceRepository,
SessionRepository, UserRepository,
ApiKeyRepository, McpConnectorRepository, OAuthStateRepository,
PgApiKeyModelAffinityRepository, PgAttestationRepository, PgConversationRepository,
PgOrganizationInvitationRepository, PgOrganizationRepository, PgResponseItemsRepository,
PgResponseRepository, PostgresNearNonceRepository, SessionRepository, UserRepository,
};
pub use shutdown_coordinator::{ShutdownCoordinator, ShutdownStage, ShutdownStageResult};

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
CREATE TABLE api_key_model_affinity (
api_key_id UUID NOT NULL,
model_name TEXT NOT NULL,
provider_url TEXT NOT NULL,
expires_at TIMESTAMPTZ NOT NULL,
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
PRIMARY KEY (api_key_id, model_name)
);
Loading
Loading