diff --git a/postgres-derive-test/src/composites.rs b/postgres-derive-test/src/composites.rs index a1b76345f..c59ff0a7b 100644 --- a/postgres-derive-test/src/composites.rs +++ b/postgres-derive-test/src/composites.rs @@ -89,6 +89,48 @@ fn name_overrides() { ); } +#[test] +fn rename_all_overrides() { + #[derive(FromSql, ToSql, Debug, PartialEq)] + #[postgres(name = "inventory_item", rename_all = "SCREAMING_SNAKE_CASE")] + struct InventoryItem { + name: String, + supplier_id: i32, + price: Option, + } + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.batch_execute( + "CREATE TYPE pg_temp.inventory_item AS ( + \"NAME\" TEXT, + \"SUPPLIER_ID\" INT, + \"PRICE\" DOUBLE PRECISION + );", + ) + .unwrap(); + + let item = InventoryItem { + name: "foobar".to_owned(), + supplier_id: 100, + price: Some(15.50), + }; + + let item_null = InventoryItem { + name: "foobar".to_owned(), + supplier_id: 100, + price: None, + }; + + test_type( + &mut conn, + "inventory_item", + &[ + (item, "ROW('foobar', 100, 15.50)"), + (item_null, "ROW('foobar', 100, NULL)"), + ], + ); +} + #[test] fn wrong_name() { #[derive(FromSql, ToSql, Debug, PartialEq)] diff --git a/postgres-derive-test/src/enums.rs b/postgres-derive-test/src/enums.rs index a7039ca05..e44f37616 100644 --- a/postgres-derive-test/src/enums.rs +++ b/postgres-derive-test/src/enums.rs @@ -53,6 +53,35 @@ fn name_overrides() { ); } +#[test] +fn rename_all_overrides() { + #[derive(Debug, ToSql, FromSql, PartialEq)] + #[postgres(name = "mood", rename_all = "snake_case")] + enum Mood { + Sad, + #[postgres(name = "okay")] + Ok, + Happy, + } + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.execute( + "CREATE TYPE pg_temp.mood AS ENUM ('sad', 'okay', 'happy')", + &[], + ) + .unwrap(); + + test_type( + &mut conn, + "mood", + &[ + (Mood::Sad, "'sad'"), + (Mood::Ok, "'okay'"), + (Mood::Happy, "'happy'"), + ], + ); +} + #[test] fn wrong_name() { #[derive(Debug, ToSql, FromSql, PartialEq)] diff --git a/postgres-derive/Cargo.toml b/postgres-derive/Cargo.toml index 22b50b707..912c83bfc 100644 --- a/postgres-derive/Cargo.toml +++ b/postgres-derive/Cargo.toml @@ -12,6 +12,10 @@ proc-macro = true test = false [dependencies] +convert_case = "0.6" +itertools = "0.10" syn = "1.0" proc-macro2 = "1.0" +strum = "0.24" +strum_macros = "0.24" quote = "1.0" diff --git a/postgres-derive/src/composites.rs b/postgres-derive/src/composites.rs index 15bfabc13..80cf6aa6c 100644 --- a/postgres-derive/src/composites.rs +++ b/postgres-derive/src/composites.rs @@ -4,7 +4,7 @@ use syn::{ TypeParamBound, }; -use crate::overrides::Overrides; +use crate::{field_variant_overrides::FieldVariantOverrides, struct_overrides::StructOverrides}; pub struct Field { pub name: String, @@ -13,17 +13,20 @@ pub struct Field { } impl Field { - pub fn parse(raw: &syn::Field) -> Result { - let overrides = Overrides::extract(&raw.attrs)?; + pub fn parse(struct_overrides: &StructOverrides, raw: &syn::Field) -> Result { + use convert_case::Casing; + + let mut overrides = FieldVariantOverrides::extract(&raw.attrs)?; let ident = raw.ident.as_ref().unwrap().clone(); Ok(Field { - name: overrides.name.unwrap_or_else(|| { + name: overrides.name.take().unwrap_or_else(|| { let name = ident.to_string(); - match name.strip_prefix("r#") { - Some(name) => name.to_string(), - None => name, - } + let name = name.strip_prefix("r#").map(String::from).unwrap_or(name); + struct_overrides + .rename_all + .map(|case| name.to_case(case)) + .unwrap_or(name) }), ident, type_: raw.ty.clone(), diff --git a/postgres-derive/src/enums.rs b/postgres-derive/src/enums.rs index 3c6bc7113..4f4be4dfc 100644 --- a/postgres-derive/src/enums.rs +++ b/postgres-derive/src/enums.rs @@ -1,6 +1,6 @@ use syn::{Error, Fields, Ident}; -use crate::overrides::Overrides; +use crate::{field_variant_overrides::FieldVariantOverrides, struct_overrides::StructOverrides}; pub struct Variant { pub ident: Ident, @@ -8,7 +8,9 @@ pub struct Variant { } impl Variant { - pub fn parse(raw: &syn::Variant) -> Result { + pub fn parse(struct_overrides: &StructOverrides, raw: &syn::Variant) -> Result { + use convert_case::Casing; + match raw.fields { Fields::Unit => {} _ => { @@ -19,10 +21,17 @@ impl Variant { } } - let overrides = Overrides::extract(&raw.attrs)?; + let mut overrides = FieldVariantOverrides::extract(&raw.attrs)?; Ok(Variant { ident: raw.ident.clone(), - name: overrides.name.unwrap_or_else(|| raw.ident.to_string()), + name: overrides.name.take().unwrap_or_else(|| { + let name = raw.ident.to_string(); + let name = name.strip_prefix("r#").map(String::from).unwrap_or(name); + struct_overrides + .rename_all + .map(|case| name.to_case(case)) + .unwrap_or(name) + }), }) } } diff --git a/postgres-derive/src/field_variant_overrides.rs b/postgres-derive/src/field_variant_overrides.rs new file mode 100644 index 000000000..370f338cc --- /dev/null +++ b/postgres-derive/src/field_variant_overrides.rs @@ -0,0 +1,49 @@ +use syn::{Attribute, Error, Lit, Meta, NestedMeta}; + +#[derive(Default)] +pub struct FieldVariantOverrides { + pub name: Option, +} + +impl FieldVariantOverrides { + pub fn extract(attrs: &[Attribute]) -> Result { + let mut overrides: FieldVariantOverrides = Default::default(); + + for attr in attrs { + let attr = attr.parse_meta()?; + + if !attr.path().is_ident("postgres") { + continue; + } + + let list = match attr { + Meta::List(ref list) => list, + bad => return Err(Error::new_spanned(bad, "expected a #[postgres(...)]")), + }; + + for item in &list.nested { + match item { + NestedMeta::Meta(Meta::NameValue(meta)) => { + if meta.path.is_ident("name") { + let value = match &meta.lit { + Lit::Str(s) => s.value(), + bad => { + return Err(Error::new_spanned( + bad, + "expected a string literal", + )) + } + }; + overrides.name = Some(value); + } else { + return Err(Error::new_spanned(&meta.path, "unknown override")); + } + } + bad => return Err(Error::new_spanned(bad, "unknown attribute")), + } + } + } + + Ok(overrides) + } +} diff --git a/postgres-derive/src/fromsql.rs b/postgres-derive/src/fromsql.rs index f458c6e3d..5eb1167a6 100644 --- a/postgres-derive/src/fromsql.rs +++ b/postgres-derive/src/fromsql.rs @@ -12,10 +12,10 @@ use crate::accepts; use crate::composites::Field; use crate::composites::{append_generic_bound, new_derive_path}; use crate::enums::Variant; -use crate::overrides::Overrides; +use crate::struct_overrides::StructOverrides; pub fn expand_derive_fromsql(input: DeriveInput) -> Result { - let overrides = Overrides::extract(&input.attrs)?; + let mut overrides = StructOverrides::extract(&input.attrs)?; if overrides.name.is_some() && overrides.transparent { return Err(Error::new_spanned( @@ -24,7 +24,10 @@ pub fn expand_derive_fromsql(input: DeriveInput) -> Result { )); } - let name = overrides.name.unwrap_or_else(|| input.ident.to_string()); + let name = overrides + .name + .take() + .unwrap_or_else(|| input.ident.to_string()); let (accepts_body, to_sql_body) = if overrides.transparent { match input.data { @@ -51,7 +54,7 @@ pub fn expand_derive_fromsql(input: DeriveInput) -> Result { let variants = data .variants .iter() - .map(Variant::parse) + .map(|variant| Variant::parse(&overrides, variant)) .collect::, _>>()?; ( accepts::enum_body(&name, &variants), @@ -75,7 +78,7 @@ pub fn expand_derive_fromsql(input: DeriveInput) -> Result { let fields = fields .named .iter() - .map(Field::parse) + .map(|field| Field::parse(&overrides, field)) .collect::, _>>()?; ( accepts::composite_body(&name, "FromSql", &fields), diff --git a/postgres-derive/src/lib.rs b/postgres-derive/src/lib.rs index 98e6add24..69244c24c 100644 --- a/postgres-derive/src/lib.rs +++ b/postgres-derive/src/lib.rs @@ -9,8 +9,10 @@ use syn::parse_macro_input; mod accepts; mod composites; mod enums; +mod field_variant_overrides; mod fromsql; -mod overrides; +mod rename_rule; +mod struct_overrides; mod tosql; #[proc_macro_derive(ToSql, attributes(postgres))] diff --git a/postgres-derive/src/overrides.rs b/postgres-derive/src/overrides.rs deleted file mode 100644 index c00d5a94b..000000000 --- a/postgres-derive/src/overrides.rs +++ /dev/null @@ -1,57 +0,0 @@ -use syn::{Attribute, Error, Lit, Meta, NestedMeta}; - -pub struct Overrides { - pub name: Option, - pub transparent: bool, -} - -impl Overrides { - pub fn extract(attrs: &[Attribute]) -> Result { - let mut overrides = Overrides { - name: None, - transparent: false, - }; - - for attr in attrs { - let attr = attr.parse_meta()?; - - if !attr.path().is_ident("postgres") { - continue; - } - - let list = match attr { - Meta::List(ref list) => list, - bad => return Err(Error::new_spanned(bad, "expected a #[postgres(...)]")), - }; - - for item in &list.nested { - match item { - NestedMeta::Meta(Meta::NameValue(meta)) => { - if !meta.path.is_ident("name") { - return Err(Error::new_spanned(&meta.path, "unknown override")); - } - - let value = match &meta.lit { - Lit::Str(s) => s.value(), - bad => { - return Err(Error::new_spanned(bad, "expected a string literal")) - } - }; - - overrides.name = Some(value); - } - NestedMeta::Meta(Meta::Path(ref path)) => { - if !path.is_ident("transparent") { - return Err(Error::new_spanned(path, "unknown override")); - } - - overrides.transparent = true; - } - bad => return Err(Error::new_spanned(bad, "unknown attribute")), - } - } - } - - Ok(overrides) - } -} diff --git a/postgres-derive/src/rename_rule.rs b/postgres-derive/src/rename_rule.rs new file mode 100644 index 000000000..a6c5bb19b --- /dev/null +++ b/postgres-derive/src/rename_rule.rs @@ -0,0 +1,36 @@ +use strum_macros::{Display, EnumIter, EnumString}; + +#[derive(Clone, Copy, Display, EnumIter, EnumString)] +pub enum RenameRule { + #[strum(serialize = "camelCase")] + Camel, + #[strum(serialize = "kebab-case")] + Kebab, + #[strum(serialize = "lowercase")] + Lower, + #[strum(serialize = "PascalCase")] + Pascal, + #[strum(serialize = "SCREAMING-KEBAB-CASE")] + ScreamingKebab, + #[strum(serialize = "SCREAMING_SNAKE_CASE")] + ScreamingSnake, + #[strum(serialize = "snake_case")] + Snake, + #[strum(serialize = "UPPERCASE")] + Upper, +} + +impl From for convert_case::Case { + fn from(rule: RenameRule) -> Self { + match rule { + RenameRule::Camel => Self::Camel, + RenameRule::Kebab => Self::Kebab, + RenameRule::Lower => Self::Lower, + RenameRule::Pascal => Self::Pascal, + RenameRule::ScreamingKebab => Self::UpperKebab, + RenameRule::ScreamingSnake => Self::UpperSnake, + RenameRule::Snake => Self::Snake, + RenameRule::Upper => Self::Upper, + } + } +} diff --git a/postgres-derive/src/struct_overrides.rs b/postgres-derive/src/struct_overrides.rs new file mode 100644 index 000000000..33fa18864 --- /dev/null +++ b/postgres-derive/src/struct_overrides.rs @@ -0,0 +1,77 @@ +use syn::{Attribute, Error, Lit, Meta, NestedMeta}; + +use crate::rename_rule::RenameRule; + +#[derive(Default)] +pub struct StructOverrides { + pub name: Option, + pub transparent: bool, + pub rename_all: Option, +} + +impl StructOverrides { + pub fn extract(attrs: &[Attribute]) -> Result { + use itertools::Itertools; + use strum::IntoEnumIterator; + + let mut overrides: StructOverrides = Default::default(); + + for attr in attrs { + let attr = attr.parse_meta()?; + + if !attr.path().is_ident("postgres") { + continue; + } + + let list = match attr { + Meta::List(ref list) => list, + bad => return Err(Error::new_spanned(bad, "expected a #[postgres(...)]")), + }; + + for item in &list.nested { + match item { + NestedMeta::Meta(Meta::NameValue(meta)) => { + if meta.path.is_ident("name") { + let value = match &meta.lit { + Lit::Str(s) => s.value(), + bad => { + return Err(Error::new_spanned( + bad, + "expected a string literal", + )) + } + }; + overrides.name = Some(value); + } else if meta.path.is_ident("rename_all") { + let rename_rule: RenameRule = match &meta.lit { + Lit::Str(s) => s.value().parse().ok(), + _other => None, + } + .ok_or_else(|| { + let all_variants = RenameRule::iter() + .map(|variant| format!("\"{variant}\"")) + .join(", "); + Error::new_spanned( + &meta.lit, + format!("expected one of: {all_variants}"), + ) + })?; + overrides.rename_all = Some(rename_rule.into()); + } else { + return Err(Error::new_spanned(&meta.path, "unknown override")); + } + } + NestedMeta::Meta(Meta::Path(ref path)) => { + if !path.is_ident("transparent") { + return Err(Error::new_spanned(path, "unknown override")); + } + overrides.transparent = true; + } + bad => return Err(Error::new_spanned(bad, "unknown attribute")), + } + } + } + + Ok(overrides) + } +} diff --git a/postgres-derive/src/tosql.rs b/postgres-derive/src/tosql.rs index e51acc7fd..a34c09300 100644 --- a/postgres-derive/src/tosql.rs +++ b/postgres-derive/src/tosql.rs @@ -10,10 +10,10 @@ use crate::accepts; use crate::composites::Field; use crate::composites::{append_generic_bound, new_derive_path}; use crate::enums::Variant; -use crate::overrides::Overrides; +use crate::struct_overrides::StructOverrides; pub fn expand_derive_tosql(input: DeriveInput) -> Result { - let overrides = Overrides::extract(&input.attrs)?; + let mut overrides = StructOverrides::extract(&input.attrs)?; if overrides.name.is_some() && overrides.transparent { return Err(Error::new_spanned( @@ -22,7 +22,10 @@ pub fn expand_derive_tosql(input: DeriveInput) -> Result { )); } - let name = overrides.name.unwrap_or_else(|| input.ident.to_string()); + let name = overrides + .name + .take() + .unwrap_or_else(|| input.ident.to_string()); let (accepts_body, to_sql_body) = if overrides.transparent { match input.data { @@ -47,7 +50,7 @@ pub fn expand_derive_tosql(input: DeriveInput) -> Result { let variants = data .variants .iter() - .map(Variant::parse) + .map(|variant| Variant::parse(&overrides, variant)) .collect::, _>>()?; ( accepts::enum_body(&name, &variants), @@ -69,7 +72,7 @@ pub fn expand_derive_tosql(input: DeriveInput) -> Result { let fields = fields .named .iter() - .map(Field::parse) + .map(|field| Field::parse(&overrides, field)) .collect::, _>>()?; ( accepts::composite_body(&name, "ToSql", &fields),