Skip to content

Commit

Permalink
setindex: disallow breaking the object model
Browse files Browse the repository at this point in the history
This was written fairly carefully to be safe, assuming it was not improperly optimized.
But others are not as careful when copying this code. And it is just better not to break the object model and attempt to mutate constant values.
  • Loading branch information
vtjnash committed Dec 23, 2019
1 parent 4a8ea8c commit 77624c0
Show file tree
Hide file tree
Showing 10 changed files with 124 additions and 62 deletions.
37 changes: 26 additions & 11 deletions base/deepcopy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,20 +56,35 @@ end

function deepcopy_internal(@nospecialize(x), stackdict::IdDict)
T = typeof(x)::DataType
isbitstype(T) && return x
if haskey(stackdict, x)
return stackdict[x]
end
y = ccall(:jl_new_struct_uninit, Any, (Any,), T)
nf = nfields(x)
if T.mutable
if haskey(stackdict, x)
return stackdict[x]
end
y = ccall(:jl_new_struct_uninit, Any, (Any,), T)
stackdict[x] = y
end
for i in 1:nfields(x)
if isdefined(x,i)
xi = getfield(x, i)
xi = deepcopy_internal(xi, stackdict)::typeof(xi)
ccall(:jl_set_nth_field, Cvoid, (Any, Csize_t, Any), y, i-1, xi)
for i in 1:nf
if isdefined(x, i)
xi = getfield(x, i)
xi = deepcopy_internal(xi, stackdict)::typeof(xi)
ccall(:jl_set_nth_field, Cvoid, (Any, Csize_t, Any), y, i-1, xi)
end
end
elseif nf == 0 || isbitstype(T)
y = x
else
flds = Vector{Any}(undef, nf)
for i in 1:nf
if isdefined(x, i)
xi = getfield(x, i)
xi = deepcopy_internal(xi, stackdict)::typeof(xi)
flds[i] = xi
else
nf = i - 1 # rest of tail must be undefined values
break
end
end
y = ccall(:jl_new_structv, Any, (Any, Ptr{Any}, UInt32), T, flds, nf)
end
return y::T
end
Expand Down
17 changes: 12 additions & 5 deletions src/builtins.c
Original file line number Diff line number Diff line change
Expand Up @@ -717,15 +717,22 @@ JL_CALLABLE(jl_f_tuple)
tt = jl_inst_concrete_tupletype(types);
JL_GC_POP();
}
return jl_new_structv(tt, args, nargs);
if (tt->instance != NULL)
return tt->instance;
jl_ptls_t ptls = jl_get_ptls_states();
jl_value_t *jv = jl_gc_alloc(ptls, jl_datatype_size(tt), tt);
for (i = 0; i < nargs; i++)
set_nth_field(tt, (void*)jv, i, args[i]);
return jv;
}

JL_CALLABLE(jl_f_svec)
{
size_t i;
if (nargs == 0) return (jl_value_t*)jl_emptysvec;
if (nargs == 0)
return (jl_value_t*)jl_emptysvec;
jl_svec_t *t = jl_alloc_svec_uninit(nargs);
for(i=0; i < nargs; i++) {
for (i = 0; i < nargs; i++) {
jl_svecset(t, i, args[i]);
}
return (jl_value_t*)t;
Expand Down Expand Up @@ -785,11 +792,11 @@ JL_CALLABLE(jl_f_setfield)
JL_TYPECHK(setfield!, symbol, args[1]);
idx = jl_field_index(st, (jl_sym_t*)args[1], 1);
}
jl_value_t *ft = jl_field_type(st,idx);
jl_value_t *ft = jl_field_type(st, idx);
if (!jl_isa(args[2], ft)) {
jl_type_error("setfield!", ft, args[2]);
}
jl_set_nth_field(v, idx, args[2]);
set_nth_field(st, (void*)v, idx, args[2]);
return args[2];
}

