diff --git a/grpc/src/client/channel.rs b/grpc/src/client/channel.rs index b817fc2de..599e7b9fb 100644 --- a/grpc/src/client/channel.rs +++ b/grpc/src/client/channel.rs @@ -444,6 +444,7 @@ impl load_balancing::ChannelController for InternalChannelController { } // A channel that is not idle (connecting, ready, or erroring). +#[derive(Debug)] pub(super) struct GracefulSwitchBalancer { pub(super) policy: Mutex>>, policy_builder: Mutex>>, diff --git a/grpc/src/client/load_balancing/child_manager.rs b/grpc/src/client/load_balancing/child_manager.rs index 1c1cba8df..d3e8ca5b2 100644 --- a/grpc/src/client/load_balancing/child_manager.rs +++ b/grpc/src/client/load_balancing/child_manager.rs @@ -30,6 +30,7 @@ // production. Also, support for the work scheduler is missing. use std::collections::HashSet; +use std::fmt::Debug; use std::sync::Mutex; use std::{collections::HashMap, error::Error, hash::Hash, mem, sync::Arc}; @@ -44,64 +45,73 @@ use crate::rt::Runtime; use super::{Subchannel, SubchannelState}; // An LbPolicy implementation that manages multiple children. -pub struct ChildManager { - subchannel_child_map: HashMap, +#[derive(Debug)] +pub(crate) struct ChildManager> { + subchannel_to_child_idx: HashMap, children: Vec>, - update_sharder: Box>, + update_sharder: S, pending_work: Arc>>, runtime: Arc, + updated: bool, // Set when any child updates its picker; cleared when accessed. } -struct Child { - identifier: T, +#[non_exhaustive] +#[derive(Debug)] +pub(crate) struct Child { + pub identifier: T, + pub builder: Arc, + pub state: LbState, policy: Box, - state: LbState, work_scheduler: Arc, } /// A collection of data sent to a child of the ChildManager. -pub struct ChildUpdate { +pub(crate) struct ChildUpdate { /// The identifier the ChildManager should use for this child. pub child_identifier: T, /// The builder the ChildManager should use to create this child if it does - /// not exist. + /// not exist. The child_policy_builder's name is effectively a part of the + /// child_identifier. If two identifiers are identical but have different + /// builder names, they are treated as different children. pub child_policy_builder: Arc, - /// The relevant ResolverUpdate to send to this child. - pub child_update: ResolverUpdate, + /// The relevant ResolverUpdate and LbConfig to send to this child. If + /// None, then resolver_update will not be called on the child. Should + /// generally be Some for any new children, otherwise they will not be + /// called. + pub child_update: Option<(ResolverUpdate, Option)>, } -pub trait ResolverUpdateSharder: Send { - /// Performs the operation of sharding an aggregate ResolverUpdate into one - /// or more ChildUpdates. Called automatically by the ChildManager when its - /// resolver_update method is called. The key in the returned map is the - /// identifier the ChildManager should use for this child. +pub(crate) trait ResolverUpdateSharder: Send { + /// Performs the operation of sharding an aggregate ResolverUpdate/LbConfig + /// into one or more ChildUpdates. Called automatically by the ChildManager + /// when its resolver_update method is called. fn shard_update( - &self, - resolver_update: ResolverUpdate, - ) -> Result>>, Box>; + &mut self, + update: ResolverUpdate, + config: Option<&LbConfig>, + ) -> Result>, Box>; } -impl ChildManager { +impl ChildManager +where + S: ResolverUpdateSharder, +{ /// Creates a new ChildManager LB policy. shard_update is called whenever a /// resolver_update operation occurs. - pub fn new( - update_sharder: Box>, - runtime: Arc, - ) -> Self { + pub fn new(update_sharder: S, runtime: Arc) -> Self { Self { update_sharder, - subchannel_child_map: Default::default(), + subchannel_to_child_idx: Default::default(), children: Default::default(), pending_work: Default::default(), runtime, + updated: false, } } /// Returns data for all current children. - pub fn child_states(&mut self) -> impl Iterator { - self.children - .iter() - .map(|child| (&child.identifier, &child.state)) + pub fn children(&self) -> impl Iterator> { + self.children.iter() } /// Aggregates states from child policies. @@ -110,7 +120,7 @@ impl ChildManager { /// Otherwise, if any child is CONNECTING, then report CONNECTING. /// Otherwise, if any child is IDLE, then report IDLE. /// Report TRANSIENT FAILURE if no conditions above apply. - pub fn aggregate_states(&mut self) -> ConnectivityState { + pub fn aggregate_states(&self) -> ConnectivityState { let mut is_connecting = false; let mut is_idle = false; @@ -153,16 +163,33 @@ impl ChildManager { ) { // Add all created subchannels into the subchannel_child_map. for csc in channel_controller.created_subchannels { - self.subchannel_child_map.insert(csc.into(), child_idx); + self.subchannel_to_child_idx.insert(csc.into(), child_idx); } // Update the tracked state if the child produced an update. if let Some(state) = channel_controller.picker_update { self.children[child_idx].state = state; + self.updated = true; }; } + + /// Returns a mutable reference to the update sharder so operations may be + /// performed on it for instances in which it needs to retain state. + pub fn update_sharder(&mut self) -> &mut S { + &mut self.update_sharder + } + + /// Returns true if any child has updated its picker since the last call to + /// child_updated. + pub fn child_updated(&mut self) -> bool { + mem::take(&mut self.updated) + } } -impl LbPolicy for ChildManager { +impl LbPolicy for ChildManager +where + T: PartialEq + Hash + Eq + Send + Sync + 'static, + S: ResolverUpdateSharder, +{ fn resolver_update( &mut self, resolver_update: ResolverUpdate, @@ -170,7 +197,7 @@ impl LbPolicy for ChildManager channel_controller: &mut dyn ChannelController, ) -> Result<(), Box> { // First determine if the incoming update is valid. - let child_updates = self.update_sharder.shard_update(resolver_update)?; + let child_updates = self.update_sharder.shard_update(resolver_update, config)?; // Hold the lock to prevent new work requests during this operation and // rewrite the indices. @@ -184,27 +211,38 @@ impl LbPolicy for ChildManager let old_children = mem::take(&mut self.children); // Replace the subchannel map with an empty map. - let old_subchannel_child_map = mem::take(&mut self.subchannel_child_map); + let old_subchannel_child_map = mem::take(&mut self.subchannel_to_child_idx); - // Reverse the old subchannel map. - let mut old_child_subchannels_map: HashMap> = HashMap::new(); + // Reverse the old subchannel map into a vector indexed by the old child ID. + let mut old_child_subchannels: Vec> = Vec::new(); + old_child_subchannels.resize_with(old_children.len(), Vec::new); - for (subchannel, child_idx) in old_subchannel_child_map { - old_child_subchannels_map - .entry(child_idx) - .or_default() - .push(subchannel); + for (subchannel, old_idx) in old_subchannel_child_map { + old_child_subchannels[old_idx].push(subchannel); } // Build a map of the old children from their IDs for efficient lookups. - let old_children = old_children + // This leverages a Child to hold all the entries where the + // identifier becomes the index within the old self.children vector. + let mut old_children: HashMap<(&'static str, T), _> = old_children .into_iter() .enumerate() - .map(|(old_idx, e)| (e.identifier, (e.policy, e.state, old_idx, e.work_scheduler))); - let mut old_children: HashMap = old_children.collect(); + .map(|(old_idx, e)| { + ( + (e.builder.name(), e.identifier), + Child { + identifier: old_idx, + policy: e.policy, + builder: e.builder, + state: e.state, + work_scheduler: e.work_scheduler, + }, + ) + }) + .collect(); // Split the child updates into the IDs and builders, and the - // ResolverUpdates. + // ResolverUpdates/LbConfigs. let (ids_builders, updates): (Vec<_>, Vec<_>) = child_updates .map(|e| ((e.child_identifier, e.child_policy_builder), e.child_update)) .unzip(); @@ -213,24 +251,22 @@ impl LbPolicy for ChildManager // update, and create new children. Add entries back into the // subchannel map. for (new_idx, (identifier, builder)) in ids_builders.into_iter().enumerate() { - if let Some((policy, state, old_idx, work_scheduler)) = old_children.remove(&identifier) - { - for subchannel in old_child_subchannels_map - .remove(&old_idx) - .into_iter() - .flatten() - { - self.subchannel_child_map.insert(subchannel, new_idx); + let k = (builder.name(), identifier); + if let Some(old_child) = old_children.remove(&k) { + let old_idx = old_child.identifier; + for subchannel in mem::take(&mut old_child_subchannels[old_idx]) { + self.subchannel_to_child_idx.insert(subchannel, new_idx); } if old_pending_work.contains(&old_idx) { pending_work.insert(new_idx); } - *work_scheduler.idx.lock().unwrap() = Some(new_idx); + *old_child.work_scheduler.idx.lock().unwrap() = Some(new_idx); self.children.push(Child { - identifier, - state, - policy, - work_scheduler, + builder, + identifier: k.1, + state: old_child.state, + policy: old_child.policy, + work_scheduler: old_child.work_scheduler, }); } else { let work_scheduler = Arc::new(ChildWorkScheduler { @@ -241,10 +277,10 @@ impl LbPolicy for ChildManager work_scheduler: work_scheduler.clone(), runtime: self.runtime.clone(), }); - let state = LbState::initial(); self.children.push(Child { - identifier, - state, + builder, + identifier: k.1, + state: LbState::initial(), policy, work_scheduler, }); @@ -252,8 +288,8 @@ impl LbPolicy for ChildManager } // Invalidate all deleted children's work_schedulers. - for (_, (_, _, _, work_scheduler)) in old_children { - *work_scheduler.idx.lock().unwrap() = None; + for (_, old_child) in old_children { + *old_child.work_scheduler.idx.lock().unwrap() = None; } // Release the pending_work mutex before calling into the children to @@ -267,15 +303,22 @@ impl LbPolicy for ChildManager for child_idx in 0..self.children.len() { let child = &mut self.children[child_idx]; let child_update = updates.next().unwrap(); + let Some((resolver_update, config)) = child_update else { + continue; + }; let mut channel_controller = WrappedController::new(channel_controller); - let _ = child - .policy - .resolver_update(child_update, config, &mut channel_controller); + let _ = child.policy.resolver_update( + resolver_update, + config.as_ref(), + &mut channel_controller, + ); self.resolve_child_controller(channel_controller, child_idx); } Ok(()) } + // Forwards the subchannel_update to the child that created the subchannel + // being updated. fn subchannel_update( &mut self, subchannel: Arc, @@ -284,7 +327,7 @@ impl LbPolicy for ChildManager ) { // Determine which child created this subchannel. let child_idx = *self - .subchannel_child_map + .subchannel_to_child_idx .get(&WeakSubchannel::new(&subchannel)) .unwrap(); let policy = &mut self.children[child_idx].policy; @@ -295,6 +338,7 @@ impl LbPolicy for ChildManager self.resolve_child_controller(channel_controller, child_idx); } + // Calls work on any children that scheduled work via our work scheduler. fn work(&mut self, channel_controller: &mut dyn ChannelController) { let child_idxes = mem::take(&mut *self.pending_work.lock().unwrap()); for child_idx in child_idxes { @@ -306,8 +350,14 @@ impl LbPolicy for ChildManager } } - fn exit_idle(&mut self, _channel_controller: &mut dyn ChannelController) { - todo!("implement exit_idle") + // Simply calls exit_idle on all children. + fn exit_idle(&mut self, channel_controller: &mut dyn ChannelController) { + for child_idx in 0..self.children.len() { + let child = &mut self.children[child_idx]; + let mut channel_controller = WrappedController::new(channel_controller); + child.policy.exit_idle(&mut channel_controller); + self.resolve_child_controller(channel_controller, child_idx); + } } } @@ -343,6 +393,7 @@ impl ChannelController for WrappedController<'_> { } } +#[derive(Debug)] struct ChildWorkScheduler { pending_work: Arc>>, // Must be taken first for correctness idx: Mutex>, // None if the child is deleted. @@ -363,13 +414,14 @@ mod test { ChildManager, ChildUpdate, ResolverUpdateSharder, }; use crate::client::load_balancing::test_utils::{ - self, StubPolicyFuncs, TestChannelController, TestEvent, + self, StubPolicyData, StubPolicyFuncs, TestChannelController, TestEvent, }; use crate::client::load_balancing::{ ChannelController, LbPolicy, LbPolicyBuilder, LbState, QueuingPicker, Subchannel, SubchannelState, GLOBAL_LB_REGISTRY, }; use crate::client::name_resolution::{Address, Endpoint, ResolverUpdate}; + use crate::client::service_config::LbConfig; use crate::client::ConnectivityState; use crate::rt::default_runtime; use std::error::Error; @@ -380,31 +432,36 @@ mod test { // TODO: This needs to be moved to a common place that can be shared between // round_robin and this test. This EndpointSharder maps endpoints to // children policies. + #[derive(Debug)] struct EndpointSharder { builder: Arc, } impl ResolverUpdateSharder for EndpointSharder { fn shard_update( - &self, + &mut self, resolver_update: ResolverUpdate, - ) -> Result>>, Box> + config: Option<&LbConfig>, + ) -> Result>, Box> { let mut sharded_endpoints = Vec::new(); - for endpoint in resolver_update.endpoints.unwrap().iter() { + for endpoint in resolver_update.endpoints.unwrap().into_iter() { let child_update = ChildUpdate { child_identifier: endpoint.clone(), child_policy_builder: self.builder.clone(), - child_update: ResolverUpdate { - attributes: resolver_update.attributes.clone(), - endpoints: Ok(vec![endpoint.clone()]), - service_config: resolver_update.service_config.clone(), - resolution_note: resolver_update.resolution_note.clone(), - }, + child_update: Some(( + ResolverUpdate { + attributes: resolver_update.attributes.clone(), + endpoints: Ok(vec![endpoint]), + service_config: resolver_update.service_config.clone(), + resolution_note: resolver_update.resolution_note.clone(), + }, + config.cloned(), + )), }; sharded_endpoints.push(child_update); } - Ok(Box::new(sharded_endpoints.into_iter())) + Ok(sharded_endpoints.into_iter()) } } @@ -430,16 +487,16 @@ mod test { test_name: &'static str, ) -> ( mpsc::UnboundedReceiver, - Box>, + ChildManager, Box, ) { test_utils::reg_stub_policy(test_name, funcs); let (tx_events, rx_events) = mpsc::unbounded_channel::(); let tcc = Box::new(TestChannelController { tx_events }); let builder: Arc = GLOBAL_LB_REGISTRY.get_policy(test_name).unwrap(); - let endpoint_sharder = EndpointSharder { builder: builder }; - let child_manager = ChildManager::new(Box::new(endpoint_sharder), default_runtime()); - (rx_events, Box::new(child_manager), tcc) + let endpoint_sharder = EndpointSharder { builder }; + let child_manager = ChildManager::new(endpoint_sharder, default_runtime()); + (rx_events, child_manager, tcc) } fn create_n_endpoints_with_k_addresses(n: usize, k: usize) -> Vec { @@ -474,7 +531,7 @@ mod test { } fn move_subchannel_to_state( - lb_policy: &mut dyn LbPolicy, + lb_policy: &mut impl LbPolicy, subchannel: Arc, tcc: &mut dyn ChannelController, state: ConnectivityState, @@ -510,25 +567,29 @@ mod test { // Defines the functions resolver_update and subchannel_update to test // aggregate_states. fn create_verifying_funcs_for_aggregate_tests() -> StubPolicyFuncs { + let data = StubPolicyData::new(); StubPolicyFuncs { // Closure for resolver_update. resolver_update should only receive // one endpoint and create one subchannel for the endpoint it // receives. - resolver_update: Some(move |update: ResolverUpdate, _, controller| { - assert_eq!(update.endpoints.iter().len(), 1); - let endpoint = update.endpoints.unwrap().pop().unwrap(); - let subchannel = controller.new_subchannel(&endpoint.addresses[0]); - Ok(()) - }), + resolver_update: Some(Arc::new( + move |data, update: ResolverUpdate, _, controller| { + assert_eq!(update.endpoints.iter().len(), 1); + let endpoint = update.endpoints.unwrap().pop().unwrap(); + let subchannel = controller.new_subchannel(&endpoint.addresses[0]); + Ok(()) + }, + )), // Closure for subchannel_update. Sends a picker of the same state // that was passed to it. - subchannel_update: Some(move |updated_subchannel, state, controller| { - controller.update_picker(LbState { - connectivity_state: state.connectivity_state, - picker: Arc::new(QueuingPicker {}), - }); - }), - ..Default::default() + subchannel_update: Some(Arc::new( + move |data, updated_subchannel, state, controller| { + controller.update_picker(LbState { + connectivity_state: state.connectivity_state, + picker: Arc::new(QueuingPicker {}), + }); + }, + )), } } @@ -542,7 +603,7 @@ mod test { "stub-childmanager_aggregate_state_is_ready_if_any_child_is_ready", ); let endpoints = create_n_endpoints_with_k_addresses(4, 1); - send_resolver_update_to_policy(child_manager.as_mut(), endpoints.clone(), tcc.as_mut()); + send_resolver_update_to_policy(&mut child_manager, endpoints.clone(), tcc.as_mut()); let mut subchannels = vec![]; for endpoint in endpoints { subchannels.push( @@ -554,25 +615,25 @@ mod test { let mut subchannels = subchannels.into_iter(); move_subchannel_to_state( - child_manager.as_mut(), + &mut child_manager, subchannels.next().unwrap(), tcc.as_mut(), ConnectivityState::TransientFailure, ); move_subchannel_to_state( - child_manager.as_mut(), + &mut child_manager, subchannels.next().unwrap(), tcc.as_mut(), ConnectivityState::Idle, ); move_subchannel_to_state( - child_manager.as_mut(), + &mut child_manager, subchannels.next().unwrap(), tcc.as_mut(), ConnectivityState::Connecting, ); move_subchannel_to_state( - child_manager.as_mut(), + &mut child_manager, subchannels.next().unwrap(), tcc.as_mut(), ConnectivityState::Ready, @@ -590,7 +651,7 @@ mod test { "stub-childmanager_aggregate_state_is_connecting_if_no_child_is_ready", ); let endpoints = create_n_endpoints_with_k_addresses(3, 1); - send_resolver_update_to_policy(child_manager.as_mut(), endpoints.clone(), tcc.as_mut()); + send_resolver_update_to_policy(&mut child_manager, endpoints.clone(), tcc.as_mut()); let mut subchannels = vec![]; for endpoint in endpoints { subchannels.push( @@ -601,19 +662,19 @@ mod test { } let mut subchannels = subchannels.into_iter(); move_subchannel_to_state( - child_manager.as_mut(), + &mut child_manager, subchannels.next().unwrap(), tcc.as_mut(), ConnectivityState::TransientFailure, ); move_subchannel_to_state( - child_manager.as_mut(), + &mut child_manager, subchannels.next().unwrap(), tcc.as_mut(), ConnectivityState::Idle, ); move_subchannel_to_state( - child_manager.as_mut(), + &mut child_manager, subchannels.next().unwrap(), tcc.as_mut(), ConnectivityState::Connecting, @@ -636,7 +697,7 @@ mod test { ); let endpoints = create_n_endpoints_with_k_addresses(2, 1); - send_resolver_update_to_policy(child_manager.as_mut(), endpoints.clone(), tcc.as_mut()); + send_resolver_update_to_policy(&mut child_manager, endpoints.clone(), tcc.as_mut()); let mut subchannels = vec![]; for endpoint in endpoints { subchannels.push( @@ -647,13 +708,13 @@ mod test { } let mut subchannels = subchannels.into_iter(); move_subchannel_to_state( - child_manager.as_mut(), + &mut child_manager, subchannels.next().unwrap(), tcc.as_mut(), ConnectivityState::TransientFailure, ); move_subchannel_to_state( - child_manager.as_mut(), + &mut child_manager, subchannels.next().unwrap(), tcc.as_mut(), ConnectivityState::Idle, @@ -671,7 +732,7 @@ mod test { "stub-childmanager_aggregate_state_is_transient_failure_if_all_children_are", ); let endpoints = create_n_endpoints_with_k_addresses(2, 1); - send_resolver_update_to_policy(child_manager.as_mut(), endpoints.clone(), tcc.as_mut()); + send_resolver_update_to_policy(&mut child_manager, endpoints.clone(), tcc.as_mut()); let mut subchannels = vec![]; for endpoint in endpoints { subchannels.push( @@ -682,13 +743,13 @@ mod test { } let mut subchannels = subchannels.into_iter(); move_subchannel_to_state( - child_manager.as_mut(), + &mut child_manager, subchannels.next().unwrap(), tcc.as_mut(), ConnectivityState::TransientFailure, ); move_subchannel_to_state( - child_manager.as_mut(), + &mut child_manager, subchannels.next().unwrap(), tcc.as_mut(), ConnectivityState::TransientFailure, diff --git a/grpc/src/client/load_balancing/graceful_switch.rs b/grpc/src/client/load_balancing/graceful_switch.rs new file mode 100644 index 000000000..0f4430df9 --- /dev/null +++ b/grpc/src/client/load_balancing/graceful_switch.rs @@ -0,0 +1,1008 @@ +use crate::client::load_balancing::child_manager::{ + self, ChildManager, ChildUpdate, ResolverUpdateSharder, +}; +use crate::client::load_balancing::{ + ChannelController, LbConfig, LbPolicy, LbPolicyBuilder, LbState, ParsedJsonLbConfig, + Subchannel, SubchannelState, GLOBAL_LB_REGISTRY, +}; +use crate::client::name_resolution::ResolverUpdate; +use crate::client::ConnectivityState; +use crate::rt::Runtime; + +use std::collections::HashMap; +use std::error::Error; +use std::sync::Arc; + +enum GracefulSwitchLbConfig { + Update(Arc, Option), + Swap(Arc), +} + +#[derive(Debug)] +struct UpdateSharder { + active_child_builder: Option>, +} + +impl ResolverUpdateSharder<()> for UpdateSharder { + fn shard_update( + &mut self, + resolver_update: ResolverUpdate, + config: Option<&LbConfig>, // The config is always produced based on the state stored in the sharder. + ) -> Result>, Box> + { + let config = config.expect("graceful switch should always get an LbConfig"); + + let gsb_config: Arc = config.convert_to().expect("invalid config"); + + let child_config; + let child_builder; + match &*gsb_config { + GracefulSwitchLbConfig::Swap(child_builder) => { + // When swapping we update the active_child_builder to the one + // we are swapping to and send an empty update that only + // includes that child, which removes the other child. + self.active_child_builder = Some(child_builder.clone()); + return Ok(vec![ChildUpdate { + child_policy_builder: child_builder.clone(), + child_identifier: (), + child_update: None, + }] + .into_iter()); + } + GracefulSwitchLbConfig::Update(cb, cc) => { + child_builder = cb; + child_config = cc; + } + } + + if self.active_child_builder.is_none() { + // When there are no children yet, the current update immediately + // becomes the active child. + self.active_child_builder = Some(child_builder.clone()); + } + let active_child_builder = self.active_child_builder.as_ref().unwrap(); + + let mut resp = Vec::with_capacity(2); + + // Always include the incoming update. + resp.push(ChildUpdate { + child_policy_builder: child_builder.clone(), + child_identifier: (), + child_update: Some((resolver_update, child_config.clone())), + }); + + // Include the active child if it does not match the updated child so + // that the child manager will not delete it. + if child_builder.name() != active_child_builder.name() { + resp.push(ChildUpdate { + child_policy_builder: active_child_builder.clone(), + child_identifier: (), + child_update: None, + }); + } + + Ok(resp.into_iter()) + } +} + +impl UpdateSharder { + fn new() -> Self { + Self { + active_child_builder: None, + } + } +} + +/// A graceful switching load balancing policy. In graceful switch, there is +/// always either one or two child policies. When there is one policy, all +/// operations are delegated to it. When the child policy type needs to change, +/// graceful switch creates a "pending" child policy alongside the "active" +/// policy. When the pending policy leaves the CONNECTING state, or when the +/// active policy is not READY, graceful switch will promote the pending policy +/// to to active and tear down the previously active policy. +#[derive(Debug)] +pub(crate) struct GracefulSwitchPolicy { + child_manager: ChildManager<(), UpdateSharder>, // Child ID is the name of the child policy. + last_update: LbState, // Saves the last output LbState to determine if an update is needed. +} + +impl LbPolicy for GracefulSwitchPolicy { + fn resolver_update( + &mut self, + update: ResolverUpdate, + config: Option<&LbConfig>, + channel_controller: &mut dyn ChannelController, + ) -> Result<(), Box> { + let res = self + .child_manager + .resolver_update(update, config, channel_controller)?; + self.update_picker(channel_controller); + Ok(()) + } + + fn subchannel_update( + &mut self, + subchannel: Arc, + state: &SubchannelState, + channel_controller: &mut dyn ChannelController, + ) { + self.child_manager + .subchannel_update(subchannel, state, channel_controller); + self.update_picker(channel_controller); + } + + fn work(&mut self, channel_controller: &mut dyn ChannelController) { + self.child_manager.work(channel_controller); + self.update_picker(channel_controller); + } + + fn exit_idle(&mut self, channel_controller: &mut dyn ChannelController) { + self.child_manager.exit_idle(channel_controller); + self.update_picker(channel_controller); + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] +enum ChildKind { + Current, + Pending, +} + +impl GracefulSwitchPolicy { + /// Creates a new Graceful Switch policy. + pub fn new(runtime: Arc) -> Self { + GracefulSwitchPolicy { + child_manager: ChildManager::new(UpdateSharder::new(), runtime), + last_update: LbState::initial(), + } + } + + /// Parses a child config list and returns a LB config for the + /// GracefulSwitchPolicy. Config is expected to contain a JSON array of LB + /// policy names + configs matching the format of the "loadBalancingConfig" + /// field in the gRPC ServiceConfig. It returns a type that should be passed + /// to resolver_update in the LbConfig.config field. + pub fn parse_config( + config: &ParsedJsonLbConfig, + ) -> Result> { + let cfg: Vec> = match config.convert_to() { + Ok(c) => c, + Err(e) => { + return Err(format!("failed to parse JSON config: {}", e).into()); + } + }; + for c in cfg { + if c.len() != 1 { + return Err(format!( + "Each element in array must contain exactly one policy name/config; found {:?}", + c.keys() + ) + .into()); + } + let (policy_name, policy_config) = c.into_iter().next().unwrap(); + let Some(child_builder) = GLOBAL_LB_REGISTRY.get_policy(policy_name.as_str()) else { + continue; + }; + let parsed_config = ParsedJsonLbConfig { + value: policy_config, + }; + let child_config = child_builder.parse_config(&parsed_config)?; + let gsb_config = GracefulSwitchLbConfig::Update(child_builder, child_config); + return Ok(LbConfig::new(gsb_config)); + } + Err("no supported policies found in config".into()) + } + + fn update_picker(&mut self, channel_controller: &mut dyn ChannelController) { + let Some(update) = self.maybe_swap(channel_controller) else { + return; + }; + if self.last_update.connectivity_state == update.connectivity_state + && std::ptr::addr_eq( + Arc::as_ptr(&self.last_update.picker), + Arc::as_ptr(&update.picker), + ) + { + return; + } + channel_controller.update_picker(update.clone()); + self.last_update = update; + } + + // Determines the appropriate state to output + fn maybe_swap(&mut self, channel_controller: &mut dyn ChannelController) -> Option { + if !self.child_manager.child_updated() { + return None; + } + + let active_name = self + .child_manager + .update_sharder() + .active_child_builder + .as_ref() + .unwrap() + .name(); + + let mut active_child = None; + let mut pending_child = None; + for child in self.child_manager.children() { + if child.builder.name() == active_name { + active_child = Some(child); + } else { + pending_child = Some(child); + } + } + let active_child = active_child.expect("There should always be an active child policy"); + let Some(pending_child) = pending_child else { + return Some(active_child.state.clone()); + }; + + if active_child.state.connectivity_state == ConnectivityState::Ready + && pending_child.state.connectivity_state == ConnectivityState::Connecting + { + return Some(active_child.state.clone()); + } + + let config = &LbConfig::new(GracefulSwitchLbConfig::Swap(pending_child.builder.clone())); + let state = pending_child.state.clone(); + self.child_manager + .resolver_update(ResolverUpdate::default(), Some(config), channel_controller) + .expect("resolver_update with an empty update should not fail"); + return Some(state); + } +} + +#[cfg(test)] +mod test { + use crate::client::load_balancing::graceful_switch::GracefulSwitchPolicy; + use crate::client::load_balancing::test_utils::{ + self, reg_stub_policy, StubPolicyData, StubPolicyFuncs, TestChannelController, TestEvent, + TestSubchannel, TestWorkScheduler, + }; + use crate::client::load_balancing::{ + ChannelController, LbPolicy, ParsedJsonLbConfig, PickResult, Picker, Subchannel, + SubchannelState, + }; + use crate::client::load_balancing::{LbState, Pick}; + use crate::client::name_resolution::{Address, Endpoint, ResolverUpdate}; + use crate::client::ConnectivityState; + use crate::rt::default_runtime; + use crate::service::Request; + use std::time::Duration; + use std::{panic, sync::Arc}; + use tokio::select; + use tokio::sync::mpsc::{self, UnboundedReceiver}; + use tonic::metadata::MetadataMap; + + const DEFAULT_TEST_SHORT_TIMEOUT: Duration = Duration::from_millis(10); + + struct TestSubchannelList { + subchannels: Vec>, + } + + impl TestSubchannelList { + fn new(addresses: &Vec
, channel_controller: &mut dyn ChannelController) -> Self { + let mut scl = TestSubchannelList { + subchannels: Vec::new(), + }; + for address in addresses { + let sc = channel_controller.new_subchannel(address); + scl.subchannels.push(sc.clone()); + } + scl + } + + fn contains(&self, sc: &Arc) -> bool { + self.subchannels.contains(sc) + } + } + + #[derive(Debug)] + struct TestPicker { + name: &'static str, + } + + impl TestPicker { + fn new(name: &'static str) -> Self { + Self { name } + } + } + impl Picker for TestPicker { + fn pick(&self, _req: &Request) -> PickResult { + PickResult::Pick(Pick { + subchannel: Arc::new(TestSubchannel::new( + Address { + address: self.name.to_string().into(), + ..Default::default() + }, + mpsc::unbounded_channel().0, + )), + metadata: MetadataMap::new(), + on_complete: None, + }) + } + } + + struct TestState { + subchannel_list: TestSubchannelList, + } + + // Defines the functions resolver_update and subchannel_update to test graceful switch + fn create_funcs_for_gracefulswitch_tests(name: &'static str) -> StubPolicyFuncs { + StubPolicyFuncs { + // Closure for resolver_update. It creates a subchannel for the + // endpoint it receives and stores which endpoint it received and + // which subchannel this child created in the data field. + resolver_update: Some(Arc::new( + move |data: &mut StubPolicyData, update: ResolverUpdate, _, channel_controller| { + if let Ok(ref endpoints) = update.endpoints { + let addresses: Vec<_> = endpoints + .iter() + .flat_map(|ep| ep.addresses.clone()) + .collect(); + let scl = TestSubchannelList::new(&addresses, channel_controller); + let child_state = TestState { + subchannel_list: scl, + }; + data.test_data = Some(Box::new(child_state)); + } else { + data.test_data = None; + } + Ok(()) + }, + )), + // Closure for subchannel_update. Verify that the subchannel that + // being updated now is the same one that this child policy created + // in resolver_update. It then sends a picker of the same state that + // was passed to it. + subchannel_update: Some(Arc::new( + move |data: &mut StubPolicyData, updated_subchannel, state, channel_controller| { + // Retrieve the specific TestState from the generic test_data field. + // This downcasts the `Any` trait object. + let test_data = data.test_data.as_mut().unwrap(); + let test_state = test_data.downcast_mut::().unwrap(); + let scl = &mut test_state.subchannel_list; + assert!( + scl.contains(&updated_subchannel), + "subchannel_update received an update for a subchannel it does not own." + ); + channel_controller.update_picker(LbState { + connectivity_state: state.connectivity_state, + picker: Arc::new(TestPicker { name }), + }); + }, + )), + } + } + + // Sets up the test environment. + // + // Performs the following: + // 1. Creates a work scheduler. + // 2. Creates a fake channel that acts as a channel controller. + // 3. Creates an StubPolicyBuilder with StubFuncs that each test will define + // and name of the test. + // 5. Creates a GracefulSwitch. + // + // Returns the following: + // 1. A receiver for events initiated by the LB policy (like creating a new + // subchannel, sending a new picker etc). + // 2. The GracefulSwitch to send resolver and subchannel updates from the + // test. + // 3. The controller to pass to the LB policy as part of the updates. + fn setup() -> ( + mpsc::UnboundedReceiver, + Box, + Box, + ) { + let (tx_events, rx_events) = mpsc::unbounded_channel::(); + let work_scheduler = Arc::new(TestWorkScheduler { + tx_events: tx_events.clone(), + }); + + let tcc = Box::new(TestChannelController { tx_events }); + + let graceful_switch = GracefulSwitchPolicy::new(default_runtime()); + (rx_events, Box::new(graceful_switch), tcc) + } + + fn create_endpoint_with_one_address(addr: String) -> Endpoint { + Endpoint { + addresses: vec![Address { + address: addr.into(), + ..Default::default() + }], + ..Default::default() + } + } + + // Verifies that the next event on rx_events channel is NewSubchannel. + // Returns the subchannel created. + async fn verify_subchannel_creation_from_policy( + rx_events: &mut mpsc::UnboundedReceiver, + ) -> Arc { + match rx_events.recv().await.unwrap() { + TestEvent::NewSubchannel(sc) => { + return sc; + } + other => panic!("unexpected event {:?}", other), + }; + } + + // Verifies that the channel moves to READY state with a picker that returns the + // given subchannel. + // + // Returns the picker for tests to make more picks, if required. + async fn verify_correct_picker_from_policy( + rx_events: &mut mpsc::UnboundedReceiver, + name: &str, + ) { + println!("verify ready picker"); + let event = rx_events.recv().await.unwrap(); + let TestEvent::UpdatePicker(update) = event else { + panic!("unexpected event {:?}", event); + }; + let req = test_utils::new_request(); + println!("{:?}", update.connectivity_state); + + let pick = update.picker.pick(&req); + let PickResult::Pick(pick) = pick else { + panic!("unexpected pick result: {:?}", pick); + }; + let received_address = &pick.subchannel.address().address.to_string(); + // It's good practice to create the expected value once. + let expected_address = name.to_string(); + + // Check for inequality and panic with a detailed message if they don't match. + assert_eq!(received_address, &expected_address); + } + + fn move_subchannel_to_state( + lb_policy: &mut dyn LbPolicy, + subchannel: Arc, + tcc: &mut dyn ChannelController, + state: ConnectivityState, + ) { + lb_policy.subchannel_update( + subchannel, + &SubchannelState { + connectivity_state: state, + ..Default::default() + }, + tcc, + ); + } + + // Tests that the gracefulswitch policy correctly sets a child and sends + // updates to that child when it receives its first config. + #[tokio::test] + async fn gracefulswitch_successful_first_update() { + reg_stub_policy( + "stub-gracefulswitch_successful_first_update-one", + create_funcs_for_gracefulswitch_tests( + "stub-gracefulswitch_successful_first_update-one", + ), + ); + reg_stub_policy( + "stub-gracefulswitch_successful_first_update-two", + create_funcs_for_gracefulswitch_tests( + "stub-gracefulswitch_successful_first_update-two", + ), + ); + + let (mut rx_events, mut graceful_switch, mut tcc) = setup(); + let service_config = serde_json::json!([ + { "stub-gracefulswitch_successful_first_update-one": serde_json::json!({}) }, + { "stub-gracefulswitch_successful_first_update-two": serde_json::json!({}) } + ] + ); + + let parsed_config = GracefulSwitchPolicy::parse_config(&ParsedJsonLbConfig { + value: service_config, + }) + .unwrap(); + + let endpoint = create_endpoint_with_one_address("127.0.0.1:1234".to_string()); + let update = ResolverUpdate { + endpoints: Ok(vec![endpoint.clone()]), + ..Default::default() + }; + graceful_switch + .resolver_update(update.clone(), Some(&parsed_config), &mut *tcc) + .unwrap(); + + let subchannel = verify_subchannel_creation_from_policy(&mut rx_events).await; + move_subchannel_to_state( + &mut *graceful_switch, + subchannel, + tcc.as_mut(), + ConnectivityState::Ready, + ); + verify_correct_picker_from_policy( + &mut rx_events, + "stub-gracefulswitch_successful_first_update-one", + ) + .await; + } + + // Tests that the gracefulswitch policy correctly sets a pending child and + // sends subchannel updates to that child when it receives a new config. + #[tokio::test] + async fn gracefulswitch_switching_to_resolver_update() { + let (mut rx_events, mut graceful_switch, mut tcc) = setup(); + reg_stub_policy( + "stub-gracefulswitch_switching_to_resolver_update-one", + create_funcs_for_gracefulswitch_tests( + "stub-gracefulswitch_switching_to_resolver_update-one", + ), + ); + reg_stub_policy( + "stub-gracefulswitch_switching_to_resolver_update-two", + create_funcs_for_gracefulswitch_tests( + "stub-gracefulswitch_switching_to_resolver_update-two", + ), + ); + + let service_config = serde_json::json!([ + { "stub-gracefulswitch_switching_to_resolver_update-one": serde_json::json!({}) } + ] + ); + let parsed_config = GracefulSwitchPolicy::parse_config(&ParsedJsonLbConfig { + value: service_config, + }) + .unwrap(); + + let endpoint = create_endpoint_with_one_address("127.0.0.1:1234".to_string()); + let update = ResolverUpdate { + endpoints: Ok(vec![endpoint.clone()]), + ..Default::default() + }; + + graceful_switch + .resolver_update(update.clone(), Some(&parsed_config), &mut *tcc) + .unwrap(); + + // Subchannel creation and ready + let subchannel = verify_subchannel_creation_from_policy(&mut rx_events).await; + move_subchannel_to_state( + &mut *graceful_switch, + subchannel, + tcc.as_mut(), + ConnectivityState::Ready, + ); + + // Assert picker is TestPickerOne by checking subchannel address + verify_correct_picker_from_policy( + &mut rx_events, + "stub-gracefulswitch_switching_to_resolver_update-one", + ) + .await; + + // 2. Switch to mock_policy_two as pending + let new_service_config = serde_json::json!([ + { "stub-gracefulswitch_switching_to_resolver_update-two": serde_json::json!({}) } + ] + ); + let new_parsed_config = GracefulSwitchPolicy::parse_config(&ParsedJsonLbConfig { + value: new_service_config, + }) + .unwrap(); + graceful_switch + .resolver_update(update.clone(), Some(&new_parsed_config), &mut *tcc) + .unwrap(); + + // Simulate subchannel creation and ready for pending + let subchannel_two = verify_subchannel_creation_from_policy(&mut rx_events).await; + move_subchannel_to_state( + &mut *graceful_switch, + subchannel_two, + tcc.as_mut(), + ConnectivityState::Ready, + ); + // Assert picker is TestPickerTwo by checking subchannel address + verify_correct_picker_from_policy( + &mut rx_events, + "stub-gracefulswitch_switching_to_resolver_update-two", + ) + .await; + assert_channel_empty(&mut rx_events).await; + } + + async fn assert_channel_empty(rx_events: &mut UnboundedReceiver) { + select! { + event = rx_events.recv() => { + panic!("Received unexpected event from policy: {event:?}"); + } + _ = tokio::time::sleep(DEFAULT_TEST_SHORT_TIMEOUT) => {} + }; + } + + // Tests that the gracefulswitch policy should do nothing when it receives a + // new config of the same policy that it received before. + #[tokio::test] + async fn gracefulswitch_two_policies_same_type() { + let (mut rx_events, mut graceful_switch, mut tcc) = setup(); + reg_stub_policy( + "stub-gracefulswitch_two_policies_same_type-one", + create_funcs_for_gracefulswitch_tests("stub-gracefulswitch_two_policies_same_type-one"), + ); + let service_config = serde_json::json!( + [ + { "stub-gracefulswitch_two_policies_same_type-one": serde_json::json!({}) } + ] + ); + let parsed_config = GracefulSwitchPolicy::parse_config(&ParsedJsonLbConfig { + value: service_config, + }) + .unwrap(); + let endpoint = create_endpoint_with_one_address("127.0.0.1:1234".to_string()); + let update = ResolverUpdate { + endpoints: Ok(vec![endpoint.clone()]), + ..Default::default() + }; + graceful_switch + .resolver_update(update.clone(), Some(&parsed_config), &mut *tcc) + .unwrap(); + let subchannel = verify_subchannel_creation_from_policy(&mut rx_events).await; + move_subchannel_to_state( + &mut *graceful_switch, + subchannel, + tcc.as_mut(), + ConnectivityState::Ready, + ); + verify_correct_picker_from_policy( + &mut rx_events, + "stub-gracefulswitch_two_policies_same_type-one", + ) + .await; + + let service_config2 = serde_json::json!( + [ + { "stub-gracefulswitch_two_policies_same_type-one": serde_json::json!({}) } + ] + ); + let parsed_config2 = GracefulSwitchPolicy::parse_config(&ParsedJsonLbConfig { + value: service_config2, + }) + .unwrap(); + graceful_switch + .resolver_update(update.clone(), Some(&parsed_config2), &mut *tcc) + .unwrap(); + let subchannel = verify_subchannel_creation_from_policy(&mut rx_events).await; + assert_eq!(&*subchannel.address().address, "127.0.0.1:1234"); + assert_channel_empty(&mut rx_events).await; + } + + // Tests that the gracefulswitch policy should replace the current child + // with the pending child if the current child isn't ready. + #[tokio::test] + async fn gracefulswitch_current_not_ready_pending_update() { + let (mut rx_events, mut graceful_switch, mut tcc) = setup(); + reg_stub_policy( + "stub-gracefulswitch_current_not_ready_pending_update-one", + create_funcs_for_gracefulswitch_tests( + "stub-gracefulswitch_current_not_ready_pending_update-one", + ), + ); + reg_stub_policy( + "stub-gracefulswitch_current_not_ready_pending_update-two", + create_funcs_for_gracefulswitch_tests( + "stub-gracefulswitch_current_not_ready_pending_update-two", + ), + ); + + let service_config = serde_json::json!([ + { "stub-gracefulswitch_current_not_ready_pending_update-one": serde_json::json!({}) } + ] + ); + + let parsed_config = GracefulSwitchPolicy::parse_config(&ParsedJsonLbConfig { + value: service_config, + }) + .unwrap(); + + let endpoint = create_endpoint_with_one_address("127.0.0.1:1234".to_string()); + let second_endpoint = create_endpoint_with_one_address("0.0.0.0.0".to_string()); + let update = ResolverUpdate { + endpoints: Ok(vec![endpoint.clone()]), + ..Default::default() + }; + + // Switch to first one (current) + graceful_switch + .resolver_update(update.clone(), Some(&parsed_config), &mut *tcc) + .unwrap(); + + let current_subchannels = verify_subchannel_creation_from_policy(&mut rx_events).await; + assert_channel_empty(&mut rx_events).await; + + let new_service_config = serde_json::json!([ + { "stub-gracefulswitch_current_not_ready_pending_update-two": serde_json::json!({ "shuffleAddressList": false }) }, + ] + ); + let second_update = ResolverUpdate { + endpoints: Ok(vec![second_endpoint.clone()]), + ..Default::default() + }; + let new_parsed_config = GracefulSwitchPolicy::parse_config(&ParsedJsonLbConfig { + value: new_service_config, + }) + .unwrap(); + graceful_switch + .resolver_update(second_update.clone(), Some(&new_parsed_config), &mut *tcc) + .unwrap(); + + let second_subchannel = verify_subchannel_creation_from_policy(&mut rx_events).await; + assert_channel_empty(&mut rx_events).await; + + move_subchannel_to_state( + &mut *graceful_switch, + second_subchannel, + tcc.as_mut(), + ConnectivityState::Ready, + ); + verify_correct_picker_from_policy( + &mut rx_events, + "stub-gracefulswitch_current_not_ready_pending_update-two", + ) + .await; + assert_channel_empty(&mut rx_events).await; + } + + // Tests that the gracefulswitch policy should replace the current child + // with the pending child if the current child was ready but then leaves ready. + #[tokio::test] + async fn gracefulswitch_current_leaving_ready() { + let (mut rx_events, mut graceful_switch, mut tcc) = setup(); + reg_stub_policy( + "stub-gracefulswitch_current_leaving_ready-one", + create_funcs_for_gracefulswitch_tests("stub-gracefulswitch_current_leaving_ready-one"), + ); + reg_stub_policy( + "stub-gracefulswitch_current_leaving_ready-two", + create_funcs_for_gracefulswitch_tests("stub-gracefulswitch_current_leaving_ready-two"), + ); + let service_config = serde_json::json!([ + { "stub-gracefulswitch_current_leaving_ready-one": serde_json::json!({}) } + ] + ); + let parsed_config = GracefulSwitchPolicy::parse_config(&ParsedJsonLbConfig { + value: service_config, + }) + .unwrap(); + + let endpoint = create_endpoint_with_one_address("127.0.0.1:1234".to_string()); + let endpoint2 = create_endpoint_with_one_address("127.0.0.1:1235".to_string()); + let update = ResolverUpdate { + endpoints: Ok(vec![endpoint.clone()]), + ..Default::default() + }; + + // Switch to first one (current) + graceful_switch + .resolver_update(update.clone(), Some(&parsed_config), &mut *tcc) + .unwrap(); + + let current_subchannel = verify_subchannel_creation_from_policy(&mut rx_events).await; + move_subchannel_to_state( + &mut *graceful_switch, + current_subchannel.clone(), + tcc.as_mut(), + ConnectivityState::Ready, + ); + verify_correct_picker_from_policy( + &mut rx_events, + "stub-gracefulswitch_current_leaving_ready-one", + ) + .await; + let new_service_config = serde_json::json!( + [ + { "stub-gracefulswitch_current_leaving_ready-two": serde_json::json!({}) }, + + ] + ); + let new_update = ResolverUpdate { + endpoints: Ok(vec![endpoint2.clone()]), + ..Default::default() + }; + let new_parsed_config = GracefulSwitchPolicy::parse_config(&ParsedJsonLbConfig { + value: new_service_config, + }) + .unwrap(); + graceful_switch + .resolver_update(new_update.clone(), Some(&new_parsed_config), &mut *tcc) + .unwrap(); + + let pending_subchannel = verify_subchannel_creation_from_policy(&mut rx_events).await; + + move_subchannel_to_state( + &mut *graceful_switch, + pending_subchannel, + tcc.as_mut(), + ConnectivityState::Connecting, + ); + // This should not produce an update. + assert_channel_empty(&mut rx_events).await; + move_subchannel_to_state( + &mut *graceful_switch, + current_subchannel, + tcc.as_mut(), + ConnectivityState::Connecting, + ); + verify_correct_picker_from_policy( + &mut rx_events, + "stub-gracefulswitch_current_leaving_ready-two", + ) + .await; + } + + // Tests that the gracefulswitch policy should replace the current child + // with the pending child if the pending child leaves connecting. + #[tokio::test] + async fn gracefulswitch_pending_leaving_connecting() { + let (mut rx_events, mut graceful_switch, mut tcc) = setup(); + reg_stub_policy( + "stub-gracefulswitch_current_leaving_ready-one", + create_funcs_for_gracefulswitch_tests("stub-gracefulswitch_current_leaving_ready-one"), + ); + reg_stub_policy( + "stub-gracefulswitch_current_leaving_ready-two", + create_funcs_for_gracefulswitch_tests("stub-gracefulswitch_current_leaving_ready-two"), + ); + let service_config = serde_json::json!( + [ + { "stub-gracefulswitch_current_leaving_ready-one": serde_json::json!({}) } + ] + ); + let parsed_config = GracefulSwitchPolicy::parse_config(&ParsedJsonLbConfig { + value: service_config, + }) + .unwrap(); + let endpoint = create_endpoint_with_one_address("127.0.0.1:1234".to_string()); + let endpoint2 = create_endpoint_with_one_address("127.0.0.1:1235".to_string()); + let update = ResolverUpdate { + endpoints: Ok(vec![endpoint.clone()]), + ..Default::default() + }; + + // Switch to first one (current) + graceful_switch + .resolver_update(update.clone(), Some(&parsed_config), &mut *tcc) + .unwrap(); + + let current_subchannel = verify_subchannel_creation_from_policy(&mut rx_events).await; + move_subchannel_to_state( + &mut *graceful_switch, + current_subchannel, + tcc.as_mut(), + ConnectivityState::Ready, + ); + verify_correct_picker_from_policy( + &mut rx_events, + "stub-gracefulswitch_current_leaving_ready-one", + ) + .await; + let new_service_config = serde_json::json!( + [ + { "stub-gracefulswitch_current_leaving_ready-two": serde_json::json!({}) }, + ] + ); + let new_update = ResolverUpdate { + endpoints: Ok(vec![endpoint2.clone()]), + ..Default::default() + }; + let new_parsed_config = GracefulSwitchPolicy::parse_config(&ParsedJsonLbConfig { + value: new_service_config, + }) + .unwrap(); + + graceful_switch + .resolver_update(new_update.clone(), Some(&new_parsed_config), &mut *tcc) + .unwrap(); + + let pending_subchannel = verify_subchannel_creation_from_policy(&mut rx_events).await; + + move_subchannel_to_state( + &mut *graceful_switch, + pending_subchannel.clone(), + tcc.as_mut(), + ConnectivityState::TransientFailure, + ); + verify_correct_picker_from_policy( + &mut rx_events, + "stub-gracefulswitch_current_leaving_ready-two", + ) + .await; + move_subchannel_to_state( + &mut *graceful_switch, + pending_subchannel, + tcc.as_mut(), + ConnectivityState::Connecting, + ); + verify_correct_picker_from_policy( + &mut rx_events, + "stub-gracefulswitch_current_leaving_ready-two", + ) + .await; + } + + // Tests that the gracefulswitch policy should remove the current child's + // subchannels after swapping. + #[tokio::test] + async fn gracefulswitch_subchannels_removed_after_current_child_swapped() { + let (mut rx_events, mut graceful_switch, mut tcc) = setup(); + reg_stub_policy( + "stub-gracefulswitch_subchannels_removed_after_current_child_swapped-one", + create_funcs_for_gracefulswitch_tests( + "stub-gracefulswitch_subchannels_removed_after_current_child_swapped-one", + ), + ); + reg_stub_policy( + "stub-gracefulswitch_subchannels_removed_after_current_child_swapped-two", + create_funcs_for_gracefulswitch_tests( + "stub-gracefulswitch_subchannels_removed_after_current_child_swapped-two", + ), + ); + let service_config = serde_json::json!( + [ + { "stub-gracefulswitch_subchannels_removed_after_current_child_swapped-one": serde_json::json!({}) } + ] + ); + let parsed_config = GracefulSwitchPolicy::parse_config(&ParsedJsonLbConfig { + value: service_config, + }) + .unwrap(); + let endpoint = create_endpoint_with_one_address("127.0.0.1:1234".to_string()); + let update = ResolverUpdate { + endpoints: Ok(vec![endpoint.clone()]), + ..Default::default() + }; + graceful_switch + .resolver_update(update.clone(), Some(&parsed_config), &mut *tcc) + .unwrap(); + + let current_subchannel = verify_subchannel_creation_from_policy(&mut rx_events).await; + move_subchannel_to_state( + &mut *graceful_switch, + current_subchannel.clone(), + tcc.as_mut(), + ConnectivityState::Ready, + ); + verify_correct_picker_from_policy( + &mut rx_events, + "stub-gracefulswitch_subchannels_removed_after_current_child_swapped-one", + ) + .await; + let new_service_config = serde_json::json!( + [ + { "stub-gracefulswitch_subchannels_removed_after_current_child_swapped-two": serde_json::json!({ "shuffleAddressList": false }) }, + ] + ); + let second_endpoint = create_endpoint_with_one_address("127.0.0.1:1235".to_string()); + let second_update = ResolverUpdate { + endpoints: Ok(vec![second_endpoint.clone()]), + ..Default::default() + }; + let new_parsed_config = GracefulSwitchPolicy::parse_config(&ParsedJsonLbConfig { + value: new_service_config, + }) + .unwrap(); + graceful_switch + .resolver_update(second_update.clone(), Some(&new_parsed_config), &mut *tcc) + .unwrap(); + let pending_subchannel = verify_subchannel_creation_from_policy(&mut rx_events).await; + println!("moving subchannel to idle"); + move_subchannel_to_state( + &mut *graceful_switch, + pending_subchannel, + tcc.as_mut(), + ConnectivityState::Idle, + ); + verify_correct_picker_from_policy( + &mut rx_events, + "stub-gracefulswitch_subchannels_removed_after_current_child_swapped-two", + ) + .await; + assert!(Arc::strong_count(¤t_subchannel) == 1); + } +} diff --git a/grpc/src/client/load_balancing/mod.rs b/grpc/src/client/load_balancing/mod.rs index 01f9dea41..a15e030a1 100644 --- a/grpc/src/client/load_balancing/mod.rs +++ b/grpc/src/client/load_balancing/mod.rs @@ -46,10 +46,12 @@ use crate::client::{ ConnectivityState, }; -pub mod child_manager; -pub mod pick_first; +pub(crate) mod child_manager; +pub(crate) mod graceful_switch; +pub(crate) mod pick_first; + #[cfg(test)] -pub mod test_utils; +pub(crate) mod test_utils; pub(crate) mod registry; use super::{service_config::LbConfig, subchannel::SubchannelStateWatcher}; @@ -57,7 +59,7 @@ pub(crate) use registry::GLOBAL_LB_REGISTRY; /// A collection of data configured on the channel that is constructing this /// LbPolicy. -pub struct LbPolicyOptions { +pub(crate) struct LbPolicyOptions { /// A hook into the channel's work scheduler that allows the LbPolicy to /// request the ability to perform operations on the ChannelController. pub work_scheduler: Arc, @@ -67,7 +69,7 @@ pub struct LbPolicyOptions { /// Used to asynchronously request a call into the LbPolicy's work method if /// the LbPolicy needs to provide an update without waiting for an update /// from the channel first. -pub trait WorkScheduler: Send + Sync { +pub(crate) trait WorkScheduler: Send + Sync + Debug { // Schedules a call into the LbPolicy's work method. If there is already a // pending work call that has not yet started, this may not schedule another // call. @@ -78,7 +80,7 @@ pub trait WorkScheduler: Send + Sync { /// JSON. Hides internal storage details and includes a method to deserialize /// the JSON into a concrete policy struct. #[derive(Debug)] -pub struct ParsedJsonLbConfig { +pub(crate) struct ParsedJsonLbConfig { value: serde_json::Value, } @@ -115,7 +117,7 @@ impl ParsedJsonLbConfig { /// An LB policy factory that produces LbPolicy instances used by the channel /// to manage connections and pick connections for RPCs. -pub(crate) trait LbPolicyBuilder: Send + Sync { +pub(crate) trait LbPolicyBuilder: Send + Sync + Debug { /// Builds and returns a new LB policy instance. /// /// Note that build must not fail. Any optional configuration is delivered @@ -145,7 +147,7 @@ pub(crate) trait LbPolicyBuilder: Send + Sync { /// LB policies are responsible for creating connections (modeled as /// Subchannels) and producing Picker instances for picking connections for /// RPCs. -pub trait LbPolicy: Send { +pub(crate) trait LbPolicy: Send + Debug { /// Called by the channel when the name resolver produces a new set of /// resolved addresses or a new service config. fn resolver_update( @@ -174,7 +176,7 @@ pub trait LbPolicy: Send { } /// Controls channel behaviors. -pub trait ChannelController: Send + Sync { +pub(crate) trait ChannelController: Send + Sync { /// Creates a new subchannel in IDLE state. fn new_subchannel(&mut self, address: &Address) -> Arc; @@ -189,7 +191,7 @@ pub trait ChannelController: Send + Sync { /// Represents the current state of a Subchannel. #[derive(Debug, Clone)] -pub struct SubchannelState { +pub(crate) struct SubchannelState { /// The connectivity state of the subchannel. See SubChannel for a /// description of the various states and their valid transitions. pub connectivity_state: ConnectivityState, @@ -238,7 +240,7 @@ impl Display for SubchannelState { /// /// If the ConnectivityState is TransientFailure, the Picker should return an /// Err with an error that describes why connections are failing. -pub trait Picker: Send + Sync { +pub(crate) trait Picker: Send + Sync + Debug { /// Picks a connection to use for the request. /// /// This function should not block. If the Picker needs to do blocking or @@ -248,7 +250,8 @@ pub trait Picker: Send + Sync { fn pick(&self, request: &Request) -> PickResult; } -pub enum PickResult { +#[derive(Debug)] +pub(crate) enum PickResult { /// Indicates the Subchannel in the Pick should be used for the request. Pick(Pick), /// Indicates the LbPolicy is attempting to connect to a server to use for @@ -309,8 +312,8 @@ impl Display for PickResult { } } /// Data provided by the LB policy. -#[derive(Clone)] -pub struct LbState { +#[derive(Clone, Debug)] +pub(crate) struct LbState { pub connectivity_state: super::ConnectivityState, pub picker: Arc, } @@ -327,10 +330,10 @@ impl LbState { } /// Type alias for the completion callback function. -pub type CompletionCallback = Box; +pub(crate) type CompletionCallback = Box; /// A collection of data used by the channel for routing a request. -pub struct Pick { +pub(crate) struct Pick { /// The Subchannel for the request. pub subchannel: Arc, // Metadata to be added to existing outgoing metadata. @@ -339,7 +342,17 @@ pub struct Pick { pub on_complete: Option, } -pub trait DynHash { +impl Debug for Pick { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Pick") + .field("subchannel", &self.subchannel) + .field("metadata", &self.metadata) + .field("on_complete", &format_args!("{:p}", &self.on_complete)) + .finish() + } +} + +pub(crate) trait DynHash { #[allow(clippy::redundant_allocation)] fn dyn_hash(&self, state: &mut Box<&mut dyn Hasher>); } @@ -350,7 +363,7 @@ impl DynHash for T { } } -pub trait DynPartialEq { +pub(crate) trait DynPartialEq { fn dyn_eq(&self, other: &&dyn Any) -> bool; } @@ -367,7 +380,7 @@ mod private { pub trait Sealed {} } -pub trait SealedSubchannel: private::Sealed {} +pub(crate) trait SealedSubchannel: private::Sealed {} /// A Subchannel represents a method of communicating with a server which may be /// connected or disconnected many times across its lifetime. @@ -386,7 +399,9 @@ pub trait SealedSubchannel: private::Sealed {} /// /// When a Subchannel is dropped, it is disconnected automatically, and no /// subsequent state updates will be provided for it to the LB policy. -pub trait Subchannel: SealedSubchannel + DynHash + DynPartialEq + Any + Send + Sync { +pub(crate) trait Subchannel: + SealedSubchannel + DynHash + DynPartialEq + Any + Send + Sync +{ /// Returns the address of the Subchannel. /// TODO: Consider whether this should really be public. fn address(&self) -> Address; @@ -430,6 +445,7 @@ impl Display for dyn Subchannel { } } +#[derive(Debug)] struct WeakSubchannel(Weak); impl From> for WeakSubchannel { @@ -540,7 +556,7 @@ impl Display for ExternalSubchannel { } } -pub trait ForwardingSubchannel: DynHash + DynPartialEq + Any + Send + Sync { +pub(crate) trait ForwardingSubchannel: DynHash + DynPartialEq + Any + Send + Sync { fn delegate(&self) -> Arc; fn address(&self) -> Address { @@ -564,7 +580,8 @@ impl private::Sealed for T {} /// QueuingPicker always returns Queue. LB policies that are not actively /// Connecting should not use this picker. -pub struct QueuingPicker {} +#[derive(Debug)] +pub(crate) struct QueuingPicker {} impl Picker for QueuingPicker { fn pick(&self, _request: &Request) -> PickResult { @@ -572,7 +589,8 @@ impl Picker for QueuingPicker { } } -pub struct Failing { +#[derive(Debug)] +pub(crate) struct Failing { pub error: String, } diff --git a/grpc/src/client/load_balancing/pick_first.rs b/grpc/src/client/load_balancing/pick_first.rs index 2f9cbd9f8..b3940ce06 100644 --- a/grpc/src/client/load_balancing/pick_first.rs +++ b/grpc/src/client/load_balancing/pick_first.rs @@ -17,8 +17,9 @@ use super::{ SubchannelState, WorkScheduler, }; -pub static POLICY_NAME: &str = "pick_first"; +pub(crate) static POLICY_NAME: &str = "pick_first"; +#[derive(Debug)] struct Builder {} impl LbPolicyBuilder for Builder { @@ -36,10 +37,11 @@ impl LbPolicyBuilder for Builder { } } -pub fn reg() { +pub(crate) fn reg() { super::GLOBAL_LB_REGISTRY.add_builder(Builder {}) } +#[derive(Debug)] struct PickFirstPolicy { work_scheduler: Arc, subchannel: Option>, @@ -104,6 +106,7 @@ impl LbPolicy for PickFirstPolicy { } } +#[derive(Debug)] struct OneSubchannelPicker { sc: Arc, } @@ -112,8 +115,9 @@ impl Picker for OneSubchannelPicker { fn pick(&self, request: &Request) -> PickResult { PickResult::Pick(Pick { subchannel: self.sc.clone(), - on_complete: None, + // on_complete: None, metadata: MetadataMap::new(), + on_complete: None, }) } } diff --git a/grpc/src/client/load_balancing/registry.rs b/grpc/src/client/load_balancing/registry.rs index d3dbeffb3..45c757c89 100644 --- a/grpc/src/client/load_balancing/registry.rs +++ b/grpc/src/client/load_balancing/registry.rs @@ -7,7 +7,7 @@ use super::LbPolicyBuilder; /// A registry to store and retrieve LB policies. LB policies are indexed by /// their names. -pub struct LbPolicyRegistry { +pub(crate) struct LbPolicyRegistry { m: Arc>>>, } @@ -37,4 +37,5 @@ impl Default for LbPolicyRegistry { /// The registry used if a local registry is not provided to a channel or if it /// does not exist in the local registry. -pub static GLOBAL_LB_REGISTRY: LazyLock = LazyLock::new(LbPolicyRegistry::new); +pub(crate) static GLOBAL_LB_REGISTRY: LazyLock = + LazyLock::new(LbPolicyRegistry::new); diff --git a/grpc/src/client/load_balancing/test_utils.rs b/grpc/src/client/load_balancing/test_utils.rs index 0069fc929..fc6ba4422 100644 --- a/grpc/src/client/load_balancing/test_utils.rs +++ b/grpc/src/client/load_balancing/test_utils.rs @@ -29,6 +29,8 @@ use crate::client::load_balancing::{ use crate::client::name_resolution::{Address, ResolverUpdate}; use crate::client::service_config::LbConfig; use crate::service::{Message, Request}; +use serde::{Deserialize, Serialize}; +use std::any::Any; use std::error::Error; use std::hash::Hash; use std::{fmt::Debug, sync::Arc}; @@ -138,6 +140,7 @@ impl ChannelController for TestChannelController { } } +#[derive(Debug)] pub(crate) struct TestWorkScheduler { pub(crate) tx_events: mpsc::UnboundedSender, } @@ -149,26 +152,56 @@ impl WorkScheduler for TestWorkScheduler { } // The callback to invoke when resolver_update is invoked on the stub policy. -type ResolverUpdateFn = fn( - ResolverUpdate, - Option<&LbConfig>, - &mut dyn ChannelController, -) -> Result<(), Box>; +type ResolverUpdateFn = Arc< + dyn Fn( + &mut StubPolicyData, + ResolverUpdate, + Option<&LbConfig>, + &mut dyn ChannelController, + ) -> Result<(), Box> + + Send + + Sync, +>; // The callback to invoke when subchannel_update is invoked on the stub policy. -type SubchannelUpdateFn = fn(Arc, &SubchannelState, &mut dyn ChannelController); +type SubchannelUpdateFn = Arc< + dyn Fn(&mut StubPolicyData, Arc, &SubchannelState, &mut dyn ChannelController) + + Send + + Sync, +>; /// This struct holds `LbPolicy` trait stub functions that tests are expected to /// implement. -#[derive(Clone, Default)] -pub struct StubPolicyFuncs { +#[derive(Clone)] +pub(crate) struct StubPolicyFuncs { pub resolver_update: Option, pub subchannel_update: Option, } +impl Debug for StubPolicyFuncs { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "stub funcs") + } +} + +/// Data holds test data that will be passed all to functions in PolicyFuncs +#[derive(Debug)] +pub(crate) struct StubPolicyData { + pub test_data: Option>, +} + +impl StubPolicyData { + /// Creates an instance of StubPolicyData. + pub fn new() -> Self { + Self { test_data: None } + } +} + /// The stub `LbPolicy` that calls the provided functions. -pub struct StubPolicy { +#[derive(Debug)] +pub(crate) struct StubPolicy { funcs: StubPolicyFuncs, + data: StubPolicyData, } impl LbPolicy for StubPolicy { @@ -178,8 +211,8 @@ impl LbPolicy for StubPolicy { config: Option<&LbConfig>, channel_controller: &mut dyn ChannelController, ) -> Result<(), Box> { - if let Some(f) = &self.funcs.resolver_update { - return f(update, config, channel_controller); + if let Some(f) = &mut self.funcs.resolver_update { + return f(&mut self.data, update, config, channel_controller); } Ok(()) } @@ -191,7 +224,7 @@ impl LbPolicy for StubPolicy { channel_controller: &mut dyn ChannelController, ) { if let Some(f) = &self.funcs.subchannel_update { - f(subchannel, state, channel_controller); + f(&mut self.data, subchannel, state, channel_controller); } } @@ -205,15 +238,24 @@ impl LbPolicy for StubPolicy { } /// StubPolicyBuilder builds a StubLbPolicy. -pub struct StubPolicyBuilder { +#[derive(Debug)] +pub(crate) struct StubPolicyBuilder { name: &'static str, funcs: StubPolicyFuncs, } +#[derive(Serialize, Deserialize, Debug)] +#[serde(rename_all = "camelCase")] +pub(super) struct MockConfig { + shuffle_address_list: Option, +} + impl LbPolicyBuilder for StubPolicyBuilder { fn build(&self, options: LbPolicyOptions) -> Box { + let data = StubPolicyData::new(); Box::new(StubPolicy { funcs: self.funcs.clone(), + data, }) } @@ -223,12 +265,18 @@ impl LbPolicyBuilder for StubPolicyBuilder { fn parse_config( &self, - _config: &ParsedJsonLbConfig, + config: &ParsedJsonLbConfig, ) -> Result, Box> { - todo!("Implement parse_config in StubPolicyBuilder") + let cfg: MockConfig = match config.convert_to() { + Ok(c) => c, + Err(e) => { + return Err(format!("failed to parse JSON config: {}", e).into()); + } + }; + Ok(Some(LbConfig::new(cfg))) } } -pub fn reg_stub_policy(name: &'static str, funcs: StubPolicyFuncs) { +pub(crate) fn reg_stub_policy(name: &'static str, funcs: StubPolicyFuncs) { super::GLOBAL_LB_REGISTRY.add_builder(StubPolicyBuilder { name, funcs }) } diff --git a/grpc/src/client/mod.rs b/grpc/src/client/mod.rs index e896412ae..b141e7f2d 100644 --- a/grpc/src/client/mod.rs +++ b/grpc/src/client/mod.rs @@ -25,14 +25,15 @@ use std::fmt::Display; pub mod channel; -pub(crate) mod load_balancing; -pub(crate) mod name_resolution; pub mod service_config; mod subchannel; -pub(crate) mod transport; pub use channel::Channel; pub use channel::ChannelOptions; +pub(crate) mod load_balancing; +pub(crate) mod name_resolution; +pub(crate) mod transport; + /// A representation of the current state of a gRPC channel, also used for the /// state of subchannels (individual connections within the channel). /// @@ -44,7 +45,7 @@ pub use channel::ChannelOptions; /// /// Channels may re-enter the Idle state if they are unused for longer than /// their configured idleness timeout. -#[derive(Copy, Clone, PartialEq, Debug)] +#[derive(Copy, Clone, PartialEq, Eq, Debug)] pub enum ConnectivityState { Idle, Connecting, diff --git a/grpc/src/client/name_resolution/backoff.rs b/grpc/src/client/name_resolution/backoff.rs index 8725706e2..154fad47f 100644 --- a/grpc/src/client/name_resolution/backoff.rs +++ b/grpc/src/client/name_resolution/backoff.rs @@ -26,7 +26,7 @@ use rand::Rng; use std::time::Duration; #[derive(Clone)] -pub struct BackoffConfig { +pub(crate) struct BackoffConfig { /// The amount of time to backoff after the first failure. pub base_delay: Duration, @@ -41,7 +41,7 @@ pub struct BackoffConfig { pub max_delay: Duration, } -pub struct ExponentialBackoff { +pub(crate) struct ExponentialBackoff { config: BackoffConfig, /// The delay for the next retry, without the random jitter. Store as f64 @@ -54,7 +54,7 @@ pub struct ExponentialBackoff { /// /// This should be useful for callers who want to configure backoff with /// non-default values only for a subset of the options. -pub const DEFAULT_EXPONENTIAL_CONFIG: BackoffConfig = BackoffConfig { +pub(crate) const DEFAULT_EXPONENTIAL_CONFIG: BackoffConfig = BackoffConfig { base_delay: Duration::from_secs(1), multiplier: 1.6, jitter: 0.2, diff --git a/grpc/src/client/name_resolution/dns/mod.rs b/grpc/src/client/name_resolution/dns/mod.rs index 6475d62c3..dddb62492 100644 --- a/grpc/src/client/name_resolution/dns/mod.rs +++ b/grpc/src/client/name_resolution/dns/mod.rs @@ -83,7 +83,7 @@ fn get_resolving_timeout() -> Duration { /// premature timeouts during resolution, while setting it too high may lead to /// unnecessary delays in service discovery. Choose a value appropriate for your /// specific needs and network environment. -pub fn set_resolving_timeout(duration: Duration) { +pub(crate) fn set_resolving_timeout(duration: Duration) { RESOLVING_TIMEOUT_MS.store(duration.as_millis() as u64, Ordering::Relaxed); } @@ -96,11 +96,11 @@ fn get_min_resolution_interval() -> Duration { /// /// It must be called only at application startup, before any gRPC calls are /// made. -pub fn set_min_resolution_interval(duration: Duration) { +pub(crate) fn set_min_resolution_interval(duration: Duration) { MIN_RESOLUTION_INTERVAL_MS.store(duration.as_millis() as u64, Ordering::Relaxed); } -pub fn reg() { +pub(crate) fn reg() { global_registry().add_builder(Box::new(Builder {})); } diff --git a/grpc/src/client/name_resolution/dns/test.rs b/grpc/src/client/name_resolution/dns/test.rs index 135e8ccfa..f787a8df0 100644 --- a/grpc/src/client/name_resolution/dns/test.rs +++ b/grpc/src/client/name_resolution/dns/test.rs @@ -48,7 +48,7 @@ use super::{DnsOptions, ParseResult}; const DEFAULT_TEST_SHORT_TIMEOUT: Duration = Duration::from_millis(10); #[test] -pub fn target_parsing() { +pub(crate) fn target_parsing() { struct TestCase { input: &'static str, want_result: Result, @@ -191,7 +191,7 @@ impl ChannelController for FakeChannelController { } #[tokio::test] -pub async fn dns_basic() { +pub(crate) async fn dns_basic() { reg(); let builder = global_registry().get("dns").unwrap(); let target = &"dns:///localhost:1234".parse().unwrap(); @@ -220,7 +220,7 @@ pub async fn dns_basic() { } #[tokio::test] -pub async fn invalid_target() { +pub(crate) async fn invalid_target() { reg(); let builder = global_registry().get("dns").unwrap(); let target = &"dns:///:1234".parse().unwrap(); @@ -252,7 +252,7 @@ pub async fn invalid_target() { .contains(&target.to_string())); } -#[derive(Clone)] +#[derive(Clone, Debug)] struct FakeDns { latency: Duration, lookup_result: Result, String>, @@ -270,6 +270,7 @@ impl rt::DnsResolver for FakeDns { } } +#[derive(Debug)] struct FakeRuntime { inner: TokioRuntime, dns: FakeDns, @@ -301,7 +302,7 @@ impl rt::Runtime for FakeRuntime { } #[tokio::test] -pub async fn dns_lookup_error() { +pub(crate) async fn dns_lookup_error() { reg(); let builder = global_registry().get("dns").unwrap(); let target = &"dns:///grpc.io:1234".parse().unwrap(); @@ -337,7 +338,7 @@ pub async fn dns_lookup_error() { } #[tokio::test] -pub async fn dns_lookup_timeout() { +pub(crate) async fn dns_lookup_timeout() { let (work_tx, mut work_rx) = mpsc::unbounded_channel(); let work_scheduler = Arc::new(FakeWorkScheduler { work_tx: work_tx.clone(), @@ -379,7 +380,7 @@ pub async fn dns_lookup_timeout() { } #[tokio::test] -pub async fn rate_limit() { +pub(crate) async fn rate_limit() { let (work_tx, mut work_rx) = mpsc::unbounded_channel(); let work_scheduler = Arc::new(FakeWorkScheduler { work_tx: work_tx.clone(), @@ -429,7 +430,7 @@ pub async fn rate_limit() { } #[tokio::test] -pub async fn re_resolution_after_success() { +pub(crate) async fn re_resolution_after_success() { let (work_tx, mut work_rx) = mpsc::unbounded_channel(); let work_scheduler = Arc::new(FakeWorkScheduler { work_tx: work_tx.clone(), @@ -473,7 +474,7 @@ pub async fn re_resolution_after_success() { } #[tokio::test] -pub async fn backoff_on_error() { +pub(crate) async fn backoff_on_error() { let (work_tx, mut work_rx) = mpsc::unbounded_channel(); let work_scheduler = Arc::new(FakeWorkScheduler { work_tx: work_tx.clone(), diff --git a/grpc/src/client/name_resolution/mod.rs b/grpc/src/client/name_resolution/mod.rs index 28f6ba91f..76123cbde 100644 --- a/grpc/src/client/name_resolution/mod.rs +++ b/grpc/src/client/name_resolution/mod.rs @@ -41,7 +41,7 @@ use std::{ mod backoff; mod dns; mod registry; -pub use registry::global_registry; +pub(crate) use registry::global_registry; use url::Url; /// Target represents a target for gRPC, as specified in: @@ -55,7 +55,7 @@ use url::Url; /// (i.e. no corresponding resolver available to resolve the endpoint), we will /// apply the default scheme, and will attempt to reparse it. #[derive(Debug, Clone)] -pub struct Target { +pub(crate) struct Target { url: Url, } @@ -133,7 +133,7 @@ impl Display for Target { /// A name resolver factory that produces Resolver instances used by the channel /// to resolve network addresses for the target URI. -pub trait ResolverBuilder: Send + Sync { +pub(crate) trait ResolverBuilder: Send + Sync { /// Builds a name resolver instance. /// /// Note that build must not fail. Instead, an erroring Resolver may be @@ -164,7 +164,7 @@ pub trait ResolverBuilder: Send + Sync { /// A collection of data configured on the channel that is constructing this /// name resolver. #[non_exhaustive] -pub struct ResolverOptions { +pub(crate) struct ResolverOptions { /// The authority that will be used for the channel by default. This refers /// to the `:authority` value sent in HTTP/2 requests — the dataplane /// authority — and not the authority portion of the target URI, which is @@ -184,7 +184,7 @@ pub struct ResolverOptions { } /// Used to asynchronously request a call into the Resolver's work method. -pub trait WorkScheduler: Send + Sync { +pub(crate) trait WorkScheduler: Send + Sync { // Schedules a call into the Resolver's work method. If there is already a // pending work call that has not yet started, this may not schedule another // call. @@ -196,7 +196,7 @@ pub trait WorkScheduler: Send + Sync { // This trait may not need the Sync sub-trait if the channel implementation can // ensure that the resolver is accessed serially. The sub-trait can be removed // in that case. -pub trait Resolver: Send + Sync { +pub(crate) trait Resolver: Send + Sync { /// Asks the resolver to obtain an updated resolver result, if applicable. /// /// This is useful for polling resolvers to decide when to re-resolve. @@ -215,7 +215,7 @@ pub trait Resolver: Send + Sync { /// The `ChannelController` trait provides the resolver with functionality /// to interact with the channel. -pub trait ChannelController: Send + Sync { +pub(crate) trait ChannelController: Send + Sync { /// Notifies the channel about the current state of the name resolver. If /// an error value is returned, the name resolver should attempt to /// re-resolve, if possible. The resolver is responsible for applying an @@ -232,7 +232,7 @@ pub trait ChannelController: Send + Sync { #[non_exhaustive] /// ResolverUpdate contains the current Resolver state relevant to the /// channel. -pub struct ResolverUpdate { +pub(crate) struct ResolverUpdate { /// Attributes contains arbitrary data about the resolver intended for /// consumption by the load balancing policy. pub attributes: Attributes, @@ -271,7 +271,7 @@ impl Default for ResolverUpdate { /// which the server can be reached, e.g. via IPv4 and IPv6 addresses. #[derive(Debug, Default, Clone, PartialEq, Eq)] #[non_exhaustive] -pub struct Endpoint { +pub(crate) struct Endpoint { /// Addresses contains a list of addresses used to access this endpoint. pub addresses: Vec
, @@ -289,7 +289,7 @@ impl Hash for Endpoint { /// An Address is an identifier that indicates how to connect to a server. #[non_exhaustive] #[derive(Debug, Clone, Default, Ord, PartialOrd)] -pub struct Address { +pub(crate) struct Address { /// The network type is used to identify what kind of transport to create /// when connecting to this address. Typically TCP_IP_ADDRESS_TYPE. pub network_type: &'static str, @@ -327,7 +327,7 @@ impl Display for Address { /// Indicates the address is an IPv4 or IPv6 address that should be connected to /// via TCP/IP. -pub static TCP_IP_NETWORK_TYPE: &str = "tcp"; +pub(crate) static TCP_IP_NETWORK_TYPE: &str = "tcp"; // A resolver that returns the same result every time its work method is called. // It can be used to return an error to the channel when a resolver fails to diff --git a/grpc/src/client/name_resolution/registry.rs b/grpc/src/client/name_resolution/registry.rs index eb216538c..ca2c2035a 100644 --- a/grpc/src/client/name_resolution/registry.rs +++ b/grpc/src/client/name_resolution/registry.rs @@ -34,7 +34,7 @@ static GLOBAL_RESOLVER_REGISTRY: OnceLock = OnceLock::new(); /// A registry to store and retrieve name resolvers. Resolvers are indexed by /// the URI scheme they are intended to handle. #[derive(Default)] -pub struct ResolverRegistry { +pub(crate) struct ResolverRegistry { inner: Arc>>>, } @@ -90,6 +90,6 @@ impl ResolverRegistry { } /// Global registry for resolver builders. -pub fn global_registry() -> &'static ResolverRegistry { +pub(crate) fn global_registry() -> &'static ResolverRegistry { GLOBAL_RESOLVER_REGISTRY.get_or_init(ResolverRegistry::new) } diff --git a/grpc/src/client/service_config.rs b/grpc/src/client/service_config.rs index da268ca33..e729976f3 100644 --- a/grpc/src/client/service_config.rs +++ b/grpc/src/client/service_config.rs @@ -21,7 +21,7 @@ * IN THE SOFTWARE. * */ -use std::{any::Any, error::Error, sync::Arc}; +use std::{any::Any, sync::Arc}; /// An in-memory representation of a service config, usually provided to gRPC as /// a JSON object. @@ -29,26 +29,21 @@ use std::{any::Any, error::Error, sync::Arc}; pub(crate) struct ServiceConfig; /// A convenience wrapper for an LB policy's configuration object. -#[derive(Debug)] +#[derive(Debug, Clone)] pub(crate) struct LbConfig { config: Arc, } impl LbConfig { /// Create a new LbConfig wrapper containing the provided config. - pub fn new(config: T) -> Self { + pub fn new(config: impl Any + Send + Sync) -> Self { LbConfig { config: Arc::new(config), } } /// Convenience method to extract the LB policy's configuration object. - pub fn convert_to( - &self, - ) -> Result, Box> { - match self.config.clone().downcast::() { - Ok(c) => Ok(c), - Err(e) => Err("failed to downcast to config type".into()), - } + pub fn convert_to(&self) -> Option> { + self.config.clone().downcast::().ok() } } diff --git a/grpc/src/client/transport/registry.rs b/grpc/src/client/transport/registry.rs index 0b4f614ef..9fe246e92 100644 --- a/grpc/src/client/transport/registry.rs +++ b/grpc/src/client/transport/registry.rs @@ -48,5 +48,5 @@ impl TransportRegistry { /// The registry used if a local registry is not provided to a channel or if it /// does not exist in the local registry. -pub static GLOBAL_TRANSPORT_REGISTRY: LazyLock = +pub(crate) static GLOBAL_TRANSPORT_REGISTRY: LazyLock = LazyLock::new(TransportRegistry::new); diff --git a/grpc/src/client/transport/tonic/mod.rs b/grpc/src/client/transport/tonic/mod.rs index 11fcd24e1..57f695bd6 100644 --- a/grpc/src/client/transport/tonic/mod.rs +++ b/grpc/src/client/transport/tonic/mod.rs @@ -264,7 +264,7 @@ impl GrpcService for TonicService { /// A future that resolves to an HTTP response. /// /// This is returned by the `Service::call` on [`Channel`]. -pub struct ResponseFuture { +pub(crate) struct ResponseFuture { inner: BufferResponseFuture, BoxError>>>, } diff --git a/grpc/src/client/transport/tonic/test.rs b/grpc/src/client/transport/tonic/test.rs index 678280e34..2ccdc75da 100644 --- a/grpc/src/client/transport/tonic/test.rs +++ b/grpc/src/client/transport/tonic/test.rs @@ -21,7 +21,7 @@ const DEFAULT_TEST_SHORT_DURATION: Duration = Duration::from_millis(10); // Tests the tonic transport by creating a bi-di stream with a tonic server. #[tokio::test] -pub async fn tonic_transport_rpc() { +pub(crate) async fn tonic_transport_rpc() { super::reg(); let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); let addr = listener.local_addr().unwrap(); // get the assigned address @@ -110,7 +110,7 @@ pub async fn tonic_transport_rpc() { } #[derive(Debug)] -pub struct EchoService {} +pub(crate) struct EchoService {} #[async_trait] impl Echo for EchoService { diff --git a/grpc/src/rt/mod.rs b/grpc/src/rt/mod.rs index 81d22ff7c..4186fefff 100644 --- a/grpc/src/rt/mod.rs +++ b/grpc/src/rt/mod.rs @@ -23,6 +23,7 @@ */ use ::tokio::io::{AsyncRead, AsyncWrite}; +use std::fmt::Debug; use std::{future::Future, net::SocketAddr, pin::Pin, sync::Arc, time::Duration}; pub(crate) mod hyper_wrapper; @@ -39,7 +40,7 @@ pub(crate) type BoxedTaskHandle = Box; /// time-based operations such as sleeping. It provides a uniform interface /// that can be implemented for various async runtimes, enabling pluggable /// and testable infrastructure. -pub(super) trait Runtime: Send + Sync { +pub(super) trait Runtime: Send + Sync + Debug { /// Spawns the given asynchronous task to run in the background. fn spawn(&self, task: Pin + Send + 'static>>) -> BoxedTaskHandle; @@ -98,7 +99,7 @@ pub(crate) trait TcpStream: AsyncRead + AsyncWrite + Send + Unpin {} /// # Panics /// /// Panics if any of its functions are called. -#[derive(Default)] +#[derive(Default, Debug)] pub(crate) struct NoOpRuntime {} impl Runtime for NoOpRuntime { diff --git a/grpc/src/rt/tokio/mod.rs b/grpc/src/rt/tokio/mod.rs index 8caec4cf3..860f0617a 100644 --- a/grpc/src/rt/tokio/mod.rs +++ b/grpc/src/rt/tokio/mod.rs @@ -64,6 +64,7 @@ impl DnsResolver for TokioDefaultDnsResolver { } } +#[derive(Debug)] pub(crate) struct TokioRuntime {} impl TaskHandle for JoinHandle<()> {