Skip to content

Commit

Permalink
inference: implement an opt-in interface to cache generated sources (J…
Browse files Browse the repository at this point in the history
…uliaLang#54916)

In Cassette-like systems, where inference has to infer many calls of
`@generated` function and the generated function involves complex code
transformations, the overhead from code generation itself can become
significant. This is because the results of code generation are not
cached, leading to duplicated code generation in the following contexts:
- `method_for_inference_heuristics` for regular inference on cached
`@generated` function calls (since
`method_for_inference_limit_heuristics` isn't stored in cached optimized
sources, but is attached to generated unoptimized sources).
- `retrieval_code_info` for constant propagation on cached `@generated`
function calls.

Having said that, caching unoptimized sources generated by `@generated`
functions is not a good tradeoff in general cases, considering the
memory space consumed (and the image bloat). The code generation for
generators like `GeneratedFunctionStub` produced by the front end is
generally very simple, and the first duplicated code generation
mentioned above does not occur for `GeneratedFunctionStub`.

So this unoptimized source caching should be enabled in an opt-in
manner.

Based on this idea, this commit defines the trait `abstract type
Core.CachedGenerator` as an interface for the external systems to
opt-in. If the generator is a subtype of this trait, inference caches
the generated unoptimized code, sacrificing memory space to improve the
performance of subsequent inferences. Specifically, the mechanism for
caching the unoptimized source uses the infrastructure already
implemented in JuliaLang#54362. Thanks to JuliaLang#54362,
the cache for generated functions is now partitioned by world age, so
even if the unoptimized source is cached, the existing invalidation
system will invalidate it as expected.

In JuliaDebug/CassetteOverlay.jl#56, the following benchmark results
showed that approximately 1.5~3x inference speedup is achieved by opting
into this feature:

## Setup
```julia
using CassetteOverlay, BaseBenchmarks, BenchmarkTools

@MethodTable table;
pass = @overlaypass table;
BaseBenchmarks.load!("inference");
benchfunc1() = sin(42)
benchfunc2(xs, x) = findall(>(x), abs.(xs))
interp = BaseBenchmarks.InferenceBenchmarks.InferenceBenchmarker()

# benchmark inference on entire call graphs from scratch
@benchmark BaseBenchmarks.InferenceBenchmarks.@inf_call pass(benchfunc1)
@benchmark BaseBenchmarks.InferenceBenchmarks.@inf_call pass(benchfunc2, rand(10), 0.5)

# benchmark inference on the call graphs with most of them cached
@benchmark BaseBenchmarks.InferenceBenchmarks.@inf_call interp=interp pass(benchfunc1)
@benchmark BaseBenchmarks.InferenceBenchmarks.@inf_call interp=interp pass(benchfunc2, rand(10), 0.5)
```

## Benchmark inference on entire call graphs from scratch
> on master
```
julia> @benchmark BaseBenchmarks.InferenceBenchmarks.@inf_call pass(benchfunc1)
BenchmarkTools.Trial: 61 samples with 1 evaluation.
 Range (min … max):  78.574 ms … 87.653 ms  ┊ GC (min … max): 0.00% … 8.81%
 Time  (median):     83.149 ms              ┊ GC (median):    4.85%
 Time  (mean ± σ):   82.138 ms ±  2.366 ms  ┊ GC (mean ± σ):  3.36% ± 2.65%

  ▂ ▂▂ █     ▂                     █ ▅     ▅
  █▅██▅█▅▁█▁▁█▁▁▁▁▅▁▁▁▁▁▁▁▁▅▁▁▅██▅▅█████████▁█▁▅▁▁▁▁▁▁▁▁▁▁▁▁▅ ▁
  78.6 ms         Histogram: frequency by time        86.8 ms <

 Memory estimate: 52.32 MiB, allocs estimate: 1201192.

julia> @benchmark BaseBenchmarks.InferenceBenchmarks.@inf_call pass(benchfunc2, rand(10), 0.5)
BenchmarkTools.Trial: 4 samples with 1 evaluation.
 Range (min … max):  1.345 s …  1.369 s  ┊ GC (min … max): 2.45% … 3.39%
 Time  (median):     1.355 s             ┊ GC (median):    2.98%
 Time  (mean ± σ):   1.356 s ± 9.847 ms  ┊ GC (mean ± σ):  2.96% ± 0.41%

  █                   █     █                            █
  █▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█ ▁
  1.35 s        Histogram: frequency by time        1.37 s <

 Memory estimate: 637.96 MiB, allocs estimate: 15159639.
```
> with this PR
```
julia> @benchmark BaseBenchmarks.InferenceBenchmarks.@inf_call pass(benchfunc1)
BenchmarkTools.Trial: 230 samples with 1 evaluation.
 Range (min … max):  19.339 ms … 82.521 ms  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     19.938 ms              ┊ GC (median):    0.00%
 Time  (mean ± σ):   21.665 ms ±  4.666 ms  ┊ GC (mean ± σ):  6.72% ± 8.80%

  ▃▇█▇▄                     ▂▂▃▃▄
  █████▇█▇▆▅▅▆▅▅▁▅▁▁▁▁▁▁▁▁▁██████▆▁█▁▅▇▆▁▅▁▁▅▁▅▁▁▁▁▁▁▅▁▁▁▁▁▁▅ ▆
  19.3 ms      Histogram: log(frequency) by time      29.4 ms <

 Memory estimate: 28.67 MiB, allocs estimate: 590138.

julia> @benchmark BaseBenchmarks.InferenceBenchmarks.@inf_call pass(benchfunc2, rand(10), 0.5)
BenchmarkTools.Trial: 14 samples with 1 evaluation.
 Range (min … max):  354.585 ms … 390.400 ms  ┊ GC (min … max): 0.00% … 7.01%
 Time  (median):     368.778 ms               ┊ GC (median):    3.74%
 Time  (mean ± σ):   368.824 ms ±   8.853 ms  ┊ GC (mean ± σ):  3.70% ± 1.89%

             ▃            █
  ▇▁▁▁▁▁▁▁▁▁▁█▁▇▇▁▁▁▁▇▁▁▁▁█▁▁▁▁▇▁▁▇▁▁▇▁▁▁▇▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▇ ▁
  355 ms           Histogram: frequency by time          390 ms <

 Memory estimate: 227.86 MiB, allocs estimate: 4689830.
```

## Benchmark inference on the call graphs with most of them cached
> on master
```
julia> @benchmark BaseBenchmarks.InferenceBenchmarks.@inf_call interp=interp pass(benchfunc1)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  45.166 μs …  9.799 ms  ┊ GC (min … max): 0.00% … 98.96%
 Time  (median):     46.792 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   48.339 μs ± 97.539 μs  ┊ GC (mean ± σ):  2.01% ±  0.99%

    ▁▂▄▆▆▇███▇▆▅▄▃▄▄▂▂▂▁▁▁   ▁▁▂▂▁ ▁ ▂▁ ▁                     ▃
  ▃▇██████████████████████▇████████████████▇█▆▇▇▆▆▆▅▆▆▆▇▆▅▅▅▆ █
  45.2 μs      Histogram: log(frequency) by time        55 μs <

 Memory estimate: 25.27 KiB, allocs estimate: 614.

julia> @benchmark BaseBenchmarks.InferenceBenchmarks.@inf_call interp=interp pass(benchfunc2, rand(10), 0.5)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  303.375 μs …  16.582 ms  ┊ GC (min … max): 0.00% … 97.38%
 Time  (median):     317.625 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   338.772 μs ± 274.164 μs  ┊ GC (mean ± σ):  5.44% ±  7.56%

       ▃▆██▇▅▂▁
  ▂▂▄▅██████████▇▆▅▅▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▂▂▂▂▂▁▂▁▁▂▁▂▂ ▃
  303 μs           Histogram: frequency by time          394 μs <

 Memory estimate: 412.80 KiB, allocs estimate: 6224.
```
> with this PR
```
@benchmark BaseBenchmarks.InferenceBenchmarks.@inf_call interp=interp pass(benchfunc1)
BenchmarkTools.Trial: 10000 samples with 6 evaluations.
 Range (min … max):  5.444 μs …  1.808 ms  ┊ GC (min … max): 0.00% … 99.01%
 Time  (median):     5.694 μs              ┊ GC (median):    0.00%
 Time  (mean ± σ):   6.228 μs ± 25.393 μs  ┊ GC (mean ± σ):  5.73% ±  1.40%

      ▄█▇▄
  ▁▂▄█████▇▄▃▃▃▂▂▂▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂
  5.44 μs        Histogram: frequency by time        7.47 μs <

 Memory estimate: 8.72 KiB, allocs estimate: 196.

julia> @benchmark BaseBenchmarks.InferenceBenchmarks.@inf_call interp=interp pass(benchfunc2, rand(10), 0.5)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  211.000 μs …  36.187 ms  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     223.000 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   280.025 μs ± 750.097 μs  ┊ GC (mean ± σ):  6.86% ± 7.16%

  █▆▄▂▁                                                         ▁
  ███████▇▇▇▆▆▆▅▆▅▅▅▅▅▄▅▄▄▄▅▅▁▄▅▃▄▄▄▃▄▄▃▅▄▁▁▃▄▁▃▁▁▁▃▄▃▁▃▁▁▁▃▃▁▃ █
  211 μs        Histogram: log(frequency) by time       1.46 ms <

 Memory estimate: 374.17 KiB, allocs estimate: 5269.
```
  • Loading branch information
aviatesk committed Jun 28, 2024
1 parent ca0b2a8 commit b5d0b90
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 30 deletions.
7 changes: 7 additions & 0 deletions base/boot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,13 @@ function (g::GeneratedFunctionStub)(world::UInt, source::LineNumberNode, @nospec
end
end

# If the generator is a subtype of this trait, inference caches the generated unoptimized
# code, sacrificing memory space to improve the performance of subsequent inferences.
# This tradeoff is not appropriate in general cases (e.g., for `GeneratedFunctionStub`s
# generated from the front end), but it can be justified for generators involving complex
# code transformations, such as a Cassette-like system.
abstract type CachedGenerator end

NamedTuple() = NamedTuple{(),Tuple{}}(())

eval(Core, :(NamedTuple{names}(args::Tuple) where {names} =
Expand Down
16 changes: 15 additions & 1 deletion base/compiler/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,27 @@ use_const_api(li::CodeInstance) = invoke_api(li) == 2

function get_staged(mi::MethodInstance, world::UInt)
may_invoke_generator(mi) || return nothing
cache_ci = (mi.def::Method).generator isa Core.CachedGenerator ?
RefValue{CodeInstance}() : nothing
try
return ccall(:jl_code_for_staged, Ref{CodeInfo}, (Any, UInt, Ptr{Cvoid}), mi, world, C_NULL)
return call_get_staged(mi, world, cache_ci)
catch # user code might throw errors – ignore them
return nothing
end
end

# enable caching of unoptimized generated code if the generator is `CachedGenerator`
function call_get_staged(mi::MethodInstance, world::UInt, cache_ci::RefValue{CodeInstance})
token = @_gc_preserve_begin cache_ci
cache_ci_ptr = pointer_from_objref(cache_ci)
src = ccall(:jl_code_for_staged, Ref{CodeInfo}, (Any, UInt, Ptr{CodeInstance}), mi, world, cache_ci_ptr)
@_gc_preserve_end token
return src
end
function call_get_staged(mi::MethodInstance, world::UInt, ::Nothing)
return ccall(:jl_code_for_staged, Ref{CodeInfo}, (Any, UInt, Ptr{Cvoid}), mi, world, C_NULL)
end

function get_cached_uninferred(mi::MethodInstance, world::UInt)
ccall(:jl_cached_uninferred, Any, (Any, UInt), mi.cache, world)::CodeInstance
end
Expand Down
54 changes: 25 additions & 29 deletions test/compiler/contextual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,9 @@ module MiniCassette
# A minimal demonstration of the cassette mechanism. Doesn't support all the
# fancy features, but sufficient to exercise this code path in the compiler.

using Core.Compiler: retrieve_code_info, CodeInfo,
MethodInstance, SSAValue, GotoNode, GotoIfNot, ReturnNode, SlotNumber, quoted,
signature_type, anymap
using Base: _methods_by_ftype
using Core.IR
using Core.Compiler: retrieve_code_info, quoted, signature_type, anymap
using Base.Meta: isexpr
using Test

export Ctx, overdub

Expand Down Expand Up @@ -75,38 +72,36 @@ module MiniCassette
end
end

function overdub_generator(world::UInt, source, self, c, f, args)
function overdub_generator(world::UInt, source, self, ctx, f, args)
@nospecialize
if !Base.issingletontype(f)
# (c, f, args..) -> f(args...)
code_info = :(return f(args...))
return Core.GeneratedFunctionStub(identity, Core.svec(:overdub, :c, :f, :args), Core.svec())(world, source, code_info)
ex = :(return f(args...))
return Core.GeneratedFunctionStub(identity, Core.svec(:overdub, :ctx, :f, :args), Core.svec())(world, source, ex)
end

tt = Tuple{f, args...}
match = Base._which(tt; world)
mi = Core.Compiler.specialize_method(match)
# Unsupported in this mini-cassette
@assert !mi.def.isva
code_info = retrieve_code_info(mi, world)
@assert isa(code_info, CodeInfo)
code_info = copy(code_info)
@assert code_info.edges === nothing
code_info.edges = MethodInstance[mi]
transform!(mi, code_info, length(args), match.sparams)
src = retrieve_code_info(mi, world)
@assert isa(src, CodeInfo)
src = copy(src)
@assert src.edges === nothing
src.edges = MethodInstance[mi]
transform!(mi, src, length(args), match.sparams)
# TODO: this is mandatory: code_info.min_world = max(code_info.min_world, min_world[])
# TODO: this is mandatory: code_info.max_world = min(code_info.max_world, max_world[])
# Match the generator, since that's what our transform! does
code_info.nargs = 4
code_info.isva = true
return code_info
src.nargs = 4
src.isva = true
return src
end

@inline function overdub(c::Ctx, f::Union{Core.Builtin, Core.IntrinsicFunction}, args...)
f(args...)
end
@inline overdub(::Ctx, f::Union{Core.Builtin, Core.IntrinsicFunction}, args...) = f(args...)

@eval function overdub(c::Ctx, f, args...)
@eval function overdub(ctx::Ctx, f, args...)
$(Expr(:meta, :generated_only))
$(Expr(:meta, :generated, overdub_generator))
end
Expand Down Expand Up @@ -149,14 +144,15 @@ end

end # module OverlayModule

methods = Base._methods_by_ftype(Tuple{typeof(sin), Float64}, nothing, 1, Base.get_world_counter())
@test only(methods).method.module === Base.Math

methods = Base._methods_by_ftype(Tuple{typeof(sin), Float64}, OverlayModule.mt, 1, Base.get_world_counter())
@test only(methods).method.module === OverlayModule

methods = Base._methods_by_ftype(Tuple{typeof(sin), Int}, OverlayModule.mt, 1, Base.get_world_counter())
@test isempty(methods)
let ms = Base._methods_by_ftype(Tuple{typeof(sin), Float64}, nothing, 1, Base.get_world_counter())
@test only(ms).method.module === Base.Math
end
let ms = Base._methods_by_ftype(Tuple{typeof(sin), Float64}, OverlayModule.mt, 1, Base.get_world_counter())
@test only(ms).method.module === OverlayModule
end
let ms = Base._methods_by_ftype(Tuple{typeof(sin), Int}, OverlayModule.mt, 1, Base.get_world_counter())
@test isempty(ms)
end

# precompilation

Expand Down
25 changes: 25 additions & 0 deletions test/staged.jl
Original file line number Diff line number Diff line change
Expand Up @@ -377,3 +377,28 @@ let
ir, _ = Base.code_ircode(f_unreachable, ()) |> only
@test length(ir.cfg.blocks) == 1
end

# Test that `Core.CachedGenerator` works as expected
struct Generator54916 <: Core.CachedGenerator end
function (::Generator54916)(world::UInt, source::LineNumberNode, args...)
stub = Core.GeneratedFunctionStub(identity, Core.svec(:doit54916, :func, :arg), Core.svec())
return stub(world, source, :(func(arg)))
end
@eval function doit54916(func, arg)
$(Expr(:meta, :generated, Generator54916()))
$(Expr(:meta, :generated_only))
end
@test doit54916(sin, 1) == sin(1)
let mi = only(methods(doit54916)).specializations
ci = mi.cache::Core.CodeInstance
found = false
while true
if ci.owner === :uninferred && ci.inferred isa Core.CodeInfo
found = true
break
end
isdefined(ci, :next) || break
ci = ci.next
end
@test found
end

0 comments on commit b5d0b90

Please sign in to comment.