Expand Down
24 changes: 13 additions & 11 deletions src/datatype.c
Original file line number Diff line number Diff line change
Expand Up @@ -897,7 +897,7 @@ JL_DLLEXPORT jl_value_t *jl_new_struct(jl_datatype_t *type, ...)
va_start(args, type);
jl_value_t *jv = jl_gc_alloc(ptls, jl_datatype_size(type), type);
for (size_t i = 0; i < nf; i++) {
jl_set_nth_field(jv, i, va_arg(args, jl_value_t*));
set_nth_field(type, (void*)jv, i, va_arg(args, jl_value_t*));
}
va_end(args);
return jv;
Expand All @@ -906,14 +906,15 @@ JL_DLLEXPORT jl_value_t *jl_new_struct(jl_datatype_t *type, ...)
static void init_struct_tail(jl_datatype_t *type, jl_value_t *jv, size_t na)
{
size_t nf = jl_datatype_nfields(type);
for(size_t i=na; i < nf; i++) {
char *data = (char*)jl_data_ptr(jv);
for (size_t i = na; i < nf; i++) {
if (jl_field_isptr(type, i)) {
*(jl_value_t**)((char*)jl_data_ptr(jv)+jl_field_offset(type,i)) = NULL;
*(jl_value_t**)(data + jl_field_offset(type, i)) = NULL;
}
else {
jl_value_t *ft = jl_field_type(type, i);
if (jl_is_uniontype(ft)) {
uint8_t *psel = &((uint8_t *)jv)[jl_field_offset(type, i) + jl_field_size(type, i) - 1];
uint8_t *psel = &((uint8_t *)data)[jl_field_offset(type, i) + jl_field_size(type, i) - 1];
*psel = 0;
}
}
Expand All @@ -923,6 +924,10 @@ static void init_struct_tail(jl_datatype_t *type, jl_value_t *jv, size_t na)
JL_DLLEXPORT jl_value_t *jl_new_structv(jl_datatype_t *type, jl_value_t **args, uint32_t na)
{
jl_ptls_t ptls = jl_get_ptls_states();
if (!jl_is_datatype(type) || type->layout == NULL)
jl_type_error("new", (jl_value_t*)jl_datatype_type, (jl_value_t*)type);
if ((type->ninitialized > na && !type->mutabl) || na > jl_datatype_nfields(type))
jl_error("invalid struct allocation");
if (type->instance != NULL) {
for (size_t i = 0; i < na; i++) {
jl_value_t *ft = jl_field_type(type, i);
Expand All @@ -931,15 +936,13 @@ JL_DLLEXPORT jl_value_t *jl_new_structv(jl_datatype_t *type, jl_value_t **args,
}
return type->instance;
}
if (type->layout == NULL)
jl_type_error("new", (jl_value_t*)jl_datatype_type, (jl_value_t*)type);
jl_value_t *jv = jl_gc_alloc(ptls, jl_datatype_size(type), type);
JL_GC_PUSH1(&jv);
for (size_t i = 0; i < na; i++) {
jl_value_t *ft = jl_field_type(type, i);
if (!jl_isa(args[i], ft))
jl_type_error("new", ft, args[i]);
jl_set_nth_field(jv, i, args[i]);
set_nth_field(type, (void*)jv, i, args[i]);
}
init_struct_tail(type, jv, na);
JL_GC_POP();
Expand All @@ -951,7 +954,7 @@ JL_DLLEXPORT jl_value_t *jl_new_structt(jl_datatype_t *type, jl_value_t *tup)
jl_ptls_t ptls = jl_get_ptls_states();
if (!jl_is_tuple(tup))
jl_type_error("new", (jl_value_t*)jl_tuple_type, tup);
if (type->layout == NULL)
if (!jl_is_datatype(type) || type->layout == NULL)
jl_type_error("new", (jl_value_t *)jl_datatype_type, (jl_value_t *)type);
size_t nargs = jl_nfields(tup);
size_t nf = jl_datatype_nfields(type);
Expand All @@ -975,7 +978,7 @@ JL_DLLEXPORT jl_value_t *jl_new_structt(jl_datatype_t *type, jl_value_t *tup)
fi = jl_get_nth_field(tup, i);
if (!jl_isa(fi, ft))
jl_type_error("new", ft, fi);
jl_set_nth_field(jv, i, fi);
set_nth_field(type, (void*)jv, i, fi);
}
JL_GC_POP();
return jv;
Expand Down Expand Up @@ -1074,9 +1077,8 @@ JL_DLLEXPORT jl_value_t *jl_get_nth_field_checked(jl_value_t *v, size_t i)
return undefref_check((jl_datatype_t*)ty, jl_new_bits(ty, (char*)v + offs));
}

JL_DLLEXPORT void jl_set_nth_field(jl_value_t *v, size_t i, jl_value_t *rhs) JL_NOTSAFEPOINT
void set_nth_field(jl_datatype_t *st, void *v, size_t i, jl_value_t *rhs) JL_NOTSAFEPOINT
{
jl_datatype_t *st = (jl_datatype_t*)jl_typeof(v);
size_t offs = jl_field_offset(st, i);
if (jl_field_isptr(st, i)) {
*(jl_value_t**)((char*)v + offs) = rhs;
Expand Down
4 changes: 2 additions & 2 deletions src/dump.c
Original file line number Diff line number Diff line change
Expand Up @@ -2112,7 +2112,7 @@ static jl_value_t *jl_deserialize_value(jl_serializer_state *s, jl_value_t **loc
v = jl_new_struct_uninit(tag == TAG_GOTONODE ? jl_gotonode_type : jl_quotenode_type);
if (usetable)
arraylist_push(&backref_list, v);
jl_set_nth_field(v, 0, jl_deserialize_value(s, NULL));
set_nth_field(tag == TAG_GOTONODE ? jl_gotonode_type : jl_quotenode_type, (void*)v, 0, jl_deserialize_value(s, NULL));
return v;
case TAG_UNIONALL:
pos = backref_list.len;
Expand Down Expand Up @@ -2228,7 +2228,7 @@ static jl_value_t *jl_deserialize_value(jl_serializer_state *s, jl_value_t **loc
arraylist_push(&backref_list, v);
for (i = 0; i < jl_datatype_nfields(jl_lineinfonode_type); i++) {
size_t offs = jl_field_offset(jl_lineinfonode_type, i);
jl_set_nth_field(v, i, jl_deserialize_value(s, (jl_value_t**)((char*)v + offs)));
set_nth_field(jl_lineinfonode_type, (void*)v, i, jl_deserialize_value(s, (jl_value_t**)((char*)v + offs)));
}
return v;
case TAG_DATATYPE:
Expand Down
2 changes: 0 additions & 2 deletions src/interpreter.c
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,6 @@ static jl_value_t *eval_value(jl_value_t *e, interpreter_state *s)
JL_GC_PUSHARGS(argv, nargs);
for (size_t i = 0; i < nargs; i++)
argv[i] = eval_value(args[i], s);
assert(jl_is_structtype(argv[0]));
jl_value_t *v = jl_new_structv((jl_datatype_t*)argv[0], &argv[1], nargs - 1);
JL_GC_POP();
return v;
Expand All @@ -519,7 +518,6 @@ static jl_value_t *eval_value(jl_value_t *e, interpreter_state *s)
JL_GC_PUSHARGS(argv, 2);
argv[0] = eval_value(args[0], s);
argv[1] = eval_value(args[1], s);
assert(jl_is_structtype(argv[0]));
jl_value_t *v = jl_new_structt((jl_datatype_t*)argv[0], argv[1]);
JL_GC_POP();
return v;
Expand Down
1 change: 1 addition & 0 deletions src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,7 @@ void jl_precompute_memoized_dt(jl_datatype_t *dt);
jl_datatype_t *jl_wrap_Type(jl_value_t *t); // x -> Type{x}
jl_value_t *jl_wrap_vararg(jl_value_t *t, jl_value_t *n);
void jl_assign_bits(void *dest, jl_value_t *bits) JL_NOTSAFEPOINT;
void set_nth_field(jl_datatype_t *st, void *v, size_t i, jl_value_t *rhs) JL_NOTSAFEPOINT;
jl_expr_t *jl_exprn(jl_sym_t *head, size_t n);
jl_function_t *jl_new_generic_function(jl_sym_t *name, jl_module_t *module);
jl_function_t *jl_new_generic_function_with_supertype(jl_sym_t *name, jl_module_t *module, jl_datatype_t *st);
Expand Down
16 changes: 16 additions & 0 deletions src/rtutils.c
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,22 @@ JL_DLLEXPORT jl_value_t *jl_value_ptr(jl_value_t *a)
return a;
}

// optimization of setfield which bypasses boxing of the idx (and checking field type validity)
JL_DLLEXPORT void jl_set_nth_field(jl_value_t *v, size_t idx0, jl_value_t *rhs)
{
jl_datatype_t *st = (jl_datatype_t*)jl_typeof(v);
if (!st->mutabl)
jl_errorf("setfield! immutable struct of type %s cannot be changed", jl_symbol_name(st->name->name));
if (idx0 >= jl_datatype_nfields(st))
jl_bounds_error_int(v, idx0 + 1);
//jl_value_t *ft = jl_field_type(st, idx0);
//if (!jl_isa(rhs, ft)) {
// jl_type_error("setfield!", ft, rhs);
//}
set_nth_field(st, (void*)v, idx0, rhs);
}


// parsing --------------------------------------------------------------------

int substr_isspace(char *p, char *pend)
Expand Down
41 changes: 18 additions & 23 deletions stdlib/Serialization/src/Serialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,7 @@ function serialize_any(s::AbstractSerializer, @nospecialize(x))
serialize_type(s, t)
write(s.io, x)
else
if t.mutable && nf > 0
if t.mutable
serialize_cycle(s, x) && return
serialize_type(s, t, true)
else
Expand Down Expand Up @@ -1288,36 +1288,31 @@ function deserialize(s::AbstractSerializer, t::DataType)
if nf == 0 && t.size > 0
# bits type
return read(s.io, t)
end
if nf == 0
return ccall(:jl_new_struct, Any, (Any,Any...), t)
elseif isbitstype(t)
if nf == 1
f1 = deserialize(s)
return ccall(:jl_new_struct, Any, (Any,Any...), t, f1)
elseif nf == 2
f1 = deserialize(s)
f2 = deserialize(s)
return ccall(:jl_new_struct, Any, (Any,Any...), t, f1, f2)
elseif nf == 3
f1 = deserialize(s)
f2 = deserialize(s)
f3 = deserialize(s)
return ccall(:jl_new_struct, Any, (Any,Any...), t, f1, f2, f3)
else
flds = Any[ deserialize(s) for i = 1:nf ]
return ccall(:jl_new_structv, Any, (Any,Ptr{Cvoid},UInt32), t, flds, nf)
end
else
elseif t.mutable
x = ccall(:jl_new_struct_uninit, Any, (Any,), t)
t.mutable && deserialize_cycle(s, x)
deserialize_cycle(s, x)
for i in 1:nf
tag = Int32(read(s.io, UInt8)::UInt8)
if tag != UNDEFREF_TAG
ccall(:jl_set_nth_field, Cvoid, (Any, Csize_t, Any), x, i-1, handle_deserialize(s, tag))
end
end
return x
elseif nf == 0
return ccall(:jl_new_struct_uninit, Any, (Any,), t)
else
na = nf
vflds = Vector{Any}(undef, nf)
for i in 1:nf
tag = Int32(read(s.io, UInt8)::UInt8)
if tag != UNDEFREF_TAG
f = handle_deserialize(s, tag)
na >= i && (vflds[i] = f)
else
na >= i && (na = i - 1) # rest of tail must be undefined values
end
end
return ccall(:jl_new_structv, Any, (Any, Ptr{Any}, UInt32), t, vflds, na)
end
end

Expand Down
11 changes: 11 additions & 0 deletions stdlib/Serialization/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,17 @@ create_serialization_stream() do s
@test C[1] === C[2]
end

mutable struct MSingle end
create_serialization_stream() do s
x = MSingle()
A = [x, x, MSingle()]
serialize(s, A)
seekstart(s)
C = deserialize(s)
@test A[1] === x === A[2] !== A[3]
@test x !== C[1] === C[2] !== C[3]
end

# Regex
create_serialization_stream() do s
r1 = r"a?b.*"
Expand Down
33 changes: 25 additions & 8 deletions test/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1085,8 +1085,8 @@ let
@test_throws BoundsError(z, -1) getfield(z, -1)
@test_throws BoundsError(z, 0) getfield(z, 0)
@test_throws BoundsError(z, 3) getfield(z, 3)

strct = LoadError("yofile", 0, "bad")
end
let strct = LoadError("yofile", 0, "bad")
@test nfields(strct) == 3 # sanity test
@test_throws BoundsError(strct, 10) getfield(strct, 10)
@test_throws ErrorException("setfield! immutable struct of type LoadError cannot be changed") setfield!(strct, 0, "")
Expand All @@ -1098,8 +1098,8 @@ let
@test getfield(strct, 1) == "yofile"
@test getfield(strct, 2) === 0
@test getfield(strct, 3) == "bad"

mstrct = TestMutable("melm", 1, nothing)
end
let mstrct = TestMutable("melm", 1, nothing)
@test Base.setproperty!(mstrct, :line, 8.0) === 8
@test mstrct.line === 8
@test_throws TypeError(:setfield!, "", Int, 8.0) setfield!(mstrct, :line, 8.0)
Expand All @@ -1112,6 +1112,14 @@ let
@test_throws BoundsError(mstrct, 0) setfield!(mstrct, 0, "")
@test_throws BoundsError(mstrct, 4) setfield!(mstrct, 4, "")
end
let strct = LoadError("yofile", 0, "bad")
@test_throws(ErrorException("setfield! immutable struct of type LoadError cannot be changed"),
ccall(:jl_set_nth_field, Cvoid, (Any, Csize_t, Any), strct, 0, ""))
end
let mstrct = TestMutable("melm", 1, nothing)
@test_throws(BoundsError(mstrct, 4),
ccall(:jl_set_nth_field, Cvoid, (Any, Csize_t, Any), mstrct, 3, ""))
end

# test getfield-overloading
function Base.getproperty(mstrct::TestMutable, p::Symbol)
Expand Down Expand Up @@ -3614,10 +3622,19 @@ end
return nothing
end
end
@test_throws TypeError f1()
@test_throws TypeError f2()
@test_throws TypeError f3()
@test_throws TypeError eval(Expr(:new, B, 1))
@test_throws TypeError("new", A, 1) f1()
@test_throws TypeError("new", A, 1) f2()
@test_throws TypeError("new", A, 1) f3()
@test_throws TypeError("new", A, 1) eval(Expr(:new, B, 1))

# some tests for handling of malformed syntax--these cases should not be possible in normal code
@test eval(Expr(:new, B, A())) == B(A())
@test_throws ErrorException("invalid struct allocation") eval(Expr(:new, B))
@test_throws ErrorException("invalid struct allocation") eval(Expr(:new, B, A(), A()))
@test_throws TypeError("new", DataType, Complex) eval(Expr(:new, Complex))
@test_throws TypeError("new", DataType, Complex.body) eval(Expr(:new, Complex.body))
@test_throws TypeError("new", DataType, Complex) eval(Expr(:splatnew, Complex, ()))
@test_throws TypeError("new", DataType, Complex.body) eval(Expr(:splatnew, Complex.body, ()))

end

Expand Down

0 comments on commit 77624c0

Please sign in to comment.