Skip to content

Commit 8af40ea

Browse files
committed
[NewOptimizer] Perform getfield elim for mutable structs
The algorithm works essentially the same as SSA renaming and reuses idf.
1 parent c5f1092 commit 8af40ea

File tree

8 files changed

+214
-30
lines changed

8 files changed

+214
-30
lines changed

base/compiler/optimize.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -924,7 +924,9 @@ function effect_free(@nospecialize(e), src, mod::Module, allow_volatile::Bool)
924924
end
925925
fieldcount(typ) >= length(ea) - 1 || return false
926926
for fld_idx in 1:(length(ea) - 1)
927-
exprtype(ea[fld_idx + 1], src, mod) fieldtype(typ, fld_idx) || return false
927+
eT = exprtype(ea[fld_idx + 1], src, mod)
928+
fT = fieldtype(typ, fld_idx)
929+
eT fT || return false
928930
end
929931
# fall-through
930932
elseif head === :return

base/compiler/ssair/driver.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,11 @@ function run_passes(ci::CodeInfo, nargs::Int, linetable::Vector{LineInfoNode})
129129
IRCode(code, lines, cfg, argtypes, mod, meta)
130130
end
131131
ir = construct_ssa!(ci, ir, domtree, defuse_insts, nargs)
132+
# TODO: Domsorting can produce an updated domtree - no need to recompute here
133+
domtree = construct_domtree(cfg)
132134
ir = compact!(ir)
133135
verify_ir(ir)
134-
ir = getfield_elim_pass!(ir)
136+
ir = getfield_elim_pass!(ir, domtree)
135137
ir = compact!(ir)
136138
ir = type_lift_pass!(ir)
137139
ir = compact!(ir)

base/compiler/ssair/ir.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,13 @@ function getindex(x::IRCode, s::SSAValue)
151151
end
152152
end
153153

