Skip to content

Commit 5842cc9

Browse files
hlopkocopybara-github
authored andcommitted
Implement support for messages as map values
PiperOrigin-RevId: 605581725
1 parent 7da29c6 commit 5842cc9

File tree

11 files changed

+675
-23
lines changed

11 files changed

+675
-23
lines changed

rust/cpp.rs

+7
Original file line numberDiff line numberDiff line change
@@ -402,10 +402,16 @@ pub struct InnerMapMut<'msg> {
402402
_phantom: PhantomData<&'msg ()>,
403403
}
404404

405+
#[doc(hidden)]
405406
impl<'msg> InnerMapMut<'msg> {
406407
pub fn new(_private: Private, raw: RawMap) -> Self {
407408
InnerMapMut { raw, _phantom: PhantomData }
408409
}
410+
411+
#[doc(hidden)]
412+
pub fn as_raw(&self, _private: Private) -> RawMap {
413+
self.raw
414+
}
409415
}
410416

411417
/// An untyped iterator in a map, produced via `.cbegin()` on a typed map.
@@ -547,6 +553,7 @@ macro_rules! impl_ProxiedInMapValue_for_non_generated_value_types {
547553
let ffi_key = $to_ffi_key(key);
548554
let mut ffi_value = MaybeUninit::uninit();
549555
let found = unsafe { [< __rust_proto_thunk__Map_ $key_t _ $t _get >](map.as_raw(Private), ffi_key, ffi_value.as_mut_ptr()) };
556+
550557
if !found {
551558
return None;
552559
}

rust/map.rs

+14
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,12 @@ pub struct MapMut<'msg, K: ?Sized, V: ?Sized> {
4242
_phantom: PhantomData<(&'msg mut K, &'msg mut V)>,
4343
}
4444

45+
impl<'msg, K: ?Sized, V: ?Sized> MapMut<'msg, K, V> {
46+
pub fn inner(&self, _private: Private) -> InnerMapMut {
47+
self.inner
48+
}
49+
}
50+
4551
unsafe impl<'msg, K: ?Sized, V: ?Sized> Sync for MapMut<'msg, K, V> {}
4652

4753
impl<'msg, K: ?Sized, V: ?Sized> std::fmt::Debug for MapMut<'msg, K, V> {
@@ -178,6 +184,14 @@ where
178184
pub unsafe fn from_inner(_private: Private, inner: InnerMapMut<'static>) -> Self {
179185
Self { inner, _phantom: PhantomData }
180186
}
187+
188+
pub fn as_raw(&self, _private: Private) -> RawMap {
189+
self.inner.as_raw(Private)
190+
}
191+
192+
pub fn inner(&self, _private: Private) -> InnerMapMut<'static> {
193+
self.inner
194+
}
181195
}
182196

183197
#[doc(hidden)]

rust/test/shared/BUILD

+2
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,7 @@ rust_test(
446446
],
447447
deps = [
448448
"//rust/test:map_unittest_cc_rust_proto",
449+
"//rust/test:unittest_cc_rust_proto",
449450
"@crate_index//:googletest",
450451
],
451452
)
@@ -462,6 +463,7 @@ rust_test(
462463
],
463464
deps = [
464465
"//rust/test:map_unittest_upb_rust_proto",
466+
"//rust/test:unittest_upb_rust_proto",
465467
"@crate_index//:googletest",
466468
],
467469
)

rust/test/shared/accessors_map_test.rs

+152-1
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
// https://developers.google.com/open-source/licenses/bsd
77

88
use googletest::prelude::*;
9-
use map_unittest_proto::TestMap;
9+
use map_unittest_proto::{TestMap, TestMapWithMessages};
1010
use paste::paste;
1111
use std::collections::HashMap;
12+
use unittest_proto::TestAllTypes;
1213

