Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/github/examples/github.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ fn main() -> Result<(), anyhow::Error> {
let variables = repo_view::Variables {
owner: owner.to_string(),
name: name.to_string(),
with_issues: true,
};

let client = Client::builder()
Expand Down Expand Up @@ -81,6 +82,7 @@ fn main() -> Result<(), anyhow::Error> {
.repository
.expect("missing repository")
.issues
.unwrap()
.nodes
.expect("issue nodes is null")
.iter()
Expand Down
4 changes: 2 additions & 2 deletions examples/github/examples/query_1.graphql
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
query RepoView($owner: String!, $name: String!) {
query RepoView($owner: String!, $name: String!, $withIssues: Boolean!) {
repository(owner: $owner, name: $name) {
homepageUrl
stargazers {
totalCount
}
issues(first: 20, states: OPEN) {
issues(first: 20, states: OPEN) @include(if: $withIssues) {
nodes {
title
comments {
Expand Down
34 changes: 31 additions & 3 deletions graphql_client_codegen/src/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,14 @@ fn render_variable_field_type(
let safe_name = shared::keyword_replace(normalized_name.clone());
let full_name = Ident::new(safe_name.as_ref(), Span::call_site());

decorate_type(&full_name, &variable.r#type.qualifiers)
decorate_type(&full_name, &variable.r#type.qualifiers, false)
}

fn decorate_type(ident: &Ident, qualifiers: &[GraphqlTypeQualifier]) -> TokenStream {
fn decorate_type(
ident: &Ident,
qualifiers: &[GraphqlTypeQualifier],
skip_or_include: bool,
) -> TokenStream {
let mut qualified = quote!(#ident);

let mut non_null = false;
Expand Down Expand Up @@ -233,7 +237,8 @@ fn decorate_type(ident: &Ident, qualifiers: &[GraphqlTypeQualifier]) -> TokenStr

// If we are in nullable context at the end of the iteration, we wrap the whole
// type with an Option.
if !non_null {
// This can also happen if the field has a @skip or @include directive
if !non_null || skip_or_include {
qualified = quote!(Option<#qualified>);
}

Expand Down Expand Up @@ -345,3 +350,26 @@ where
#(#fields,)*
})
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn decorate_type_emits_optional_when_skip_or_include() {
let ident = Ident::new("Test", Span::call_site());
let qualifiers = [GraphqlTypeQualifier::Required, GraphqlTypeQualifier::List];
let rendered_type = decorate_type(&ident, &qualifiers, true).to_string();

assert_eq!(rendered_type, "Option < Vec < Option < Test >> >");
}

#[test]
fn decorate_type_emits_required_when_no_skip_or_include() {
let ident = Ident::new("Test", Span::call_site());
let qualifiers = [GraphqlTypeQualifier::Required, GraphqlTypeQualifier::List];
let rendered_type = decorate_type(&ident, &qualifiers, false).to_string();

assert_eq!(rendered_type, "Vec < Option < Test >>");
}
}
4 changes: 2 additions & 2 deletions graphql_client_codegen/src/codegen/inputs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ fn generate_struct(
None
};
let type_name = Ident::new(normalized_field_type_name.as_ref(), Span::call_site());
let field_type_tokens = super::decorate_type(&type_name, &field_type.qualifiers);
let field_type_tokens = super::decorate_type(&type_name, &field_type.qualifiers, false);
let field_type = if field_type
.id
.as_input_id()
Expand Down Expand Up @@ -127,7 +127,7 @@ fn generate_enum(
let mut qualifiers = vec![GraphqlTypeQualifier::Required];
qualifiers.extend(field_type.qualifiers.iter().cloned());

let field_type_tokens = super::decorate_type(&type_name, &qualifiers);
let field_type_tokens = super::decorate_type(&type_name, &qualifiers, false);
let field_type = if field_type
.id
.as_input_id()
Expand Down
79 changes: 53 additions & 26 deletions graphql_client_codegen/src/codegen/selection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ use crate::{
},
schema::{Schema, TypeId},
type_qualifiers::GraphqlTypeQualifier,
GraphQLClientCodegenOptions,
GeneralError,
GeneralError, GraphQLClientCodegenOptions,
};
use heck::*;
use proc_macro2::{Ident, Span, TokenStream};
Expand Down Expand Up @@ -43,12 +42,27 @@ pub(crate) fn render_response_data_fields<'a>(
if let Some(custom_response_type) = options.custom_response_type() {
if operation.selection_set.len() == 1 {
let selection_id = operation.selection_set[0];
let selection_field = query.query.get_selection(selection_id).as_selected_field()
.ok_or_else(|| GeneralError(format!("Custom response type {custom_response_type} will only work on fields")))?;
calculate_custom_response_type_selection(&mut expanded_selection, response_data_type_id, custom_response_type, selection_id, selection_field);
let selection_field = query
.query
.get_selection(selection_id)
.as_selected_field()
.ok_or_else(|| {
GeneralError(format!(
"Custom response type {custom_response_type} will only work on fields"
))
})?;
calculate_custom_response_type_selection(
&mut expanded_selection,
response_data_type_id,
custom_response_type,
selection_id,
selection_field,
);
return Ok(expanded_selection);
} else {
return Err(GeneralError(format!("Custom response type {custom_response_type} requires single selection field")));
return Err(GeneralError(format!(
"Custom response type {custom_response_type} requires single selection field"
)));
}
}

Expand All @@ -68,8 +82,8 @@ fn calculate_custom_response_type_selection<'a>(
struct_id: ResponseTypeId,
custom_response_type: &'a String,
selection_id: SelectionId,
field: &'a SelectedField)
{
field: &'a SelectedField,
) {
let (graphql_name, rust_name) = context.field_name(field);
let struct_name_string = full_path_prefix(selection_id, context.query);
let field = context.query.schema.get_field(field.field_id);
Expand All @@ -82,6 +96,7 @@ fn calculate_custom_response_type_selection<'a>(
flatten: false,
boxed: false,
deprecation: field.deprecation(),
skip_or_include: false,
});

let struct_id = context.push_type(ExpandedType {
Expand Down Expand Up @@ -127,7 +142,7 @@ pub(super) fn render_fragment<'a>(
/// A sub-selection set (spread) on one of the variants of a union or interface.
enum VariantSelection<'a> {
InlineFragment(&'a InlineFragment),
FragmentSpread((ResolvedFragmentId, &'a ResolvedFragment)),
FragmentSpread((ResolvedFragmentId, &'a ResolvedFragment), bool),
}

impl<'a> VariantSelection<'a> {
Expand All @@ -141,15 +156,18 @@ impl<'a> VariantSelection<'a> {
Selection::InlineFragment(inline_fragment) => {
Some(VariantSelection::InlineFragment(inline_fragment))
}
Selection::FragmentSpread(fragment_id) => {
Selection::FragmentSpread(fragment_id, has_skip_or_include) => {
let fragment = query.query.get_fragment(*fragment_id);

if fragment.on == type_id {
// The selection is on the type itself.
None
} else {
// The selection is on one of the variants of the type.
Some(VariantSelection::FragmentSpread((*fragment_id, fragment)))
Some(VariantSelection::FragmentSpread(
(*fragment_id, fragment),
*has_skip_or_include,
))
}
}
Selection::Field(_) | Selection::Typename => None,
Expand All @@ -159,7 +177,7 @@ impl<'a> VariantSelection<'a> {
fn variant_type_id(&self) -> TypeId {
match self {
VariantSelection::InlineFragment(f) => f.type_id,
VariantSelection::FragmentSpread((_id, f)) => f.on,
VariantSelection::FragmentSpread((_id, f), _) => f.on,
}
}
}
Expand All @@ -174,7 +192,7 @@ fn calculate_selection<'a>(
// If the selection only contains a fragment, replace the selection with
// that fragment.
if selection_set.len() == 1 {
if let Selection::FragmentSpread(fragment_id) =
if let Selection::FragmentSpread(fragment_id, _) =
context.query.query.get_selection(selection_set[0])
{
let fragment = context.query.query.get_fragment(*fragment_id);
Expand Down Expand Up @@ -252,7 +270,7 @@ fn calculate_selection<'a>(
let struct_id = context.push_type(expanded_type);

if variant_selections.len() == 1 {
if let VariantSelection::FragmentSpread((fragment_id, fragment)) =
if let VariantSelection::FragmentSpread((fragment_id, fragment), _) =
variant_selections[0].2
{
context.push_type_alias(TypeAlias {
Expand All @@ -275,17 +293,20 @@ fn calculate_selection<'a>(
options,
);
}
VariantSelection::FragmentSpread((fragment_id, fragment)) => context
.push_field(ExpandedField {
field_type: fragment.name.as_str().into(),
field_type_qualifiers: &[GraphqlTypeQualifier::Required],
flatten: true,
graphql_name: None,
rust_name: fragment.name.to_snake_case().into(),
struct_id,
deprecation: None,
boxed: fragment_is_recursive(*fragment_id, context.query.query),
}),
VariantSelection::FragmentSpread(
(fragment_id, fragment),
has_skip_or_include,
) => context.push_field(ExpandedField {
field_type: fragment.name.as_str().into(),
field_type_qualifiers: &[GraphqlTypeQualifier::Required],
flatten: true,
graphql_name: None,
rust_name: fragment.name.to_snake_case().into(),
struct_id,
deprecation: None,
boxed: fragment_is_recursive(*fragment_id, context.query.query),
skip_or_include: *has_skip_or_include,
}),
}
}
} else {
Expand Down Expand Up @@ -331,6 +352,7 @@ fn calculate_selection<'a>(
flatten: false,
deprecation: schema_field.deprecation(),
boxed: false,
skip_or_include: field.skip_or_include,
});
}
TypeId::Scalar(scalar) => {
Expand All @@ -348,6 +370,7 @@ fn calculate_selection<'a>(
flatten: false,
deprecation: schema_field.deprecation(),
boxed: false,
skip_or_include: field.skip_or_include,
});
}
TypeId::Object(_) | TypeId::Interface(_) | TypeId::Union(_) => {
Expand All @@ -362,6 +385,7 @@ fn calculate_selection<'a>(
flatten: false,
boxed: false,
deprecation: schema_field.deprecation(),
skip_or_include: field.skip_or_include,
});

let type_id = context.push_type(ExpandedType {
Expand All @@ -381,7 +405,7 @@ fn calculate_selection<'a>(
}
Selection::Typename => (),
Selection::InlineFragment(_inline) => (),
Selection::FragmentSpread(fragment_id) => {
Selection::FragmentSpread(fragment_id, has_skip_or_include) => {
// Here we only render fragments that are directly on the type
// itself, and not on one of its variants.

Expand All @@ -407,6 +431,7 @@ fn calculate_selection<'a>(
flatten: true,
deprecation: None,
boxed: fragment_is_recursive(*fragment_id, context.query.query),
skip_or_include: *has_skip_or_include,
});

// We stop here, because the structs for the fragments are generated separately, to
Expand Down Expand Up @@ -434,6 +459,7 @@ struct ExpandedField<'a> {
flatten: bool,
deprecation: Option<Option<&'a str>>,
boxed: bool,
skip_or_include: bool,
}

impl ExpandedField<'_> {
Expand All @@ -442,6 +468,7 @@ impl ExpandedField<'_> {
let qualified_type = decorate_type(
&Ident::new(&self.field_type, Span::call_site()),
self.field_type_qualifiers,
self.skip_or_include,
);

let qualified_type = if self.boxed {
Expand Down
29 changes: 26 additions & 3 deletions graphql_client_codegen/src/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use crate::{
StoredInputType, StoredScalar, TypeId, UnionId,
},
};
use graphql_parser::query::Directive;
use std::{
collections::{BTreeMap, BTreeSet},
fmt::Display,
Expand Down Expand Up @@ -279,7 +280,13 @@ where
))
})?;

let id = query.push_selection(Selection::FragmentSpread(fragment_id), parent);
let id = query.push_selection(
Selection::FragmentSpread(
fragment_id,
has_skip_or_include(fragment_spread.directives.as_slice()),
),
parent,
);

parent.add_to_selection_set(query, id);
}
Expand Down Expand Up @@ -323,6 +330,7 @@ where
alias: field.alias.as_ref().map(|alias| alias.as_ref().into()),
field_id,
selection_set: Vec::with_capacity(selection_set.items.len()),
skip_or_include: has_skip_or_include(field.directives.as_slice()),
}),
parent,
);
Expand Down Expand Up @@ -352,7 +360,13 @@ where
))
})?;

let id = query.push_selection(Selection::FragmentSpread(fragment_id), parent);
let id = query.push_selection(
Selection::FragmentSpread(
fragment_id,
has_skip_or_include(fragment_spread.directives.as_slice()),
),
parent,
);

parent.add_to_selection_set(query, id);
}
Expand All @@ -362,6 +376,15 @@ where
Ok(())
}

fn has_skip_or_include<'doc, T>(directives: &[Directive<'doc, T>]) -> bool
where
T: graphql_parser::query::Text<'doc>,
{
directives
.iter()
.any(|directive| ["skip", "include"].contains(&directive.name.as_ref()))
}

fn resolve_selection<'doc, T>(
ctx: &mut Query,
on: TypeId,
Expand Down Expand Up @@ -725,7 +748,7 @@ pub(crate) fn all_used_types(operation_id: OperationId, query: &BoundQuery<'_>)

pub(crate) fn full_path_prefix(selection_id: SelectionId, query: &BoundQuery<'_>) -> String {
let mut path = match query.query.get_selection(selection_id) {
Selection::FragmentSpread(_) | Selection::InlineFragment(_) => Vec::new(),
Selection::FragmentSpread(..) | Selection::InlineFragment(_) => Vec::new(),
selection => vec![selection.to_path_segment(query)],
};

Expand Down
Loading