Skip to content

Commit 5a47896

Browse files
committed
TEMP: special kernel for DoloYAML
1 parent 815f722 commit 5a47896

File tree

1 file changed

+144
-1
lines changed

1 file changed

+144
-1
lines changed

src/compiler.jl

Lines changed: 144 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ function clean_unused(code)
137137
if l.head==:(=)
138138
rhs = l.args[2]
139139
tt = []
140-
MacroTools.postwalk(x-> ( ((x isa(Symbol)) && string(x)[end]=='_' ) ? push!(tt, x) : nothing ) , rhs)
140+
MacroTools.postwalk(x-> ( ((x isa(Symbol)) ) ? push!(tt, x) : nothing ) , rhs)
141141
push!(uses, tt)
142142
push!(defs, l.args[1])
143143
end
@@ -162,6 +162,149 @@ function clean_unused(code)
162162

163163
end
164164

165+
166+
"""
167+
Create a non allocating kernel from the function factory.
168+
169+
`fff`: assumed to be a `FlatFunctionFactory` object with empty preamble.
170+
`diff`: index of variables to differentiate with or list of indices of variables positions.
171+
172+
The generated kernel looks like (diff=[0, 1])
173+
```
174+
function myfun(x::SVector{1, Float64}, y::SVector{3, Float64}, z::SVector{2, Float64}, p::SVector{1, Float64})
175+
_a_m1_ = x[1]
176+
_a__0_ = y[1]
177+
_b__0_ = y[2]
178+
_c__0_ = y[3]
179+
_c__1_ = z[1]
180+
_d__1_ = z[2]
181+
_u_ = p[1]
182+
_foo__0_ = log(_a__0_) + _b__0_ / (_a_m1_ / (1 - _c__0_))
183+
_bar__0_ = _c__1_ + _u_ * _d__1_
184+
d__foo__0__d__a_m1_ = (-(1 / (1 - _c__0_)) * _b__0_) / (_a_m1_ / (1 - _c__0_)) ^ 2
185+
d__bar__0__d__a_m1_ = 0
186+
oo_0_ = SVector(_foo__0_, _bar__0_)
187+
oo_1_ = SMatrix{2, 1}(d__foo__0__d__a_m1_, d__bar__0__d__a_m1_)
188+
res_ = (oo_0_, oo_1_)
189+
return res_
190+
end
191+
```
192+
193+
If diff is a scalar, the result of the kernel is a static vector (or a static matrix).
194+
If diff is a list, the result is a tuple.
195+
"""
196+
function gen_kernel2(fff::FlatFunctionFactory, diff::Union{Int, Vector{Int}}; funname=fff.funname, arguments=fff.arguments, dispatch=nothing)
197+
198+
targets = [keys(fff.equations)...]
199+
equations = [values(fff.equations)...]
200+
# names of symbols to output
201+
output_names = []
202+
for d in diff
203+
if d == 0
204+
push!(output_names, targets)
205+
else
206+
diff_args = collect(values(fff.arguments))[d]
207+
p = length(targets)
208+
q = length(diff_args)
209+
mat = Matrix{Symbol}(undef, p, q)
210+
for i in 1:p
211+
for j in 1:q
212+
mat[i, j] = diff_symbol(targets[i], diff_args[j])
213+
end
214+
end
215+
push!(output_names, mat)
216+
end
217+
end
218+
219+
220+
argnames = collect(keys(arguments))
221+
222+
all_eqs = cat(values(fff.preamble)..., equations, dims=1)
223+
all_args = cat(values(fff.arguments)..., dims=1)
224+
if maximum(diff)>0
225+
jac_args = cat([collect(values(fff.arguments))[i] for i in diff if i!=0]..., dims=1)
226+
else
227+
jac_args = []
228+
end
229+
# jac_args = Symbol.(jac_args) # strange type of output can by Any[]
230+
jac_args = Symbol[Symbol(e) for e in jac_args]
231+
232+
# concatenate preamble and equations (doesn't make much sense...)
233+
dd = OrderedDict()
234+
for (k, v) in (fff.preamble)
235+
dd[k] = v
236+
end
237+
for (target, eq) in zip(targets, equations)
238+
dd[target] = eq
239+
end
240+
# compute all equations to write
241+
diff_eqs = add_derivatives(dd, jac_args)
242+
for out in output_names
243+
for k in out
244+
if !(haskey(diff_eqs, k))
245+
diff_eqs[k] = 0.0
246+
end
247+
end
248+
end
249+
250+
diff_eqs = reorder_triangular_block(diff_eqs)
251+
252+
# create function block
253+
code = []
254+
255+
push!(code, :(T=getprecision(model)))
256+
257+
for (k, args) in zip(argnames, values(arguments))
258+
for (i, a) in enumerate(args)
259+
push!(code, :($a = ($k)[$i]))
260+
end
261+
end
262+
for (k, v) in diff_eqs
263+
push!(code, :($k=$v))
264+
end
265+
266+
return_args = []
267+
for (d, names) in enumerate(output_names)
268+
outname = Symbol("oo_", d, "_")
269+
push!(code, :($outname = $(_sym_sarray(names))))
270+
push!(return_args, outname)
271+
end
272+
273+
# this is to make inserting the resulting code in another function easier
274+
if typeof(diff) <: Int
275+
push!(code, :(res_ = $(return_args[1])))
276+
push!(code, :(return res_))
277+
else
278+
push!(code, :(res_ = $(Expr(:tuple, return_args...))))
279+
push!(code, :(return res_))
280+
end
281+
282+
cast_scalars = u->MacroTools.postwalk(x -> ((x isa Real) & !(x isa Integer)) ? :(convert(T,$x)) : x, u)
283+
284+
code = cast_scalars.(code)
285+
# now we construct the function
286+
# typed_args = [:($k::SVector{$(length(v)), Float64}) for (k, v) in arguments]
287+
288+
typed_args = [:($k::SVector{$(length(v))}) for (k, v) in arguments]
289+
# typed_args[end]= :(p::P)
290+
291+
# if !(dispatch isa Nothing)
292+
# prepend!(typed_args, :(::$dispatch))
293+
# end
294+
295+
# n_p = length(arguments[[keys(arguments)...][end]][2])
296+
297+
typed_args = [:(model::$dispatch), typed_args...]
298+
299+
# fun_args = Expr(:call, funname, typed_args...)
300+
fun_args = :(($funname)($(typed_args...)))
301+
ncode = Expr(:function, fun_args, Expr(:block, code...))
302+
303+
return clean_unused(ncode)
304+
305+
end
306+
307+
165308
"""
166309
Create a non allocating kernel from the function factory.
167310

0 commit comments

Comments
 (0)