Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
500 changes: 472 additions & 28 deletions Cargo.lock

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions crates/api/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ base64 = "0.22.1"
opentelemetry = { version = "0.31", features = ["metrics"] }
opentelemetry_sdk = { version = "0.31", features = ["rt-tokio", "metrics"] }
opentelemetry-otlp = { version = "0.31", features = ["metrics", "grpc-tonic"] }
aws-config = { version = "1.8.15", default-features = false, features = ["rt-tokio", "rustls"] }
aws-sdk-scheduler = "1.97.0"
aws-sdk-sqs = "1.97.0"
aws-smithy-http-client = "1.1.4"

[dev-dependencies]
serde_json = "1.0"
Expand Down
233 changes: 233 additions & 0 deletions crates/api/src/bin/task_worker.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
use anyhow::{anyhow, Context};
use async_trait::async_trait;
use aws_smithy_http_client::{
tls::{self, rustls_provider::CryptoMode},
Builder as AwsHttpClientBuilder,
};
use chrono::{Duration, Utc};
use services::jobs::{CleanupCanceledInstancesTaskPayload, NoopTaskPayload, TaskExecutor};
use services::{agent::ports::AgentService, UserId};
use std::sync::Arc;

struct DefaultTaskExecutor {
db_pool: database::DbPool,
agent_service: Arc<dyn AgentService>,
}

#[async_trait]
impl TaskExecutor for DefaultTaskExecutor {
async fn execute_noop(&self, _payload: &NoopTaskPayload) -> anyhow::Result<()> {
tracing::info!("noop task received");
Ok(())
}

async fn execute_cleanup_canceled_instances(
&self,
payload: &CleanupCanceledInstancesTaskPayload,
) -> anyhow::Result<()> {
if payload.grace_days < 0 {
return Err(anyhow!("grace_days must be >= 0"));
}

let cutoff = Utc::now() - Duration::days(payload.grace_days);
let mut offset: i64 = 0;
let batch_size: i64 = 200;
let mut total_users = 0usize;
let mut total_instances = 0usize;
let mut failed_instances = 0usize;

loop {
let client = self
.db_pool
.get()
.await
.context("failed to get DB client")?;
let rows = client
.query(
"SELECT DISTINCT s.user_id
FROM subscriptions s
WHERE s.status = 'canceled'
AND s.updated_at <= $1
AND NOT EXISTS (
SELECT 1
FROM subscriptions active_sub
WHERE active_sub.user_id = s.user_id
AND active_sub.status IN ('active', 'trialing')
)
ORDER BY s.user_id
LIMIT $2 OFFSET $3",
&[&cutoff, &batch_size, &offset],
)
.await
.context("failed to query canceled users for cleanup")?;

if rows.is_empty() {
break;
}

for row in &rows {
let user_id: UserId = row.get("user_id");
total_users += 1;

let (instances, _) = match self.agent_service.list_instances(user_id, 1000, 0).await
{
Ok(result) => result,
Err(err) => {
tracing::error!(
"cleanup task: failed to list instances user_id={} err={}",
user_id,
err
);
continue;
}
};
Comment on lines +72 to +83
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation fetches only up to 1000 instances per user due to the hardcoded limit and lack of pagination. If a user has more than 1000 instances, any additional instances will not be cleaned up, leading to orphaned resources. To ensure all instances are processed, you should paginate through the results of list_instances.

                let instances = {
                    let mut all_instances = Vec::new();
                    let mut offset = 0;
                    const LIMIT: i32 = 1000;
                    loop {
                        match self.agent_service.list_instances(user_id, LIMIT, offset).await {
                            Ok((batch, total)) => {
                                let fetched_count = batch.len();
                                all_instances.extend(batch);
                                if all_instances.len() >= total as usize
                                    || fetched_count < LIMIT as usize
                                {
                                    break Ok(all_instances);
                                }
                                offset += fetched_count as i32;
                            }
                            Err(err) => {
                                tracing::error!(
                                    "cleanup task: failed to list instances user_id={} err={}",
                                    user_id,
                                    err
                                );
                                break Err(());
                            }
                        }
                    }
                };

                let instances = match instances {
                    Ok(i) => i,
                    Err(_) => continue,
                };


let mut cleanup_targets = instances
.into_iter()
.filter(|instance| instance.status != "deleted")
.collect::<Vec<_>>();
total_instances += cleanup_targets.len();

if payload.dry_run {
for instance in &cleanup_targets {
tracing::info!(
"cleanup task dry-run: would delete instance_id={} user_id={} status={}",
instance.id,
user_id,
instance.status
);
}
continue;
}

for instance in cleanup_targets.drain(..) {
if let Err(err) = self.agent_service.delete_instance(instance.id).await {
failed_instances += 1;
tracing::error!(
"cleanup task: delete failed instance_id={} user_id={} status={} err={}",
instance.id,
user_id,
instance.status,
err
);
} else {
tracing::info!(
"cleanup task: deleted instance_id={} user_id={} previous_status={}",
instance.id,
user_id,
instance.status
);
}
}
Comment on lines +103 to +121
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation deletes instances for a user sequentially. If a user has many instances, this could be slow. You can improve performance by deleting them concurrently using futures::stream.

You'll need to add use futures::StreamExt; to the top of the file.

                let results = futures::stream::iter(cleanup_targets.into_iter())
                    .map(|instance| {
                        let agent_service = Arc::clone(&self.agent_service);
                        async move {
                            let res = agent_service.delete_instance(instance.id).await;
                            (res, instance)
                        }
                    })
                    .buffer_unordered(10)
                    .collect::<Vec<_>>()
                    .await;

                for (result, instance) in results {
                    match result {
                        Ok(_) => {
                            tracing::info!(
                                "cleanup task: deleted instance_id={} user_id={} previous_status={}",
                                instance.id,
                                user_id,
                                instance.status
                            );
                        }
                        Err(err) => {
                            failed_instances += 1;
                            tracing::error!(
                                "cleanup task: delete failed instance_id={} user_id={} status={} err={}",
                                instance.id,
                                user_id,
                                instance.status,
                                err
                            );
                        }
                    }
                }

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree this is a performance concern, but not a correctness issue. Given our current instance limits (5 max per user), it’s unlikely to materially affect normal production users

}

offset += rows.len() as i64;
if rows.len() < batch_size as usize {
break;
}
}

tracing::info!(
"cleanup task finished grace_days={} dry_run={} users_scanned={} instances_targeted={} delete_failures={}",
payload.grace_days,
payload.dry_run,
total_users,
total_instances,
failed_instances
);

if failed_instances > 0 {
return Err(anyhow!(
"cleanup completed with {} failed instance deletions",
failed_instances
));
}

Ok(())
}
}

