Skip to content

Commit

Permalink
Symbol clean up and strlen fixes (#37149)
Browse files Browse the repository at this point in the history
* Require the symbol to fit in ssize_t

  This is the limit we have on arrays and I don't think anyone is using
  symbols that is as long as half the full address space.

* Remove some unnecessary embeded NUL byte checks when constructing symbols

  Almost all use of `jl_symbol_n` have the caller checked for embedded NUL byte already.
  This technically introduces a C API/ABI breakage but it shouldn't matter on all platforms we support.

* Fix ccall of `strlen`.
  • Loading branch information
yuyichao committed Aug 25, 2020
1 parent 5ef6d0f commit 98c68a2
Show file tree
Hide file tree
Showing 11 changed files with 39 additions and 23 deletions.
5 changes: 2 additions & 3 deletions base/expr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@ Generates a symbol which will not conflict with other variable names.
"""
gensym() = ccall(:jl_gensym, Ref{Symbol}, ())

gensym(s::String) = ccall(:jl_tagged_gensym, Ref{Symbol}, (Ptr{UInt8}, Int32), s, sizeof(s))
gensym(s::String) = ccall(:jl_tagged_gensym, Ref{Symbol}, (Ptr{UInt8}, Csize_t), s, sizeof(s))

gensym(ss::String...) = map(gensym, ss)
gensym(s::Symbol) =
ccall(:jl_tagged_gensym, Ref{Symbol}, (Ptr{UInt8}, Int32), s, ccall(:strlen, Csize_t, (Ptr{UInt8},), s))
gensym(s::Symbol) = ccall(:jl_tagged_gensym, Ref{Symbol}, (Ptr{UInt8}, Csize_t), s, -1 % Csize_t)

"""
@gensym
Expand Down
4 changes: 2 additions & 2 deletions base/gcutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,9 @@ directly to `ccall` which counts as an explicit use.)
julia> let
x = "Hello"
p = pointer(x)
GC.@preserve x @ccall strlen(p::Cstring)::Cint
Int(GC.@preserve x @ccall strlen(p::Cstring)::Csize_t)
# Preferred alternative
@ccall strlen(x::Cstring)::Cint
Int(@ccall strlen(x::Cstring)::Csize_t)
end
5
```
Expand Down
2 changes: 1 addition & 1 deletion base/io.jl
Original file line number Diff line number Diff line change
Expand Up @@ -707,7 +707,7 @@ end

function write(io::IO, s::Symbol)
pname = unsafe_convert(Ptr{UInt8}, s)
return unsafe_write(io, pname, Int(ccall(:strlen, Csize_t, (Cstring,), pname)))
return unsafe_write(io, pname, ccall(:strlen, Csize_t, (Cstring,), pname))
end

function write(to::IO, from::IO)
Expand Down
2 changes: 1 addition & 1 deletion base/secretbuffer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ function SecretBuffer!(d::Vector{UInt8})
s
end

unsafe_SecretBuffer!(s::Cstring) = unsafe_SecretBuffer!(convert(Ptr{UInt8}, s), ccall(:strlen, Cint, (Cstring,), s))
unsafe_SecretBuffer!(s::Cstring) = unsafe_SecretBuffer!(convert(Ptr{UInt8}, s), Int(ccall(:strlen, Csize_t, (Cstring,), s)))
function unsafe_SecretBuffer!(p::Ptr{UInt8}, len=1)
s = SecretBuffer(sizehint=len)
for i in 1:len
Expand Down
4 changes: 2 additions & 2 deletions src/datatype.c
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ jl_sym_t *jl_demangle_typename(jl_sym_t *s) JL_NOTSAFEPOINT
else
len = (end-n) - 1; // extract `f` from `#f#...`
if (is10digit(n[1]))
return jl_symbol_n(n, len+1);
return jl_symbol_n(&n[1], len);
return _jl_symbol(n, len+1);
return _jl_symbol(&n[1], len);
}

