diff --git a/Cargo.lock b/Cargo.lock index 502cc79..5542f88 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -511,6 +511,7 @@ dependencies = [ "time", "tokio", "tokio-postgres", + "tokio-util", "tower 0.5.1", "tower-http", "tracing", @@ -2649,16 +2650,17 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.7.10" +version = "0.7.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5419f34732d9eb6ee4c3578b7989078579b7f039cbbb9ca2c4da015749371e15" +checksum = "61e7c3654c13bcd040d4a03abee2c75b1d14a37b423cf5a813ceae1cc903ec6a" dependencies = [ "bytes", "futures-core", "futures-sink", + "futures-util", + "hashbrown 0.14.3", "pin-project-lite", "tokio", - "tracing", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 8491724..c1df051 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -60,6 +60,7 @@ tokio-postgres = { version = "0.7.12", features = [ "with-serde_json-1", "with-time-0_3", ] } +tokio-util = { version = "0.7.12", features = ["rt"] } tower = "0.5.1" tower-http = { version = "0.6.2", features = ["auth", "fs", "set-header", "trace"] } tracing = "0.1.40" diff --git a/clowarden-cli/src/main.rs b/clowarden-cli/src/main.rs index 125862c..53e8a5e 100644 --- a/clowarden-cli/src/main.rs +++ b/clowarden-cli/src/main.rs @@ -22,6 +22,9 @@ use clowarden_core::{ }, }; +/// Environment variable containing Github token. +const GITHUB_TOKEN: &str = "GITHUB_TOKEN"; + #[derive(Parser)] #[command( version, @@ -83,9 +86,6 @@ struct GenerateArgs { output_file: PathBuf, } -/// Environment variable containing Github token. -const GITHUB_TOKEN: &str = "GITHUB_TOKEN"; - #[tokio::main] async fn main() -> Result<()> { let cli = Cli::parse(); diff --git a/clowarden-core/src/services/mod.rs b/clowarden-core/src/services/mod.rs index 8a59d9c..5a8e250 100644 --- a/clowarden-core/src/services/mod.rs +++ b/clowarden-core/src/services/mod.rs @@ -1,7 +1,7 @@ //! This module defines some types and traits that service handlers //! implementations will rely upon. -use std::fmt::Debug; +use std::{fmt::Debug, sync::Arc}; use anyhow::Result; use as_any::AsAny; @@ -27,7 +27,7 @@ pub trait ServiceHandler { } /// Type alias to represent a service handler trait object. -pub type DynServiceHandler = Box; +pub type DynServiceHandler = Arc; /// Represents a summary of changes detected in the service's state as defined /// in the configuration from the base to the head reference. diff --git a/clowarden-server/Cargo.toml b/clowarden-server/Cargo.toml index c63ad9d..40e82a2 100644 --- a/clowarden-server/Cargo.toml +++ b/clowarden-server/Cargo.toml @@ -36,6 +36,7 @@ thiserror = { workspace = true } time = { workspace = true } tokio = { workspace = true } tokio-postgres = { workspace = true } +tokio-util = { workspace = true } tower = { workspace = true } tower-http = { workspace = true } tracing = { workspace = true } diff --git a/clowarden-server/src/github.rs b/clowarden-server/src/github.rs index 31b7cf1..f263f95 100644 --- a/clowarden-server/src/github.rs +++ b/clowarden-server/src/github.rs @@ -20,6 +20,9 @@ use thiserror::Error; use clowarden_core::cfg::{GitHubApp, Organization}; +/// Name used for the check run in GitHub. +const CHECK_RUN_NAME: &str = "CLOWarden"; + /// Trait that defines some operations a GH implementation must support. #[async_trait] #[cfg_attr(test, automock)] @@ -165,9 +168,6 @@ pub(crate) enum PullRequestEventAction { Other, } -/// Name used for the check run in GitHub. -const CHECK_RUN_NAME: &str = "CLOWarden"; - /// Helper function to create a new ChecksCreateRequest instance. pub(crate) fn new_checks_create_request( head_sha: String, diff --git a/clowarden-server/src/jobs.rs b/clowarden-server/src/jobs.rs index 4889453..5fc035c 100644 --- a/clowarden-server/src/jobs.rs +++ b/clowarden-server/src/jobs.rs @@ -1,7 +1,7 @@ //! This module defines the types and functionality needed to schedule and //! process jobs. -use std::{collections::HashMap, sync::Arc, time::Duration}; +use std::{collections::HashMap, time::Duration}; use ::time::OffsetDateTime; use anyhow::{Error, Result}; @@ -10,10 +10,11 @@ use futures::future::{self, JoinAll}; use octorust::types::{ChecksCreateRequestConclusion, JobStatus, PullRequestData}; use serde::{Deserialize, Serialize}; use tokio::{ - sync::{broadcast, mpsc}, + sync::mpsc, task::JoinHandle, time::{self, sleep, MissedTickBehavior}, }; +use tokio_util::sync::CancellationToken; use tracing::{debug, error, instrument}; use self::core::github::Source; @@ -31,6 +32,9 @@ use crate::{ tmpl, }; +/// How often periodic reconcile jobs should be scheduled (in seconds). +const RECONCILE_FREQUENCY: u64 = 60 * 60; // Every hour + /// Represents a job to be executed. #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] @@ -117,84 +121,83 @@ impl ValidateInput { } } -/// A jobs handler is in charge of executing the received jobs. -pub(crate) struct Handler { +/// A jobs handler is in charge of executing the received jobs. It will create +/// a worker for each organization, plus an additional task to route jobs to +/// the corresponding organization worker. All tasks will stop when the +/// cancellation token is cancelled. +pub(crate) fn handler( + db: &DynDB, + gh: &DynGH, + ghc: &core::github::DynGH, + services: &HashMap, + mut jobs_rx: mpsc::UnboundedReceiver, + cancel_token: CancellationToken, + orgs: Vec, +) -> JoinAll> { + let mut handles = Vec::with_capacity(orgs.len() + 1); + let mut orgs_jobs_tx_channels = HashMap::new(); + + // Create a worker for each organization + for org in orgs { + let (org_jobs_tx, org_jobs_rx) = mpsc::unbounded_channel(); + orgs_jobs_tx_channels.insert(org.name, org_jobs_tx); + let org_worker = OrgWorker::new(db.clone(), gh.clone(), ghc.clone(), services.clone()); + handles.push(org_worker.run(org_jobs_rx, cancel_token.clone())); + } + + // Create a worker to route jobs to the corresponding org worker + let jobs_router = tokio::spawn(async move { + loop { + tokio::select! { + biased; + + // Pick next job from the queue and send it to the corresponding org worker + Some(job) = jobs_rx.recv() => { + if let Some(org_jobs_tx) = orgs_jobs_tx_channels.get(job.org_name()) { + _ = org_jobs_tx.send(job); + } + } + + // Exit if the handler has been asked to stop + () = cancel_token.cancelled() => break, + } + } + }); + handles.push(jobs_router); + + future::join_all(handles) +} + +/// An organization worker is in charge of processing jobs for a given +/// organization. +struct OrgWorker { db: DynDB, gh: DynGH, ghc: core::github::DynGH, services: HashMap, } -impl Handler { - /// Create a new handler instance. - pub(crate) fn new( +impl OrgWorker { + /// Create a new organization worker instance. + fn new( db: DynDB, gh: DynGH, ghc: core::github::DynGH, services: HashMap, - ) -> Arc { - Arc::new(Self { + ) -> Self { + Self { db, gh, ghc, services, - }) - } - - /// Spawn some tasks to process jobs received on the jobs channel. We will - /// create one worker per organization, plus an additional task to route - /// jobs to the corresponding organization worker. All tasks will stop when - /// notified on the stop channel provided. - pub(crate) fn start( - self: Arc, - mut jobs_rx: mpsc::UnboundedReceiver, - stop_tx: &broadcast::Sender<()>, - orgs: Vec, - ) -> JoinAll> { - let mut handles = Vec::with_capacity(orgs.len() + 1); - let mut orgs_jobs_tx_channels = HashMap::new(); - - // Create a worker for each organization - for org in orgs { - let (org_jobs_tx, org_jobs_rx) = mpsc::unbounded_channel(); - orgs_jobs_tx_channels.insert(org.name, org_jobs_tx); - let org_worker = self.clone().organization_worker(org_jobs_rx, stop_tx.subscribe()); - handles.push(org_worker); } - - // Create a worker to route jobs to the corresponding org worker - let mut stop_rx = stop_tx.subscribe(); - let jobs_router = tokio::spawn(async move { - loop { - tokio::select! { - biased; - - // Pick next job from the queue and send it to the corresponding org worker - Some(job) = jobs_rx.recv() => { - if let Some(org_jobs_tx) = orgs_jobs_tx_channels.get(job.org_name()) { - _ = org_jobs_tx.send(job); - } - } - - // Exit if the handler has been asked to stop - _ = stop_rx.recv() => { - break - } - } - } - }); - handles.push(jobs_router); - - future::join_all(handles) } - /// Spawn a worker that will take care of processing jobs for a given - /// organization. The worker will stop when notified on the stop channel - /// provided. - fn organization_worker( - self: Arc, + /// Run organization worker. + fn run( + self, mut org_jobs_rx: mpsc::UnboundedReceiver, - mut stop_rx: broadcast::Receiver<()>, + cancel_token: CancellationToken, ) -> JoinHandle<()> { tokio::spawn(async move { loop { @@ -210,9 +213,7 @@ impl Handler { } // Exit if the handler has been asked to stop - _ = stop_rx.recv() => { - break - } + () = cancel_token.cancelled() => break, } } }) @@ -349,14 +350,11 @@ impl Handler { } } -/// How often periodic reconcile jobs should be scheduled (in seconds). -const RECONCILE_FREQUENCY: u64 = 60 * 60; - /// A jobs scheduler is in charge of scheduling the execution of some jobs /// periodically. pub(crate) fn scheduler( jobs_tx: mpsc::UnboundedSender, - mut stop_rx: broadcast::Receiver<()>, + cancel_token: CancellationToken, orgs: Vec, ) -> JoinAll> { let scheduler = tokio::spawn(async move { @@ -369,9 +367,7 @@ pub(crate) fn scheduler( biased; // Exit if the scheduler has been asked to stop - _ = stop_rx.recv() => { - break - } + () = cancel_token.cancelled() => break, // Schedule reconcile job for each of the registered organizations _ = reconcile.tick() => { diff --git a/clowarden-server/src/main.rs b/clowarden-server/src/main.rs index 905ec1e..0c53f06 100644 --- a/clowarden-server/src/main.rs +++ b/clowarden-server/src/main.rs @@ -6,15 +6,14 @@ use std::{collections::HashMap, net::SocketAddr, path::PathBuf, sync::Arc}; use anyhow::{Context, Result}; use clap::Parser; use config::{Config, File}; +use db::DynDB; use deadpool_postgres::{Config as DbConfig, Runtime}; use futures::future; +use github::DynGH; use openssl::ssl::{SslConnector, SslMethod, SslVerifyMode}; use postgres_openssl::MakeTlsConnector; -use tokio::{ - net::TcpListener, - signal, - sync::{broadcast, mpsc}, -}; +use tokio::{net::TcpListener, signal, sync::mpsc}; +use tokio_util::sync::CancellationToken; use tracing::{error, info}; use tracing_subscriber::EnvFilter; @@ -69,12 +68,12 @@ async fn main() -> Result<()> { let connector = MakeTlsConnector::new(builder.build()); let db_cfg: DbConfig = cfg.get("db")?; let pool = db_cfg.create_pool(Some(Runtime::Tokio1), connector)?; - let db = Arc::new(PgDB::new(pool)); + let db: DynDB = Arc::new(PgDB::new(pool)); // Setup GitHub clients let gh_app: core::cfg::GitHubApp = cfg.get("server.githubApp")?; - let gh = Arc::new(github::GHApi::new(&gh_app).context("error setting up github client")?); - let ghc = Arc::new( + let gh: DynGH = Arc::new(github::GHApi::new(&gh_app).context("error setting up github client")?); + let ghc: core::github::DynGH = Arc::new( core::github::GHApi::new_with_app_creds(&gh_app).context("error setting up core github client")?, ); @@ -84,18 +83,24 @@ async fn main() -> Result<()> { let svc = Arc::new(services::github::service::SvcApi::new_with_app_creds(&gh_app)?); services.insert( services::github::SERVICE_NAME, - Box::new(services::github::Handler::new(ghc.clone(), svc)), + Arc::new(services::github::Handler::new(ghc.clone(), svc)), ); } // Setup and launch jobs workers - let (stop_tx, _): (broadcast::Sender<()>, _) = broadcast::channel(1); + let cancel_token = CancellationToken::new(); let (jobs_tx, jobs_rx) = mpsc::unbounded_channel(); - let jobs_handler = jobs::Handler::new(db.clone(), gh.clone(), ghc.clone(), services); - let jobs_workers_done = future::join_all([ - jobs_handler.start(jobs_rx, &stop_tx, cfg.get("organizations")?), - jobs::scheduler(jobs_tx.clone(), stop_tx.subscribe(), cfg.get("organizations")?), - ]); + let jobs_handler = jobs::handler( + &db, + &gh, + &ghc, + &services, + jobs_rx, + cancel_token.clone(), + cfg.get("organizations")?, + ); + let jobs_scheduler = jobs::scheduler(jobs_tx.clone(), cancel_token.clone(), cfg.get("organizations")?); + let jobs_workers_done = future::join_all([jobs_handler, jobs_scheduler]); // Setup and launch HTTP server let router = handlers::setup_router(&cfg, db.clone(), gh.clone(), jobs_tx) @@ -110,7 +115,7 @@ async fn main() -> Result<()> { } // Ask jobs workers to stop and wait for them to finish - drop(stop_tx); + cancel_token.cancel(); jobs_workers_done.await; info!("server stopped");