Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a container-level rename_all attribute to FromSql/ToSql #952

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions postgres-derive-test/src/composites.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,48 @@ fn name_overrides() {
);
}

#[test]
fn rename_all_overrides() {
#[derive(FromSql, ToSql, Debug, PartialEq)]
#[postgres(name = "inventory_item", rename_all = "SCREAMING_SNAKE_CASE")]
struct InventoryItem {
name: String,
supplier_id: i32,
price: Option<f64>,
}

let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
conn.batch_execute(
"CREATE TYPE pg_temp.inventory_item AS (
\"NAME\" TEXT,
\"SUPPLIER_ID\" INT,
\"PRICE\" DOUBLE PRECISION
);",
)
.unwrap();

let item = InventoryItem {
name: "foobar".to_owned(),
supplier_id: 100,
price: Some(15.50),
};

let item_null = InventoryItem {
name: "foobar".to_owned(),
supplier_id: 100,
price: None,
};

test_type(
&mut conn,
"inventory_item",
&[
(item, "ROW('foobar', 100, 15.50)"),
(item_null, "ROW('foobar', 100, NULL)"),
],
);
}

#[test]
fn wrong_name() {
#[derive(FromSql, ToSql, Debug, PartialEq)]
Expand Down
29 changes: 29 additions & 0 deletions postgres-derive-test/src/enums.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,35 @@ fn name_overrides() {
);
}

#[test]
fn rename_all_overrides() {
#[derive(Debug, ToSql, FromSql, PartialEq)]
#[postgres(name = "mood", rename_all = "snake_case")]
enum Mood {
Sad,
#[postgres(name = "okay")]
Ok,
Happy,
}

let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
conn.execute(
"CREATE TYPE pg_temp.mood AS ENUM ('sad', 'okay', 'happy')",
&[],
)
.unwrap();

test_type(
&mut conn,
"mood",
&[
(Mood::Sad, "'sad'"),
(Mood::Ok, "'okay'"),
(Mood::Happy, "'happy'"),
],
);
}

