Skip to content

Commit

Permalink
Add test macro for relaxed inference of small unions
Browse files Browse the repository at this point in the history
  • Loading branch information
haampie committed Dec 21, 2018
1 parent 7c3904a commit fa19651
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 35 deletions.
65 changes: 47 additions & 18 deletions stdlib/Test/src/Test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import Distributed: myid
using Random
using Random: AbstractRNG, GLOBAL_RNG
using InteractiveUtils: gen_call_with_extracted_types
using Core.Compiler: typesubtract

#-----------------------------------------------------------------------

Expand Down Expand Up @@ -1262,39 +1263,65 @@ end
_args_and_call(args...; kwargs...) = (args[1:end-1], kwargs, args[end](args[1:end-1]...; kwargs...))
_materialize_broadcasted(f, args...) = Broadcast.materialize(Broadcast.broadcasted(f, args...))
"""
@inferred f(x)
@inferred [AllowedType] f(x)
Tests that the call expression `f(x)` returns a value of the same type
inferred by the compiler. It is useful to check for type stability.
Tests that the call expression `f(x)` returns a value of the same type inferred by the
compiler. It is useful to check for type stability.
`f(x)` can be any call expression.
Returns the result of `f(x)` if the types match,
and an `Error` `Result` if it finds different types.
`f(x)` can be any call expression. Returns the result of `f(x)` if the types match, and an
`Error` `Result` if it finds different types.
Optionally, `AllowedType` relaxes the test, by making it pass when either the type of `f(x)`
matches the inferred type modulo `AllowedType`, or when the return type is a subtype of
`AllowedType`. This is useful when testing type stability of functions returning a small
union such as `Union{Nothing, T}` or `Union{Missing, T}`.
```jldoctest; setup = :(using InteractiveUtils), filter = r"begin\\n(.|\\n)*end"
julia> f(a, b, c) = b > 1 ? 1 : 1.0
julia> f(a) = a > 1 ? 1 : 1.0
f (generic function with 1 method)
julia> typeof(f(1, 2, 3))
julia> typeof(f(2))
Int64
julia> @code_warntype f(1, 2, 3)
julia> @code_warntype f(2)
Body::UNION{FLOAT64, INT64}
1 ─ %1 = (Base.slt_int)(1, b)::Bool
└── goto #3 if not %1
2 ─ return 1
3 ─ return 1.0
1 1 ─ %1 = (Base.slt_int)(1, a)::Bool
└── goto #3 if not %1
2 ─ return 1
3 ─ return 1.0
julia> @inferred f(1, 2, 3)
julia> @inferred f(2)
ERROR: return type Int64 does not match inferred return type Union{Float64, Int64}
Stacktrace:
[...]
julia> @inferred max(1, 2)
2
julia> g(a) = a < 10 ? missing : 1.0
g (generic function with 1 method)
julia> @inferred g(20)
ERROR: return type Float64 does not match inferred return type Union{Missing, Float64}
[...]
julia> @inferred Missing g(20)
1.0
julia> h(a) = a < 10 ? missing : f(a)
h (generic function with 1 method)
julia> @inferred Missing h(20)
ERROR: return type Int64 does not match inferred return type Union{Missing, Float64, Int64}
[...]
```
"""
macro inferred(ex)
_inferred(ex, __module__)
end
macro inferred(allow, ex)
_inferred(ex, __module__, allow)
end
function _inferred(ex, mod, allow = :(Union{}))
if Meta.isexpr(ex, :ref)
ex = Expr(:call, :getindex, ex.args...)
end
Expand All @@ -1307,13 +1334,15 @@ macro inferred(ex)
end
Base.remove_linenums!(quote
let
allow = $(esc(allow))
allow isa Type || throw(ArgumentError("@inferred requires a type as second argument"))
$(if any(a->(Meta.isexpr(a, :kw) || Meta.isexpr(a, :parameters)), ex.args)
# Has keywords
args = gensym()
kwargs = gensym()
quote
$(esc(args)), $(esc(kwargs)), result = $(esc(Expr(:call, _args_and_call, ex.args[2:end]..., ex.args[1])))
inftypes = $(gen_call_with_extracted_types(__module__, Base.return_types, :($(ex.args[1])($(args)...; $(kwargs)...))))
inftypes = $(gen_call_with_extracted_types(mod, Base.return_types, :($(ex.args[1])($(args)...; $(kwargs)...))))
end
else
# No keywords
Expand All @@ -1324,8 +1353,8 @@ macro inferred(ex)
end
end)
@assert length(inftypes) == 1
rettype = isa(result, Type) ? Type{result} : typeof(result)
rettype == inftypes[1] || error("return type $rettype does not match inferred return type $(inftypes[1])")
rettype = result isa Type ? Type{result} : typeof(result)
rettype <: allow || rettype == typesubtract(inftypes[1], allow) || error("return type $rettype does not match inferred return type $(inftypes[1])")
result
end
end)
Expand Down
32 changes: 15 additions & 17 deletions stdlib/Test/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -491,13 +491,15 @@ for i in 1:6
end

# test @inferred
function uninferrable_function(i)
q = [1, "1"]
return q[i]
end

uninferrable_function(i) = (1, "1")[i]
uninferrable_small_union(i) = (1, nothing)[i]
@test_throws ErrorException @inferred(uninferrable_function(1))
@test @inferred(identity(1)) == 1
@test @inferred(Nothing, uninferrable_small_union(1)) === 1
@test @inferred(Nothing, uninferrable_small_union(2)) === nothing
@test_throws ErrorException @inferred(Missing, uninferrable_small_union(1))
@test_throws ErrorException @inferred(Missing, uninferrable_small_union(2))
@test_throws ArgumentError @inferred(nothing, uninferrable_small_union(1))

# Ensure @inferred only evaluates the arguments once
inferred_test_global = 0
Expand All @@ -512,8 +514,8 @@ end
struct SillyArray <: AbstractArray{Float64,1} end
Base.getindex(a::SillyArray, i) = rand() > 0.5 ? 0 : false
@testset "@inferred works with A[i] expressions" begin
@test @inferred((1:3)[2]) == 2
test_result = @test_throws ErrorException @inferred(SillyArray()[2])
@test (@inferred (1:3)[2]) == 2
test_result = @test_throws ErrorException (@inferred SillyArray()[2])
@test occursin("Bool", test_result.value.msg)
end
# Issue #14928
Expand All @@ -522,16 +524,12 @@ end

# Issue #17105
# @inferred with kwargs
function inferrable_kwtest(x; y=1)
2x
end
function uninferrable_kwtest(x; y=1)
2x+y
end
@test @inferred(inferrable_kwtest(1)) == 2
@test @inferred(inferrable_kwtest(1; y=1)) == 2
@test @inferred(uninferrable_kwtest(1)) == 3
@test @inferred(uninferrable_kwtest(1; y=2)) == 4
inferrable_kwtest(x; y=1) = 2x
uninferrable_kwtest(x; y=1) = 2x+y
@test (@inferred inferrable_kwtest(1)) == 2
@test (@inferred inferrable_kwtest(1; y=1)) == 2
@test (@inferred uninferrable_kwtest(1)) == 3
@test (@inferred uninferrable_kwtest(1; y=2)) == 4

@test_throws ErrorException @testset "$(error())" for i in 1:10
end
Expand Down

0 comments on commit fa19651

Please sign in to comment.