@@ -137,7 +137,7 @@ function clean_unused(code)
137137        if  l. head== :(= )
138138        rhs =  l. args[2 ] 
139139        tt =  []
140-         MacroTools. postwalk (x->  ( ((x isa (Symbol)) &&   string (x)[ end ] == ' _ ' ?  push! (tt, x) :  nothing  ) , rhs)
140+         MacroTools. postwalk (x->  ( ((x isa (Symbol))  ) ?  push! (tt, x) :  nothing  ) , rhs)
141141        push! (uses, tt)
142142        push! (defs, l. args[1 ])
143143        end 
@@ -162,6 +162,149 @@ function clean_unused(code)
162162
163163end 
164164
165+ 
166+ """ 
167+ Create a non allocating kernel from the function factory. 
168+ 
169+ `fff`: assumed to be a `FlatFunctionFactory` object with empty preamble. 
170+ `diff`: index of variables to differentiate with or list of indices of variables positions. 
171+ 
172+ The generated kernel looks like (diff=[0, 1]) 
173+ ``` 
174+ function myfun(x::SVector{1, Float64}, y::SVector{3, Float64}, z::SVector{2, Float64}, p::SVector{1, Float64}) 
175+     _a_m1_ = x[1] 
176+     _a__0_ = y[1] 
177+     _b__0_ = y[2] 
178+     _c__0_ = y[3] 
179+     _c__1_ = z[1] 
180+     _d__1_ = z[2] 
181+     _u_ = p[1] 
182+     _foo__0_ = log(_a__0_) + _b__0_ / (_a_m1_ / (1 - _c__0_)) 
183+     _bar__0_ = _c__1_ + _u_ * _d__1_ 
184+     d__foo__0__d__a_m1_ = (-(1 / (1 - _c__0_)) * _b__0_) / (_a_m1_ / (1 - _c__0_)) ^ 2 
185+     d__bar__0__d__a_m1_ = 0 
186+     oo_0_ = SVector(_foo__0_, _bar__0_) 
187+     oo_1_ = SMatrix{2, 1}(d__foo__0__d__a_m1_, d__bar__0__d__a_m1_) 
188+     res_ = (oo_0_, oo_1_) 
189+     return res_ 
190+ end 
191+ ``` 
192+ 
193+ If diff is a scalar, the result of the kernel is a static vector (or a static matrix). 
194+ If diff is a list, the result is a tuple. 
195+ """ 
196+ function  gen_kernel2 (fff:: FlatFunctionFactory , diff:: Union{Int, Vector{Int}} ; funname= fff. funname, arguments= fff. arguments, dispatch= nothing )
197+ 
198+     targets =  [keys (fff. equations)... ]
199+     equations =  [values (fff. equations)... ]
200+     #  names of symbols to output
201+     output_names =  []
202+     for  d in  diff
203+         if  d ==  0 
204+             push! (output_names, targets)
205+         else 
206+             diff_args =  collect (values (fff. arguments))[d]
207+             p =  length (targets)
208+             q =  length (diff_args)
209+             mat =  Matrix {Symbol} (undef, p, q)
210+             for  i in  1 : p
211+                 for  j in  1 : q
212+                     mat[i, j] =  diff_symbol (targets[i], diff_args[j])
213+                 end 
214+             end 
215+             push! (output_names, mat)
216+         end 
217+     end 
218+ 
219+ 
220+     argnames =  collect (keys (arguments))
221+ 
222+     all_eqs =  cat (values (fff. preamble)... , equations, dims= 1 )
223+     all_args =  cat (values (fff. arguments)... , dims= 1 )
224+     if  maximum (diff)> 0 
225+         jac_args =  cat ([collect (values (fff. arguments))[i] for  i in  diff if  i!= 0 ]. .. , dims= 1 )
226+     else 
227+         jac_args =  []
228+     end 
229+     #  jac_args = Symbol.(jac_args) # strange type of output can by Any[]
230+     jac_args =  Symbol[Symbol (e) for  e in  jac_args]
231+ 
232+     #  concatenate preamble and equations (doesn't make much sense...)
233+     dd =  OrderedDict ()
234+     for  (k, v) in  (fff. preamble)
235+         dd[k] =  v
236+     end 
237+     for  (target, eq) in  zip (targets, equations)
238+         dd[target] =  eq
239+     end 
240+     #  compute all equations to write
241+     diff_eqs =  add_derivatives (dd, jac_args)
242+     for  out in  output_names
243+         for  k in  out
244+             if  ! (haskey (diff_eqs, k))
245+                 diff_eqs[k] =  0.0 
246+             end 
247+         end 
248+     end 
249+ 
250+     diff_eqs =  reorder_triangular_block (diff_eqs)
251+ 
252+     #  create function block
253+     code =  []
254+ 
255+     push! (code, :(T= getprecision (model)))
256+ 
257+     for  (k, args) in  zip (argnames, values (arguments))
258+         for  (i, a) in  enumerate (args)
259+             push! (code, :($ a =  ($ k)[$ i]))
260+         end 
261+     end 
262+     for  (k, v) in  diff_eqs
263+         push! (code, :($ k= $ v))
264+     end 
265+ 
266+     return_args =  []
267+     for  (d, names) in  enumerate (output_names)
268+         outname =  Symbol (" oo_" " _" 
269+         push! (code, :($ outname =  $ (_sym_sarray (names))))
270+         push! (return_args, outname)
271+     end 
272+ 
273+     #  this is to make inserting the resulting code in another function easier
274+     if  typeof (diff) <:  Int 
275+         push! (code, :(res_ =  $ (return_args[1 ])))
276+         push! (code, :(return  res_))
277+     else 
278+         push! (code, :(res_ =  $ (Expr (:tuple , return_args... ))))
279+         push! (code, :(return  res_))
280+     end 
281+ 
282+     cast_scalars =  u-> MacroTools. postwalk (x ->  ((x isa  Real) &  ! (x isa  Integer)) ?  :(convert (T,$ x)) :  x, u)
283+ 
284+     code =  cast_scalars .(code)
285+     #  now we construct the function
286+     #  typed_args = [:($k::SVector{$(length(v)), Float64}) for (k, v) in arguments]
287+ 
288+     typed_args =  [:($ k:: SVector{$(length(v))} ) for  (k, v) in  arguments]
289+     #  typed_args[end]= :(p::P)
290+ 
291+     #  if !(dispatch isa Nothing)
292+     #      prepend!(typed_args, :(::$dispatch))
293+     #  end
294+ 
295+     #  n_p = length(arguments[[keys(arguments)...][end]][2])
296+ 
297+     typed_args =  [:(model:: $dispatch ), typed_args... ]
298+     
299+     #  fun_args = Expr(:call, funname, typed_args...)
300+     fun_args =  :(($ funname)($ (typed_args... )))
301+     ncode =  Expr (:function , fun_args, Expr (:block , code... ))
302+ 
303+     return  clean_unused (ncode)
304+ 
305+ end 
306+ 
307+ 
165308""" 
166309Create a non allocating kernel from the function factory. 
167310
0 commit comments