1314
macro_rules! generate_map_primitives_tests {
1415
(
@@ -145,3 +146,153 @@ fn test_bytes_and_string_copied() {
145146
);
146147
assert_that!(msg.map_int32_bytes_mut().get(1).unwrap(), eq(b"world"));
147148
}
149+
150+
macro_rules! generate_map_with_msg_values_tests {
151+
(
152+
$(($k_field:ident, $k_nonzero:expr, $k_other:expr $(,)?)),*
153+
$(,)?
154+
) => {
155+
paste! { $(
156+
#[test]
157+
fn [< test_map_ $k_field _all_types >]() {
158+
// We need to cover the following upb/c++ thunks:
159+
// TODO - b/323883851: Add test once Map::new is public.
160+
// * new
161+
// * free (covered implicitly by drop)
162+
// * clear, size, insert, get, remove, iter, iter_next (all covered below)
163+
let mut msg = TestMapWithMessages::new();
164+
assert_that!(msg.[< map_ $k_field _all_types >]().len(), eq(0));
165+
assert_that!(msg.[< map_ $k_field _all_types >]().get($k_nonzero), none());
166+
// this block makes sure `insert` copies/moves, not borrows.
167+
{
168+
let mut msg_val = TestAllTypes::new();
169+
msg_val.optional_int32_mut().set(1001);
170+
assert_that!(
171+
msg
172+
.[< map_ $k_field _all_types_mut >]()
173+
.insert($k_nonzero, msg_val.as_view()),
174+
eq(true),
175+
"`insert` should return true when key was inserted."
176+
);
177+
assert_that!(
178+
msg
179+
.[< map_ $k_field _all_types_mut >]()
180+
.insert($k_nonzero, msg_val.as_view()),
181+
eq(false),
182+
"`insert` should return false when key was already present."
183+
184+
);
185+
}
186+
187+
assert_that!(
188+
msg.[< map_ $k_field _all_types >]().len(),
189+
eq(1),
190+
"`size` thunk should return correct len.");
191+
192+
assert_that!(
193+
msg.[< map_ $k_field _all_types >]().get($k_nonzero),
194+
some(anything()),
195+
"`get` should return Some when key present.");
196+
assert_that!(
197+
msg.[< map_ $k_field _all_types >]().get($k_nonzero).unwrap().optional_int32(),
198+
eq(1001));
199+
assert_that!(
200+
msg.[< map_ $k_field _all_types >]().get($k_other),
201+
none(),
202+
"`get` should return None when key missing.");
203+
204+
msg.[< map_ $k_field _all_types_mut >]().clear();
205+
assert_that!(
206+
msg.[< map_ $k_field _all_types >]().len(),
207+
eq(0),
208+
"`clear` should drop all elements.");
209+
210+
211+
assert_that!(
212+
msg.[< map_ $k_field _all_types_mut >]().insert($k_nonzero, TestAllTypes::new().as_view()),
213+
eq(true));
214+
assert_that!(
215+
msg.[< map_ $k_field _all_types_mut >]().remove($k_nonzero),
216+
eq(true),
217+
"`remove` should return true when key was present.");
218+
assert_that!(msg.[< map_ $k_field _all_types >]().len(), eq(0));
219+
assert_that!(
220+
msg.[< map_ $k_field _all_types_mut >]().remove($k_nonzero),
221+
eq(false),
222+
"`remove` should return false when key was missing.");
223+
224+
// empty iter
225+
// assert_that!(
226+
// msg.[< map_ $k_field _all_types_mut >]().iter().collect::<Vec<_>>(),
227+
// elements_are![],
228+
// "`iter` should work when empty."
229+
// );
230+
assert_that!(
231+
msg.[< map_ $k_field _all_types_mut >]().keys().collect::<Vec<_>>(),
232+
elements_are![],
233+
"`iter` should work when empty."
234+
);
235+
assert_that!(
236+
msg.[< map_ $k_field _all_types_mut >]().values().collect::<Vec<_>>(),
237+
elements_are![],
238+
"`iter` should work when empty."
239+
);
240+
241+
// single element iter
242+
assert_that!(
243+
msg.[< map_ $k_field _all_types_mut >]().insert($k_nonzero, TestAllTypes::new().as_view()),
244+
eq(true));
245+
// assert_that!(
246+
// msg.[< map_ $k_field _all_types >]().iter().collect::<Vec<_>>(),
247+
// unordered_elements_are![
248+
// eq(($k_nonzero, anything())),
249+
// ]
250+
// );
251+
assert_that!(
252+
msg.[< map_ $k_field _all_types >]().keys().collect::<Vec<_>>(),
253+
unordered_elements_are![eq($k_nonzero)]
254+
);
255+
assert_that!(
256+
msg.[< map_ $k_field _all_types >]().values().collect::<Vec<_>>().len(),
257+
eq(1));
258+
259+
260+
// 2 element iter
261+
assert_that!(
262+
msg
263+
.[< map_ $k_field _all_types_mut >]()
264+
.insert($k_other, TestAllTypes::new().as_view()),
265+
eq(true));
266+
267+
assert_that!(
268+
msg.[< map_ $k_field _all_types >]().iter().collect::<Vec<_>>().len(),
269+
eq(2)
270+
);
271+
assert_that!(
272+
msg.[< map_ $k_field _all_types >]().keys().collect::<Vec<_>>(),
273+
unordered_elements_are![eq($k_nonzero), eq($k_other)]
274+
);
275+
assert_that!(
276+
msg.[< map_ $k_field _all_types >]().values().collect::<Vec<_>>().len(),
277+
eq(2)
278+
);
279+
}
280+
)* }
281+
}
282+
}
283+
284+
generate_map_with_msg_values_tests!(
285+
(int32, 1i32, 2i32),
286+
(int64, 1i64, 2i64),
287+
(uint32, 1u32, 2u32),
288+
(uint64, 1u64, 2u64),
289+
(sint32, 1, 2),
290+
(sint64, 1, 2),
291+
(fixed32, 1u32, 2u32),
292+
(fixed64, 1u64, 2u64),
293+
(sfixed32, 1, 2),
294+
(sfixed64, 1, 2),
295+
// TODO - b/324468833: fix msan failure
296+
// (bool, true, false),
297+
(string, "foo", "bar"),
298+
);

rust/upb.rs

+25-14
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,10 @@ pub struct Arena {
4646

4747
extern "C" {
4848
// `Option<NonNull<T: Sized>>` is ABI-compatible with `*mut T`
49-
fn upb_Arena_New() -> Option<RawArena>;
50-
fn upb_Arena_Free(arena: RawArena);
51-
fn upb_Arena_Malloc(arena: RawArena, size: usize) -> *mut u8;
52-
fn upb_Arena_Realloc(arena: RawArena, ptr: *mut u8, old: usize, new: usize) -> *mut u8;
49+
pub fn upb_Arena_New() -> Option<RawArena>;
50+
pub fn upb_Arena_Free(arena: RawArena);
51+
pub fn upb_Arena_Malloc(arena: RawArena, size: usize) -> *mut u8;
52+
pub fn upb_Arena_Realloc(arena: RawArena, ptr: *mut u8, old: usize, new: usize) -> *mut u8;
5353
}
5454

5555
impl Arena {
@@ -716,13 +716,24 @@ pub struct InnerMapMut<'msg> {
716716
_phantom: PhantomData<&'msg Arena>,
717717
}
718718

719+
#[doc(hidden)]
719720
impl<'msg> InnerMapMut<'msg> {
720721
pub fn new(_private: Private, raw: RawMap, raw_arena: RawArena) -> Self {
721722
InnerMapMut { raw, raw_arena, _phantom: PhantomData }
722723
}
724+
725+
#[doc(hidden)]
726+
pub fn as_raw(&self, _private: Private) -> RawMap {
727+
self.raw
728+
}
729+
730+
#[doc(hidden)]
731+
pub fn raw_arena(&self, _private: Private) -> RawArena {
732+
self.raw_arena
733+
}
723734
}
724735

725-
trait UpbTypeConversions: Proxied {
736+
pub trait UpbTypeConversions: Proxied {
726737
fn upb_type() -> UpbCType;
727738
fn to_message_value(val: View<'_, Self>) -> upb_MessageValue;
728739
fn empty_message_value() -> upb_MessageValue;
@@ -858,7 +869,7 @@ impl RawMapIter {
858869
/// # Safety
859870
/// - `self.map` must be valid, and remain valid while the return value is
860871
/// in use.
861-
pub(crate) unsafe fn next_unchecked(
872+
pub unsafe fn next_unchecked(
862873
&mut self,
863874
_private: Private,
864875
) -> Option<(upb_MessageValue, upb_MessageValue)> {
@@ -986,7 +997,7 @@ impl_ProxiedInMapValue_for_key_types!(i32, u32, i64, u64, bool, ProtoStr);
986997

987998
#[repr(C)]
988999
#[allow(dead_code)]
989-
enum upb_MapInsertStatus {
1000+
pub enum upb_MapInsertStatus {
9901001
Inserted = 0,
9911002
Replaced = 1,
9921003
OutOfMemory = 2,
@@ -1019,25 +1030,25 @@ pub unsafe fn upb_Map_InsertAndReturnIfInserted(
10191030
}
10201031

10211032
extern "C" {
1022-
fn upb_Map_New(arena: RawArena, key_type: UpbCType, value_type: UpbCType) -> RawMap;
1023-
fn upb_Map_Size(map: RawMap) -> usize;
1024-
fn upb_Map_Insert(
1033+
pub fn upb_Map_New(arena: RawArena, key_type: UpbCType, value_type: UpbCType) -> RawMap;
1034+
pub fn upb_Map_Size(map: RawMap) -> usize;
1035+
pub fn upb_Map_Insert(
10251036
map: RawMap,
10261037
key: upb_MessageValue,
10271038
value: upb_MessageValue,
10281039
arena: RawArena,
10291040
) -> upb_MapInsertStatus;
1030-
fn upb_Map_Get(map: RawMap, key: upb_MessageValue, value: *mut upb_MessageValue) -> bool;
1031-
fn upb_Map_Delete(
1041+
pub fn upb_Map_Get(map: RawMap, key: upb_MessageValue, value: *mut upb_MessageValue) -> bool;
1042+
pub fn upb_Map_Delete(
10321043
map: RawMap,
10331044
key: upb_MessageValue,
10341045
removed_value: *mut upb_MessageValue,
10351046
) -> bool;
1036-
fn upb_Map_Clear(map: RawMap);
1047+
pub fn upb_Map_Clear(map: RawMap);
10371048

10381049
static __rust_proto_kUpb_Map_Begin: usize;
10391050

1040-
fn upb_Map_Next(
1051+
pub fn upb_Map_Next(
10411052
map: RawMap,
10421053
key: *mut upb_MessageValue,
10431054
value: *mut upb_MessageValue,

src/google/protobuf/compiler/rust/accessors/accessors.cc

+1-3
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,8 @@ std::unique_ptr<AccessorGenerator> AccessorGeneratorFor(
3636
auto value_type = field.message_type()->map_value()->type();
3737
switch (value_type) {
3838
case FieldDescriptor::TYPE_ENUM:
39-
case FieldDescriptor::TYPE_MESSAGE:
4039
return std::make_unique<UnsupportedField>(
41-
"Maps with values of type enum and message are not "
42-
"supported");
40+
"Maps with values of type enum are not supported");
4341
default:
4442
return std::make_unique<Map>();
4543
}

src/google/protobuf/compiler/rust/accessors/map.cc

+15-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
// license that can be found in the LICENSE file or at
66
// https://developers.google.com/open-source/licenses/bsd
77

8+
#include <string>
9+
810
#include "google/protobuf/compiler/cpp/helpers.h"
911
#include "google/protobuf/compiler/rust/accessors/accessor_case.h"
1012
#include "google/protobuf/compiler/rust/accessors/accessor_generator.h"
@@ -115,13 +117,24 @@ void Map::InExternC(Context& ctx, const FieldDescriptor& field) const {
115117
)rs");
116118
}
117119

120+
std::string MapElementTypeName(FieldDescriptor::CppType cpp_type,
121+
const Descriptor* message_type) {
122+
if (cpp_type == FieldDescriptor::CPPTYPE_MESSAGE ||
123+
cpp_type == FieldDescriptor::CPPTYPE_ENUM) {
124+
return cpp::QualifiedClassName(message_type);
125+
}
126+
return cpp::PrimitiveTypeName(cpp_type);
127+
}
128+
118129
void Map::InThunkCc(Context& ctx, const FieldDescriptor& field) const {
119130
ctx.Emit(
120131
{{"field", cpp::FieldName(&field)},
121132
{"Key",
122-
cpp::PrimitiveTypeName(field.message_type()->map_key()->cpp_type())},
133+
MapElementTypeName(field.message_type()->map_key()->cpp_type(),
134+
field.message_type()->map_key()->message_type())},
123135
{"Value",
124-
cpp::PrimitiveTypeName(field.message_type()->map_value()->cpp_type())},
136+
MapElementTypeName(field.message_type()->map_value()->cpp_type(),
137+
field.message_type()->map_value()->message_type())},
125138
{"QualifiedMsg", cpp::QualifiedClassName(field.containing_type())},
126139
{"getter_thunk", ThunkName(ctx, field, "get")},
127140
{"getter_mut_thunk", ThunkName(ctx, field, "get_mut")},

0 commit comments

Comments
 (0)