#[tokio::main]
async fn main() -> anyhow::Result<()> {
if let Err(e) = dotenvy::dotenv() {
eprintln!("Warning: Could not load .env file: {e}");
eprintln!("Continuing with environment variables...");
}

let config = config::Config::from_env();
let tasks = config.tasks.clone();

tracing_subscriber::fmt()
.with_env_filter(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")),
)
.init();

if !tasks.enabled {
return Err(anyhow!(
"task worker is disabled: set TASKS_ENABLED=true to run"
));
}

let region = tasks
.aws_region
.clone()
.ok_or_else(|| anyhow!("TASKS_AWS_REGION or AWS_REGION is required"))?;

let queue_url = tasks
.sqs_queue_url
.clone()
.ok_or_else(|| anyhow!("TASKS_SQS_QUEUE_URL is required"))?;

let db = database::Database::from_config(&config.database)
.await
.context("failed to connect database for task worker")?;

let system_configs_service = Arc::new(
services::system_configs::service::SystemConfigsServiceImpl::new(
db.system_configs_repository()
as Arc<dyn services::system_configs::ports::SystemConfigsRepository>,
),
);

let agent_service = Arc::new(services::agent::AgentServiceImpl::new(
db.agent_repository() as Arc<dyn services::agent::ports::AgentRepository>,
config.agent.managers.clone(),
config.agent.nearai_api_url.clone(),
system_configs_service as Arc<dyn services::system_configs::ports::SystemConfigsService>,
config.agent.channel_relay_url.clone(),
config.agent.non_tee_agent_url_pattern.clone(),
));

let http_client = AwsHttpClientBuilder::new()
.tls_provider(tls::Provider::Rustls(CryptoMode::AwsLc))
.build_https();

let aws_config = aws_config::defaults(aws_config::BehaviorVersion::latest())
.http_client(http_client)
.region(aws_sdk_sqs::config::Region::new(region))
.load()
.await;

let sqs_client = aws_sdk_sqs::Client::new(&aws_config);
let executor = Arc::new(DefaultTaskExecutor {
db_pool: db.pool().clone(),
agent_service,
});

let worker = api::tasks::AwsSqsTaskWorker::new(
sqs_client,
queue_url,
tasks.worker_max_concurrency,
tasks.worker_wait_seconds,
tasks.worker_visibility_timeout,
tasks.worker_max_messages,
executor,
);

worker
.run_forever()
.await
.context("task worker loop exited unexpectedly")
}
1 change: 1 addition & 0 deletions crates/api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pub mod openapi;
pub mod routes;
pub mod state;
pub mod static_files;
pub mod tasks;
pub mod usage_parsing;
pub mod validation;
pub mod web_search_pricing;
Expand Down
10 changes: 10 additions & 0 deletions crates/api/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,16 @@ async fn main() -> anyhow::Result<()> {
// Initialize tracing based on configuration
init_tracing(&config.logging);

if config.tasks.enabled {
if config.tasks.is_scheduler_configured() {
if let Err(err) = api::tasks::ensure_daily_cleanup_task(&config.tasks).await {
tracing::warn!("failed to ensure daily cleanup schedule: {}", err);
}
} else {
tracing::info!("tasks scheduler not configured; skipping daily cleanup schedule setup");
}
}

tracing::info!("Starting API server...");

tracing::info!(
Expand Down
Loading
Loading