Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 108 additions & 111 deletions Cargo.lock

Large diffs are not rendered by default.

3 changes: 1 addition & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,7 @@ pulldown-cmark = "0.13"
rand = "0.8"
reqwest = { version = "0.12", features = ["json"] }
rsa = "0.9"
# 0.21.2 causes config parsing errors
rust-ini = "=0.21.1"
rust-ini = "0.21"
semver = { version = "1.0", features = ["serde"] }
secrecy = { version = "0.10", features = ["serde"] }
serde = { version = "1.0", features = ["derive"] }
Expand Down
15 changes: 8 additions & 7 deletions crates/defguard/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use defguard_core::{
grpc::{
WorkerState,
gateway::{client_state::ClientMap, map::GatewayMap},
run_grpc_bidi_stream, run_grpc_server,
run_grpc_bidi_stream, run_grpc_gateway_stream, run_grpc_server,
},
init_dev_env, init_vpn_location, run_web_server,
utility_thread::run_utility_thread,
Expand Down Expand Up @@ -153,6 +153,13 @@ async fn main() -> Result<(), anyhow::Error> {

// run services
tokio::select! {
res = run_grpc_gateway_stream(
pool.clone(),
client_state,
wireguard_tx.clone(),
mail_tx.clone(),
grpc_event_tx,
) => error!("Gateway gRPC stream returned early: {res:?}"),
res = run_grpc_bidi_stream(
pool.clone(),
wireguard_tx.clone(),
Expand All @@ -163,15 +170,9 @@ async fn main() -> Result<(), anyhow::Error> {
res = run_grpc_server(
Arc::clone(&worker_state),
pool.clone(),
Arc::clone(&gateway_state),
client_state,
wireguard_tx.clone(),
mail_tx.clone(),
grpc_cert,
grpc_key,
failed_logins.clone(),
grpc_event_tx,
Arc::clone(&incompatible_components),
) => error!("gRPC server returned early: {res:?}"),
res = run_web_server(
worker_state,
Expand Down
26 changes: 24 additions & 2 deletions crates/defguard_common/src/config.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{net::IpAddr, sync::OnceLock};
use std::{fs::read_to_string, io, net::IpAddr, sync::OnceLock};

use clap::{Args, Parser, Subcommand};
use humantime::Duration;
Expand All @@ -13,6 +13,7 @@ use rsa::{
};
use secrecy::{ExposeSecret, SecretString};
use serde::Serialize;
use tonic::transport::{Certificate, ClientTlsConfig, Identity};

pub static SERVER_CONFIG: OnceLock<DefGuardConfig> = OnceLock::new();

Expand Down Expand Up @@ -65,9 +66,11 @@ pub struct DefGuardConfig {
#[arg(long, env = "DEFGUARD_GRPC_PORT", default_value_t = 50055)]
pub grpc_port: u16,

// Certificate authority (CA), certificate, and key for gRPC communication over HTTPS.
#[arg(long, env = "DEFGUARD_GRPC_CA")]
pub grpc_ca: Option<String>,
#[arg(long, env = "DEFGUARD_GRPC_CERT")]
pub grpc_cert: Option<String>,

#[arg(long, env = "DEFGUARD_GRPC_KEY")]
pub grpc_key: Option<String>,

Expand Down Expand Up @@ -298,6 +301,25 @@ impl DefGuardConfig {
}
url
}

/// Provide [`ClientTlsConfig`] from paths to cerfiticate, key, and cerfiticate authority (CA).
pub fn grpc_client_tls_config(&self) -> Result<Option<ClientTlsConfig>, io::Error> {
if self.grpc_ca.is_none() && (self.grpc_cert.is_none() || self.grpc_key.is_none()) {
return Ok(None);
}
let mut tls = ClientTlsConfig::new();
if let (Some(cert_path), Some(key_path)) = (&self.grpc_cert, &self.grpc_key) {
let cert = read_to_string(cert_path)?;
let key = read_to_string(key_path)?;
tls = tls.identity(Identity::from_pem(cert, key));
}
if let Some(ca_path) = &self.grpc_ca {
let ca = read_to_string(ca_path)?;
tls = tls.ca_certificate(Certificate::from_pem(ca));
}

Ok(Some(tls))
}
}

impl Default for DefGuardConfig {
Expand Down
15 changes: 15 additions & 0 deletions crates/defguard_common/src/db/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,18 @@ pub async fn setup_pool(options: PgConnectOptions) -> PgPool {
.expect("Cannot run database migrations.");
pool
}

#[derive(Deserialize)]
#[serde(rename_all = "UPPERCASE")]
pub enum TriggerOperation {
Insert,
Update,
Delete,
}

#[derive(Deserialize)]
pub struct ChangeNotification<T> {
pub operation: TriggerOperation,
pub old: Option<T>,
pub new: Option<T>,
}
105 changes: 105 additions & 0 deletions crates/defguard_core/src/db/models/gateway.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
use std::fmt;

use chrono::{NaiveDateTime, Utc};
use model_derive::Model;
use sqlx::{PgExecutor, query, query_as};

use defguard_common::db::{Id, NoId};

#[derive(Clone, Debug, Deserialize, Model, PartialEq, Serialize)]
pub(crate) struct Gateway<I = NoId> {
pub id: I,
pub network_id: Id,
pub url: String,
pub hostname: Option<String>,
pub connected_at: Option<NaiveDateTime>,
pub disconnected_at: Option<NaiveDateTime>,
}

impl Gateway {
#[must_use]
pub(crate) fn new<S: Into<String>>(network_id: Id, url: S) -> Self {
Self {
id: NoId,
network_id,
url: url.into(),
hostname: None,
connected_at: None,
disconnected_at: None,
}
}
}

impl Gateway<Id> {
pub(crate) async fn find_by_network_id<'e, E>(
executor: E,
network_id: Id,
) -> Result<Vec<Self>, sqlx::Error>
where
E: PgExecutor<'e>,
{
query_as!(
Self,
"SELECT * FROM gateway WHERE network_id = $1 ORDER BY id",
network_id
)
.fetch_all(executor)
.await
}

/// Update `hostname` and set `connected_at` to the current time and save it to the database.
pub(crate) async fn touch_connected<'e, E>(
&mut self,
executor: E,
hostname: String,
) -> Result<(), sqlx::Error>
where
E: PgExecutor<'e>,
{
self.hostname = Some(hostname);
self.connected_at = Some(Utc::now().naive_utc());
query!(
"UPDATE gateway SET hostname = $2, connected_at = $3 WHERE id = $1",
self.id,
self.hostname,
self.connected_at
)
.execute(executor)
.await?;

Ok(())
}

/// Set `disconnected_at` to the current time and save it to the database.
pub(crate) async fn touch_disconnected<'e, E>(&mut self, executor: E) -> Result<(), sqlx::Error>
where
E: PgExecutor<'e>,
{
self.disconnected_at = Some(Utc::now().naive_utc());
query!(
"UPDATE gateway SET disconnected_at = $2 WHERE id = $1",
self.id,
self.disconnected_at
)
.execute(executor)
.await?;

Ok(())
}

pub(crate) fn is_connected(&self) -> bool {
if let (Some(connected_at), Some(disconnected_at)) =
(self.connected_at, self.disconnected_at)
{
disconnected_at <= connected_at
} else {
self.connected_at.is_some()
}
}
}

impl fmt::Display for Gateway<Id> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Gateway(ID {}; URL {})", self.id, self.url)
}
}
1 change: 1 addition & 0 deletions crates/defguard_core/src/db/models/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
pub mod activity_log;
pub mod device;
pub mod enrollment;
pub mod gateway;
pub mod group;
pub mod oauth2authorizedapp;
pub mod oauth2client;
Expand Down
11 changes: 7 additions & 4 deletions crates/defguard_core/src/db/models/polling_token.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use defguard_common::{
random::gen_alphanumeric,
};
use model_derive::Model;
use sqlx::{Error as SqlxError, PgExecutor, PgPool, query_as};
use sqlx::{PgExecutor, query_as};

// Token used for polling requests.
#[derive(Clone, Debug, Model)]
Expand All @@ -28,18 +28,21 @@ impl PollingToken {
}

impl PollingToken<Id> {
pub async fn find(pool: &PgPool, token: &str) -> Result<Option<Self>, SqlxError> {
pub async fn find<'e, E>(executor: E, token: &str) -> Result<Option<Self>, sqlx::Error>
where
E: PgExecutor<'e>,
{
query_as!(
Self,
"SELECT id, token, device_id, created_at \
FROM pollingtoken WHERE token = $1",
token
)
.fetch_optional(pool)
.fetch_optional(executor)
.await
}

pub async fn delete_for_device_id<'e, E>(executor: E, device_id: Id) -> Result<(), SqlxError>
pub async fn delete_for_device_id<'e, E>(executor: E, device_id: Id) -> Result<(), sqlx::Error>
where
E: PgExecutor<'e>,
{
Expand Down
Loading
Loading