From 98c68a2519aa1063bb12c0cc550d01643dab20c6 Mon Sep 17 00:00:00 2001 From: Yichao Yu Date: Tue, 25 Aug 2020 01:36:43 -0400 Subject: [PATCH] Symbol clean up and strlen fixes (#37149) * Require the symbol to fit in ssize_t This is the limit we have on arrays and I don't think anyone is using symbols that is as long as half the full address space. * Remove some unnecessary embeded NUL byte checks when constructing symbols Almost all use of `jl_symbol_n` have the caller checked for embedded NUL byte already. This technically introduces a C API/ABI breakage but it shouldn't matter on all platforms we support. * Fix ccall of `strlen`. --- base/expr.jl | 5 ++--- base/gcutils.jl | 4 ++-- base/io.jl | 2 +- base/secretbuffer.jl | 2 +- src/datatype.c | 4 ++-- src/dump.c | 2 +- src/ircode.c | 4 ++-- src/julia.h | 2 +- src/julia_internal.h | 2 ++ src/staticdata.c | 2 +- src/symbol.c | 33 ++++++++++++++++++++++++--------- 11 files changed, 39 insertions(+), 23 deletions(-) diff --git a/base/expr.jl b/base/expr.jl index 11d620bb3b2eb..04bbb9c5859e2 100644 --- a/base/expr.jl +++ b/base/expr.jl @@ -9,11 +9,10 @@ Generates a symbol which will not conflict with other variable names. """ gensym() = ccall(:jl_gensym, Ref{Symbol}, ()) -gensym(s::String) = ccall(:jl_tagged_gensym, Ref{Symbol}, (Ptr{UInt8}, Int32), s, sizeof(s)) +gensym(s::String) = ccall(:jl_tagged_gensym, Ref{Symbol}, (Ptr{UInt8}, Csize_t), s, sizeof(s)) gensym(ss::String...) = map(gensym, ss) -gensym(s::Symbol) = - ccall(:jl_tagged_gensym, Ref{Symbol}, (Ptr{UInt8}, Int32), s, ccall(:strlen, Csize_t, (Ptr{UInt8},), s)) +gensym(s::Symbol) = ccall(:jl_tagged_gensym, Ref{Symbol}, (Ptr{UInt8}, Csize_t), s, -1 % Csize_t) """ @gensym diff --git a/base/gcutils.jl b/base/gcutils.jl index c403e4f20f626..51e3943877444 100644 --- a/base/gcutils.jl +++ b/base/gcutils.jl @@ -147,9 +147,9 @@ directly to `ccall` which counts as an explicit use.) julia> let x = "Hello" p = pointer(x) - GC.@preserve x @ccall strlen(p::Cstring)::Cint + Int(GC.@preserve x @ccall strlen(p::Cstring)::Csize_t) # Preferred alternative - @ccall strlen(x::Cstring)::Cint + Int(@ccall strlen(x::Cstring)::Csize_t) end 5 ``` diff --git a/base/io.jl b/base/io.jl index fdf802ad29860..421c2d9a60058 100644 --- a/base/io.jl +++ b/base/io.jl @@ -707,7 +707,7 @@ end function write(io::IO, s::Symbol) pname = unsafe_convert(Ptr{UInt8}, s) - return unsafe_write(io, pname, Int(ccall(:strlen, Csize_t, (Cstring,), pname))) + return unsafe_write(io, pname, ccall(:strlen, Csize_t, (Cstring,), pname)) end function write(to::IO, from::IO) diff --git a/base/secretbuffer.jl b/base/secretbuffer.jl index dae5a697a5bd7..02a133be088f0 100644 --- a/base/secretbuffer.jl +++ b/base/secretbuffer.jl @@ -79,7 +79,7 @@ function SecretBuffer!(d::Vector{UInt8}) s end -unsafe_SecretBuffer!(s::Cstring) = unsafe_SecretBuffer!(convert(Ptr{UInt8}, s), ccall(:strlen, Cint, (Cstring,), s)) +unsafe_SecretBuffer!(s::Cstring) = unsafe_SecretBuffer!(convert(Ptr{UInt8}, s), Int(ccall(:strlen, Csize_t, (Cstring,), s))) function unsafe_SecretBuffer!(p::Ptr{UInt8}, len=1) s = SecretBuffer(sizehint=len) for i in 1:len diff --git a/src/datatype.c b/src/datatype.c index dc09a4bf74bff..f0c643af11548 100644 --- a/src/datatype.c +++ b/src/datatype.c @@ -36,8 +36,8 @@ jl_sym_t *jl_demangle_typename(jl_sym_t *s) JL_NOTSAFEPOINT else len = (end-n) - 1; // extract `f` from `#f#...` if (is10digit(n[1])) - return jl_symbol_n(n, len+1); - return jl_symbol_n(&n[1], len); + return _jl_symbol(n, len+1); + return _jl_symbol(&n[1], len); } JL_DLLEXPORT jl_methtable_t *jl_new_method_table(jl_sym_t *name, jl_module_t *module) diff --git a/src/dump.c b/src/dump.c index 86a8ac0fde58e..fe1f7f68700c7 100644 --- a/src/dump.c +++ b/src/dump.c @@ -1932,7 +1932,7 @@ static jl_value_t *read_verify_mod_list(ios_t *s, jl_array_t *mod_list) uuid.hi = read_uint64(s); uuid.lo = read_uint64(s); uint64_t build_id = read_uint64(s); - jl_sym_t *sym = jl_symbol_n(name, len); + jl_sym_t *sym = _jl_symbol(name, len); jl_module_t *m = (jl_module_t*)jl_array_ptr_ref(mod_list, i); if (!m || !jl_is_module(m) || m->uuid.hi != uuid.hi || m->uuid.lo != uuid.lo || m->name != sym || m->build_id != build_id) { return jl_get_exceptionf(jl_errorexception_type, diff --git a/src/ircode.c b/src/ircode.c index a3f602dd2f2e1..c7ee418f42edf 100644 --- a/src/ircode.c +++ b/src/ircode.c @@ -924,7 +924,7 @@ JL_DLLEXPORT jl_array_t *jl_uncompress_argnames(jl_value_t *syms) JL_GC_PUSH1(&names); for (i = 0; i < len; i++) { size_t namelen = strlen(namestr); - jl_sym_t *name = jl_symbol_n(namestr, namelen); + jl_sym_t *name = _jl_symbol(namestr, namelen); jl_array_ptr_set(names, i, name); namestr += namelen + 1; } @@ -940,7 +940,7 @@ JL_DLLEXPORT jl_value_t *jl_uncompress_argname_n(jl_value_t *syms, size_t i) while (remaining) { size_t namelen = strlen(namestr); if (i-- == 0) { - jl_sym_t *name = jl_symbol_n(namestr, namelen); + jl_sym_t *name = _jl_symbol(namestr, namelen); return (jl_value_t*)name; } namestr += namelen + 1; diff --git a/src/julia.h b/src/julia.h index 166c6b59f2806..5738661c53710 100644 --- a/src/julia.h +++ b/src/julia.h @@ -1320,7 +1320,7 @@ JL_DLLEXPORT jl_sym_t *jl_symbol(const char *str) JL_NOTSAFEPOINT; JL_DLLEXPORT jl_sym_t *jl_symbol_lookup(const char *str) JL_NOTSAFEPOINT; JL_DLLEXPORT jl_sym_t *jl_symbol_n(const char *str, size_t len) JL_NOTSAFEPOINT; JL_DLLEXPORT jl_sym_t *jl_gensym(void); -JL_DLLEXPORT jl_sym_t *jl_tagged_gensym(const char *str, int32_t len); +JL_DLLEXPORT jl_sym_t *jl_tagged_gensym(const char *str, size_t len); JL_DLLEXPORT jl_sym_t *jl_get_root_symbol(void); JL_DLLEXPORT jl_value_t *jl_generic_function_def(jl_sym_t *name, jl_module_t *module, diff --git a/src/julia_internal.h b/src/julia_internal.h index 5c2a0f22f943c..4754c534fb3cb 100644 --- a/src/julia_internal.h +++ b/src/julia_internal.h @@ -1177,6 +1177,8 @@ void jl_register_fptrs(uint64_t sysimage_base, const struct _jl_sysimg_fptrs_t * # define jl_unreachable() ((void)jl_assume(0)) #endif +jl_sym_t *_jl_symbol(const char *str, size_t len) JL_NOTSAFEPOINT; + // Tools for locally disabling spurious compiler warnings // // Particular calls which are used elsewhere in the code include: diff --git a/src/staticdata.c b/src/staticdata.c index 5b8e2c816d5b5..72a9908e864dc 100644 --- a/src/staticdata.c +++ b/src/staticdata.c @@ -949,7 +949,7 @@ static void jl_read_symbols(jl_serializer_state *s) const char *str = (const char*)base; base += len + 1; //printf("symbol %3d: %s\n", len, str); - jl_sym_t *sym = jl_symbol_n(str, len); + jl_sym_t *sym = _jl_symbol(str, len); arraylist_push(&deser_sym, (void*)sym); } } diff --git a/src/symbol.c b/src/symbol.c index e76f5b5ed2a30..f1a4343a39e8e 100644 --- a/src/symbol.c +++ b/src/symbol.c @@ -17,6 +17,8 @@ extern "C" { static jl_sym_t *symtab = NULL; +#define MAX_SYM_LEN ((size_t)INTPTR_MAX - sizeof(jl_taggedvalue_t) - sizeof(jl_sym_t) - 1) + static uintptr_t hash_symbol(const char *str, size_t len) JL_NOTSAFEPOINT { return memhash(str, len) ^ ~(uintptr_t)0/3*2; @@ -71,8 +73,15 @@ static jl_sym_t *symtab_lookup(jl_sym_t **ptree, const char *str, size_t len, jl return node; } -static jl_sym_t *_jl_symbol(const char *str, size_t len) JL_NOTSAFEPOINT +jl_sym_t *_jl_symbol(const char *str, size_t len) JL_NOTSAFEPOINT // (or throw) { +#ifndef __clang_analyzer__ + // Hide the error throwing from the analyser since there isn't a way to express + // "safepoint only when throwing error" currently. + if (len > MAX_SYM_LEN) + jl_exceptionf(jl_argumenterror_type, "Symbol name too long"); +#endif + assert(!memchr(str, 0, len)); jl_sym_t **slot; jl_sym_t *node = symtab_lookup(&symtab, str, len, &slot); if (node == NULL) { @@ -89,12 +98,12 @@ static jl_sym_t *_jl_symbol(const char *str, size_t len) JL_NOTSAFEPOINT return node; } -JL_DLLEXPORT jl_sym_t *jl_symbol(const char *str) +JL_DLLEXPORT jl_sym_t *jl_symbol(const char *str) JL_NOTSAFEPOINT // (or throw) { return _jl_symbol(str, strlen(str)); } -JL_DLLEXPORT jl_sym_t *jl_symbol_lookup(const char *str) +JL_DLLEXPORT jl_sym_t *jl_symbol_lookup(const char *str) JL_NOTSAFEPOINT { return symtab_lookup(&symtab, str, strlen(str), NULL); } @@ -125,13 +134,19 @@ JL_DLLEXPORT jl_sym_t *jl_gensym(void) return jl_symbol(n); } -JL_DLLEXPORT jl_sym_t *jl_tagged_gensym(const char *str, int32_t len) +JL_DLLEXPORT jl_sym_t *jl_tagged_gensym(const char *str, size_t len) { - char gs_name[14]; - if (memchr(str, 0, len)) + if (len == (size_t)-1) { + len = strlen(str); + } + else if (memchr(str, 0, len)) { jl_exceptionf(jl_argumenterror_type, "Symbol name may not contain \\0"); - char *name = (char*) (len >= 256 ? malloc_s(sizeof(gs_name) + len + 3) : - alloca(sizeof(gs_name) + len + 3)); + } + char gs_name[14]; + size_t alloc_len = sizeof(gs_name) + len + 3; + if (len > MAX_SYM_LEN || alloc_len > MAX_SYM_LEN) + jl_exceptionf(jl_argumenterror_type, "Symbol name too long"); + char *name = (char*)(len >= 256 ? malloc_s(alloc_len) : alloca(alloc_len)); char *n; name[0] = '#'; name[1] = '#'; @@ -140,7 +155,7 @@ JL_DLLEXPORT jl_sym_t *jl_tagged_gensym(const char *str, int32_t len) uint32_t ctr = jl_atomic_fetch_add(&gs_ctr, 1); n = uint2str(gs_name, sizeof(gs_name), ctr, 10); memcpy(name + 3 + len, n, sizeof(gs_name) - (n - gs_name)); - jl_sym_t *sym = _jl_symbol(name, len + 3 + sizeof(gs_name) - (n - gs_name)- 1); + jl_sym_t *sym = _jl_symbol(name, alloc_len - (n - gs_name)- 1); if (len >= 256) free(name); return sym;