diff --git a/examples/github/examples/github.rs b/examples/github/examples/github.rs index 707d79f0..634b5a4e 100644 --- a/examples/github/examples/github.rs +++ b/examples/github/examples/github.rs @@ -45,6 +45,7 @@ fn main() -> Result<(), anyhow::Error> { let variables = repo_view::Variables { owner: owner.to_string(), name: name.to_string(), + with_issues: true, }; let client = Client::builder() @@ -81,6 +82,7 @@ fn main() -> Result<(), anyhow::Error> { .repository .expect("missing repository") .issues + .unwrap() .nodes .expect("issue nodes is null") .iter() diff --git a/examples/github/examples/query_1.graphql b/examples/github/examples/query_1.graphql index f134c7c7..a7ae2645 100644 --- a/examples/github/examples/query_1.graphql +++ b/examples/github/examples/query_1.graphql @@ -1,10 +1,10 @@ -query RepoView($owner: String!, $name: String!) { +query RepoView($owner: String!, $name: String!, $withIssues: Boolean!) { repository(owner: $owner, name: $name) { homepageUrl stargazers { totalCount } - issues(first: 20, states: OPEN) { + issues(first: 20, states: OPEN) @include(if: $withIssues) { nodes { title comments { diff --git a/graphql_client_codegen/src/codegen.rs b/graphql_client_codegen/src/codegen.rs index e33bb1b1..797607e9 100644 --- a/graphql_client_codegen/src/codegen.rs +++ b/graphql_client_codegen/src/codegen.rs @@ -199,10 +199,14 @@ fn render_variable_field_type( let safe_name = shared::keyword_replace(normalized_name.clone()); let full_name = Ident::new(safe_name.as_ref(), Span::call_site()); - decorate_type(&full_name, &variable.r#type.qualifiers) + decorate_type(&full_name, &variable.r#type.qualifiers, false) } -fn decorate_type(ident: &Ident, qualifiers: &[GraphqlTypeQualifier]) -> TokenStream { +fn decorate_type( + ident: &Ident, + qualifiers: &[GraphqlTypeQualifier], + skip_or_include: bool, +) -> TokenStream { let mut qualified = quote!(#ident); let mut non_null = false; @@ -233,7 +237,8 @@ fn decorate_type(ident: &Ident, qualifiers: &[GraphqlTypeQualifier]) -> TokenStr // If we are in nullable context at the end of the iteration, we wrap the whole // type with an Option. - if !non_null { + // This can also happen if the field has a @skip or @include directive + if !non_null || skip_or_include { qualified = quote!(Option<#qualified>); } @@ -345,3 +350,26 @@ where #(#fields,)* }) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn decorate_type_emits_optional_when_skip_or_include() { + let ident = Ident::new("Test", Span::call_site()); + let qualifiers = [GraphqlTypeQualifier::Required, GraphqlTypeQualifier::List]; + let rendered_type = decorate_type(&ident, &qualifiers, true).to_string(); + + assert_eq!(rendered_type, "Option < Vec < Option < Test >> >"); + } + + #[test] + fn decorate_type_emits_required_when_no_skip_or_include() { + let ident = Ident::new("Test", Span::call_site()); + let qualifiers = [GraphqlTypeQualifier::Required, GraphqlTypeQualifier::List]; + let rendered_type = decorate_type(&ident, &qualifiers, false).to_string(); + + assert_eq!(rendered_type, "Vec < Option < Test >>"); + } +} diff --git a/graphql_client_codegen/src/codegen/inputs.rs b/graphql_client_codegen/src/codegen/inputs.rs index d8cc1080..fd58736e 100644 --- a/graphql_client_codegen/src/codegen/inputs.rs +++ b/graphql_client_codegen/src/codegen/inputs.rs @@ -74,7 +74,7 @@ fn generate_struct( None }; let type_name = Ident::new(normalized_field_type_name.as_ref(), Span::call_site()); - let field_type_tokens = super::decorate_type(&type_name, &field_type.qualifiers); + let field_type_tokens = super::decorate_type(&type_name, &field_type.qualifiers, false); let field_type = if field_type .id .as_input_id() @@ -127,7 +127,7 @@ fn generate_enum( let mut qualifiers = vec![GraphqlTypeQualifier::Required]; qualifiers.extend(field_type.qualifiers.iter().cloned()); - let field_type_tokens = super::decorate_type(&type_name, &qualifiers); + let field_type_tokens = super::decorate_type(&type_name, &qualifiers, false); let field_type = if field_type .id .as_input_id() diff --git a/graphql_client_codegen/src/codegen/selection.rs b/graphql_client_codegen/src/codegen/selection.rs index ec1703b8..63258b32 100644 --- a/graphql_client_codegen/src/codegen/selection.rs +++ b/graphql_client_codegen/src/codegen/selection.rs @@ -12,8 +12,7 @@ use crate::{ }, schema::{Schema, TypeId}, type_qualifiers::GraphqlTypeQualifier, - GraphQLClientCodegenOptions, - GeneralError, + GeneralError, GraphQLClientCodegenOptions, }; use heck::*; use proc_macro2::{Ident, Span, TokenStream}; @@ -43,12 +42,27 @@ pub(crate) fn render_response_data_fields<'a>( if let Some(custom_response_type) = options.custom_response_type() { if operation.selection_set.len() == 1 { let selection_id = operation.selection_set[0]; - let selection_field = query.query.get_selection(selection_id).as_selected_field() - .ok_or_else(|| GeneralError(format!("Custom response type {custom_response_type} will only work on fields")))?; - calculate_custom_response_type_selection(&mut expanded_selection, response_data_type_id, custom_response_type, selection_id, selection_field); + let selection_field = query + .query + .get_selection(selection_id) + .as_selected_field() + .ok_or_else(|| { + GeneralError(format!( + "Custom response type {custom_response_type} will only work on fields" + )) + })?; + calculate_custom_response_type_selection( + &mut expanded_selection, + response_data_type_id, + custom_response_type, + selection_id, + selection_field, + ); return Ok(expanded_selection); } else { - return Err(GeneralError(format!("Custom response type {custom_response_type} requires single selection field"))); + return Err(GeneralError(format!( + "Custom response type {custom_response_type} requires single selection field" + ))); } } @@ -68,8 +82,8 @@ fn calculate_custom_response_type_selection<'a>( struct_id: ResponseTypeId, custom_response_type: &'a String, selection_id: SelectionId, - field: &'a SelectedField) -{ + field: &'a SelectedField, +) { let (graphql_name, rust_name) = context.field_name(field); let struct_name_string = full_path_prefix(selection_id, context.query); let field = context.query.schema.get_field(field.field_id); @@ -82,6 +96,7 @@ fn calculate_custom_response_type_selection<'a>( flatten: false, boxed: false, deprecation: field.deprecation(), + skip_or_include: false, }); let struct_id = context.push_type(ExpandedType { @@ -127,7 +142,7 @@ pub(super) fn render_fragment<'a>( /// A sub-selection set (spread) on one of the variants of a union or interface. enum VariantSelection<'a> { InlineFragment(&'a InlineFragment), - FragmentSpread((ResolvedFragmentId, &'a ResolvedFragment)), + FragmentSpread((ResolvedFragmentId, &'a ResolvedFragment), bool), } impl<'a> VariantSelection<'a> { @@ -141,7 +156,7 @@ impl<'a> VariantSelection<'a> { Selection::InlineFragment(inline_fragment) => { Some(VariantSelection::InlineFragment(inline_fragment)) } - Selection::FragmentSpread(fragment_id) => { + Selection::FragmentSpread(fragment_id, has_skip_or_include) => { let fragment = query.query.get_fragment(*fragment_id); if fragment.on == type_id { @@ -149,7 +164,10 @@ impl<'a> VariantSelection<'a> { None } else { // The selection is on one of the variants of the type. - Some(VariantSelection::FragmentSpread((*fragment_id, fragment))) + Some(VariantSelection::FragmentSpread( + (*fragment_id, fragment), + *has_skip_or_include, + )) } } Selection::Field(_) | Selection::Typename => None, @@ -159,7 +177,7 @@ impl<'a> VariantSelection<'a> { fn variant_type_id(&self) -> TypeId { match self { VariantSelection::InlineFragment(f) => f.type_id, - VariantSelection::FragmentSpread((_id, f)) => f.on, + VariantSelection::FragmentSpread((_id, f), _) => f.on, } } } @@ -174,7 +192,7 @@ fn calculate_selection<'a>( // If the selection only contains a fragment, replace the selection with // that fragment. if selection_set.len() == 1 { - if let Selection::FragmentSpread(fragment_id) = + if let Selection::FragmentSpread(fragment_id, _) = context.query.query.get_selection(selection_set[0]) { let fragment = context.query.query.get_fragment(*fragment_id); @@ -252,7 +270,7 @@ fn calculate_selection<'a>( let struct_id = context.push_type(expanded_type); if variant_selections.len() == 1 { - if let VariantSelection::FragmentSpread((fragment_id, fragment)) = + if let VariantSelection::FragmentSpread((fragment_id, fragment), _) = variant_selections[0].2 { context.push_type_alias(TypeAlias { @@ -275,17 +293,20 @@ fn calculate_selection<'a>( options, ); } - VariantSelection::FragmentSpread((fragment_id, fragment)) => context - .push_field(ExpandedField { - field_type: fragment.name.as_str().into(), - field_type_qualifiers: &[GraphqlTypeQualifier::Required], - flatten: true, - graphql_name: None, - rust_name: fragment.name.to_snake_case().into(), - struct_id, - deprecation: None, - boxed: fragment_is_recursive(*fragment_id, context.query.query), - }), + VariantSelection::FragmentSpread( + (fragment_id, fragment), + has_skip_or_include, + ) => context.push_field(ExpandedField { + field_type: fragment.name.as_str().into(), + field_type_qualifiers: &[GraphqlTypeQualifier::Required], + flatten: true, + graphql_name: None, + rust_name: fragment.name.to_snake_case().into(), + struct_id, + deprecation: None, + boxed: fragment_is_recursive(*fragment_id, context.query.query), + skip_or_include: *has_skip_or_include, + }), } } } else { @@ -331,6 +352,7 @@ fn calculate_selection<'a>( flatten: false, deprecation: schema_field.deprecation(), boxed: false, + skip_or_include: field.skip_or_include, }); } TypeId::Scalar(scalar) => { @@ -348,6 +370,7 @@ fn calculate_selection<'a>( flatten: false, deprecation: schema_field.deprecation(), boxed: false, + skip_or_include: field.skip_or_include, }); } TypeId::Object(_) | TypeId::Interface(_) | TypeId::Union(_) => { @@ -362,6 +385,7 @@ fn calculate_selection<'a>( flatten: false, boxed: false, deprecation: schema_field.deprecation(), + skip_or_include: field.skip_or_include, }); let type_id = context.push_type(ExpandedType { @@ -381,7 +405,7 @@ fn calculate_selection<'a>( } Selection::Typename => (), Selection::InlineFragment(_inline) => (), - Selection::FragmentSpread(fragment_id) => { + Selection::FragmentSpread(fragment_id, has_skip_or_include) => { // Here we only render fragments that are directly on the type // itself, and not on one of its variants. @@ -407,6 +431,7 @@ fn calculate_selection<'a>( flatten: true, deprecation: None, boxed: fragment_is_recursive(*fragment_id, context.query.query), + skip_or_include: *has_skip_or_include, }); // We stop here, because the structs for the fragments are generated separately, to @@ -434,6 +459,7 @@ struct ExpandedField<'a> { flatten: bool, deprecation: Option>, boxed: bool, + skip_or_include: bool, } impl ExpandedField<'_> { @@ -442,6 +468,7 @@ impl ExpandedField<'_> { let qualified_type = decorate_type( &Ident::new(&self.field_type, Span::call_site()), self.field_type_qualifiers, + self.skip_or_include, ); let qualified_type = if self.boxed { diff --git a/graphql_client_codegen/src/query.rs b/graphql_client_codegen/src/query.rs index 71d0798f..53781a7a 100644 --- a/graphql_client_codegen/src/query.rs +++ b/graphql_client_codegen/src/query.rs @@ -18,6 +18,7 @@ use crate::{ StoredInputType, StoredScalar, TypeId, UnionId, }, }; +use graphql_parser::query::Directive; use std::{ collections::{BTreeMap, BTreeSet}, fmt::Display, @@ -279,7 +280,13 @@ where )) })?; - let id = query.push_selection(Selection::FragmentSpread(fragment_id), parent); + let id = query.push_selection( + Selection::FragmentSpread( + fragment_id, + has_skip_or_include(fragment_spread.directives.as_slice()), + ), + parent, + ); parent.add_to_selection_set(query, id); } @@ -323,6 +330,7 @@ where alias: field.alias.as_ref().map(|alias| alias.as_ref().into()), field_id, selection_set: Vec::with_capacity(selection_set.items.len()), + skip_or_include: has_skip_or_include(field.directives.as_slice()), }), parent, ); @@ -352,7 +360,13 @@ where )) })?; - let id = query.push_selection(Selection::FragmentSpread(fragment_id), parent); + let id = query.push_selection( + Selection::FragmentSpread( + fragment_id, + has_skip_or_include(fragment_spread.directives.as_slice()), + ), + parent, + ); parent.add_to_selection_set(query, id); } @@ -362,6 +376,15 @@ where Ok(()) } +fn has_skip_or_include<'doc, T>(directives: &[Directive<'doc, T>]) -> bool +where + T: graphql_parser::query::Text<'doc>, +{ + directives + .iter() + .any(|directive| ["skip", "include"].contains(&directive.name.as_ref())) +} + fn resolve_selection<'doc, T>( ctx: &mut Query, on: TypeId, @@ -725,7 +748,7 @@ pub(crate) fn all_used_types(operation_id: OperationId, query: &BoundQuery<'_>) pub(crate) fn full_path_prefix(selection_id: SelectionId, query: &BoundQuery<'_>) -> String { let mut path = match query.query.get_selection(selection_id) { - Selection::FragmentSpread(_) | Selection::InlineFragment(_) => Vec::new(), + Selection::FragmentSpread(..) | Selection::InlineFragment(_) => Vec::new(), selection => vec![selection.to_path_segment(query)], }; diff --git a/graphql_client_codegen/src/query/selection.rs b/graphql_client_codegen/src/query/selection.rs index 018e6f22..871e8bdb 100644 --- a/graphql_client_codegen/src/query/selection.rs +++ b/graphql_client_codegen/src/query/selection.rs @@ -14,7 +14,7 @@ pub(super) fn validate_type_conditions( let selection = query.query.get_selection(selection_id); let selected_type = match selection { - Selection::FragmentSpread(fragment_id) => query.query.get_fragment(*fragment_id).on, + Selection::FragmentSpread(fragment_id, _) => query.query.get_fragment(*fragment_id).on, Selection::InlineFragment(inline_fragment) => inline_fragment.type_id, _ => return Ok(()), }; @@ -146,7 +146,7 @@ impl SelectionParent { pub(crate) enum Selection { Field(SelectedField), InlineFragment(InlineFragment), - FragmentSpread(ResolvedFragmentId), + FragmentSpread(ResolvedFragmentId, bool), Typename, } @@ -184,7 +184,7 @@ impl Selection { selection.collect_used_types(used_types, query); } } - Selection::FragmentSpread(fragment_id) => { + Selection::FragmentSpread(fragment_id, _) => { // This is necessary to avoid infinite recursion. if used_types.fragments.contains(fragment_id) { return; @@ -204,7 +204,7 @@ impl Selection { pub(crate) fn contains_fragment(&self, fragment_id: ResolvedFragmentId, query: &Query) -> bool { match self { - Selection::FragmentSpread(id) => *id == fragment_id, + Selection::FragmentSpread(id, _) => *id == fragment_id, _ => self.subselection().iter().any(|selection_id| { query .get_selection(*selection_id) @@ -257,6 +257,7 @@ pub(crate) struct SelectedField { pub(crate) alias: Option, pub(crate) field_id: StoredFieldId, pub(crate) selection_set: Vec, + pub(crate) skip_or_include: bool, } impl SelectedField { diff --git a/graphql_client_codegen/src/query/validation.rs b/graphql_client_codegen/src/query/validation.rs index ee0a48f5..63722f92 100644 --- a/graphql_client_codegen/src/query/validation.rs +++ b/graphql_client_codegen/src/query/validation.rs @@ -56,7 +56,7 @@ fn selection_set_contains_type_name( match selection { Selection::Typename => return true, - Selection::FragmentSpread(fragment_id) => { + Selection::FragmentSpread(fragment_id, _) => { let fragment = query.get_fragment(*fragment_id); if fragment.on == parent_type_id && selection_set_contains_type_name(fragment.on, &fragment.selection_set, query) diff --git a/graphql_client_codegen/src/tests/mod.rs b/graphql_client_codegen/src/tests/mod.rs index aaed3e5d..5856da44 100644 --- a/graphql_client_codegen/src/tests/mod.rs +++ b/graphql_client_codegen/src/tests/mod.rs @@ -8,6 +8,9 @@ const KEYWORDS_SCHEMA_PATH: &str = "keywords_schema.graphql"; const FOOBARS_QUERY: &str = include_str!("foobars_query.graphql"); const FOOBARS_SCHEMA_PATH: &str = "foobars_schema.graphql"; +const POSTS_QUERY: &str = include_str!("posts_query.graphql"); +const POSTS_SCHEMA_PATH: &str = "posts_schema.graphql"; + fn build_schema_path(path: &str) -> PathBuf { std::env::current_dir() .unwrap() @@ -62,7 +65,8 @@ fn blended_custom_types_works() { match r { Ok(_) => { // Variables and returns should be replaced with custom types - assert!(generated_code.contains("pub type SearchQuerySearch = external_crate :: Transaction")); + assert!(generated_code + .contains("pub type SearchQuerySearch = external_crate :: Transaction")); assert!(generated_code.contains("pub type extern_ = external_crate :: ID")); } Err(e) => { @@ -154,3 +158,30 @@ fn skip_serializing_none_should_generate_serde_skip_serializing() { } }; } + +#[test] +fn generate_option_for_skip_and_include() { + let query_string = POSTS_QUERY; + let schema_path = build_schema_path(POSTS_SCHEMA_PATH); + + let options = GraphQLClientCodegenOptions::new(CodegenMode::Cli); + + let generated_tokens = + generate_module_token_stream_from_string(query_string, &schema_path, options) + .expect("Generate posts module"); + + let generated_code = generated_tokens.to_string(); + + let r: syn::parse::Result = syn::parse2(generated_tokens); + + match r { + Ok(_) => { + println!("{}", generated_code); + let expected_type = "pub struct UserQueryUser { pub name : String , pub email : Option < String > , pub friends : Option < Vec < UserQueryUserFriends > > , # [serde (flatten)] pub with_post_fragment : Option < WithPostFragment > , }"; + assert!(generated_code.contains(expected_type)); + } + Err(e) => { + panic!("Error: {}\n Generated content: {}\n", e, &generated_code); + } + } +} diff --git a/graphql_client_codegen/src/tests/posts_query.graphql b/graphql_client_codegen/src/tests/posts_query.graphql new file mode 100644 index 00000000..675c0d0b --- /dev/null +++ b/graphql_client_codegen/src/tests/posts_query.graphql @@ -0,0 +1,22 @@ +query UserQuery( + $id: ID!, + $includeEmail: Boolean!, + $skipFriends: Boolean!, + $skipPosts: Boolean! +) { + user(id: $id) { + name + email @include(if: $includeEmail) + friends @skip(if: $skipFriends) { + name + } + ...WithPostFragment @skip(if: $skipPosts) + } +} + +fragment WithPostFragment on User { + posts { + title + body + } +} diff --git a/graphql_client_codegen/src/tests/posts_schema.graphql b/graphql_client_codegen/src/tests/posts_schema.graphql new file mode 100644 index 00000000..65298dcd --- /dev/null +++ b/graphql_client_codegen/src/tests/posts_schema.graphql @@ -0,0 +1,32 @@ +schema { + query: Query + mutation: Mutation +} + +type Query { + user(id: ID!): User + posts(first: Int, after: ID): [Post!]! +} + +type User { + id: ID! + name: String! + email: String! + friends: [User!]! + posts: [Post!]! +} + +type Post { + id: ID! + title: String! + body: String! + author: User! + comments: [Comment!]! +} + +type Comment { + id: ID! + title: String! + body: String! + author: User! +}