#[test]
fn wrong_name() {
#[derive(Debug, ToSql, FromSql, PartialEq)]
Expand Down
4 changes: 4 additions & 0 deletions postgres-derive/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ proc-macro = true
test = false

[dependencies]
convert_case = "0.6"
itertools = "0.10"
syn = "1.0"
proc-macro2 = "1.0"
strum = "0.24"
strum_macros = "0.24"
quote = "1.0"
19 changes: 11 additions & 8 deletions postgres-derive/src/composites.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use syn::{
TypeParamBound,
};

use crate::overrides::Overrides;
use crate::{field_variant_overrides::FieldVariantOverrides, struct_overrides::StructOverrides};

pub struct Field {
pub name: String,
Expand All @@ -13,17 +13,20 @@ pub struct Field {
}

impl Field {
pub fn parse(raw: &syn::Field) -> Result<Field, Error> {
let overrides = Overrides::extract(&raw.attrs)?;
pub fn parse(struct_overrides: &StructOverrides, raw: &syn::Field) -> Result<Field, Error> {
use convert_case::Casing;

let mut overrides = FieldVariantOverrides::extract(&raw.attrs)?;

let ident = raw.ident.as_ref().unwrap().clone();
Ok(Field {
name: overrides.name.unwrap_or_else(|| {
name: overrides.name.take().unwrap_or_else(|| {
let name = ident.to_string();
match name.strip_prefix("r#") {
Some(name) => name.to_string(),
None => name,
}
let name = name.strip_prefix("r#").map(String::from).unwrap_or(name);
struct_overrides
.rename_all
.map(|case| name.to_case(case))
.unwrap_or(name)
}),
ident,
type_: raw.ty.clone(),
Expand Down
17 changes: 13 additions & 4 deletions postgres-derive/src/enums.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
use syn::{Error, Fields, Ident};

use crate::overrides::Overrides;
use crate::{field_variant_overrides::FieldVariantOverrides, struct_overrides::StructOverrides};

pub struct Variant {
pub ident: Ident,
pub name: String,
}

impl Variant {
pub fn parse(raw: &syn::Variant) -> Result<Variant, Error> {
pub fn parse(struct_overrides: &StructOverrides, raw: &syn::Variant) -> Result<Variant, Error> {
use convert_case::Casing;

match raw.fields {
Fields::Unit => {}
_ => {
Expand All @@ -19,10 +21,17 @@ impl Variant {
}
}

let overrides = Overrides::extract(&raw.attrs)?;
let mut overrides = FieldVariantOverrides::extract(&raw.attrs)?;
Ok(Variant {
ident: raw.ident.clone(),
name: overrides.name.unwrap_or_else(|| raw.ident.to_string()),
name: overrides.name.take().unwrap_or_else(|| {
let name = raw.ident.to_string();
let name = name.strip_prefix("r#").map(String::from).unwrap_or(name);
struct_overrides
.rename_all
.map(|case| name.to_case(case))
.unwrap_or(name)
}),
})
}
}
49 changes: 49 additions & 0 deletions postgres-derive/src/field_variant_overrides.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
use syn::{Attribute, Error, Lit, Meta, NestedMeta};

#[derive(Default)]
pub struct FieldVariantOverrides {
pub name: Option<String>,
}

impl FieldVariantOverrides {
pub fn extract(attrs: &[Attribute]) -> Result<FieldVariantOverrides, Error> {
let mut overrides: FieldVariantOverrides = Default::default();

for attr in attrs {
let attr = attr.parse_meta()?;

if !attr.path().is_ident("postgres") {
continue;
}

let list = match attr {
Meta::List(ref list) => list,
bad => return Err(Error::new_spanned(bad, "expected a #[postgres(...)]")),
};

for item in &list.nested {
match item {
NestedMeta::Meta(Meta::NameValue(meta)) => {
if meta.path.is_ident("name") {
let value = match &meta.lit {
Lit::Str(s) => s.value(),
bad => {
return Err(Error::new_spanned(
bad,
"expected a string literal",
))
}
};
overrides.name = Some(value);
} else {
return Err(Error::new_spanned(&meta.path, "unknown override"));
}
}
bad => return Err(Error::new_spanned(bad, "unknown attribute")),
}
}
}

Ok(overrides)
}
}
13 changes: 8 additions & 5 deletions postgres-derive/src/fromsql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ use crate::accepts;
use crate::composites::Field;
use crate::composites::{append_generic_bound, new_derive_path};
use crate::enums::Variant;
use crate::overrides::Overrides;
use crate::struct_overrides::StructOverrides;

pub fn expand_derive_fromsql(input: DeriveInput) -> Result<TokenStream, Error> {
let overrides = Overrides::extract(&input.attrs)?;
let mut overrides = StructOverrides::extract(&input.attrs)?;

if overrides.name.is_some() && overrides.transparent {
return Err(Error::new_spanned(
Expand All @@ -24,7 +24,10 @@ pub fn expand_derive_fromsql(input: DeriveInput) -> Result<TokenStream, Error> {
));
}

let name = overrides.name.unwrap_or_else(|| input.ident.to_string());
let name = overrides
.name
.take()
.unwrap_or_else(|| input.ident.to_string());

let (accepts_body, to_sql_body) = if overrides.transparent {
match input.data {
Expand All @@ -51,7 +54,7 @@ pub fn expand_derive_fromsql(input: DeriveInput) -> Result<TokenStream, Error> {
let variants = data
.variants
.iter()
.map(Variant::parse)
.map(|variant| Variant::parse(&overrides, variant))
.collect::<Result<Vec<_>, _>>()?;
(
accepts::enum_body(&name, &variants),
Expand All @@ -75,7 +78,7 @@ pub fn expand_derive_fromsql(input: DeriveInput) -> Result<TokenStream, Error> {
let fields = fields
.named
.iter()
.map(Field::parse)
.map(|field| Field::parse(&overrides, field))
.collect::<Result<Vec<_>, _>>()?;
(
accepts::composite_body(&name, "FromSql", &fields),
Expand Down
4 changes: 3 additions & 1 deletion postgres-derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ use syn::parse_macro_input;
mod accepts;
mod composites;
mod enums;
mod field_variant_overrides;
mod fromsql;
mod overrides;
mod rename_rule;
mod struct_overrides;
mod tosql;

#[proc_macro_derive(ToSql, attributes(postgres))]
Expand Down
57 changes: 0 additions & 57 deletions postgres-derive/src/overrides.rs

This file was deleted.

36 changes: 36 additions & 0 deletions postgres-derive/src/rename_rule.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
use strum_macros::{Display, EnumIter, EnumString};

#[derive(Clone, Copy, Display, EnumIter, EnumString)]
pub enum RenameRule {
#[strum(serialize = "camelCase")]
Camel,
#[strum(serialize = "kebab-case")]
Kebab,
#[strum(serialize = "lowercase")]
Lower,
#[strum(serialize = "PascalCase")]
Pascal,
#[strum(serialize = "SCREAMING-KEBAB-CASE")]
ScreamingKebab,
#[strum(serialize = "SCREAMING_SNAKE_CASE")]
ScreamingSnake,
#[strum(serialize = "snake_case")]
Snake,
#[strum(serialize = "UPPERCASE")]
Upper,
}

impl From<RenameRule> for convert_case::Case {
fn from(rule: RenameRule) -> Self {
match rule {
RenameRule::Camel => Self::Camel,
RenameRule::Kebab => Self::Kebab,
RenameRule::Lower => Self::Lower,
RenameRule::Pascal => Self::Pascal,
RenameRule::ScreamingKebab => Self::UpperKebab,
RenameRule::ScreamingSnake => Self::UpperSnake,
RenameRule::Snake => Self::Snake,
RenameRule::Upper => Self::Upper,
}
}
}
Loading