diff --git a/graphql_client_cli/src/generate.rs b/graphql_client_cli/src/generate.rs index 2871887e..1a36d0cf 100644 --- a/graphql_client_cli/src/generate.rs +++ b/graphql_client_cli/src/generate.rs @@ -23,6 +23,8 @@ pub(crate) struct CliCodegenParams { pub custom_scalars_module: Option, pub fragments_other_variant: bool, pub external_enums: Option>, + pub custom_variable_types: Option, + pub custom_response_type: Option, } const WARNING_SUPPRESSION: &str = "#![allow(clippy::all, warnings)]"; @@ -41,6 +43,8 @@ pub(crate) fn generate_code(params: CliCodegenParams) -> CliResult<()> { custom_scalars_module, fragments_other_variant, external_enums, + custom_variable_types, + custom_response_type, } = params; let deprecation_strategy = deprecation_strategy.as_ref().and_then(|s| s.parse().ok()); @@ -89,6 +93,14 @@ pub(crate) fn generate_code(params: CliCodegenParams) -> CliResult<()> { options.set_custom_scalars_module(custom_scalars_module); } + + if let Some(custom_variable_types) = custom_variable_types { + options.set_custom_variable_types(custom_variable_types.split(",").map(String::from).collect()); + } + + if let Some(custom_response_type) = custom_response_type { + options.set_custom_response_type(custom_response_type); + } let gen = generate_module_token_stream(query_path.clone(), &schema_path, options) .map_err(|err| Error::message(format!("Error generating module code: {}", err)))?; diff --git a/graphql_client_cli/src/main.rs b/graphql_client_cli/src/main.rs index f7292935..6721953b 100644 --- a/graphql_client_cli/src/main.rs +++ b/graphql_client_cli/src/main.rs @@ -94,6 +94,14 @@ enum Cli { /// List of externally defined enum types. Type names must match those used in the schema exactly #[clap(long = "external-enums", num_args(0..), action(clap::ArgAction::Append))] external_enums: Option>, + /// Custom variable types to use + /// --custom-variable-types='external_crate::MyStruct,external_crate::MyStruct2' + #[clap(long = "custom-variable_types")] + custom_variable_types: Option, + /// Custom response type to use + /// --custom-response-type='external_crate::MyResponse' + #[clap(long = "custom-response-type")] + custom_response_type: Option, }, } @@ -131,7 +139,9 @@ fn main() -> CliResult<()> { selected_operation, custom_scalars_module, fragments_other_variant, - external_enums, + external_enums, + custom_variable_types, + custom_response_type, } => generate::generate_code(generate::CliCodegenParams { query_path, schema_path, @@ -145,6 +155,8 @@ fn main() -> CliResult<()> { custom_scalars_module, fragments_other_variant, external_enums, + custom_variable_types, + custom_response_type, }), } } diff --git a/graphql_client_codegen/src/codegen.rs b/graphql_client_codegen/src/codegen.rs index 007dce86..e33bb1b1 100644 --- a/graphql_client_codegen/src/codegen.rs +++ b/graphql_client_codegen/src/codegen.rs @@ -42,7 +42,7 @@ pub(crate) fn response_for_query( generate_variables_struct(operation_id, &variable_derives, options, &query); let definitions = - render_response_data_fields(operation_id, options, &query).render(&response_derives); + render_response_data_fields(operation_id, options, &query)?.render(&response_derives); let q = quote! { use #serde::{Serialize, Deserialize}; diff --git a/graphql_client_codegen/src/codegen/inputs.rs b/graphql_client_codegen/src/codegen/inputs.rs index adf0b2db..d8cc1080 100644 --- a/graphql_client_codegen/src/codegen/inputs.rs +++ b/graphql_client_codegen/src/codegen/inputs.rs @@ -15,10 +15,18 @@ pub(super) fn generate_input_object_definitions( variable_derives: &impl quote::ToTokens, query: &BoundQuery<'_>, ) -> Vec { + let custom_variable_types = options.custom_variable_types(); all_used_types .inputs(query.schema) - .map(|(_input_id, input)| { - if input.is_one_of { + .map(|(input_id, input)| { + let custom_variable_type = query.query.variables.iter() + .enumerate() + .find(|(_, v) | v.r#type.id.as_input_id().is_some_and(|i| i == input_id)) + .map(|(index, _)| custom_variable_types.get(index)) + .flatten(); + if let Some(custom_type) = custom_variable_type { + generate_type_def(input, options, custom_type) + } else if input.is_one_of { generate_enum(input, options, variable_derives, query) } else { generate_struct(input, options, variable_derives, query) @@ -27,6 +35,18 @@ pub(super) fn generate_input_object_definitions( .collect() } +fn generate_type_def( + input: &StoredInputType, + options: &GraphQLClientCodegenOptions, + custom_type: &String, +) -> TokenStream { + let custom_type = syn::parse_str::(custom_type).unwrap(); + let normalized_name = options.normalization().input_name(input.name.as_str()); + let safe_name = keyword_replace(normalized_name); + let struct_name = Ident::new(safe_name.as_ref(), Span::call_site()); + quote!(pub type #struct_name = #custom_type;) +} + fn generate_struct( input: &StoredInputType, options: &GraphQLClientCodegenOptions, diff --git a/graphql_client_codegen/src/codegen/selection.rs b/graphql_client_codegen/src/codegen/selection.rs index 6b6677dd..ec1703b8 100644 --- a/graphql_client_codegen/src/codegen/selection.rs +++ b/graphql_client_codegen/src/codegen/selection.rs @@ -13,17 +13,19 @@ use crate::{ schema::{Schema, TypeId}, type_qualifiers::GraphqlTypeQualifier, GraphQLClientCodegenOptions, + GeneralError, }; use heck::*; use proc_macro2::{Ident, Span, TokenStream}; use quote::{quote, ToTokens}; use std::borrow::Cow; +use syn::Path; pub(crate) fn render_response_data_fields<'a>( operation_id: OperationId, options: &'a GraphQLClientCodegenOptions, query: &'a BoundQuery<'a>, -) -> ExpandedSelection<'a> { +) -> Result, GeneralError> { let operation = query.query.get_operation(operation_id); let mut expanded_selection = ExpandedSelection { query, @@ -38,6 +40,18 @@ pub(crate) fn render_response_data_fields<'a>( name: Cow::Borrowed("ResponseData"), }); + if let Some(custom_response_type) = options.custom_response_type() { + if operation.selection_set.len() == 1 { + let selection_id = operation.selection_set[0]; + let selection_field = query.query.get_selection(selection_id).as_selected_field() + .ok_or_else(|| GeneralError(format!("Custom response type {custom_response_type} will only work on fields")))?; + calculate_custom_response_type_selection(&mut expanded_selection, response_data_type_id, custom_response_type, selection_id, selection_field); + return Ok(expanded_selection); + } else { + return Err(GeneralError(format!("Custom response type {custom_response_type} requires single selection field"))); + } + } + calculate_selection( &mut expanded_selection, &operation.selection_set, @@ -46,7 +60,38 @@ pub(crate) fn render_response_data_fields<'a>( options, ); - expanded_selection + Ok(expanded_selection) +} + +fn calculate_custom_response_type_selection<'a>( + context: &mut ExpandedSelection<'a>, + struct_id: ResponseTypeId, + custom_response_type: &'a String, + selection_id: SelectionId, + field: &'a SelectedField) +{ + let (graphql_name, rust_name) = context.field_name(field); + let struct_name_string = full_path_prefix(selection_id, context.query); + let field = context.query.schema.get_field(field.field_id); + context.push_field(ExpandedField { + struct_id, + graphql_name: Some(graphql_name), + rust_name, + field_type_qualifiers: &field.r#type.qualifiers, + field_type: struct_name_string.clone().into(), + flatten: false, + boxed: false, + deprecation: field.deprecation(), + }); + + let struct_id = context.push_type(ExpandedType { + name: struct_name_string.into(), + }); + context.push_type_alias(TypeAlias { + name: custom_response_type.as_str(), + struct_id, + boxed: false, + }); } pub(super) fn render_fragment<'a>( @@ -557,14 +602,14 @@ impl<'a> ExpandedSelection<'a> { // If the type is aliased, stop here. if let Some(alias) = self.aliases.iter().find(|alias| alias.struct_id == type_id) { - let fragment_name = Ident::new(alias.name, Span::call_site()); - let fragment_name = if alias.boxed { - quote!(Box<#fragment_name>) + let type_name = syn::parse_str::(alias.name).unwrap(); + let type_name = if alias.boxed { + quote!(Box<#type_name>) } else { - quote!(#fragment_name) + quote!(#type_name) }; let item = quote! { - pub type #struct_name = #fragment_name; + pub type #struct_name = #type_name; }; items.push(item); continue; diff --git a/graphql_client_codegen/src/codegen_options.rs b/graphql_client_codegen/src/codegen_options.rs index 006ae727..7b3d8d73 100644 --- a/graphql_client_codegen/src/codegen_options.rs +++ b/graphql_client_codegen/src/codegen_options.rs @@ -49,6 +49,10 @@ pub struct GraphQLClientCodegenOptions { skip_serializing_none: bool, /// Path to the serde crate. serde_path: syn::Path, + /// list of custom type paths to use for input variables + custom_variable_types: Option>, + /// Custom response type path + custom_response_type: Option, } impl GraphQLClientCodegenOptions { @@ -71,6 +75,8 @@ impl GraphQLClientCodegenOptions { fragments_other_variant: Default::default(), skip_serializing_none: Default::default(), serde_path: syn::parse_quote!(::serde), + custom_variable_types: Default::default(), + custom_response_type: Default::default(), } } @@ -138,6 +144,26 @@ impl GraphQLClientCodegenOptions { self.response_derives = Some(response_derives); } + /// Type use as the response type + pub fn custom_response_type(&self) -> Option<&String> { + self.custom_response_type.as_ref() + } + + /// Type use as the response type + pub fn set_custom_response_type(&mut self, response_type: String) { + self.custom_response_type = Some(response_type); + } + + /// list of custom type paths to use for input variables + pub fn custom_variable_types(&self) -> Vec { + self.custom_variable_types.clone().unwrap_or_default() + } + + /// list of custom type paths to use for input variables + pub fn set_custom_variable_types(&mut self, variables_types: Vec) { + self.custom_variable_types = Some(variables_types); + } + /// The deprecation strategy to adopt. pub fn set_deprecation_strategy(&mut self, deprecation_strategy: DeprecationStrategy) { self.deprecation_strategy = Some(deprecation_strategy); diff --git a/graphql_client_codegen/src/query.rs b/graphql_client_codegen/src/query.rs index bb2b6318..71d0798f 100644 --- a/graphql_client_codegen/src/query.rs +++ b/graphql_client_codegen/src/query.rs @@ -58,7 +58,7 @@ pub(crate) struct ResolvedFragmentId(u32); #[allow(dead_code)] #[derive(Debug, Clone, Copy)] -pub(crate) struct VariableId(u32); +pub(crate) struct VariableId(pub u32); pub(crate) fn resolve<'doc, T>( schema: &Schema, @@ -512,7 +512,7 @@ pub(crate) struct Query { operations: Vec, selection_parent_idx: BTreeMap, selections: Vec, - variables: Vec, + pub(crate) variables: Vec, } impl Query { diff --git a/graphql_client_codegen/src/tests/mod.rs b/graphql_client_codegen/src/tests/mod.rs index 263001c6..aaed3e5d 100644 --- a/graphql_client_codegen/src/tests/mod.rs +++ b/graphql_client_codegen/src/tests/mod.rs @@ -42,6 +42,35 @@ fn schema_with_keywords_works() { }; } +#[test] +fn blended_custom_types_works() { + let query_string = KEYWORDS_QUERY; + let schema_path = build_schema_path(KEYWORDS_SCHEMA_PATH); + + let mut options = GraphQLClientCodegenOptions::new(CodegenMode::Cli); + options.set_custom_response_type("external_crate::Transaction".to_string()); + options.set_custom_variable_types(vec!["external_crate::ID".to_string()]); + + let generated_tokens = + generate_module_token_stream_from_string(query_string, &schema_path, options) + .expect("Generate keywords module"); + + let generated_code = generated_tokens.to_string(); + + // Parse generated code. Variables and returns should be replaced with custom types + let r: syn::parse::Result = syn::parse2(generated_tokens); + match r { + Ok(_) => { + // Variables and returns should be replaced with custom types + assert!(generated_code.contains("pub type SearchQuerySearch = external_crate :: Transaction")); + assert!(generated_code.contains("pub type extern_ = external_crate :: ID")); + } + Err(e) => { + panic!("Error: {}\n Generated content: {}\n", e, &generated_code); + } + }; +} + #[test] fn fragments_other_variant_should_generate_unknown_other_variant() { let query_string = FOOBARS_QUERY; diff --git a/graphql_query_derive/src/attributes.rs b/graphql_query_derive/src/attributes.rs index 99158217..535914fb 100644 --- a/graphql_query_derive/src/attributes.rs +++ b/graphql_query_derive/src/attributes.rs @@ -295,4 +295,44 @@ mod test { vec!["Direction", "DistanceUnit"], ); } + + #[test] + fn test_custom_variable_types() { + let input = r#" + #[derive(Serialize, Deserialize, Debug)] + #[derive(GraphQLQuery)] + #[graphql( + schema_path = "x", + query_path = "x", + variable_types("extern_crate::Var1", "extern_crate::Var2"), + )] + struct MyQuery; + "#; + let parsed: syn::DeriveInput = syn::parse_str(input).unwrap(); + + assert_eq!( + extract_attr_list(&parsed, "variable_types").ok().unwrap(), + vec!["extern_crate::Var1", "extern_crate::Var2"], + ); + } + + #[test] + fn test_custom_response_type() { + let input = r#" + #[derive(Serialize, Deserialize, Debug)] + #[derive(GraphQLQuery)] + #[graphql( + schema_path = "x", + query_path = "x", + response_type = "extern_crate::Resp", + )] + struct MyQuery; + "#; + let parsed: syn::DeriveInput = syn::parse_str(input).unwrap(); + + assert_eq!( + extract_attr(&parsed, "response_type").ok().unwrap(), + "extern_crate::Resp", + ); + } } diff --git a/graphql_query_derive/src/lib.rs b/graphql_query_derive/src/lib.rs index 0eea2c16..c6a7eca3 100644 --- a/graphql_query_derive/src/lib.rs +++ b/graphql_query_derive/src/lib.rs @@ -65,6 +65,8 @@ fn build_graphql_client_derive_options( let extern_enums = attributes::extract_attr_list(input, "extern_enums").ok(); let fragments_other_variant: bool = attributes::extract_fragments_other_variant(input); let skip_serializing_none: bool = attributes::extract_skip_serializing_none(input); + let custom_variable_types = attributes::extract_attr_list(input, "variable_types").ok(); + let custom_response_type = attributes::extract_attr(input, "response_type").ok(); let mut options = GraphQLClientCodegenOptions::new(CodegenMode::Derive); options.set_query_file(query_path); @@ -101,6 +103,14 @@ fn build_graphql_client_derive_options( options.set_extern_enums(extern_enums); } + if let Some(custom_variable_types) = custom_variable_types { + options.set_custom_variable_types(custom_variable_types); + } + + if let Some(custom_response_type) = custom_response_type { + options.set_custom_response_type(custom_response_type); + } + options.set_struct_ident(input.ident.clone()); options.set_module_visibility(input.vis.clone()); options.set_operation_name(input.ident.to_string());