Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
257 changes: 245 additions & 12 deletions Cargo.lock

Large diffs are not rendered by default.

72 changes: 72 additions & 0 deletions crates/api/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,21 @@ use opentelemetry_sdk::{
};
use services::{
analytics::AnalyticsServiceImpl,
audit::service::AuditServiceImpl,
auth::OAuthServiceImpl,
conversation::service::ConversationServiceImpl,
domain::service::DomainVerificationServiceImpl,
file::service::FileServiceImpl,
metrics::{MockMetricsService, OtlpMetricsService},
model::service::ModelServiceImpl,
organization::service::OrganizationServiceImpl,
rbac::service::{PermissionServiceImpl, RoleServiceImpl},
response::service::OpenAIProxy,
saml::service::SamlServiceImpl,
user::UserServiceImpl,
user::UserSettingsServiceImpl,
vpc::{initialize_vpc_credentials, VpcAuthConfig},
workspace::service::WorkspaceServiceImpl,
};
use std::sync::Arc;
use tracing_subscriber::EnvFilter;
Expand Down Expand Up @@ -75,6 +81,16 @@ async fn main() -> anyhow::Result<()> {
let system_configs_repo = db.system_configs_repository();
let model_repo = db.model_repository();

// Enterprise repositories
let organization_repo = db.organization_repository();
let workspace_repo = db.workspace_repository();
let permission_repo = db.permission_repository();
let role_repo = db.role_repository();
let audit_repo = db.audit_repository();
let saml_idp_config_repo = db.saml_idp_config_repository();
let saml_auth_state_repo = db.saml_auth_state_repository();
let domain_repo = db.domain_repository();

// Create services
tracing::info!("Initializing services...");
let oauth_service = Arc::new(OAuthServiceImpl::new(
Expand Down Expand Up @@ -151,6 +167,51 @@ async fn main() -> anyhow::Result<()> {
tracing::info!("Initializing analytics service...");
let analytics_service = Arc::new(AnalyticsServiceImpl::new(analytics_repo));

// Initialize enterprise services
tracing::info!("Initializing enterprise services...");

let organization_service = Arc::new(OrganizationServiceImpl::new(
organization_repo.clone(),
workspace_repo.clone(),
));

let workspace_service = Arc::new(WorkspaceServiceImpl::new(
workspace_repo.clone(),
));

let permission_service = Arc::new(PermissionServiceImpl::new(
permission_repo.clone(),
role_repo.clone(),
));

let role_service = Arc::new(RoleServiceImpl::new(
role_repo.clone(),
permission_repo.clone(),
));

let audit_service = Arc::new(AuditServiceImpl::new(audit_repo.clone()));

let domain_service = Arc::new(DomainVerificationServiceImpl::new(domain_repo));

// SAML service is optional - only initialize if SAML is enabled
let saml_service: Option<Arc<dyn services::saml::ports::SamlService>> = if config.saml.enabled {
tracing::info!("Initializing SAML SSO service...");
tracing::info!(
"SAML SP Base URL: {}, Entity ID: {}",
config.saml.sp_base_url,
config.saml.get_sp_entity_id()
);
Some(Arc::new(SamlServiceImpl::new(
saml_idp_config_repo,
saml_auth_state_repo,
config.saml.sp_base_url.clone(),
)))
} else {
tracing::info!("SAML SSO is disabled (set SAML_ENABLED=true to enable)");
let _ = (saml_idp_config_repo, saml_auth_state_repo); // Suppress unused warnings
None
};

Comment on lines +197 to +214
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To improve resource management and code clarity, it's better to initialize saml_idp_config_repo and saml_auth_state_repo only when SAML is enabled. This avoids initializing them unconditionally and then having to suppress unused variable warnings.

With the suggested change, you should also remove the unconditional initialization of these repositories on lines 90-91.

Suggested change
let saml_service: Option<Arc<dyn services::saml::ports::SamlService>> = if config.saml.enabled {
tracing::info!("Initializing SAML SSO service...");
tracing::info!(
"SAML SP Base URL: {}, Entity ID: {}",
config.saml.sp_base_url,
config.saml.get_sp_entity_id()
);
Some(Arc::new(SamlServiceImpl::new(
saml_idp_config_repo,
saml_auth_state_repo,
config.saml.sp_base_url.clone(),
)))
} else {
tracing::info!("SAML SSO is disabled (set SAML_ENABLED=true to enable)");
let _ = (saml_idp_config_repo, saml_auth_state_repo); // Suppress unused warnings
None
};
let saml_service: Option<Arc<dyn services::saml::ports::SamlService>> = if config.saml.enabled {
tracing::info!("Initializing SAML SSO service...");
tracing::info!(
"SAML SP Base URL: {}, Entity ID: {}",
config.saml.sp_base_url,
config.saml.get_sp_entity_id()
);
let saml_idp_config_repo = db.saml_idp_config_repository();
let saml_auth_state_repo = db.saml_auth_state_repository();
Some(Arc::new(SamlServiceImpl::new(
saml_idp_config_repo,
saml_auth_state_repo,
config.saml.sp_base_url.clone(),
)))
} else {
tracing::info!("SAML SSO is disabled (set SAML_ENABLED=true to enable)");
None
};

// Initialize system configs service
tracing::info!("Initializing system configs service...");
let system_configs_service = Arc::new(
Expand Down Expand Up @@ -235,6 +296,17 @@ async fn main() -> anyhow::Result<()> {
near_rpc_url: config.near.rpc_url.clone(),
near_balance_cache: Arc::new(tokio::sync::RwLock::new(std::collections::HashMap::new())),
model_settings_cache: Arc::new(tokio::sync::RwLock::new(std::collections::HashMap::new())),
// Enterprise services
organization_service,
organization_repository: organization_repo,
workspace_service,
workspace_repository: workspace_repo,
permission_service,
role_service,
role_repository: role_repo,
audit_service,
saml_service,
domain_service,
};

// Create router with CORS support
Expand Down
5 changes: 5 additions & 0 deletions crates/api/src/middleware/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
pub mod auth;
pub mod metrics;
pub mod rate_limit;
pub mod tenant;

pub use auth::{admin_auth_middleware, auth_middleware, AuthState, AuthenticatedUser};
pub use metrics::{http_metrics_middleware, MetricsState};
pub use rate_limit::{rate_limit_middleware, RateLimitConfig, RateLimitState};
pub use tenant::{
has_all_permissions, has_any_permission, require_permission, tenant_middleware, TenantContext,
TenantState,
};
Loading