Skip to content

Commit

Permalink
Fix invdepth within existential subtyping. (#49049)
Browse files Browse the repository at this point in the history
* Remove `Rinvdepth` from `stenv`
  • Loading branch information
N5N3 committed Mar 21, 2023
1 parent 826674c commit ceffaee
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 67 deletions.
92 changes: 31 additions & 61 deletions src/subtype.c
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,7 @@ typedef struct jl_stenv_t {
jl_value_t **envout; // for passing caller the computed bounds of right-side variables
int envsz; // length of envout
int envidx; // current index in envout
int invdepth; // # of invariant constructors we're nested in on the left
int Rinvdepth; // # of invariant constructors we're nested in on the right
int invdepth; // current number of invariant constructors we're nested in
int ignore_free; // treat free vars as black boxes; used during intersection
int intersection; // true iff subtype is being called from intersection
int emptiness_only; // true iff intersection only needs to test for emptiness
Expand Down Expand Up @@ -658,7 +657,7 @@ static void record_var_occurrence(jl_varbinding_t *vb, jl_stenv_t *e, int param)
vb->occurs = 1;
if (vb != NULL && param) {
// saturate counters at 2; we don't need values bigger than that
if (param == 2 && (vb->right ? e->Rinvdepth : e->invdepth) > vb->depth0) {
if (param == 2 && e->invdepth > vb->depth0) {
if (vb->occurs_inv < 2)
vb->occurs_inv++;
}
Expand All @@ -680,7 +679,7 @@ static int var_outside(jl_stenv_t *e, jl_tvar_t *x, jl_tvar_t *y)
return 0;
}

static jl_value_t *intersect_aside(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int R, int d);
static jl_value_t *intersect_aside(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int depth);

static int reachable_var(jl_value_t *x, jl_tvar_t *y, jl_stenv_t *e);

Expand All @@ -700,7 +699,7 @@ static int var_lt(jl_tvar_t *b, jl_value_t *a, jl_stenv_t *e, int param)
// for this to work we need to compute issub(left,right) before issub(right,left),
// since otherwise the issub(a, bb.ub) check in var_gt becomes vacuous.
if (e->intersection) {
jl_value_t *ub = intersect_aside(bb->ub, a, e, 0, bb->depth0);
jl_value_t *ub = intersect_aside(a, bb->ub, e, bb->depth0);
JL_GC_PUSH1(&ub);
if (ub != (jl_value_t*)b && (!jl_is_typevar(ub) || !reachable_var(ub, b, e)))
bb->ub = ub;
Expand Down Expand Up @@ -849,7 +848,7 @@ static int subtype_unionall(jl_value_t *t, jl_unionall_t *u, jl_stenv_t *e, int8
{
u = unalias_unionall(u, e);
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0, 0, 0,
R ? e->Rinvdepth : e->invdepth, 0, NULL, e->vars };
e->invdepth, 0, NULL, e->vars };
JL_GC_PUSH4(&u, &vb.lb, &vb.ub, &vb.innervars);
e->vars = &vb;
int ans;
Expand Down Expand Up @@ -949,10 +948,8 @@ static int check_vararg_length(jl_value_t *v, ssize_t n, jl_stenv_t *e)
jl_value_t *nn = jl_box_long(n);
JL_GC_PUSH1(&nn);
e->invdepth++;
e->Rinvdepth++;
int ans = subtype(nn, N, e, 2) && subtype(N, nn, e, 0);
e->invdepth--;
e->Rinvdepth--;
JL_GC_POP();
if (!ans)
return 0;
Expand Down Expand Up @@ -1049,16 +1046,13 @@ static int subtype_tuple_varargs(
// set lb to Any. Since `intvalued` is set, we'll interpret that
// appropriately.
e->invdepth++;
e->Rinvdepth++;
int ans = subtype((jl_value_t*)jl_any_type, yp1, e, 2);
e->invdepth--;
e->Rinvdepth--;
return ans;
}

// Vararg{T,N} <: Vararg{T2,N2}; equate N and N2
e->invdepth++;
e->Rinvdepth++;
JL_GC_PUSH2(&xp1, &yp1);
if (xp1 && jl_is_long(xp1) && vx != 1)
xp1 = jl_box_long(jl_unbox_long(xp1) - vx + 1);
Expand All @@ -1067,7 +1061,6 @@ static int subtype_tuple_varargs(
int ans = forall_exists_equal(xp1, yp1, e);
JL_GC_POP();
e->invdepth--;
e->Rinvdepth--;
return ans;
}

Expand Down Expand Up @@ -1354,10 +1347,7 @@ static int subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param)
// The answer is true iff `T` has full bounds (as in `Type`), but this needs to
// be checked at the same depth where `Type{T}` occurs --- the depth of the LHS
// doesn't matter because it (e.g. `DataType`) doesn't actually contain the variable.
int saved = e->invdepth;
e->invdepth = e->Rinvdepth;
int issub = subtype((jl_value_t*)jl_type_type, y, e, param);
e->invdepth = saved;
return issub;
}
while (xd != jl_any_type && xd->name != yd->name) {
Expand All @@ -1373,15 +1363,13 @@ static int subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param)
size_t i, np = jl_nparams(xd);
int ans = 1;
e->invdepth++;
e->Rinvdepth++;
for (i=0; i < np; i++) {
jl_value_t *xi = jl_tparam(xd, i), *yi = jl_tparam(yd, i);
if (!(xi == yi || forall_exists_equal(xi, yi, e))) {
ans = 0; break;
}
}
e->invdepth--;
e->Rinvdepth--;
return ans;
}
if (jl_is_type(y))
Expand Down Expand Up @@ -1573,7 +1561,7 @@ static void init_stenv(jl_stenv_t *e, jl_value_t **env, int envsz)
if (envsz)
memset(env, 0, envsz*sizeof(void*));
e->envidx = 0;
e->invdepth = e->Rinvdepth = 0;
e->invdepth = 0;
e->ignore_free = 0;
e->intersection = 0;
e->emptiness_only = 0;
Expand Down Expand Up @@ -2028,31 +2016,20 @@ JL_DLLEXPORT int jl_subtype_env(jl_value_t *x, jl_value_t *y, jl_value_t **env,
return subtype;
}

static int subtype_in_env_(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int invdepth, int Rinvdepth)
static int subtype_in_env(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
{
jl_stenv_t e2;
init_stenv(&e2, NULL, 0);
e2.vars = e->vars;
e2.intersection = e->intersection;
e2.ignore_free = e->ignore_free;
e2.invdepth = invdepth;
e2.Rinvdepth = Rinvdepth;
e2.invdepth = e->invdepth;
e2.envsz = e->envsz;
e2.envout = e->envout;
e2.envidx = e->envidx;
return forall_exists_subtype(x, y, &e2, 0);
}

static int subtype_in_env(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
{
return subtype_in_env_(x, y, e, e->invdepth, e->Rinvdepth);
}

static int subtype_bounds_in_env(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int R, int d)
{
return subtype_in_env_(x, y, e, R ? e->invdepth : d, R ? d : e->Rinvdepth);
}

JL_DLLEXPORT int jl_subtype(jl_value_t *x, jl_value_t *y)
{
return jl_subtype_env(x, y, NULL, 0);
Expand Down Expand Up @@ -2259,27 +2236,23 @@ static jl_value_t *intersect(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int pa
static jl_value_t *intersect_all(jl_value_t *x, jl_value_t *y, jl_stenv_t *e);

// intersect in nested union environment, similar to subtype_ccheck
static jl_value_t *intersect_aside(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int R, int d)
static jl_value_t *intersect_aside(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int depth)
{
// band-aid for #30335
if (x == (jl_value_t*)jl_any_type && !jl_is_typevar(y))
return y;
if (y == (jl_value_t*)jl_any_type && !jl_is_typevar(x))
return x;
// band-aid for #46736
if (jl_egal(x, y))
if (obviously_egal(x, y))
return x;

jl_saved_unionstate_t oldRunions; push_unionstate(&oldRunions, &e->Runions);
int savedepth = e->invdepth, Rsavedepth = e->Rinvdepth;
// TODO: this doesn't quite make sense
e->invdepth = e->Rinvdepth = d;

int savedepth = e->invdepth;
e->invdepth = depth;
jl_value_t *res = intersect_all(x, y, e);

pop_unionstate(&e->Runions, &oldRunions);
e->invdepth = savedepth;
e->Rinvdepth = Rsavedepth;
pop_unionstate(&e->Runions, &oldRunions);
return res;
}

Expand Down Expand Up @@ -2380,14 +2353,14 @@ static int try_subtype_by_bounds(jl_value_t *a, jl_value_t *b, jl_stenv_t *e)
return 0;
}

static int try_subtype_in_env(jl_value_t *a, jl_value_t *b, jl_stenv_t *e, int R, int d)
static int try_subtype_in_env(jl_value_t *a, jl_value_t *b, jl_stenv_t *e, int flip)
{
if (a == jl_bottom_type || b == (jl_value_t *)jl_any_type || try_subtype_by_bounds(a, b, e))
return 1;
jl_value_t *root=NULL; jl_savedenv_t se;
JL_GC_PUSH1(&root);
save_env(e, &root, &se);
int ret = subtype_bounds_in_env(a, b, e, R, d);
int ret = subtype_in_env(a, b, e);
restore_env(e, root, &se);
free_env(&se);
JL_GC_POP();
Expand All @@ -2409,7 +2382,7 @@ static void set_bound(jl_value_t **bound, jl_value_t *val, jl_tvar_t *v, jl_sten
}

// subtype, treating all vars as existential
static int subtype_in_env_existential(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int R, int d)
static int subtype_in_env_existential(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int flip)
{
jl_varbinding_t *v = e->vars;
int len = 0;
Expand All @@ -2428,7 +2401,7 @@ static int subtype_in_env_existential(jl_value_t *x, jl_value_t *y, jl_stenv_t *
v->right = 1;
v = v->prev;
}
int issub = subtype_bounds_in_env(x, y, e, R, d);
int issub = subtype_in_env(x, y, e);
n = 0; v = e->vars;
while (n < len) {
assert(v != NULL);
Expand Down Expand Up @@ -2506,25 +2479,23 @@ static jl_value_t *intersect_var(jl_tvar_t *b, jl_value_t *a, jl_stenv_t *e, int
{
jl_varbinding_t *bb = lookup(e, b);
if (bb == NULL)
return R ? intersect_aside(a, b->ub, e, 1, 0) : intersect_aside(b->ub, a, e, 0, 0);
return R ? intersect_aside(a, b->ub, e, 0) : intersect_aside(b->ub, a, e, 0);
if (reachable_var(bb->lb, b, e) || reachable_var(bb->ub, b, e))
return a;
if (bb->lb == bb->ub && jl_is_typevar(bb->lb)) {
return intersect(a, bb->lb, e, param);
}
if (bb->lb == bb->ub && jl_is_typevar(bb->lb))
return R ? intersect(a, bb->lb, e, param) : intersect(bb->lb, a, e, param);
if (!jl_is_type(a) && !jl_is_typevar(a))
return set_var_to_const(bb, a, NULL);
int d = bb->depth0;
jl_value_t *root=NULL; jl_savedenv_t se;
if (param == 2) {
jl_value_t *ub = NULL;
JL_GC_PUSH2(&ub, &root);
if (!jl_has_free_typevars(a)) {
save_env(e, &root, &se);
int issub = subtype_in_env_existential(bb->lb, a, e, 0, d);
int issub = subtype_in_env_existential(bb->lb, a, e, R);
restore_env(e, root, &se);
if (issub) {
issub = subtype_in_env_existential(a, bb->ub, e, 1, d);
issub = subtype_in_env_existential(a, bb->ub, e, !R);
restore_env(e, root, &se);
}
free_env(&se);
Expand All @@ -2536,10 +2507,10 @@ static jl_value_t *intersect_var(jl_tvar_t *b, jl_value_t *a, jl_stenv_t *e, int
}
else {
e->triangular++;
ub = R ? intersect_aside(a, bb->ub, e, 1, d) : intersect_aside(bb->ub, a, e, 0, d);
ub = R ? intersect_aside(a, bb->ub, e, bb->depth0) : intersect_aside(bb->ub, a, e, bb->depth0);
e->triangular--;
save_env(e, &root, &se);
int issub = subtype_in_env_existential(bb->lb, ub, e, 0, d);
int issub = subtype_in_env_existential(bb->lb, ub, e, R);
restore_env(e, root, &se);
free_env(&se);
if (!issub) {
Expand Down Expand Up @@ -2570,7 +2541,7 @@ static jl_value_t *intersect_var(jl_tvar_t *b, jl_value_t *a, jl_stenv_t *e, int
JL_GC_POP();
return ub;
}
jl_value_t *ub = R ? intersect_aside(a, bb->ub, e, 1, d) : intersect_aside(bb->ub, a, e, 0, d);
jl_value_t *ub = R ? intersect_aside(a, bb->ub, e, bb->depth0) : intersect_aside(bb->ub, a, e, bb->depth0);
if (ub == jl_bottom_type)
return jl_bottom_type;
if (bb->constraintkind == 1 || e->triangular) {
Expand All @@ -2581,7 +2552,7 @@ static jl_value_t *intersect_var(jl_tvar_t *b, jl_value_t *a, jl_stenv_t *e, int
}
else if (bb->constraintkind == 0) {
JL_GC_PUSH1(&ub);
if (!jl_is_typevar(a) && try_subtype_in_env(bb->ub, a, e, 0, d)) {
if (!jl_is_typevar(a) && try_subtype_in_env(bb->ub, a, e, R)) {
JL_GC_POP();
return (jl_value_t*)b;
}
Expand Down Expand Up @@ -2911,7 +2882,7 @@ static jl_value_t *intersect_unionall(jl_value_t *t, jl_unionall_t *u, jl_stenv_
jl_value_t *res=NULL, *save=NULL;
jl_savedenv_t se;
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0, 0, 0,
R ? e->Rinvdepth : e->invdepth, 0, NULL, e->vars };
e->invdepth, 0, NULL, e->vars };
JL_GC_PUSH5(&res, &vb.lb, &vb.ub, &save, &vb.innervars);
save_env(e, &save, &se);
res = intersect_unionall_(t, u, e, R, param, &vb);
Expand Down Expand Up @@ -3122,10 +3093,8 @@ static jl_value_t *intersect_invariant(jl_value_t *x, jl_value_t *y, jl_stenv_t
return (jl_subtype(x,y) && jl_subtype(y,x)) ? y : NULL;
}
e->invdepth++;
e->Rinvdepth++;
jl_value_t *ii = intersect(x, y, e, 2);
e->invdepth--;
e->Rinvdepth--;
// Skip the following subtype check if `ii` was returned from `set_vat_to_const`.
// As `var_gt`/`var_lt` might not handle `Vararg` length offset correctly.
// TODO: fix this on subtype side and remove this branch.
Expand All @@ -3148,11 +3117,11 @@ static jl_value_t *intersect_invariant(jl_value_t *x, jl_value_t *y, jl_stenv_t
jl_savedenv_t se;
JL_GC_PUSH2(&ii, &root);
save_env(e, &root, &se);
if (!subtype_in_env_existential(x, y, e, 0, e->invdepth))
if (!subtype_in_env_existential(x, y, e, 0))
ii = NULL;
else {
restore_env(e, root, &se);
if (!subtype_in_env_existential(y, x, e, 0, e->invdepth))
if (!subtype_in_env_existential(y, x, e, 1))
ii = NULL;
}
restore_env(e, root, &se);
Expand Down Expand Up @@ -3314,7 +3283,8 @@ static jl_value_t *intersect(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int pa
}
jl_value_t *ub=NULL, *lb=NULL;
JL_GC_PUSH2(&lb, &ub);
ub = intersect_aside(xub, yub, e, 0, xx ? xx->depth0 : 0);
int d = xx ? xx->depth0 : yy ? yy->depth0 : 0;
ub = R ? intersect_aside(yub, xub, e, d) : intersect_aside(xub, yub, e, d);
if (reachable_var(xlb, (jl_tvar_t*)y, e))
lb = ylb;
else
Expand Down
16 changes: 10 additions & 6 deletions test/subtype.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1804,8 +1804,14 @@ end
#end

# issue #32386
@test typeintersect(Type{S} where S<:(Vector{Pair{_A,N} where N} where _A),
Type{Vector{T}} where T) == Type{Vector{Pair{_A,N} where N}} where _A
@testintersect(Type{S} where S<:(Vector{Pair{_A,N} where N} where _A),
Type{Vector{T}} where T,
Type{Vector{Pair{_A,N} where N}} where _A)

# pr #49049
@testintersect(Tuple{Type{Pair{T, A} where {T, A<:Array{T}}}, Int, Any},
Tuple{Type{F}, Any, Int} where {F<:(Pair{T, A} where {T, A<:Array{T}})},
Tuple{Type{Pair{T, A} where {T, A<:(Array{T})}}, Int, Int})

# issue #32488
struct S32488{S <: Tuple, T, N, L}
Expand Down Expand Up @@ -2431,11 +2437,9 @@ abstract type MyAbstract47877{C}; end
struct MyType47877{A,B} <: MyAbstract47877{A} end
let A = Tuple{Type{T}, T} where T,
B = Tuple{Type{MyType47877{W, V} where V<:Union{Base.BitInteger, MyAbstract47877{W}}}, MyAbstract47877{<:Base.BitInteger}} where W
C = Tuple{Type{MyType47877{W, V} where V<:Union{MyAbstract47877{W1}, Base.BitInteger}}, MyType47877{W, V} where V<:Union{MyAbstract47877{W1}, Base.BitInteger}} where {W<:Base.BitInteger, W1<:Base.BitInteger}
# ensure that merge_env for innervars does not blow up (the large Unions ensure this will take excessive memory if it does)
@test typeintersect(A, B) == C # suboptimal, but acceptable
C = Tuple{Type{MyType47877{W, V} where V<:Union{MyAbstract47877{W}, Base.BitInteger}}, MyType47877{W, V} where V<:Union{MyAbstract47877{W}, Base.BitInteger}} where W<:Base.BitInteger
@test typeintersect(B, A) == C
# ensure that merge_env for innervars does not blow up (the large Unions ensure this will take excessive memory if it does)
@testintersect(A, B, C)
end

let a = (isodd(i) ? Pair{Char, String} : Pair{String, String} for i in 1:2000)
Expand Down

0 comments on commit ceffaee

Please sign in to comment.