Skip to content

Commit

Permalink
[Distributed] make finalizer messages threadsafe
Browse files Browse the repository at this point in the history
  • Loading branch information
vchuravy committed Sep 15, 2021
1 parent f8d3bd2 commit f8b7064
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 32 deletions.
11 changes: 8 additions & 3 deletions stdlib/Distributed/src/cluster.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,10 @@ end
@enum WorkerState W_CREATED W_CONNECTED W_TERMINATING W_TERMINATED
mutable struct Worker
id::Int
del_msgs::Array{Any,1}
msg_lock::Threads.ReentrantLock # Lock for del_msgs, add_msgs, and gcflag
del_msgs::Array{Any,1} # XXX: Could del_msgs and add_msgs be Channels?
add_msgs::Array{Any,1}
gcflag::Bool
gcflag::Bool # XXX: Make this atomic?
state::WorkerState
c_state::Condition # wait for state changes
ct_time::Float64 # creation time
Expand Down Expand Up @@ -133,7 +134,7 @@ mutable struct Worker
if haskey(map_pid_wrkr, id)
return map_pid_wrkr[id]
end
w=new(id, [], [], false, W_CREATED, Condition(), time(), conn_func)
w=new(id, Threads.ReentrantLock(), [], [], false, W_CREATED, Condition(), time(), conn_func)
w.initialized = Event()
register_worker(w)
w
Expand Down Expand Up @@ -471,6 +472,10 @@ function addprocs_locked(manager::ClusterManager; kwargs...)
# The `launch` method should add an object of type WorkerConfig for every
# worker launched. It provides information required on how to connect
# to it.

# FIXME: launched should be a Channel, launch_ntfy should be a Threads.Condition
# but both are part of the public interface. This means we currently can't use
# `Threads.@spawn` in the code below.
launched = WorkerConfig[]
launch_ntfy = Condition()

Expand Down
43 changes: 25 additions & 18 deletions stdlib/Distributed/src/messages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,23 +126,30 @@ function flush_gc_msgs(w::Worker)
if !isdefined(w, :w_stream)
return
end
w.gcflag = false
new_array = Any[]
msgs = w.add_msgs
w.add_msgs = new_array
if !isempty(msgs)
remote_do(add_clients, w, msgs)
end
add_msgs = nothing
del_msgs = nothing
@lock w.msg_lock begin
if !w.gcflag # No work needed for this worker
return
end
w.gcflag = false
if !isempty(w.add_msgs)
add_msgs = w.add_msgs
w.add_msgs = Any[]
end

# del_msgs gets populated by finalizers, so be very careful here about ordering of allocations
# XXX: threading requires this to be atomic
new_array = Any[]
msgs = w.del_msgs
w.del_msgs = new_array
if !isempty(msgs)
#print("sending delete of $msgs\n")
remote_do(del_clients, w, msgs)
if !isempty(w.del_msgs)
del_msgs = w.del_msgs
w.del_msgs = Any[]
end
end
if add_msgs !== nothing
remote_do(add_clients, w, add_msgs)
end
if del_msgs !== nothing
remote_do(del_clients, w, del_msgs)
end
return
end

# Boundary inserted between messages on the wire, used for recovering
Expand Down Expand Up @@ -174,7 +181,7 @@ function send_msg_(w::Worker, header, msg, now::Bool)
invokelatest(serialize_msg, w.w_serializer, msg) # io is wrapped in w_serializer
write(io, MSG_BOUNDARY)

if !now && w.gcflag
if !now && w.gcflag # XXX: `w.gcflag` is used outside lock
flush_gc_msgs(w)
else
flush(io)
Expand All @@ -187,8 +194,8 @@ end
function flush_gc_msgs()
try
for w in (PGRP::ProcessGroup).workers
if isa(w,Worker) && w.gcflag && (w.state == W_CONNECTED)
flush_gc_msgs(w)
if isa(w,Worker) && (w.state == W_CONNECTED)
flush_gc_msgs(w) # checks w.gcflag after acquiring w.msg_lock
end
end
catch e
Expand Down
55 changes: 44 additions & 11 deletions stdlib/Distributed/src/remotecall.jl
Original file line number Diff line number Diff line change
Expand Up @@ -256,14 +256,27 @@ function del_clients(pairs::Vector)
end
end

const any_gc_flag = Condition()
# The task below is coalescing the `flush_gc_msgs` call
# across multiple producers, see `send_del_client`,
# and `send_add_client`.
# XXX: Is this worth the additional complexity?
# `flush_gc_msgs` has to iterate over all connected workers.
const any_gc_flag = Threads.Condition()
function start_gc_msgs_task()
errormonitor(@async while true
wait(any_gc_flag)
flush_gc_msgs()
end)
errormonitor(
Threads.@spawn begin
while true
lock(any_gc_flag) do
# this might miss events
wait(any_gc_flag)
end
flush_gc_msgs() # handles throws internally
end
end
)
end

# Function can be called within a finalizer
function send_del_client(rr)
if rr.where == myid()
del_client(rr)
Expand All @@ -281,11 +294,27 @@ function send_del_client_no_lock(rr)
end
end

function publish_del_msg!(w::Worker, msg)
lock(w.msg_lock) do
push!(w.del_msgs, msg)
w.gcflag = true
end
lock(any_gc_flag) do
notify(any_gc_flag)
end
end

function process_worker(rr)
w = worker_from_id(rr.where)::Worker
push!(w.del_msgs, (remoteref_id(rr), myid()))
w.gcflag = true
notify(any_gc_flag)
msg = (remoteref_id(rr), myid())

# Needs to aquire a lock on the del_msg queue
T = Threads.@spawn begin
publish_del_msg!($w, $msg)
end
Base.errormonitor(T)

return
end

function add_client(id, client)
Expand All @@ -310,9 +339,13 @@ function send_add_client(rr::AbstractRemoteRef, i)
# to the processor that owns the remote ref. it will add_client
# itself inside deserialize().
w = worker_from_id(rr.where)
push!(w.add_msgs, (remoteref_id(rr), i))
w.gcflag = true
notify(any_gc_flag)
lock(w.msg_lock) do
push!(w.add_msgs, (remoteref_id(rr), i))
w.gcflag = true
end
lock(any_gc_flag) do
notify(any_gc_flag)
end
end
end

Expand Down

0 comments on commit f8b7064

Please sign in to comment.