Skip to content

Commit 2da8a93

Browse files
committed
fix: implement Envoy-compliant connection draining with mixed protocol support
- Add ListenerProtocolConfig enum to handle HTTP, TCP, and mixed protocol listeners - Respect HttpConnectionManager.drain_timeout field from configuration - Support listeners with both http_connection_manager and tcp_proxy filters - Remove ambiguous 'Gradual' strategy, align with Envoy's draining behavior - Add initiate_listener_drain_from_filter_analysis() for proper integration Signed-off-by: Eeshu-Yadav <[email protected]>
1 parent a8e8919 commit 2da8a93

File tree

6 files changed

+171
-55
lines changed

6 files changed

+171
-55
lines changed

orion-lib/src/listeners/drain_signaling.rs

Lines changed: 124 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
//
1717

1818
use crate::{Error, Result};
19-
use orion_configuration::config::listener::DrainType as ConfigDrainType;
19+
use orion_configuration::config::listener::{DrainType as ConfigDrainType, FilterChain, MainFilter};
2020
use pingora_timeout::fast_timeout::fast_timeout;
2121
use std::collections::HashMap;
2222
use std::sync::Arc;
@@ -25,6 +25,13 @@ use tokio::sync::RwLock;
2525
use tokio::time::sleep;
2626
use tracing::{debug, info, warn};
2727

