Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 178 additions & 7 deletions grpc/src/client/load_balancing/child_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

// TODO: This is mainly provided as a fairly complex example of the current LB
// policy in use. Complete tests must be written before it can be used in
// production. Also, support for the work scheduler is missing.
// production.

use std::collections::HashSet;
use std::fmt::Debug;
Expand All @@ -53,6 +53,7 @@ pub(crate) struct ChildManager<T: Debug, S: ResolverUpdateSharder<T>> {
pending_work: Arc<Mutex<HashSet<usize>>>,
runtime: Arc<dyn Runtime>,
updated: bool, // Set when any child updates its picker; cleared when accessed.
work_scheduler: Arc<dyn WorkScheduler>,
}

#[non_exhaustive]
Expand Down Expand Up @@ -98,13 +99,18 @@ where
{
/// Creates a new ChildManager LB policy. shard_update is called whenever a
/// resolver_update operation occurs.
pub fn new(update_sharder: S, runtime: Arc<dyn Runtime>) -> Self {
pub fn new(
update_sharder: S,
runtime: Arc<dyn Runtime>,
work_scheduler: Arc<dyn WorkScheduler>,
) -> Self {
Self {
update_sharder,
subchannel_to_child_idx: Default::default(),
children: Default::default(),
pending_work: Default::default(),
runtime,
work_scheduler,
updated: false,
}
}
Expand Down Expand Up @@ -272,6 +278,7 @@ where
let work_scheduler = Arc::new(ChildWorkScheduler {
pending_work: self.pending_work.clone(),
idx: Mutex::new(Some(new_idx)),
work_scheduler: self.work_scheduler.clone(),
});
let policy = builder.build(LbPolicyOptions {
work_scheduler: work_scheduler.clone(),
Expand Down Expand Up @@ -395,8 +402,9 @@ impl ChannelController for WrappedController<'_> {

#[derive(Debug)]
struct ChildWorkScheduler {
work_scheduler: Arc<dyn WorkScheduler>, // The real work scheduler of the channel.
pending_work: Arc<Mutex<HashSet<usize>>>, // Must be taken first for correctness
idx: Mutex<Option<usize>>, // None if the child is deleted.
idx: Mutex<Option<usize>>, // None if the child is deleted.
}

impl WorkScheduler for ChildWorkScheduler {
Expand All @@ -405,6 +413,12 @@ impl WorkScheduler for ChildWorkScheduler {
if let Some(idx) = *self.idx.lock().unwrap() {
pending_work.insert(idx);
}
// Call the real work scheduler with the lock held to avoid a scenario
// where we schedule work and get called before the lock can be taken,
// and to avoid the scenario where the child is called before the
// schedule_work call is done due to a concurrent call to
// ChildManager::work().
self.work_scheduler.schedule_work();
}
}

Expand All @@ -414,7 +428,7 @@ mod test {
ChildManager, ChildUpdate, ResolverUpdateSharder,
};
use crate::client::load_balancing::test_utils::{
self, StubPolicyData, StubPolicyFuncs, TestChannelController, TestEvent,
self, StubPolicyFuncs, TestChannelController, TestEvent, TestWorkScheduler,
};
use crate::client::load_balancing::{
ChannelController, LbPolicy, LbPolicyBuilder, LbState, QueuingPicker, Subchannel,
Expand All @@ -424,9 +438,11 @@ mod test {
use crate::client::service_config::LbConfig;
use crate::client::ConnectivityState;
use crate::rt::default_runtime;
use std::collections::HashMap;
use std::error::Error;
use std::panic;
use std::sync::Arc;
use std::sync::Mutex;
use tokio::sync::mpsc;

// TODO: This needs to be moved to a common place that can be shared between
Expand Down Expand Up @@ -492,10 +508,16 @@ mod test {
) {
test_utils::reg_stub_policy(test_name, funcs);
let (tx_events, rx_events) = mpsc::unbounded_channel::<TestEvent>();
let tcc = Box::new(TestChannelController { tx_events });
let tcc = Box::new(TestChannelController {
tx_events: tx_events.clone(),
});
let builder: Arc<dyn LbPolicyBuilder> = GLOBAL_LB_REGISTRY.get_policy(test_name).unwrap();
let endpoint_sharder = EndpointSharder { builder };
let child_manager = ChildManager::new(endpoint_sharder, default_runtime());
let child_manager = ChildManager::new(
endpoint_sharder,
default_runtime(),
Arc::new(TestWorkScheduler { tx_events }),
);
(rx_events, child_manager, tcc)
}

Expand Down Expand Up @@ -567,7 +589,6 @@ mod test {
// Defines the functions resolver_update and subchannel_update to test
// aggregate_states.
fn create_verifying_funcs_for_aggregate_tests() -> StubPolicyFuncs {
let data = StubPolicyData::new();
StubPolicyFuncs {
// Closure for resolver_update. resolver_update should only receive
// one endpoint and create one subchannel for the endpoint it
Expand All @@ -590,6 +611,7 @@ mod test {
});
},
)),
work: None,
}
}

Expand Down Expand Up @@ -759,4 +781,153 @@ mod test {
ConnectivityState::TransientFailure
);
}

struct ScheduleWorkStubData {
requested_work: bool,
}

fn create_funcs_for_schedule_work_tests(name: &'static str) -> StubPolicyFuncs {
StubPolicyFuncs {
resolver_update: Some(Arc::new(move |data, _update, lbcfg, _controller| {
if data.test_data.is_none() {
data.test_data = Some(Box::new(ScheduleWorkStubData {
requested_work: false,
}));
}
let stubdata = data
.test_data
.as_mut()
.unwrap()
.downcast_mut::<ScheduleWorkStubData>()
.unwrap();
assert!(!stubdata.requested_work);
if lbcfg
.unwrap()
.convert_to::<Mutex<HashMap<&'static str, ()>>>()
.unwrap()
.lock()
.unwrap()
.contains_key(name)
{
stubdata.requested_work = true;
data.lb_policy_options.work_scheduler.schedule_work();
}
Ok(())
})),
subchannel_update: None,
work: Some(Arc::new(move |data, _controller| {
println!("work called for {name}");
let stubdata = data
.test_data
.as_mut()
.unwrap()
.downcast_mut::<ScheduleWorkStubData>()
.unwrap();
stubdata.requested_work = false;
})),
}
}

#[derive(Debug)]
struct ScheduleWorkSharder {
names: Vec<&'static str>,
}

impl ResolverUpdateSharder<()> for ScheduleWorkSharder {
fn shard_update(
&mut self,
resolver_update: ResolverUpdate,
update: Option<&LbConfig>,
) -> Result<impl Iterator<Item = ChildUpdate<()>>, Box<dyn Error + Send + Sync>> {
let mut res = Vec::with_capacity(self.names.len());
for name in &self.names {
let child_policy_builder: Arc<dyn LbPolicyBuilder> =
GLOBAL_LB_REGISTRY.get_policy(name).unwrap();
res.push(ChildUpdate {
child_identifier: (),
child_policy_builder,
child_update: Some((ResolverUpdate::default(), update.cloned())),
});
}
Ok(res.into_iter())
}
}

// Tests that the child manager properly delegates to the children that
// called schedule_work when work is called.
#[tokio::test]
async fn childmanager_schedule_work_works() {
let name1 = "childmanager_schedule_work_works-one";
let name2 = "childmanager_schedule_work_works-two";
test_utils::reg_stub_policy(name1, create_funcs_for_schedule_work_tests(name1));
test_utils::reg_stub_policy(name2, create_funcs_for_schedule_work_tests(name2));

let (tx_events, mut rx_events) = mpsc::unbounded_channel::<TestEvent>();
let mut tcc = TestChannelController {
tx_events: tx_events.clone(),
};

let sharder = ScheduleWorkSharder {
names: vec![name1, name2],
};
let mut child_manager = ChildManager::new(
sharder,
default_runtime(),
Arc::new(TestWorkScheduler { tx_events }),
);

// Request that child one requests work.
let cfg = LbConfig::new(Mutex::new(HashMap::<&'static str, ()>::new()));
let children = cfg
.convert_to::<Mutex<HashMap<&'static str, ()>>>()
.unwrap();
children.lock().unwrap().insert(name1, ());

child_manager
.resolver_update(ResolverUpdate::default(), Some(&cfg), &mut tcc)
.unwrap();

// Confirm that child one has requested work.
match rx_events.recv().await.unwrap() {
TestEvent::ScheduleWork => {}
other => panic!("unexpected event {:?}", other),
};
assert_eq!(child_manager.pending_work.lock().unwrap().len(), 1);
let idx = *child_manager
.pending_work
.lock()
.unwrap()
.iter()
.next()
.unwrap();
assert_eq!(child_manager.children[idx].builder.name(), name1);

// Perform the work call and assert the pending_work set is empty.
child_manager.work(&mut tcc);
assert_eq!(child_manager.pending_work.lock().unwrap().len(), 0);

// Now have both children request work.
children.lock().unwrap().insert(name2, ());

child_manager
.resolver_update(ResolverUpdate::default(), Some(&cfg), &mut tcc)
.unwrap();

// Confirm that both children requested work.
match rx_events.recv().await.unwrap() {
TestEvent::ScheduleWork => {}
other => panic!("unexpected event {:?}", other),
};
assert_eq!(child_manager.pending_work.lock().unwrap().len(), 2);

// Perform the work call and assert the pending_work set is empty.
child_manager.work(&mut tcc);
assert_eq!(child_manager.pending_work.lock().unwrap().len(), 0);

// Perform one final call to resolver_update which asserts that both
// child policies had their work methods called.
child_manager
.resolver_update(ResolverUpdate::default(), Some(&cfg), &mut tcc)
.unwrap();
}
}
14 changes: 9 additions & 5 deletions grpc/src/client/load_balancing/graceful_switch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
};
use crate::client::load_balancing::{
ChannelController, LbConfig, LbPolicy, LbPolicyBuilder, LbState, ParsedJsonLbConfig,
Subchannel, SubchannelState, GLOBAL_LB_REGISTRY,
Subchannel, SubchannelState, WorkScheduler, GLOBAL_LB_REGISTRY,
};
use crate::client::name_resolution::ResolverUpdate;
use crate::client::ConnectivityState;
Expand Down Expand Up @@ -113,7 +113,7 @@
config: Option<&LbConfig>,
channel_controller: &mut dyn ChannelController,
) -> Result<(), Box<dyn Error + Send + Sync>> {
let res = self

Check warning on line 116 in grpc/src/client/load_balancing/graceful_switch.rs

View workflow job for this annotation

GitHub Actions / clippy

this let-binding has unit value
.child_manager
.resolver_update(update, config, channel_controller)?;
self.update_picker(channel_controller);
Expand Down Expand Up @@ -150,9 +150,9 @@

impl GracefulSwitchPolicy {
/// Creates a new Graceful Switch policy.
pub fn new(runtime: Arc<dyn Runtime>) -> Self {
pub fn new(runtime: Arc<dyn Runtime>, work_scheduler: Arc<dyn WorkScheduler>) -> Self {
GracefulSwitchPolicy {
child_manager: ChildManager::new(UpdateSharder::new(), runtime),
child_manager: ChildManager::new(UpdateSharder::new(), runtime, work_scheduler),
last_update: LbState::initial(),
}
}
Expand Down Expand Up @@ -248,7 +248,7 @@
self.child_manager
.resolver_update(ResolverUpdate::default(), Some(config), channel_controller)
.expect("resolver_update with an empty update should not fail");
return Some(state);

Check warning on line 251 in grpc/src/client/load_balancing/graceful_switch.rs

View workflow job for this annotation

GitHub Actions / clippy

unneeded `return` statement
}
}

Expand Down Expand Up @@ -372,6 +372,7 @@
});
},
)),
work: None,
}
}

