From f713a4b592bb9abda1007dbfdb1bc1ec22f621ef Mon Sep 17 00:00:00 2001 From: Harmen Stoppels Date: Wed, 2 Jan 2019 09:59:13 +0100 Subject: [PATCH] Add test macro for relaxed inference of small unions (#27516) --- stdlib/Test/src/Test.jl | 59 +++++++++++++++++++++++++++--------- stdlib/Test/test/runtests.jl | 32 +++++++++---------- 2 files changed, 59 insertions(+), 32 deletions(-) diff --git a/stdlib/Test/src/Test.jl b/stdlib/Test/src/Test.jl index cefef17004517..4818412fe5a29 100644 --- a/stdlib/Test/src/Test.jl +++ b/stdlib/Test/src/Test.jl @@ -29,6 +29,7 @@ import Distributed: myid using Random using Random: AbstractRNG, GLOBAL_RNG using InteractiveUtils: gen_call_with_extracted_types +using Core.Compiler: typesubtract #----------------------------------------------------------------------- @@ -1262,39 +1263,65 @@ end _args_and_call(args...; kwargs...) = (args[1:end-1], kwargs, args[end](args[1:end-1]...; kwargs...)) _materialize_broadcasted(f, args...) = Broadcast.materialize(Broadcast.broadcasted(f, args...)) """ - @inferred f(x) + @inferred [AllowedType] f(x) -Tests that the call expression `f(x)` returns a value of the same type -inferred by the compiler. It is useful to check for type stability. +Tests that the call expression `f(x)` returns a value of the same type inferred by the +compiler. It is useful to check for type stability. -`f(x)` can be any call expression. -Returns the result of `f(x)` if the types match, -and an `Error` `Result` if it finds different types. +`f(x)` can be any call expression. Returns the result of `f(x)` if the types match, and an +`Error` `Result` if it finds different types. + +Optionally, `AllowedType` relaxes the test, by making it pass when either the type of `f(x)` +matches the inferred type modulo `AllowedType`, or when the return type is a subtype of +`AllowedType`. This is useful when testing type stability of functions returning a small +union such as `Union{Nothing, T}` or `Union{Missing, T}`. ```jldoctest; setup = :(using InteractiveUtils), filter = r"begin\\n(.|\\n)*end" -julia> f(a, b, c) = b > 1 ? 1 : 1.0 +julia> f(a) = a > 1 ? 1 : 1.0 f (generic function with 1 method) -julia> typeof(f(1, 2, 3)) +julia> typeof(f(2)) Int64 -julia> @code_warntype f(1, 2, 3) +julia> @code_warntype f(2) Body::UNION{FLOAT64, INT64} -1 ─ %1 = (Base.slt_int)(1, b)::Bool +1 ─ %1 = (Base.slt_int)(1, a)::Bool └── goto #3 if not %1 2 ─ return 1 3 ─ return 1.0 -julia> @inferred f(1, 2, 3) +julia> @inferred f(2) ERROR: return type Int64 does not match inferred return type Union{Float64, Int64} -Stacktrace: [...] julia> @inferred max(1, 2) 2 + +julia> g(a) = a < 10 ? missing : 1.0 +g (generic function with 1 method) + +julia> @inferred g(20) +ERROR: return type Float64 does not match inferred return type Union{Missing, Float64} +[...] + +julia> @inferred Missing g(20) +1.0 + +julia> h(a) = a < 10 ? missing : f(a) +h (generic function with 1 method) + +julia> @inferred Missing h(20) +ERROR: return type Int64 does not match inferred return type Union{Missing, Float64, Int64} +[...] ``` """ macro inferred(ex) + _inferred(ex, __module__) +end +macro inferred(allow, ex) + _inferred(ex, __module__, allow) +end +function _inferred(ex, mod, allow = :(Union{})) if Meta.isexpr(ex, :ref) ex = Expr(:call, :getindex, ex.args...) end @@ -1307,13 +1334,15 @@ macro inferred(ex) end Base.remove_linenums!(quote let + allow = $(esc(allow)) + allow isa Type || throw(ArgumentError("@inferred requires a type as second argument")) $(if any(a->(Meta.isexpr(a, :kw) || Meta.isexpr(a, :parameters)), ex.args) # Has keywords args = gensym() kwargs = gensym() quote $(esc(args)), $(esc(kwargs)), result = $(esc(Expr(:call, _args_and_call, ex.args[2:end]..., ex.args[1]))) - inftypes = $(gen_call_with_extracted_types(__module__, Base.return_types, :($(ex.args[1])($(args)...; $(kwargs)...)))) + inftypes = $(gen_call_with_extracted_types(mod, Base.return_types, :($(ex.args[1])($(args)...; $(kwargs)...)))) end else # No keywords @@ -1324,8 +1353,8 @@ macro inferred(ex) end end) @assert length(inftypes) == 1 - rettype = isa(result, Type) ? Type{result} : typeof(result) - rettype == inftypes[1] || error("return type $rettype does not match inferred return type $(inftypes[1])") + rettype = result isa Type ? Type{result} : typeof(result) + rettype <: allow || rettype == typesubtract(inftypes[1], allow) || error("return type $rettype does not match inferred return type $(inftypes[1])") result end end) diff --git a/stdlib/Test/test/runtests.jl b/stdlib/Test/test/runtests.jl index 5809ec8756239..ad7d01dd39459 100644 --- a/stdlib/Test/test/runtests.jl +++ b/stdlib/Test/test/runtests.jl @@ -491,13 +491,15 @@ for i in 1:6 end # test @inferred -function uninferrable_function(i) - q = [1, "1"] - return q[i] -end - +uninferrable_function(i) = (1, "1")[i] +uninferrable_small_union(i) = (1, nothing)[i] @test_throws ErrorException @inferred(uninferrable_function(1)) @test @inferred(identity(1)) == 1 +@test @inferred(Nothing, uninferrable_small_union(1)) === 1 +@test @inferred(Nothing, uninferrable_small_union(2)) === nothing +@test_throws ErrorException @inferred(Missing, uninferrable_small_union(1)) +@test_throws ErrorException @inferred(Missing, uninferrable_small_union(2)) +@test_throws ArgumentError @inferred(nothing, uninferrable_small_union(1)) # Ensure @inferred only evaluates the arguments once inferred_test_global = 0 @@ -512,8 +514,8 @@ end struct SillyArray <: AbstractArray{Float64,1} end Base.getindex(a::SillyArray, i) = rand() > 0.5 ? 0 : false @testset "@inferred works with A[i] expressions" begin - @test @inferred((1:3)[2]) == 2 - test_result = @test_throws ErrorException @inferred(SillyArray()[2]) + @test (@inferred (1:3)[2]) == 2 + test_result = @test_throws ErrorException (@inferred SillyArray()[2]) @test occursin("Bool", test_result.value.msg) end # Issue #14928 @@ -522,16 +524,12 @@ end # Issue #17105 # @inferred with kwargs -function inferrable_kwtest(x; y=1) - 2x -end -function uninferrable_kwtest(x; y=1) - 2x+y -end -@test @inferred(inferrable_kwtest(1)) == 2 -@test @inferred(inferrable_kwtest(1; y=1)) == 2 -@test @inferred(uninferrable_kwtest(1)) == 3 -@test @inferred(uninferrable_kwtest(1; y=2)) == 4 +inferrable_kwtest(x; y=1) = 2x +uninferrable_kwtest(x; y=1) = 2x+y +@test (@inferred inferrable_kwtest(1)) == 2 +@test (@inferred inferrable_kwtest(1; y=1)) == 2 +@test (@inferred uninferrable_kwtest(1)) == 3 +@test (@inferred uninferrable_kwtest(1; y=2)) == 4 @test_throws ErrorException @testset "$(error())" for i in 1:10 end