Skip to content

Commit

Permalink
Subtype: Improve simple_meet resolution for Union inputs (JuliaLa…
Browse files Browse the repository at this point in the history
…ng#49376)

* Improve `simple_meet` resolution.

* Fix for many-to-one cases.

* Test disjoint via `jl_has_empty_intersection`
  • Loading branch information
N5N3 committed Apr 22, 2023
1 parent 4044096 commit 6b79e8c
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 26 deletions.
135 changes: 121 additions & 14 deletions src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -485,19 +485,19 @@ static int union_sort_cmp(jl_value_t *a, jl_value_t *b) JL_NOTSAFEPOINT
}
}

static int count_union_components(jl_value_t **types, size_t n)
static int count_union_components(jl_value_t **types, size_t n, int widen)
{
size_t i, c = 0;
for (i = 0; i < n; i++) {
jl_value_t *e = types[i];
while (jl_is_uniontype(e)) {
jl_uniontype_t *u = (jl_uniontype_t*)e;
c += count_union_components(&u->a, 1);
c += count_union_components(&u->a, 1, widen);
e = u->b;
}
if (jl_is_unionall(e) && jl_is_uniontype(jl_unwrap_unionall(e))) {
if (widen && jl_is_unionall(e) && jl_is_uniontype(jl_unwrap_unionall(e))) {
jl_uniontype_t *u = (jl_uniontype_t*)jl_unwrap_unionall(e);
c += count_union_components(&u->a, 2);
c += count_union_components(&u->a, 2, widen);
}
else {
c++;
Expand All @@ -506,21 +506,21 @@ static int count_union_components(jl_value_t **types, size_t n)
return c;
}

static void flatten_type_union(jl_value_t **types, size_t n, jl_value_t **out, size_t *idx)
static void flatten_type_union(jl_value_t **types, size_t n, jl_value_t **out, size_t *idx, int widen)
{
size_t i;
for (i = 0; i < n; i++) {
jl_value_t *e = types[i];
while (jl_is_uniontype(e)) {
jl_uniontype_t *u = (jl_uniontype_t*)e;
flatten_type_union(&u->a, 1, out, idx);
flatten_type_union(&u->a, 1, out, idx, widen);
e = u->b;
}
if (jl_is_unionall(e) && jl_is_uniontype(jl_unwrap_unionall(e))) {
if (widen && jl_is_unionall(e) && jl_is_uniontype(jl_unwrap_unionall(e))) {
// flatten this UnionAll into place by switching the union and unionall
jl_uniontype_t *u = (jl_uniontype_t*)jl_unwrap_unionall(e);
size_t old_idx = 0;
flatten_type_union(&u->a, 2, out, idx);
flatten_type_union(&u->a, 2, out, idx, widen);
for (; old_idx < *idx; old_idx++)
out[old_idx] = jl_rewrap_unionall(out[old_idx], e);
}
Expand Down Expand Up @@ -560,11 +560,11 @@ JL_DLLEXPORT jl_value_t *jl_type_union(jl_value_t **ts, size_t n)
if (n == 1)
return ts[0];

size_t nt = count_union_components(ts, n);
size_t nt = count_union_components(ts, n, 1);
jl_value_t **temp;
JL_GC_PUSHARGS(temp, nt+1);
size_t count = 0;
flatten_type_union(ts, n, temp, &count);
flatten_type_union(ts, n, temp, &count, 1);
assert(count == nt);
size_t j;
for (i = 0; i < nt; i++) {
Expand Down Expand Up @@ -641,14 +641,14 @@ static int simple_subtype2(jl_value_t *a, jl_value_t *b, int hasfree)

jl_value_t *simple_union(jl_value_t *a, jl_value_t *b)
{
size_t nta = count_union_components(&a, 1);
size_t ntb = count_union_components(&b, 1);
size_t nta = count_union_components(&a, 1, 1);
size_t ntb = count_union_components(&b, 1, 1);
size_t nt = nta + ntb;
jl_value_t **temp;
JL_GC_PUSHARGS(temp, nt+1);
size_t count = 0;
flatten_type_union(&a, 1, temp, &count);
flatten_type_union(&b, 1, temp, &count);
flatten_type_union(&a, 1, temp, &count, 1);
flatten_type_union(&b, 1, temp, &count, 1);
assert(count == nt);
size_t i, j;
size_t ra = nta, rb = ntb;
Expand Down Expand Up @@ -717,6 +717,113 @@ jl_value_t *simple_union(jl_value_t *a, jl_value_t *b)
return tu;
}

int obviously_disjoint(jl_value_t *a, jl_value_t *b, int specificity);

static int simple_disjoint(jl_value_t *a, jl_value_t *b, int hasfree)
{
if (jl_is_uniontype(b)) {
jl_value_t *b1 = ((jl_uniontype_t *)b)->a, *b2 = ((jl_uniontype_t *)b)->b;
JL_GC_PUSH2(&b1, &b2);
int res = simple_disjoint(a, b1, hasfree) && simple_disjoint(a, b2, hasfree);
JL_GC_POP();
return res;
}
if (!hasfree && !jl_has_free_typevars(b))
return jl_has_empty_intersection(a, b);
return obviously_disjoint(a, b, 0);
}

jl_value_t *simple_intersect(jl_value_t *a, jl_value_t *b, int overesi)
{
// Unlike `Union`, we don't unwrap `UnionAll` here to avoid possible widening.
size_t nta = count_union_components(&a, 1, 0);
size_t ntb = count_union_components(&b, 1, 0);
size_t nt = nta + ntb;
jl_value_t **temp;
JL_GC_PUSHARGS(temp, nt+1);
size_t count = 0;
flatten_type_union(&a, 1, temp, &count, 0);
flatten_type_union(&b, 1, temp, &count, 0);
assert(count == nt);
size_t i, j;
// first remove disjoint elements.
for (i = 0; i < nt; i++) {
if (simple_disjoint(temp[i], (i < nta ? b : a), jl_has_free_typevars(temp[i])))
temp[i] = NULL;
}
// then check subtyping.
// stemp[k] == -1 : ∃i temp[k] >:ₛ temp[i]
// stemp[k] == 1 : ∃i temp[k] == temp[i]
// stemp[k] == 2 : ∃i temp[k] <:ₛ temp[i]
int8_t *stemp = (int8_t *)alloca(count);
memset(stemp, 0, count);
for (i = 0; i < nta; i++) {
if (temp[i] == NULL) continue;
int hasfree = jl_has_free_typevars(temp[i]);
for (j = nta; j < nt; j++) {
if (temp[j] == NULL) continue;
int subs = simple_subtype2(temp[i], temp[j], hasfree || jl_has_free_typevars(temp[j]));
int subab = subs & 1, subba = subs >> 1;
if (subba && !subab) {
stemp[i] = -1;
if (stemp[j] >= 0) stemp[j] = 2;
}
else if (subab && !subba) {
stemp[j] = -1;
if (stemp[i] >= 0) stemp[i] = 2;
}
else if (subs) {
if (stemp[i] == 0) stemp[i] = 1;
if (stemp[j] == 0) stemp[j] = 1;
}
}
}
int subs[2] = {1, 1}, rs[2] = {1, 1};
for (i = 0; i < nt; i++) {
subs[i >= nta] &= (temp[i] == NULL || stemp[i] > 0);
rs[i >= nta] &= (temp[i] != NULL && stemp[i] > 0);
}
// return a(b) if a(b) <: b(a)
if (rs[0]) {
JL_GC_POP();
return a;
}
if (rs[1]) {
JL_GC_POP();
return b;
}
// return `Union{}` for `merge_env` if we can't prove `<:` or `>:`
if (!overesi && !subs[0] && !subs[1]) {
JL_GC_POP();
return jl_bottom_type;
}
nt = subs[0] ? nta : subs[1] ? nt : nt;
i = subs[0] ? 0 : subs[1] ? nta : 0;
count = nt - i;
if (!subs[0] && !subs[1]) {
// prepare for over estimation
// only preserve `a` with strict <:, but preserve `b` without strict >:
for (j = 0; j < nt; j++) {
if (stemp[j] < (j < nta ? 2 : 0))
temp[j] = NULL;
}
}
isort_union(&temp[i], count);
temp[nt] = jl_bottom_type;
size_t k;
for (k = nt; k-- > i; ) {
if (temp[k] != NULL) {
if (temp[nt] == jl_bottom_type)
temp[nt] = temp[k];
else
temp[nt] = jl_new_struct(jl_uniontype_type, temp[k], temp[nt]);
}
}
assert(temp[nt] != NULL);
jl_value_t *tu = temp[nt];
JL_GC_POP();
return tu;
}

// unionall types -------------------------------------------------------------

Expand Down
15 changes: 3 additions & 12 deletions src/subtype.c
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ static int obviously_in_union(jl_value_t *u, jl_value_t *x)
return obviously_egal(u, x);
}

static int obviously_disjoint(jl_value_t *a, jl_value_t *b, int specificity)
int obviously_disjoint(jl_value_t *a, jl_value_t *b, int specificity)
{
if (a == b || a == (jl_value_t*)jl_any_type || b == (jl_value_t*)jl_any_type)
return 0;
Expand Down Expand Up @@ -559,6 +559,7 @@ static jl_value_t *simple_join(jl_value_t *a, jl_value_t *b)
return simple_union(a, b);
}

jl_value_t *simple_intersect(jl_value_t *a, jl_value_t *b, int overesi);
// Compute a greatest lower bound of `a` and `b`
// For the subtype path, we need to over-estimate this by returning `b` in many cases.
// But for `merge_env`, we'd better under-estimate and return a `Union{}`
Expand All @@ -570,10 +571,6 @@ static jl_value_t *simple_meet(jl_value_t *a, jl_value_t *b, int overesi)
return a;
if (!(jl_is_type(a) || jl_is_typevar(a)) || !(jl_is_type(b) || jl_is_typevar(b)))
return jl_bottom_type;
if (jl_is_uniontype(a) && obviously_in_union(a, b))
return b;
if (jl_is_uniontype(b) && obviously_in_union(b, a))
return a;
if (jl_is_kind(a) && jl_is_type_type(b) && jl_typeof(jl_tparam0(b)) == a)
return b;
if (jl_is_kind(b) && jl_is_type_type(a) && jl_typeof(jl_tparam0(a)) == b)
Expand All @@ -582,13 +579,7 @@ static jl_value_t *simple_meet(jl_value_t *a, jl_value_t *b, int overesi)
return a;
if (jl_is_typevar(b) && obviously_egal(a, ((jl_tvar_t*)b)->ub))
return b;
if (obviously_disjoint(a, b, 0))
return jl_bottom_type;
if (!jl_has_free_typevars(a) && !jl_has_free_typevars(b)) {
if (jl_subtype(a, b)) return a;
if (jl_subtype(b, a)) return b;
}
return overesi ? b : jl_bottom_type;
return simple_intersect(a, b, overesi);
}

// main subtyping algorithm
Expand Down
8 changes: 8 additions & 0 deletions test/subtype.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2491,6 +2491,14 @@ end
@test !<:(Tuple{Type{Int}, Int}, Tuple{Type{Union{Int, T}}, T} where T<:Union{Int8,Int16})
@test <:(Tuple{Type{Int}, Int}, Tuple{Type{Union{Int, T}}, T} where T<:Union{Int8,Int})

#issue #49354 (requires assertions enabled)
@test !<:(Tuple{Type{Union{Int, Val{1}}}, Int}, Tuple{Type{Union{Int, T1}}, T1} where T1<:Val)
@test !<:(Tuple{Type{Union{Int, Val{1}}}, Int}, Tuple{Type{Union{Int, T1}}, T1} where T1<:Union{Val,Pair})
@test <:(Tuple{Type{Union{Int, Val{1}}}, Int}, Tuple{Type{Union{Int, T1}}, T1} where T1<:Union{Integer,Val})
@test <:(Tuple{Type{Union{Int, Int8}}, Int}, Tuple{Type{Union{Int, T1}}, T1} where T1<:Integer)
@test !<:(Tuple{Type{Union{Pair{Int, Any}, Pair{Int, Int}}}, Pair{Int, Any}},
Tuple{Type{Union{Pair{Int, Any}, T1}}, T1} where T1<:(Pair{T,T} where {T}))

let A = Tuple{Type{T}, T, Val{T}} where T,
B = Tuple{Type{S}, Val{S}, Val{S}} where S
@test_broken typeintersect(A, B) != Union{}
Expand Down

0 comments on commit 6b79e8c

Please sign in to comment.