Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sibling enum conversions #22

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@ quote = "1.0.23"
syn = { version = "1.0.107", features = ["full", "extra-traits"] }
proc-macro2 = "1.0.51"
heck = "0.4.1"
itertools = "0.11.0"
2 changes: 1 addition & 1 deletion src/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ fn add_bound(generics: &mut Generics, bound: TypeParamBound) {
// * Foo -> Foo
// * Foo(Bar, Baz) -> Foo(var1, var2)
// * Foo { x: i32, y: i32 } -> Foo { x, y }
fn variant_to_unary_pat(variant: &Variant) -> TokenStream2 {
pub(crate) fn variant_to_unary_pat(variant: &Variant) -> TokenStream2 {
let ident = &variant.ident;

match &variant.fields {
Expand Down
1 change: 1 addition & 0 deletions src/enum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use syn::{punctuated::Punctuated, Generics, Ident, Token, TypeParamBound, Varian

use crate::{extractor::Extractor, iter::BoxedIter, param::Param, Derive};

#[derive(Clone)]
pub struct Enum {
pub ident: Ident,
pub variants: Punctuated<Variant, Token![,]>,
Expand Down
130 changes: 128 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,18 @@ mod extractor;
mod iter;
mod param;

use std::collections::HashMap;
use std::collections::{HashMap, HashSet};

use derive::Derive;
use heck::ToSnakeCase;
use itertools::Itertools;
use proc_macro::TokenStream;
use proc_macro2::Ident;
use quote::quote;
use r#enum::Enum;
use syn::{
parse_macro_input, Attribute, AttributeArgs, DeriveInput, Field, Meta, NestedMeta, Path, Type,
parse_macro_input, Attribute, AttributeArgs, DeriveInput, Field, Meta, MetaList, MetaNameValue,
NestedMeta, Type, Variant, Path, punctuated::Punctuated,
};

const SUBENUM: &str = "subenum";
Expand Down Expand Up @@ -130,6 +132,128 @@ pub fn subenum(args: TokenStream, tokens: TokenStream) -> TokenStream {
e.compute_generics(&input.generics);
}

let mut sibling_conversions = Vec::new();
for (sibling1, sibling2) in enums.values().cloned().tuple_combinations() {
let sibling1_variants_hash_set: HashSet<Variant> = sibling1.variants.into_iter().collect();
let sibling2_variants_hash_set: HashSet<Variant> = sibling2.variants.into_iter().collect();

let intersection = sibling1_variants_hash_set.intersection(&sibling2_variants_hash_set).collect::<Vec<&Variant>>();
if intersection.is_empty() {
continue;
}

let sibling1_ident = sibling1.ident;
let (_, sibling1_ty, _) = sibling1.generics.split_for_impl();

let sibling2_ident = sibling2.ident;
let (_, sibling2_ty, _) = sibling2.generics.split_for_impl();

let mut combined_generics = sibling1.generics.params.clone().into_iter().collect::<HashSet<syn::GenericParam>>();
combined_generics.extend(sibling2.generics.params.clone().into_iter().collect::<HashSet<syn::GenericParam>>());

let combined_generics = syn::Generics {
lt_token: Some(syn::token::Lt::default()),
params: Punctuated::from_iter(combined_generics.into_iter()),
gt_token: Some(syn::token::Gt::default()),
where_clause: None,
};

let mut combined_where = sibling1.generics.where_clause.clone()
.map(|where_clause| where_clause.predicates.into_iter().collect::<HashSet<syn::WherePredicate>>()).unwrap_or_default();
combined_where.extend(sibling2.generics.where_clause.clone()
.map(|where_clause| where_clause.predicates.into_iter().collect::<HashSet<syn::WherePredicate>>()).unwrap_or_default());

let combined_where = Some(syn::WhereClause {
where_token: syn::token::Where::default(),
predicates: Punctuated::from_iter(combined_where.into_iter())
});

let pats: Vec<proc_macro2::TokenStream> = intersection.iter().map(|variant| build::variant_to_unary_pat(*variant)).collect();

let sibling1_to_sibling2 = if sibling1_variants_hash_set.len() == intersection.len() {
let from_sibling1_arms = pats
.iter()
.map(|pat| quote!(#sibling1_ident::#pat => #sibling2_ident::#pat));

quote! {
#[automatically_derived]
impl #combined_generics std::convert::From<#sibling1_ident #sibling1_ty> for #sibling2_ident #sibling2_ty #combined_where {
fn from(sibling: #sibling1_ident #sibling1_ty) -> Self {
match sibling {
#(#from_sibling1_arms),*
}
}
}
}
} else {
let try_from_sibling1_arms = pats
.iter()
.map(|pat| quote!(#sibling1_ident::#pat => Ok(#sibling2_ident::#pat)));

let error = quote::format_ident!("{sibling2_ident}ConvertError");

quote! {
#[automatically_derived]
impl #combined_generics std::convert::TryFrom<#sibling1_ident #sibling1_ty> for #sibling2_ident #sibling2_ty #combined_where {
type Error = #error;

fn try_from(sibling: #sibling1_ident #sibling1_ty) -> Result<Self, Self::Error> {
match sibling {
#(#try_from_sibling1_arms),*,
_ => Err(#error)
}
}
}
}
};

let sibling2_to_sibling1 = if sibling2_variants_hash_set.len() == intersection.len() {
let from_sibling2_arms = pats
.iter()
.map(|pat| quote!(#sibling2_ident::#pat => #sibling1_ident::#pat));

quote! {
#[automatically_derived]
impl #combined_generics std::convert::From<#sibling2_ident #sibling2_ty> for #sibling1_ident #sibling1_ty #combined_where {
fn from(sibling: #sibling2_ident #sibling2_ty) -> Self {
match sibling {
#(#from_sibling2_arms),*
}
}
}
}
} else {
let try_from_sibling2_arms = pats
.iter()
.map(|pat| quote!(#sibling2_ident::#pat => Ok(#sibling1_ident::#pat)));

let error = quote::format_ident!("{sibling1_ident}ConvertError");

quote! {
#[automatically_derived]
impl #combined_generics std::convert::TryFrom<#sibling2_ident #sibling2_ty> for #sibling1_ident #sibling1_ty #combined_where {
type Error = #error;

fn try_from(sibling: #sibling2_ident #sibling2_ty) -> Result<Self, Self::Error> {
match sibling {
#(#try_from_sibling2_arms),*,
_ => Err(#error)
}
}
}
}
};

sibling_conversions.push(
quote!{
#sibling1_to_sibling2

#sibling2_to_sibling1
}
);
}


let enums: Vec<_> = enums.into_values().map(|e| e.build(&input, data)).collect();

sanitize_input(&mut input);
Expand All @@ -138,6 +262,8 @@ pub fn subenum(args: TokenStream, tokens: TokenStream) -> TokenStream {
#input

#(#enums)*

#(#sibling_conversions)*
)
.into()
}
14 changes: 14 additions & 0 deletions tests/test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
use subenum::subenum;

#[subenum(EnumB, EnumC, EnumD)]
#[derive(PartialEq, Debug, Clone)]
enum EnumA<T, U> where
T: From<String>,
U: Copy {
#[subenum(EnumC, EnumD)]
A,
#[subenum(EnumB, EnumC)]
B(T),
#[subenum(EnumB)]
C(U)
}