Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
217 changes: 217 additions & 0 deletions chalk-derive/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
extern crate proc_macro;

use proc_macro::TokenStream;
use proc_macro2::Span;
use quote::{format_ident, quote};
use syn::{parse_macro_input, Data, DeriveInput, GenericParam, Ident, TypeParamBound};

Expand Down Expand Up @@ -339,3 +340,219 @@ fn bounded_by_trait<'p>(param: &'p GenericParam, name: &str) -> Option<&'p Ident
_ => None,
}
}

/// Derives Visit for structs and enums for which one of the following is true:
/// - It has a `#[has_interner(TheInterner)]` attribute
/// - There is a single parameter `T: HasInterner` (does not have to be named `T`)
/// - There is a single parameter `I: Interner` (does not have to be named `I`)
#[proc_macro_derive(Visit, attributes(has_interner))]
pub fn derive_visit(item: TokenStream) -> TokenStream {
let trait_name = Ident::new("Visit", Span::call_site());
let method_name = Ident::new("visit_with", Span::call_site());
derive_any_visit(item, trait_name, method_name)
}

/// Same as Visit, but derives SuperVisit instead
#[proc_macro_derive(SuperVisit, attributes(has_interner))]
pub fn derive_super_visit(item: TokenStream) -> TokenStream {
let trait_name = Ident::new("SuperVisit", Span::call_site());
let method_name = Ident::new("super_visit_with", Span::call_site());
derive_any_visit(item, trait_name, method_name)
}

fn derive_any_visit(item: TokenStream, trait_name: Ident, method_name: Ident) -> TokenStream {
let input = parse_macro_input!(item as DeriveInput);
let (impl_generics, ty_generics, where_clause_ref) = input.generics.split_for_impl();

let type_name = input.ident;
let body = derive_visit_body(&type_name, input.data);

if let Some(attr) = input.attrs.iter().find(|a| a.path.is_ident("has_interner")) {
// Hardcoded interner:
//
// impl Visit<ChalkIr> for Type {
//
// }
let arg = attr
.parse_args::<proc_macro2::TokenStream>()
.expect("Expected has_interner argument");

return TokenStream::from(quote! {
impl #impl_generics #trait_name < #arg > for #type_name #ty_generics #where_clause_ref {
fn #method_name <'i, R: VisitResult>(
&self,
visitor: &mut dyn Visitor < 'i, #arg, Result = R >,
outer_binder: DebruijnIndex,
) -> R
where
I: 'i
{
#body
}
}
});
}

match input.generics.params.len() {
1 => {}

0 => {
panic!("Visit derive requires a single type parameter or a `#[has_interner]` attr");
}

_ => {
panic!("Visit derive only works with a single type parameter");
}
};

let generic_param0 = &input.generics.params[0];

if let Some(param) = has_interner(&generic_param0) {
// HasInterner bound:
//
// Example:
//
// impl<T, _I> Visit<_I> for Binders<T>
// where
// T: HasInterner<Interner = _I>,
// {
// }

let mut impl_generics = input.generics.clone();
impl_generics.params.extend(vec![GenericParam::Type(
syn::parse(quote! { _I: Interner }.into()).unwrap(),
)]);

let mut where_clause = where_clause_ref
.cloned()
.unwrap_or_else(|| syn::parse2(quote![where]).unwrap());
where_clause
.predicates
.push(syn::parse2(quote! { #param: HasInterner<Interner = _I> }).unwrap());
where_clause
.predicates
.push(syn::parse2(quote! { #param: Visit<_I> }).unwrap());

return TokenStream::from(quote! {
impl #impl_generics #trait_name < _I > for #type_name < #param >
#where_clause
{
fn #method_name <'i, R: VisitResult>(
&self,
visitor: &mut dyn Visitor < 'i, _I, Result = R >,
outer_binder: DebruijnIndex,
) -> R
where
_I: 'i
{
#body
}
}
});
}

// Interner bound:
//
// Example:
//
// impl<I> Visit<I> for Foo<I>
// where
// I: Interner,
// {
// }

if let Some(i) = is_interner(&generic_param0) {
let impl_generics = &input.generics;

return TokenStream::from(quote! {
impl #impl_generics #trait_name < #i > for #type_name < #i >
#where_clause_ref
{
fn #method_name <'i, R: VisitResult>(
&self,
visitor: &mut dyn Visitor < 'i, #i, Result = R >,
outer_binder: DebruijnIndex,
) -> R
where
I: 'i
{
#body
}
}
});
}

panic!(
"derive({}) requires a parameter that implements HasInterner or Interner",
trait_name
);
}

/// Generates the body of the Visit impl
fn derive_visit_body(type_name: &Ident, data: Data) -> proc_macro2::TokenStream {
match data {
Data::Struct(s) => {
let fields = s.fields.into_iter().map(|f| {
let name = f.ident.as_ref().expect("Unnamed field in a struct");
quote! {
result = result.combine(self.#name.visit_with(visitor, outer_binder));
if result.return_early() { return result; }
}
});
quote! {
let mut result = R::new();
#(#fields)*

result
}
}
Data::Enum(e) => {
let matches = e.variants.into_iter().map(|v| {
let variant = v.ident;
match &v.fields {
syn::Fields::Named(fields) => {
let fnames: &Vec<_> = &fields.named.iter().map(|f| &f.ident).collect();
quote! {
#type_name :: #variant { #(#fnames),* } => {
let mut result = R::new();
#(
result = result.combine(#fnames.visit_with(visitor, outer_binder));
if result.return_early() { return result; }
)*
result
}
}
}

syn::Fields::Unnamed(_fields) => {
let names: Vec<_> = (0..v.fields.iter().count())
.map(|index| format_ident!("a{}", index))
.collect();
quote! {
#type_name::#variant( #(ref #names),* ) => {
let mut result = R::new();
#(
result = result.combine(#names.visit_with(visitor, outer_binder));
if result.return_early() { return result; }
)*
result
}
}
}

syn::Fields::Unit => {
quote! {
#type_name::#variant => R::new(),
}
}
}
});
quote! {
match *self {
#(#matches)*
}
}
}
Data::Union(..) => panic!("Visit can not be derived for unions"),
}
}
Loading