diff --git a/Cargo.toml b/Cargo.toml index cac2b00b..25b48d3c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,7 @@ keywords = ["deep-learning", "language", "model", "rwkv"] license = "MIT OR Apache-2.0" repository = "https://github.com/cryscan/web-rwkv" rust-version = "1.83.0" -version = "0.10.16" +version = "0.10.17" [package] authors.workspace = true @@ -24,23 +24,26 @@ name = "web-rwkv" repository.workspace = true version.workspace = true +[workspace.dependencies] +serde = "1.0" + [dependencies] ahash = "0.8" -bytemuck = { version = "1.23", features = ["extern_crate_alloc"] } +bytemuck = { version = "1.24", features = ["extern_crate_alloc"] } derive-getters = "0.5" document-features = "0.2.8" embed-doc-image = "0.1.4" flume = "0.11" futures = "0.3" gpp = "0.6.2" -half = { version = "2.2", features = ["bytemuck", "serde"] } +half = { version = "2.7", features = ["bytemuck", "serde"] } instant = { version = "0.1", features = ["inaccurate", "wasm-bindgen"] } itertools = "0.14" log = "0.4" -regex = "1.11" +regex = "1.12" rustc-hash = "2.1" safetensors = "0.6" -serde = { version = "1.0", features = ["derive", "rc"] } +serde = { workspace = true, features = ["derive", "rc"] } serde_bytes = "0.11" serde_json = "1.0" serde_variant = "0.1.3" @@ -52,7 +55,7 @@ tracing-tracy = { version = "0.11.4", optional = true } trait-variant = "0.1" uid = "0.1" wasm-bindgen = "0.2" -wgpu = "26.0" +wgpu = "27.0" [dependencies.web-rwkv-derive] path = "crates/web-rwkv-derive" @@ -62,7 +65,7 @@ version = "0.10" default-features = false features = ["macros", "rt", "sync", "time"] optional = true -version = "1.47" +version = "1.48" [dev-dependencies] anyhow = "1.0" @@ -73,7 +76,7 @@ dialoguer = "0.12.0" fastrand = "2.3" memmap2 = "0.9" ratatui = { version = "0.29", features = ["all-widgets"] } -simple_logger = { version = "5.0.0", features = ["stderr"] } +simple_logger = { version = "5.1", features = ["stderr"] } tokio = { version = "1.41", features = ["full"] } [features] diff --git a/crates/web-rwkv-derive/Cargo.toml b/crates/web-rwkv-derive/Cargo.toml index 9b290657..a0bb7952 100644 --- a/crates/web-rwkv-derive/Cargo.toml +++ b/crates/web-rwkv-derive/Cargo.toml @@ -17,5 +17,15 @@ proc-macro2 = "1" quote = "1" syn = "2" +[dependencies.serde] +workspace = true + +[build-dependencies] +cargo_metadata = "0.23" + +[features] +default = [] +deserialize_in_place = [] + [lib] proc-macro = true diff --git a/crates/web-rwkv-derive/build.rs b/crates/web-rwkv-derive/build.rs new file mode 100644 index 00000000..9d5f0de8 --- /dev/null +++ b/crates/web-rwkv-derive/build.rs @@ -0,0 +1,75 @@ +use cargo_metadata::MetadataCommand; +use std::{ + collections::{HashSet, VecDeque}, + path::Path, +}; + +fn main() { + let metadata = MetadataCommand::new() + .exec() + .expect("failed to obtain cargo metadata"); + + // locate the current package using CARGO_MANIFEST_DIR + let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR not set"); + let manifest_path = Path::new(&manifest_dir).join("Cargo.toml"); + let manifest_path_str = manifest_path.to_str().expect("path not UTF-8"); + + let current_package = metadata + .packages + .iter() + .find(|p| p.manifest_path.as_str() == manifest_path_str) + .expect("current package not found in metadata"); + + let resolve = metadata.resolve.as_ref().expect("resolve graph missing"); + + let current_node = resolve + .nodes + .iter() + .find(|node| node.id == current_package.id) + .expect("current node not found in resolve graph"); + + // perform BFS to find the first occurrence of the "serde" crate + let mut queue = VecDeque::new(); + let mut visited = HashSet::new(); + queue.push_back(current_node); + visited.insert(¤t_node.id); + + let serde_version = loop { + let node = queue.pop_front().expect("dependency graph exhausted"); + + // check dependencies of the current node + let mut found = None; + for id in &node.dependencies { + if visited.contains(id) { + continue; + } + visited.insert(id); + + let node = resolve + .nodes + .iter() + .find(|n| &n.id == id) + .expect("dependency node not found"); + let package = metadata + .packages + .iter() + .find(|p| p.id == node.id) + .expect("dependency package not found"); + + if package.name == "serde" { + found = Some(package.version.clone()); + break; + } + queue.push_back(node); + } + + if let Some(version) = found { + break version; + } + }; + + println!( + "cargo:rustc-env=SERDE_PATCH_VERSION={}", + serde_version.patch + ); +} diff --git a/crates/web-rwkv-derive/src/serde/bound.rs b/crates/web-rwkv-derive/src/serde/bound.rs index 4a6ea6da..41c96399 100644 --- a/crates/web-rwkv-derive/src/serde/bound.rs +++ b/crates/web-rwkv-derive/src/serde/bound.rs @@ -144,7 +144,7 @@ pub fn with_bound( fn visit_type(&mut self, ty: &'ast syn::Type) { match ty { - #![cfg_attr(all(test), deny(non_exhaustive_omitted_patterns))] + #![cfg_attr(all(test, exhaustive), deny(non_exhaustive_omitted_patterns))] syn::Type::Array(ty) => self.visit_type(&ty.elem), syn::Type::BareFn(ty) => { for arg in &ty.inputs { @@ -196,7 +196,7 @@ pub fn with_bound( syn::PathArguments::AngleBracketed(arguments) => { for arg in &arguments.args { match arg { - #![cfg_attr(all(test), deny(non_exhaustive_omitted_patterns))] + #![cfg_attr(all(test, exhaustive), deny(non_exhaustive_omitted_patterns))] syn::GenericArgument::Type(arg) => self.visit_type(arg), syn::GenericArgument::AssocType(arg) => self.visit_type(&arg.ty), syn::GenericArgument::Lifetime(_) @@ -225,9 +225,11 @@ pub fn with_bound( fn visit_type_param_bound(&mut self, bound: &'ast syn::TypeParamBound) { match bound { - #![cfg_attr(all(test), deny(non_exhaustive_omitted_patterns))] + #![cfg_attr(all(test, exhaustive), deny(non_exhaustive_omitted_patterns))] syn::TypeParamBound::Trait(bound) => self.visit_path(&bound.path), - syn::TypeParamBound::Lifetime(_) | syn::TypeParamBound::Verbatim(_) => {} + syn::TypeParamBound::Lifetime(_) + | syn::TypeParamBound::PreciseCapture(_) + | syn::TypeParamBound::Verbatim(_) => {} _ => {} } } diff --git a/crates/web-rwkv-derive/src/serde/de.rs b/crates/web-rwkv-derive/src/serde/de.rs index 39a49332..044a9aa3 100644 --- a/crates/web-rwkv-derive/src/serde/de.rs +++ b/crates/web-rwkv-derive/src/serde/de.rs @@ -1,8 +1,10 @@ -use crate::serde::fragment::{Expr, Fragment, Match, Stmts}; +use crate::serde::deprecated::allow_deprecated; +use crate::serde::fragment::{Expr, Fragment, Stmts}; use crate::serde::internals::ast::{Container, Data, Field, Style, Variant}; +use crate::serde::internals::name::Name; use crate::serde::internals::{attr, replace_receiver, ungroup, Ctxt, Derive}; -use crate::serde::{bound, dummy, this}; -use proc_macro2::{Literal, Span, TokenStream}; +use crate::serde::{bound, dummy, pretend, private, this}; +use proc_macro2::{Span, TokenStream}; use quote::{quote, quote_spanned, ToTokens}; use std::collections::BTreeSet; use std::ptr; @@ -10,23 +12,21 @@ use syn::punctuated::Punctuated; use syn::spanned::Spanned; use syn::{parse_quote, Ident, Index, Member}; -macro_rules! quote_expr { - ($($tt:tt)*) => { - $crate::serde::fragment::Fragment::Expr(quote!($($tt)*)) - } -} - -macro_rules! quote_block { - ($($tt:tt)*) => { - $crate::serde::fragment::Fragment::Block(quote!($($tt)*)) - } -} +mod enum_; +mod enum_adjacently; +mod enum_externally; +mod enum_internally; +mod enum_untagged; +mod identifier; +mod struct_; +mod tuple; +mod unit; pub fn expand_derive_deserialize(input: &mut syn::DeriveInput) -> syn::Result { replace_receiver(input); let ctxt = Ctxt::new(); - let cont = match Container::from_ast(&ctxt, input, Derive::Deserialize) { + let cont = match Container::from_ast(&ctxt, input, Derive::Deserialize, &private.ident()) { Some(cont) => cont, None => return Err(ctxt.check().unwrap_err()), }; @@ -35,25 +35,47 @@ pub fn expand_derive_deserialize(input: &mut syn::DeriveInput) -> syn::Result(self, __deserializer: __D) -> _serde::#private::Result<#remote #ty_generics, __D::Error> + where + __D: _serde::Deserializer<#delife>, + { + #used + #body + } + } + } + } else { + let fn_deserialize_in_place = deserialize_in_place_body(&cont, ¶ms); + quote! { #[automatically_derived] impl #de_impl_generics #serde::de::DeserializeSeed<#delife> for #seed<#delife, #context, #ident #ty_generics> #where_clause { type Value = #ident #ty_generics; - fn deserialize<__D>(self, __deserializer: __D) -> #serde::__private::Result + fn deserialize<__D>(self, __deserializer: __D) -> #serde::#private::Result where __D: #serde::Deserializer<#delife>, { #body } + + #fn_deserialize_in_place } } }; @@ -160,6 +182,23 @@ impl Parameters { fn type_name(&self) -> String { self.this_type.segments.last().unwrap().ident.to_string() } + + /// Split the data structure's generics into the pieces to use for its + /// `Deserialize` impl, augmented with an additional `'de` lifetime for use + /// as the `Deserialize` trait's lifetime. + fn generics_with_de_lifetime( + &self, + ) -> ( + DeImplGenerics, + DeTypeGenerics, + syn::TypeGenerics, + Option<&syn::WhereClause>, + ) { + let de_impl_generics = DeImplGenerics(self); + let de_ty_generics = DeTypeGenerics(self); + let (_, ty_generics, where_clause) = self.generics.split_for_impl(); + (de_impl_generics, de_ty_generics, ty_generics, where_clause) + } } // All the generics in the input, plus a bound `T: Deserialize` for each generic @@ -180,7 +219,7 @@ fn build_generics(cont: &Container, borrowed: &BorrowedLifetimes) -> syn::Generi attr::Default::Default => bound::with_self_bound( cont, &generics, - &parse_quote!(_serde::__private::Default), + &parse_quote!(_serde::#private::Default), ), attr::Default::None | attr::Default::Path(_) => generics, }; @@ -197,7 +236,7 @@ fn build_generics(cont: &Container, borrowed: &BorrowedLifetimes) -> syn::Generi cont, &generics, requires_default, - &parse_quote!(_serde::__private::Default), + &parse_quote!(_serde::#private::Default), ) } } @@ -223,7 +262,11 @@ fn needs_deserialize_bound(field: &attr::Field, variant: Option<&attr::Variant>) // Fields with a `default` attribute (not `default=...`), and fields with a // `skip_deserializing` attribute that do not also have `default=...`. fn requires_default(field: &attr::Field, _variant: Option<&attr::Variant>) -> bool { - matches!(*field.default(), attr::Default::Default) + if let attr::Default::Default = *field.default() { + true + } else { + false + } } enum BorrowedLifetimes { @@ -284,23 +327,74 @@ fn deserialize_body(cont: &Container, params: &Parameters) -> Fragment { deserialize_try_from(type_try_from) } else if let attr::Identifier::No = cont.attrs.identifier() { match &cont.data { - Data::Enum(variants) => deserialize_enum(params, variants, &cont.attrs), + Data::Enum(variants) => enum_::deserialize(params, variants, &cont.attrs), Data::Struct(Style::Struct, fields) => { - deserialize_struct(params, fields, &cont.attrs, StructForm::Struct) + struct_::deserialize(params, fields, &cont.attrs, StructForm::Struct) } Data::Struct(Style::Tuple, fields) | Data::Struct(Style::Newtype, fields) => { - deserialize_tuple(params, fields, &cont.attrs, TupleForm::Tuple) + tuple::deserialize(params, fields, &cont.attrs, TupleForm::Tuple) } - Data::Struct(Style::Unit, _) => deserialize_unit_struct(params, &cont.attrs), + Data::Struct(Style::Unit, _) => unit::deserialize(params, &cont.attrs), } } else { match &cont.data { - Data::Enum(variants) => deserialize_custom_identifier(params, variants, &cont.attrs), + Data::Enum(variants) => identifier::deserialize_custom(params, variants, &cont.attrs), Data::Struct(_, _) => unreachable!("checked in serde_derive_internals"), } } } +#[cfg(feature = "deserialize_in_place")] +fn deserialize_in_place_body(cont: &Container, params: &Parameters) -> Option { + // Only remote derives have getters, and we do not generate + // deserialize_in_place for remote derives. + assert!(!params.has_getter); + + if cont.attrs.transparent() + || cont.attrs.type_from().is_some() + || cont.attrs.type_try_from().is_some() + || cont.attrs.identifier().is_some() + || cont + .data + .all_fields() + .all(|f| f.attrs.deserialize_with().is_some()) + { + return None; + } + + let code = match &cont.data { + Data::Struct(Style::Struct, fields) => { + struct_::deserialize_in_place(params, fields, &cont.attrs)? + } + Data::Struct(Style::Tuple, fields) | Data::Struct(Style::Newtype, fields) => { + tuple::deserialize_in_place(params, fields, &cont.attrs) + } + Data::Enum(_) | Data::Struct(Style::Unit, _) => { + return None; + } + }; + + let delife = params.borrowed.de_lifetime(); + let stmts = Stmts(code); + + let fn_deserialize_in_place = quote_block! { + fn deserialize_in_place<__D>(__deserializer: __D, __place: &mut Self) -> _serde::#private::Result<(), __D::Error> + where + __D: _serde::Deserializer<#delife>, + { + #stmts + } + }; + + Some(Stmts(fn_deserialize_in_place)) +} + +#[cfg(not(feature = "deserialize_in_place"))] +fn deserialize_in_place_body(_cont: &Container, _params: &Parameters) -> Option { + None +} + +/// Generates `Deserialize::deserialize` body for a type with `#[serde(transparent)]` attribute fn deserialize_transparent(cont: &Container, params: &Parameters) -> Fragment { let fields = match &cont.data { Data::Struct(_, fields) => fields, @@ -324,79 +418,40 @@ fn deserialize_transparent(cont: &Container, params: &Parameters) -> Fragment { quote!(#member: __transparent) } else { let value = match field.attrs.default() { - attr::Default::Default => quote!(_serde::__private::Default::default()), - attr::Default::Path(path) => quote!(#path()), - attr::Default::None => quote!(_serde::__private::PhantomData), + attr::Default::Default => quote!(_serde::#private::Default::default()), + // If #path returns wrong type, error will be reported here (^^^^^). + // We attach span of the path to the function so it will be reported + // on the #[serde(default = "...")] + // ^^^^^ + attr::Default::Path(path) => quote_spanned!(path.span()=> #path()), + attr::Default::None => quote!(_serde::#private::PhantomData), }; quote!(#member: #value) } }); quote_block! { - _serde::__private::Result::map( + _serde::#private::Result::map( #path(__deserializer), |__transparent| #this_value { #(#assign),* }) } } +/// Generates `Deserialize::deserialize` body for a type with `#[serde(from)]` attribute fn deserialize_from(type_from: &syn::Type) -> Fragment { quote_block! { - _serde::__private::Result::map( + _serde::#private::Result::map( <#type_from as _serde::Deserialize>::deserialize(__deserializer), - _serde::__private::From::from) + _serde::#private::From::from) } } +/// Generates `Deserialize::deserialize` body for a type with `#[serde(try_from)]` attribute fn deserialize_try_from(type_try_from: &syn::Type) -> Fragment { quote_block! { - _serde::__private::Result::and_then( + _serde::#private::Result::and_then( <#type_try_from as _serde::Deserialize>::deserialize(__deserializer), - |v| _serde::__private::TryFrom::try_from(v).map_err(_serde::de::Error::custom)) - } -} - -fn deserialize_unit_struct(params: &Parameters, cattrs: &attr::Container) -> Fragment { - let this_type = ¶ms.this_type; - let this_value = ¶ms.this_value; - let type_name = cattrs.name().deserialize_name(); - let (de_impl_generics, de_ty_generics, ty_generics, where_clause) = - split_with_de_lifetime(params); - let delife = params.borrowed.de_lifetime(); - - let expecting = format!("unit struct {}", params.type_name()); - let expecting = cattrs.expecting().unwrap_or(&expecting); - - quote_block! { - #[doc(hidden)] - struct __Visitor #de_impl_generics #where_clause { - marker: _serde::__private::PhantomData<#this_type #ty_generics>, - lifetime: _serde::__private::PhantomData<&#delife ()>, - } - - impl #de_impl_generics _serde::de::Visitor<#delife> for __Visitor #de_ty_generics #where_clause { - type Value = #this_type #ty_generics; - - fn expecting(&self, __formatter: &mut _serde::__private::Formatter) -> _serde::__private::fmt::Result { - _serde::__private::Formatter::write_str(__formatter, #expecting) - } - - #[inline] - fn visit_unit<__E>(self) -> _serde::__private::Result - where - __E: _serde::de::Error, - { - _serde::__private::Ok(#this_value) - } - } - - _serde::Deserializer::deserialize_unit_struct( - __deserializer, - #type_name, - __Visitor { - marker: _serde::__private::PhantomData::<#this_type #ty_generics>, - lifetime: _serde::__private::PhantomData, - }, - ) + |v| _serde::#private::TryFrom::try_from(v).map_err(_serde::de::Error::custom)) } } @@ -404,127 +459,8 @@ enum TupleForm<'a> { Tuple, /// Contains a variant name ExternallyTagged(&'a syn::Ident), - /// Contains a variant name and an intermediate deserializer from which actual - /// deserialization will be performed - Untagged(&'a syn::Ident, TokenStream), -} - -fn deserialize_tuple( - params: &Parameters, - fields: &[Field], - cattrs: &attr::Container, - form: TupleForm, -) -> Fragment { - assert!(!cattrs.has_flatten()); - - let field_count = fields - .iter() - .filter(|field| !field.attrs.skip_deserializing()) - .count(); - - let this_type = ¶ms.this_type; - let this_value = ¶ms.this_value; - let (de_impl_generics, de_ty_generics, ty_generics, where_clause) = - split_with_de_lifetime(params); - let delife = params.borrowed.de_lifetime(); - - // If there are getters (implying private fields), construct the local type - // and use an `Into` conversion to get the remote type. If there are no - // getters then construct the target type directly. - let construct = if params.has_getter { - let local = ¶ms.local; - quote!(#local) - } else { - quote!(#this_value) - }; - - let type_path = match form { - TupleForm::Tuple => construct, - TupleForm::ExternallyTagged(variant_ident) | TupleForm::Untagged(variant_ident, _) => { - quote!(#construct::#variant_ident) - } - }; - let expecting = match form { - TupleForm::Tuple => format!("tuple struct {}", params.type_name()), - TupleForm::ExternallyTagged(variant_ident) | TupleForm::Untagged(variant_ident, _) => { - format!("tuple variant {}::{}", params.type_name(), variant_ident) - } - }; - let expecting = cattrs.expecting().unwrap_or(&expecting); - - let nfields = fields.len(); - - let visit_newtype_struct = match form { - TupleForm::Tuple if nfields == 1 => { - Some(deserialize_newtype_struct(&type_path, params, &fields[0])) - } - _ => None, - }; - - let visit_seq = Stmts(deserialize_seq( - &type_path, params, fields, false, cattrs, expecting, - )); - - let visitor_expr = quote! { - __Visitor { - marker: _serde::__private::PhantomData::<#this_type #ty_generics>, - lifetime: _serde::__private::PhantomData, - } - }; - let dispatch = match form { - TupleForm::Tuple if nfields == 1 => { - let type_name = cattrs.name().deserialize_name(); - quote! { - _serde::Deserializer::deserialize_newtype_struct(__deserializer, #type_name, #visitor_expr) - } - } - TupleForm::Tuple => { - let type_name = cattrs.name().deserialize_name(); - quote! { - _serde::Deserializer::deserialize_tuple_struct(__deserializer, #type_name, #field_count, #visitor_expr) - } - } - TupleForm::ExternallyTagged(_) => quote! { - _serde::de::VariantAccess::tuple_variant(__variant, #field_count, #visitor_expr) - }, - TupleForm::Untagged(_, deserializer) => quote! { - _serde::Deserializer::deserialize_tuple(#deserializer, #field_count, #visitor_expr) - }, - }; - - let visitor_var = if field_count == 0 { - quote!(_) - } else { - quote!(mut __seq) - }; - - quote_block! { - #[doc(hidden)] - struct __Visitor #de_impl_generics #where_clause { - marker: _serde::__private::PhantomData<#this_type #ty_generics>, - lifetime: _serde::__private::PhantomData<&#delife ()>, - } - - impl #de_impl_generics _serde::de::Visitor<#delife> for __Visitor #de_ty_generics #where_clause { - type Value = #this_type #ty_generics; - - fn expecting(&self, __formatter: &mut _serde::__private::Formatter) -> _serde::__private::fmt::Result { - _serde::__private::Formatter::write_str(__formatter, #expecting) - } - - #visit_newtype_struct - - #[inline] - fn visit_seq<__A>(self, #visitor_var: __A) -> _serde::__private::Result - where - __A: _serde::de::SeqAccess<#delife>, - { - #visit_seq - } - } - - #dispatch - } + /// Contains a variant name + Untagged(&'a syn::Ident), } fn deserialize_seq( @@ -542,9 +478,9 @@ fn deserialize_seq( .filter(|field| !field.attrs.skip_deserializing()) .count(); let expecting = if deserialized_count == 1 { - format!("{expecting} with 1 element") + format!("{} with 1 element", expecting) } else { - format!("{expecting} with {deserialized_count} elements") + format!("{} with {} elements", expecting, deserialized_count) }; let expecting = cattrs.expecting().unwrap_or(&expecting); @@ -572,9 +508,8 @@ fn deserialize_seq( let context = ¶ms.context; quote!({ #wrapper - _serde::__private::Option::map( - _serde::de::SeqAccess::next_element_seed::<#seed<#context, #wrapper_ty>>( - &mut __seq, #seed::new(self.context))?, + _serde::#private::Option::map( + _serde::de::SeqAccess::next_element_seed::<#seed<#context, #wrapper_ty>>(&mut __seq, #seed::new(self.context))?, |__wrap| __wrap.value) }) } @@ -582,8 +517,8 @@ fn deserialize_seq( let value_if_none = expr_is_missing_seq(None, index_in_seq, field, cattrs, expecting); let assign = quote! { let #var = match #visit { - _serde::__private::Some(__value) => __value, - _serde::__private::None => #value_if_none, + _serde::#private::Some(__value) => __value, + _serde::#private::None => #value_if_none, }; }; index_in_seq += 1; @@ -606,15 +541,19 @@ fn deserialize_seq( let this_type = ¶ms.this_type; let (_, ty_generics, _) = params.generics.split_for_impl(); result = quote! { - _serde::__private::Into::<#this_type #ty_generics>::into(#result) + _serde::#private::Into::<#this_type #ty_generics>::into(#result) }; } let let_default = match cattrs.default() { attr::Default::Default => Some(quote!( - let __default: Self::Value = _serde::__private::Default::default(); + let __default: Self::Value = _serde::#private::Default::default(); )), - attr::Default::Path(path) => Some(quote!( + // If #path returns wrong type, error will be reported here (^^^^^). + // We attach span of the path to the function so it will be reported + // on the #[serde(default = "...")] + // ^^^^^ + attr::Default::Path(path) => Some(quote_spanned!(path.span()=> let __default: Self::Value = #path(); )), attr::Default::None => { @@ -627,51 +566,93 @@ fn deserialize_seq( quote_block! { #let_default #(#let_values)* - _serde::__private::Ok(#result) + _serde::#private::Ok(#result) } } -fn deserialize_newtype_struct( - type_path: &TokenStream, +#[cfg(feature = "deserialize_in_place")] +fn deserialize_seq_in_place( params: &Parameters, - field: &Field, -) -> TokenStream { - let delife = params.borrowed.de_lifetime(); - let field_ty = field.ty; + fields: &[Field], + cattrs: &attr::Container, + expecting: &str, +) -> Fragment { + let deserialized_count = fields + .iter() + .filter(|field| !field.attrs.skip_deserializing()) + .count(); + let expecting = if deserialized_count == 1 { + format!("{} with 1 element", expecting) + } else { + format!("{} with {} elements", expecting, deserialized_count) + }; + let expecting = cattrs.expecting().unwrap_or(&expecting); - let value = match field.attrs.deserialize_with() { - None => { - let span = field.original.span(); - let func = quote_spanned!(span=> <#field_ty as _serde::Deserialize>::deserialize); + let mut index_in_seq = 0usize; + let write_values = fields.iter().map(|field| { + let member = &field.member; + + if field.attrs.skip_deserializing() { + let default = Expr(expr_is_missing(field, cattrs)); quote! { - #func(__e)? + self.place.#member = #default; } + } else { + let value_if_none = expr_is_missing_seq(Some(quote!(self.place.#member = )), index_in_seq, field, cattrs, expecting); + let write = match field.attrs.deserialize_with() { + None => { + quote! { + if let _serde::#private::None = _serde::de::SeqAccess::next_element_seed(&mut __seq, + _serde::#private::de::InPlaceSeed(&mut self.place.#member))? + { + #value_if_none; + } + } + } + Some(path) => { + let (wrapper, wrapper_ty) = wrap_deserialize_field_with(params, field.ty, path); + quote!({ + #wrapper + match _serde::de::SeqAccess::next_element::<#wrapper_ty>(&mut __seq)? { + _serde::#private::Some(__wrap) => { + self.place.#member = __wrap.value; + } + _serde::#private::None => { + #value_if_none; + } + } + }) + } + }; + index_in_seq += 1; + write } - Some(path) => { - quote! { - #path(__e)? - } + }); + + let this_type = ¶ms.this_type; + let (_, ty_generics, _) = params.generics.split_for_impl(); + let let_default = match cattrs.default() { + attr::Default::Default => Some(quote!( + let __default: #this_type #ty_generics = _serde::#private::Default::default(); + )), + // If #path returns wrong type, error will be reported here (^^^^^). + // We attach span of the path to the function so it will be reported + // on the #[serde(default = "...")] + // ^^^^^ + attr::Default::Path(path) => Some(quote_spanned!(path.span()=> + let __default: #this_type #ty_generics = #path(); + )), + attr::Default::None => { + // We don't need the default value, to prevent an unused variable warning + // we'll leave the line empty. + None } }; - let mut result = quote!(#type_path(__field0)); - if params.has_getter { - let this_type = ¶ms.this_type; - let (_, ty_generics, _) = params.generics.split_for_impl(); - result = quote! { - _serde::__private::Into::<#this_type #ty_generics>::into(#result) - }; - } - - quote! { - #[inline] - fn visit_newtype_struct<__E>(self, __e: __E) -> _serde::__private::Result - where - __E: _serde::Deserializer<#delife>, - { - let __field0: #field_ty = #value; - _serde::__private::Ok(#result) - } + quote_block! { + #let_default + #(#write_values)* + _serde::#private::Ok(()) } } @@ -679,1776 +660,85 @@ enum StructForm<'a> { Struct, /// Contains a variant name ExternallyTagged(&'a syn::Ident), - /// Contains a variant name and an intermediate deserializer from which actual - /// deserialization will be performed - InternallyTagged(&'a syn::Ident, TokenStream), - /// Contains a variant name and an intermediate deserializer from which actual - /// deserialization will be performed - Untagged(&'a syn::Ident, TokenStream), + /// Contains a variant name + InternallyTagged(&'a syn::Ident), + /// Contains a variant name + Untagged(&'a syn::Ident), +} + +struct FieldWithAliases<'a> { + ident: Ident, + aliases: &'a BTreeSet, +} + +fn field_i(i: usize) -> Ident { + Ident::new(&format!("__field{}", i), Span::call_site()) } -fn deserialize_struct( +/// This function wraps the expression in `#[serde(deserialize_with = "...")]` +/// in a trait to prevent it from accessing the internal `Deserialize` state. +fn wrap_deserialize_with( params: &Parameters, - fields: &[Field], - cattrs: &attr::Container, - form: StructForm, -) -> Fragment { + value_ty: &TokenStream, + deserialize_with: &syn::ExprPath, +) -> (TokenStream, TokenStream) { let this_type = ¶ms.this_type; - let this_value = ¶ms.this_value; let (de_impl_generics, de_ty_generics, ty_generics, where_clause) = - split_with_de_lifetime(params); + params.generics_with_de_lifetime(); let delife = params.borrowed.de_lifetime(); - let context = ¶ms.context; - - // If there are getters (implying private fields), construct the local type - // and use an `Into` conversion to get the remote type. If there are no - // getters then construct the target type directly. - let construct = if params.has_getter { - let local = ¶ms.local; - quote!(#local) - } else { - quote!(#this_value) - }; + let deserializer_var = quote!(__deserializer); - let type_path = match form { - StructForm::Struct => construct, - StructForm::ExternallyTagged(variant_ident) - | StructForm::InternallyTagged(variant_ident, _) - | StructForm::Untagged(variant_ident, _) => quote!(#construct::#variant_ident), + // If #deserialize_with returns wrong type, error will be reported here (^^^^^). + // We attach span of the path to the function so it will be reported + // on the #[serde(with = "...")] + // ^^^^^ + let value = quote_spanned! {deserialize_with.span()=> + #deserialize_with(#deserializer_var)? }; - let expecting = match form { - StructForm::Struct => format!("struct {}", params.type_name()), - StructForm::ExternallyTagged(variant_ident) - | StructForm::InternallyTagged(variant_ident, _) - | StructForm::Untagged(variant_ident, _) => { - format!("struct variant {}::{}", params.type_name(), variant_ident) + let wrapper = quote! { + #[doc(hidden)] + struct __DeserializeWith #de_impl_generics #where_clause { + value: #value_ty, + phantom: _serde::#private::PhantomData<#this_type #ty_generics>, + lifetime: _serde::#private::PhantomData<&#delife ()>, } - }; - let expecting = cattrs.expecting().unwrap_or(&expecting); - - let field_names_idents: Vec<_> = fields - .iter() - .enumerate() - // Skip fields that shouldn't be deserialized or that were flattened, - // so they don't appear in the storage in their literal form - .filter(|&(_, field)| !field.attrs.skip_deserializing() && !field.attrs.flatten()) - .map(|(i, field)| { - ( - field.attrs.name().deserialize_name(), - field_i(i), - field.attrs.aliases(), - ) - }) - .collect(); - let field_visitor = deserialize_field_identifier(&field_names_idents, cattrs); - - // untagged struct variants do not get a visit_seq method. The same applies to - // structs that only have a map representation. - let visit_seq = match form { - StructForm::Untagged(..) => None, - _ if cattrs.has_flatten() => None, - _ => { - let mut_seq = if field_names_idents.is_empty() { - quote!(_) - } else { - quote!(mut __seq) - }; - - let visit_seq = Stmts(deserialize_seq( - &type_path, params, fields, true, cattrs, expecting, - )); - Some(quote! { - #[inline] - fn visit_seq<__A>(self, #mut_seq: __A) -> _serde::__private::Result - where - __A: _serde::de::SeqAccess<#delife>, - { - #visit_seq - } - }) + #[automatically_derived] + impl #de_impl_generics _serde::Deserialize<#delife> for __DeserializeWith #de_ty_generics #where_clause { + fn deserialize<__D>(#deserializer_var: __D) -> _serde::#private::Result + where + __D: _serde::Deserializer<#delife>, + { + _serde::#private::Ok(__DeserializeWith { + value: #value, + phantom: _serde::#private::PhantomData, + lifetime: _serde::#private::PhantomData, + }) + } } }; - let visit_map = Stmts(deserialize_map(&type_path, params, fields, cattrs)); - let visitor_seed = match form { - StructForm::ExternallyTagged(..) if cattrs.has_flatten() => Some(quote! { - impl #de_impl_generics _serde::de::DeserializeSeed<#delife> for __Visitor #de_ty_generics #where_clause { - type Value = #this_type #ty_generics; + let wrapper_ty = quote!(__DeserializeWith #de_ty_generics); - fn deserialize<__D>(self, __deserializer: __D) -> _serde::__private::Result - where - __D: _serde::Deserializer<#delife>, - { - _serde::Deserializer::deserialize_map(__deserializer, self) - } - } - }), - _ => None, - }; + (wrapper, wrapper_ty) +} - let fields_stmt = if cattrs.has_flatten() { - None - } else { - let field_names = field_names_idents - .iter() - .flat_map(|&(_, _, aliases)| aliases); +fn wrap_deserialize_field_with( + params: &Parameters, + field_ty: &syn::Type, + deserialize_with: &syn::ExprPath, +) -> (TokenStream, TokenStream) { + wrap_deserialize_with(params, "e!(#field_ty), deserialize_with) +} - Some(quote! { - #[doc(hidden)] - const FIELDS: &'static [&'static str] = &[ #(#field_names),* ]; - }) - }; - - let visitor_expr = quote! { - __Visitor { - context: self.context, - marker: _serde::__private::PhantomData::<#this_type #ty_generics>, - lifetime: _serde::__private::PhantomData, - } - }; - let dispatch = match form { - StructForm::Struct if cattrs.has_flatten() => quote! { - _serde::Deserializer::deserialize_map(__deserializer, #visitor_expr) - }, - StructForm::Struct => { - let type_name = cattrs.name().deserialize_name(); - quote! { - _serde::Deserializer::deserialize_struct(__deserializer, #type_name, FIELDS, #visitor_expr) - } - } - StructForm::ExternallyTagged(_) if cattrs.has_flatten() => quote! { - _serde::de::VariantAccess::newtype_variant_seed(__variant, #visitor_expr) - }, - StructForm::ExternallyTagged(_) => quote! { - _serde::de::VariantAccess::struct_variant(__variant, FIELDS, #visitor_expr) - }, - StructForm::InternallyTagged(_, deserializer) => quote! { - _serde::Deserializer::deserialize_any(#deserializer, #visitor_expr) - }, - StructForm::Untagged(_, deserializer) => quote! { - _serde::Deserializer::deserialize_any(#deserializer, #visitor_expr) - }, - }; - - quote_block! { - #field_visitor - - #[doc(hidden)] - struct __Visitor #de_impl_generics #where_clause { - context: &#delife #context, - marker: _serde::__private::PhantomData<#this_type #ty_generics>, - lifetime: _serde::__private::PhantomData<&#delife ()>, - } - - impl #de_impl_generics _serde::de::Visitor<#delife> for __Visitor #de_ty_generics #where_clause { - type Value = #this_type #ty_generics; - - fn expecting(&self, __formatter: &mut _serde::__private::Formatter) -> _serde::__private::fmt::Result { - _serde::__private::Formatter::write_str(__formatter, #expecting) - } - - #visit_seq - - #[inline] - fn visit_map<__A>(self, mut __map: __A) -> _serde::__private::Result - where - __A: _serde::de::MapAccess<#delife>, - { - #visit_map - } - } - - #visitor_seed - - #fields_stmt - - #dispatch - } -} - -fn deserialize_enum( - params: &Parameters, - variants: &[Variant], - cattrs: &attr::Container, -) -> Fragment { - // The variants have already been checked (in ast.rs) that all untagged variants appear at the end - match variants.iter().position(|var| var.attrs.untagged()) { - Some(variant_idx) => { - let (tagged, untagged) = variants.split_at(variant_idx); - let tagged_frag = Expr(deserialize_homogeneous_enum(params, tagged, cattrs)); - deserialize_untagged_enum_after(params, untagged, cattrs, Some(tagged_frag)) - } - None => deserialize_homogeneous_enum(params, variants, cattrs), - } -} - -fn deserialize_homogeneous_enum( - params: &Parameters, - variants: &[Variant], - cattrs: &attr::Container, -) -> Fragment { - match cattrs.tag() { - attr::TagType::External => deserialize_externally_tagged_enum(params, variants, cattrs), - attr::TagType::Internal { tag } => { - deserialize_internally_tagged_enum(params, variants, cattrs, tag) - } - attr::TagType::Adjacent { tag, content } => { - deserialize_adjacently_tagged_enum(params, variants, cattrs, tag, content) - } - attr::TagType::None => deserialize_untagged_enum(params, variants, cattrs), - } -} - -fn prepare_enum_variant_enum( - variants: &[Variant], - cattrs: &attr::Container, -) -> (TokenStream, Stmts) { - let mut deserialized_variants = variants - .iter() - .enumerate() - .filter(|&(_, variant)| !variant.attrs.skip_deserializing()); - - let variant_names_idents: Vec<_> = deserialized_variants - .clone() - .map(|(i, variant)| { - ( - variant.attrs.name().deserialize_name(), - field_i(i), - variant.attrs.aliases(), - ) - }) - .collect(); - - let fallthrough = deserialized_variants - .position(|(_, variant)| variant.attrs.other()) - .map(|other_idx| { - let ignore_variant = variant_names_idents[other_idx].1.clone(); - quote!(_serde::__private::Ok(__Field::#ignore_variant)) - }); - - let variants_stmt = { - let variant_names = variant_names_idents.iter().map(|(name, _, _)| name); - quote! { - #[doc(hidden)] - const VARIANTS: &'static [&'static str] = &[ #(#variant_names),* ]; - } - }; - - let variant_visitor = Stmts(deserialize_generated_identifier( - &variant_names_idents, - cattrs, - true, - None, - fallthrough, - )); - - (variants_stmt, variant_visitor) -} - -fn deserialize_externally_tagged_enum( - params: &Parameters, - variants: &[Variant], - cattrs: &attr::Container, -) -> Fragment { - let this_type = ¶ms.this_type; - let (de_impl_generics, de_ty_generics, ty_generics, where_clause) = - split_with_de_lifetime(params); - let delife = params.borrowed.de_lifetime(); - let context = ¶ms.context; - - let type_name = cattrs.name().deserialize_name(); - let expecting = format!("enum {}", params.type_name()); - let expecting = cattrs.expecting().unwrap_or(&expecting); - - let (variants_stmt, variant_visitor) = prepare_enum_variant_enum(variants, cattrs); - - // Match arms to extract a variant from a string - let variant_arms = variants - .iter() - .enumerate() - .filter(|&(_, variant)| !variant.attrs.skip_deserializing()) - .map(|(i, variant)| { - let variant_name = field_i(i); - - let block = Match(deserialize_externally_tagged_variant( - params, variant, cattrs, - )); - - quote! { - (__Field::#variant_name, __variant) => #block - } - }); - - let all_skipped = variants - .iter() - .all(|variant| variant.attrs.skip_deserializing()); - let match_variant = if all_skipped { - // This is an empty enum like `enum Impossible {}` or an enum in which - // all variants have `#[serde(skip_deserializing)]`. - quote! { - // FIXME: Once feature(exhaustive_patterns) is stable: - // let _serde::__private::Err(__err) = _serde::de::EnumAccess::variant::<__Field>(__data); - // _serde::__private::Err(__err) - _serde::__private::Result::map( - _serde::de::EnumAccess::variant::<__Field>(__data), - |(__impossible, _)| match __impossible {}) - } - } else { - quote! { - match _serde::de::EnumAccess::variant(__data)? { - #(#variant_arms)* - } - } - }; - - quote_block! { - #variant_visitor - - #[doc(hidden)] - struct __Visitor #de_impl_generics #where_clause { - context: &#delife #context, - marker: _serde::__private::PhantomData<#this_type #ty_generics>, - lifetime: _serde::__private::PhantomData<&#delife ()>, - } - - impl #de_impl_generics _serde::de::Visitor<#delife> for __Visitor #de_ty_generics #where_clause { - type Value = #this_type #ty_generics; - - fn expecting(&self, __formatter: &mut _serde::__private::Formatter) -> _serde::__private::fmt::Result { - _serde::__private::Formatter::write_str(__formatter, #expecting) - } - - fn visit_enum<__A>(self, __data: __A) -> _serde::__private::Result - where - __A: _serde::de::EnumAccess<#delife>, - { - #match_variant - } - } - - #variants_stmt - - _serde::Deserializer::deserialize_enum( - __deserializer, - #type_name, - VARIANTS, - __Visitor { - context: self.context, - marker: _serde::__private::PhantomData::<#this_type #ty_generics>, - lifetime: _serde::__private::PhantomData, - }, - ) - } -} - -fn deserialize_internally_tagged_enum( - params: &Parameters, - variants: &[Variant], - cattrs: &attr::Container, - tag: &str, -) -> Fragment { - let (variants_stmt, variant_visitor) = prepare_enum_variant_enum(variants, cattrs); - - // Match arms to extract a variant from a string - let variant_arms = variants - .iter() - .enumerate() - .filter(|&(_, variant)| !variant.attrs.skip_deserializing()) - .map(|(i, variant)| { - let variant_name = field_i(i); - - let block = Match(deserialize_internally_tagged_variant( - params, - variant, - cattrs, - quote!(__deserializer), - )); - - quote! { - __Field::#variant_name => #block - } - }); - - let expecting = format!("internally tagged enum {}", params.type_name()); - let expecting = cattrs.expecting().unwrap_or(&expecting); - - quote_block! { - #variant_visitor - - #variants_stmt - - let (__tag, __content) = _serde::Deserializer::deserialize_any( - __deserializer, - _serde::__private::de::TaggedContentVisitor::<__Field>::new(#tag, #expecting))?; - let __deserializer = _serde::__private::de::ContentDeserializer::<__D::Error>::new(__content); - - match __tag { - #(#variant_arms)* - } - } -} - -fn deserialize_adjacently_tagged_enum( - params: &Parameters, - variants: &[Variant], - cattrs: &attr::Container, - tag: &str, - content: &str, -) -> Fragment { - let this_type = ¶ms.this_type; - let this_value = ¶ms.this_value; - let (de_impl_generics, de_ty_generics, ty_generics, where_clause) = - split_with_de_lifetime(params); - let delife = params.borrowed.de_lifetime(); - let context = ¶ms.context; - - let (variants_stmt, variant_visitor) = prepare_enum_variant_enum(variants, cattrs); - - let variant_arms: &Vec<_> = &variants - .iter() - .enumerate() - .filter(|&(_, variant)| !variant.attrs.skip_deserializing()) - .map(|(i, variant)| { - let variant_index = field_i(i); - - let block = Match(deserialize_untagged_variant( - params, - variant, - cattrs, - quote!(__deserializer), - )); - - quote! { - __Field::#variant_index => #block - } - }) - .collect(); - - let rust_name = params.type_name(); - let expecting = format!("adjacently tagged enum {rust_name}"); - let expecting = cattrs.expecting().unwrap_or(&expecting); - let type_name = cattrs.name().deserialize_name(); - let deny_unknown_fields = cattrs.deny_unknown_fields(); - - // If unknown fields are allowed, we pick the visitor that can step over - // those. Otherwise we pick the visitor that fails on unknown keys. - let field_visitor_ty = if deny_unknown_fields { - quote! { _serde::__private::de::TagOrContentFieldVisitor } - } else { - quote! { _serde::__private::de::TagContentOtherFieldVisitor } - }; - - let tag_or_content = quote! { - #field_visitor_ty { - tag: #tag, - content: #content, - } - }; - - let variant_seed = quote! { - _serde::__private::de::AdjacentlyTaggedEnumVariantSeed::<__Field> { - enum_name: #rust_name, - variants: VARIANTS, - fields_enum: _serde::__private::PhantomData - } - }; - - let mut missing_content = quote! { - _serde::__private::Err(<__A::Error as _serde::de::Error>::missing_field(#content)) - }; - let mut missing_content_fallthrough = quote!(); - let missing_content_arms = variants - .iter() - .enumerate() - .filter(|&(_, variant)| !variant.attrs.skip_deserializing()) - .filter_map(|(i, variant)| { - let variant_index = field_i(i); - let variant_ident = &variant.ident; - - let arm = match variant.style { - Style::Unit => quote! { - _serde::__private::Ok(#this_value::#variant_ident) - }, - Style::Newtype if variant.attrs.deserialize_with().is_none() => { - let span = variant.original.span(); - let func = quote_spanned!(span=> _serde::__private::de::missing_field); - quote! { - #func(#content).map(#this_value::#variant_ident) - } - } - _ => { - missing_content_fallthrough = quote!(_ => #missing_content); - return None; - } - }; - Some(quote! { - __Field::#variant_index => #arm, - }) - }) - .collect::>(); - if !missing_content_arms.is_empty() { - missing_content = quote! { - match __field { - #(#missing_content_arms)* - #missing_content_fallthrough - } - }; - } - - // Advance the map by one key, returning early in case of error. - let next_key = quote! { - _serde::de::MapAccess::next_key_seed(&mut __map, #tag_or_content)? - }; - - let variant_from_map = quote! { - _serde::de::MapAccess::next_value_seed(&mut __map, #variant_seed)? - }; - - // When allowing unknown fields, we want to transparently step through keys - // we don't care about until we find `tag`, `content`, or run out of keys. - let next_relevant_key = if deny_unknown_fields { - next_key - } else { - quote!({ - let mut __rk : _serde::__private::Option<_serde::__private::de::TagOrContentField> = _serde::__private::None; - while let _serde::__private::Some(__k) = #next_key { - match __k { - _serde::__private::de::TagContentOtherField::Other => { - let _ = _serde::de::MapAccess::next_value::<_serde::de::IgnoredAny>(&mut __map)?; - continue; - }, - _serde::__private::de::TagContentOtherField::Tag => { - __rk = _serde::__private::Some(_serde::__private::de::TagOrContentField::Tag); - break; - } - _serde::__private::de::TagContentOtherField::Content => { - __rk = _serde::__private::Some(_serde::__private::de::TagOrContentField::Content); - break; - } - } - } - - __rk - }) - }; - - // Step through remaining keys, looking for duplicates of previously-seen - // keys. When unknown fields are denied, any key that isn't a duplicate will - // at this point immediately produce an error. - let visit_remaining_keys = quote! { - match #next_relevant_key { - _serde::__private::Some(_serde::__private::de::TagOrContentField::Tag) => { - _serde::__private::Err(<__A::Error as _serde::de::Error>::duplicate_field(#tag)) - } - _serde::__private::Some(_serde::__private::de::TagOrContentField::Content) => { - _serde::__private::Err(<__A::Error as _serde::de::Error>::duplicate_field(#content)) - } - _serde::__private::None => _serde::__private::Ok(__ret), - } - }; - - let finish_content_then_tag = if variant_arms.is_empty() { - quote! { - match #variant_from_map {} - } - } else { - quote! { - let __ret = match #variant_from_map { - // Deserialize the buffered content now that we know the variant. - #(#variant_arms)* - }?; - // Visit remaining keys, looking for duplicates. - #visit_remaining_keys - } - }; - - quote_block! { - #variant_visitor - - #variants_stmt - - #[doc(hidden)] - struct __Seed #de_impl_generics #where_clause { - field: __Field, - marker: _serde::__private::PhantomData<#this_type #ty_generics>, - lifetime: _serde::__private::PhantomData<&#delife ()>, - } - - impl #de_impl_generics _serde::de::DeserializeSeed<#delife> for __Seed #de_ty_generics #where_clause { - type Value = #this_type #ty_generics; - - fn deserialize<__D>(self, __deserializer: __D) -> _serde::__private::Result - where - __D: _serde::Deserializer<#delife>, - { - match self.field { - #(#variant_arms)* - } - } - } - - #[doc(hidden)] - struct __Visitor #de_impl_generics #where_clause { - context: &#delife #context, - marker: _serde::__private::PhantomData<#this_type #ty_generics>, - lifetime: _serde::__private::PhantomData<&#delife ()>, - } - - impl #de_impl_generics _serde::de::Visitor<#delife> for __Visitor #de_ty_generics #where_clause { - type Value = #this_type #ty_generics; - - fn expecting(&self, __formatter: &mut _serde::__private::Formatter) -> _serde::__private::fmt::Result { - _serde::__private::Formatter::write_str(__formatter, #expecting) - } - - fn visit_map<__A>(self, mut __map: __A) -> _serde::__private::Result - where - __A: _serde::de::MapAccess<#delife>, - { - // Visit the first relevant key. - match #next_relevant_key { - // First key is the tag. - _serde::__private::Some(_serde::__private::de::TagOrContentField::Tag) => { - // Parse the tag. - let __field = #variant_from_map; - // Visit the second key. - match #next_relevant_key { - // Second key is a duplicate of the tag. - _serde::__private::Some(_serde::__private::de::TagOrContentField::Tag) => { - _serde::__private::Err(<__A::Error as _serde::de::Error>::duplicate_field(#tag)) - } - // Second key is the content. - _serde::__private::Some(_serde::__private::de::TagOrContentField::Content) => { - let __ret = _serde::de::MapAccess::next_value_seed(&mut __map, - __Seed { - field: __field, - marker: _serde::__private::PhantomData, - lifetime: _serde::__private::PhantomData, - })?; - // Visit remaining keys, looking for duplicates. - #visit_remaining_keys - } - // There is no second key; might be okay if the we have a unit variant. - _serde::__private::None => #missing_content - } - } - // First key is the content. - _serde::__private::Some(_serde::__private::de::TagOrContentField::Content) => { - // Buffer up the content. - let __content = _serde::de::MapAccess::next_value::<_serde::__private::de::Content>(&mut __map)?; - // Visit the second key. - match #next_relevant_key { - // Second key is the tag. - _serde::__private::Some(_serde::__private::de::TagOrContentField::Tag) => { - let __deserializer = _serde::__private::de::ContentDeserializer::<__A::Error>::new(__content); - #finish_content_then_tag - } - // Second key is a duplicate of the content. - _serde::__private::Some(_serde::__private::de::TagOrContentField::Content) => { - _serde::__private::Err(<__A::Error as _serde::de::Error>::duplicate_field(#content)) - } - // There is no second key. - _serde::__private::None => { - _serde::__private::Err(<__A::Error as _serde::de::Error>::missing_field(#tag)) - } - } - } - // There is no first key. - _serde::__private::None => { - _serde::__private::Err(<__A::Error as _serde::de::Error>::missing_field(#tag)) - } - } - } - - fn visit_seq<__A>(self, mut __seq: __A) -> _serde::__private::Result - where - __A: _serde::de::SeqAccess<#delife>, - { - // Visit the first element - the tag. - match _serde::de::SeqAccess::next_element(&mut __seq)? { - _serde::__private::Some(__field) => { - // Visit the second element - the content. - match _serde::de::SeqAccess::next_element_seed( - &mut __seq, - __Seed { - field: __field, - marker: _serde::__private::PhantomData, - lifetime: _serde::__private::PhantomData, - }, - )? { - _serde::__private::Some(__ret) => _serde::__private::Ok(__ret), - // There is no second element. - _serde::__private::None => { - _serde::__private::Err(_serde::de::Error::invalid_length(1, &self)) - } - } - } - // There is no first element. - _serde::__private::None => { - _serde::__private::Err(_serde::de::Error::invalid_length(0, &self)) - } - } - } - } - - #[doc(hidden)] - const FIELDS: &'static [&'static str] = &[#tag, #content]; - _serde::Deserializer::deserialize_struct( - __deserializer, - #type_name, - FIELDS, - __Visitor { - context: self.context, - marker: _serde::__private::PhantomData::<#this_type #ty_generics>, - lifetime: _serde::__private::PhantomData, - }, - ) - } -} - -fn deserialize_untagged_enum( - params: &Parameters, - variants: &[Variant], - cattrs: &attr::Container, -) -> Fragment { - let first_attempt = None; - deserialize_untagged_enum_after(params, variants, cattrs, first_attempt) -} - -fn deserialize_untagged_enum_after( - params: &Parameters, - variants: &[Variant], - cattrs: &attr::Container, - first_attempt: Option, -) -> Fragment { - let attempts = variants - .iter() - .filter(|variant| !variant.attrs.skip_deserializing()) - .map(|variant| { - Expr(deserialize_untagged_variant( - params, - variant, - cattrs, - quote!(__deserializer), - )) - }); - // TODO this message could be better by saving the errors from the failed - // attempts. The heuristic used by TOML was to count the number of fields - // processed before an error, and use the error that happened after the - // largest number of fields. I'm not sure I like that. Maybe it would be - // better to save all the errors and combine them into one message that - // explains why none of the variants matched. - let fallthrough_msg = format!( - "data did not match any variant of untagged enum {}", - params.type_name() - ); - let fallthrough_msg = cattrs.expecting().unwrap_or(&fallthrough_msg); - - // Ignore any error associated with non-untagged deserialization so that we - // can fall through to the untagged variants. This may be infallible so we - // need to provide the error type. - let first_attempt = first_attempt.map(|expr| { - quote! { - if let _serde::__private::Result::<_, __D::Error>::Ok(__ok) = (|| #expr)() { - return _serde::__private::Ok(__ok); - } - } - }); - - quote_block! { - let __content = <_serde::__private::de::Content as _serde::Deserialize>::deserialize(__deserializer)?; - let __deserializer = _serde::__private::de::ContentRefDeserializer::<__D::Error>::new(&__content); - - #first_attempt - - #( - if let _serde::__private::Ok(__ok) = #attempts { - return _serde::__private::Ok(__ok); - } - )* - - _serde::__private::Err(_serde::de::Error::custom(#fallthrough_msg)) - } -} - -fn deserialize_externally_tagged_variant( - params: &Parameters, - variant: &Variant, - cattrs: &attr::Container, -) -> Fragment { - if let Some(path) = variant.attrs.deserialize_with() { - let (wrapper, wrapper_ty, unwrap_fn) = wrap_deserialize_variant_with(params, variant, path); - return quote_block! { - #wrapper - _serde::__private::Result::map( - _serde::de::VariantAccess::newtype_variant::<#wrapper_ty>(__variant), #unwrap_fn) - }; - } - - let variant_ident = &variant.ident; - - match variant.style { - Style::Unit => { - let this_value = ¶ms.this_value; - quote_block! { - _serde::de::VariantAccess::unit_variant(__variant)?; - _serde::__private::Ok(#this_value::#variant_ident) - } - } - Style::Newtype => deserialize_externally_tagged_newtype_variant( - variant_ident, - params, - &variant.fields[0], - cattrs, - ), - Style::Tuple => deserialize_tuple( - params, - &variant.fields, - cattrs, - TupleForm::ExternallyTagged(variant_ident), - ), - Style::Struct => deserialize_struct( - params, - &variant.fields, - cattrs, - StructForm::ExternallyTagged(variant_ident), - ), - } -} - -// Generates significant part of the visit_seq and visit_map bodies of visitors -// for the variants of internally tagged enum. -fn deserialize_internally_tagged_variant( - params: &Parameters, - variant: &Variant, - cattrs: &attr::Container, - deserializer: TokenStream, -) -> Fragment { - if variant.attrs.deserialize_with().is_some() { - return deserialize_untagged_variant(params, variant, cattrs, deserializer); - } - - let variant_ident = &variant.ident; - - match effective_style(variant) { - Style::Unit => { - let this_value = ¶ms.this_value; - let type_name = params.type_name(); - let variant_name = variant.ident.to_string(); - let default = variant.fields.first().map(|field| { - let default = Expr(expr_is_missing(field, cattrs)); - quote!((#default)) - }); - quote_block! { - _serde::Deserializer::deserialize_any(#deserializer, _serde::__private::de::InternallyTaggedUnitVisitor::new(#type_name, #variant_name))?; - _serde::__private::Ok(#this_value::#variant_ident #default) - } - } - Style::Newtype => deserialize_untagged_newtype_variant( - variant_ident, - params, - &variant.fields[0], - &deserializer, - ), - Style::Struct => deserialize_struct( - params, - &variant.fields, - cattrs, - StructForm::InternallyTagged(variant_ident, deserializer), - ), - Style::Tuple => unreachable!("checked in serde_derive_internals"), - } -} - -fn deserialize_untagged_variant( - params: &Parameters, - variant: &Variant, - cattrs: &attr::Container, - deserializer: TokenStream, -) -> Fragment { - if let Some(path) = variant.attrs.deserialize_with() { - let unwrap_fn = unwrap_to_variant_closure(params, variant, false); - return quote_block! { - _serde::__private::Result::map(#path(#deserializer), #unwrap_fn) - }; - } - - let variant_ident = &variant.ident; - - match effective_style(variant) { - Style::Unit => { - let this_value = ¶ms.this_value; - let type_name = params.type_name(); - let variant_name = variant.ident.to_string(); - let default = variant.fields.first().map(|field| { - let default = Expr(expr_is_missing(field, cattrs)); - quote!((#default)) - }); - quote_expr! { - match _serde::Deserializer::deserialize_any( - #deserializer, - _serde::__private::de::UntaggedUnitVisitor::new(#type_name, #variant_name) - ) { - _serde::__private::Ok(()) => _serde::__private::Ok(#this_value::#variant_ident #default), - _serde::__private::Err(__err) => _serde::__private::Err(__err), - } - } - } - Style::Newtype => deserialize_untagged_newtype_variant( - variant_ident, - params, - &variant.fields[0], - &deserializer, - ), - Style::Tuple => deserialize_tuple( - params, - &variant.fields, - cattrs, - TupleForm::Untagged(variant_ident, deserializer), - ), - Style::Struct => deserialize_struct( - params, - &variant.fields, - cattrs, - StructForm::Untagged(variant_ident, deserializer), - ), - } -} - -fn deserialize_externally_tagged_newtype_variant( - variant_ident: &syn::Ident, - params: &Parameters, - field: &Field, - cattrs: &attr::Container, -) -> Fragment { - let this_value = ¶ms.this_value; - - if field.attrs.skip_deserializing() { - let default = Expr(expr_is_missing(field, cattrs)); - return quote_block! { - _serde::de::VariantAccess::unit_variant(__variant)?; - _serde::__private::Ok(#this_value::#variant_ident(#default)) - }; - } - - match field.attrs.deserialize_with() { - None => { - let field_ty = field.ty; - let span = field.original.span(); - let seed = ¶ms.seed; - let context = ¶ms.context; - let func = quote_spanned!(span=> _serde::de::VariantAccess::newtype_variant_seed::<#seed<#context, #field_ty>>); - quote_expr! { - _serde::__private::Result::map(#func(__variant, #seed::new(self.context)), #this_value::#variant_ident) - } - } - Some(path) => { - let (wrapper, wrapper_ty) = wrap_deserialize_field_with(params, field.ty, path); - let seed = ¶ms.seed; - let context = ¶ms.context; - quote_block! { - #wrapper - _serde::__private::Result::map( - _serde::de::VariantAccess::newtype_variant::<#seed<#context, #wrapper_ty>>( - __variant, #seed::new(self.context)), - |__wrapper| #this_value::#variant_ident(__wrapper.value)) - } - } - } -} - -fn deserialize_untagged_newtype_variant( - variant_ident: &syn::Ident, - params: &Parameters, - field: &Field, - deserializer: &TokenStream, -) -> Fragment { - let this_value = ¶ms.this_value; - let field_ty = field.ty; - match field.attrs.deserialize_with() { - None => { - let span = field.original.span(); - let func = quote_spanned!(span=> <#field_ty as _serde::Deserialize>::deserialize); - quote_expr! { - _serde::__private::Result::map(#func(#deserializer), #this_value::#variant_ident) - } - } - Some(path) => { - quote_block! { - let __value: _serde::__private::Result<#field_ty, _> = #path(#deserializer); - _serde::__private::Result::map(__value, #this_value::#variant_ident) - } - } - } -} - -fn deserialize_generated_identifier( - fields: &[(&str, Ident, &BTreeSet)], - cattrs: &attr::Container, - is_variant: bool, - ignore_variant: Option, - fallthrough: Option, -) -> Fragment { - let this_value = quote!(__Field); - let field_idents: &Vec<_> = &fields.iter().map(|(_, ident, _)| ident).collect(); - - let visitor_impl = Stmts(deserialize_identifier( - &this_value, - fields, - is_variant, - fallthrough, - None, - !is_variant && cattrs.has_flatten(), - None, - )); - - let lifetime = if !is_variant && cattrs.has_flatten() { - Some(quote!(<'de>)) - } else { - None - }; - - quote_block! { - #[allow(non_camel_case_types)] - #[doc(hidden)] - enum __Field #lifetime { - #(#field_idents,)* - #ignore_variant - } - - #[doc(hidden)] - struct __FieldVisitor; - - impl<'de> _serde::de::Visitor<'de> for __FieldVisitor { - type Value = __Field #lifetime; - - #visitor_impl - } - - impl<'de> _serde::Deserialize<'de> for __Field #lifetime { - #[inline] - fn deserialize<__D>(__deserializer: __D) -> _serde::__private::Result - where - __D: _serde::Deserializer<'de>, - { - _serde::Deserializer::deserialize_identifier(__deserializer, __FieldVisitor) - } - } - } -} - -/// Generates enum and its `Deserialize` implementation that represents each -/// non-skipped field of the struct -fn deserialize_field_identifier( - fields: &[(&str, Ident, &BTreeSet)], - cattrs: &attr::Container, -) -> Stmts { - let (ignore_variant, fallthrough) = if cattrs.has_flatten() { - let ignore_variant = quote!(__other(_serde::__private::de::Content<'de>),); - let fallthrough = quote!(_serde::__private::Ok(__Field::__other(__value))); - (Some(ignore_variant), Some(fallthrough)) - } else if cattrs.deny_unknown_fields() { - (None, None) - } else { - let ignore_variant = quote!(__ignore,); - let fallthrough = quote!(_serde::__private::Ok(__Field::__ignore)); - (Some(ignore_variant), Some(fallthrough)) - }; - - Stmts(deserialize_generated_identifier( - fields, - cattrs, - false, - ignore_variant, - fallthrough, - )) -} - -// Generates `Deserialize::deserialize` body for an enum with -// `serde(field_identifier)` or `serde(variant_identifier)` attribute. -fn deserialize_custom_identifier( - params: &Parameters, - variants: &[Variant], - cattrs: &attr::Container, -) -> Fragment { - let is_variant = match cattrs.identifier() { - attr::Identifier::Variant => true, - attr::Identifier::Field => false, - attr::Identifier::No => unreachable!(), - }; - - let this_type = params.this_type.to_token_stream(); - let this_value = params.this_value.to_token_stream(); - - let (ordinary, fallthrough, fallthrough_borrowed) = if let Some(last) = variants.last() { - let last_ident = &last.ident; - if last.attrs.other() { - // Process `serde(other)` attribute. It would always be found on the - // last variant (checked in `check_identifier`), so all preceding - // are ordinary variants. - let ordinary = &variants[..variants.len() - 1]; - let fallthrough = quote!(_serde::__private::Ok(#this_value::#last_ident)); - (ordinary, Some(fallthrough), None) - } else if let Style::Newtype = last.style { - let ordinary = &variants[..variants.len() - 1]; - let fallthrough = |value| { - quote! { - _serde::__private::Result::map( - _serde::Deserialize::deserialize( - _serde::__private::de::IdentifierDeserializer::from(#value) - ), - #this_value::#last_ident) - } - }; - ( - ordinary, - Some(fallthrough(quote!(__value))), - Some(fallthrough(quote!(_serde::__private::de::Borrowed( - __value - )))), - ) - } else { - (variants, None, None) - } - } else { - (variants, None, None) - }; - - let names_idents: Vec<_> = ordinary - .iter() - .map(|variant| { - ( - variant.attrs.name().deserialize_name(), - variant.ident.clone(), - variant.attrs.aliases(), - ) - }) - .collect(); - - let names = names_idents.iter().flat_map(|&(_, _, aliases)| aliases); - - let names_const = if fallthrough.is_some() { - None - } else if is_variant { - let variants = quote! { - #[doc(hidden)] - const VARIANTS: &'static [&'static str] = &[ #(#names),* ]; - }; - Some(variants) - } else { - let fields = quote! { - #[doc(hidden)] - const FIELDS: &'static [&'static str] = &[ #(#names),* ]; - }; - Some(fields) - }; - - let (de_impl_generics, de_ty_generics, ty_generics, where_clause) = - split_with_de_lifetime(params); - let delife = params.borrowed.de_lifetime(); - let visitor_impl = Stmts(deserialize_identifier( - &this_value, - &names_idents, - is_variant, - fallthrough, - fallthrough_borrowed, - false, - cattrs.expecting(), - )); - - quote_block! { - #names_const - - #[doc(hidden)] - struct __FieldVisitor #de_impl_generics #where_clause { - marker: _serde::__private::PhantomData<#this_type #ty_generics>, - lifetime: _serde::__private::PhantomData<&#delife ()>, - } - - impl #de_impl_generics _serde::de::Visitor<#delife> for __FieldVisitor #de_ty_generics #where_clause { - type Value = #this_type #ty_generics; - - #visitor_impl - } - - let __visitor = __FieldVisitor { - marker: _serde::__private::PhantomData::<#this_type #ty_generics>, - lifetime: _serde::__private::PhantomData, - }; - _serde::Deserializer::deserialize_identifier(__deserializer, __visitor) - } -} - -fn deserialize_identifier( - this_value: &TokenStream, - fields: &[(&str, Ident, &BTreeSet)], - is_variant: bool, - fallthrough: Option, - fallthrough_borrowed: Option, - collect_other_fields: bool, - expecting: Option<&str>, -) -> Fragment { - let str_mapping = fields.iter().map(|(_, ident, aliases)| { - // `aliases` also contains a main name - quote!(#(#aliases)|* => _serde::__private::Ok(#this_value::#ident)) - }); - let bytes_mapping = fields.iter().map(|(_, ident, aliases)| { - // `aliases` also contains a main name - let aliases = aliases - .iter() - .map(|alias| Literal::byte_string(alias.as_bytes())); - quote!(#(#aliases)|* => _serde::__private::Ok(#this_value::#ident)) - }); - - let expecting = expecting.unwrap_or(if is_variant { - "variant identifier" - } else { - "field identifier" - }); - - let bytes_to_str = if fallthrough.is_some() || collect_other_fields { - None - } else { - Some(quote! { - let __value = &_serde::__private::from_utf8_lossy(__value); - }) - }; - - let ( - value_as_str_content, - value_as_borrowed_str_content, - value_as_bytes_content, - value_as_borrowed_bytes_content, - ) = if collect_other_fields { - ( - Some(quote! { - let __value = _serde::__private::de::Content::String(_serde::__private::ToString::to_string(__value)); - }), - Some(quote! { - let __value = _serde::__private::de::Content::Str(__value); - }), - Some(quote! { - let __value = _serde::__private::de::Content::ByteBuf(__value.to_vec()); - }), - Some(quote! { - let __value = _serde::__private::de::Content::Bytes(__value); - }), - ) - } else { - (None, None, None, None) - }; - - let fallthrough_arm_tokens; - let fallthrough_arm = if let Some(fallthrough) = &fallthrough { - fallthrough - } else if is_variant { - fallthrough_arm_tokens = quote! { - _serde::__private::Err(_serde::de::Error::unknown_variant(__value, VARIANTS)) - }; - &fallthrough_arm_tokens - } else { - fallthrough_arm_tokens = quote! { - _serde::__private::Err(_serde::de::Error::unknown_field(__value, FIELDS)) - }; - &fallthrough_arm_tokens - }; - - let visit_other = if collect_other_fields { - quote! { - fn visit_bool<__E>(self, __value: bool) -> _serde::__private::Result - where - __E: _serde::de::Error, - { - _serde::__private::Ok(__Field::__other(_serde::__private::de::Content::Bool(__value))) - } - - fn visit_i8<__E>(self, __value: i8) -> _serde::__private::Result - where - __E: _serde::de::Error, - { - _serde::__private::Ok(__Field::__other(_serde::__private::de::Content::I8(__value))) - } - - fn visit_i16<__E>(self, __value: i16) -> _serde::__private::Result - where - __E: _serde::de::Error, - { - _serde::__private::Ok(__Field::__other(_serde::__private::de::Content::I16(__value))) - } - - fn visit_i32<__E>(self, __value: i32) -> _serde::__private::Result - where - __E: _serde::de::Error, - { - _serde::__private::Ok(__Field::__other(_serde::__private::de::Content::I32(__value))) - } - - fn visit_i64<__E>(self, __value: i64) -> _serde::__private::Result - where - __E: _serde::de::Error, - { - _serde::__private::Ok(__Field::__other(_serde::__private::de::Content::I64(__value))) - } - - fn visit_u8<__E>(self, __value: u8) -> _serde::__private::Result - where - __E: _serde::de::Error, - { - _serde::__private::Ok(__Field::__other(_serde::__private::de::Content::U8(__value))) - } - - fn visit_u16<__E>(self, __value: u16) -> _serde::__private::Result - where - __E: _serde::de::Error, - { - _serde::__private::Ok(__Field::__other(_serde::__private::de::Content::U16(__value))) - } - - fn visit_u32<__E>(self, __value: u32) -> _serde::__private::Result - where - __E: _serde::de::Error, - { - _serde::__private::Ok(__Field::__other(_serde::__private::de::Content::U32(__value))) - } - - fn visit_u64<__E>(self, __value: u64) -> _serde::__private::Result - where - __E: _serde::de::Error, - { - _serde::__private::Ok(__Field::__other(_serde::__private::de::Content::U64(__value))) - } - - fn visit_f32<__E>(self, __value: f32) -> _serde::__private::Result - where - __E: _serde::de::Error, - { - _serde::__private::Ok(__Field::__other(_serde::__private::de::Content::F32(__value))) - } - - fn visit_f64<__E>(self, __value: f64) -> _serde::__private::Result - where - __E: _serde::de::Error, - { - _serde::__private::Ok(__Field::__other(_serde::__private::de::Content::F64(__value))) - } - - fn visit_char<__E>(self, __value: char) -> _serde::__private::Result - where - __E: _serde::de::Error, - { - _serde::__private::Ok(__Field::__other(_serde::__private::de::Content::Char(__value))) - } - - fn visit_unit<__E>(self) -> _serde::__private::Result - where - __E: _serde::de::Error, - { - _serde::__private::Ok(__Field::__other(_serde::__private::de::Content::Unit)) - } - } - } else { - let u64_mapping = fields.iter().enumerate().map(|(i, (_, ident, _))| { - let i = i as u64; - quote!(#i => _serde::__private::Ok(#this_value::#ident)) - }); - - let u64_fallthrough_arm_tokens; - let u64_fallthrough_arm = if let Some(fallthrough) = &fallthrough { - fallthrough - } else { - let index_expecting = if is_variant { "variant" } else { "field" }; - let fallthrough_msg = format!("{} index 0 <= i < {}", index_expecting, fields.len()); - u64_fallthrough_arm_tokens = quote! { - _serde::__private::Err(_serde::de::Error::invalid_value( - _serde::de::Unexpected::Unsigned(__value), - &#fallthrough_msg, - )) - }; - &u64_fallthrough_arm_tokens - }; - - quote! { - fn visit_u64<__E>(self, __value: u64) -> _serde::__private::Result - where - __E: _serde::de::Error, - { - match __value { - #(#u64_mapping,)* - _ => #u64_fallthrough_arm, - } - } - } - }; - - let visit_borrowed = if fallthrough_borrowed.is_some() || collect_other_fields { - let str_mapping = str_mapping.clone(); - let bytes_mapping = bytes_mapping.clone(); - let fallthrough_borrowed_arm = fallthrough_borrowed.as_ref().unwrap_or(fallthrough_arm); - Some(quote! { - fn visit_borrowed_str<__E>(self, __value: &'de str) -> _serde::__private::Result - where - __E: _serde::de::Error, - { - match __value { - #(#str_mapping,)* - _ => { - #value_as_borrowed_str_content - #fallthrough_borrowed_arm - } - } - } - - fn visit_borrowed_bytes<__E>(self, __value: &'de [u8]) -> _serde::__private::Result - where - __E: _serde::de::Error, - { - match __value { - #(#bytes_mapping,)* - _ => { - #bytes_to_str - #value_as_borrowed_bytes_content - #fallthrough_borrowed_arm - } - } - } - }) - } else { - None - }; - - quote_block! { - fn expecting(&self, __formatter: &mut _serde::__private::Formatter) -> _serde::__private::fmt::Result { - _serde::__private::Formatter::write_str(__formatter, #expecting) - } - - #visit_other - - fn visit_str<__E>(self, __value: &str) -> _serde::__private::Result - where - __E: _serde::de::Error, - { - match __value { - #(#str_mapping,)* - _ => { - #value_as_str_content - #fallthrough_arm - } - } - } - - fn visit_bytes<__E>(self, __value: &[u8]) -> _serde::__private::Result - where - __E: _serde::de::Error, - { - match __value { - #(#bytes_mapping,)* - _ => { - #bytes_to_str - #value_as_bytes_content - #fallthrough_arm - } - } - } - - #visit_borrowed - } -} - -fn deserialize_map( - struct_path: &TokenStream, - params: &Parameters, - fields: &[Field], - cattrs: &attr::Container, -) -> Fragment { - // Create the field names for the fields. - let fields_names: Vec<_> = fields - .iter() - .enumerate() - .map(|(i, field)| (field, field_i(i))) - .collect(); - - // Declare each field that will be deserialized. - let let_values = fields_names - .iter() - .filter(|&&(field, _)| !field.attrs.skip_deserializing() && !field.attrs.flatten()) - .map(|(field, name)| { - let field_ty = field.ty; - quote! { - let mut #name: _serde::__private::Option<#field_ty> = _serde::__private::None; - } - }); - - // Collect contents for flatten fields into a buffer - let let_collect = if cattrs.has_flatten() { - Some(quote! { - let mut __collect = _serde::__private::Vec::<_serde::__private::Option<( - _serde::__private::de::Content, - _serde::__private::de::Content - )>>::new(); - }) - } else { - None - }; - - // Match arms to extract a value for a field. - let value_arms = fields_names - .iter() - .filter(|&&(field, _)| !field.attrs.skip_deserializing() && !field.attrs.flatten()) - .map(|(field, name)| { - let deser_name = field.attrs.name().deserialize_name(); - - let visit = match field.attrs.deserialize_with() { - None => { - let field_ty = field.ty; - let span = field.original.span(); - let seed = ¶ms.seed; - let context = ¶ms.context; - let func = - quote_spanned!(span=> _serde::de::MapAccess::next_value_seed::<#seed<#context, #field_ty>>); - quote! { - #func(&mut __map, #seed::new(self.context))? - } - } - Some(path) => { - let (wrapper, wrapper_ty) = wrap_deserialize_field_with(params, field.ty, path); - let seed = ¶ms.seed; - let context = ¶ms.context; - quote!({ - #wrapper - match _serde::de::MapAccess::next_value_seed::<#seed<#context, #wrapper_ty>>(&mut __map, #seed::new(self.context)) { - _serde::__private::Ok(__wrapper) => __wrapper.value, - _serde::__private::Err(__err) => { - return _serde::__private::Err(__err); - } - } - }) - } - }; - quote! { - __Field::#name => { - if _serde::__private::Option::is_some(&#name) { - return _serde::__private::Err(<__A::Error as _serde::de::Error>::duplicate_field(#deser_name)); - } - #name = _serde::__private::Some(#visit); - } - } - }); - - // Visit ignored values to consume them - let ignored_arm = if cattrs.has_flatten() { - Some(quote! { - __Field::__other(__name) => { - __collect.push(_serde::__private::Some(( - __name, - _serde::de::MapAccess::next_value(&mut __map)?))); - } - }) - } else if cattrs.deny_unknown_fields() { - None - } else { - Some(quote! { - _ => { let _ = _serde::de::MapAccess::next_value::<_serde::de::IgnoredAny>(&mut __map)?; } - }) - }; - - let all_skipped = fields.iter().all(|field| field.attrs.skip_deserializing()); - let match_keys = if cattrs.deny_unknown_fields() && all_skipped { - quote! { - // FIXME: Once feature(exhaustive_patterns) is stable: - // let _serde::__private::None::<__Field> = _serde::de::MapAccess::next_key(&mut __map)?; - _serde::__private::Option::map( - _serde::de::MapAccess::next_key::<__Field>(&mut __map)?, - |__impossible| match __impossible {}); - } - } else { - quote! { - while let _serde::__private::Some(__key) = _serde::de::MapAccess::next_key::<__Field>(&mut __map)? { - match __key { - #(#value_arms)* - #ignored_arm - } - } - } - }; - - let extract_values = fields_names - .iter() - .filter(|&&(field, _)| !field.attrs.skip_deserializing() && !field.attrs.flatten()) - .map(|(field, name)| { - let missing_expr = Match(expr_is_missing(field, cattrs)); - - quote! { - let #name = match #name { - _serde::__private::Some(#name) => #name, - _serde::__private::None => #missing_expr - }; - } - }); - - let extract_collected = fields_names - .iter() - .filter(|&&(field, _)| field.attrs.flatten() && !field.attrs.skip_deserializing()) - .map(|(field, name)| { - let field_ty = field.ty; - let func = match field.attrs.deserialize_with() { - None => { - let span = field.original.span(); - quote_spanned!(span=> _serde::de::Deserialize::deserialize) - } - Some(path) => quote!(#path), - }; - quote! { - let #name: #field_ty = #func( - _serde::__private::de::FlatMapDeserializer( - &mut __collect, - _serde::__private::PhantomData))?; - } - }); - - let collected_deny_unknown_fields = if cattrs.has_flatten() && cattrs.deny_unknown_fields() { - Some(quote! { - if let _serde::__private::Some(_serde::__private::Some((__key, _))) = - __collect.into_iter().filter(_serde::__private::Option::is_some).next() - { - if let _serde::__private::Some(__key) = __key.as_str() { - return _serde::__private::Err( - _serde::de::Error::custom(format_args!("unknown field `{}`", &__key))); - } else { - return _serde::__private::Err( - _serde::de::Error::custom(format_args!("unexpected map key"))); - } - } - }) - } else { - None - }; - - let result = fields_names.iter().map(|(field, name)| { - let member = &field.member; - if field.attrs.skip_deserializing() { - let value = Expr(expr_is_missing(field, cattrs)); - quote!(#member: #value) - } else { - quote!(#member: #name) - } - }); - - let let_default = match cattrs.default() { - attr::Default::Default => Some(quote!( - let __default: Self::Value = _serde::__private::Default::default(); - )), - attr::Default::Path(path) => Some(quote!( - let __default: Self::Value = #path(); - )), - attr::Default::None => { - // We don't need the default value, to prevent an unused variable warning - // we'll leave the line empty. - None - } - }; - - let mut result = quote!(#struct_path { #(#result),* }); - if params.has_getter { - let this_type = ¶ms.this_type; - let (_, ty_generics, _) = params.generics.split_for_impl(); - result = quote! { - _serde::__private::Into::<#this_type #ty_generics>::into(#result) - }; - } - - quote_block! { - #(#let_values)* - - #let_collect - - #match_keys - - #let_default - - #(#extract_values)* - - #(#extract_collected)* - - #collected_deny_unknown_fields - - _serde::__private::Ok(#result) - } -} - -fn field_i(i: usize) -> Ident { - Ident::new(&format!("__field{i}"), Span::call_site()) -} - -/// This function wraps the expression in `#[serde(deserialize_with = "...")]` -/// in a trait to prevent it from accessing the internal `Deserialize` state. -fn wrap_deserialize_with( - params: &Parameters, - value_ty: &TokenStream, - deserialize_with: &syn::ExprPath, -) -> (TokenStream, TokenStream) { - let this_type = ¶ms.this_type; - let (de_impl_generics, de_ty_generics, ty_generics, where_clause) = - split_with_de_lifetime(params); - let delife = params.borrowed.de_lifetime(); - - let wrapper = quote! { - #[doc(hidden)] - struct __DeserializeWith #de_impl_generics #where_clause { - value: #value_ty, - phantom: _serde::__private::PhantomData<#this_type #ty_generics>, - lifetime: _serde::__private::PhantomData<&#delife ()>, - } - - impl #de_impl_generics _serde::Deserialize<#delife> for __DeserializeWith #de_ty_generics #where_clause { - fn deserialize<__D>(__deserializer: __D) -> _serde::__private::Result - where - __D: _serde::Deserializer<#delife>, - { - _serde::__private::Ok(__DeserializeWith { - value: #deserialize_with(__deserializer)?, - phantom: _serde::__private::PhantomData, - lifetime: _serde::__private::PhantomData, - }) - } - } - }; - - let wrapper_ty = quote!(__DeserializeWith #de_ty_generics); - - (wrapper, wrapper_ty) -} - -fn wrap_deserialize_field_with( - params: &Parameters, - field_ty: &syn::Type, - deserialize_with: &syn::ExprPath, -) -> (TokenStream, TokenStream) { - wrap_deserialize_with(params, "e!(#field_ty), deserialize_with) -} - -fn wrap_deserialize_variant_with( - params: &Parameters, - variant: &Variant, - deserialize_with: &syn::ExprPath, -) -> (TokenStream, TokenStream, TokenStream) { - let field_tys = variant.fields.iter().map(|field| field.ty); - let (wrapper, wrapper_ty) = - wrap_deserialize_with(params, "e!((#(#field_tys),*)), deserialize_with); - - let unwrap_fn = unwrap_to_variant_closure(params, variant, true); - - (wrapper, wrapper_ty, unwrap_fn) -} - -// Generates closure that converts single input parameter to the final value. -fn unwrap_to_variant_closure( - params: &Parameters, - variant: &Variant, - with_wrapper: bool, -) -> TokenStream { - let this_value = ¶ms.this_value; - let variant_ident = &variant.ident; +// Generates closure that converts single input parameter to the final value. +fn unwrap_to_variant_closure( + params: &Parameters, + variant: &Variant, + with_wrapper: bool, +) -> TokenStream { + let this_value = ¶ms.this_value; + let variant_ident = &variant.ident; let (arg, wrapper) = if with_wrapper { (quote! { __wrap }, quote! { __wrap.value }) @@ -2493,11 +783,15 @@ fn expr_is_missing(field: &Field, cattrs: &attr::Container) -> Fragment { match field.attrs.default() { attr::Default::Default => { let span = field.original.span(); - let func = quote_spanned!(span=> _serde::__private::Default::default); + let func = quote_spanned!(span=> _serde::#private::Default::default); return quote_expr!(#func()); } attr::Default::Path(path) => { - return quote_expr!(#path()); + // If #path returns wrong type, error will be reported here (^^^^^). + // We attach span of the path to the function so it will be reported + // on the #[serde(default = "...")] + // ^^^^^ + return Fragment::Expr(quote_spanned!(path.span()=> #path())); } attr::Default::None => { /* below */ } } @@ -2514,19 +808,19 @@ fn expr_is_missing(field: &Field, cattrs: &attr::Container) -> Fragment { // match field.attrs.deserialize_with() { // None => { // let span = field.original.span(); - // let func = quote_spanned!(span=> _serde::__private::de::missing_field); + // let func = quote_spanned!(span=> _serde::#private::de::missing_field); // quote_expr! { // #func(#name)? // } // } // Some(_) => { // quote_expr! { - // return _serde::__private::Err(<__A::Error as _serde::de::Error>::missing_field(#name)) + // return _serde::#private::Err(<__A::Error as _serde::de::Error>::missing_field(#name)) // } // } // } quote_expr! { - return _serde::__private::Err(<__A::Error as _serde::de::Error>::missing_field(#name)) + return _serde::#private::Err(<__A::Error as _serde::de::Error>::missing_field(#name)) } } @@ -2540,9 +834,13 @@ fn expr_is_missing_seq( match field.attrs.default() { attr::Default::Default => { let span = field.original.span(); - return quote_spanned!(span=> #assign_to _serde::__private::Default::default()); + return quote_spanned!(span=> #assign_to _serde::#private::Default::default()); } attr::Default::Path(path) => { + // If #path returns wrong type, error will be reported here (^^^^^). + // We attach span of the path to the function so it will be reported + // on the #[serde(default = "...")] + // ^^^^^ return quote_spanned!(path.span()=> #assign_to #path()); } attr::Default::None => { /* below */ } @@ -2554,7 +852,7 @@ fn expr_is_missing_seq( quote!(#assign_to __default.#member) } attr::Default::None => quote!( - return _serde::__private::Err(_serde::de::Error::invalid_length(#index, &#expecting)) + return _serde::#private::Err(_serde::de::Error::invalid_length(#index, &#expecting)) ), } } @@ -2566,11 +864,56 @@ fn effective_style(variant: &Variant) -> Style { } } +/// True if there is any field with a `#[serde(flatten)]` attribute, other than +/// fields which are skipped. +fn has_flatten(fields: &[Field]) -> bool { + fields + .iter() + .any(|field| field.attrs.flatten() && !field.attrs.skip_deserializing()) +} + struct DeImplGenerics<'a>(&'a Parameters); +#[cfg(feature = "deserialize_in_place")] +struct InPlaceImplGenerics<'a>(&'a Parameters); + +impl<'a> ToTokens for DeImplGenerics<'a> { + fn to_tokens(&self, tokens: &mut TokenStream) { + let mut generics = self.0.generics.clone(); + if let Some(de_lifetime) = self.0.borrowed.de_lifetime_param() { + generics.params = Some(syn::GenericParam::Lifetime(de_lifetime)) + .into_iter() + .chain(generics.params) + .collect(); + } + let (impl_generics, _, _) = generics.split_for_impl(); + impl_generics.to_tokens(tokens); + } +} -impl ToTokens for DeImplGenerics<'_> { +#[cfg(feature = "deserialize_in_place")] +impl<'a> ToTokens for InPlaceImplGenerics<'a> { fn to_tokens(&self, tokens: &mut TokenStream) { + let place_lifetime = place_lifetime(); let mut generics = self.0.generics.clone(); + + // Add lifetime for `&'place mut Self, and `'a: 'place` + for param in &mut generics.params { + match param { + syn::GenericParam::Lifetime(param) => { + param.bounds.push(place_lifetime.lifetime.clone()); + } + syn::GenericParam::Type(param) => { + param.bounds.push(syn::TypeParamBound::Lifetime( + place_lifetime.lifetime.clone(), + )); + } + syn::GenericParam::Const(_) => {} + } + } + generics.params = Some(syn::GenericParam::Lifetime(place_lifetime)) + .into_iter() + .chain(generics.params) + .collect(); if let Some(de_lifetime) = self.0.borrowed.de_lifetime_param() { generics.params = Some(syn::GenericParam::Lifetime(de_lifetime)) .into_iter() @@ -2582,7 +925,16 @@ impl ToTokens for DeImplGenerics<'_> { } } +#[cfg(feature = "deserialize_in_place")] +impl<'a> DeImplGenerics<'a> { + fn in_place(self) -> InPlaceImplGenerics<'a> { + InPlaceImplGenerics(self.0) + } +} + struct DeTypeGenerics<'a>(&'a Parameters); +#[cfg(feature = "deserialize_in_place")] +struct InPlaceTypeGenerics<'a>(&'a Parameters); fn de_type_generics_to_tokens( mut generics: syn::Generics, @@ -2606,22 +958,38 @@ fn de_type_generics_to_tokens( ty_generics.to_tokens(tokens); } -impl ToTokens for DeTypeGenerics<'_> { +impl<'a> ToTokens for DeTypeGenerics<'a> { fn to_tokens(&self, tokens: &mut TokenStream) { de_type_generics_to_tokens(self.0.generics.clone(), &self.0.borrowed, tokens); } } -fn split_with_de_lifetime( - params: &Parameters, -) -> ( - DeImplGenerics<'_>, - DeTypeGenerics<'_>, - syn::TypeGenerics<'_>, - Option<&syn::WhereClause>, -) { - let de_impl_generics = DeImplGenerics(params); - let de_ty_generics = DeTypeGenerics(params); - let (_, ty_generics, where_clause) = params.generics.split_for_impl(); - (de_impl_generics, de_ty_generics, ty_generics, where_clause) +#[cfg(feature = "deserialize_in_place")] +impl<'a> ToTokens for InPlaceTypeGenerics<'a> { + fn to_tokens(&self, tokens: &mut TokenStream) { + let mut generics = self.0.generics.clone(); + generics.params = Some(syn::GenericParam::Lifetime(place_lifetime())) + .into_iter() + .chain(generics.params) + .collect(); + + de_type_generics_to_tokens(generics, &self.0.borrowed, tokens); + } +} + +#[cfg(feature = "deserialize_in_place")] +impl<'a> DeTypeGenerics<'a> { + fn in_place(self) -> InPlaceTypeGenerics<'a> { + InPlaceTypeGenerics(self.0) + } +} + +#[cfg(feature = "deserialize_in_place")] +fn place_lifetime() -> syn::LifetimeParam { + syn::LifetimeParam { + attrs: Vec::new(), + lifetime: syn::Lifetime::new("'place", Span::call_site()), + colon_token: None, + bounds: Punctuated::new(), + } } diff --git a/crates/web-rwkv-derive/src/serde/de/enum_.rs b/crates/web-rwkv-derive/src/serde/de/enum_.rs new file mode 100644 index 00000000..e9761620 --- /dev/null +++ b/crates/web-rwkv-derive/src/serde/de/enum_.rs @@ -0,0 +1,96 @@ +use crate::serde::de::enum_adjacently; +use crate::serde::de::enum_externally; +use crate::serde::de::enum_internally; +use crate::serde::de::enum_untagged; +use crate::serde::de::identifier; +use crate::serde::de::{field_i, FieldWithAliases, Parameters}; +use crate::serde::fragment::{Expr, Fragment, Stmts}; +use crate::serde::internals::ast::Variant; +use crate::serde::internals::attr; +use crate::serde::private; +use proc_macro2::TokenStream; +use quote::quote; + +/// Generates `Deserialize::deserialize` body for an `enum Enum {...}` +pub(super) fn deserialize( + params: &Parameters, + variants: &[Variant], + cattrs: &attr::Container, +) -> Fragment { + // The variants have already been checked (in ast.rs) that all untagged variants appear at the end + match variants.iter().position(|var| var.attrs.untagged()) { + Some(variant_idx) => { + let (tagged, untagged) = variants.split_at(variant_idx); + let tagged_frag = Expr(deserialize_homogeneous_enum(params, tagged, cattrs)); + // Ignore any error associated with non-untagged deserialization so that we + // can fall through to the untagged variants. This may be infallible so we + // need to provide the error type. + let first_attempt = quote! { + if let _serde::#private::Result::<_, __D::Error>::Ok(__ok) = (|| #tagged_frag)() { + return _serde::#private::Ok(__ok); + } + }; + enum_untagged::deserialize(params, untagged, cattrs, Some(first_attempt)) + } + None => deserialize_homogeneous_enum(params, variants, cattrs), + } +} + +fn deserialize_homogeneous_enum( + params: &Parameters, + variants: &[Variant], + cattrs: &attr::Container, +) -> Fragment { + match cattrs.tag() { + attr::TagType::External => enum_externally::deserialize(params, variants, cattrs), + attr::TagType::Internal { tag } => { + enum_internally::deserialize(params, variants, cattrs, tag) + } + attr::TagType::Adjacent { tag, content } => { + enum_adjacently::deserialize(params, variants, cattrs, tag, content) + } + attr::TagType::None => enum_untagged::deserialize(params, variants, cattrs, None), + } +} + +pub fn prepare_enum_variant_enum(variants: &[Variant]) -> (TokenStream, Stmts) { + let deserialized_variants = variants + .iter() + .enumerate() + .filter(|&(_i, variant)| !variant.attrs.skip_deserializing()); + + let fallthrough = deserialized_variants + .clone() + .find(|(_i, variant)| variant.attrs.other()) + .map(|(i, _variant)| { + let ignore_variant = field_i(i); + quote!(_serde::#private::Ok(__Field::#ignore_variant)) + }); + + let variants_stmt = { + let variant_names = deserialized_variants + .clone() + .flat_map(|(_i, variant)| variant.attrs.aliases()); + quote! { + #[doc(hidden)] + const VARIANTS: &'static [&'static str] = &[ #(#variant_names),* ]; + } + }; + + let deserialized_variants: Vec<_> = deserialized_variants + .map(|(i, variant)| FieldWithAliases { + ident: field_i(i), + aliases: variant.attrs.aliases(), + }) + .collect(); + + let variant_visitor = Stmts(identifier::deserialize_generated( + &deserialized_variants, + false, // variant identifiers do not depend on the presence of flatten fields + true, + None, + fallthrough, + )); + + (variants_stmt, variant_visitor) +} diff --git a/crates/web-rwkv-derive/src/serde/de/enum_adjacently.rs b/crates/web-rwkv-derive/src/serde/de/enum_adjacently.rs new file mode 100644 index 00000000..48943765 --- /dev/null +++ b/crates/web-rwkv-derive/src/serde/de/enum_adjacently.rs @@ -0,0 +1,323 @@ +//! Deserialization for adjacently tagged enums: +//! +//! ```ignore +//! #[serde(tag = "...", content = "...")] +//! enum Enum {} +//! ``` + +use crate::serde::de::enum_; +use crate::serde::de::enum_untagged; +use crate::serde::de::{field_i, Parameters}; +use crate::serde::fragment::{Fragment, Match}; +use crate::serde::internals::ast::{Style, Variant}; +use crate::serde::internals::attr; +use crate::serde::private; +use quote::{quote, quote_spanned}; +use syn::spanned::Spanned; + +/// Generates `Deserialize::deserialize` body for an `enum Enum {...}` with `#[serde(tag, content)]` attributes +pub(super) fn deserialize( + params: &Parameters, + variants: &[Variant], + cattrs: &attr::Container, + tag: &str, + content: &str, +) -> Fragment { + let this_type = ¶ms.this_type; + let this_value = ¶ms.this_value; + let (de_impl_generics, de_ty_generics, ty_generics, where_clause) = + params.generics_with_de_lifetime(); + let delife = params.borrowed.de_lifetime(); + + let (variants_stmt, variant_visitor) = enum_::prepare_enum_variant_enum(variants); + + let variant_arms: &Vec<_> = &variants + .iter() + .enumerate() + .filter(|&(_, variant)| !variant.attrs.skip_deserializing()) + .map(|(i, variant)| { + let variant_index = field_i(i); + + let block = Match(enum_untagged::deserialize_variant(params, variant, cattrs)); + + quote! { + __Field::#variant_index => #block + } + }) + .collect(); + + let rust_name = params.type_name(); + let expecting = format!("adjacently tagged enum {}", rust_name); + let expecting = cattrs.expecting().unwrap_or(&expecting); + let type_name = cattrs.name().deserialize_name(); + let deny_unknown_fields = cattrs.deny_unknown_fields(); + + // If unknown fields are allowed, we pick the visitor that can step over + // those. Otherwise we pick the visitor that fails on unknown keys. + let field_visitor_ty = if deny_unknown_fields { + quote! { _serde::#private::de::TagOrContentFieldVisitor } + } else { + quote! { _serde::#private::de::TagContentOtherFieldVisitor } + }; + + let mut missing_content = quote! { + _serde::#private::Err(<__A::Error as _serde::de::Error>::missing_field(#content)) + }; + let mut missing_content_fallthrough = quote!(); + let missing_content_arms = variants + .iter() + .enumerate() + .filter(|&(_, variant)| !variant.attrs.skip_deserializing()) + .filter_map(|(i, variant)| { + let variant_index = field_i(i); + let variant_ident = &variant.ident; + + let arm = match variant.style { + Style::Unit => quote! { + _serde::#private::Ok(#this_value::#variant_ident) + }, + Style::Newtype if variant.attrs.deserialize_with().is_none() => { + let span = variant.original.span(); + let func = quote_spanned!(span=> _serde::#private::de::missing_field); + quote! { + #func(#content).map(#this_value::#variant_ident) + } + } + _ => { + missing_content_fallthrough = quote!(_ => #missing_content); + return None; + } + }; + Some(quote! { + __Field::#variant_index => #arm, + }) + }) + .collect::>(); + if !missing_content_arms.is_empty() { + missing_content = quote! { + match __field { + #(#missing_content_arms)* + #missing_content_fallthrough + } + }; + } + + // Advance the map by one key, returning early in case of error. + let next_key = quote! { + _serde::de::MapAccess::next_key_seed(&mut __map, #field_visitor_ty { + tag: #tag, + content: #content, + })? + }; + + let variant_from_map = quote! { + _serde::de::MapAccess::next_value_seed(&mut __map, _serde::#private::de::AdjacentlyTaggedEnumVariantSeed::<__Field> { + enum_name: #rust_name, + variants: VARIANTS, + fields_enum: _serde::#private::PhantomData + })? + }; + + // When allowing unknown fields, we want to transparently step through keys + // we don't care about until we find `tag`, `content`, or run out of keys. + let next_relevant_key = if deny_unknown_fields { + next_key + } else { + quote!({ + let mut __rk : _serde::#private::Option<_serde::#private::de::TagOrContentField> = _serde::#private::None; + while let _serde::#private::Some(__k) = #next_key { + match __k { + _serde::#private::de::TagContentOtherField::Other => { + let _ = _serde::de::MapAccess::next_value::<_serde::de::IgnoredAny>(&mut __map)?; + continue; + }, + _serde::#private::de::TagContentOtherField::Tag => { + __rk = _serde::#private::Some(_serde::#private::de::TagOrContentField::Tag); + break; + } + _serde::#private::de::TagContentOtherField::Content => { + __rk = _serde::#private::Some(_serde::#private::de::TagOrContentField::Content); + break; + } + } + } + + __rk + }) + }; + + // Step through remaining keys, looking for duplicates of previously-seen + // keys. When unknown fields are denied, any key that isn't a duplicate will + // at this point immediately produce an error. + let visit_remaining_keys = quote! { + match #next_relevant_key { + _serde::#private::Some(_serde::#private::de::TagOrContentField::Tag) => { + _serde::#private::Err(<__A::Error as _serde::de::Error>::duplicate_field(#tag)) + } + _serde::#private::Some(_serde::#private::de::TagOrContentField::Content) => { + _serde::#private::Err(<__A::Error as _serde::de::Error>::duplicate_field(#content)) + } + _serde::#private::None => _serde::#private::Ok(__ret), + } + }; + + let finish_content_then_tag = if variant_arms.is_empty() { + quote! { + match #variant_from_map {} + } + } else { + quote! { + let __seed = __Seed { + variant: #variant_from_map, + marker: _serde::#private::PhantomData, + lifetime: _serde::#private::PhantomData, + }; + let __deserializer = _serde::#private::de::ContentDeserializer::<__A::Error>::new(__content); + let __ret = _serde::de::DeserializeSeed::deserialize(__seed, __deserializer)?; + // Visit remaining keys, looking for duplicates. + #visit_remaining_keys + } + }; + + quote_block! { + #variant_visitor + + #variants_stmt + + #[doc(hidden)] + struct __Seed #de_impl_generics #where_clause { + variant: __Field, + marker: _serde::#private::PhantomData<#this_type #ty_generics>, + lifetime: _serde::#private::PhantomData<&#delife ()>, + } + + #[automatically_derived] + impl #de_impl_generics _serde::de::DeserializeSeed<#delife> for __Seed #de_ty_generics #where_clause { + type Value = #this_type #ty_generics; + + fn deserialize<__D>(self, __deserializer: __D) -> _serde::#private::Result + where + __D: _serde::Deserializer<#delife>, + { + match self.variant { + #(#variant_arms)* + } + } + } + + #[doc(hidden)] + struct __Visitor #de_impl_generics #where_clause { + marker: _serde::#private::PhantomData<#this_type #ty_generics>, + lifetime: _serde::#private::PhantomData<&#delife ()>, + } + + #[automatically_derived] + impl #de_impl_generics _serde::de::Visitor<#delife> for __Visitor #de_ty_generics #where_clause { + type Value = #this_type #ty_generics; + + fn expecting(&self, __formatter: &mut _serde::#private::Formatter) -> _serde::#private::fmt::Result { + _serde::#private::Formatter::write_str(__formatter, #expecting) + } + + fn visit_map<__A>(self, mut __map: __A) -> _serde::#private::Result + where + __A: _serde::de::MapAccess<#delife>, + { + // Visit the first relevant key. + match #next_relevant_key { + // First key is the tag. + _serde::#private::Some(_serde::#private::de::TagOrContentField::Tag) => { + // Parse the tag. + let __field = #variant_from_map; + // Visit the second key. + match #next_relevant_key { + // Second key is a duplicate of the tag. + _serde::#private::Some(_serde::#private::de::TagOrContentField::Tag) => { + _serde::#private::Err(<__A::Error as _serde::de::Error>::duplicate_field(#tag)) + } + // Second key is the content. + _serde::#private::Some(_serde::#private::de::TagOrContentField::Content) => { + let __ret = _serde::de::MapAccess::next_value_seed(&mut __map, + __Seed { + variant: __field, + marker: _serde::#private::PhantomData, + lifetime: _serde::#private::PhantomData, + })?; + // Visit remaining keys, looking for duplicates. + #visit_remaining_keys + } + // There is no second key; might be okay if the we have a unit variant. + _serde::#private::None => #missing_content + } + } + // First key is the content. + _serde::#private::Some(_serde::#private::de::TagOrContentField::Content) => { + // Buffer up the content. + let __content = _serde::de::MapAccess::next_value_seed(&mut __map, _serde::#private::de::ContentVisitor::new())?; + // Visit the second key. + match #next_relevant_key { + // Second key is the tag. + _serde::#private::Some(_serde::#private::de::TagOrContentField::Tag) => { + #finish_content_then_tag + } + // Second key is a duplicate of the content. + _serde::#private::Some(_serde::#private::de::TagOrContentField::Content) => { + _serde::#private::Err(<__A::Error as _serde::de::Error>::duplicate_field(#content)) + } + // There is no second key. + _serde::#private::None => { + _serde::#private::Err(<__A::Error as _serde::de::Error>::missing_field(#tag)) + } + } + } + // There is no first key. + _serde::#private::None => { + _serde::#private::Err(<__A::Error as _serde::de::Error>::missing_field(#tag)) + } + } + } + + fn visit_seq<__A>(self, mut __seq: __A) -> _serde::#private::Result + where + __A: _serde::de::SeqAccess<#delife>, + { + // Visit the first element - the tag. + match _serde::de::SeqAccess::next_element(&mut __seq)? { + _serde::#private::Some(__variant) => { + // Visit the second element - the content. + match _serde::de::SeqAccess::next_element_seed( + &mut __seq, + __Seed { + variant: __variant, + marker: _serde::#private::PhantomData, + lifetime: _serde::#private::PhantomData, + }, + )? { + _serde::#private::Some(__ret) => _serde::#private::Ok(__ret), + // There is no second element. + _serde::#private::None => { + _serde::#private::Err(_serde::de::Error::invalid_length(1, &self)) + } + } + } + // There is no first element. + _serde::#private::None => { + _serde::#private::Err(_serde::de::Error::invalid_length(0, &self)) + } + } + } + } + + #[doc(hidden)] + const FIELDS: &'static [&'static str] = &[#tag, #content]; + _serde::Deserializer::deserialize_struct( + __deserializer, + #type_name, + FIELDS, + __Visitor { + marker: _serde::#private::PhantomData::<#this_type #ty_generics>, + lifetime: _serde::#private::PhantomData, + }, + ) + } +} diff --git a/crates/web-rwkv-derive/src/serde/de/enum_externally.rs b/crates/web-rwkv-derive/src/serde/de/enum_externally.rs new file mode 100644 index 00000000..64a42f3f --- /dev/null +++ b/crates/web-rwkv-derive/src/serde/de/enum_externally.rs @@ -0,0 +1,220 @@ +//! Deserialization for externally tagged enums: +//! +//! ```ignore +//! enum Enum {} +//! ``` + +use crate::serde::de::enum_; +use crate::serde::de::struct_; +use crate::serde::de::tuple; +use crate::serde::de::{ + expr_is_missing, field_i, unwrap_to_variant_closure, wrap_deserialize_field_with, + wrap_deserialize_with, Parameters, StructForm, TupleForm, +}; +use crate::serde::fragment::{Expr, Fragment, Match}; +use crate::serde::internals::ast::{Field, Style, Variant}; +use crate::serde::internals::attr; +use crate::serde::private; +use proc_macro2::TokenStream; +use quote::{quote, quote_spanned}; +use syn::spanned::Spanned; + +/// Generates `Deserialize::deserialize` body for an `enum Enum {...}` without additional attributes +pub(super) fn deserialize( + params: &Parameters, + variants: &[Variant], + cattrs: &attr::Container, +) -> Fragment { + let this_type = ¶ms.this_type; + let (de_impl_generics, de_ty_generics, ty_generics, where_clause) = + params.generics_with_de_lifetime(); + let delife = params.borrowed.de_lifetime(); + let context = ¶ms.context; + + let type_name = cattrs.name().deserialize_name(); + let expecting = format!("enum {}", params.type_name()); + let expecting = cattrs.expecting().unwrap_or(&expecting); + + let (variants_stmt, variant_visitor) = enum_::prepare_enum_variant_enum(variants); + + // Match arms to extract a variant from a string + let variant_arms = variants + .iter() + .enumerate() + .filter(|&(_, variant)| !variant.attrs.skip_deserializing()) + .map(|(i, variant)| { + let variant_name = field_i(i); + + let block = Match(deserialize_externally_tagged_variant( + params, variant, cattrs, + )); + + quote! { + (__Field::#variant_name, __variant) => #block + } + }); + + let all_skipped = variants + .iter() + .all(|variant| variant.attrs.skip_deserializing()); + let match_variant = if all_skipped { + // This is an empty enum like `enum Impossible {}` or an enum in which + // all variants have `#[serde(skip_deserializing)]`. + quote! { + // FIXME: Once feature(exhaustive_patterns) is stable: + // let _serde::#private::Err(__err) = _serde::de::EnumAccess::variant::<__Field>(__data); + // _serde::#private::Err(__err) + _serde::#private::Result::map( + _serde::de::EnumAccess::variant::<__Field>(__data), + |(__impossible, _)| match __impossible {}) + } + } else { + quote! { + match _serde::de::EnumAccess::variant(__data)? { + #(#variant_arms)* + } + } + }; + + quote_block! { + #variant_visitor + + #[doc(hidden)] + struct __Visitor #de_impl_generics #where_clause { + context: &#delife #context, + marker: _serde::#private::PhantomData<#this_type #ty_generics>, + lifetime: _serde::#private::PhantomData<&#delife ()>, + } + + #[automatically_derived] + impl #de_impl_generics _serde::de::Visitor<#delife> for __Visitor #de_ty_generics #where_clause { + type Value = #this_type #ty_generics; + + fn expecting(&self, __formatter: &mut _serde::#private::Formatter) -> _serde::#private::fmt::Result { + _serde::#private::Formatter::write_str(__formatter, #expecting) + } + + fn visit_enum<__A>(self, __data: __A) -> _serde::#private::Result + where + __A: _serde::de::EnumAccess<#delife>, + { + #match_variant + } + } + + #variants_stmt + + _serde::Deserializer::deserialize_enum( + __deserializer, + #type_name, + VARIANTS, + __Visitor { + context: self.context, + marker: _serde::#private::PhantomData::<#this_type #ty_generics>, + lifetime: _serde::#private::PhantomData, + }, + ) + } +} + +fn deserialize_externally_tagged_variant( + params: &Parameters, + variant: &Variant, + cattrs: &attr::Container, +) -> Fragment { + if let Some(path) = variant.attrs.deserialize_with() { + let (wrapper, wrapper_ty, unwrap_fn) = wrap_deserialize_variant_with(params, variant, path); + let seed = ¶ms.seed; + let context = ¶ms.context; + return quote_block! { + #wrapper + _serde::#private::Result::map( + _serde::de::VariantAccess::newtype_variant_seed::<#seed<#context, #wrapper_ty>>(__variant, #seed::new(self.context)), #unwrap_fn) + }; + } + + let variant_ident = &variant.ident; + + match variant.style { + Style::Unit => { + let this_value = ¶ms.this_value; + quote_block! { + _serde::de::VariantAccess::unit_variant(__variant)?; + _serde::#private::Ok(#this_value::#variant_ident) + } + } + Style::Newtype => deserialize_externally_tagged_newtype_variant( + variant_ident, + params, + &variant.fields[0], + cattrs, + ), + Style::Tuple => tuple::deserialize( + params, + &variant.fields, + cattrs, + TupleForm::ExternallyTagged(variant_ident), + ), + Style::Struct => struct_::deserialize( + params, + &variant.fields, + cattrs, + StructForm::ExternallyTagged(variant_ident), + ), + } +} + +fn wrap_deserialize_variant_with( + params: &Parameters, + variant: &Variant, + deserialize_with: &syn::ExprPath, +) -> (TokenStream, TokenStream, TokenStream) { + let field_tys = variant.fields.iter().map(|field| field.ty); + let (wrapper, wrapper_ty) = + wrap_deserialize_with(params, "e!((#(#field_tys),*)), deserialize_with); + + let unwrap_fn = unwrap_to_variant_closure(params, variant, true); + + (wrapper, wrapper_ty, unwrap_fn) +} + +fn deserialize_externally_tagged_newtype_variant( + variant_ident: &syn::Ident, + params: &Parameters, + field: &Field, + cattrs: &attr::Container, +) -> Fragment { + let this_value = ¶ms.this_value; + + if field.attrs.skip_deserializing() { + let default = Expr(expr_is_missing(field, cattrs)); + return quote_block! { + _serde::de::VariantAccess::unit_variant(__variant)?; + _serde::#private::Ok(#this_value::#variant_ident(#default)) + }; + } + + match field.attrs.deserialize_with() { + None => { + let field_ty = field.ty; + let span = field.original.span(); + let seed = ¶ms.seed; + let context = ¶ms.context; + let func = quote_spanned!(span=> _serde::de::VariantAccess::newtype_variant_seed::<#seed<#context, #field_ty>>); + quote_expr! { + _serde::#private::Result::map(#func(__variant, #seed::new(self.context)), #this_value::#variant_ident) + } + } + Some(path) => { + let (wrapper, wrapper_ty) = wrap_deserialize_field_with(params, field.ty, path); + let seed = ¶ms.seed; + let context = ¶ms.context; + quote_block! { + #wrapper + _serde::#private::Result::map( + _serde::de::VariantAccess::newtype_variant_seed::<#seed<#context, #wrapper_ty>>(__variant, #seed::new(self.context)), + |__wrapper| #this_value::#variant_ident(__wrapper.value)) + } + } + } +} diff --git a/crates/web-rwkv-derive/src/serde/de/enum_internally.rs b/crates/web-rwkv-derive/src/serde/de/enum_internally.rs new file mode 100644 index 00000000..5006b4bc --- /dev/null +++ b/crates/web-rwkv-derive/src/serde/de/enum_internally.rs @@ -0,0 +1,106 @@ +//! Deserialization for internally tagged enums: +//! +//! ```ignore +//! #[serde(tag = "...")] +//! enum Enum {} +//! ``` + +use crate::serde::de::enum_; +use crate::serde::de::enum_untagged; +use crate::serde::de::struct_; +use crate::serde::de::{ + effective_style, expr_is_missing, field_i, unwrap_to_variant_closure, Parameters, StructForm, +}; +use crate::serde::fragment::{Expr, Fragment, Match}; +use crate::serde::internals::ast::{Style, Variant}; +use crate::serde::internals::attr; +use crate::serde::private; +use quote::quote; + +/// Generates `Deserialize::deserialize` body for an `enum Enum {...}` with `#[serde(tag)]` attribute +pub(super) fn deserialize( + params: &Parameters, + variants: &[Variant], + cattrs: &attr::Container, + tag: &str, +) -> Fragment { + let (variants_stmt, variant_visitor) = enum_::prepare_enum_variant_enum(variants); + + // Match arms to extract a variant from a string + let variant_arms = variants + .iter() + .enumerate() + .filter(|&(_, variant)| !variant.attrs.skip_deserializing()) + .map(|(i, variant)| { + let variant_name = field_i(i); + + let block = Match(deserialize_internally_tagged_variant( + params, variant, cattrs, + )); + + quote! { + __Field::#variant_name => #block + } + }); + + let expecting = format!("internally tagged enum {}", params.type_name()); + let expecting = cattrs.expecting().unwrap_or(&expecting); + + quote_block! { + #variant_visitor + + #variants_stmt + + let (__tag, __content) = _serde::Deserializer::deserialize_any( + __deserializer, + _serde::#private::de::TaggedContentVisitor::<__Field>::new(#tag, #expecting))?; + let __deserializer = _serde::#private::de::ContentDeserializer::<__D::Error>::new(__content); + + match __tag { + #(#variant_arms)* + } + } +} + +// Generates significant part of the visit_seq and visit_map bodies of visitors +// for the variants of internally tagged enum. +fn deserialize_internally_tagged_variant( + params: &Parameters, + variant: &Variant, + cattrs: &attr::Container, +) -> Fragment { + if let Some(path) = variant.attrs.deserialize_with() { + let unwrap_fn = unwrap_to_variant_closure(params, variant, false); + return quote_block! { + _serde::#private::Result::map(#path(__deserializer), #unwrap_fn) + }; + } + + let variant_ident = &variant.ident; + + match effective_style(variant) { + Style::Unit => { + let this_value = ¶ms.this_value; + let type_name = params.type_name(); + let variant_name = variant.ident.to_string(); + let default = variant.fields.first().map(|field| { + let default = Expr(expr_is_missing(field, cattrs)); + quote!((#default)) + }); + quote_block! { + _serde::Deserializer::deserialize_any(__deserializer, _serde::#private::de::InternallyTaggedUnitVisitor::new(#type_name, #variant_name))?; + _serde::#private::Ok(#this_value::#variant_ident #default) + } + } + Style::Newtype => { + enum_untagged::deserialize_newtype_variant(variant_ident, params, &variant.fields[0]) + } + Style::Struct => struct_::deserialize( + params, + &variant.fields, + cattrs, + StructForm::InternallyTagged(variant_ident), + ), + Style::Tuple => unreachable!("checked in serde_derive_internals"), + } +} diff --git a/crates/web-rwkv-derive/src/serde/de/enum_untagged.rs b/crates/web-rwkv-derive/src/serde/de/enum_untagged.rs new file mode 100644 index 00000000..81a1ad86 --- /dev/null +++ b/crates/web-rwkv-derive/src/serde/de/enum_untagged.rs @@ -0,0 +1,135 @@ +//! Deserialization for untagged enums: +//! +//! ```ignore +//! #[serde(untagged)] +//! enum Enum {} +//! ``` + +use crate::serde::de::struct_; +use crate::serde::de::tuple; +use crate::serde::de::{ + effective_style, expr_is_missing, unwrap_to_variant_closure, Parameters, StructForm, TupleForm, +}; +use crate::serde::fragment::{Expr, Fragment}; +use crate::serde::internals::ast::{Field, Style, Variant}; +use crate::serde::internals::attr; +use crate::serde::private; +use proc_macro2::TokenStream; +use quote::{quote, quote_spanned}; +use syn::spanned::Spanned; + +/// Generates `Deserialize::deserialize` body for an `enum Enum {...}` with `#[serde(untagged)]` attribute +pub(super) fn deserialize( + params: &Parameters, + variants: &[Variant], + cattrs: &attr::Container, + first_attempt: Option, +) -> Fragment { + let attempts = variants + .iter() + .filter(|variant| !variant.attrs.skip_deserializing()) + .map(|variant| Expr(deserialize_variant(params, variant, cattrs))); + // TODO this message could be better by saving the errors from the failed + // attempts. The heuristic used by TOML was to count the number of fields + // processed before an error, and use the error that happened after the + // largest number of fields. I'm not sure I like that. Maybe it would be + // better to save all the errors and combine them into one message that + // explains why none of the variants matched. + let fallthrough_msg = format!( + "data did not match any variant of untagged enum {}", + params.type_name() + ); + let fallthrough_msg = cattrs.expecting().unwrap_or(&fallthrough_msg); + + let private2 = private; + quote_block! { + let __content = _serde::de::DeserializeSeed::deserialize(_serde::#private::de::ContentVisitor::new(), __deserializer)?; + let __deserializer = _serde::#private::de::ContentRefDeserializer::<__D::Error>::new(&__content); + + #first_attempt + + #( + if let _serde::#private2::Ok(__ok) = #attempts { + return _serde::#private2::Ok(__ok); + } + )* + + _serde::#private::Err(_serde::de::Error::custom(#fallthrough_msg)) + } +} + +// Also used by adjacently tagged enums +pub(super) fn deserialize_variant( + params: &Parameters, + variant: &Variant, + cattrs: &attr::Container, +) -> Fragment { + if let Some(path) = variant.attrs.deserialize_with() { + let unwrap_fn = unwrap_to_variant_closure(params, variant, false); + return quote_block! { + _serde::#private::Result::map(#path(__deserializer), #unwrap_fn) + }; + } + + let variant_ident = &variant.ident; + + match effective_style(variant) { + Style::Unit => { + let this_value = ¶ms.this_value; + let type_name = params.type_name(); + let variant_name = variant.ident.to_string(); + let default = variant.fields.first().map(|field| { + let default = Expr(expr_is_missing(field, cattrs)); + quote!((#default)) + }); + quote_expr! { + match _serde::Deserializer::deserialize_any( + __deserializer, + _serde::#private::de::UntaggedUnitVisitor::new(#type_name, #variant_name) + ) { + _serde::#private::Ok(()) => _serde::#private::Ok(#this_value::#variant_ident #default), + _serde::#private::Err(__err) => _serde::#private::Err(__err), + } + } + } + Style::Newtype => deserialize_newtype_variant(variant_ident, params, &variant.fields[0]), + Style::Tuple => tuple::deserialize( + params, + &variant.fields, + cattrs, + TupleForm::Untagged(variant_ident), + ), + Style::Struct => struct_::deserialize( + params, + &variant.fields, + cattrs, + StructForm::Untagged(variant_ident), + ), + } +} + +// Also used by internally tagged enums +// Implicitly (via `generate_variant`) used by adjacently tagged enums +pub(super) fn deserialize_newtype_variant( + variant_ident: &syn::Ident, + params: &Parameters, + field: &Field, +) -> Fragment { + let this_value = ¶ms.this_value; + let field_ty = field.ty; + match field.attrs.deserialize_with() { + None => { + let span = field.original.span(); + let func = quote_spanned!(span=> <#field_ty as _serde::Deserialize>::deserialize); + quote_expr! { + _serde::#private::Result::map(#func(__deserializer), #this_value::#variant_ident) + } + } + Some(path) => { + quote_block! { + let __value: _serde::#private::Result<#field_ty, _> = #path(__deserializer); + _serde::#private::Result::map(__value, #this_value::#variant_ident) + } + } + } +} diff --git a/crates/web-rwkv-derive/src/serde/de/identifier.rs b/crates/web-rwkv-derive/src/serde/de/identifier.rs new file mode 100644 index 00000000..281655d3 --- /dev/null +++ b/crates/web-rwkv-derive/src/serde/de/identifier.rs @@ -0,0 +1,477 @@ +//! Deserialization of struct field identifiers and enum variant identifiers by +//! way of a Rust enum. + +use crate::serde::de::{FieldWithAliases, Parameters}; +use crate::serde::fragment::{Fragment, Stmts}; +use crate::serde::internals::ast::{Style, Variant}; +use crate::serde::internals::attr; +use crate::serde::private; +use proc_macro2::{Literal, TokenStream}; +use quote::{quote, ToTokens}; + +// Generates `Deserialize::deserialize` body for an enum with +// `serde(field_identifier)` or `serde(variant_identifier)` attribute. +pub(super) fn deserialize_custom( + params: &Parameters, + variants: &[Variant], + cattrs: &attr::Container, +) -> Fragment { + let is_variant = match cattrs.identifier() { + attr::Identifier::Variant => true, + attr::Identifier::Field => false, + attr::Identifier::No => unreachable!(), + }; + + let this_type = params.this_type.to_token_stream(); + let this_value = params.this_value.to_token_stream(); + + let (ordinary, fallthrough, fallthrough_borrowed) = if let Some(last) = variants.last() { + let last_ident = &last.ident; + if last.attrs.other() { + // Process `serde(other)` attribute. It would always be found on the + // last variant (checked in `check_identifier`), so all preceding + // are ordinary variants. + let ordinary = &variants[..variants.len() - 1]; + let fallthrough = quote!(_serde::#private::Ok(#this_value::#last_ident)); + (ordinary, Some(fallthrough), None) + } else if let Style::Newtype = last.style { + let ordinary = &variants[..variants.len() - 1]; + let fallthrough = |value| { + quote! { + _serde::#private::Result::map( + _serde::Deserialize::deserialize( + _serde::#private::de::IdentifierDeserializer::from(#value) + ), + #this_value::#last_ident) + } + }; + ( + ordinary, + Some(fallthrough(quote!(__value))), + Some(fallthrough(quote!(_serde::#private::de::Borrowed( + __value + )))), + ) + } else { + (variants, None, None) + } + } else { + (variants, None, None) + }; + + let idents_aliases: Vec<_> = ordinary + .iter() + .map(|variant| FieldWithAliases { + ident: variant.ident.clone(), + aliases: variant.attrs.aliases(), + }) + .collect(); + + let names = idents_aliases.iter().flat_map(|variant| variant.aliases); + + let names_const = if fallthrough.is_some() { + None + } else if is_variant { + let variants = quote! { + #[doc(hidden)] + const VARIANTS: &'static [&'static str] = &[ #(#names),* ]; + }; + Some(variants) + } else { + let fields = quote! { + #[doc(hidden)] + const FIELDS: &'static [&'static str] = &[ #(#names),* ]; + }; + Some(fields) + }; + + let (de_impl_generics, de_ty_generics, ty_generics, where_clause) = + params.generics_with_de_lifetime(); + let delife = params.borrowed.de_lifetime(); + let visitor_impl = Stmts(deserialize_identifier( + &this_value, + &idents_aliases, + is_variant, + fallthrough, + fallthrough_borrowed, + false, + cattrs.expecting(), + )); + + quote_block! { + #names_const + + #[doc(hidden)] + struct __FieldVisitor #de_impl_generics #where_clause { + marker: _serde::#private::PhantomData<#this_type #ty_generics>, + lifetime: _serde::#private::PhantomData<&#delife ()>, + } + + #[automatically_derived] + impl #de_impl_generics _serde::de::Visitor<#delife> for __FieldVisitor #de_ty_generics #where_clause { + type Value = #this_type #ty_generics; + + #visitor_impl + } + + let __visitor = __FieldVisitor { + marker: _serde::#private::PhantomData::<#this_type #ty_generics>, + lifetime: _serde::#private::PhantomData, + }; + _serde::Deserializer::deserialize_identifier(__deserializer, __visitor) + } +} + +pub(super) fn deserialize_generated( + deserialized_fields: &[FieldWithAliases], + has_flatten: bool, + is_variant: bool, + ignore_variant: Option, + fallthrough: Option, +) -> Fragment { + let this_value = quote!(__Field); + let field_idents: &Vec<_> = &deserialized_fields + .iter() + .map(|field| &field.ident) + .collect(); + + let visitor_impl = Stmts(deserialize_identifier( + &this_value, + deserialized_fields, + is_variant, + fallthrough, + None, + !is_variant && has_flatten, + None, + )); + + let lifetime = if !is_variant && has_flatten { + Some(quote!(<'de>)) + } else { + None + }; + + quote_block! { + #[allow(non_camel_case_types)] + #[doc(hidden)] + enum __Field #lifetime { + #(#field_idents,)* + #ignore_variant + } + + #[doc(hidden)] + struct __FieldVisitor; + + #[automatically_derived] + impl<'de> _serde::de::Visitor<'de> for __FieldVisitor { + type Value = __Field #lifetime; + + #visitor_impl + } + + #[automatically_derived] + impl<'de> _serde::Deserialize<'de> for __Field #lifetime { + #[inline] + fn deserialize<__D>(__deserializer: __D) -> _serde::#private::Result + where + __D: _serde::Deserializer<'de>, + { + _serde::Deserializer::deserialize_identifier(__deserializer, __FieldVisitor) + } + } + } +} + +fn deserialize_identifier( + this_value: &TokenStream, + deserialized_fields: &[FieldWithAliases], + is_variant: bool, + fallthrough: Option, + fallthrough_borrowed: Option, + collect_other_fields: bool, + expecting: Option<&str>, +) -> Fragment { + let str_mapping = deserialized_fields.iter().map(|field| { + let ident = &field.ident; + let aliases = field.aliases; + let private2 = private; + // `aliases` also contains a main name + quote! { + #( + #aliases => _serde::#private2::Ok(#this_value::#ident), + )* + } + }); + let bytes_mapping = deserialized_fields.iter().map(|field| { + let ident = &field.ident; + // `aliases` also contains a main name + let aliases = field + .aliases + .iter() + .map(|alias| Literal::byte_string(alias.value.as_bytes())); + let private2 = private; + quote! { + #( + #aliases => _serde::#private2::Ok(#this_value::#ident), + )* + } + }); + + let expecting = expecting.unwrap_or(if is_variant { + "variant identifier" + } else { + "field identifier" + }); + + let bytes_to_str = if fallthrough.is_some() || collect_other_fields { + None + } else { + Some(quote! { + let __value = &_serde::#private::from_utf8_lossy(__value); + }) + }; + + let ( + value_as_str_content, + value_as_borrowed_str_content, + value_as_bytes_content, + value_as_borrowed_bytes_content, + ) = if collect_other_fields { + ( + Some(quote! { + let __value = _serde::#private::de::Content::String(_serde::#private::ToString::to_string(__value)); + }), + Some(quote! { + let __value = _serde::#private::de::Content::Str(__value); + }), + Some(quote! { + let __value = _serde::#private::de::Content::ByteBuf(__value.to_vec()); + }), + Some(quote! { + let __value = _serde::#private::de::Content::Bytes(__value); + }), + ) + } else { + (None, None, None, None) + }; + + let fallthrough_arm_tokens; + let fallthrough_arm = if let Some(fallthrough) = &fallthrough { + fallthrough + } else if is_variant { + fallthrough_arm_tokens = quote! { + _serde::#private::Err(_serde::de::Error::unknown_variant(__value, VARIANTS)) + }; + &fallthrough_arm_tokens + } else { + fallthrough_arm_tokens = quote! { + _serde::#private::Err(_serde::de::Error::unknown_field(__value, FIELDS)) + }; + &fallthrough_arm_tokens + }; + + let visit_other = if collect_other_fields { + quote! { + fn visit_bool<__E>(self, __value: bool) -> _serde::#private::Result + where + __E: _serde::de::Error, + { + _serde::#private::Ok(__Field::__other(_serde::#private::de::Content::Bool(__value))) + } + + fn visit_i8<__E>(self, __value: i8) -> _serde::#private::Result + where + __E: _serde::de::Error, + { + _serde::#private::Ok(__Field::__other(_serde::#private::de::Content::I8(__value))) + } + + fn visit_i16<__E>(self, __value: i16) -> _serde::#private::Result + where + __E: _serde::de::Error, + { + _serde::#private::Ok(__Field::__other(_serde::#private::de::Content::I16(__value))) + } + + fn visit_i32<__E>(self, __value: i32) -> _serde::#private::Result + where + __E: _serde::de::Error, + { + _serde::#private::Ok(__Field::__other(_serde::#private::de::Content::I32(__value))) + } + + fn visit_i64<__E>(self, __value: i64) -> _serde::#private::Result + where + __E: _serde::de::Error, + { + _serde::#private::Ok(__Field::__other(_serde::#private::de::Content::I64(__value))) + } + + fn visit_u8<__E>(self, __value: u8) -> _serde::#private::Result + where + __E: _serde::de::Error, + { + _serde::#private::Ok(__Field::__other(_serde::#private::de::Content::U8(__value))) + } + + fn visit_u16<__E>(self, __value: u16) -> _serde::#private::Result + where + __E: _serde::de::Error, + { + _serde::#private::Ok(__Field::__other(_serde::#private::de::Content::U16(__value))) + } + + fn visit_u32<__E>(self, __value: u32) -> _serde::#private::Result + where + __E: _serde::de::Error, + { + _serde::#private::Ok(__Field::__other(_serde::#private::de::Content::U32(__value))) + } + + fn visit_u64<__E>(self, __value: u64) -> _serde::#private::Result + where + __E: _serde::de::Error, + { + _serde::#private::Ok(__Field::__other(_serde::#private::de::Content::U64(__value))) + } + + fn visit_f32<__E>(self, __value: f32) -> _serde::#private::Result + where + __E: _serde::de::Error, + { + _serde::#private::Ok(__Field::__other(_serde::#private::de::Content::F32(__value))) + } + + fn visit_f64<__E>(self, __value: f64) -> _serde::#private::Result + where + __E: _serde::de::Error, + { + _serde::#private::Ok(__Field::__other(_serde::#private::de::Content::F64(__value))) + } + + fn visit_char<__E>(self, __value: char) -> _serde::#private::Result + where + __E: _serde::de::Error, + { + _serde::#private::Ok(__Field::__other(_serde::#private::de::Content::Char(__value))) + } + + fn visit_unit<__E>(self) -> _serde::#private::Result + where + __E: _serde::de::Error, + { + _serde::#private::Ok(__Field::__other(_serde::#private::de::Content::Unit)) + } + } + } else { + let u64_mapping = deserialized_fields.iter().enumerate().map(|(i, field)| { + let i = i as u64; + let ident = &field.ident; + quote!(#i => _serde::#private::Ok(#this_value::#ident)) + }); + + let u64_fallthrough_arm_tokens; + let u64_fallthrough_arm = if let Some(fallthrough) = &fallthrough { + fallthrough + } else { + let index_expecting = if is_variant { "variant" } else { "field" }; + let fallthrough_msg = format!( + "{} index 0 <= i < {}", + index_expecting, + deserialized_fields.len(), + ); + u64_fallthrough_arm_tokens = quote! { + _serde::#private::Err(_serde::de::Error::invalid_value( + _serde::de::Unexpected::Unsigned(__value), + &#fallthrough_msg, + )) + }; + &u64_fallthrough_arm_tokens + }; + + quote! { + fn visit_u64<__E>(self, __value: u64) -> _serde::#private::Result + where + __E: _serde::de::Error, + { + match __value { + #(#u64_mapping,)* + _ => #u64_fallthrough_arm, + } + } + } + }; + + let visit_borrowed = if fallthrough_borrowed.is_some() || collect_other_fields { + let str_mapping = str_mapping.clone(); + let bytes_mapping = bytes_mapping.clone(); + let fallthrough_borrowed_arm = fallthrough_borrowed.as_ref().unwrap_or(fallthrough_arm); + Some(quote! { + fn visit_borrowed_str<__E>(self, __value: &'de str) -> _serde::#private::Result + where + __E: _serde::de::Error, + { + match __value { + #(#str_mapping)* + _ => { + #value_as_borrowed_str_content + #fallthrough_borrowed_arm + } + } + } + + fn visit_borrowed_bytes<__E>(self, __value: &'de [u8]) -> _serde::#private::Result + where + __E: _serde::de::Error, + { + match __value { + #(#bytes_mapping)* + _ => { + #bytes_to_str + #value_as_borrowed_bytes_content + #fallthrough_borrowed_arm + } + } + } + }) + } else { + None + }; + + quote_block! { + fn expecting(&self, __formatter: &mut _serde::#private::Formatter) -> _serde::#private::fmt::Result { + _serde::#private::Formatter::write_str(__formatter, #expecting) + } + + #visit_other + + fn visit_str<__E>(self, __value: &str) -> _serde::#private::Result + where + __E: _serde::de::Error, + { + match __value { + #(#str_mapping)* + _ => { + #value_as_str_content + #fallthrough_arm + } + } + } + + fn visit_bytes<__E>(self, __value: &[u8]) -> _serde::#private::Result + where + __E: _serde::de::Error, + { + match __value { + #(#bytes_mapping)* + _ => { + #bytes_to_str + #value_as_bytes_content + #fallthrough_arm + } + } + } + + #visit_borrowed + } +} diff --git a/crates/web-rwkv-derive/src/serde/de/struct_.rs b/crates/web-rwkv-derive/src/serde/de/struct_.rs new file mode 100644 index 00000000..72ae7f76 --- /dev/null +++ b/crates/web-rwkv-derive/src/serde/de/struct_.rs @@ -0,0 +1,704 @@ +use crate::serde::de::identifier; +use crate::serde::de::{ + deserialize_seq, expr_is_missing, field_i, has_flatten, wrap_deserialize_field_with, + FieldWithAliases, Parameters, StructForm, +}; +#[cfg(feature = "deserialize_in_place")] +use crate::serde::de::{deserialize_seq_in_place, place_lifetime}; +use crate::serde::fragment::{Expr, Fragment, Match, Stmts}; +use crate::serde::internals::ast::Field; +use crate::serde::internals::attr; +use crate::serde::private; +use proc_macro2::TokenStream; +use quote::{quote, quote_spanned}; +use syn::spanned::Spanned; + +/// Generates `Deserialize::deserialize` body for a `struct Struct {...}` +pub(super) fn deserialize( + params: &Parameters, + fields: &[Field], + cattrs: &attr::Container, + form: StructForm, +) -> Fragment { + let this_type = ¶ms.this_type; + let this_value = ¶ms.this_value; + let (de_impl_generics, de_ty_generics, ty_generics, where_clause) = + params.generics_with_de_lifetime(); + let delife = params.borrowed.de_lifetime(); + let context = ¶ms.context; + + // If there are getters (implying private fields), construct the local type + // and use an `Into` conversion to get the remote type. If there are no + // getters then construct the target type directly. + let construct = if params.has_getter { + let local = ¶ms.local; + quote!(#local) + } else { + quote!(#this_value) + }; + + let type_path = match form { + StructForm::Struct => construct, + StructForm::ExternallyTagged(variant_ident) + | StructForm::InternallyTagged(variant_ident) + | StructForm::Untagged(variant_ident) => quote!(#construct::#variant_ident), + }; + let expecting = match form { + StructForm::Struct => format!("struct {}", params.type_name()), + StructForm::ExternallyTagged(variant_ident) + | StructForm::InternallyTagged(variant_ident) + | StructForm::Untagged(variant_ident) => { + format!("struct variant {}::{}", params.type_name(), variant_ident) + } + }; + let expecting = cattrs.expecting().unwrap_or(&expecting); + + let deserialized_fields: Vec<_> = fields + .iter() + .enumerate() + // Skip fields that shouldn't be deserialized or that were flattened, + // so they don't appear in the storage in their literal form + .filter(|&(_, field)| !field.attrs.skip_deserializing() && !field.attrs.flatten()) + .map(|(i, field)| FieldWithAliases { + ident: field_i(i), + aliases: field.attrs.aliases(), + }) + .collect(); + + let has_flatten = has_flatten(fields); + let field_visitor = deserialize_field_identifier(&deserialized_fields, cattrs, has_flatten); + + // untagged struct variants do not get a visit_seq method. The same applies to + // structs that only have a map representation. + let visit_seq = match form { + StructForm::Untagged(_) => None, + _ if has_flatten => None, + _ => { + let mut_seq = if deserialized_fields.is_empty() { + quote!(_) + } else { + quote!(mut __seq) + }; + + let visit_seq = Stmts(deserialize_seq( + &type_path, params, fields, true, cattrs, expecting, + )); + + Some(quote! { + #[inline] + fn visit_seq<__A>(self, #mut_seq: __A) -> _serde::#private::Result + where + __A: _serde::de::SeqAccess<#delife>, + { + #visit_seq + } + }) + } + }; + let visit_map = Stmts(deserialize_map( + &type_path, + params, + fields, + cattrs, + has_flatten, + )); + + let visitor_seed = match form { + StructForm::ExternallyTagged(..) if has_flatten => Some(quote! { + #[automatically_derived] + impl #de_impl_generics _serde::de::DeserializeSeed<#delife> for __Visitor #de_ty_generics #where_clause { + type Value = #this_type #ty_generics; + + fn deserialize<__D>(self, __deserializer: __D) -> _serde::#private::Result + where + __D: _serde::Deserializer<#delife>, + { + _serde::Deserializer::deserialize_map(__deserializer, self) + } + } + }), + _ => None, + }; + + let fields_stmt = if has_flatten { + None + } else { + let field_names = deserialized_fields.iter().flat_map(|field| field.aliases); + + Some(quote! { + #[doc(hidden)] + const FIELDS: &'static [&'static str] = &[ #(#field_names),* ]; + }) + }; + + let visitor_expr = quote! { + __Visitor { + context: self.context, + marker: _serde::#private::PhantomData::<#this_type #ty_generics>, + lifetime: _serde::#private::PhantomData, + } + }; + let dispatch = match form { + StructForm::Struct if has_flatten => quote! { + _serde::Deserializer::deserialize_map(__deserializer, #visitor_expr) + }, + StructForm::Struct => { + let type_name = cattrs.name().deserialize_name(); + quote! { + _serde::Deserializer::deserialize_struct(__deserializer, #type_name, FIELDS, #visitor_expr) + } + } + StructForm::ExternallyTagged(_) if has_flatten => quote! { + _serde::de::VariantAccess::newtype_variant_seed(__variant, #visitor_expr) + }, + StructForm::ExternallyTagged(_) => quote! { + _serde::de::VariantAccess::struct_variant(__variant, FIELDS, #visitor_expr) + }, + StructForm::InternallyTagged(_) => quote! { + _serde::Deserializer::deserialize_any(__deserializer, #visitor_expr) + }, + StructForm::Untagged(_) => quote! { + _serde::Deserializer::deserialize_any(__deserializer, #visitor_expr) + }, + }; + + quote_block! { + #field_visitor + + #[doc(hidden)] + struct __Visitor #de_impl_generics #where_clause { + context: &#delife #context, + marker: _serde::#private::PhantomData<#this_type #ty_generics>, + lifetime: _serde::#private::PhantomData<&#delife ()>, + } + + #[automatically_derived] + impl #de_impl_generics _serde::de::Visitor<#delife> for __Visitor #de_ty_generics #where_clause { + type Value = #this_type #ty_generics; + + fn expecting(&self, __formatter: &mut _serde::#private::Formatter) -> _serde::#private::fmt::Result { + _serde::#private::Formatter::write_str(__formatter, #expecting) + } + + #visit_seq + + #[inline] + fn visit_map<__A>(self, mut __map: __A) -> _serde::#private::Result + where + __A: _serde::de::MapAccess<#delife>, + { + #visit_map + } + } + + #visitor_seed + + #fields_stmt + + #dispatch + } +} + +fn deserialize_map( + struct_path: &TokenStream, + params: &Parameters, + fields: &[Field], + cattrs: &attr::Container, + has_flatten: bool, +) -> Fragment { + // Create the field names for the fields. + let fields_names: Vec<_> = fields + .iter() + .enumerate() + .map(|(i, field)| (field, field_i(i))) + .collect(); + + // Declare each field that will be deserialized. + let let_values = fields_names + .iter() + .filter(|&&(field, _)| !field.attrs.skip_deserializing() && !field.attrs.flatten()) + .map(|(field, name)| { + let field_ty = field.ty; + quote! { + let mut #name: _serde::#private::Option<#field_ty> = _serde::#private::None; + } + }); + + // Collect contents for flatten fields into a buffer + let let_collect = if has_flatten { + Some(quote! { + let mut __collect = _serde::#private::Vec::<_serde::#private::Option<( + _serde::#private::de::Content, + _serde::#private::de::Content + )>>::new(); + }) + } else { + None + }; + + // Match arms to extract a value for a field. + let value_arms = fields_names + .iter() + .filter(|&&(field, _)| !field.attrs.skip_deserializing() && !field.attrs.flatten()) + .map(|(field, name)| { + let deser_name = field.attrs.name().deserialize_name(); + + let visit = match field.attrs.deserialize_with() { + None => { + let field_ty = field.ty; + let span = field.original.span(); + let seed = ¶ms.seed; + let context = ¶ms.context; + let func = + quote_spanned!(span=> _serde::de::MapAccess::next_value_seed::<#seed<#context, #field_ty>>); + quote! { + #func(&mut __map, #seed::new(self.context))? + } + } + Some(path) => { + let (wrapper, wrapper_ty) = wrap_deserialize_field_with(params, field.ty, path); + let seed = ¶ms.seed; + let context = ¶ms.context; + quote!({ + #wrapper + match _serde::de::MapAccess::next_value_seed::<#seed<#context, #wrapper_ty>>(&mut __map, #seed::new(self.context)) { + _serde::#private::Ok(__wrapper) => __wrapper.value, + _serde::#private::Err(__err) => { + return _serde::#private::Err(__err); + } + } + }) + } + }; + quote! { + __Field::#name => { + if _serde::#private::Option::is_some(&#name) { + return _serde::#private::Err(<__A::Error as _serde::de::Error>::duplicate_field(#deser_name)); + } + #name = _serde::#private::Some(#visit); + } + } + }); + + // Visit ignored values to consume them + let ignored_arm = if has_flatten { + Some(quote! { + __Field::__other(__name) => { + __collect.push(_serde::#private::Some(( + __name, + _serde::de::MapAccess::next_value_seed(&mut __map, _serde::#private::de::ContentVisitor::new())?))); + } + }) + } else if cattrs.deny_unknown_fields() { + None + } else { + Some(quote! { + _ => { let _ = _serde::de::MapAccess::next_value::<_serde::de::IgnoredAny>(&mut __map)?; } + }) + }; + + let all_skipped = fields.iter().all(|field| field.attrs.skip_deserializing()); + let match_keys = if cattrs.deny_unknown_fields() && all_skipped { + quote! { + // FIXME: Once feature(exhaustive_patterns) is stable: + // let _serde::#private::None::<__Field> = _serde::de::MapAccess::next_key(&mut __map)?; + _serde::#private::Option::map( + _serde::de::MapAccess::next_key::<__Field>(&mut __map)?, + |__impossible| match __impossible {}); + } + } else { + quote! { + while let _serde::#private::Some(__key) = _serde::de::MapAccess::next_key::<__Field>(&mut __map)? { + match __key { + #(#value_arms)* + #ignored_arm + } + } + } + }; + + let extract_values = fields_names + .iter() + .filter(|&&(field, _)| !field.attrs.skip_deserializing() && !field.attrs.flatten()) + .map(|(field, name)| { + let missing_expr = Match(expr_is_missing(field, cattrs)); + + quote! { + let #name = match #name { + _serde::#private::Some(#name) => #name, + _serde::#private::None => #missing_expr + }; + } + }); + + let extract_collected = fields_names + .iter() + .filter(|&&(field, _)| field.attrs.flatten() && !field.attrs.skip_deserializing()) + .map(|(field, name)| { + let field_ty = field.ty; + let func = match field.attrs.deserialize_with() { + None => { + let span = field.original.span(); + quote_spanned!(span=> _serde::de::Deserialize::deserialize) + } + Some(path) => quote!(#path), + }; + quote! { + let #name: #field_ty = #func( + _serde::#private::de::FlatMapDeserializer( + &mut __collect, + _serde::#private::PhantomData))?; + } + }); + + let collected_deny_unknown_fields = if has_flatten && cattrs.deny_unknown_fields() { + Some(quote! { + if let _serde::#private::Some(_serde::#private::Some((__key, _))) = + __collect.into_iter().filter(_serde::#private::Option::is_some).next() + { + if let _serde::#private::Some(__key) = _serde::#private::de::content_as_str(&__key) { + return _serde::#private::Err( + _serde::de::Error::custom(format_args!("unknown field `{}`", &__key))); + } else { + return _serde::#private::Err( + _serde::de::Error::custom(format_args!("unexpected map key"))); + } + } + }) + } else { + None + }; + + let result = fields_names.iter().map(|(field, name)| { + let member = &field.member; + if field.attrs.skip_deserializing() { + let value = Expr(expr_is_missing(field, cattrs)); + quote!(#member: #value) + } else { + quote!(#member: #name) + } + }); + + let let_default = match cattrs.default() { + attr::Default::Default => Some(quote!( + let __default: Self::Value = _serde::#private::Default::default(); + )), + // If #path returns wrong type, error will be reported here (^^^^^). + // We attach span of the path to the function so it will be reported + // on the #[serde(default = "...")] + // ^^^^^ + attr::Default::Path(path) => Some(quote_spanned!(path.span()=> + let __default: Self::Value = #path(); + )), + attr::Default::None => { + // We don't need the default value, to prevent an unused variable warning + // we'll leave the line empty. + None + } + }; + + let mut result = quote!(#struct_path { #(#result),* }); + if params.has_getter { + let this_type = ¶ms.this_type; + let (_, ty_generics, _) = params.generics.split_for_impl(); + result = quote! { + _serde::#private::Into::<#this_type #ty_generics>::into(#result) + }; + } + + quote_block! { + #(#let_values)* + + #let_collect + + #match_keys + + #let_default + + #(#extract_values)* + + #(#extract_collected)* + + #collected_deny_unknown_fields + + _serde::#private::Ok(#result) + } +} + +/// Generates `Deserialize::deserialize_in_place` body for a `struct Struct {...}` +#[cfg(feature = "deserialize_in_place")] +pub(super) fn deserialize_in_place( + params: &Parameters, + fields: &[Field], + cattrs: &attr::Container, +) -> Option { + // for now we do not support in_place deserialization for structs that + // are represented as map. + if has_flatten(fields) { + return None; + } + + let this_type = ¶ms.this_type; + let (de_impl_generics, de_ty_generics, ty_generics, where_clause) = + params.generics_with_de_lifetime(); + let delife = params.borrowed.de_lifetime(); + + let expecting = format!("struct {}", params.type_name()); + let expecting = cattrs.expecting().unwrap_or(&expecting); + + let deserialized_fields: Vec<_> = fields + .iter() + .enumerate() + .filter(|&(_, field)| !field.attrs.skip_deserializing()) + .map(|(i, field)| FieldWithAliases { + ident: field_i(i), + aliases: field.attrs.aliases(), + }) + .collect(); + + let field_visitor = deserialize_field_identifier(&deserialized_fields, cattrs, false); + + let mut_seq = if deserialized_fields.is_empty() { + quote!(_) + } else { + quote!(mut __seq) + }; + let visit_seq = Stmts(deserialize_seq_in_place(params, fields, cattrs, expecting)); + let visit_map = Stmts(deserialize_map_in_place(params, fields, cattrs)); + let field_names = deserialized_fields.iter().flat_map(|field| field.aliases); + let type_name = cattrs.name().deserialize_name(); + + let in_place_impl_generics = de_impl_generics.in_place(); + let in_place_ty_generics = de_ty_generics.in_place(); + let place_life = place_lifetime(); + + Some(quote_block! { + #field_visitor + + #[doc(hidden)] + struct __Visitor #in_place_impl_generics #where_clause { + place: &#place_life mut #this_type #ty_generics, + lifetime: _serde::#private::PhantomData<&#delife ()>, + } + + #[automatically_derived] + impl #in_place_impl_generics _serde::de::Visitor<#delife> for __Visitor #in_place_ty_generics #where_clause { + type Value = (); + + fn expecting(&self, __formatter: &mut _serde::#private::Formatter) -> _serde::#private::fmt::Result { + _serde::#private::Formatter::write_str(__formatter, #expecting) + } + + #[inline] + fn visit_seq<__A>(self, #mut_seq: __A) -> _serde::#private::Result + where + __A: _serde::de::SeqAccess<#delife>, + { + #visit_seq + } + + #[inline] + fn visit_map<__A>(self, mut __map: __A) -> _serde::#private::Result + where + __A: _serde::de::MapAccess<#delife>, + { + #visit_map + } + } + + #[doc(hidden)] + const FIELDS: &'static [&'static str] = &[ #(#field_names),* ]; + + _serde::Deserializer::deserialize_struct(__deserializer, #type_name, FIELDS, __Visitor { + place: __place, + lifetime: _serde::#private::PhantomData, + }) + }) +} + +#[cfg(feature = "deserialize_in_place")] +fn deserialize_map_in_place( + params: &Parameters, + fields: &[Field], + cattrs: &attr::Container, +) -> Fragment { + assert!( + !has_flatten(fields), + "inplace deserialization of maps does not support flatten fields" + ); + + // Create the field names for the fields. + let fields_names: Vec<_> = fields + .iter() + .enumerate() + .map(|(i, field)| (field, field_i(i))) + .collect(); + + // For deserialize_in_place, declare booleans for each field that will be + // deserialized. + let let_flags = fields_names + .iter() + .filter(|&&(field, _)| !field.attrs.skip_deserializing()) + .map(|(_, name)| { + quote! { + let mut #name: bool = false; + } + }); + + // Match arms to extract a value for a field. + let value_arms_from = fields_names + .iter() + .filter(|&&(field, _)| !field.attrs.skip_deserializing()) + .map(|(field, name)| { + let deser_name = field.attrs.name().deserialize_name(); + let member = &field.member; + + let visit = match field.attrs.deserialize_with() { + None => { + quote! { + _serde::de::MapAccess::next_value_seed(&mut __map, _serde::#private::de::InPlaceSeed(&mut self.place.#member))? + } + } + Some(path) => { + let (wrapper, wrapper_ty) = wrap_deserialize_field_with(params, field.ty, path); + quote!({ + #wrapper + self.place.#member = match _serde::de::MapAccess::next_value::<#wrapper_ty>(&mut __map) { + _serde::#private::Ok(__wrapper) => __wrapper.value, + _serde::#private::Err(__err) => { + return _serde::#private::Err(__err); + } + }; + }) + } + }; + quote! { + __Field::#name => { + if #name { + return _serde::#private::Err(<__A::Error as _serde::de::Error>::duplicate_field(#deser_name)); + } + #visit; + #name = true; + } + } + }); + + // Visit ignored values to consume them + let ignored_arm = if cattrs.deny_unknown_fields() { + None + } else { + Some(quote! { + _ => { let _ = _serde::de::MapAccess::next_value::<_serde::de::IgnoredAny>(&mut __map)?; } + }) + }; + + let all_skipped = fields.iter().all(|field| field.attrs.skip_deserializing()); + + let match_keys = if cattrs.deny_unknown_fields() && all_skipped { + quote! { + // FIXME: Once feature(exhaustive_patterns) is stable: + // let _serde::#private::None::<__Field> = _serde::de::MapAccess::next_key(&mut __map)?; + _serde::#private::Option::map( + _serde::de::MapAccess::next_key::<__Field>(&mut __map)?, + |__impossible| match __impossible {}); + } + } else { + quote! { + while let _serde::#private::Some(__key) = _serde::de::MapAccess::next_key::<__Field>(&mut __map)? { + match __key { + #(#value_arms_from)* + #ignored_arm + } + } + } + }; + + let check_flags = fields_names + .iter() + .filter(|&&(field, _)| !field.attrs.skip_deserializing()) + .map(|(field, name)| { + let missing_expr = expr_is_missing(field, cattrs); + // If missing_expr unconditionally returns an error, don't try + // to assign its value to self.place. + if field.attrs.default().is_none() + && cattrs.default().is_none() + && field.attrs.deserialize_with().is_some() + { + let missing_expr = Stmts(missing_expr); + quote! { + if !#name { + #missing_expr; + } + } + } else { + let member = &field.member; + let missing_expr = Expr(missing_expr); + quote! { + if !#name { + self.place.#member = #missing_expr; + }; + } + } + }); + + let this_type = ¶ms.this_type; + let (_, ty_generics, _) = params.generics.split_for_impl(); + + let let_default = match cattrs.default() { + attr::Default::Default => Some(quote!( + let __default: #this_type #ty_generics = _serde::#private::Default::default(); + )), + // If #path returns wrong type, error will be reported here (^^^^^). + // We attach span of the path to the function so it will be reported + // on the #[serde(default = "...")] + // ^^^^^ + attr::Default::Path(path) => Some(quote_spanned!(path.span()=> + let __default: #this_type #ty_generics = #path(); + )), + attr::Default::None => { + // We don't need the default value, to prevent an unused variable warning + // we'll leave the line empty. + None + } + }; + + quote_block! { + #(#let_flags)* + + #match_keys + + #let_default + + #(#check_flags)* + + _serde::#private::Ok(()) + } +} + +/// Generates enum and its `Deserialize` implementation that represents each +/// non-skipped field of the struct +fn deserialize_field_identifier( + deserialized_fields: &[FieldWithAliases], + cattrs: &attr::Container, + has_flatten: bool, +) -> Stmts { + let (ignore_variant, fallthrough) = if has_flatten { + let ignore_variant = quote!(__other(_serde::#private::de::Content<'de>),); + let fallthrough = quote!(_serde::#private::Ok(__Field::__other(__value))); + (Some(ignore_variant), Some(fallthrough)) + } else if cattrs.deny_unknown_fields() { + (None, None) + } else { + let ignore_variant = quote!(__ignore,); + let fallthrough = quote!(_serde::#private::Ok(__Field::__ignore)); + (Some(ignore_variant), Some(fallthrough)) + }; + + Stmts(identifier::deserialize_generated( + deserialized_fields, + has_flatten, + false, + ignore_variant, + fallthrough, + )) +} diff --git a/crates/web-rwkv-derive/src/serde/de/tuple.rs b/crates/web-rwkv-derive/src/serde/de/tuple.rs new file mode 100644 index 00000000..237e0871 --- /dev/null +++ b/crates/web-rwkv-derive/src/serde/de/tuple.rs @@ -0,0 +1,283 @@ +use crate::serde::de::{deserialize_seq, has_flatten, Parameters, TupleForm}; +#[cfg(feature = "deserialize_in_place")] +use crate::serde::de::{deserialize_seq_in_place, place_lifetime}; +use crate::serde::fragment::{Fragment, Stmts}; +use crate::serde::internals::ast::Field; +use crate::serde::internals::attr; +use crate::serde::private; +use proc_macro2::TokenStream; +use quote::{quote, quote_spanned}; +use syn::spanned::Spanned; + +/// Generates `Deserialize::deserialize` body for a `struct Tuple(...);` including `struct Newtype(T);` +pub(super) fn deserialize( + params: &Parameters, + fields: &[Field], + cattrs: &attr::Container, + form: TupleForm, +) -> Fragment { + assert!( + !has_flatten(fields), + "tuples and tuple variants cannot have flatten fields" + ); + + let field_count = fields + .iter() + .filter(|field| !field.attrs.skip_deserializing()) + .count(); + + let this_type = ¶ms.this_type; + let this_value = ¶ms.this_value; + let (de_impl_generics, de_ty_generics, ty_generics, where_clause) = + params.generics_with_de_lifetime(); + let delife = params.borrowed.de_lifetime(); + + // If there are getters (implying private fields), construct the local type + // and use an `Into` conversion to get the remote type. If there are no + // getters then construct the target type directly. + let construct = if params.has_getter { + let local = ¶ms.local; + quote!(#local) + } else { + quote!(#this_value) + }; + + let type_path = match form { + TupleForm::Tuple => construct, + TupleForm::ExternallyTagged(variant_ident) | TupleForm::Untagged(variant_ident) => { + quote!(#construct::#variant_ident) + } + }; + let expecting = match form { + TupleForm::Tuple => format!("tuple struct {}", params.type_name()), + TupleForm::ExternallyTagged(variant_ident) | TupleForm::Untagged(variant_ident) => { + format!("tuple variant {}::{}", params.type_name(), variant_ident) + } + }; + let expecting = cattrs.expecting().unwrap_or(&expecting); + + let nfields = fields.len(); + + let visit_newtype_struct = match form { + TupleForm::Tuple if nfields == 1 => { + Some(deserialize_newtype_struct(&type_path, params, &fields[0])) + } + _ => None, + }; + + let visit_seq = Stmts(deserialize_seq( + &type_path, params, fields, false, cattrs, expecting, + )); + + let visitor_expr = quote! { + __Visitor { + marker: _serde::#private::PhantomData::<#this_type #ty_generics>, + lifetime: _serde::#private::PhantomData, + } + }; + let dispatch = match form { + TupleForm::Tuple if nfields == 1 => { + let type_name = cattrs.name().deserialize_name(); + quote! { + _serde::Deserializer::deserialize_newtype_struct(__deserializer, #type_name, #visitor_expr) + } + } + TupleForm::Tuple => { + let type_name = cattrs.name().deserialize_name(); + quote! { + _serde::Deserializer::deserialize_tuple_struct(__deserializer, #type_name, #field_count, #visitor_expr) + } + } + TupleForm::ExternallyTagged(_) => quote! { + _serde::de::VariantAccess::tuple_variant(__variant, #field_count, #visitor_expr) + }, + TupleForm::Untagged(_) => quote! { + _serde::Deserializer::deserialize_tuple(__deserializer, #field_count, #visitor_expr) + }, + }; + + let visitor_var = if field_count == 0 { + quote!(_) + } else { + quote!(mut __seq) + }; + + quote_block! { + #[doc(hidden)] + struct __Visitor #de_impl_generics #where_clause { + marker: _serde::#private::PhantomData<#this_type #ty_generics>, + lifetime: _serde::#private::PhantomData<&#delife ()>, + } + + #[automatically_derived] + impl #de_impl_generics _serde::de::Visitor<#delife> for __Visitor #de_ty_generics #where_clause { + type Value = #this_type #ty_generics; + + fn expecting(&self, __formatter: &mut _serde::#private::Formatter) -> _serde::#private::fmt::Result { + _serde::#private::Formatter::write_str(__formatter, #expecting) + } + + #visit_newtype_struct + + #[inline] + fn visit_seq<__A>(self, #visitor_var: __A) -> _serde::#private::Result + where + __A: _serde::de::SeqAccess<#delife>, + { + #visit_seq + } + } + + #dispatch + } +} + +fn deserialize_newtype_struct( + type_path: &TokenStream, + params: &Parameters, + field: &Field, +) -> TokenStream { + let delife = params.borrowed.de_lifetime(); + let field_ty = field.ty; + let deserializer_var = quote!(__e); + + let value = match field.attrs.deserialize_with() { + None => { + let span = field.original.span(); + let func = quote_spanned!(span=> <#field_ty as _serde::Deserialize>::deserialize); + quote! { + #func(#deserializer_var)? + } + } + Some(path) => { + // If #path returns wrong type, error will be reported here (^^^^^). + // We attach span of the path to the function so it will be reported + // on the #[serde(with = "...")] + // ^^^^^ + quote_spanned! {path.span()=> + #path(#deserializer_var)? + } + } + }; + + let mut result = quote!(#type_path(__field0)); + if params.has_getter { + let this_type = ¶ms.this_type; + let (_, ty_generics, _) = params.generics.split_for_impl(); + result = quote! { + _serde::#private::Into::<#this_type #ty_generics>::into(#result) + }; + } + + quote! { + #[inline] + fn visit_newtype_struct<__E>(self, #deserializer_var: __E) -> _serde::#private::Result + where + __E: _serde::Deserializer<#delife>, + { + let __field0: #field_ty = #value; + _serde::#private::Ok(#result) + } + } +} + +/// Generates `Deserialize::deserialize_in_place` body for a `struct Tuple(...);` including `struct Newtype(T);` +#[cfg(feature = "deserialize_in_place")] +pub(super) fn deserialize_in_place( + params: &Parameters, + fields: &[Field], + cattrs: &attr::Container, +) -> Fragment { + assert!( + !has_flatten(fields), + "tuples and tuple variants cannot have flatten fields" + ); + + let field_count = fields + .iter() + .filter(|field| !field.attrs.skip_deserializing()) + .count(); + + let this_type = ¶ms.this_type; + let (de_impl_generics, de_ty_generics, ty_generics, where_clause) = + params.generics_with_de_lifetime(); + let delife = params.borrowed.de_lifetime(); + + let expecting = format!("tuple struct {}", params.type_name()); + let expecting = cattrs.expecting().unwrap_or(&expecting); + + let nfields = fields.len(); + + let visit_newtype_struct = if nfields == 1 { + // We do not generate deserialize_in_place if every field has a + // deserialize_with. + assert!(fields[0].attrs.deserialize_with().is_none()); + + Some(quote! { + #[inline] + fn visit_newtype_struct<__E>(self, __e: __E) -> _serde::#private::Result + where + __E: _serde::Deserializer<#delife>, + { + _serde::Deserialize::deserialize_in_place(__e, &mut self.place.0) + } + }) + } else { + None + }; + + let visit_seq = Stmts(deserialize_seq_in_place(params, fields, cattrs, expecting)); + + let visitor_expr = quote! { + __Visitor { + place: __place, + lifetime: _serde::#private::PhantomData, + } + }; + + let type_name = cattrs.name().deserialize_name(); + let dispatch = if nfields == 1 { + quote!(_serde::Deserializer::deserialize_newtype_struct(__deserializer, #type_name, #visitor_expr)) + } else { + quote!(_serde::Deserializer::deserialize_tuple_struct(__deserializer, #type_name, #field_count, #visitor_expr)) + }; + + let visitor_var = if field_count == 0 { + quote!(_) + } else { + quote!(mut __seq) + }; + + let in_place_impl_generics = de_impl_generics.in_place(); + let in_place_ty_generics = de_ty_generics.in_place(); + let place_life = place_lifetime(); + + quote_block! { + #[doc(hidden)] + struct __Visitor #in_place_impl_generics #where_clause { + place: &#place_life mut #this_type #ty_generics, + lifetime: _serde::#private::PhantomData<&#delife ()>, + } + + #[automatically_derived] + impl #in_place_impl_generics _serde::de::Visitor<#delife> for __Visitor #in_place_ty_generics #where_clause { + type Value = (); + + fn expecting(&self, __formatter: &mut _serde::#private::Formatter) -> _serde::#private::fmt::Result { + _serde::#private::Formatter::write_str(__formatter, #expecting) + } + + #visit_newtype_struct + + #[inline] + fn visit_seq<__A>(self, #visitor_var: __A) -> _serde::#private::Result + where + __A: _serde::de::SeqAccess<#delife>, + { + #visit_seq + } + } + + #dispatch + } +} diff --git a/crates/web-rwkv-derive/src/serde/de/unit.rs b/crates/web-rwkv-derive/src/serde/de/unit.rs new file mode 100644 index 00000000..9e63f91a --- /dev/null +++ b/crates/web-rwkv-derive/src/serde/de/unit.rs @@ -0,0 +1,52 @@ +use crate::serde::de::Parameters; +use crate::serde::fragment::Fragment; +use crate::serde::internals::attr; +use crate::serde::private; +use quote::quote; + +/// Generates `Deserialize::deserialize` body for a `struct Unit;` +pub(super) fn deserialize(params: &Parameters, cattrs: &attr::Container) -> Fragment { + let this_type = ¶ms.this_type; + let this_value = ¶ms.this_value; + let type_name = cattrs.name().deserialize_name(); + let (de_impl_generics, de_ty_generics, ty_generics, where_clause) = + params.generics_with_de_lifetime(); + let delife = params.borrowed.de_lifetime(); + + let expecting = format!("unit struct {}", params.type_name()); + let expecting = cattrs.expecting().unwrap_or(&expecting); + + quote_block! { + #[doc(hidden)] + struct __Visitor #de_impl_generics #where_clause { + marker: _serde::#private::PhantomData<#this_type #ty_generics>, + lifetime: _serde::#private::PhantomData<&#delife ()>, + } + + #[automatically_derived] + impl #de_impl_generics _serde::de::Visitor<#delife> for __Visitor #de_ty_generics #where_clause { + type Value = #this_type #ty_generics; + + fn expecting(&self, __formatter: &mut _serde::#private::Formatter) -> _serde::#private::fmt::Result { + _serde::#private::Formatter::write_str(__formatter, #expecting) + } + + #[inline] + fn visit_unit<__E>(self) -> _serde::#private::Result + where + __E: _serde::de::Error, + { + _serde::#private::Ok(#this_value) + } + } + + _serde::Deserializer::deserialize_unit_struct( + __deserializer, + #type_name, + __Visitor { + marker: _serde::#private::PhantomData::<#this_type #ty_generics>, + lifetime: _serde::#private::PhantomData, + }, + ) + } +} diff --git a/crates/web-rwkv-derive/src/serde/deprecated.rs b/crates/web-rwkv-derive/src/serde/deprecated.rs new file mode 100644 index 00000000..3abdc1b2 --- /dev/null +++ b/crates/web-rwkv-derive/src/serde/deprecated.rs @@ -0,0 +1,56 @@ +use proc_macro2::TokenStream; +use quote::quote; + +pub fn allow_deprecated(input: &syn::DeriveInput) -> Option { + if should_allow_deprecated(input) { + Some(quote! { #[allow(deprecated)] }) + } else { + None + } +} + +/// Determine if an `#[allow(deprecated)]` should be added to the derived impl. +/// +/// This should happen if the derive input or an enum variant it contains has +/// one of: +/// - `#[deprecated]` +/// - `#[allow(deprecated)]` +fn should_allow_deprecated(input: &syn::DeriveInput) -> bool { + if contains_deprecated(&input.attrs) { + return true; + } + if let syn::Data::Enum(data_enum) = &input.data { + for variant in &data_enum.variants { + if contains_deprecated(&variant.attrs) { + return true; + } + } + } + false +} + +/// Check whether the given attributes contains one of: +/// - `#[deprecated]` +/// - `#[allow(deprecated)]` +fn contains_deprecated(attrs: &[syn::Attribute]) -> bool { + for attr in attrs { + if attr.path().is_ident("deprecated") { + return true; + } + if let syn::Meta::List(meta_list) = &attr.meta { + if meta_list.path.is_ident("allow") { + let mut allow_deprecated = false; + let _ = meta_list.parse_nested_meta(|meta| { + if meta.path.is_ident("deprecated") { + allow_deprecated = true; + } + Ok(()) + }); + if allow_deprecated { + return true; + } + } + } + } + false +} diff --git a/crates/web-rwkv-derive/src/serde/dummy.rs b/crates/web-rwkv-derive/src/serde/dummy.rs index 095f950f..e0bca647 100644 --- a/crates/web-rwkv-derive/src/serde/dummy.rs +++ b/crates/web-rwkv-derive/src/serde/dummy.rs @@ -14,9 +14,17 @@ pub fn wrap_in_const(serde_path: Option<&syn::Path>, code: TokenStream) -> Token quote! { #[doc(hidden)] - #[allow(non_upper_case_globals, unused_attributes, unused_qualifications)] + #[allow( + non_upper_case_globals, + unused_attributes, + unused_qualifications, + clippy::absolute_paths, + )] const _: () = { #use_serde + + _serde::__require_serde_not_serde_core!(); + #code }; } diff --git a/crates/web-rwkv-derive/src/serde/fragment.rs b/crates/web-rwkv-derive/src/serde/fragment.rs index dff8e22d..a24d0dc3 100644 --- a/crates/web-rwkv-derive/src/serde/fragment.rs +++ b/crates/web-rwkv-derive/src/serde/fragment.rs @@ -10,6 +10,18 @@ pub enum Fragment { Block(TokenStream), } +macro_rules! quote_expr { + ($($tt:tt)*) => { + $crate::serde::fragment::Fragment::Expr(quote!($($tt)*)) + } +} + +macro_rules! quote_block { + ($($tt:tt)*) => { + $crate::serde::fragment::Fragment::Block(quote!($($tt)*)) + } +} + /// Interpolate a fragment in place of an expression. This involves surrounding /// Block fragments in curly braces. pub struct Expr(pub Fragment); diff --git a/crates/web-rwkv-derive/src/serde/internals/ast.rs b/crates/web-rwkv-derive/src/serde/internals/ast.rs index 8729b46b..5d28b057 100644 --- a/crates/web-rwkv-derive/src/serde/internals/ast.rs +++ b/crates/web-rwkv-derive/src/serde/internals/ast.rs @@ -1,6 +1,7 @@ //! A Serde ast, parsed from the Syn ast and ready to generate Rust code. use crate::serde::internals::{attr, check, Ctxt, Derive}; +use proc_macro2::Ident; use syn::punctuated::Punctuated; use syn::Token; @@ -62,13 +63,17 @@ impl<'a> Container<'a> { cx: &Ctxt, item: &'a syn::DeriveInput, derive: Derive, + private: &Ident, ) -> Option> { - let mut attrs = attr::Container::from_ast(cx, item); + let attrs = attr::Container::from_ast(cx, item); let mut data = match &item.data { - syn::Data::Enum(data) => Data::Enum(enum_from_ast(cx, &data.variants, attrs.default())), + syn::Data::Enum(data) => { + Data::Enum(enum_from_ast(cx, &data.variants, attrs.default(), private)) + } syn::Data::Struct(data) => { - let (style, fields) = struct_from_ast(cx, &data.fields, None, attrs.default()); + let (style, fields) = + struct_from_ast(cx, &data.fields, None, attrs.default(), private); Data::Struct(style, fields) } syn::Data::Union(_) => { @@ -77,15 +82,11 @@ impl<'a> Container<'a> { } }; - let mut has_flatten = false; match &mut data { Data::Enum(variants) => { for variant in variants { variant.attrs.rename_by_rules(attrs.rename_all_rules()); for field in &mut variant.fields { - if field.attrs.flatten() { - has_flatten = true; - } field.attrs.rename_by_rules( variant .attrs @@ -97,18 +98,11 @@ impl<'a> Container<'a> { } Data::Struct(_, fields) => { for field in fields { - if field.attrs.flatten() { - has_flatten = true; - } field.attrs.rename_by_rules(attrs.rename_all_rules()); } } } - if has_flatten { - attrs.mark_has_flatten(); - } - let mut item = Container { ident: item.ident.clone(), attrs, @@ -140,13 +134,19 @@ fn enum_from_ast<'a>( cx: &Ctxt, variants: &'a Punctuated, container_default: &attr::Default, + private: &Ident, ) -> Vec> { let variants: Vec = variants .iter() .map(|variant| { let attrs = attr::Variant::from_ast(cx, variant); - let (style, fields) = - struct_from_ast(cx, &variant.fields, Some(&attrs), container_default); + let (style, fields) = struct_from_ast( + cx, + &variant.fields, + Some(&attrs), + container_default, + private, + ); Variant { ident: variant.ident.clone(), attrs, @@ -176,19 +176,20 @@ fn struct_from_ast<'a>( fields: &'a syn::Fields, attrs: Option<&attr::Variant>, container_default: &attr::Default, + private: &Ident, ) -> (Style, Vec>) { match fields { syn::Fields::Named(fields) => ( Style::Struct, - fields_from_ast(cx, &fields.named, attrs, container_default), + fields_from_ast(cx, &fields.named, attrs, container_default, private), ), syn::Fields::Unnamed(fields) if fields.unnamed.len() == 1 => ( Style::Newtype, - fields_from_ast(cx, &fields.unnamed, attrs, container_default), + fields_from_ast(cx, &fields.unnamed, attrs, container_default, private), ), syn::Fields::Unnamed(fields) => ( Style::Tuple, - fields_from_ast(cx, &fields.unnamed, attrs, container_default), + fields_from_ast(cx, &fields.unnamed, attrs, container_default, private), ), syn::Fields::Unit => (Style::Unit, Vec::new()), } @@ -199,6 +200,7 @@ fn fields_from_ast<'a>( fields: &'a Punctuated, attrs: Option<&attr::Variant>, container_default: &attr::Default, + private: &Ident, ) -> Vec> { fields .iter() @@ -208,7 +210,7 @@ fn fields_from_ast<'a>( Some(ident) => syn::Member::Named(ident.clone()), None => syn::Member::Unnamed(i.into()), }, - attrs: attr::Field::from_ast(cx, i, field, attrs, container_default), + attrs: attr::Field::from_ast(cx, i, field, attrs, container_default, private), ty: &field.ty, original: field, }) diff --git a/crates/web-rwkv-derive/src/serde/internals/attr.rs b/crates/web-rwkv-derive/src/serde/internals/attr.rs index 9f78ef7e..a5d8d5d0 100644 --- a/crates/web-rwkv-derive/src/serde/internals/attr.rs +++ b/crates/web-rwkv-derive/src/serde/internals/attr.rs @@ -1,3 +1,4 @@ +use crate::serde::internals::name::{MultiName, Name}; use crate::serde::internals::symbol::*; use crate::serde::internals::{ungroup, Ctxt}; use proc_macro2::{Spacing, Span, TokenStream, TokenTree}; @@ -8,6 +9,7 @@ use std::iter::FromIterator; use syn::meta::ParseNestedMeta; use syn::parse::ParseStream; use syn::punctuated::Punctuated; +use syn::spanned::Spanned; use syn::{parse_quote, token, Ident, Lifetime, Token}; // This module handles parsing of `#[serde(...)]` attributes. The entrypoints @@ -20,7 +22,7 @@ use syn::{parse_quote, token, Ident, Lifetime, Token}; pub use crate::serde::internals::case::RenameRule; -struct Attr<'c, T> { +pub(crate) struct Attr<'c, T> { cx: &'c Ctxt, name: Symbol, tokens: TokenStream, @@ -61,7 +63,7 @@ impl<'c, T> Attr<'c, T> { } } - fn get(self) -> Option { + pub(crate) fn get(self) -> Option { self.value } @@ -89,7 +91,7 @@ impl<'c> BoolAttr<'c> { } } -struct VecAttr<'c, T> { +pub(crate) struct VecAttr<'c, T> { cx: &'c Ctxt, name: Symbol, first_dup_tokens: TokenStream, @@ -124,69 +126,19 @@ impl<'c, T> VecAttr<'c, T> { } } - fn get(self) -> Vec { + pub(crate) fn get(self) -> Vec { self.values } } -pub struct Name { - serialize: String, - serialize_renamed: bool, - deserialize: String, - deserialize_renamed: bool, - deserialize_aliases: BTreeSet, -} - -fn unraw(ident: &Ident) -> String { - ident.to_string().trim_start_matches("r#").to_owned() -} - -impl Name { - fn from_attrs( - source_name: String, - ser_name: Attr, - de_name: Attr, - de_aliases: Option>, - ) -> Name { - let mut alias_set = BTreeSet::new(); - if let Some(de_aliases) = de_aliases { - for alias_name in de_aliases.get() { - alias_set.insert(alias_name); - } - } - - let ser_name = ser_name.get(); - let ser_renamed = ser_name.is_some(); - let de_name = de_name.get(); - let de_renamed = de_name.is_some(); - Name { - serialize: ser_name.unwrap_or_else(|| source_name.clone()), - serialize_renamed: ser_renamed, - deserialize: de_name.unwrap_or(source_name), - deserialize_renamed: de_renamed, - deserialize_aliases: alias_set, - } - } - - /// Return the container name for the container when serializing. - pub fn serialize_name(&self) -> &str { - &self.serialize - } - - /// Return the container name for the container when deserializing. - pub fn deserialize_name(&self) -> &str { - &self.deserialize - } - - fn deserialize_aliases(&self) -> &BTreeSet { - &self.deserialize_aliases - } +fn unraw(ident: &Ident) -> Ident { + Ident::new(ident.to_string().trim_start_matches("r#"), ident.span()) } #[derive(Copy, Clone)] pub struct RenameAllRules { - serialize: RenameRule, - deserialize: RenameRule, + pub serialize: RenameRule, + pub deserialize: RenameRule, } impl RenameAllRules { @@ -204,7 +156,7 @@ impl RenameAllRules { pub struct Container { seed: syn::Path, context: syn::Path, - name: Name, + name: MultiName, transparent: bool, deny_unknown_fields: bool, default: Default, @@ -275,6 +227,7 @@ pub enum Identifier { } impl Identifier { + #[cfg(feature = "deserialize_in_place")] pub fn is_some(self) -> bool { match self { Identifier::No => false, @@ -340,8 +293,8 @@ impl Container { // #[serde(rename = "foo")] // #[serde(rename(serialize = "foo", deserialize = "bar"))] let (ser, de) = get_renames(cx, RENAME, &meta)?; - ser_name.set_opt(&meta.path, ser.as_ref().map(syn::LitStr::value)); - de_name.set_opt(&meta.path, de.as_ref().map(syn::LitStr::value)); + ser_name.set_opt(&meta.path, ser.as_ref().map(Name::from)); + de_name.set_opt(&meta.path, de.as_ref().map(Name::from)); } else if meta.path == RENAME_ALL { // #[serde(rename_all = "foo")] // #[serde(rename_all(serialize = "foo", deserialize = "bar"))] @@ -556,7 +509,7 @@ impl Container { } else { let path = meta.path.to_token_stream().to_string().replace(' ', ""); return Err( - meta.error(format_args!("unknown serde container attribute `{path}`")) + meta.error(format_args!("unknown serde container attribute `{}`", path)) ); } Ok(()) @@ -587,7 +540,7 @@ impl Container { Container { seed, context, - name: Name::from_attrs(unraw(&item.ident), ser_name, de_name, None), + name: MultiName::from_attrs(Name::from(&unraw(&item.ident)), ser_name, de_name, None), transparent: transparent.get(), deny_unknown_fields: deny_unknown_fields.get(), default: default.get().unwrap_or(Default::None), @@ -623,7 +576,7 @@ impl Container { &self.context } - pub fn name(&self) -> &Name { + pub fn name(&self) -> &MultiName { &self.name } @@ -818,7 +771,7 @@ fn decide_identifier( /// Represents variant attribute information pub struct Variant { - name: Name, + name: MultiName, rename_all_rules: RenameAllRules, ser_bound: Option>, de_bound: Option>, @@ -854,7 +807,7 @@ impl Variant { let mut untagged = BoolAttr::none(cx, UNTAGGED); for attr in &variant.attrs { - if attr.path() != SERDE_SEED { + if attr.path() != SERDE { continue; } @@ -869,15 +822,15 @@ impl Variant { // #[serde(rename = "foo")] // #[serde(rename(serialize = "foo", deserialize = "bar"))] let (ser, de) = get_multiple_renames(cx, &meta)?; - ser_name.set_opt(&meta.path, ser.as_ref().map(syn::LitStr::value)); + ser_name.set_opt(&meta.path, ser.as_ref().map(Name::from)); for de_value in de { - de_name.set_if_none(de_value.value()); - de_aliases.insert(&meta.path, de_value.value()); + de_name.set_if_none(Name::from(&de_value)); + de_aliases.insert(&meta.path, Name::from(&de_value)); } } else if meta.path == ALIAS { // #[serde(alias = "foo")] if let Some(s) = get_lit_str(cx, ALIAS, &meta)? { - de_aliases.insert(&meta.path, s.value()); + de_aliases.insert(&meta.path, Name::from(&s)); } } else if meta.path == RENAME_ALL { // #[serde(rename_all = "foo")] @@ -926,13 +879,13 @@ impl Variant { ser_path .path .segments - .push(Ident::new("serialize", Span::call_site()).into()); + .push(Ident::new("serialize", ser_path.span()).into()); serialize_with.set(&meta.path, ser_path); let mut de_path = path; de_path .path .segments - .push(Ident::new("deserialize", Span::call_site()).into()); + .push(Ident::new("deserialize", de_path.span()).into()); deserialize_with.set(&meta.path, de_path); } } else if meta.path == SERIALIZE_WITH { @@ -974,7 +927,7 @@ impl Variant { } else { let path = meta.path.to_token_stream().to_string().replace(' ', ""); return Err( - meta.error(format_args!("unknown serde variant attribute `{path}`")) + meta.error(format_args!("unknown serde variant attribute `{}`", path)) ); } Ok(()) @@ -984,7 +937,12 @@ impl Variant { } Variant { - name: Name::from_attrs(unraw(&variant.ident), ser_name, de_name, Some(de_aliases)), + name: MultiName::from_attrs( + Name::from(&unraw(&variant.ident)), + ser_name, + de_name, + Some(de_aliases), + ), rename_all_rules: RenameAllRules { serialize: rename_all_ser_rule.get().unwrap_or(RenameRule::None), deserialize: rename_all_de_rule.get().unwrap_or(RenameRule::None), @@ -1001,20 +959,23 @@ impl Variant { } } - pub fn name(&self) -> &Name { + pub fn name(&self) -> &MultiName { &self.name } - pub fn aliases(&self) -> &BTreeSet { + pub fn aliases(&self) -> &BTreeSet { self.name.deserialize_aliases() } pub fn rename_by_rules(&mut self, rules: RenameAllRules) { if !self.name.serialize_renamed { - self.name.serialize = rules.serialize.apply_to_variant(&self.name.serialize); + self.name.serialize.value = + rules.serialize.apply_to_variant(&self.name.serialize.value); } if !self.name.deserialize_renamed { - self.name.deserialize = rules.deserialize.apply_to_variant(&self.name.deserialize); + self.name.deserialize.value = rules + .deserialize + .apply_to_variant(&self.name.deserialize.value); } self.name .deserialize_aliases @@ -1060,7 +1021,7 @@ impl Variant { /// Represents field attribute information pub struct Field { - name: Name, + name: MultiName, skip_serializing: bool, skip_deserializing: bool, skip_serializing_if: Option, @@ -1076,7 +1037,6 @@ pub struct Field { } /// Represents the default to use for a field when deserializing. -#[allow(clippy::enum_variant_names)] pub enum Default { /// Field must always be specified because it does not have a default. None, @@ -1103,6 +1063,7 @@ impl Field { field: &syn::Field, attrs: Option<&Variant>, container_default: &Default, + private: &Ident, ) -> Self { let mut ser_name = Attr::none(cx, RENAME); let mut de_name = Attr::none(cx, RENAME); @@ -1120,8 +1081,11 @@ impl Field { let mut flatten = BoolAttr::none(cx, FLATTEN); let ident = match &field.ident { - Some(ident) => unraw(ident), - None => index.to_string(), + Some(ident) => Name::from(&unraw(ident)), + None => Name { + value: index.to_string(), + span: Span::call_site(), + }, }; if let Some(borrow_attribute) = attrs.and_then(|variant| variant.borrow.as_ref()) { @@ -1129,7 +1093,8 @@ impl Field { if let Some(lifetimes) = &borrow_attribute.lifetimes { for lifetime in lifetimes { if !borrowable.contains(lifetime) { - let msg = format!("field `{ident}` does not have lifetime {lifetime}"); + let msg = + format!("field `{}` does not have lifetime {}", ident, lifetime); cx.error_spanned_by(field, msg); } } @@ -1141,7 +1106,7 @@ impl Field { } for attr in &field.attrs { - if attr.path() != SERDE_SEED { + if attr.path() != SERDE { continue; } @@ -1156,15 +1121,15 @@ impl Field { // #[serde(rename = "foo")] // #[serde(rename(serialize = "foo", deserialize = "bar"))] let (ser, de) = get_multiple_renames(cx, &meta)?; - ser_name.set_opt(&meta.path, ser.as_ref().map(syn::LitStr::value)); + ser_name.set_opt(&meta.path, ser.as_ref().map(Name::from)); for de_value in de { - de_name.set_if_none(de_value.value()); - de_aliases.insert(&meta.path, de_value.value()); + de_name.set_if_none(Name::from(&de_value)); + de_aliases.insert(&meta.path, Name::from(&de_value)); } } else if meta.path == ALIAS { // #[serde(alias = "foo")] if let Some(s) = get_lit_str(cx, ALIAS, &meta)? { - de_aliases.insert(&meta.path, s.value()); + de_aliases.insert(&meta.path, Name::from(&s)); } } else if meta.path == DEFAULT { if meta.input.peek(Token![=]) { @@ -1208,13 +1173,13 @@ impl Field { ser_path .path .segments - .push(Ident::new("serialize", Span::call_site()).into()); + .push(Ident::new("serialize", ser_path.span()).into()); serialize_with.set(&meta.path, ser_path); let mut de_path = path; de_path .path .segments - .push(Ident::new("deserialize", Span::call_site()).into()); + .push(Ident::new("deserialize", de_path.span()).into()); deserialize_with.set(&meta.path, de_path); } } else if meta.path == BOUND { @@ -1231,7 +1196,8 @@ impl Field { for lifetime in &lifetimes { if !borrowable.contains(lifetime) { let msg = format!( - "field `{ident}` does not have lifetime {lifetime}", + "field `{}` does not have lifetime {}", + ident, lifetime, ); cx.error_spanned_by(field, msg); } @@ -1254,7 +1220,9 @@ impl Field { flatten.set_true(&meta.path); } else { let path = meta.path.to_token_stream().to_string().replace(' ', ""); - return Err(meta.error(format_args!("unknown serde field attribute `{path}`"))); + return Err( + meta.error(format_args!("unknown serde field attribute `{}`", path)) + ); } Ok(()) }) { @@ -1289,7 +1257,7 @@ impl Field { }; let span = Span::call_site(); path.segments.push(Ident::new("_serde", span).into()); - path.segments.push(Ident::new("__private", span).into()); + path.segments.push(private.clone().into()); path.segments.push(Ident::new("de", span).into()); path.segments .push(Ident::new("borrow_cow_str", span).into()); @@ -1306,7 +1274,7 @@ impl Field { }; let span = Span::call_site(); path.segments.push(Ident::new("_serde", span).into()); - path.segments.push(Ident::new("__private", span).into()); + path.segments.push(private.clone().into()); path.segments.push(Ident::new("de", span).into()); path.segments .push(Ident::new("borrow_cow_bytes", span).into()); @@ -1324,7 +1292,7 @@ impl Field { } Field { - name: Name::from_attrs(ident, ser_name, de_name, Some(de_aliases)), + name: MultiName::from_attrs(ident, ser_name, de_name, Some(de_aliases)), skip_serializing: skip_serializing.get(), skip_deserializing: skip_deserializing.get(), skip_serializing_if: skip_serializing_if.get(), @@ -1340,20 +1308,22 @@ impl Field { } } - pub fn name(&self) -> &Name { + pub fn name(&self) -> &MultiName { &self.name } - pub fn aliases(&self) -> &BTreeSet { + pub fn aliases(&self) -> &BTreeSet { self.name.deserialize_aliases() } pub fn rename_by_rules(&mut self, rules: RenameAllRules) { if !self.name.serialize_renamed { - self.name.serialize = rules.serialize.apply_to_field(&self.name.serialize); + self.name.serialize.value = rules.serialize.apply_to_field(&self.name.serialize.value); } if !self.name.deserialize_renamed { - self.name.deserialize = rules.deserialize.apply_to_field(&self.name.deserialize); + self.name.deserialize.value = rules + .deserialize + .apply_to_field(&self.name.deserialize.value); } self.name .deserialize_aliases @@ -1447,7 +1417,8 @@ where } } else { return Err(meta.error(format_args!( - "malformed {attr_name} attribute, expected `{attr_name}(serialize = ..., deserialize = ...)`", + "malformed {0} attribute, expected `{0}(serialize = ..., deserialize = ...)`", + attr_name, ))); } Ok(()) @@ -1512,7 +1483,7 @@ fn get_lit_str2( if !suffix.is_empty() { cx.error_spanned_by( lit, - format!("unexpected suffix `{suffix}` on string literal"), + format!("unexpected suffix `{}` on string literal", suffix), ); } Ok(Some(lit.clone())) @@ -1520,7 +1491,8 @@ fn get_lit_str2( cx.error_spanned_by( expr, format!( - "expected serde {attr_name} attribute to be a string: `{meta_item_name} = \"...\"`" + "expected serde {} attribute to be a string: `{} = \"...\"`", + attr_name, meta_item_name ), ); Ok(None) @@ -1631,7 +1603,10 @@ fn parse_lit_into_lifetimes( while !input.is_empty() { let lifetime: Lifetime = input.parse()?; if !set.insert(lifetime.clone()) { - cx.error_spanned_by(&string, format!("duplicate borrowed lifetime `{lifetime}`")); + cx.error_spanned_by( + &string, + format!("duplicate borrowed lifetime `{}`", lifetime), + ); } if input.is_empty() { break; @@ -1798,13 +1773,13 @@ fn is_primitive_path(path: &syn::Path, primitive: &str) -> bool { // attribute on the field so there must be at least one borrowable lifetime. fn borrowable_lifetimes( cx: &Ctxt, - name: &str, + name: &Name, field: &syn::Field, ) -> Result, ()> { let mut lifetimes = BTreeSet::new(); collect_lifetimes(&field.ty, &mut lifetimes); if lifetimes.is_empty() { - let msg = format!("field `{name}` has no lifetimes to borrow"); + let msg = format!("field `{}` has no lifetimes to borrow", name); cx.error_spanned_by(field, msg); Err(()) } else { @@ -1814,7 +1789,7 @@ fn borrowable_lifetimes( fn collect_lifetimes(ty: &syn::Type, out: &mut BTreeSet) { match ty { - #![cfg_attr(all(test), deny(non_exhaustive_omitted_patterns))] + #![cfg_attr(all(test, exhaustive), deny(non_exhaustive_omitted_patterns))] syn::Type::Slice(ty) => { collect_lifetimes(&ty.elem, out); } @@ -1852,8 +1827,8 @@ fn collect_lifetimes(ty: &syn::Type, out: &mut BTreeSet) { } syn::GenericArgument::Const(_) | syn::GenericArgument::AssocConst(_) - | syn::GenericArgument::Constraint(_) => {} - _ => {} + | syn::GenericArgument::Constraint(_) + | _ => {} } } } diff --git a/crates/web-rwkv-derive/src/serde/internals/case.rs b/crates/web-rwkv-derive/src/serde/internals/case.rs index 635a7155..8c8c02e7 100644 --- a/crates/web-rwkv-derive/src/serde/internals/case.rs +++ b/crates/web-rwkv-derive/src/serde/internals/case.rs @@ -42,7 +42,7 @@ static RENAME_RULES: &[(&str, RenameRule)] = &[ ]; impl RenameRule { - pub fn from_str(rename_all_str: &str) -> Result> { + pub fn from_str(rename_all_str: &str) -> Result { for (name, rule) in RENAME_RULES { if rename_all_str == *name { return Ok(*rule); @@ -121,7 +121,7 @@ pub struct ParseError<'a> { unknown: &'a str, } -impl Display for ParseError<'_> { +impl<'a> Display for ParseError<'a> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.write_str("unknown rename rule `rename_all = ")?; Debug::fmt(self.unknown, f)?; diff --git a/crates/web-rwkv-derive/src/serde/internals/check.rs b/crates/web-rwkv-derive/src/serde/internals/check.rs index e0bcc76a..ad68b3a4 100644 --- a/crates/web-rwkv-derive/src/serde/internals/check.rs +++ b/crates/web-rwkv-derive/src/serde/internals/check.rs @@ -38,7 +38,7 @@ fn check_default_on_tuple(cx: &Ctxt, cont: &Container) { if let Some(first) = first_default_index { cx.error_spanned_by( field.ty, - format!("field must have #[serde(default)] because previous field {first} has #[serde(default)]"), + format!("field must have #[serde(default)] because previous field {} has #[serde(default)]", first), ); } continue; @@ -311,7 +311,7 @@ fn check_internal_tag_field_name_conflict(cx: &Ctxt, cont: &Container) { let diagnose_conflict = || { cx.error_spanned_by( cont.original, - format!("variant field name `{tag}` conflicts with internal tag"), + format!("variant field name `{}` conflicts with internal tag", tag), ); }; @@ -329,13 +329,13 @@ fn check_internal_tag_field_name_conflict(cx: &Ctxt, cont: &Container) { let name = field.attrs.name(); let ser_name = name.serialize_name(); - if check_ser && ser_name == tag { + if check_ser && ser_name.value == tag { diagnose_conflict(); return; } for de_name in field.attrs.aliases() { - if check_de && de_name == tag { + if check_de && de_name.value == tag { diagnose_conflict(); return; } @@ -358,7 +358,10 @@ fn check_adjacent_tag_conflict(cx: &Ctxt, cont: &Container) { if type_tag == content_tag { cx.error_spanned_by( cont.original, - format!("enum tags `{type_tag}` for type and content conflict with each other"), + format!( + "enum tags `{}` for type and content conflict with each other", + type_tag + ), ); } } @@ -444,7 +447,7 @@ fn check_transparent(cx: &Ctxt, cont: &mut Container, derive: Derive) { fn member_message(member: &Member) -> String { match member { - Member::Named(ident) => format!("`{ident}`"), + Member::Named(ident) => format!("`{}`", ident), Member::Unnamed(i) => format!("#{}", i.index), } } diff --git a/crates/web-rwkv-derive/src/serde/internals/mod.rs b/crates/web-rwkv-derive/src/serde/internals/mod.rs index f98ef08e..cd1e8105 100644 --- a/crates/web-rwkv-derive/src/serde/internals/mod.rs +++ b/crates/web-rwkv-derive/src/serde/internals/mod.rs @@ -1,5 +1,6 @@ pub mod ast; pub mod attr; +pub mod name; mod case; mod check; diff --git a/crates/web-rwkv-derive/src/serde/internals/name.rs b/crates/web-rwkv-derive/src/serde/internals/name.rs new file mode 100644 index 00000000..b88304f2 --- /dev/null +++ b/crates/web-rwkv-derive/src/serde/internals/name.rs @@ -0,0 +1,113 @@ +use crate::serde::internals::attr::{Attr, VecAttr}; +use proc_macro2::{Ident, Span, TokenStream}; +use quote::ToTokens; +use std::cmp::Ordering; +use std::collections::BTreeSet; +use std::fmt::{self, Display}; +use syn::LitStr; + +pub struct MultiName { + pub(crate) serialize: Name, + pub(crate) serialize_renamed: bool, + pub(crate) deserialize: Name, + pub(crate) deserialize_renamed: bool, + pub(crate) deserialize_aliases: BTreeSet, +} + +impl MultiName { + pub(crate) fn from_attrs( + source_name: Name, + ser_name: Attr, + de_name: Attr, + de_aliases: Option>, + ) -> Self { + let mut alias_set = BTreeSet::new(); + if let Some(de_aliases) = de_aliases { + for alias_name in de_aliases.get() { + alias_set.insert(alias_name); + } + } + + let ser_name = ser_name.get(); + let ser_renamed = ser_name.is_some(); + let de_name = de_name.get(); + let de_renamed = de_name.is_some(); + MultiName { + serialize: ser_name.unwrap_or_else(|| source_name.clone()), + serialize_renamed: ser_renamed, + deserialize: de_name.unwrap_or(source_name), + deserialize_renamed: de_renamed, + deserialize_aliases: alias_set, + } + } + + /// Return the container name for the container when serializing. + pub fn serialize_name(&self) -> &Name { + &self.serialize + } + + /// Return the container name for the container when deserializing. + pub fn deserialize_name(&self) -> &Name { + &self.deserialize + } + + pub(crate) fn deserialize_aliases(&self) -> &BTreeSet { + &self.deserialize_aliases + } +} + +#[derive(Clone)] +pub struct Name { + pub value: String, + pub span: Span, +} + +impl ToTokens for Name { + fn to_tokens(&self, tokens: &mut TokenStream) { + LitStr::new(&self.value, self.span).to_tokens(tokens); + } +} + +impl Ord for Name { + fn cmp(&self, other: &Self) -> Ordering { + Ord::cmp(&self.value, &other.value) + } +} + +impl PartialOrd for Name { + fn partial_cmp(&self, other: &Self) -> Option { + Some(Ord::cmp(self, other)) + } +} + +impl Eq for Name {} + +impl PartialEq for Name { + fn eq(&self, other: &Self) -> bool { + self.value == other.value + } +} + +impl From<&Ident> for Name { + fn from(ident: &Ident) -> Self { + Name { + value: ident.to_string(), + span: ident.span(), + } + } +} + +impl From<&LitStr> for Name { + fn from(lit: &LitStr) -> Self { + Name { + value: lit.value(), + span: lit.span(), + } + } +} + +impl Display for Name { + fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + Display::fmt(&self.value, formatter) + } +} diff --git a/crates/web-rwkv-derive/src/serde/internals/receiver.rs b/crates/web-rwkv-derive/src/serde/internals/receiver.rs index ade5266b..1cedc80f 100644 --- a/crates/web-rwkv-derive/src/serde/internals/receiver.rs +++ b/crates/web-rwkv-derive/src/serde/internals/receiver.rs @@ -83,7 +83,7 @@ impl ReplaceReceiver<'_> { self.visit_type_mut_impl(ty); return; }; - *ty = self.self_ty(span).into(); + *ty = Type::Path(self.self_ty(span)); } // `Self::Assoc` -> `::Assoc` @@ -106,7 +106,7 @@ impl ReplaceReceiver<'_> { fn visit_type_mut_impl(&mut self, ty: &mut Type) { match ty { - #![cfg_attr(all(test), deny(non_exhaustive_omitted_patterns))] + #![cfg_attr(all(test, exhaustive), deny(non_exhaustive_omitted_patterns))] Type::Array(ty) => { self.visit_type_mut(&mut ty.elem); self.visit_expr_mut(&mut ty.len); @@ -177,7 +177,7 @@ impl ReplaceReceiver<'_> { PathArguments::AngleBracketed(arguments) => { for arg in &mut arguments.args { match arg { - #![cfg_attr(all(test), deny(non_exhaustive_omitted_patterns))] + #![cfg_attr(all(test, exhaustive), deny(non_exhaustive_omitted_patterns))] GenericArgument::Type(arg) => self.visit_type_mut(arg), GenericArgument::AssocType(arg) => self.visit_type_mut(&mut arg.ty), GenericArgument::Lifetime(_) @@ -206,9 +206,11 @@ impl ReplaceReceiver<'_> { fn visit_type_param_bound_mut(&mut self, bound: &mut TypeParamBound) { match bound { - #![cfg_attr(all(test), deny(non_exhaustive_omitted_patterns))] + #![cfg_attr(all(test, exhaustive), deny(non_exhaustive_omitted_patterns))] TypeParamBound::Trait(bound) => self.visit_path_mut(&mut bound.path), - TypeParamBound::Lifetime(_) | TypeParamBound::Verbatim(_) => {} + TypeParamBound::Lifetime(_) + | TypeParamBound::PreciseCapture(_) + | TypeParamBound::Verbatim(_) => {} _ => {} } } @@ -227,7 +229,7 @@ impl ReplaceReceiver<'_> { if let Some(where_clause) = &mut generics.where_clause { for predicate in &mut where_clause.predicates { match predicate { - #![cfg_attr(all(test), deny(non_exhaustive_omitted_patterns))] + #![cfg_attr(all(test, exhaustive), deny(non_exhaustive_omitted_patterns))] WherePredicate::Type(predicate) => { self.visit_type_mut(&mut predicate.bounded_ty); for bound in &mut predicate.bounds { diff --git a/crates/web-rwkv-derive/src/serde/internals/symbol.rs b/crates/web-rwkv-derive/src/serde/internals/symbol.rs index 28bed9f5..7da8df6f 100644 --- a/crates/web-rwkv-derive/src/serde/internals/symbol.rs +++ b/crates/web-rwkv-derive/src/serde/internals/symbol.rs @@ -28,6 +28,7 @@ pub const RENAME_ALL: Symbol = Symbol("rename_all"); pub const RENAME_ALL_FIELDS: Symbol = Symbol("rename_all_fields"); pub const REPR: Symbol = Symbol("repr"); pub const SEED: Symbol = Symbol("seed"); +pub const SERDE: Symbol = Symbol("serde"); pub const SERDE_SEED: Symbol = Symbol("serde_seed"); pub const SERIALIZE: Symbol = Symbol("serialize"); pub const SERIALIZE_WITH: Symbol = Symbol("serialize_with"); diff --git a/crates/web-rwkv-derive/src/serde/mod.rs b/crates/web-rwkv-derive/src/serde/mod.rs index 8b67fc0e..0e6129ab 100644 --- a/crates/web-rwkv-derive/src/serde/mod.rs +++ b/crates/web-rwkv-derive/src/serde/mod.rs @@ -1,8 +1,92 @@ +//! This crate provides Serde's derive macro for `DeserializeSeed`. + +#![allow(unexpected_cfgs)] +// Ignored clippy lints +#![allow( + // clippy false positive: https://github.com/rust-lang/rust-clippy/issues/7054 + clippy::branches_sharing_code, + clippy::cognitive_complexity, + // clippy bug: https://github.com/rust-lang/rust-clippy/issues/7575 + clippy::collapsible_match, + clippy::derive_partial_eq_without_eq, + clippy::enum_variant_names, + // clippy bug: https://github.com/rust-lang/rust-clippy/issues/6797 + clippy::manual_map, + clippy::match_like_matches_macro, + clippy::needless_lifetimes, + clippy::needless_pass_by_value, + clippy::too_many_arguments, + clippy::trivially_copy_pass_by_ref, + clippy::used_underscore_binding, + clippy::wildcard_in_or_patterns, + // clippy bug: https://github.com/rust-lang/rust-clippy/issues/5704 + clippy::unnested_or_patterns, +)] +// Ignored clippy_pedantic lints +#![allow( + clippy::cast_possible_truncation, + clippy::checked_conversions, + clippy::doc_markdown, + clippy::elidable_lifetime_names, + clippy::enum_glob_use, + clippy::indexing_slicing, + clippy::items_after_statements, + clippy::let_underscore_untyped, + clippy::manual_assert, + clippy::map_err_ignore, + clippy::match_same_arms, + // clippy bug: https://github.com/rust-lang/rust-clippy/issues/6984 + clippy::match_wildcard_for_single_variants, + clippy::module_name_repetitions, + clippy::must_use_candidate, + clippy::similar_names, + clippy::single_match_else, + clippy::struct_excessive_bools, + clippy::too_many_lines, + clippy::uninlined_format_args, + clippy::unseparated_literal_suffix, + clippy::unused_self, + clippy::use_self, + clippy::wildcard_imports +)] +#![allow(unknown_lints, mismatched_lifetime_syntaxes)] #![allow(dead_code)] -pub mod bound; +extern crate proc_macro2; +extern crate quote; +extern crate syn; + +extern crate proc_macro; + +use proc_macro2::{Ident, Span}; +use quote::{ToTokens, TokenStreamExt as _}; + +mod internals; +#[macro_use] +mod bound; +#[macro_use] +mod fragment; +mod deprecated; +mod dummy; +mod pretend; +mod this; + pub mod de; -pub mod dummy; -pub mod fragment; -pub mod internals; -pub mod this; + +#[allow(non_camel_case_types)] +struct private; + +impl private { + fn ident(&self) -> Ident { + Ident::new( + concat!("__private", env!("SERDE_PATCH_VERSION")), + Span::call_site(), + ) + } +} + +impl ToTokens for private { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + tokens.append(self.ident()); + } +} diff --git a/crates/web-rwkv-derive/src/serde/pretend.rs b/crates/web-rwkv-derive/src/serde/pretend.rs new file mode 100644 index 00000000..22926ad1 --- /dev/null +++ b/crates/web-rwkv-derive/src/serde/pretend.rs @@ -0,0 +1,188 @@ +use crate::serde::internals::ast::{Container, Data, Field, Style, Variant}; +use crate::serde::private; +use proc_macro2::TokenStream; +use quote::{format_ident, quote}; + +// Suppress dead_code warnings that would otherwise appear when using a remote +// derive. Other than this pretend code, a struct annotated with remote derive +// never has its fields referenced and an enum annotated with remote derive +// never has its variants constructed. +// +// warning: field is never used: `i` +// --> src/main.rs:4:20 +// | +// 4 | struct StructDef { i: i32 } +// | ^^^^^^ +// +// warning: variant is never constructed: `V` +// --> src/main.rs:8:16 +// | +// 8 | enum EnumDef { V } +// | ^ +// +pub fn pretend_used(cont: &Container, is_packed: bool) -> TokenStream { + let pretend_fields = pretend_fields_used(cont, is_packed); + let pretend_variants = pretend_variants_used(cont); + + quote! { + #pretend_fields + #pretend_variants + } +} + +// For structs with named fields, expands to: +// +// match None::<&T> { +// Some(T { a: __v0, b: __v1 }) => {} +// _ => {} +// } +// +// For packed structs on sufficiently new rustc, expands to: +// +// match None::<&T> { +// Some(__v @ T { a: _, b: _ }) => { +// let _ = addr_of!(__v.a); +// let _ = addr_of!(__v.b); +// } +// _ => {} +// } +// +// For packed structs on older rustc, we assume Sized and !Drop, and expand to: +// +// match None:: { +// Some(T { a: __v0, b: __v1 }) => {} +// _ => {} +// } +// +// For enums, expands to the following but only including struct variants: +// +// match None::<&T> { +// Some(T::A { a: __v0 }) => {} +// Some(T::B { b: __v0 }) => {} +// _ => {} +// } +// +fn pretend_fields_used(cont: &Container, is_packed: bool) -> TokenStream { + match &cont.data { + Data::Enum(variants) => pretend_fields_used_enum(cont, variants), + Data::Struct(Style::Struct | Style::Tuple | Style::Newtype, fields) => { + if is_packed { + pretend_fields_used_struct_packed(cont, fields) + } else { + pretend_fields_used_struct(cont, fields) + } + } + Data::Struct(Style::Unit, _) => quote!(), + } +} + +fn pretend_fields_used_struct(cont: &Container, fields: &[Field]) -> TokenStream { + let type_ident = &cont.ident; + let (_, ty_generics, _) = cont.generics.split_for_impl(); + + let members = fields.iter().map(|field| &field.member); + let placeholders = (0usize..).map(|i| format_ident!("__v{}", i)); + + quote! { + match _serde::#private::None::<&#type_ident #ty_generics> { + _serde::#private::Some(#type_ident { #(#members: #placeholders),* }) => {} + _ => {} + } + } +} + +fn pretend_fields_used_struct_packed(cont: &Container, fields: &[Field]) -> TokenStream { + let type_ident = &cont.ident; + let (_, ty_generics, _) = cont.generics.split_for_impl(); + + let members = fields.iter().map(|field| &field.member).collect::>(); + + let private2 = private; + quote! { + match _serde::#private::None::<&#type_ident #ty_generics> { + _serde::#private::Some(__v @ #type_ident { #(#members: _),* }) => { + #( + let _ = _serde::#private2::ptr::addr_of!(__v.#members); + )* + } + _ => {} + } + } +} + +fn pretend_fields_used_enum(cont: &Container, variants: &[Variant]) -> TokenStream { + let type_ident = &cont.ident; + let (_, ty_generics, _) = cont.generics.split_for_impl(); + + let patterns = variants + .iter() + .filter_map(|variant| match variant.style { + Style::Struct | Style::Tuple | Style::Newtype => { + let variant_ident = &variant.ident; + let members = variant.fields.iter().map(|field| &field.member); + let placeholders = (0usize..).map(|i| format_ident!("__v{}", i)); + Some(quote!(#type_ident::#variant_ident { #(#members: #placeholders),* })) + } + Style::Unit => None, + }) + .collect::>(); + + let private2 = private; + quote! { + match _serde::#private::None::<&#type_ident #ty_generics> { + #( + _serde::#private2::Some(#patterns) => {} + )* + _ => {} + } + } +} + +// Expands to one of these per enum variant: +// +// match None { +// Some((__v0, __v1,)) => { +// let _ = E::V { a: __v0, b: __v1 }; +// } +// _ => {} +// } +// +fn pretend_variants_used(cont: &Container) -> TokenStream { + let variants = match &cont.data { + Data::Enum(variants) => variants, + Data::Struct(_, _) => { + return quote!(); + } + }; + + let type_ident = &cont.ident; + let (_, ty_generics, _) = cont.generics.split_for_impl(); + let turbofish = ty_generics.as_turbofish(); + + let cases = variants.iter().map(|variant| { + let variant_ident = &variant.ident; + let placeholders = &(0..variant.fields.len()) + .map(|i| format_ident!("__v{}", i)) + .collect::>(); + + let pat = match variant.style { + Style::Struct => { + let members = variant.fields.iter().map(|field| &field.member); + quote!({ #(#members: #placeholders),* }) + } + Style::Tuple | Style::Newtype => quote!(( #(#placeholders),* )), + Style::Unit => quote!(), + }; + + quote! { + match _serde::#private::None { + _serde::#private::Some((#(#placeholders,)*)) => { + let _ = #type_ident::#variant_ident #turbofish #pat; + } + _ => {} + } + } + }); + + quote!(#(#cases)*) +} diff --git a/src/context.rs b/src/context.rs index 99bbdf80..d6892eb8 100644 --- a/src/context.rs +++ b/src/context.rs @@ -8,9 +8,9 @@ use wgpu::{ util::{BufferInitDescriptor, DeviceExt}, Adapter, BindGroup, BindGroupDescriptor, BindGroupEntry, BindGroupLayout, BindGroupLayoutDescriptor, BindGroupLayoutEntry, Buffer, BufferDescriptor, BufferUsages, - ComputePipeline, ComputePipelineDescriptor, Device, DeviceDescriptor, Features, Instance, - Limits, MemoryHints, PipelineLayoutDescriptor, PowerPreference, Queue, RequestAdapterOptions, - ShaderModuleDescriptor, + ComputePipeline, ComputePipelineDescriptor, Device, DeviceDescriptor, ExperimentalFeatures, + Features, Instance, Limits, MemoryHints, PipelineLayoutDescriptor, PowerPreference, Queue, + RequestAdapterOptions, ShaderModuleDescriptor, Trace, }; use crate::tensor::{ @@ -69,7 +69,10 @@ impl Drop for Context { if self.event.sender_count() <= 1 { self.clear_buffers(); self.queue.submit(None); - _ = self.device.poll(wgpu::PollType::Wait); + _ = self.device.poll(wgpu::PollType::Wait { + submission_index: None, + timeout: None, + }); } } } @@ -120,7 +123,8 @@ impl ContextBuilder { required_features: features, required_limits: limits, memory_hints: MemoryHints::Performance, - trace: wgpu::Trace::Off, + trace: Trace::Off, + experimental_features: ExperimentalFeatures::disabled(), }) .await .map_err(|_| ContextError::RequestDeviceFailed)?; @@ -440,7 +444,10 @@ fn read_back_buffer(device: &Device, buffer: &Buffer) -> Box<[u8]> { let slice = buffer.slice(..); slice.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap()); - _ = device.poll(wgpu::PollType::Wait); + _ = device.poll(wgpu::PollType::Wait { + submission_index: None, + timeout: None, + }); receiver .recv() .expect("failed to receive read back buffer") diff --git a/src/runtime/v4.rs b/src/runtime/v4.rs index 3f2cc4aa..1b6b8446 100644 --- a/src/runtime/v4.rs +++ b/src/runtime/v4.rs @@ -506,7 +506,7 @@ impl Bundle { } fn turbo(num_token: usize) -> bool { - num_token % super::infer::rnn::MIN_TOKEN_CHUNK_SIZE == 0 + num_token.is_multiple_of(super::infer::rnn::MIN_TOKEN_CHUNK_SIZE) } fn hook_op( @@ -784,7 +784,7 @@ fn dispatch_layer( hook_op(Hook::PostFfn(index))?, ]); - if (index + 1) % rescale == 0 { + if (index + 1).is_multiple_of(rescale) { ops.push(TensorOp::affine(&buffer.x, 0.5, 0.0)?); } @@ -863,8 +863,11 @@ impl ModelBuilder { w: Matrix::Fp16(loader.load_matrix_f16("head.weight")?), }; - context.queue.submit(None); - _ = context.device.poll(wgpu::PollType::Wait); + let submission_index = Some(context.queue.submit(None)); + _ = context.device.poll(wgpu::PollType::Wait { + submission_index, + timeout: None, + }); let load_matrix = |name: String, quant: Quant| loader.load_matrix(name, quant); let load_matrix_discount = |name: String, quant: Quant, discount: f32| { @@ -917,8 +920,11 @@ impl ModelBuilder { w_v: load_matrix_discount(format!("{ffn}.value.weight"), quant, discount)?, }; - context.queue.submit(None); - _ = context.device.poll(wgpu::PollType::Wait); + let submission_index = Some(context.queue.submit(None)); + _ = context.device.poll(wgpu::PollType::Wait { + submission_index, + timeout: None, + }); layers.push(Layer { att_layer_norm, @@ -928,8 +934,11 @@ impl ModelBuilder { }) } - context.queue.submit(None); - _ = context.device.poll(wgpu::PollType::Wait); + let submission_index = Some(context.queue.submit(None)); + _ = context.device.poll(wgpu::PollType::Wait { + submission_index, + timeout: None, + }); let tensor = ModelTensor { embed, diff --git a/src/runtime/v5.rs b/src/runtime/v5.rs index cca8efe5..37a6b9bd 100644 --- a/src/runtime/v5.rs +++ b/src/runtime/v5.rs @@ -523,7 +523,7 @@ impl super::model::Bundle for Bundle { } fn turbo(num_token: usize) -> bool { - num_token % super::infer::rnn::MIN_TOKEN_CHUNK_SIZE == 0 + num_token.is_multiple_of(super::infer::rnn::MIN_TOKEN_CHUNK_SIZE) } fn hook_op( @@ -877,7 +877,7 @@ fn dispatch_layer( hook_op(Hook::PostFfn(index))?, ]); - if (index + 1) % rescale == 0 { + if (index + 1).is_multiple_of(rescale) { ops.push(TensorOp::affine(&buffer.x, 0.5, 0.0)?); } @@ -956,8 +956,11 @@ impl ModelBuilder { w: Matrix::Fp16(loader.load_matrix_f16("head.weight")?), }; - context.queue.submit(None); - _ = context.device.poll(wgpu::PollType::Wait); + let submission_index = Some(context.queue.submit(None)); + _ = context.device.poll(wgpu::PollType::Wait { + submission_index, + timeout: None, + }); let load_matrix = |name: String, quant: Quant| loader.load_matrix(name, quant); let load_matrix_discount = |name: String, quant: Quant, discount: f32| { @@ -1033,8 +1036,11 @@ impl ModelBuilder { w_v: load_matrix_discount(format!("{ffn}.value.weight"), quant, discount)?, }; - context.queue.submit(None); - _ = context.device.poll(wgpu::PollType::Wait); + let submission_index = Some(context.queue.submit(None)); + _ = context.device.poll(wgpu::PollType::Wait { + submission_index, + timeout: None, + }); layers.push(Layer { att_layer_norm, @@ -1044,8 +1050,11 @@ impl ModelBuilder { }) } - context.queue.submit(None); - _ = context.device.poll(wgpu::PollType::Wait); + let submission_index = Some(context.queue.submit(None)); + _ = context.device.poll(wgpu::PollType::Wait { + submission_index, + timeout: None, + }); let tensor = ModelTensor { embed, diff --git a/src/runtime/v6.rs b/src/runtime/v6.rs index 4995f1ae..91ca41af 100644 --- a/src/runtime/v6.rs +++ b/src/runtime/v6.rs @@ -566,7 +566,7 @@ impl super::model::Bundle for Bundle { } fn turbo(num_token: usize) -> bool { - num_token % super::infer::rnn::MIN_TOKEN_CHUNK_SIZE == 0 + num_token.is_multiple_of(super::infer::rnn::MIN_TOKEN_CHUNK_SIZE) } fn hook_op( @@ -950,7 +950,7 @@ fn dispatch_layer( hook_op(Hook::PostFfn(index))?, ]); - if (index + 1) % rescale == 0 { + if (index + 1).is_multiple_of(rescale) { ops.push(TensorOp::affine(&buffer.x, 0.5, 0.0)?); } @@ -1029,8 +1029,11 @@ impl ModelBuilder { w: Matrix::Fp16(loader.load_matrix_f16_padded("head.weight")?), }; - context.queue.submit(None); - _ = context.device.poll(wgpu::PollType::Wait); + let submission_index = Some(context.queue.submit(None)); + _ = context.device.poll(wgpu::PollType::Wait { + submission_index, + timeout: None, + }); let load_matrix = |name: String, quant: Quant| loader.load_matrix(name, quant); let load_matrix_discount = |name: String, quant: Quant, discount: f32| { @@ -1129,8 +1132,11 @@ impl ModelBuilder { w_v: load_matrix_discount(format!("{ffn}.value.weight"), quant, discount)?, }; - context.queue.submit(None); - _ = context.device.poll(wgpu::PollType::Wait); + let submission_index = Some(context.queue.submit(None)); + _ = context.device.poll(wgpu::PollType::Wait { + submission_index, + timeout: None, + }); layers.push(Layer { att_layer_norm, @@ -1140,8 +1146,11 @@ impl ModelBuilder { }) } - context.queue.submit(None); - _ = context.device.poll(wgpu::PollType::Wait); + let submission_index = Some(context.queue.submit(None)); + _ = context.device.poll(wgpu::PollType::Wait { + submission_index, + timeout: None, + }); let tensor = ModelTensor { embed, diff --git a/src/runtime/v7.rs b/src/runtime/v7.rs index 406c4512..e61a52c6 100644 --- a/src/runtime/v7.rs +++ b/src/runtime/v7.rs @@ -581,7 +581,7 @@ impl super::model::Bundle for Bundle { } fn turbo(num_token: usize) -> bool { - num_token % super::infer::rnn::MIN_TOKEN_CHUNK_SIZE == 0 + num_token.is_multiple_of(super::infer::rnn::MIN_TOKEN_CHUNK_SIZE) } fn hook_op( @@ -999,7 +999,7 @@ fn dispatch_layer( hook_op(Hook::PostFfn(index))?, ]); - if (index + 1) % rescale == 0 { + if (index + 1).is_multiple_of(rescale) { ops.push(TensorOp::affine(&buffer.x, 0.5, 0.0)?); } @@ -1073,8 +1073,11 @@ impl ModelBuilder { w: Matrix::Fp16(loader.load_matrix_f16_padded("head.weight")?), }; - context.queue.submit(None); - _ = context.device.poll(wgpu::PollType::Wait); + let submission_index = Some(context.queue.submit(None)); + _ = context.device.poll(wgpu::PollType::Wait { + submission_index, + timeout: None, + }); let load_matrix = |name: String, quant: Quant| loader.load_matrix(name, quant); let load_matrix_discount = |name: String, quant: Quant, discount: f32| { @@ -1183,8 +1186,11 @@ impl ModelBuilder { w_v: load_matrix_discount(format!("{ffn}.value.weight"), quant, discount)?, }; - context.queue.submit(None); - _ = context.device.poll(wgpu::PollType::Wait); + let submission_index = Some(context.queue.submit(None)); + _ = context.device.poll(wgpu::PollType::Wait { + submission_index, + timeout: None, + }); layers.push(Layer { att_ln, @@ -1194,8 +1200,11 @@ impl ModelBuilder { }) } - context.queue.submit(None); - _ = context.device.poll(wgpu::PollType::Wait); + let submission_index = Some(context.queue.submit(None)); + _ = context.device.poll(wgpu::PollType::Wait { + submission_index, + timeout: None, + }); let tensor = ModelTensor { embed, diff --git a/src/tensor/serialization.rs b/src/tensor/serialization.rs index 99416139..492b3061 100644 --- a/src/tensor/serialization.rs +++ b/src/tensor/serialization.rs @@ -27,7 +27,15 @@ impl From> for TensorBlob<'_> { impl From> for TensorCpu { fn from(value: TensorBlob) -> Self { let TensorBlob { shape, data } = value; - let data: Vec = bytemuck::cast_slice(&data).to_vec(); + let data = data.to_vec().into_boxed_slice(); + // let data: Vec = bytemuck::cast_slice(&data).to_vec(); + let data = Box::leak(data); + let data: Box<[T]> = unsafe { + let ptr = data.as_ptr() as *const T; + let len = data.len() / size_of::(); + let slice = core::slice::from_raw_parts(ptr, len); + Box::from(slice) + }; let data = data.into(); Self { shape,