Skip to content

Commit 91f2f77

Browse files
committed
[NewOptimizer] Add back getfield elim pass
This adds back a (fixed version of) the NewOptimizer getfield elim pass. This passes all getfield elim (and `@allocated`) tests in base, with the exception of two tests that check for getfield elim on mutables. That part is planned in the design of this pass, but will be part of a follow up PR.
1 parent 46dcb35 commit 91f2f77

File tree

4 files changed

+120
-5
lines changed

4 files changed

+120
-5
lines changed

base/compiler/ssair/driver.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,10 @@ function run_passes(ci::CodeInfo, nargs::Int, linetable::Vector{LineInfoNode})
129129
IRCode(code, lines, cfg, argtypes, mod, meta)
130130
end
131131
ir = construct_ssa!(ci, ir, domtree, defuse_insts, nargs)
132-
domtree = construct_domtree(ir.cfg)
133132
ir = compact!(ir)
134133
verify_ir(ir)
134+
ir = getfield_elim_pass!(ir)
135+
ir = compact!(ir)
135136
ir = type_lift_pass!(ir)
136137
ir = compact!(ir)
137138
verify_ir(ir)

base/compiler/ssair/ir.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,10 @@ function getindex(compact::IncrementalCompact, idx)
355355
end
356356
end
357357

358+
function getindex(view::TypesView, v::OldSSAValue)
359+
return view.ir.ir.types[v.id]
360+
end
361+
358362
function setindex!(compact::IncrementalCompact, v, idx)
359363
if idx < compact.result_idx
360364
# Kill count for current uses
@@ -372,8 +376,8 @@ end
372376

