@@ -86,27 +86,23 @@ mod llvm_enzyme {
86
86
ecx : & mut ExtCtxt < ' _ > ,
87
87
meta_item : & ThinVec < MetaItemInner > ,
88
88
has_ret : bool ,
89
+ mode : DiffMode ,
89
90
) -> AutoDiffAttrs {
90
91
let dcx = ecx. sess . dcx ( ) ;
91
- let mode = name ( & meta_item[ 1 ] ) ;
92
- let Ok ( mode) = DiffMode :: from_str ( & mode) else {
93
- dcx. emit_err ( errors:: AutoDiffInvalidMode { span : meta_item[ 1 ] . span ( ) , mode } ) ;
94
- return AutoDiffAttrs :: error ( ) ;
95
- } ;
96
92
97
93
// Now we check, whether the user wants autodiff in batch/vector mode, or scalar mode.
98
94
// If he doesn't specify an integer (=width), we default to scalar mode, thus width=1.
99
- let mut first_activity = 2 ;
95
+ let mut first_activity = 1 ;
100
96
101
- let width = if let [ _, _ , x, ..] = & meta_item[ ..]
97
+ let width = if let [ _, x, ..] = & meta_item[ ..]
102
98
&& let Some ( x) = width ( x)
103
99
{
104
- first_activity = 3 ;
100
+ first_activity = 2 ;
105
101
match x. try_into ( ) {
106
102
Ok ( x) => x,
107
103
Err ( _) => {
108
104
dcx. emit_err ( errors:: AutoDiffInvalidWidth {
109
- span : meta_item[ 2 ] . span ( ) ,
105
+ span : meta_item[ 1 ] . span ( ) ,
110
106
width : x,
111
107
} ) ;
112
108
return AutoDiffAttrs :: error ( ) ;
@@ -165,6 +161,24 @@ mod llvm_enzyme {
165
161
ts. push ( TokenTree :: Token ( comma. clone ( ) , Spacing :: Alone ) ) ;
166
162
}
167
163
164
+ pub ( crate ) fn expand_forward (
165
+ ecx : & mut ExtCtxt < ' _ > ,
166
+ expand_span : Span ,
167
+ meta_item : & ast:: MetaItem ,
168
+ item : Annotatable ,
169
+ ) -> Vec < Annotatable > {
170
+ expand_with_mode ( ecx, expand_span, meta_item, item, DiffMode :: Forward )
171
+ }
172
+
173
+ pub ( crate ) fn expand_reverse (
174
+ ecx : & mut ExtCtxt < ' _ > ,
175
+ expand_span : Span ,
176
+ meta_item : & ast:: MetaItem ,
177
+ item : Annotatable ,
178
+ ) -> Vec < Annotatable > {
179
+ expand_with_mode ( ecx, expand_span, meta_item, item, DiffMode :: Reverse )
180
+ }
181
+
168
182
/// We expand the autodiff macro to generate a new placeholder function which passes
169
183
/// type-checking and can be called by users. The function body of the placeholder function will
170
184
/// later be replaced on LLVM-IR level, so the design of the body is less important and for now
@@ -198,11 +212,12 @@ mod llvm_enzyme {
198
212
/// ```
199
213
/// FIXME(ZuseZ4): Once autodiff is enabled by default, make this a doc comment which is checked
200
214
/// in CI.
201
- pub ( crate ) fn expand (
215
+ pub ( crate ) fn expand_with_mode (
202
216
ecx : & mut ExtCtxt < ' _ > ,
203
217
expand_span : Span ,
204
218
meta_item : & ast:: MetaItem ,
205
219
mut item : Annotatable ,
220
+ mode : DiffMode ,
206
221
) -> Vec < Annotatable > {
207
222
if cfg ! ( not( llvm_enzyme) ) {
208
223
ecx. sess . dcx ( ) . emit_err ( errors:: AutoDiffSupportNotBuild { span : meta_item. span } ) ;
@@ -245,29 +260,41 @@ mod llvm_enzyme {
245
260
// create TokenStream from vec elemtents:
246
261
// meta_item doesn't have a .tokens field
247
262
let mut ts: Vec < TokenTree > = vec ! [ ] ;
248
- if meta_item_vec. len ( ) < 2 {
249
- // At the bare minimum, we need a fnc name and a mode, even for a dummy function with no
250
- // input and output args.
263
+ if meta_item_vec. len ( ) < 1 {
264
+ // At the bare minimum, we need a fnc name.
251
265
dcx. emit_err ( errors:: AutoDiffMissingConfig { span : item. span ( ) } ) ;
252
266
return vec ! [ item] ;
253
267
}
254
268
255
- meta_item_inner_to_ts ( & meta_item_vec[ 1 ] , & mut ts) ;
269
+ let mode_symbol = match mode {
270
+ DiffMode :: Forward => sym:: Forward ,
271
+ DiffMode :: Reverse => sym:: Reverse ,
272
+ _ => unreachable ! ( "Unsupported mode: {:?}" , mode) ,
273
+ } ;
274
+
275
+ // Insert mode token
276
+ let mode_token = Token :: new ( TokenKind :: Ident ( mode_symbol, false . into ( ) ) , Span :: default ( ) ) ;
277
+ ts. insert ( 0 , TokenTree :: Token ( mode_token, Spacing :: Joint ) ) ;
278
+ ts. insert (
279
+ 1 ,
280
+ TokenTree :: Token ( Token :: new ( TokenKind :: Comma , Span :: default ( ) ) , Spacing :: Alone ) ,
281
+ ) ;
256
282
257
283
// Now, if the user gave a width (vector aka batch-mode ad), then we copy it.
258
284
// If it is not given, we default to 1 (scalar mode).
259
285
let start_position;
260
286
let kind: LitKind = LitKind :: Integer ;
261
287
let symbol;
262
- if meta_item_vec. len ( ) >= 3
263
- && let Some ( width) = width ( & meta_item_vec[ 2 ] )
288
+ if meta_item_vec. len ( ) >= 2
289
+ && let Some ( width) = width ( & meta_item_vec[ 1 ] )
264
290
{
265
- start_position = 3 ;
291
+ start_position = 2 ;
266
292
symbol = Symbol :: intern ( & width. to_string ( ) ) ;
267
293
} else {
268
- start_position = 2 ;
294
+ start_position = 1 ;
269
295
symbol = sym:: integer ( 1 ) ;
270
296
}
297
+
271
298
let l: Lit = Lit { kind, symbol, suffix : None } ;
272
299
let t = Token :: new ( TokenKind :: Literal ( l) , Span :: default ( ) ) ;
273
300
let comma = Token :: new ( TokenKind :: Comma , Span :: default ( ) ) ;
@@ -289,7 +316,7 @@ mod llvm_enzyme {
289
316
ts. pop ( ) ;
290
317
let ts: TokenStream = TokenStream :: from_iter ( ts) ;
291
318
292
- let x: AutoDiffAttrs = from_ast ( ecx, & meta_item_vec, has_ret) ;
319
+ let x: AutoDiffAttrs = from_ast ( ecx, & meta_item_vec, has_ret, mode ) ;
293
320
if !x. is_active ( ) {
294
321
// We encountered an error, so we return the original item.
295
322
// This allows us to potentially parse other attributes.
@@ -1017,4 +1044,4 @@ mod llvm_enzyme {
1017
1044
}
1018
1045
}
1019
1046
1020
- pub ( crate ) use llvm_enzyme:: expand ;
1047
+ pub ( crate ) use llvm_enzyme:: { expand_forward , expand_reverse } ;
0 commit comments