28+
#[derive(Debug, Clone)]
29+
pub enum ListenerProtocolConfig {
30+
Http { drain_timeout: Option<Duration> },
31+
Tcp,
32+
Mixed { http_drain_timeout: Option<Duration>, has_tcp: bool, has_http: bool },
33+
}
34+
2835
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
2936
pub enum DrainScenario {
3037
HealthCheckFail,
@@ -46,8 +53,8 @@ impl DrainScenario {
4653
pub enum DrainStrategy {
4754
Tcp { global_timeout: Duration },
4855
Http { global_timeout: Duration, drain_timeout: Duration },
56+
Mixed { global_timeout: Duration, http_drain_timeout: Duration, tcp_connections: bool, http_connections: bool },
4957
Immediate,
50-
Gradual,
5158
}
5259

5360
#[derive(Debug, Clone)]
@@ -102,9 +109,10 @@ impl ListenerDrainContext {
102109

103110
pub fn is_timeout_exceeded(&self) -> bool {
104111
let global_timeout = match &self.strategy {
105-
DrainStrategy::Tcp { global_timeout } | DrainStrategy::Http { global_timeout, .. } => *global_timeout,
112+
DrainStrategy::Tcp { global_timeout }
113+
| DrainStrategy::Http { global_timeout, .. }
114+
| DrainStrategy::Mixed { global_timeout, .. } => *global_timeout,
106115
DrainStrategy::Immediate => Duration::from_secs(0),
107-
DrainStrategy::Gradual => Duration::from_secs(600),
108116
};
109117

110118
self.drain_start.elapsed() >= global_timeout
@@ -113,6 +121,7 @@ impl ListenerDrainContext {
113121
pub fn get_http_drain_timeout(&self) -> Option<Duration> {
114122
match &self.strategy {
115123
DrainStrategy::Http { drain_timeout, .. } => Some(*drain_timeout),
124+
DrainStrategy::Mixed { http_drain_timeout, .. } => Some(*http_drain_timeout),
116125
_ => None,
117126
}
118127
}
@@ -126,12 +135,30 @@ pub struct DrainSignalingManager {
126135
listener_drain_state: Arc<RwLock<Option<ListenerDrainState>>>,
127136
}
128137

138+
impl ListenerProtocolConfig {
139+
pub fn from_listener_analysis(
140+
has_http_connection_manager: bool,
141+
has_tcp_proxy: bool,
142+
http_drain_timeout: Option<Duration>,
143+
) -> Self {
144+
match (has_http_connection_manager, has_tcp_proxy) {
145+
(true, true) => Self::Mixed { http_drain_timeout, has_tcp: true, has_http: true },
146+
(true, false) => Self::Http { drain_timeout: http_drain_timeout },
147+
(false, true) => Self::Tcp,
148+
(false, false) => {
149+
warn!("No HTTP connection manager or TCP proxy found in listener, defaulting to TCP draining");
150+
Self::Tcp
151+
},
152+
}
153+
}
154+
}
155+
129156
impl DrainSignalingManager {
130157
pub fn new() -> Self {
131158
Self {
132159
drain_contexts: Arc::new(RwLock::new(HashMap::new())),
133160
global_drain_timeout: Duration::from_secs(600),
134-
default_http_drain_timeout: Duration::from_millis(5000),
161+
default_http_drain_timeout: Duration::from_secs(5),
135162
listener_drain_state: Arc::new(RwLock::new(None)),
136163
}
137164
}
@@ -207,17 +234,21 @@ impl DrainSignalingManager {
207234
pub async fn initiate_listener_drain(
208235
&self,
209236
listener_id: String,
210-
is_http: bool,
211-
http_drain_timeout: Option<Duration>,
237+
protocol_config: ListenerProtocolConfig,
212238
active_connections: usize,
213239
) -> Result<Arc<ListenerDrainContext>> {
214-
let strategy = if is_http {
215-
DrainStrategy::Http {
240+
let strategy = match protocol_config {
241+
ListenerProtocolConfig::Http { drain_timeout } => DrainStrategy::Http {
216242
global_timeout: self.global_drain_timeout,
217-
drain_timeout: http_drain_timeout.unwrap_or(self.default_http_drain_timeout),
218-
}
219-
} else {
220-
DrainStrategy::Tcp { global_timeout: self.global_drain_timeout }
243+
drain_timeout: drain_timeout.unwrap_or(self.default_http_drain_timeout),
244+
},
245+
ListenerProtocolConfig::Tcp => DrainStrategy::Tcp { global_timeout: self.global_drain_timeout },
246+
ListenerProtocolConfig::Mixed { http_drain_timeout, has_tcp, has_http } => DrainStrategy::Mixed {
247+
global_timeout: self.global_drain_timeout,
248+
http_drain_timeout: http_drain_timeout.unwrap_or(self.default_http_drain_timeout),
249+
tcp_connections: has_tcp,
250+
http_connections: has_http,
251+
},
221252
};
222253

223254
let context = Arc::new(ListenerDrainContext::new(listener_id.clone(), strategy.clone(), active_connections));
@@ -326,6 +357,33 @@ impl DrainSignalingManager {
326357
Err(Error::new("Timeout waiting for listener drain completion"))
327358
}
328359
}
360+
361+
pub async fn initiate_listener_drain_from_filter_analysis(
362+
&self,
363+
listener_id: String,
364+
filter_chains: &[FilterChain],
365+
active_connections: usize,
366+
) -> Result<Arc<ListenerDrainContext>> {
367+
let mut has_http = false;
368+
let mut has_tcp = false;
369+
let mut http_drain_timeout: Option<Duration> = None;
370+
371+
for filter_chain in filter_chains {
372+
match &filter_chain.terminal_filter {
373+
MainFilter::Http(http_config) => {
374+
has_http = true;
375+
http_drain_timeout = http_config.drain_timeout;
376+
},
377+
MainFilter::Tcp(_) => {
378+
has_tcp = true;
379+
},
380+
}
381+
}
382+
383+
let protocol_config = ListenerProtocolConfig::from_listener_analysis(has_http, has_tcp, http_drain_timeout);
384+
385+
self.initiate_listener_drain(listener_id, protocol_config, active_connections).await
386+
}
329387
}
330388

331389
impl Clone for DrainSignalingManager {
@@ -403,7 +461,8 @@ mod tests {
403461
let manager = DrainSignalingManager::new();
404462
assert!(!manager.has_draining_listeners().await);
405463

406-
let context = manager.initiate_listener_drain("test".to_string(), false, None, 1).await.unwrap();
464+
let context =
465+
manager.initiate_listener_drain("test".to_string(), ListenerProtocolConfig::Tcp, 1).await.unwrap();
407466

408467
assert!(manager.has_draining_listeners().await);
409468
assert_eq!(manager.get_draining_listeners().await, vec!["test"]);
@@ -417,7 +476,14 @@ mod tests {
417476
async fn test_timeout_behavior() {
418477
let manager = DrainSignalingManager::with_timeouts(Duration::from_millis(50), Duration::from_millis(25));
419478

420-
let context = manager.initiate_listener_drain("timeout-test".to_string(), true, None, 5).await.unwrap();
479+
let context = manager
480+
.initiate_listener_drain(
481+
"timeout-test".to_string(),
482+
ListenerProtocolConfig::Http { drain_timeout: None },
483+
5,
484+
)
485+
.await
486+
.unwrap();
421487

422488
sleep(Duration::from_millis(10)).await;
423489
sleep(Duration::from_millis(60)).await;
@@ -435,4 +501,47 @@ mod tests {
435501
"Expected manager to no longer track the listener after timeout"
436502
);
437503
}
504+
505+
#[tokio::test]
506+
async fn test_mixed_protocol_drain_context() {
507+
let strategy = DrainStrategy::Mixed {
508+
global_timeout: Duration::from_secs(600),
509+
http_drain_timeout: Duration::from_secs(5),
510+
tcp_connections: true,
511+
http_connections: true,
512+
};
513+
let context = ListenerDrainContext::new("test-mixed".to_string(), strategy, 10);
514+
515+
assert_eq!(context.get_active_connections().await, 10);
516+
assert!(!context.is_completed().await);
517+
assert_eq!(context.get_http_drain_timeout(), Some(Duration::from_secs(5)));
518+
assert!(!context.is_timeout_exceeded());
519+
}
520+
521+
#[tokio::test]
522+
async fn test_listener_protocol_config_analysis() {
523+
let http_config = ListenerProtocolConfig::from_listener_analysis(true, false, Some(Duration::from_secs(10)));
524+
match http_config {
525+
ListenerProtocolConfig::Http { drain_timeout } => {
526+
assert_eq!(drain_timeout, Some(Duration::from_secs(10)));
527+
},
528+
_ => panic!("Expected HTTP config"),
529+
}
530+
531+
let tcp_config = ListenerProtocolConfig::from_listener_analysis(false, true, None);
532+
match tcp_config {
533+
ListenerProtocolConfig::Tcp => {},
534+
_ => panic!("Expected TCP config"),
535+
}
536+
537+
let mixed_config = ListenerProtocolConfig::from_listener_analysis(true, true, Some(Duration::from_secs(3)));
538+
match mixed_config {
539+
ListenerProtocolConfig::Mixed { http_drain_timeout, has_tcp, has_http } => {
540+
assert_eq!(http_drain_timeout, Some(Duration::from_secs(3)));
541+
assert!(has_tcp);
542+
assert!(has_http);
543+
},
544+
_ => panic!("Expected Mixed config"),
545+
}
546+
}
438547
}

orion-lib/src/listeners/filterchain.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,8 +235,10 @@ impl FilterchainType {
235235
.serve_connection_with_upgrades(
236236
stream,
237237
hyper::service::service_fn(|req: Request<hyper::body::Incoming>| {
238-
let handler_req =
239-
ExtendedRequest { request: req, downstream_metadata: downstream_metadata.clone() };
238+
let handler_req = ExtendedRequest {
239+
request: req,
240+
downstream_metadata: Arc::new(downstream_metadata.connection.clone()),
241+
};
240242
req_handler.call(handler_req).map_err(orion_error::Error::into_inner)
241243
}),
242244
)

orion-lib/src/listeners/http_connection_manager.rs

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,15 @@ use parking_lot::Mutex;
7171
use route::MatchedRequest;
7272
use scopeguard::defer;
7373
use std::collections::HashSet;
74-
use std::sync::atomic::{AtomicU64, AtomicU8, Ordering};
74+
use std::sync::atomic::AtomicUsize;
7575
use std::thread::ThreadId;
7676
use std::time::Instant;
77-
use std::{fmt, future::Future, result::Result as StdResult, sync::Arc};
77+
use std::{
78+
fmt,
79+
future::Future,
80+
result::Result as StdResult,
81+
sync::{Arc, LazyLock},
82+
};
7883
use tokio::sync::mpsc::Permit;
7984
use tokio::sync::watch;
8085
use upgrades as upgrade_utils;
@@ -91,8 +96,11 @@ use crate::{
9196
use crate::{
9297
body::body_with_timeout::BodyWithTimeout,
9398
listeners::{
94-
access_log::AccessLogContext, drain_signaling::DrainSignalingManager,
95-
filter_state::DownstreamConnectionMetadata, listeners_manager::ConnectionManager, rate_limiter::LocalRateLimit,
99+
access_log::AccessLogContext,
100+
drain_signaling::DrainSignalingManager,
101+
filter_state::{DownstreamConnectionMetadata, DownstreamMetadata},
102+
listeners_manager::ConnectionManager,
103+
rate_limiter::LocalRateLimit,
96104
synthetic_http_response::SyntheticHttpResponse,
97105
},
98106
utils::http::{request_head_size, response_head_size},
@@ -102,6 +110,9 @@ use orion_tracing::http_tracer::{HttpTracer, SpanKind, SpanName};
102110
use orion_tracing::request_id::{RequestId, RequestIdManager};
103111
use orion_tracing::trace_context::TraceContext;
104112

113+
static EMPTY_HASHMAP: LazyLock<Arc<HashMap<RouteMatch, Vec<Arc<HttpFilter>>>>> =
114+
LazyLock::new(|| Arc::new(HashMap::new()));
115+
105116
#[derive(Debug, Clone)]
106117
pub struct HttpConnectionManagerBuilder {
107118
listener_name: Option<&'static str>,
@@ -405,7 +416,7 @@ impl HttpConnectionManager {
405416
}
406417

407418
pub fn remove_route(&self) {
408-
self.http_filters_per_route.swap(Arc::new(HashMap::new()));
419+
self.http_filters_per_route.swap(EMPTY_HASHMAP.clone());
409420
let _ = self.router_sender.send_replace(None);
410421
}
411422

@@ -418,10 +429,10 @@ impl HttpConnectionManager {
418429
}
419430

420431
pub async fn start_draining(&self, drain_state: crate::listeners::drain_signaling::ListenerDrainState) {
421-
if let Some(drain_timeout) = self.drain_timeout {
422-
let listener_id = format!("{}-{}", self.listener_name, self.filter_chain_match_hash);
423-
let _ = self.drain_signaling.initiate_listener_drain(listener_id, true, Some(drain_timeout), 0).await;
424-
}
432+
let listener_id = format!("{}-{}", self.listener_name, self.filter_chain_match_hash);
433+
let protocol_config =
434+
crate::listeners::drain_signaling::ListenerProtocolConfig::Http { drain_timeout: self.drain_timeout };
435+
let _ = self.drain_signaling.initiate_listener_drain(listener_id, protocol_config, 0).await;
425436

426437
self.drain_signaling.start_listener_draining(drain_state).await;
427438
}
@@ -628,7 +639,7 @@ pub(crate) struct HttpRequestHandler {
628639

629640
pub struct ExtendedRequest<B> {
630641
pub request: Request<B>,
631-
pub downstream_metadata: Arc<DownstreamMetadata>,
642+
pub downstream_metadata: Arc<DownstreamConnectionMetadata>,
632643
}
633644

634645
#[derive(Debug)]
@@ -1174,7 +1185,7 @@ impl Service<ExtendedRequest<Incoming>> for HttpRequestHandler {
11741185

11751186
//
11761187
// 1. evaluate InitHttpContext, if logging is enabled
1177-
eval_http_init_context(&request, &trans_handler, downstream_metadata.server_name.as_deref());
1188+
eval_http_init_context(&request, &trans_handler, None);
11781189

11791190
//
11801191
// 2. create the MetricsBody, which will track the size of the request body
@@ -1329,9 +1340,12 @@ impl Service<ExtendedRequest<Incoming>> for HttpRequestHandler {
13291340
return Ok(response);
13301341
};
13311342

1343+
let downstream_metadata_with_server_name =
1344+
Arc::new(DownstreamMetadata::new(downstream_metadata.as_ref().clone(), None::<String>));
1345+
13321346
let response = trans_handler
13331347
.clone()
1334-
.handle_transaction(route_conf, manager, permit, request, downstream_metadata)
1348+
.handle_transaction(route_conf, manager, permit, request, downstream_metadata_with_server_name)
13351349
.await;
13361350

13371351
trans_handler.trace_status_code(response, listener_name)

orion-lib/src/listeners/lds_update.rs

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -127,22 +127,11 @@ impl LdsManager {
127127

128128
let mut listeners_guard = listeners.write().await;
129129
if let Some(versions) = listeners_guard.get_vec_mut(&name) {
130-
let mut to_remove = Vec::new();
131-
for (i, listener_info) in versions.iter().enumerate() {
132-
if listener_info.is_draining() {
133-
to_remove.push(i);
134-
}
135-
}
136-
137-
for &index in to_remove.iter().rev() {
138-
if let Some(listener_info) = versions.get_mut(index) {
139-
listener_info.handle.abort();
140-
info!("LDS: Draining version of listener '{}' forcibly closed after timeout", name);
141-
}
142-
}
143-
for &index in to_remove.iter().rev() {
144-
versions.remove(index);
145-
}
130+
versions.iter_mut().filter(|listener_info| listener_info.is_draining()).for_each(|listener_info| {
131+
listener_info.handle.abort();
132+
info!("LDS: Draining version of listener '{}' forcibly closed after timeout", name);
133+
});
134+
versions.retain(|listener_info| !listener_info.is_draining());
146135
if versions.is_empty() {
147136
listeners_guard.remove(&name);
148137
}

orion-lib/src/listeners/listener.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ use std::{
4444
fmt::Debug,
4545
net::SocketAddr,
4646
sync::{
47-
atomic::{AtomicBool, Ordering},
47+
atomic::{AtomicBool, AtomicU64, Ordering},
4848
Arc,
4949
},
5050
};
@@ -53,7 +53,8 @@ use tokio::{
5353
sync::broadcast::{self},
5454
};
5555
use tracing::{debug, info, warn};
56-
use uuid;
56+
57+
static CONNECTION_COUNTER: AtomicU64 = AtomicU64::new(1);
5758

5859
#[derive(Debug, Clone)]
5960
struct PartialListener {
@@ -375,7 +376,8 @@ impl Listener {
375376
drain_handler: Option<Arc<DefaultConnectionHandler>>,
376377
) -> Result<()> {
377378
let shard_id = std::thread::current().id();
378-
let connection_id = format!("{}:{}:{}", local_address, peer_addr, uuid::Uuid::new_v4());
379+
let connection_id =
380+
format!("{}:{}:{}", local_address, peer_addr, CONNECTION_COUNTER.fetch_add(1, Ordering::Relaxed));
379381

380382
debug!("New connection {} established on listener {}", connection_id, listener_name);
381383

0 commit comments

Comments
 (0)