Skip to content

Commit c710afa

Browse files
authored
Rework leaf node extensions to work via parameters rather than as a c… (#196)
* Rework leaf node extensions to work via parameters rather than as a client configuration * Rework key package extensions to work via parameters rather than as a client configuration * Address clippy issues * Apply formatting changes * Remove TODO on Renit Client as WONT DO is the conclusion * Fix unit tests breaking due to grease
1 parent 1a1fa84 commit c710afa

30 files changed

+311
-196
lines changed

mls-rs-uniffi/src/lib.rs

+10-3
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,10 @@ impl Client {
382382
/// See [`mls_rs::Client::generate_key_package_message`] for
383383
/// details.
384384
pub async fn generate_key_package_message(&self) -> Result<Message, Error> {
385-
let message = self.inner.generate_key_package_message().await?;
385+
let message = self
386+
.inner
387+
.generate_key_package_message(Default::default(), Default::default())
388+
.await?;
386389
Ok(message.into())
387390
}
388391

@@ -403,10 +406,14 @@ impl Client {
403406
let inner = match group_id {
404407
Some(group_id) => {
405408
self.inner
406-
.create_group_with_id(group_id, extensions)
409+
.create_group_with_id(group_id, extensions, Default::default())
410+
.await?
411+
}
412+
None => {
413+
self.inner
414+
.create_group(extensions, Default::default())
407415
.await?
408416
}
409-
None => self.inner.create_group(extensions).await?,
410417
};
411418
Ok(Group {
412419
inner: Arc::new(Mutex::new(inner)),

mls-rs/benches/group_add.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,15 @@ use mls_rs_crypto_openssl::OpensslCryptoProvider;
1616

1717
fn bench(c: &mut Criterion) {
1818
let alice = make_client("alice")
19-
.create_group(Default::default())
19+
.create_group(Default::default(), Default::default())
2020
.unwrap();
2121

2222
const MAX_ADD_COUNT: usize = 1000;
2323

2424
let key_packages = (0..MAX_ADD_COUNT)
2525
.map(|i| {
2626
make_client(&format!("bob-{i}"))
27-
.generate_key_package_message()
27+
.generate_key_package_message(Default::default(), Default::default())
2828
.unwrap()
2929
})
3030
.collect::<Vec<_>>();

mls-rs/examples/basic_server_usage.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,9 @@ fn main() -> Result<(), MlsError> {
140140
let bob = make_client("bob")?;
141141

142142
// Alice creates a group with bob
143-
let mut alice_group = alice.create_group(ExtensionList::default())?;
144-
let bob_key_package = bob.generate_key_package_message()?;
143+
let mut alice_group = alice.create_group(ExtensionList::default(), Default::default())?;
144+
let bob_key_package =
145+
bob.generate_key_package_message(Default::default(), Default::default())?;
145146

146147
let welcome = &alice_group
147148
.commit_builder()

mls-rs/examples/basic_usage.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,11 @@ fn main() -> Result<(), MlsError> {
4444
let bob = make_client(crypto_provider.clone(), "bob")?;
4545

4646
// Alice creates a new group.
47-
let mut alice_group = alice.create_group(ExtensionList::default())?;
47+
let mut alice_group = alice.create_group(ExtensionList::default(), Default::default())?;
4848

4949
// Bob generates a key package that Alice needs to add Bob to the group.
50-
let bob_key_package = bob.generate_key_package_message()?;
50+
let bob_key_package =
51+
bob.generate_key_package_message(Default::default(), Default::default())?;
5152

5253
// Alice issues a commit that adds Bob to the group.
5354
let alice_commit = alice_group

mls-rs/examples/custom.rs

+6-3
Original file line numberDiff line numberDiff line change
@@ -369,11 +369,13 @@ fn main() -> Result<(), CustomError> {
369369
let roster = vec![alice.credential];
370370
context_extensions.set_from(RosterExtension { roster })?;
371371

372-
let mut alice_tablet_group = make_client(alice_tablet)?.create_group(context_extensions)?;
372+
let mut alice_tablet_group =
373+
make_client(alice_tablet)?.create_group(context_extensions, Default::default())?;
373374

374375
// Alice can add her other device
375376
let alice_pc_client = make_client(alice_pc)?;
376-
let key_package = alice_pc_client.generate_key_package_message()?;
377+
let key_package =
378+
alice_pc_client.generate_key_package_message(Default::default(), Default::default())?;
377379

378380
let welcome = alice_tablet_group
379381
.commit_builder()
@@ -387,7 +389,8 @@ fn main() -> Result<(), CustomError> {
387389

388390
// Alice cannot add bob's devices yet
389391
let bob_tablet_client = make_client(bob_tablet)?;
390-
let key_package = bob_tablet_client.generate_key_package_message()?;
392+
let key_package =
393+
bob_tablet_client.generate_key_package_message(Default::default(), Default::default())?;
391394

392395
let res = alice_tablet_group
393396
.commit_builder()

mls-rs/examples/large_group.rs

+6-4
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,16 @@ fn make_groups_best_case<P: CryptoProvider + Clone>(
5858
) -> Result<Vec<Group<impl MlsConfig>>, MlsError> {
5959
let bob_client = make_client(crypto_provider.clone(), &make_name(0))?;
6060

61-
let bob_group = bob_client.create_group(Default::default())?;
61+
let bob_group = bob_client.create_group(Default::default(), Default::default())?;
6262

6363
let mut groups = vec![bob_group];
6464

6565
for i in 0..(num_groups - 1) {
6666
let bob_client = make_client(crypto_provider.clone(), &make_name(i + 1))?;
6767

6868
// The new client generates a key package.
69-
let bob_kpkg = bob_client.generate_key_package_message()?;
69+
let bob_kpkg =
70+
bob_client.generate_key_package_message(Default::default(), Default::default())?;
7071

7172
// Last group sends a commit adding the new client to the group.
7273
let commit = groups
@@ -100,7 +101,7 @@ fn make_groups_worst_case<P: CryptoProvider + Clone>(
100101
) -> Result<Vec<Group<impl MlsConfig>>, MlsError> {
101102
let alice_client = make_client(crypto_provider.clone(), &make_name(0))?;
102103

103-
let mut alice_group = alice_client.create_group(Default::default())?;
104+
let mut alice_group = alice_client.create_group(Default::default(), Default::default())?;
104105

105106
let bob_clients = (0..(num_groups - 1))
106107
.map(|i| make_client(crypto_provider.clone(), &make_name(i + 1)))
@@ -110,7 +111,8 @@ fn make_groups_worst_case<P: CryptoProvider + Clone>(
110111
let mut commit_builder = alice_group.commit_builder();
111112

112113
for bob_client in &bob_clients {
113-
let bob_kpkg = bob_client.generate_key_package_message()?;
114+
let bob_kpkg =
115+
bob_client.generate_key_package_message(Default::default(), Default::default())?;
114116
commit_builder = commit_builder.add_member(bob_kpkg)?;
115117
}
116118

mls-rs/examples/x509.rs

+3-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@ fn main() {
3131
.signing_identity(signing_identity, secret_key, CIPHERSUITE)
3232
.build();
3333

34-
let mut alice_group = alice_client.create_group(Default::default()).unwrap();
34+
let mut alice_group = alice_client
35+
.create_group(Default::default(), Default::default())
36+
.unwrap();
3537

3638
alice_group.commit(Vec::new()).unwrap();
3739
alice_group.apply_pending_commit().unwrap();

mls-rs/src/client.rs

+51-15
Original file line numberDiff line numberDiff line change
@@ -429,12 +429,23 @@ where
429429
///
430430
/// A key package message may only be used once.
431431
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
432-
pub async fn generate_key_package_message(&self) -> Result<MlsMessage, MlsError> {
433-
Ok(self.generate_key_package().await?.key_package_message())
432+
pub async fn generate_key_package_message(
433+
&self,
434+
key_package_extensions: ExtensionList,
435+
leaf_node_extensions: ExtensionList,
436+
) -> Result<MlsMessage, MlsError> {
437+
Ok(self
438+
.generate_key_package(key_package_extensions, leaf_node_extensions)
439+
.await?
440+
.key_package_message())
434441
}
435442

436443
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
437-
async fn generate_key_package(&self) -> Result<KeyPackageGeneration, MlsError> {
444+
async fn generate_key_package(
445+
&self,
446+
key_package_extensions: ExtensionList,
447+
leaf_node_extensions: ExtensionList,
448+
) -> Result<KeyPackageGeneration, MlsError> {
438449
let (signing_identity, cipher_suite) = self.signing_identity()?;
439450

440451
let cipher_suite_provider = self
@@ -454,8 +465,8 @@ where
454465
.generate(
455466
self.config.lifetime(),
456467
self.config.capabilities(),
457-
self.config.key_package_extensions(),
458-
self.config.leaf_node_extensions(),
468+
key_package_extensions,
469+
leaf_node_extensions,
459470
)
460471
.await?;
461472

@@ -486,6 +497,7 @@ where
486497
&self,
487498
group_id: Vec<u8>,
488499
group_context_extensions: ExtensionList,
500+
leaf_node_extensions: ExtensionList,
489501
) -> Result<Group<C>, MlsError> {
490502
let (signing_identity, cipher_suite) = self.signing_identity()?;
491503

@@ -496,6 +508,7 @@ where
496508
self.version,
497509
signing_identity.clone(),
498510
group_context_extensions,
511+
leaf_node_extensions,
499512
self.signer()?.clone(),
500513
)
501514
.await
@@ -510,6 +523,7 @@ where
510523
pub async fn create_group(
511524
&self,
512525
group_context_extensions: ExtensionList,
526+
leaf_node_extensions: ExtensionList,
513527
) -> Result<Group<C>, MlsError> {
514528
let (signing_identity, cipher_suite) = self.signing_identity()?;
515529

@@ -520,6 +534,7 @@ where
520534
self.version,
521535
signing_identity.clone(),
522536
group_context_extensions,
537+
leaf_node_extensions,
523538
self.signer()?.clone(),
524539
)
525540
.await
@@ -674,6 +689,8 @@ where
674689
group_info: &MlsMessage,
675690
tree_data: Option<crate::group::ExportedTree<'_>>,
676691
authenticated_data: Vec<u8>,
692+
key_package_extensions: ExtensionList,
693+
leaf_node_extensions: ExtensionList,
677694
) -> Result<MlsMessage, MlsError> {
678695
let protocol_version = group_info.version;
679696

@@ -702,7 +719,10 @@ where
702719
)
703720
.await?;
704721

705-
let key_package = self.generate_key_package().await?.key_package;
722+
let key_package = self
723+
.generate_key_package(key_package_extensions, leaf_node_extensions)
724+
.await?
725+
.key_package;
706726

707727
(key_package.cipher_suite == cipher_suite)
708728
.then_some(())
@@ -745,11 +765,6 @@ where
745765
.ok_or(MlsError::SignerNotFound)
746766
}
747767

748-
/// Returns key package extensions used by this client
749-
pub fn key_package_extensions(&self) -> ExtensionList {
750-
self.config.key_package_extensions()
751-
}
752-
753768
/// The [KeyPackageStorage] that this client was configured to use.
754769
#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen_ignore)]
755770
pub fn key_package_store(&self) -> <C as ClientConfig>::KeyPackageRepository {
@@ -793,14 +808,24 @@ pub(crate) mod test_utils {
793808
cipher_suite: CipherSuite,
794809
identity: &str,
795810
) -> (Client<TestClientConfig>, MlsMessage) {
796-
test_client_with_key_pkg_custom(protocol_version, cipher_suite, identity, |_| {}).await
811+
test_client_with_key_pkg_custom(
812+
protocol_version,
813+
cipher_suite,
814+
identity,
815+
Default::default(),
816+
Default::default(),
817+
|_| {},
818+
)
819+
.await
797820
}
798821

799822
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
800823
pub async fn test_client_with_key_pkg_custom<F>(
801824
protocol_version: ProtocolVersion,
802825
cipher_suite: CipherSuite,
803826
identity: &str,
827+
key_package_extensions: ExtensionList,
828+
leaf_node_extensions: ExtensionList,
804829
mut config: F,
805830
) -> (Client<TestClientConfig>, MlsMessage)
806831
where
@@ -816,7 +841,10 @@ pub(crate) mod test_utils {
816841

817842
config(&mut client.config);
818843

819-
let key_package = client.generate_key_package_message().await.unwrap();
844+
let key_package = client
845+
.generate_key_package_message(key_package_extensions, leaf_node_extensions)
846+
.await
847+
.unwrap();
820848

821849
(client, key_package)
822850
}
@@ -863,7 +891,10 @@ mod tests {
863891
.build();
864892

865893
// TODO: Tests around extensions
866-
let key_package = client.generate_key_package_message().await.unwrap();
894+
let key_package = client
895+
.generate_key_package_message(Default::default(), Default::default())
896+
.await
897+
.unwrap();
867898

868899
assert_eq!(key_package.version, protocol_version);
869900

@@ -902,6 +933,8 @@ mod tests {
902933
&alice_group.group_info_message(true).await.unwrap(),
903934
None,
904935
vec![],
936+
Default::default(),
937+
Default::default(),
905938
)
906939
.await
907940
.unwrap();
@@ -1047,7 +1080,10 @@ mod tests {
10471080
.signing_identity(alice_identity.clone(), secret_key, TEST_CIPHER_SUITE)
10481081
.build();
10491082

1050-
let msg = alice.generate_key_package_message().await.unwrap();
1083+
let msg = alice
1084+
.generate_key_package_message(Default::default(), Default::default())
1085+
.await
1086+
.unwrap();
10511087
let res = alice.commit_external(msg).await.map(|_| ());
10521088

10531089
assert_matches!(res, Err(MlsError::UnexpectedMessageType));

0 commit comments

Comments
 (0)