373377
function getindex(view::TypesView, idx)
374378
isa(idx, SSAValue) && (idx = idx.id)
375-
if isa(view.ir, IncrementalCompact) && idx < view.compact.result_idx
376-
return view.compact.result_types[idx]
379+
if isa(view.ir, IncrementalCompact) && idx < view.ir.result_idx
380+
return view.ir.result_types[idx]
377381
else
378382
ir = isa(view.ir, IncrementalCompact) ? view.ir.ir : view.ir
379383
if idx <= length(ir.types)
@@ -500,7 +504,8 @@ function next(compact::IncrementalCompact, (idx, active_bb, old_result_idx)::Tup
500504
end
501505

502506
function maybe_erase_unused!(extra_worklist, compact, idx)
503-
if stmt_effect_free(compact.result[idx], compact.ir, compact.ir.mod)
507+
effect_free = stmt_effect_free(compact.result[idx], compact.ir, compact.ir.mod)
508+
if effect_free
504509
for ops in userefs(compact.result[idx])
505510
val = ops[]
506511
if isa(val, SSAValue)

base/compiler/ssair/passes.jl

Lines changed: 108 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,110 @@
1+
function compact_exprtype(compact, value)
2+
if isa(value, Union{SSAValue, OldSSAValue})
3+
return types(compact)[value]
4+
elseif isa(value, Argument)
5+
return compact.ir.argtypes[value.n]
6+
end
7+
exprtype(value, compact.ir, compact.ir.mod)
8+
end
9+
10+
function getfield_elim_pass!(ir::IRCode)
11+
compact = IncrementalCompact(ir)
12+
insertions = Vector{Any}()
13+
for (idx, stmt) in compact
14+
isa(stmt, Expr) || continue
15+
is_known_call(stmt, getfield, ir, ir.mod) || continue
16+
isa(stmt.args[2], SSAValue) || continue
17+
field = stmt.args[3]
18+
isa(field, QuoteNode) && (field = field.value)
19+
isa(field, Union{Int, Symbol}) || continue
20+
orig_defidx = defidx = stmt.args[2].id
21+
def = compact[defidx]
22+
typeconstraint = types(compact)[defidx]
23+
phi_locs = Tuple{Int, Int}[]
24+
while true
25+
if isa(def, PiNode)
26+
typeconstraint = typeintersect(typeconstraint, def.typ)
27+
if isa(def.val, SSAValue)
28+
defidx = def.val.id
29+
def = compact[defidx]
30+
else
31+
def = def.val
32+
end
33+
continue
34+
elseif isa(def, PhiNode)
35+
possible_predecessors = collect(Iterators.filter(1:length(def.edges)) do n
36+
isassigned(def.values, n) || return false
37+
value = def.values[n]
38+
edge_typ = compact_exprtype(compact, value)
39+
return edge_typ typeconstraint
40+
end)
41+
# For now, only look at unique predecessors
42+
if length(possible_predecessors) == 1
43+
n = possible_predecessors[1]
44+
pred = def.edges[n]
45+
val = def.values[n]
46+
if isa(val, SSAValue)
47+
push!(phi_locs, (pred, defidx))
48+
defidx = val.id
49+
def = compact[defidx]
50+
elseif def == val
51+
# This shouldn't really ever happen, but
52+
# patterns like this can occur in dead code,
53+
# so bail out.
54+
break
55+
else
56+
def = val
57+
end
58+
continue
59+
end
60+
end
61+
break
62+
end
63+
if isa(def, Expr) && is_known_call(def, tuple, ir, ir.mod) && isa(field, Int) && 1 <= field < length(def.args)
64+
forwarded = def.args[1+field]
65+
elseif isexpr(def, :new)
66+
typ = def.typ
67+
if isa(typ, UnionAll)
68+
typ = unwrap_unionall(typ)
69+
end
70+
isa(typ, DataType) || continue
71+
!typ.mutable || continue
72+
if isa(field, Symbol)
73+
field = fieldindex(typ, field, false)
74+
field == 0 && continue
75+
elseif isa(field, Integer)
76+
(1 <= field <= fieldcount(typ)) || continue
77+
end
78+
forwarded = def.args[1+field]
79+
else
80+
continue
81+
end
82+
if !isempty(phi_locs) && isa(forwarded, SSAValue)
83+
# TODO: We have have to use BB ids for phi_locs
84+
# to avoid index invalidation.
85+
push!(insertions, (idx, phi_locs))
86+
end
87+
compact[idx] = forwarded
88+
end
89+
ir = finish(compact)
90+
for (idx, phi_locs) in insertions
91+
# For non-dominating load-store forward, we may have to insert extra phi nodes
92+
# TODO: Can use the domtree to eliminate unnecessary phis, but ok for now
93+
forwarded = ir.stmts[idx]
94+
if isa(forwarded, SSAValue)
95+
forwarded_typ = ir.types[forwarded.id]
96+
for (pred, pos) in reverse!(phi_locs)
97+
node = PhiNode()
98+
push!(node.edges, pred)
99+
push!(node.values, forwarded)
100+
forwarded = insert_node!(ir, pos, forwarded_typ, node)
101+
end
102+
end
103+
ir.stmts[idx] = forwarded
104+
end
105+
ir
106+
end
107+
1108
function type_lift_pass!(ir::IRCode)
2109
type_ctx_uses = Vector{Vector{Int}}[]
3110
has_non_type_ctx_uses = IdSet{Int}()
@@ -81,4 +188,4 @@ function type_lift_pass!(ir::IRCode)
81188
end
82189
end
83190
ir
84-
end
191+
end

base/compiler/utilities.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,8 @@ function exprtype(@nospecialize(x), src, mod::Module)
181181
return (x::TypedSlot).typ
182182
elseif isa(x, SSAValue)
183183
return abstract_eval_ssavalue(x::SSAValue, src)
184+
elseif isa(x, Argument)
185+
return src.argtypes[x.n]
184186
elseif isa(x, Symbol)
185187
return abstract_eval_global(mod, x::Symbol)
186188
elseif isa(x, QuoteNode)

0 commit comments

Comments
 (0)