Skip to content

Add support for custom variable and response types #536

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions graphql_client_cli/src/generate.rs
Original file line number Diff line number Diff line change
@@ -23,6 +23,8 @@ pub(crate) struct CliCodegenParams {
pub custom_scalars_module: Option<String>,
pub fragments_other_variant: bool,
pub external_enums: Option<Vec<String>>,
pub custom_variable_types: Option<String>,
pub custom_response_type: Option<String>,
}

const WARNING_SUPPRESSION: &str = "#![allow(clippy::all, warnings)]";
@@ -41,6 +43,8 @@ pub(crate) fn generate_code(params: CliCodegenParams) -> CliResult<()> {
custom_scalars_module,
fragments_other_variant,
external_enums,
custom_variable_types,
custom_response_type,
} = params;

let deprecation_strategy = deprecation_strategy.as_ref().and_then(|s| s.parse().ok());
@@ -89,6 +93,14 @@ pub(crate) fn generate_code(params: CliCodegenParams) -> CliResult<()> {

options.set_custom_scalars_module(custom_scalars_module);
}

if let Some(custom_variable_types) = custom_variable_types {
options.set_custom_variable_types(custom_variable_types.split(",").map(String::from).collect());
}

if let Some(custom_response_type) = custom_response_type {
options.set_custom_response_type(custom_response_type);
}

let gen = generate_module_token_stream(query_path.clone(), &schema_path, options)
.map_err(|err| Error::message(format!("Error generating module code: {}", err)))?;
14 changes: 13 additions & 1 deletion graphql_client_cli/src/main.rs
Original file line number Diff line number Diff line change
@@ -94,6 +94,14 @@ enum Cli {
/// List of externally defined enum types. Type names must match those used in the schema exactly
#[clap(long = "external-enums", num_args(0..), action(clap::ArgAction::Append))]
external_enums: Option<Vec<String>>,
/// Custom variable types to use
/// --custom-variable-types='external_crate::MyStruct,external_crate::MyStruct2'
#[clap(long = "custom-variable_types")]
custom_variable_types: Option<String>,
/// Custom response type to use
/// --custom-response-type='external_crate::MyResponse'
#[clap(long = "custom-response-type")]
custom_response_type: Option<String>,
},
}

