Skip to content

Commit

Permalink
Expand more Vararg elements during re-intersection if valid. (JuliaLa…
Browse files Browse the repository at this point in the history
…ng#46604)

Our type intersection "prefers" `Tuple` with more parameters.
This PR tries to replace `Tuple{Vararg{T,N}}` with
`Tuple{T,T,T,Vararg{T,N}}` during re-intersection if we can prove that
`N >= 3` and `N` is used only for Vararg length.
  • Loading branch information
N5N3 committed Dec 14, 2023
1 parent 9147437 commit d69bb97
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 60 deletions.
147 changes: 96 additions & 51 deletions src/subtype.c
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ typedef struct jl_varbinding_t {
int8_t occurs_inv; // occurs in invariant position
int8_t occurs_cov; // # of occurrences in covariant position
int8_t concrete; // 1 if another variable has a constraint forcing this one to be concrete
int8_t max_offset; // record the maximum positive offset of the variable (up to 32)
// max_offset < 0 if this variable occurs outside VarargNum.
// constraintkind: in covariant position, we try three different ways to compute var ∩ type:
// let ub = var.ub ∩ type
// 0 - var.ub <: type ? var : ub
Expand All @@ -77,6 +79,7 @@ typedef struct jl_varbinding_t {
int8_t constraintkind;
int8_t intvalued; // intvalued: must be integer-valued; i.e. occurs as N in Vararg{_,N}
int8_t limited;
int8_t intersected; // whether this variable has been intersected
int16_t depth0; // # of invariant constructors nested around the UnionAll type for this var
// array of typevars that our bounds depend on, whose UnionAlls need to be
// moved outside ours.
Expand Down Expand Up @@ -168,9 +171,9 @@ static int current_env_length(jl_stenv_t *e)
typedef struct {
int8_t *buf;
int rdepth;
int8_t _space[24]; // == 8 * 3
int8_t _space[32]; // == 8 * 4
jl_gcframe_t gcframe;
jl_value_t *roots[24];
jl_value_t *roots[24]; // == 8 * 3
} jl_savedenv_t;

static void re_save_env(jl_stenv_t *e, jl_savedenv_t *se, int root)
Expand Down Expand Up @@ -200,6 +203,7 @@ static void re_save_env(jl_stenv_t *e, jl_savedenv_t *se, int root)
se->buf[j++] = v->occurs;
se->buf[j++] = v->occurs_inv;
se->buf[j++] = v->occurs_cov;
se->buf[j++] = v->max_offset;
v = v->prev;
}
assert(i == nroots); (void)nroots;
Expand Down Expand Up @@ -231,7 +235,7 @@ static void alloc_env(jl_stenv_t *e, jl_savedenv_t *se, int root)
ct->gcstack = &se->gcframe;
}
}
se->buf = (len > 8 ? (int8_t*)malloc_s(len * 3) : se->_space);
se->buf = (len > 8 ? (int8_t*)malloc_s(len * 4) : se->_space);
#ifdef __clang_gcanalyzer__
memset(se->buf, 0, len * 3);
#endif
Expand Down Expand Up @@ -281,6 +285,7 @@ static void restore_env(jl_stenv_t *e, jl_savedenv_t *se, int root) JL_NOTSAFEPO
v->occurs = se->buf[j++];
v->occurs_inv = se->buf[j++];
v->occurs_cov = se->buf[j++];
v->max_offset = se->buf[j++];
v = v->prev;
}
assert(i == nroots); (void)nroots;
Expand Down Expand Up @@ -677,6 +682,10 @@ static void record_var_occurrence(jl_varbinding_t *vb, jl_stenv_t *e, int param)
else if (vb->occurs_cov < 2) {
vb->occurs_cov++;
}
// Always set `max_offset` to `-1` during the 1st round intersection.
// Would be recovered in `intersect_varargs`/`subtype_tuple_varargs` if needed.
if (!vb->intersected)
vb->max_offset = -1;
}
}

