diff --git a/Cargo.toml b/Cargo.toml index f0fbc01..7e63d77 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,3 +24,4 @@ quote = "1.0.23" syn = { version = "1.0.107", features = ["full", "extra-traits"] } proc-macro2 = "1.0.51" heck = "0.4.1" +itertools = "0.11.0" \ No newline at end of file diff --git a/src/build.rs b/src/build.rs index bfc782b..7df7966 100644 --- a/src/build.rs +++ b/src/build.rs @@ -25,7 +25,7 @@ fn add_bound(generics: &mut Generics, bound: TypeParamBound) { // * Foo -> Foo // * Foo(Bar, Baz) -> Foo(var1, var2) // * Foo { x: i32, y: i32 } -> Foo { x, y } -fn variant_to_unary_pat(variant: &Variant) -> TokenStream2 { +pub(crate) fn variant_to_unary_pat(variant: &Variant) -> TokenStream2 { let ident = &variant.ident; match &variant.fields { diff --git a/src/enum.rs b/src/enum.rs index a13bccd..22b8ef0 100644 --- a/src/enum.rs +++ b/src/enum.rs @@ -4,6 +4,7 @@ use syn::{punctuated::Punctuated, Generics, Ident, Token, TypeParamBound, Varian use crate::{extractor::Extractor, iter::BoxedIter, param::Param, Derive}; +#[derive(Clone)] pub struct Enum { pub ident: Ident, pub variants: Punctuated<Variant, Token![,]>, diff --git a/src/lib.rs b/src/lib.rs index 61ad220..f577e85 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,16 +7,18 @@ mod extractor; mod iter; mod param; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use derive::Derive; use heck::ToSnakeCase; +use itertools::Itertools; use proc_macro::TokenStream; use proc_macro2::Ident; use quote::quote; use r#enum::Enum; use syn::{ - parse_macro_input, Attribute, AttributeArgs, DeriveInput, Field, Meta, NestedMeta, Path, Type, + parse_macro_input, Attribute, AttributeArgs, DeriveInput, Field, Meta, MetaList, MetaNameValue, + NestedMeta, Type, Variant, Path, punctuated::Punctuated, }; const SUBENUM: &str = "subenum"; @@ -130,6 +132,128 @@ pub fn subenum(args: TokenStream, tokens: TokenStream) -> TokenStream { e.compute_generics(&input.generics); } + let mut sibling_conversions = Vec::new(); + for (sibling1, sibling2) in enums.values().cloned().tuple_combinations() { + let sibling1_variants_hash_set: HashSet<Variant> = sibling1.variants.into_iter().collect(); + let sibling2_variants_hash_set: HashSet<Variant> = sibling2.variants.into_iter().collect(); + + let intersection = sibling1_variants_hash_set.intersection(&sibling2_variants_hash_set).collect::<Vec<&Variant>>(); + if intersection.is_empty() { + continue; + } + + let sibling1_ident = sibling1.ident; + let (_, sibling1_ty, _) = sibling1.generics.split_for_impl(); + + let sibling2_ident = sibling2.ident; + let (_, sibling2_ty, _) = sibling2.generics.split_for_impl(); + + let mut combined_generics = sibling1.generics.params.clone().into_iter().collect::<HashSet<syn::GenericParam>>(); + combined_generics.extend(sibling2.generics.params.clone().into_iter().collect::<HashSet<syn::GenericParam>>()); + + let combined_generics = syn::Generics { + lt_token: Some(syn::token::Lt::default()), + params: Punctuated::from_iter(combined_generics.into_iter()), + gt_token: Some(syn::token::Gt::default()), + where_clause: None, + }; + + let mut combined_where = sibling1.generics.where_clause.clone() + .map(|where_clause| where_clause.predicates.into_iter().collect::<HashSet<syn::WherePredicate>>()).unwrap_or_default(); + combined_where.extend(sibling2.generics.where_clause.clone() + .map(|where_clause| where_clause.predicates.into_iter().collect::<HashSet<syn::WherePredicate>>()).unwrap_or_default()); + + let combined_where = Some(syn::WhereClause { + where_token: syn::token::Where::default(), + predicates: Punctuated::from_iter(combined_where.into_iter()) + }); + + let pats: Vec<proc_macro2::TokenStream> = intersection.iter().map(|variant| build::variant_to_unary_pat(*variant)).collect(); + + let sibling1_to_sibling2 = if sibling1_variants_hash_set.len() == intersection.len() { + let from_sibling1_arms = pats + .iter() + .map(|pat| quote!(#sibling1_ident::#pat => #sibling2_ident::#pat)); + + quote! { + #[automatically_derived] + impl #combined_generics std::convert::From<#sibling1_ident #sibling1_ty> for #sibling2_ident #sibling2_ty #combined_where { + fn from(sibling: #sibling1_ident #sibling1_ty) -> Self { + match sibling { + #(#from_sibling1_arms),* + } + } + } + } + } else { + let try_from_sibling1_arms = pats + .iter() + .map(|pat| quote!(#sibling1_ident::#pat => Ok(#sibling2_ident::#pat))); + + let error = quote::format_ident!("{sibling2_ident}ConvertError"); + + quote! { + #[automatically_derived] + impl #combined_generics std::convert::TryFrom<#sibling1_ident #sibling1_ty> for #sibling2_ident #sibling2_ty #combined_where { + type Error = #error; + + fn try_from(sibling: #sibling1_ident #sibling1_ty) -> Result<Self, Self::Error> { + match sibling { + #(#try_from_sibling1_arms),*, + _ => Err(#error) + } + } + } + } + }; + + let sibling2_to_sibling1 = if sibling2_variants_hash_set.len() == intersection.len() { + let from_sibling2_arms = pats + .iter() + .map(|pat| quote!(#sibling2_ident::#pat => #sibling1_ident::#pat)); + + quote! { + #[automatically_derived] + impl #combined_generics std::convert::From<#sibling2_ident #sibling2_ty> for #sibling1_ident #sibling1_ty #combined_where { + fn from(sibling: #sibling2_ident #sibling2_ty) -> Self { + match sibling { + #(#from_sibling2_arms),* + } + } + } + } + } else { + let try_from_sibling2_arms = pats + .iter() + .map(|pat| quote!(#sibling2_ident::#pat => Ok(#sibling1_ident::#pat))); + + let error = quote::format_ident!("{sibling1_ident}ConvertError"); + + quote! { + #[automatically_derived] + impl #combined_generics std::convert::TryFrom<#sibling2_ident #sibling2_ty> for #sibling1_ident #sibling1_ty #combined_where { + type Error = #error; + + fn try_from(sibling: #sibling2_ident #sibling2_ty) -> Result<Self, Self::Error> { + match sibling { + #(#try_from_sibling2_arms),*, + _ => Err(#error) + } + } + } + } + }; + + sibling_conversions.push( + quote!{ + #sibling1_to_sibling2 + + #sibling2_to_sibling1 + } + ); + } + + let enums: Vec<_> = enums.into_values().map(|e| e.build(&input, data)).collect(); sanitize_input(&mut input); @@ -138,6 +262,8 @@ pub fn subenum(args: TokenStream, tokens: TokenStream) -> TokenStream { #input #(#enums)* + + #(#sibling_conversions)* ) .into() } diff --git a/tests/test.rs b/tests/test.rs new file mode 100644 index 0000000..0e19c06 --- /dev/null +++ b/tests/test.rs @@ -0,0 +1,14 @@ +use subenum::subenum; + +#[subenum(EnumB, EnumC, EnumD)] +#[derive(PartialEq, Debug, Clone)] +enum EnumA<T, U> where +T: From<String>, +U: Copy { + #[subenum(EnumC, EnumD)] + A, + #[subenum(EnumB, EnumC)] + B(T), + #[subenum(EnumB)] + C(U) +} \ No newline at end of file