@@ -131,7 +139,9 @@ fn main() -> CliResult<()> {
selected_operation,
custom_scalars_module,
fragments_other_variant,
external_enums,
external_enums,
custom_variable_types,
custom_response_type,
} => generate::generate_code(generate::CliCodegenParams {
query_path,
schema_path,
@@ -145,6 +155,8 @@ fn main() -> CliResult<()> {
custom_scalars_module,
fragments_other_variant,
external_enums,
custom_variable_types,
custom_response_type,
}),
}
}
2 changes: 1 addition & 1 deletion graphql_client_codegen/src/codegen.rs
Original file line number Diff line number Diff line change
@@ -42,7 +42,7 @@ pub(crate) fn response_for_query(
generate_variables_struct(operation_id, &variable_derives, options, &query);

let definitions =
render_response_data_fields(operation_id, options, &query).render(&response_derives);
render_response_data_fields(operation_id, options, &query)?.render(&response_derives);

let q = quote! {
use #serde::{Serialize, Deserialize};
24 changes: 22 additions & 2 deletions graphql_client_codegen/src/codegen/inputs.rs
Original file line number Diff line number Diff line change
@@ -15,10 +15,18 @@ pub(super) fn generate_input_object_definitions(
variable_derives: &impl quote::ToTokens,
query: &BoundQuery<'_>,
) -> Vec<TokenStream> {
let custom_variable_types = options.custom_variable_types();
all_used_types
.inputs(query.schema)
.map(|(_input_id, input)| {
if input.is_one_of {
.map(|(input_id, input)| {
let custom_variable_type = query.query.variables.iter()
.enumerate()
.find(|(_, v) | v.r#type.id.as_input_id().is_some_and(|i| i == input_id))
.map(|(index, _)| custom_variable_types.get(index))
.flatten();
if let Some(custom_type) = custom_variable_type {
generate_type_def(input, options, custom_type)
} else if input.is_one_of {
generate_enum(input, options, variable_derives, query)
} else {
generate_struct(input, options, variable_derives, query)
@@ -27,6 +35,18 @@ pub(super) fn generate_input_object_definitions(
.collect()
}

fn generate_type_def(
input: &StoredInputType,
options: &GraphQLClientCodegenOptions,
custom_type: &String,
) -> TokenStream {
let custom_type = syn::parse_str::<syn::Path>(custom_type).unwrap();
let normalized_name = options.normalization().input_name(input.name.as_str());
let safe_name = keyword_replace(normalized_name);
let struct_name = Ident::new(safe_name.as_ref(), Span::call_site());
quote!(pub type #struct_name = #custom_type;)
}

fn generate_struct(
input: &StoredInputType,
options: &GraphQLClientCodegenOptions,
59 changes: 52 additions & 7 deletions graphql_client_codegen/src/codegen/selection.rs
Original file line number Diff line number Diff line change
@@ -13,17 +13,19 @@ use crate::{
schema::{Schema, TypeId},
type_qualifiers::GraphqlTypeQualifier,
GraphQLClientCodegenOptions,
GeneralError,
};
use heck::*;
use proc_macro2::{Ident, Span, TokenStream};
use quote::{quote, ToTokens};
use std::borrow::Cow;
use syn::Path;

pub(crate) fn render_response_data_fields<'a>(
operation_id: OperationId,
options: &'a GraphQLClientCodegenOptions,
query: &'a BoundQuery<'a>,
) -> ExpandedSelection<'a> {
) -> Result<ExpandedSelection<'a>, GeneralError> {
let operation = query.query.get_operation(operation_id);
let mut expanded_selection = ExpandedSelection {
query,
@@ -38,6 +40,18 @@ pub(crate) fn render_response_data_fields<'a>(
name: Cow::Borrowed("ResponseData"),
});

if let Some(custom_response_type) = options.custom_response_type() {
if operation.selection_set.len() == 1 {
let selection_id = operation.selection_set[0];
let selection_field = query.query.get_selection(selection_id).as_selected_field()
.ok_or_else(|| GeneralError(format!("Custom response type {custom_response_type} will only work on fields")))?;
calculate_custom_response_type_selection(&mut expanded_selection, response_data_type_id, custom_response_type, selection_id, selection_field);
return Ok(expanded_selection);
} else {
return Err(GeneralError(format!("Custom response type {custom_response_type} requires single selection field")));
}
}

calculate_selection(
&mut expanded_selection,
&operation.selection_set,
@@ -46,7 +60,38 @@ pub(crate) fn render_response_data_fields<'a>(
options,
);

expanded_selection
Ok(expanded_selection)
}

fn calculate_custom_response_type_selection<'a>(
context: &mut ExpandedSelection<'a>,
struct_id: ResponseTypeId,
custom_response_type: &'a String,
selection_id: SelectionId,
field: &'a SelectedField)
{
let (graphql_name, rust_name) = context.field_name(field);
let struct_name_string = full_path_prefix(selection_id, context.query);
let field = context.query.schema.get_field(field.field_id);
context.push_field(ExpandedField {
struct_id,
graphql_name: Some(graphql_name),
rust_name,
field_type_qualifiers: &field.r#type.qualifiers,
field_type: struct_name_string.clone().into(),
flatten: false,
boxed: false,
deprecation: field.deprecation(),
});

let struct_id = context.push_type(ExpandedType {
name: struct_name_string.into(),
});
context.push_type_alias(TypeAlias {
name: custom_response_type.as_str(),
struct_id,
boxed: false,
});
}

pub(super) fn render_fragment<'a>(
@@ -557,14 +602,14 @@ impl<'a> ExpandedSelection<'a> {

// If the type is aliased, stop here.
if let Some(alias) = self.aliases.iter().find(|alias| alias.struct_id == type_id) {
let fragment_name = Ident::new(alias.name, Span::call_site());
let fragment_name = if alias.boxed {
quote!(Box<#fragment_name>)
let type_name = syn::parse_str::<Path>(alias.name).unwrap();
let type_name = if alias.boxed {
quote!(Box<#type_name>)
} else {
quote!(#fragment_name)
quote!(#type_name)
};
let item = quote! {
pub type #struct_name = #fragment_name;
pub type #struct_name = #type_name;
};
items.push(item);
continue;
26 changes: 26 additions & 0 deletions graphql_client_codegen/src/codegen_options.rs
Original file line number Diff line number Diff line change
@@ -49,6 +49,10 @@ pub struct GraphQLClientCodegenOptions {
skip_serializing_none: bool,
/// Path to the serde crate.
serde_path: syn::Path,
/// list of custom type paths to use for input variables
custom_variable_types: Option<Vec<String>>,
/// Custom response type path
custom_response_type: Option<String>,
}

impl GraphQLClientCodegenOptions {
@@ -71,6 +75,8 @@ impl GraphQLClientCodegenOptions {
fragments_other_variant: Default::default(),
skip_serializing_none: Default::default(),
serde_path: syn::parse_quote!(::serde),
custom_variable_types: Default::default(),
custom_response_type: Default::default(),
}
}

@@ -138,6 +144,26 @@ impl GraphQLClientCodegenOptions {
self.response_derives = Some(response_derives);
}

/// Type use as the response type
pub fn custom_response_type(&self) -> Option<&String> {
self.custom_response_type.as_ref()
}

/// Type use as the response type
pub fn set_custom_response_type(&mut self, response_type: String) {
self.custom_response_type = Some(response_type);
}

/// list of custom type paths to use for input variables
pub fn custom_variable_types(&self) -> Vec<String> {
self.custom_variable_types.clone().unwrap_or_default()
}

/// list of custom type paths to use for input variables
pub fn set_custom_variable_types(&mut self, variables_types: Vec<String>) {
self.custom_variable_types = Some(variables_types);
}

/// The deprecation strategy to adopt.
pub fn set_deprecation_strategy(&mut self, deprecation_strategy: DeprecationStrategy) {
self.deprecation_strategy = Some(deprecation_strategy);
4 changes: 2 additions & 2 deletions graphql_client_codegen/src/query.rs
Original file line number Diff line number Diff line change
@@ -58,7 +58,7 @@ pub(crate) struct ResolvedFragmentId(u32);

#[allow(dead_code)]
#[derive(Debug, Clone, Copy)]
pub(crate) struct VariableId(u32);
pub(crate) struct VariableId(pub u32);

pub(crate) fn resolve<'doc, T>(
schema: &Schema,
@@ -512,7 +512,7 @@ pub(crate) struct Query {
operations: Vec<ResolvedOperation>,
selection_parent_idx: BTreeMap<SelectionId, SelectionParent>,
selections: Vec<Selection>,
variables: Vec<ResolvedVariable>,
pub(crate) variables: Vec<ResolvedVariable>,
}

impl Query {
29 changes: 29 additions & 0 deletions graphql_client_codegen/src/tests/mod.rs
Original file line number Diff line number Diff line change
@@ -42,6 +42,35 @@ fn schema_with_keywords_works() {
};
}

#[test]
fn blended_custom_types_works() {
let query_string = KEYWORDS_QUERY;
let schema_path = build_schema_path(KEYWORDS_SCHEMA_PATH);

let mut options = GraphQLClientCodegenOptions::new(CodegenMode::Cli);
options.set_custom_response_type("external_crate::Transaction".to_string());
options.set_custom_variable_types(vec!["external_crate::ID".to_string()]);

let generated_tokens =
generate_module_token_stream_from_string(query_string, &schema_path, options)
.expect("Generate keywords module");

let generated_code = generated_tokens.to_string();

// Parse generated code. Variables and returns should be replaced with custom types
let r: syn::parse::Result<proc_macro2::TokenStream> = syn::parse2(generated_tokens);
match r {
Ok(_) => {
// Variables and returns should be replaced with custom types
assert!(generated_code.contains("pub type SearchQuerySearch = external_crate :: Transaction"));
assert!(generated_code.contains("pub type extern_ = external_crate :: ID"));
}
Err(e) => {
panic!("Error: {}\n Generated content: {}\n", e, &generated_code);
}
};
}

#[test]
fn fragments_other_variant_should_generate_unknown_other_variant() {
let query_string = FOOBARS_QUERY;
40 changes: 40 additions & 0 deletions graphql_query_derive/src/attributes.rs
Original file line number Diff line number Diff line change
@@ -295,4 +295,44 @@ mod test {
vec!["Direction", "DistanceUnit"],
);
}

#[test]
fn test_custom_variable_types() {
let input = r#"
#[derive(Serialize, Deserialize, Debug)]
#[derive(GraphQLQuery)]
#[graphql(
schema_path = "x",
query_path = "x",
variable_types("extern_crate::Var1", "extern_crate::Var2"),
)]
struct MyQuery;
"#;
let parsed: syn::DeriveInput = syn::parse_str(input).unwrap();

assert_eq!(
extract_attr_list(&parsed, "variable_types").ok().unwrap(),
vec!["extern_crate::Var1", "extern_crate::Var2"],
);
}

#[test]
fn test_custom_response_type() {
let input = r#"
#[derive(Serialize, Deserialize, Debug)]
#[derive(GraphQLQuery)]
#[graphql(
schema_path = "x",
query_path = "x",
response_type = "extern_crate::Resp",
)]
struct MyQuery;
"#;
let parsed: syn::DeriveInput = syn::parse_str(input).unwrap();

assert_eq!(
extract_attr(&parsed, "response_type").ok().unwrap(),
"extern_crate::Resp",
);
}
}
10 changes: 10 additions & 0 deletions graphql_query_derive/src/lib.rs
Original file line number Diff line number Diff line change
@@ -65,6 +65,8 @@ fn build_graphql_client_derive_options(
let extern_enums = attributes::extract_attr_list(input, "extern_enums").ok();
let fragments_other_variant: bool = attributes::extract_fragments_other_variant(input);
let skip_serializing_none: bool = attributes::extract_skip_serializing_none(input);
let custom_variable_types = attributes::extract_attr_list(input, "variable_types").ok();
let custom_response_type = attributes::extract_attr(input, "response_type").ok();

let mut options = GraphQLClientCodegenOptions::new(CodegenMode::Derive);
options.set_query_file(query_path);
@@ -101,6 +103,14 @@ fn build_graphql_client_derive_options(
options.set_extern_enums(extern_enums);
}

if let Some(custom_variable_types) = custom_variable_types {
options.set_custom_variable_types(custom_variable_types);
}

if let Some(custom_response_type) = custom_response_type {
options.set_custom_response_type(custom_response_type);
}

options.set_struct_ident(input.ident.clone());
options.set_module_visibility(input.vis.clone());
options.set_operation_name(input.ident.to_string());