Expand Down Expand Up @@ -888,7 +897,7 @@ static jl_unionall_t *unalias_unionall(jl_unionall_t *u, jl_stenv_t *e)
static int subtype_unionall(jl_value_t *t, jl_unionall_t *u, jl_stenv_t *e, int8_t R, int param)
{
u = unalias_unionall(u, e);
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0, 0, 0,
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0, 0, 0, 0, 0,
e->invdepth, NULL, e->vars };
JL_GC_PUSH4(&u, &vb.lb, &vb.ub, &vb.innervars);
e->vars = &vb;
Expand Down Expand Up @@ -1008,39 +1017,30 @@ static int subtype_tuple_varargs(
jl_value_t *xp0 = jl_unwrap_vararg(vtx); jl_value_t *xp1 = jl_unwrap_vararg_num(vtx);
jl_value_t *yp0 = jl_unwrap_vararg(vty); jl_value_t *yp1 = jl_unwrap_vararg_num(vty);

jl_varbinding_t *xlv = NULL, *ylv = NULL;
if (xp1 && jl_is_typevar(xp1))
xlv = lookup(e, (jl_tvar_t*)xp1);
if (yp1 && jl_is_typevar(yp1))
ylv = lookup(e, (jl_tvar_t*)yp1);

int8_t max_offsetx = xlv ? xlv->max_offset : 0;
int8_t max_offsety = ylv ? ylv->max_offset : 0;

jl_value_t *xl = xlv ? xlv->lb : xp1;
jl_value_t *yl = ylv ? ylv->lb : yp1;

if (!xp1) {
jl_value_t *yl = yp1;
if (yl) {
// Unconstrained on the left, constrained on the right
if (jl_is_typevar(yl)) {
jl_varbinding_t *ylv = lookup(e, (jl_tvar_t*)yl);
if (ylv)
yl = ylv->lb;
}
if (jl_is_long(yl)) {
return 0;
}
}
// Unconstrained on the left, constrained on the right
if (yl && jl_is_long(yl))
return 0;
}
else {
jl_value_t *xl = jl_unwrap_vararg_num(vtx);
if (jl_is_typevar(xl)) {
jl_varbinding_t *xlv = lookup(e, (jl_tvar_t*)xl);
if (xlv)
xl = xlv->lb;
}
if (jl_is_long(xl)) {
if (jl_unbox_long(xl) + 1 == vx) {
// LHS is exhausted. We're a subtype if the RHS is either
// exhausted as well or unbounded (in which case we need to
// set it to 0).
jl_value_t *yl = jl_unwrap_vararg_num(vty);
if (yl) {
if (jl_is_typevar(yl)) {
jl_varbinding_t *ylv = lookup(e, (jl_tvar_t*)yl);
if (ylv)
yl = ylv->lb;
}
if (jl_is_long(yl)) {
return jl_unbox_long(yl) + 1 == vy;
}
Expand Down Expand Up @@ -1090,6 +1090,8 @@ static int subtype_tuple_varargs(
// appropriately.
e->invdepth++;
int ans = subtype((jl_value_t*)jl_any_type, yp1, e, 2);
if (ylv && !ylv->intersected)
ylv->max_offset = max_offsety;
e->invdepth--;
return ans;
}
Expand Down Expand Up @@ -1130,6 +1132,10 @@ static int subtype_tuple_varargs(
e->Loffset = 0;
}
JL_GC_POP();
if (ylv && !ylv->intersected)
ylv->max_offset = max_offsety;
if (xlv && !xlv->intersected)
xlv->max_offset = max_offsetx;
e->invdepth--;
return ans;
}
Expand Down Expand Up @@ -3134,14 +3140,15 @@ static jl_value_t *intersect_unionall(jl_value_t *t, jl_unionall_t *u, jl_stenv_
{
jl_value_t *res = NULL;
jl_savedenv_t se;
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0, 0, 0,
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0, 0, 0, 0, 0,
e->invdepth, NULL, e->vars };
JL_GC_PUSH4(&res, &vb.lb, &vb.ub, &vb.innervars);
save_env(e, &se, 1);
int noinv = !var_occurs_invariant(u->body, u->var);
if (is_leaf_typevar(u->var) && noinv && always_occurs_cov(u->body, u->var, param))
vb.constraintkind = 1;
res = intersect_unionall_(t, u, e, R, param, &vb);
vb.intersected = 1;
if (vb.limited) {
// if the environment got too big, avoid tree recursion and propagate the flag
if (e->vars)
Expand Down Expand Up @@ -3218,17 +3225,20 @@ static jl_value_t *intersect_varargs(jl_vararg_t *vmx, jl_vararg_t *vmy, ssize_t
assert(e->Loffset == 0);
e->Loffset = offset;
jl_varbinding_t *xb = NULL, *yb = NULL;
int8_t max_offsetx = 0, max_offsety = 0;
if (xp2) {
assert(jl_is_typevar(xp2));
xb = lookup(e, (jl_tvar_t*)xp2);
if (xb) xb->intvalued = 1;
if (xb) max_offsetx = xb->max_offset;
if (!yp2)
i2 = bound_var_below((jl_tvar_t*)xp2, xb, e, 0);
}
if (yp2) {
assert(jl_is_typevar(yp2));
yb = lookup(e, (jl_tvar_t*)yp2);
if (yb) yb->intvalued = 1;
if (yb) max_offsety = yb->max_offset;
if (!xp2)
i2 = bound_var_below((jl_tvar_t*)yp2, yb, e, 1);
}
Expand All @@ -3243,14 +3253,27 @@ static jl_value_t *intersect_varargs(jl_vararg_t *vmx, jl_vararg_t *vmy, ssize_t
}
assert(e->Loffset == offset);
e->Loffset = 0;
if (i2 == jl_bottom_type)
if (i2 == jl_bottom_type) {
ii = (jl_value_t*)jl_bottom_type;
else if (xp2 && obviously_egal(xp1, ii) && obviously_egal(xp2, i2))
ii = (jl_value_t*)vmx;
else if (yp2 && obviously_egal(yp1, ii) && obviously_egal(yp2, i2))
ii = (jl_value_t*)vmy;
else
ii = (jl_value_t*)jl_wrap_vararg(ii, i2, 1);
}
else {
if (xb && !xb->intersected) {
xb->max_offset = max_offsetx;
if (offset > xb->max_offset && xb->max_offset >= 0)
xb->max_offset = offset > 32 ? 32 : offset;
}
if (yb && !yb->intersected) {
yb->max_offset = max_offsety;
if (-offset > yb->max_offset && yb->max_offset >= 0)
yb->max_offset = -offset > 32 ? 32 : -offset;
}
if (xp2 && obviously_egal(xp1, ii) && obviously_egal(xp2, i2))
ii = (jl_value_t*)vmx;
else if (yp2 && obviously_egal(yp1, ii) && obviously_egal(yp2, i2))
ii = (jl_value_t*)vmy;
else
ii = (jl_value_t*)jl_wrap_vararg(ii, i2, 1);
}
JL_GC_POP();
return ii;
}
Expand All @@ -3269,6 +3292,24 @@ static jl_value_t *intersect_tuple(jl_datatype_t *xd, jl_datatype_t *yd, jl_sten
llx += jl_unbox_long(jl_unwrap_vararg_num((jl_vararg_t *)jl_tparam(xd, lx-1))) - 1;
if (vvy == JL_VARARG_INT)
lly += jl_unbox_long(jl_unwrap_vararg_num((jl_vararg_t *)jl_tparam(yd, ly-1))) - 1;
if (vvx == JL_VARARG_BOUND && (vvy == JL_VARARG_BOUND || vvy == JL_VARARG_UNBOUND)) {
jl_value_t *xlen = jl_unwrap_vararg_num((jl_vararg_t*)jl_tparam(xd, lx-1));
assert(xlen && jl_is_typevar(xlen));
jl_varbinding_t *xb = lookup(e, (jl_tvar_t*)xlen);
if (xb && xb->intersected && xb->max_offset > 0) {
assert(xb->max_offset <= 32);
llx += xb->max_offset;
}
}
if (vvy == JL_VARARG_BOUND && (vvx == JL_VARARG_BOUND || vvx == JL_VARARG_UNBOUND)) {
jl_value_t *ylen = jl_unwrap_vararg_num((jl_vararg_t*)jl_tparam(yd, ly-1));
assert(ylen && jl_is_typevar(ylen));
jl_varbinding_t *yb = lookup(e, (jl_tvar_t*)ylen);
if (yb && yb->intersected && yb->max_offset > 0) {
assert(yb->max_offset <= 32);
lly += yb->max_offset;
}
}

if ((vvx == JL_VARARG_NONE || vvx == JL_VARARG_INT) &&
(vvy == JL_VARARG_NONE || vvy == JL_VARARG_INT)) {
Expand Down Expand Up @@ -3301,8 +3342,8 @@ static jl_value_t *intersect_tuple(jl_datatype_t *xd, jl_datatype_t *yd, jl_sten
assert(i == j && i == np);
break;
}
if (xi && jl_is_vararg(xi)) vx = vvx != JL_VARARG_INT;
if (yi && jl_is_vararg(yi)) vy = vvy != JL_VARARG_INT;
if (xi && jl_is_vararg(xi)) vx = vvx == JL_VARARG_UNBOUND || (vvx == JL_VARARG_BOUND && i == llx - 1);
if (yi && jl_is_vararg(yi)) vy = vvy == JL_VARARG_UNBOUND || (vvy == JL_VARARG_BOUND && j == lly - 1);
if (xi == NULL || yi == NULL) {
if (vx && intersect_vararg_length(xi, lly+1-llx, e, 0)) {
np = j;
Expand Down Expand Up @@ -3845,15 +3886,15 @@ static int merge_env(jl_stenv_t *e, jl_savedenv_t *se, int count)
roots = se->roots;
nroots = se->gcframe.nroots >> 2;
}
int n = 0;
int m = 0, n = 0;
jl_varbinding_t *v = e->vars;
v = e->vars;
while (v != NULL) {
if (count == 0) {
// need to initialize this
se->buf[n] = 0;
se->buf[n+1] = 0;
se->buf[n+2] = 0;
se->buf[m] = 0;
se->buf[m+1] = 0;
se->buf[m+2] = 0;
se->buf[m+3] = v->max_offset;
}
if (v->occurs) {
// only merge lb/ub/innervars if this var occurs.
Expand All @@ -3879,13 +3920,17 @@ static int merge_env(jl_stenv_t *e, jl_savedenv_t *se, int count)
roots[n+2] = b2;
}
// record the meeted vars.
se->buf[n] = 1;
se->buf[m] = 1;
}
// always merge occurs_inv/cov by max (never decrease)
if (v->occurs_inv > se->buf[n+1])
se->buf[n+1] = v->occurs_inv;
if (v->occurs_cov > se->buf[n+2])
se->buf[n+2] = v->occurs_cov;
if (v->occurs_inv > se->buf[m+1])
se->buf[m+1] = v->occurs_inv;
if (v->occurs_cov > se->buf[m+2])
se->buf[m+2] = v->occurs_cov;
// always merge max_offset by min
if (!v->intersected && v->max_offset < se->buf[m+3])
se->buf[m+3] = v->max_offset;
m = m + 4;
n = n + 3;
v = v->prev;
}
Expand Down Expand Up @@ -3917,7 +3962,7 @@ static void final_merge_env(jl_stenv_t *e, jl_savedenv_t *me, jl_savedenv_t *se)
}
assert(nroots == current_env_length(e) * 3);
assert(nroots % 3 == 0);
for (int n = 0; n < nroots; n = n + 3) {
for (int n = 0, m = 0; n < nroots; n += 3, m += 4) {
if (merged[n] == NULL)
merged[n] = saved[n];
if (merged[n+1] == NULL)
Expand All @@ -3933,7 +3978,7 @@ static void final_merge_env(jl_stenv_t *e, jl_savedenv_t *me, jl_savedenv_t *se)
else
merged[n+2] = b2;
}
me->buf[n] |= se->buf[n];
me->buf[m] |= se->buf[m];
}
}

Expand Down Expand Up @@ -4489,7 +4534,7 @@ static jl_value_t *_widen_diagonal(jl_value_t *t, jl_varbinding_t *troot) {

static jl_value_t *widen_diagonal(jl_value_t *t, jl_unionall_t *u, jl_varbinding_t *troot)
{
jl_varbinding_t vb = { u->var, NULL, NULL, 1, 0, 0, 0, 0, 0, 0, 0, 0, NULL, troot };
jl_varbinding_t vb = { u->var, NULL, NULL, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, NULL, troot };
jl_value_t *nt;
JL_GC_PUSH2(&vb.innervars, &nt);
if (jl_is_unionall(u->body))
Expand Down
23 changes: 14 additions & 9 deletions test/subtype.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2210,13 +2210,19 @@ let A = Tuple{NTuple{N, Int}, NTuple{N, Int}} where N,
Bs = (Tuple{Tuple{Int, Vararg{Any}}, Tuple{Int, Int, Vararg{Any}}},
Tuple{Tuple{Int, Vararg{Any,N1}}, Tuple{Int, Int, Vararg{Any,N2}}} where {N1,N2},
Tuple{Tuple{Int, Vararg{Any,N}} where {N}, Tuple{Int, Int, Vararg{Any,N}} where {N}})
Cerr = Tuple{Tuple{Int, Vararg{Int, N}}, Tuple{Int, Int, Vararg{Int, N}}} where {N}
C = Tuple{Tuple{Int, Int, Vararg{Int, N}}, Tuple{Int, Int, Vararg{Int, N}}} where {N}
for B in Bs
C = typeintersect(A, B)
@test C == typeintersect(B, A) != Union{}
@test C != Cerr
# TODO: The ideal result is Tuple{Tuple{Int, Int, Vararg{Int, N}}, Tuple{Int, Int, Vararg{Int, N}}} where {N}
@test_broken C != Tuple{Tuple{Int, Vararg{Int}}, Tuple{Int, Int, Vararg{Int}}}
@testintersect(A, B, C)
end
A = Tuple{NTuple{N, Int}, Tuple{Int, Vararg{Int, N}}} where N
C = Tuple{Tuple{Int, Vararg{Int, N}}, Tuple{Int, Int, Vararg{Int, N}}} where {N}
for B in Bs
@testintersect(A, B, C)
end
A = Tuple{Tuple{Int, Vararg{Int, N}}, NTuple{N, Int}} where N
C = Tuple{Tuple{Int, Int, Int, Vararg{Int, N}}, Tuple{Int, Int, Vararg{Int, N}}} where {N}
for B in Bs
@testintersect(A, B, C)
end
end

Expand All @@ -2229,9 +2235,8 @@ let A = Pair{NTuple{N, Int}, NTuple{N, Int}} where N,
Bs = (Pair{<:Tuple{Int, Vararg{Int}}, <:Tuple{Int, Int, Vararg{Int}}},
Pair{Tuple{Int, Vararg{Int,N1}}, Tuple{Int, Int, Vararg{Int,N2}}} where {N1,N2},
Pair{<:Tuple{Int, Vararg{Int,N}} where {N}, <:Tuple{Int, Int, Vararg{Int,N}} where {N}})
Cs = (Bs[2], Bs[2], Bs[3])
for (B, C) in zip(Bs, Cs)
# TODO: The ideal result is Pair{Tuple{Int, Int, Vararg{Int, N}}, Tuple{Int, Int, Vararg{Int, N}}} where {N}
C = Pair{Tuple{Int, Int, Vararg{Int, N}}, Tuple{Int, Int, Vararg{Int, N}}} where {N}
for B in Bs
@testintersect(A, B, C)
end
end
Expand Down

0 comments on commit d69bb97

Please sign in to comment.