JL_DLLEXPORT jl_methtable_t *jl_new_method_table(jl_sym_t *name, jl_module_t *module)
Expand Down
2 changes: 1 addition & 1 deletion src/dump.c
Original file line number Diff line number Diff line change
Expand Up @@ -1932,7 +1932,7 @@ static jl_value_t *read_verify_mod_list(ios_t *s, jl_array_t *mod_list)
uuid.hi = read_uint64(s);
uuid.lo = read_uint64(s);
uint64_t build_id = read_uint64(s);
jl_sym_t *sym = jl_symbol_n(name, len);
jl_sym_t *sym = _jl_symbol(name, len);
jl_module_t *m = (jl_module_t*)jl_array_ptr_ref(mod_list, i);
if (!m || !jl_is_module(m) || m->uuid.hi != uuid.hi || m->uuid.lo != uuid.lo || m->name != sym || m->build_id != build_id) {
return jl_get_exceptionf(jl_errorexception_type,
Expand Down
4 changes: 2 additions & 2 deletions src/ircode.c
Original file line number Diff line number Diff line change
Expand Up @@ -924,7 +924,7 @@ JL_DLLEXPORT jl_array_t *jl_uncompress_argnames(jl_value_t *syms)
JL_GC_PUSH1(&names);
for (i = 0; i < len; i++) {
size_t namelen = strlen(namestr);
jl_sym_t *name = jl_symbol_n(namestr, namelen);
jl_sym_t *name = _jl_symbol(namestr, namelen);
jl_array_ptr_set(names, i, name);
namestr += namelen + 1;
}
Expand All @@ -940,7 +940,7 @@ JL_DLLEXPORT jl_value_t *jl_uncompress_argname_n(jl_value_t *syms, size_t i)
while (remaining) {
size_t namelen = strlen(namestr);
if (i-- == 0) {
jl_sym_t *name = jl_symbol_n(namestr, namelen);
jl_sym_t *name = _jl_symbol(namestr, namelen);
return (jl_value_t*)name;
}
namestr += namelen + 1;
Expand Down
2 changes: 1 addition & 1 deletion src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -1320,7 +1320,7 @@ JL_DLLEXPORT jl_sym_t *jl_symbol(const char *str) JL_NOTSAFEPOINT;
JL_DLLEXPORT jl_sym_t *jl_symbol_lookup(const char *str) JL_NOTSAFEPOINT;
JL_DLLEXPORT jl_sym_t *jl_symbol_n(const char *str, size_t len) JL_NOTSAFEPOINT;
JL_DLLEXPORT jl_sym_t *jl_gensym(void);
JL_DLLEXPORT jl_sym_t *jl_tagged_gensym(const char *str, int32_t len);
JL_DLLEXPORT jl_sym_t *jl_tagged_gensym(const char *str, size_t len);
JL_DLLEXPORT jl_sym_t *jl_get_root_symbol(void);
JL_DLLEXPORT jl_value_t *jl_generic_function_def(jl_sym_t *name,
jl_module_t *module,
Expand Down
2 changes: 2 additions & 0 deletions src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -1177,6 +1177,8 @@ void jl_register_fptrs(uint64_t sysimage_base, const struct _jl_sysimg_fptrs_t *
# define jl_unreachable() ((void)jl_assume(0))
#endif

jl_sym_t *_jl_symbol(const char *str, size_t len) JL_NOTSAFEPOINT;

// Tools for locally disabling spurious compiler warnings
//
// Particular calls which are used elsewhere in the code include:
Expand Down
2 changes: 1 addition & 1 deletion src/staticdata.c
Original file line number Diff line number Diff line change
Expand Up @@ -949,7 +949,7 @@ static void jl_read_symbols(jl_serializer_state *s)
const char *str = (const char*)base;
base += len + 1;
//printf("symbol %3d: %s\n", len, str);
jl_sym_t *sym = jl_symbol_n(str, len);
jl_sym_t *sym = _jl_symbol(str, len);
arraylist_push(&deser_sym, (void*)sym);
}
}
Expand Down
33 changes: 24 additions & 9 deletions src/symbol.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ extern "C" {

static jl_sym_t *symtab = NULL;

#define MAX_SYM_LEN ((size_t)INTPTR_MAX - sizeof(jl_taggedvalue_t) - sizeof(jl_sym_t) - 1)

static uintptr_t hash_symbol(const char *str, size_t len) JL_NOTSAFEPOINT
{
return memhash(str, len) ^ ~(uintptr_t)0/3*2;
Expand Down Expand Up @@ -71,8 +73,15 @@ static jl_sym_t *symtab_lookup(jl_sym_t **ptree, const char *str, size_t len, jl
return node;
}

static jl_sym_t *_jl_symbol(const char *str, size_t len) JL_NOTSAFEPOINT
jl_sym_t *_jl_symbol(const char *str, size_t len) JL_NOTSAFEPOINT // (or throw)
{
#ifndef __clang_analyzer__
// Hide the error throwing from the analyser since there isn't a way to express
// "safepoint only when throwing error" currently.
if (len > MAX_SYM_LEN)
jl_exceptionf(jl_argumenterror_type, "Symbol name too long");
#endif
assert(!memchr(str, 0, len));
jl_sym_t **slot;
jl_sym_t *node = symtab_lookup(&symtab, str, len, &slot);
if (node == NULL) {
Expand All @@ -89,12 +98,12 @@ static jl_sym_t *_jl_symbol(const char *str, size_t len) JL_NOTSAFEPOINT
return node;
}

JL_DLLEXPORT jl_sym_t *jl_symbol(const char *str)
JL_DLLEXPORT jl_sym_t *jl_symbol(const char *str) JL_NOTSAFEPOINT // (or throw)
{
return _jl_symbol(str, strlen(str));
}

JL_DLLEXPORT jl_sym_t *jl_symbol_lookup(const char *str)
JL_DLLEXPORT jl_sym_t *jl_symbol_lookup(const char *str) JL_NOTSAFEPOINT
{
return symtab_lookup(&symtab, str, strlen(str), NULL);
}
Expand Down Expand Up @@ -125,13 +134,19 @@ JL_DLLEXPORT jl_sym_t *jl_gensym(void)
return jl_symbol(n);
}

JL_DLLEXPORT jl_sym_t *jl_tagged_gensym(const char *str, int32_t len)
JL_DLLEXPORT jl_sym_t *jl_tagged_gensym(const char *str, size_t len)
{
char gs_name[14];
if (memchr(str, 0, len))
if (len == (size_t)-1) {
len = strlen(str);
}
else if (memchr(str, 0, len)) {
jl_exceptionf(jl_argumenterror_type, "Symbol name may not contain \\0");
char *name = (char*) (len >= 256 ? malloc_s(sizeof(gs_name) + len + 3) :
alloca(sizeof(gs_name) + len + 3));
}
char gs_name[14];
size_t alloc_len = sizeof(gs_name) + len + 3;
if (len > MAX_SYM_LEN || alloc_len > MAX_SYM_LEN)
jl_exceptionf(jl_argumenterror_type, "Symbol name too long");
char *name = (char*)(len >= 256 ? malloc_s(alloc_len) : alloca(alloc_len));
char *n;
name[0] = '#';
name[1] = '#';
Expand All @@ -140,7 +155,7 @@ JL_DLLEXPORT jl_sym_t *jl_tagged_gensym(const char *str, int32_t len)
uint32_t ctr = jl_atomic_fetch_add(&gs_ctr, 1);
n = uint2str(gs_name, sizeof(gs_name), ctr, 10);
memcpy(name + 3 + len, n, sizeof(gs_name) - (n - gs_name));
jl_sym_t *sym = _jl_symbol(name, len + 3 + sizeof(gs_name) - (n - gs_name)- 1);
jl_sym_t *sym = _jl_symbol(name, alloc_len - (n - gs_name)- 1);
if (len >= 256)
free(name);
return sym;
Expand Down

0 comments on commit 98c68a2

Please sign in to comment.