Expand Down Expand Up @@ -400,9 +401,12 @@
tx_events: tx_events.clone(),
});

let tcc = Box::new(TestChannelController { tx_events });
let tcc = Box::new(TestChannelController {
tx_events: tx_events.clone(),
});

let graceful_switch = GracefulSwitchPolicy::new(default_runtime());
let graceful_switch =
GracefulSwitchPolicy::new(default_runtime(), Arc::new(TestWorkScheduler { tx_events }));
(rx_events, Box::new(graceful_switch), tcc)
}

Expand All @@ -423,7 +427,7 @@
) -> Arc<dyn Subchannel> {
match rx_events.recv().await.unwrap() {
TestEvent::NewSubchannel(sc) => {
return sc;

Check warning on line 430 in grpc/src/client/load_balancing/graceful_switch.rs

View workflow job for this annotation

GitHub Actions / clippy

unneeded `return` statement
}
other => panic!("unexpected event {:?}", other),
};
Expand Down
1 change: 1 addition & 0 deletions grpc/src/client/load_balancing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ pub(crate) use registry::GLOBAL_LB_REGISTRY;

/// A collection of data configured on the channel that is constructing this
/// LbPolicy.
#[derive(Debug)]
pub(crate) struct LbPolicyOptions {
/// A hook into the channel's work scheduler that allows the LbPolicy to
/// request the ability to perform operations on the ChannelController.
Expand Down
19 changes: 14 additions & 5 deletions grpc/src/client/load_balancing/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,12 +170,15 @@ type SubchannelUpdateFn = Arc<
+ Sync,
>;

type WorkFn = Arc<dyn Fn(&mut StubPolicyData, &mut dyn ChannelController) + Send + Sync>;

/// This struct holds `LbPolicy` trait stub functions that tests are expected to
/// implement.
#[derive(Clone)]
pub(crate) struct StubPolicyFuncs {
pub resolver_update: Option<ResolverUpdateFn>,
pub subchannel_update: Option<SubchannelUpdateFn>,
pub work: Option<WorkFn>,
}

impl Debug for StubPolicyFuncs {
Expand All @@ -187,13 +190,17 @@ impl Debug for StubPolicyFuncs {
/// Data holds test data that will be passed all to functions in PolicyFuncs
#[derive(Debug)]
pub(crate) struct StubPolicyData {
pub lb_policy_options: LbPolicyOptions,
pub test_data: Option<Box<dyn Any + Send + Sync>>,
}

impl StubPolicyData {
/// Creates an instance of StubPolicyData.
pub fn new() -> Self {
Self { test_data: None }
pub fn new(lb_policy_options: LbPolicyOptions) -> Self {
Self {
test_data: None,
lb_policy_options,
}
}
}

Expand Down Expand Up @@ -232,8 +239,10 @@ impl LbPolicy for StubPolicy {
todo!("Implement exit_idle for StubPolicy")
}

fn work(&mut self, _channel_controller: &mut dyn ChannelController) {
todo!("Implement work for StubPolicy")
fn work(&mut self, channel_controller: &mut dyn ChannelController) {
if let Some(f) = &self.funcs.work {
f(&mut self.data, channel_controller);
}
}
}

Expand All @@ -252,7 +261,7 @@ pub(super) struct MockConfig {

impl LbPolicyBuilder for StubPolicyBuilder {
fn build(&self, options: LbPolicyOptions) -> Box<dyn LbPolicy> {
let data = StubPolicyData::new();
let data = StubPolicyData::new(options);
Box::new(StubPolicy {
funcs: self.funcs.clone(),
data,
Expand Down
Loading