Skip to content

Commit 998cb27

Browse files
authored
inference: reinfer and track missing code for inlining (#59413)
When code is potentially needed for inlining, but missing for any reason, be sure to regenerate it during inference with the correct `ci_meets_requirement` flags (SOURCE_MODE_GET_SOURCE instead of NOT_REQUIRED) so it is prepared for the optimizer if needed. This was supposed to be the correct fix for someone else's bug with missing inlining when expected, but I don't remember what bug it was anymore. It should however improve handling and tracking of code (and worlds of code) during inference, which should help for building upon this base. ``` julia> @atomic Base.method_instance(promote, (Float64,Int)).cache.inferred = 0x22 # mark with inlining cost, deleting code julia> @code_typed 1.0+1 # this should regenerate the code, but on master instead results in: CodeInfo( @ promotion.jl:433 within `+` 1 ─ %1 = invoke Base.promote(x::Float64, y::Int64)::Tuple{Float64, Float64} │ %2 = builtin Core.getfield(%1, 1)::Float64 │ %3 = builtin Core.getfield(%1, 2)::Float64 │ @ promotion.jl:433 within `+` @ float.jl:492 │ %4 = intrinsic Base.add_float(%2, %3)::Float64 └── return %4 ) => Float64 ```
1 parent 5fcc944 commit 998cb27

24 files changed

+737
-656
lines changed

Compiler/src/abstractinterpretation.jl

Lines changed: 111 additions & 102 deletions
Large diffs are not rendered by default.

Compiler/src/cicache.jl

Lines changed: 23 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,5 @@
11
# This file is a part of Julia. License is MIT: https://julialang.org/license
22

3-
"""
4-
struct InternalCodeCache
5-
6-
Internally, each `MethodInstance` keep a unique global cache of code instances
7-
that have been created for the given method instance, stratified by world age
8-
ranges. This struct abstracts over access to this cache.
9-
"""
10-
struct InternalCodeCache
11-
owner::Any # `jl_egal` is used for comparison
12-
end
13-
14-
function setindex!(cache::InternalCodeCache, ci::CodeInstance, mi::MethodInstance)
15-
@assert ci.owner === cache.owner
16-
m = mi.def
17-
if isa(m, Method)
18-
ccall(:jl_push_newly_inferred, Cvoid, (Any,), ci)
19-
end
20-
ccall(:jl_mi_cache_insert, Cvoid, (Any, Any), mi, ci)
21-
return cache
22-
end
23-
243
struct WorldRange
254
min_world::UInt
265
max_world::UInt
@@ -49,39 +28,43 @@ function union(a::WorldRange, b::WorldRange)
4928
end
5029

5130
"""
52-
struct WorldView
31+
struct InternalCodeCache
5332
54-
Takes a given cache and provides access to the cache contents for the given
55-
range of world ages, rather than defaulting to the current active world age.
33+
Internally, each `MethodInstance` keep a unique global cache of code instances
34+
that have been created for the given method instance, stratified by world age
35+
ranges. This struct abstracts over access to this cache.
5636
"""
57-
struct WorldView{Cache}
58-
cache::Cache
37+
struct InternalCodeCache
38+
owner::Any # `jl_egal` is used for comparison
5939
worlds::WorldRange
60-
WorldView(cache::Cache, range::WorldRange) where Cache = new{Cache}(cache, range)
40+
InternalCodeCache(@nospecialize(owner), wr::WorldRange) = new(owner, wr)
41+
InternalCodeCache(@nospecialize(owner), args...) = new(owner, WorldRange(args...))
6142
end
62-
WorldView(cache, args...) = WorldView(cache, WorldRange(args...))
63-
WorldView(wvc::WorldView, wr::WorldRange) = WorldView(wvc.cache, wr)
64-
WorldView(wvc::WorldView, args...) = WorldView(wvc.cache, args...)
6543

66-
function haskey(wvc::WorldView{InternalCodeCache}, mi::MethodInstance)
67-
return ccall(:jl_rettype_inferred, Any, (Any, Any, UInt, UInt), wvc.cache.owner, mi, first(wvc.worlds), last(wvc.worlds)) !== nothing
44+
function setindex!(cache::InternalCodeCache, ci::CodeInstance, mi::MethodInstance)
45+
@assert ci.owner === cache.owner
46+
m = mi.def
47+
if isa(m, Method)
48+
ccall(:jl_push_newly_inferred, Cvoid, (Any,), ci)
49+
end
50+
ccall(:jl_mi_cache_insert, Cvoid, (Any, Any), mi, ci)
51+
return cache
6852
end
6953

70-
function get(wvc::WorldView{InternalCodeCache}, mi::MethodInstance, default)
71-
r = ccall(:jl_rettype_inferred, Any, (Any, Any, UInt, UInt), wvc.cache.owner, mi, first(wvc.worlds), last(wvc.worlds))
54+
function haskey(wvc::InternalCodeCache, mi::MethodInstance)
55+
return ccall(:jl_rettype_inferred, Any, (Any, Any, UInt, UInt), wvc.owner, mi, first(wvc.worlds), last(wvc.worlds)) !== nothing
56+
end
57+
58+
function get(wvc::InternalCodeCache, mi::MethodInstance, default)
59+
r = ccall(:jl_rettype_inferred, Any, (Any, Any, UInt, UInt), wvc.owner, mi, first(wvc.worlds), last(wvc.worlds))
7260
if r === nothing
7361
return default
7462
end
7563
return r::CodeInstance
7664
end
7765

78-
function getindex(wvc::WorldView{InternalCodeCache}, mi::MethodInstance)
66+
function getindex(wvc::InternalCodeCache, mi::MethodInstance)
7967
r = get(wvc, mi, nothing)
8068
r === nothing && throw(KeyError(mi))
8169
return r::CodeInstance
8270
end
83-
84-
function setindex!(wvc::WorldView{InternalCodeCache}, ci::CodeInstance, mi::MethodInstance)
85-
setindex!(wvc.cache, ci, mi)
86-
return wvc
87-
end

Compiler/src/inferenceresult.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ function is_argtype_match(𝕃::AbstractLattice,
8989
end
9090
end
9191

92-
function va_process_argtypes(𝕃::AbstractLattice, given_argtypes::Vector{Any}, nargs::UInt, isva::Bool)
92+
function va_process_argtypes(𝕃::AbstractLattice, given_argtypes::Vector{Any}, nargs::UInt, isva::Bool, mi::MethodInstance)
9393
nargs = Int(nargs)
9494
if isva || (!isempty(given_argtypes) && isvarargtype(given_argtypes[end]))
9595
isva_given_argtypes = Vector{Any}(undef, nargs)
@@ -120,7 +120,10 @@ function va_process_argtypes(𝕃::AbstractLattice, given_argtypes::Vector{Any},
120120
end
121121
return isva_given_argtypes
122122
end
123-
@assert length(given_argtypes) == nargs "invalid `given_argtypes` for `mi`"
123+
if length(given_argtypes) != nargs
124+
println(given_argtypes, " != ", nargs, " for ", mi)
125+
throw(AssertionError("invalid `given_argtypes` for `mi`"))
126+
end
124127
return given_argtypes
125128
end
126129

@@ -178,16 +181,17 @@ function elim_free_typevars(@nospecialize t)
178181
end
179182
end
180183

181-
function cache_lookup(𝕃::AbstractLattice, mi::MethodInstance, given_argtypes::Vector{Any},
182-
cache::Vector{InferenceResult})
184+
function constprop_cache_lookup(𝕃::AbstractLattice, mi::MethodInstance, given_argtypes::Vector{Any}, cache::Vector{InferenceResult})
183185
method = mi.def::Method
184186
nargtypes = length(given_argtypes)
185187
for cached_result in cache
186188
cached_result.tombstone && continue # ignore deleted entries (due to LimitedAccuracy)
187189
cached_result.linfo === mi || continue
188190
cache_argtypes = cached_result.argtypes
189191
@assert length(cache_argtypes) == nargtypes "invalid `cache_argtypes` for `mi`"
190-
cache_overridden_by_const = cached_result.overridden_by_const::BitVector
192+
cache_overridden_by_const = cached_result.overridden_by_const
193+
cache_overridden_by_const === nothing && continue
194+
cache_overridden_by_const = cache_overridden_by_const::BitVector
191195
for i in 1:nargtypes
192196
if !is_argtype_match(𝕃, given_argtypes[i], cache_argtypes[i], cache_overridden_by_const[i])
193197
@goto next_cache

Compiler/src/inferencestate.jl

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,6 @@ end
217217
const CACHE_MODE_NULL = 0x00 # not cached, optimization optional
218218
const CACHE_MODE_GLOBAL = 0x01 << 0 # cached globally, optimization required
219219
const CACHE_MODE_LOCAL = 0x01 << 1 # cached locally, optimization required
220-
const CACHE_MODE_VOLATILE = 0x01 << 2 # not cached, optimization required
221220

222221
abstract type Handler end
223222
get_enter_idx(handler::Handler) = get_enter_idx_impl(handler)::Int
@@ -262,7 +261,7 @@ intersect(world::WorldWithRange, valid_worlds::WorldRange) =
262261
mutable struct InferenceState
263262
#= information about this method instance =#
264263
linfo::MethodInstance
265-
world::WorldWithRange
264+
valid_worlds::WorldRange
266265
mod::Module
267266
sptypes::Vector{VarState}
268267
slottypes::Vector{Any}
@@ -349,7 +348,7 @@ mutable struct InferenceState
349348
bb_vartable1 = bb_vartables[1] = VarTable(undef, nslots)
350349
argtypes = result.argtypes
351350

352-
argtypes = va_process_argtypes(typeinf_lattice(interp), argtypes, src.nargs, src.isva)
351+
argtypes = va_process_argtypes(typeinf_lattice(interp), argtypes, src.nargs, src.isva, mi)
353352

354353
nargtypes = length(argtypes)
355354
for i = 1:nslots
@@ -392,7 +391,7 @@ mutable struct InferenceState
392391
parentid = frameid = cycleid = 0
393392

394393
this = new(
395-
mi, WorldWithRange(world, valid_worlds), mod, sptypes, slottypes, src, cfg, spec_info,
394+
mi, valid_worlds, mod, sptypes, slottypes, src, cfg, spec_info,
396395
currbb, currpc, ip, handler_info, ssavalue_uses, bb_vartables, bb_saw_latestworld, ssavaluetypes, ssaflags, edges, stmt_info,
397396
tasks, pclimitations, limitations, cycle_backedges, callstack, parentid, frameid, cycleid,
398397
result, unreachable, bestguess, exc_bestguess, ipo_effects,
@@ -401,9 +400,6 @@ mutable struct InferenceState
401400
interp)
402401

403402
# some more setups
404-
if !iszero(cache_mode & CACHE_MODE_LOCAL)
405-
push!(get_inference_cache(interp), result)
406-
end
407403
if !iszero(cache_mode & CACHE_MODE_GLOBAL)
408404
push!(callstack, this)
409405
this.cycleid = this.frameid = length(callstack)
@@ -412,7 +408,7 @@ mutable struct InferenceState
412408
# Apply generated function restrictions
413409
if src.min_world != 1 || src.max_world != typemax(UInt)
414410
# From generated functions
415-
update_valid_age!(this, WorldRange(src.min_world, src.max_world))
411+
update_valid_age!(this, world, WorldRange(src.min_world, src.max_world))
416412
end
417413

418414
return this
@@ -615,8 +611,6 @@ function convert_cache_mode(cache_mode::Symbol)
615611
return CACHE_MODE_GLOBAL
616612
elseif cache_mode === :local
617613
return CACHE_MODE_LOCAL
618-
elseif cache_mode === :volatile
619-
return CACHE_MODE_VOLATILE
620614
elseif cache_mode === :no
621615
return CACHE_MODE_NULL
622616
end
@@ -821,7 +815,7 @@ mutable struct IRInterpretationState
821815
const spec_info::SpecInfo
822816
const ir::IRCode
823817
const mi::MethodInstance
824-
world::WorldWithRange
818+
valid_worlds::WorldRange
825819
curridx::Int
826820
time_caches::Float64
827821
time_paused::UInt64
@@ -836,9 +830,10 @@ mutable struct IRInterpretationState
836830
frameid::Int
837831
parentid::Int
838832

839-
function IRInterpretationState(interp::AbstractInterpreter,
840-
spec_info::SpecInfo, ir::IRCode, mi::MethodInstance, argtypes::Vector{Any},
841-
world::UInt, min_world::UInt, max_world::UInt)
833+
function IRInterpretationState(
834+
interp::AbstractInterpreter, spec_info::SpecInfo, ir::IRCode,
835+
mi::MethodInstance, argtypes::Vector{Any}, min_world::UInt, max_world::UInt
836+
)
842837
curridx = 1
843838
given_argtypes = Vector{Any}(undef, length(argtypes))
844839
for i = 1:length(given_argtypes)
@@ -856,28 +851,32 @@ mutable struct IRInterpretationState
856851
ssa_refined = BitSet()
857852
lazyreachability = LazyCFGReachability(ir)
858853
valid_worlds = WorldRange(min_world, max_world == typemax(UInt) ? get_world_counter() : max_world)
854+
if !(get_inference_world(interp) in valid_worlds)
855+
error("invalid age range update")
856+
end
859857
tasks = WorkThunk[]
860858
edges = Any[]
861859
callstack = AbsIntState[]
862-
return new(spec_info, ir, mi, WorldWithRange(world, valid_worlds),
860+
return new(spec_info, ir, mi, valid_worlds,
863861
curridx, 0.0, 0, argtypes_refined, ir.sptypes, tpdum,
864862
ssa_refined, lazyreachability, tasks, edges, callstack, 0, 0)
865863
end
866864
end
867865

868-
function IRInterpretationState(interp::AbstractInterpreter,
869-
codeinst::CodeInstance, mi::MethodInstance, argtypes::Vector{Any}, world::UInt)
866+
function IRInterpretationState(
867+
interp::AbstractInterpreter, codeinst::CodeInstance, mi::MethodInstance,
868+
argtypes::Vector{Any}, @nospecialize(src)
869+
)
870870
@assert get_ci_mi(codeinst) === mi "method instance is not synced with code instance"
871-
src = @atomic :monotonic codeinst.inferred
872871
if isa(src, String)
873872
src = _uncompressed_ir(codeinst, src)
874873
else
875874
isa(src, CodeInfo) || return nothing
876875
end
877876
spec_info = SpecInfo(src)
878877
ir = inflate_ir(src, mi)
879-
argtypes = va_process_argtypes(optimizer_lattice(interp), argtypes, src.nargs, src.isva)
880-
return IRInterpretationState(interp, spec_info, ir, mi, argtypes, world,
878+
argtypes = va_process_argtypes(optimizer_lattice(interp), argtypes, src.nargs, src.isva, mi)
879+
return IRInterpretationState(interp, spec_info, ir, mi, argtypes,
881880
codeinst.min_world, codeinst.max_world)
882881
end
883882

@@ -900,7 +899,7 @@ function print_callstack(frame::AbsIntState)
900899
end
901900
print("] ")
902901
print(frame_instance(sv))
903-
is_cached(sv) || print(" [uncached]")
902+
is_cached(sv) || print(" [not globally cached]")
904903
sv.parentid == idx - 1 || print(" [parent=", sv.parentid, "]")
905904
isempty(callers_in_cycle(sv)) || print(" [cycle=", sv.cycleid, "]")
906905
println()
@@ -964,9 +963,6 @@ spec_info(sv::IRInterpretationState) = sv.spec_info
964963
propagate_inbounds(sv::AbsIntState) = spec_info(sv).propagate_inbounds
965964
method_for_inference_limit_heuristics(sv::AbsIntState) = spec_info(sv).method_for_inference_limit_heuristics
966965

967-
frame_world(sv::InferenceState) = sv.world.this
968-
frame_world(sv::IRInterpretationState) = sv.world.this
969-
970966
function is_effect_overridden(sv::AbsIntState, effect::Symbol)
971967
if is_effect_overridden(frame_instance(sv), effect)
972968
return true
@@ -986,9 +982,13 @@ has_conditional(𝕃::AbstractLattice, ::InferenceState) = has_conditional(𝕃)
986982
has_conditional(::AbstractLattice, ::IRInterpretationState) = false
987983

988984
# work towards converging the valid age range for sv
989-
function update_valid_age!(sv::AbsIntState, valid_worlds::WorldRange)
990-
sv.world = intersect(sv.world, valid_worlds)
991-
return sv.world.valid_worlds
985+
function update_valid_age!(sv::AbsIntState, world, valid_worlds::WorldRange)
986+
valid_worlds = intersect(sv.valid_worlds, valid_worlds)
987+
if !(world in valid_worlds)
988+
error("invalid age range update")
989+
end
990+
sv.valid_worlds = valid_worlds
991+
return valid_worlds
992992
end
993993

994994
"""

Compiler/src/optimize.jl

Lines changed: 15 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -99,14 +99,9 @@ end
9999

100100
const TOP_TUPLE = GlobalRef(Core, :tuple)
101101

102-
# This corresponds to the type of `CodeInfo`'s `inlining_cost` field
103-
const InlineCostType = UInt16
104-
const MAX_INLINE_COST = typemax(InlineCostType)
105-
const MIN_INLINE_COST = InlineCostType(10)
106-
const MaybeCompressed = Union{CodeInfo, String}
107-
108-
is_inlineable(@nospecialize src::MaybeCompressed) =
109-
ccall(:jl_ir_inlining_cost, InlineCostType, (Any,), src) != MAX_INLINE_COST
102+
inlining_cost(@nospecialize src) =
103+
src isa Union{MaybeCompressed,UInt8} ? ccall(:jl_ir_inlining_cost, InlineCostType, (Any,), src) : MAX_INLINE_COST
104+
is_inlineable(@nospecialize src) = inlining_cost(src) != MAX_INLINE_COST
110105
set_inlineable!(src::CodeInfo, val::Bool) =
111106
src.inlining_cost = (val ? MIN_INLINE_COST : MAX_INLINE_COST)
112107

@@ -158,47 +153,38 @@ end
158153

159154
struct InliningState{Interp<:AbstractInterpreter}
160155
edges::Vector{Any}
161-
world::UInt
162156
interp::Interp
163157
opt_cache::IdDict{MethodInstance,CodeInstance}
164158
end
165159
function InliningState(sv::InferenceState, interp::AbstractInterpreter,
166160
opt_cache::IdDict{MethodInstance,CodeInstance}=IdDict{MethodInstance,CodeInstance}())
167-
return InliningState(sv.edges, frame_world(sv), interp, opt_cache)
161+
return InliningState(sv.edges, interp, opt_cache)
168162
end
169163
function InliningState(interp::AbstractInterpreter,
170164
opt_cache::IdDict{MethodInstance,CodeInstance}=IdDict{MethodInstance,CodeInstance}())
171-
return InliningState(Any[], get_inference_world(interp), interp, opt_cache)
165+
return InliningState(Any[], interp, opt_cache)
172166
end
173167

174168
struct OptimizerCache{CodeCache}
175-
wvc::WorldView{CodeCache}
176-
owner
169+
cache::CodeCache
177170
opt_cache::IdDict{MethodInstance,CodeInstance}
178171
function OptimizerCache(
179-
wvc::WorldView{CodeCache},
180-
@nospecialize(owner),
172+
cache::CodeCache,
181173
opt_cache::IdDict{MethodInstance,CodeInstance}) where CodeCache
182-
new{CodeCache}(wvc, owner, opt_cache)
174+
return new{CodeCache}(cache, opt_cache)
183175
end
184176
end
185-
function get((; wvc, owner, opt_cache)::OptimizerCache, mi::MethodInstance, default)
177+
function get((; cache, opt_cache)::OptimizerCache, mi::MethodInstance, default)
186178
if haskey(opt_cache, mi)
187-
codeinst = opt_cache[mi]
188-
@assert codeinst.min_world wvc.worlds.min_world &&
189-
wvc.worlds.max_world codeinst.max_world &&
190-
codeinst.owner === owner
191-
@assert isdefined(codeinst, :inferred) && codeinst.inferred === nothing
192-
return codeinst
179+
return opt_cache[mi] # this is incomplete right now, but will be finished (by finish_cycle) before caching anything
193180
end
194-
return get(wvc, mi, default)
181+
return get(cache, mi, default)
195182
end
196183

197184
# get `code_cache(::AbstractInterpreter)` from `state::InliningState`
198185
function code_cache(state::InliningState)
199-
cache = WorldView(code_cache(state.interp), state.world)
200-
owner = cache_owner(state.interp)
201-
return OptimizerCache(cache, owner, state.opt_cache)
186+
cache = code_cache(state.interp)
187+
return OptimizerCache(cache, state.opt_cache)
202188
end
203189

204190
mutable struct OptimizationResult
@@ -678,15 +664,15 @@ GetNativeEscapeCache(interp::AbstractInterpreter) = GetNativeEscapeCache(code_ca
678664
function ((; code_cache)::GetNativeEscapeCache)(codeinst::Union{CodeInstance,MethodInstance})
679665
if codeinst isa MethodInstance
680666
codeinst = get(code_cache, codeinst, nothing)
681-
codeinst isa CodeInstance || return false
667+
codeinst === nothing && return false
682668
end
683669
argescapes = traverse_analysis_results(codeinst) do @nospecialize result
684670
return result isa EscapeAnalysis.ArgEscapeCache ? result : nothing
685671
end
686672
if argescapes !== nothing
687673
return argescapes
688674
end
689-
effects = decode_effects(codeinst.ipo_purity_bits)
675+
effects = codeinst isa CodeInstance ? decode_effects(codeinst.ipo_purity_bits) : codeinst.ipo_effects
690676
if is_effect_free(effects) && is_inaccessiblememonly(effects)
691677
# We might not have run EA on simple frames without any escapes (e.g. when optimization
692678
# is skipped when result is constant-folded by abstract interpretation). If those

0 commit comments

Comments
 (0)