Skip to content

Commit

Permalink
fix JuliaLang#29269, type intersection bug in union parameters with t…
Browse files Browse the repository at this point in the history
…ypevars (JuliaLang#29406)

also fixes JuliaLang#25752
  • Loading branch information
JeffBezanson committed Nov 27, 2018
1 parent 6778bef commit 4f8a7b9
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 40 deletions.
4 changes: 4 additions & 0 deletions base/missing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ nonmissingtype(::Type{Any}) = Any
for U in (:Nothing, :Missing)
@eval begin
promote_rule(::Type{$U}, ::Type{T}) where {T} = Union{T, $U}
promote_rule(::Type{Union{S,$U}}, ::Type{Any}) where {S} = Any
promote_rule(::Type{Union{S,$U}}, ::Type{T}) where {T,S} = Union{promote_type(T, S), $U}
promote_rule(::Type{Any}, ::Type{$U}) = Any
promote_rule(::Type{$U}, ::Type{Any}) = Any
Expand All @@ -37,13 +38,16 @@ end
promote_rule(::Type{Union{Nothing, Missing}}, ::Type{Any}) = Any
promote_rule(::Type{Union{Nothing, Missing}}, ::Type{T}) where {T} =
Union{Nothing, Missing, T}
promote_rule(::Type{Union{Nothing, Missing, S}}, ::Type{Any}) where {S} = Any
promote_rule(::Type{Union{Nothing, Missing, S}}, ::Type{T}) where {T,S} =
Union{Nothing, Missing, promote_type(T, S)}

convert(::Type{Union{T, Missing}}, x::Union{T, Missing}) where {T} = x
convert(::Type{Union{T, Missing}}, x) where {T} = convert(T, x)
# To fix ambiguities
convert(::Type{Missing}, ::Missing) = missing
convert(::Type{Union{Nothing, Missing}}, x::Union{Nothing, Missing}) = x
convert(::Type{Union{Nothing, Missing, T}}, x::Union{Nothing, Missing, T}) where {T} = x
convert(::Type{Union{Nothing, Missing}}, x) =
throw(MethodError(convert, (Union{Nothing, Missing}, x)))
# To print more appropriate message than "T not defined"
Expand Down
3 changes: 2 additions & 1 deletion base/some.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@ promote_rule(::Type{Some{T}}, ::Type{Nothing}) where {T} = Union{Some{T}, Nothin
convert(::Type{Some{T}}, x::Some) where {T} = Some{T}(convert(T, x.value))
convert(::Type{Union{Some{T}, Nothing}}, x::Some) where {T} = convert(Some{T}, x)

convert(::Type{Union{T, Nothing}}, x::Union{T, Nothing}) where {T} = x
convert(::Type{Union{T, Nothing}}, x::Any) where {T} = convert(T, x)
convert(::Type{Nothing}, x::Any) = throw(MethodError(convert, (Nothing, x)))
convert(::Type{Nothing}, x::Nothing) = nothing
convert(::Type{Nothing}, x::Any) = throw(MethodError(convert, (Nothing, x)))

function show(io::IO, x::Some)
if get(io, :typeinfo, Any) == typeof(x)
Expand Down
71 changes: 35 additions & 36 deletions src/subtype.c
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ static int var_outside(jl_stenv_t *e, jl_tvar_t *x, jl_tvar_t *y)
return 0;
}

static jl_value_t *intersect_ufirst(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int depth);
static jl_value_t *intersect_aside(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int depth);

// check that type var `b` is <: `a`, and update b's upper bound.
static int var_lt(jl_tvar_t *b, jl_value_t *a, jl_stenv_t *e, int param)
Expand All @@ -539,7 +539,7 @@ static int var_lt(jl_tvar_t *b, jl_value_t *a, jl_stenv_t *e, int param)
// for this to work we need to compute issub(left,right) before issub(right,left),
// since otherwise the issub(a, bb.ub) check in var_gt becomes vacuous.
if (e->intersection) {
jl_value_t *ub = intersect_ufirst(bb->ub, a, e, bb->depth0);
jl_value_t *ub = intersect_aside(bb->ub, a, e, bb->depth0);
if (ub != (jl_value_t*)b)
bb->ub = ub;
}
Expand Down Expand Up @@ -1328,16 +1328,32 @@ JL_DLLEXPORT int jl_isa(jl_value_t *x, jl_value_t *t)

static jl_value_t *intersect(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param);

static jl_value_t *intersect_all(jl_value_t *x, jl_value_t *y, jl_stenv_t *e);

// intersect in nested union environment, similar to subtype_ccheck
static jl_value_t *intersect_aside(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int depth)
{
jl_value_t *res;
int savedepth = e->invdepth;
jl_unionstate_t oldRunions = e->Runions;
e->invdepth = depth;

res = intersect_all(x, y, e);

e->Runions = oldRunions;
e->invdepth = savedepth;
return res;
}

static jl_value_t *intersect_union(jl_value_t *x, jl_uniontype_t *u, jl_stenv_t *e, int8_t R, int param)
{
if (param == 2 || (!jl_has_free_typevars(x) && !jl_has_free_typevars((jl_value_t*)u))) {
jl_value_t *a=NULL, *b=NULL, *save=NULL; jl_savedenv_t se;
JL_GC_PUSH3(&a, &b, &save);
save_env(e, &save, &se);
a = R ? intersect(x, u->a, e, param) : intersect(u->a, x, e, param);
restore_env(e, NULL, &se);
b = R ? intersect(x, u->b, e, param) : intersect(u->b, x, e, param);
free(se.buf);
jl_value_t *a=NULL, *b=NULL;
JL_GC_PUSH2(&a, &b);
jl_unionstate_t oldRunions = e->Runions;
a = R ? intersect_all(x, u->a, e) : intersect_all(u->a, x, e);
b = R ? intersect_all(x, u->b, e) : intersect_all(u->b, x, e);
e->Runions = oldRunions;
jl_value_t *i = simple_join(a,b);
JL_GC_POP();
return i;
Expand All @@ -1347,21 +1363,6 @@ static jl_value_t *intersect_union(jl_value_t *x, jl_uniontype_t *u, jl_stenv_t
return R ? intersect(x, choice, e, param) : intersect(choice, x, e, param);
}

static jl_value_t *intersect_ufirst(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int depth)
{
jl_value_t *res;
int savedepth = e->invdepth;
e->invdepth = depth;
if (jl_is_uniontype(x) && jl_is_typevar(y))
res = intersect_union(y, (jl_uniontype_t*)x, e, 0, 0);
else if (jl_is_typevar(x) && jl_is_uniontype(y))
res = intersect_union(x, (jl_uniontype_t*)y, e, 1, 0);
else
res = intersect(x, y, e, 0);
e->invdepth = savedepth;
return res;
}

// set a variable to a non-type constant
static jl_value_t *set_var_to_const(jl_varbinding_t *bb, jl_value_t *v JL_MAYBE_UNROOTED, jl_varbinding_t *othervar)
{
Expand All @@ -1386,13 +1387,11 @@ static jl_value_t *set_var_to_const(jl_varbinding_t *bb, jl_value_t *v JL_MAYBE_

static int try_subtype_in_env(jl_value_t *a, jl_value_t *b, jl_stenv_t *e)
{
jl_value_t *root=NULL; jl_savedenv_t se; int ret=0;
jl_value_t *root=NULL; jl_savedenv_t se;
JL_GC_PUSH1(&root);
save_env(e, &root, &se);
if (subtype_in_env(a, b, e))
ret = 1;
else
restore_env(e, root, &se);
int ret = subtype_in_env(a, b, e);
restore_env(e, root, &se);
free(se.buf);
JL_GC_POP();
return ret;
Expand All @@ -1402,15 +1401,15 @@ static jl_value_t *intersect_var(jl_tvar_t *b, jl_value_t *a, jl_stenv_t *e, int
{
jl_varbinding_t *bb = lookup(e, b);
if (bb == NULL)
return R ? intersect_ufirst(a, b->ub, e, 0) : intersect_ufirst(b->ub, a, e, 0);
return R ? intersect_aside(a, b->ub, e, 0) : intersect_aside(b->ub, a, e, 0);
if (bb->lb == bb->ub && jl_is_typevar(bb->lb))
return intersect(a, bb->lb, e, param);
if (!jl_is_type(a) && !jl_is_typevar(a))
return set_var_to_const(bb, a, NULL);
int d = bb->depth0;
jl_value_t *root=NULL; jl_savedenv_t se;
if (param == 2) {
jl_value_t *ub = R ? intersect_ufirst(a, bb->ub, e, d) : intersect_ufirst(bb->ub, a, e, d);
jl_value_t *ub = R ? intersect_aside(a, bb->ub, e, d) : intersect_aside(bb->ub, a, e, d);
JL_GC_PUSH2(&ub, &root);
if (!jl_has_free_typevars(ub) && !jl_has_free_typevars(bb->lb)) {
save_env(e, &root, &se);
Expand Down Expand Up @@ -1450,10 +1449,10 @@ static jl_value_t *intersect_var(jl_tvar_t *b, jl_value_t *a, jl_stenv_t *e, int
if (try_subtype_in_env(bb->ub, a, e))
return (jl_value_t*)b;
}
return R ? intersect_ufirst(a, bb->ub, e, d) : intersect_ufirst(bb->ub, a, e, d);
return R ? intersect_aside(a, bb->ub, e, d) : intersect_aside(bb->ub, a, e, d);
}
else if (bb->concrete || bb->constraintkind == 1) {
jl_value_t *ub = R ? intersect_ufirst(a, bb->ub, e, d) : intersect_ufirst(bb->ub, a, e, d);
jl_value_t *ub = R ? intersect_aside(a, bb->ub, e, d) : intersect_aside(bb->ub, a, e, d);
JL_GC_PUSH1(&ub);
if (ub == jl_bottom_type || !subtype_in_env(bb->lb, a, e)) {
JL_GC_POP();
Expand All @@ -1473,7 +1472,7 @@ static jl_value_t *intersect_var(jl_tvar_t *b, jl_value_t *a, jl_stenv_t *e, int
return a;
}
assert(bb->constraintkind == 3);
jl_value_t *ub = R ? intersect_ufirst(a, bb->ub, e, d) : intersect_ufirst(bb->ub, a, e, d);
jl_value_t *ub = R ? intersect_aside(a, bb->ub, e, d) : intersect_aside(bb->ub, a, e, d);
if (ub == jl_bottom_type)
return jl_bottom_type;
if (jl_is_typevar(a))
Expand All @@ -1494,7 +1493,7 @@ static jl_value_t *intersect_var(jl_tvar_t *b, jl_value_t *a, jl_stenv_t *e, int
root = NULL;
JL_GC_PUSH2(&root, &ub);
save_env(e, &root, &se);
jl_value_t *ii = R ? intersect_ufirst(a, bb->lb, e, d) : intersect_ufirst(bb->lb, a, e, d);
jl_value_t *ii = R ? intersect_aside(a, bb->lb, e, d) : intersect_aside(bb->lb, a, e, d);
if (ii == jl_bottom_type) {
restore_env(e, root, &se);
ii = (jl_value_t*)b;
Expand Down Expand Up @@ -2050,7 +2049,7 @@ static jl_value_t *intersect(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int pa
return jl_bottom_type;
jl_value_t *ub=NULL, *lb=NULL;
JL_GC_PUSH2(&lb, &ub);
ub = intersect_ufirst(xub, yub, e, xx ? xx->depth0 : 0);
ub = intersect_aside(xub, yub, e, xx ? xx->depth0 : 0);
lb = simple_join(xlb, ylb);
if (yy) {
if (lb != y)
Expand Down
2 changes: 2 additions & 0 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,9 @@ mul!(out::AbstractMatrix, A::Transpose{<:Any,<:Diagonal}, in::StridedMatrix) = o
*(transD::Transpose{<:Any,<:Diagonal}, transA::Transpose{<:Any,<:RealHermSymComplexSym}) = transD * transA.parent
*(adjA::Adjoint{<:Any,<:RealHermSymComplexHerm}, adjD::Adjoint{<:Any,<:Diagonal}) = adjA.parent * adjD
*(adjD::Adjoint{<:Any,<:Diagonal}, adjA::Adjoint{<:Any,<:RealHermSymComplexHerm}) = adjD * adjA.parent
mul!(C::AbstractMatrix, A::Adjoint{<:Any,<:Diagonal}, B::Adjoint{<:Any,<:RealHermSym}) = mul!(C, A, B.parent)
mul!(C::AbstractMatrix, A::Adjoint{<:Any,<:Diagonal}, B::Adjoint{<:Any,<:RealHermSymComplexHerm}) = mul!(C, A, B.parent)
mul!(C::AbstractMatrix, A::Transpose{<:Any,<:Diagonal}, B::Transpose{<:Any,<:RealHermSym}) = mul!(C, A, B.parent)
mul!(C::AbstractMatrix, A::Transpose{<:Any,<:Diagonal}, B::Transpose{<:Any,<:RealHermSymComplexSym}) = mul!(C, A, B.parent)
mul!(C::AbstractMatrix, A::Adjoint{<:Any,<:Diagonal}, B::Adjoint{<:Any,<:RealHermSymComplexSym}) = C .= adjoint.(A.parent.diag) .* B
mul!(C::AbstractMatrix, A::Transpose{<:Any,<:Diagonal}, B::Transpose{<:Any,<:RealHermSymComplexHerm}) = C .= transpose.(A.parent.diag) .* B
Expand Down
9 changes: 8 additions & 1 deletion stdlib/LinearAlgebra/src/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,8 @@ end
convert(T::Type{<:Symmetric}, m::Union{Symmetric,Hermitian}) = m isa T ? m : T(m)
convert(T::Type{<:Hermitian}, m::Union{Symmetric,Hermitian}) = m isa T ? m : T(m)

const HermOrSym{T,S} = Union{Hermitian{T,S}, Symmetric{T,S}}
const HermOrSym{T, S} = Union{Hermitian{T,S}, Symmetric{T,S}}
const RealHermSym{T<:Real,S} = Union{Hermitian{T,S}, Symmetric{T,S}}
const RealHermSymComplexHerm{T<:Real,S} = Union{Hermitian{T,S}, Symmetric{T,S}, Hermitian{Complex{T},S}}
const RealHermSymComplexSym{T<:Real,S} = Union{Hermitian{T,S}, Symmetric{T,S}, Symmetric{Complex{T},S}}

Expand Down Expand Up @@ -427,11 +428,17 @@ mul!(C::StridedMatrix{T}, A::StridedMatrix{T}, B::Hermitian{T,<:StridedMatrix})
*(A::AbstractMatrix, adjB::Adjoint{<:Any,<:RealHermSymComplexHerm}) = A * adjB.parent

# ambiguities with transposed AbstractMatrix methods in linalg/matmul.jl
*(transA::Transpose{<:Any,<:RealHermSym}, transB::Transpose{<:Any,<:RealHermSym}) = transA * transB.parent
*(transA::Transpose{<:Any,<:RealHermSym}, transB::Transpose{<:Any,<:RealHermSymComplexSym}) = transA * transB.parent
*(transA::Transpose{<:Any,<:RealHermSymComplexSym}, transB::Transpose{<:Any,<:RealHermSymComplexSym}) = transA.parent * transB.parent
*(transA::Transpose{<:Any,<:RealHermSymComplexSym}, transB::Transpose{<:Any,<:RealHermSym}) = transA.parent * transB
*(transA::Transpose{<:Any,<:RealHermSymComplexSym}, transB::Transpose{<:Any,<:RealHermSymComplexHerm}) = transA.parent * transB
*(transA::Transpose{<:Any,<:RealHermSymComplexHerm}, transB::Transpose{<:Any,<:RealHermSymComplexSym}) = transA * transB.parent
*(adjA::Adjoint{<:Any,<:RealHermSym}, adjB::Adjoint{<:Any,<:RealHermSym}) = adjA * adjB.parent
*(adjA::Adjoint{<:Any,<:RealHermSymComplexHerm}, adjB::Adjoint{<:Any,<:RealHermSymComplexHerm}) = adjA.parent * adjB.parent
*(adjA::Adjoint{<:Any,<:RealHermSym}, adjB::Adjoint{<:Any,<:RealHermSymComplexHerm}) = adjA * adjB.parent
*(adjA::Adjoint{<:Any,<:RealHermSymComplexSym}, adjB::Adjoint{<:Any,<:RealHermSymComplexHerm}) = adjA * adjB.parent
*(adjA::Adjoint{<:Any,<:RealHermSymComplexHerm}, adjB::Adjoint{<:Any,<:RealHermSym}) = adjA.parent * adjB
*(adjA::Adjoint{<:Any,<:RealHermSymComplexHerm}, adjB::Adjoint{<:Any,<:RealHermSymComplexSym}) = adjA.parent * adjB

# ambiguities with AbstractTriangular
Expand Down
7 changes: 7 additions & 0 deletions test/ambiguous.jl
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ end
pop!(need_to_handle_undef_sparam, which(Core.Compiler.convert, (Type{Union{T, Nothing}} where T, Core.Compiler.Some)))
pop!(need_to_handle_undef_sparam, which(Core.Compiler.convert, Tuple{Type{Tuple{Vararg{Int}}}, Tuple{}}))
pop!(need_to_handle_undef_sparam, which(Core.Compiler.convert, Tuple{Type{Tuple{Vararg{Int}}}, Tuple{Int8}}))
pop!(need_to_handle_undef_sparam, which(Core.Compiler.convert, Tuple{Type{Union{Nothing,T}},Union{Nothing,T}} where T))
@test need_to_handle_undef_sparam == Set()
end
let need_to_handle_undef_sparam =
Expand All @@ -299,6 +300,12 @@ end
pop!(need_to_handle_undef_sparam, which(Base.convert, (Type{Union{T, Nothing}} where T, Some)))
pop!(need_to_handle_undef_sparam, which(Base.convert, Tuple{Type{Tuple{Vararg{Int}}}, Tuple{}}))
pop!(need_to_handle_undef_sparam, which(Base.convert, Tuple{Type{Tuple{Vararg{Int}}}, Tuple{Int8}}))
pop!(need_to_handle_undef_sparam, which(Base.convert, Tuple{Type{Union{Nothing,T}},Union{Nothing,T}} where T))
pop!(need_to_handle_undef_sparam, which(Base.convert, Tuple{Type{Union{Missing,T}},Union{Missing,T}} where T))
pop!(need_to_handle_undef_sparam, which(Base.convert, Tuple{Type{Union{Missing,Nothing,T}},Union{Missing,Nothing,T}} where T))
pop!(need_to_handle_undef_sparam, which(Base.promote_rule, Tuple{Type{Union{Nothing,T}},Type{Any}} where T))
pop!(need_to_handle_undef_sparam, which(Base.promote_rule, Tuple{Type{Union{Missing,T}},Type{Any}} where T))
pop!(need_to_handle_undef_sparam, which(Base.promote_rule, Tuple{Type{Union{Missing,Nothing,T}},Type{Any}} where T))
@test need_to_handle_undef_sparam == Set()
end
end
Expand Down
19 changes: 17 additions & 2 deletions test/subtype.jl
Original file line number Diff line number Diff line change
Expand Up @@ -847,10 +847,14 @@ function test_intersection()
@testintersect(Ref{@UnionAll T @UnionAll S Tuple{T,S}},
Ref{@UnionAll T Tuple{T,T}}, Bottom)

# both of these answers seem acceptable
#@testintersect(Tuple{T,T} where T<:Union{UpperTriangular, UnitUpperTriangular},
# Tuple{AbstractArray{T,N}, AbstractArray{T,N}} where N where T,
# Union{Tuple{T,T} where T<:UpperTriangular,
# Tuple{T,T} where T<:UnitUpperTriangular})
@testintersect(Tuple{T,T} where T<:Union{UpperTriangular, UnitUpperTriangular},
Tuple{AbstractArray{T,N}, AbstractArray{T,N}} where N where T,
Union{Tuple{T,T} where T<:UpperTriangular,
Tuple{T,T} where T<:UnitUpperTriangular})
Tuple{T,T} where T<:Union{UpperTriangular, UnitUpperTriangular})

@testintersect(DataType, Type, DataType)
@testintersect(DataType, Type{T} where T<:Integer, Type{T} where T<:Integer)
Expand Down Expand Up @@ -1358,6 +1362,17 @@ end
Tuple{Val{2}, Vararg{Val{3}}},
Union{})

# issue #25752
@testintersect(Base.RefValue, Ref{Union{Int,T}} where T,
Base.RefValue{Union{Int,T}} where T)
# issue #29269
@testintersect((Tuple{Int, Array{T}} where T),
(Tuple{Any, Vector{Union{Missing,T}}} where T),
(Tuple{Int, Vector{Union{Missing,T}}} where T))
@testintersect((Tuple{Int, Array{T}} where T),
(Tuple{Any, Vector{Union{Missing,Nothing,T}}} where T),
(Tuple{Int, Vector{Union{Missing,Nothing,T}}} where T))

# issue #29955
struct M29955{T, TV<:AbstractVector{T}}
end
Expand Down

0 comments on commit 4f8a7b9

Please sign in to comment.