154+
function setindex!(x::IRCode, repl, s::SSAValue)
155+
@assert s.id <= length(x.stmts)
156+
x.stmts[s.id] = repl
157+
nothing
158+
end
159+
160+
154161
struct OldSSAValue
155162
id::Int
156163
end
@@ -504,7 +511,7 @@ function next(compact::IncrementalCompact, (idx, active_bb, old_result_idx)::Tup
504511
end
505512

506513
function maybe_erase_unused!(extra_worklist, compact, idx)
507-
effect_free = stmt_effect_free(compact.result[idx], compact.ir, compact.ir.mod)
514+
effect_free = stmt_effect_free(compact.result[idx], compact, compact.ir.mod)
508515
if effect_free
509516
for ops in userefs(compact.result[idx])
510517
val = ops[]

base/compiler/ssair/passes.jl

Lines changed: 186 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,100 @@ function compact_exprtype(compact, value)
77
exprtype(value, compact.ir, compact.ir.mod)
88
end
99

10-
function getfield_elim_pass!(ir::IRCode)
10+
struct SSADefUse
11+
uses::Vector{Int}
12+
defs::Vector{Int}
13+
end
14+
SSADefUse() = SSADefUse(Int[], Int[])
15+
16+
function try_compute_fieldidx(typ, use_expr)
17+
field = use_expr.args[3]
18+
isa(field, QuoteNode) && (field = field.value)
19+
isa(field, Union{Int, Symbol}) || return nothing
20+
if isa(field, Symbol)
21+
field = fieldindex(typ, field, false)
22+
field == 0 && return nothing
23+
elseif isa(field, Integer)
24+
(1 <= field <= fieldcount(typ)) || return nothing
25+
end
26+
return field
27+
end
28+
29+
function lift_defuse(cfg::CFG, ssa::SSADefUse)
30+
SSADefUse(
31+
Int[block_for_inst(cfg, x) for x in ssa.uses],
32+
Int[block_for_inst(cfg, x) for x in ssa.defs])
33+
end
34+
35+
function find_curblock(domtree, allblocks, curblock)
36+
# TODO: This can be much faster by looking at current level and only
37+
# searching for those blocks in a sorted order
38+
while !(curblock in allblocks)
39+
curblock = domtree.idoms[curblock]
40+
end
41+
curblock
42+
end
43+
44+
function val_for_def_expr(ir, def, fidx)
45+
if isexpr(ir[SSAValue(def)], :new)
46+
return ir[SSAValue(def)].args[1+fidx]
47+
else
48+
# The use is whatever the setfield was
49+
return ir[SSAValue(def)].args[4]
50+
end
51+
end
52+
53+
function compute_value_for_block(ir, domtree, allblocks, du, phinodes, fidx, curblock)
54+
curblock = find_curblock(domtree, allblocks, curblock)
55+
def = reduce(max, 0, stmt for stmt in du.defs if block_for_inst(ir.cfg, stmt) == curblock)
56+
def == 0 ? phinodes[curblock] : val_for_def_expr(ir, def, fidx)
57+
end
58+
59+
function compute_value_for_use(ir, domtree, allblocks, du, phinodes, fidx, use_idx)
60+
# Find the first dominating def
61+
curblock = stmtblock = block_for_inst(ir.cfg, use_idx)
62+
curblock = find_curblock(domtree, allblocks, curblock)
63+
defblockdefs = [stmt for stmt in du.defs if block_for_inst(ir.cfg, stmt) == curblock]
64+
def = 0
65+
if !isempty(defblockdefs)
66+
if curblock != stmtblock
67+
# Find the last def in this block
68+
def = maximum(defblockdefs)
69+
else
70+
# Find the last def before our use
71+
def = mapreduce(x->x >= use_idx ? 0 : x, max, defblockdefs)
72+
end
73+
end
74+
if def == 0
75+
if !haskey(phinodes, curblock)
76+
# If this happens, we need to search the predecessors for defs. Which
77+
# one doesn't matter - if it did, we'd have had a phinode
78+
return compute_value_for_block(ir, domtree, allblocks, du, phinodes, fidx, first(ir.cfg.blocks[stmtblock].preds))
79+
end
80+
# The use is the phinode
81+
return phinodes[curblock]
82+
else
83+
return val_for_def_expr(ir, def, fidx)
84+
end
85+
end
86+
87+
function getfield_elim_pass!(ir::IRCode, domtree)
1188
compact = IncrementalCompact(ir)
1289
insertions = Vector{Any}()
90+
defuses = IdDict{Int, Tuple{IdSet{Int}, SSADefUse}}()
1391
for (idx, stmt) in compact
14-
# Step 1: Check whether the statement we're looking at is a getfield
1592
isa(stmt, Expr) || continue
16-
is_known_call(stmt, getfield, ir, ir.mod) || continue
93+
is_getfield = false
94+
# Step 1: Check whether the statement we're looking at is a getfield/setfield!
95+
if is_known_call(stmt, setfield!, ir, ir.mod)
96+
is_setfield = true
97+
elseif is_known_call(stmt, getfield, ir, ir.mod)
98+
is_getfield = true
99+
else
100+
continue
101+
end
17102
isa(stmt.args[2], SSAValue) || continue
18-
## Normalize the field argument to getfield
103+
## Normalize the field argument to getfield/setfield
19104
field = stmt.args[3]
20105
isa(field, QuoteNode) && (field = field.value)
21106
isa(field, Union{Int, Symbol}) || continue
@@ -26,8 +111,13 @@ function getfield_elim_pass!(ir::IRCode)
26111
typeconstraint = types(compact)[defidx]
27112
phi_locs = Tuple{Int, Int}[]
28113
## Track definitions through PiNode/PhiNode
114+
found_def = false
115+
## Track which PhiNodes, SSAValue intermediaries
116+
## we forwarded through.
117+
intermediaries = IdSet{Int}()
29118
while true
30119
if isa(def, PiNode)
120+
push!(intermediaries, defidx)
31121
typeconstraint = typeintersect(typeconstraint, def.typ)
32122
if isa(def.val, SSAValue)
33123
defidx = def.val.id
@@ -37,6 +127,8 @@ function getfield_elim_pass!(ir::IRCode)
37127
end
38128
continue
39129
elseif isa(def, PhiNode)
130+
# For now, we don't track setfields structs through phi nodes
131+
is_getfield || break
40132
possible_predecessors = collect(Iterators.filter(1:length(def.edges)) do n
41133
isassigned(def.values, n) || return false
42134
value = def.values[n]
@@ -62,9 +154,22 @@ function getfield_elim_pass!(ir::IRCode)
62154
end
63155
continue
64156
end
157+
elseif isa(def, SSAValue)
158+
push!(intermediaries, defidx)
159+
defidx = def.id
160+
def = compact[def.id]
161+
continue
65162
end
163+
found_def = true
66164
break
67165
end
166+
found_def || continue
167+
if !is_getfield
168+
mid, defuse = get!(defuses, defidx, (IdSet{Int}(), SSADefUse()))
169+
push!(defuse.defs, idx)
170+
union!(mid, intermediaries)
171+
continue
172+
end
68173
# Step 3: Check if the definition we eventually end up at is either
69174
# a tuple(...) call or Expr(:new) and perform replacement.
70175
if isa(def, Expr) && is_known_call(def, tuple, ir, ir.mod) && isa(field, Int) && 1 <= field < length(def.args)
@@ -75,13 +180,14 @@ function getfield_elim_pass!(ir::IRCode)
75180
typ = unwrap_unionall(typ)
76181
end
77182
isa(typ, DataType) || continue
78-
!typ.mutable || continue
79-
if isa(field, Symbol)
80-
field = fieldindex(typ, field, false)
81-
field == 0 && continue
82-
elseif isa(field, Integer)
83-
(1 <= field <= fieldcount(typ)) || continue
183+
if typ.mutable
184+
mid, defuse = get!(defuses, defidx, (IdSet{Int}(), SSADefUse()))
185+
push!(defuse.uses, idx)
186+
union!(mid, intermediaries)
187+
continue
84188
end
189+
field = try_compute_fieldidx(typ, stmt)
190+
field === nothing && continue
85191
forwarded = def.args[1+field]
86192
else
87193
continue
@@ -95,6 +201,76 @@ function getfield_elim_pass!(ir::IRCode)
95201
compact[idx] = forwarded
96202
end
97203
ir = finish(compact)
204+
@Base.show length(defuses)
205+
# Now go through any mutable structs and see which ones we can eliminate
206+
for (idx, (intermediaries, defuse)) in defuses
207+
intermediaries = collect(intermediaries)
208+
# Check if there are any uses we did not account for. If so, the variable
209+
# escapes and we cannot eliminate the allocation. This works, because we're guaranteed
210+
# not to include any intermediaries that have dead uses. As a result, missing uses will only ever
211+
# show up in the nuses_total count.
212+
nleaves = length(defuse.uses) + length(defuse.defs)
213+
nuses_total = compact.used_ssas[idx] + mapreduce(idx->compact.used_ssas[idx], +, 0, intermediaries) - length(intermediaries)
214+
@Base.show (nleaves, nuses_total)
215+
nleaves == nuses_total || continue
216+
# Find the type for this allocation
217+
defexpr = ir[SSAValue(idx)]
218+
isexpr(defexpr, :new) || continue
219+
typ = defexpr.typ
220+
if isa(typ, UnionAll)
221+
typ = unwrap_unionall(typ)
222+
end
223+
# Could still end up here if we tried to setfield! and immutable, which would
224+
# error at runtime, but is not illegal to have in the IR.
225+
typ.mutable || continue
226+
# Partition defuses by field
227+
fielddefuse = SSADefUse[SSADefUse() for _ = 1:fieldcount(typ)]
228+
ok = true
229+
for use in defuse.uses
230+
field = try_compute_fieldidx(typ, ir[SSAValue(use)])
231+
field === nothing && (ok = false; break)
232+
push!(fielddefuse[field].uses, use)
233+
end
234+
ok || continue
235+
for use in defuse.defs
236+
field = try_compute_fieldidx(typ, ir[SSAValue(use)])
237+
field === nothing && (ok = false; break)
238+
push!(fielddefuse[field].defs, use)
239+
end
240+
ok || continue
241+
# Everything accounted for. Go field by field and perform idf
242+
for (fidx, du) in pairs(fielddefuse)
243+
ftyp = fieldtype(typ, fidx)
244+
if !isempty(du.uses)
245+
push!(du.defs, idx)
246+
ldu = lift_defuse(ir.cfg, du)
247+
phiblocks = idf(ir.cfg, ldu, domtree)
248+
phinodes = IdDict{Int, SSAValue}()
249+
for b in phiblocks
250+
n = PhiNode()
251+
phinodes[b] = insert_node!(ir, first(ir.cfg.blocks[b].stmts), ftyp, n)
252+
end
253+
# Now go through all uses and rewrite them
254+
allblocks = sort(vcat(phiblocks, ldu.defs))
255+
for stmt in du.uses
256+
ir[SSAValue(stmt)] = compute_value_for_use(ir, domtree, allblocks, du, phinodes, fidx, stmt)
257+
end
258+
for b in phiblocks
259+
for p in ir.cfg.blocks[b].preds
260+
n = ir[phinodes[b]]
261+
push!(n.edges, p)
262+
push!(n.values, compute_value_for_block(ir, domtree,
263+
allblocks, du, phinodes, fidx, p))
264+
end
265+
end
266+
end
267+
for stmt in du.defs
268+
stmt == idx && continue
269+
ir[SSAValue(stmt)] = nothing
270+
end
271+
continue
272+
end
273+
end
98274
for (idx, phi_locs) in insertions
99275
# For non-dominating load-store forward, we may have to insert extra phi nodes
100276
# TODO: Can use the domtree to eliminate unnecessary phis, but ok for now

base/compiler/ssair/queries.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1-
function stmt_effect_free(@nospecialize(stmt), src::IRCode, mod::Module)
1+
function stmt_effect_free(@nospecialize(stmt), src, mod::Module)
22
isa(stmt, Union{PiNode, PhiNode}) && return true
33
isa(stmt, Union{ReturnNode, GotoNode, GotoIfNot}) && return false
4-
return statement_effect_free(stmt, src, mod)
4+
return effect_free(stmt, src, mod, true)
55
end
66

77
function abstract_eval_ssavalue(s::SSAValue, src::IRCode)
88
return src.types[s.id]
99
end
10+
11+
function abstract_eval_ssavalue(s::SSAValue, src::IncrementalCompact)
12+
return types(src)[s]
13+
end

base/compiler/ssair/slot2ssa.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -220,9 +220,9 @@ function typ_for_val(@nospecialize(val), ci::CodeInfo)
220220
end
221221

222222
# Run iterated dominance frontier
223-
function idf(cfg::CFG, defuse, domtree::DomTree, slot::Int)
223+
function idf(cfg::CFG, defuse, domtree::DomTree)
224224
# This should be a priority queue, but TODO - sorted array for now
225-
defs = defuse[slot].defs
225+
defs = defuse.defs
226226
pq = Tuple{Int, Int}[(defs[i], domtree.nodes[defs[i]].level) for i in 1:length(defs)]
227227
sort!(pq, by=x->x[2])
228228
phiblocks = Int[]
@@ -241,7 +241,7 @@ function idf(cfg::CFG, defuse, domtree::DomTree, slot::Int)
241241
push!(processed, succ)
242242
# <- TODO: Use liveness here
243243
push!(phiblocks, succ)
244-
if !(succ in defuse[slot].defs)
244+
if !(succ in defs)
245245
push!(pq, (succ, succ_level))
246246
sort!(pq, by=x->x[2])
247247
end
@@ -451,7 +451,7 @@ function construct_ssa!(ci::CodeInfo, ir::IRCode, domtree::DomTree, defuse, narg
451451
continue
452452
end
453453
# TODO: Perform liveness here to eliminate dead phi nodes
454-
phiblocks = idf(cfg, defuse_blocks, domtree, idx)
454+
phiblocks = idf(cfg, defuse_blocks[idx], domtree)
455455
for block in phiblocks
456456
push!(phi_slots[block], idx)
457457
node = PhiNode()

base/compiler/utilities.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ function exprtype(@nospecialize(x), src, mod::Module)
182182
elseif isa(x, SSAValue)
183183
return abstract_eval_ssavalue(x::SSAValue, src)
184184
elseif isa(x, Argument)
185-
return src.argtypes[x.n]
185+
return isa(src, IncrementalCompact) ? src.ir.argtypes[x.n] : src.argtypes[x.n]
186186
elseif isa(x, Symbol)
187187
return abstract_eval_global(mod, x::Symbol)
188188
elseif isa(x, QuoteNode)

base/iterators.jl

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1045,11 +1045,7 @@ mutable struct Stateful{T, VS}
10451045
@inline function Stateful(itr::T) where {T}
10461046
state = start(itr)
10471047
VS = fixpoint_iter_type(T, Union{}, typeof(state))
1048-
if done(itr, state)
1049-
new{T, VS}(itr, nothing, 0)
1050-
else
1051-
new{T, VS}(itr, next(itr, state)::VS, 0)
1052-
end
1048+
new{T, VS}(itr, done(itr, state) ? nothing : next(itr, state)::VS, 0)
10531049
end
10541050
end
10551051

@@ -1094,11 +1090,8 @@ convert(::Type{Stateful}, itr) = Stateful(itr)
10941090
throw(EOFError())
10951091
else
10961092
val, state = vs
1097-
if done(s.itr, state)
1098-
s.nextvalstate = nothing
1099-
else
1100-
s.nextvalstate = next(s.itr, state)
1101-
end
1093+
# Until the optimizer can handle setproperty! better here, use explicit setfield!
1094+
setfield!(s, :nextvalstate, done(s.itr, state) ? nothing : next(s.itr, state))
11021095
s.taken += 1
11031096
return val
11041097
end

0 commit comments

Comments
 (0)