Skip to content

Commit

Permalink
Fix (and test) for a deadlock that can happen while waiting for proto…
Browse files Browse the repository at this point in the history
…col info (#12633)

# Description

The local socket PR introduced a `Waitable` type, which could either
hold a value or be waited on until a value is available. Unlike a
channel, it would always return that value once set.

However, one issue with this design was that there was no way to detect
whether a value would ever be written. This splits the writer into a
different type `WaitableMut`, so that when it is dropped, waiting
threads can fail (because they'll never get a value).

# Tests + Formatting

A test has been added to `stress_internals` to make sure this fails in
the right way.

- 🟢 `toolkit fmt`
- 🟢 `toolkit clippy`
- 🟢 `toolkit test`
- 🟢 `toolkit test stdlib`
  • Loading branch information
devyn committed Apr 24, 2024
1 parent 0f645b3 commit c52884b
Show file tree
Hide file tree
Showing 8 changed files with 182 additions and 55 deletions.
10 changes: 7 additions & 3 deletions crates/nu-plugin/src/plugin/interface/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::{
PluginCall, PluginCallId, PluginCallResponse, PluginCustomValue, PluginInput, PluginOption,
PluginOutput, ProtocolInfo,
},
util::Waitable,
util::{Waitable, WaitableMut},
};
use nu_protocol::{
engine::Closure, Config, IntoInterruptiblePipelineData, LabeledError, ListStream, PipelineData,
Expand Down Expand Up @@ -85,6 +85,8 @@ impl std::fmt::Debug for EngineInterfaceState {
pub struct EngineInterfaceManager {
/// Shared state
state: Arc<EngineInterfaceState>,
/// The writer for protocol info
protocol_info_mut: WaitableMut<Arc<ProtocolInfo>>,
/// Channel to send received PluginCalls to. This is removed after `Goodbye` is received.
plugin_call_sender: Option<mpsc::Sender<ReceivedPluginCall>>,
/// Receiver for PluginCalls. This is usually taken after initialization
Expand All @@ -103,15 +105,17 @@ impl EngineInterfaceManager {
pub(crate) fn new(writer: impl PluginWrite<PluginOutput> + 'static) -> EngineInterfaceManager {
let (plug_tx, plug_rx) = mpsc::channel();
let (subscription_tx, subscription_rx) = mpsc::channel();
let protocol_info_mut = WaitableMut::new();

EngineInterfaceManager {
state: Arc::new(EngineInterfaceState {
protocol_info: Waitable::new(),
protocol_info: protocol_info_mut.reader(),
engine_call_id_sequence: Sequence::default(),
stream_id_sequence: Sequence::default(),
engine_call_subscription_sender: subscription_tx,
writer: Box::new(writer),
}),
protocol_info_mut,
plugin_call_sender: Some(plug_tx),
plugin_call_receiver: Some(plug_rx),
engine_call_subscriptions: BTreeMap::new(),
Expand Down Expand Up @@ -233,7 +237,7 @@ impl InterfaceManager for EngineInterfaceManager {
match input {
PluginInput::Hello(info) => {
let info = Arc::new(info);
self.state.protocol_info.set(info.clone())?;
self.protocol_info_mut.set(info.clone())?;

let local_info = ProtocolInfo::default();
if local_info.is_compatible_with(&info)? {
Expand Down
3 changes: 1 addition & 2 deletions crates/nu-plugin/src/plugin/interface/engine/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,7 @@ fn manager_consume_errors_on_sending_other_messages_before_hello() -> Result<(),

fn set_default_protocol_info(manager: &mut EngineInterfaceManager) -> Result<(), ShellError> {
manager
.state
.protocol_info
.protocol_info_mut
.set(Arc::new(ProtocolInfo::default()))
}

Expand Down
19 changes: 15 additions & 4 deletions crates/nu-plugin/src/plugin/interface/plugin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::{
PluginOutput, ProtocolInfo, StreamId, StreamMessage,
},
sequence::Sequence,
util::{with_custom_values_in, Waitable},
util::{with_custom_values_in, Waitable, WaitableMut},
};
use nu_protocol::{
ast::Operator, CustomValue, IntoInterruptiblePipelineData, IntoSpanned, ListStream,
Expand Down Expand Up @@ -138,6 +138,8 @@ impl Drop for PluginCallState {
pub struct PluginInterfaceManager {
/// Shared state
state: Arc<PluginInterfaceState>,
/// The writer for protocol info
protocol_info_mut: WaitableMut<Arc<ProtocolInfo>>,
/// Manages stream messages and state
stream_manager: StreamManager,
/// State related to plugin calls
Expand All @@ -159,18 +161,20 @@ impl PluginInterfaceManager {
writer: impl PluginWrite<PluginInput> + 'static,
) -> PluginInterfaceManager {
let (subscription_tx, subscription_rx) = mpsc::channel();
let protocol_info_mut = WaitableMut::new();

PluginInterfaceManager {
state: Arc::new(PluginInterfaceState {
source,
process: pid.map(PluginProcess::new),
protocol_info: Waitable::new(),
protocol_info: protocol_info_mut.reader(),
plugin_call_id_sequence: Sequence::default(),
stream_id_sequence: Sequence::default(),
plugin_call_subscription_sender: subscription_tx,
error: OnceLock::new(),
writer: Box::new(writer),
}),
protocol_info_mut,
stream_manager: StreamManager::new(),
plugin_call_states: BTreeMap::new(),
plugin_call_subscription_receiver: subscription_rx,
Expand Down Expand Up @@ -464,7 +468,7 @@ impl InterfaceManager for PluginInterfaceManager {
match input {
PluginOutput::Hello(info) => {
let info = Arc::new(info);
self.state.protocol_info.set(info.clone())?;
self.protocol_info_mut.set(info.clone())?;

let local_info = ProtocolInfo::default();
if local_info.is_compatible_with(&info)? {
Expand Down Expand Up @@ -631,7 +635,14 @@ impl PluginInterface {

/// Get the protocol info for the plugin. Will block to receive `Hello` if not received yet.
pub fn protocol_info(&self) -> Result<Arc<ProtocolInfo>, ShellError> {
self.state.protocol_info.get()
self.state.protocol_info.get().and_then(|info| {
info.ok_or_else(|| ShellError::PluginFailedToLoad {
msg: format!(
"Failed to get protocol info (`Hello` message) from the `{}` plugin",
self.state.source.identity.name()
),
})
})
}

/// Write the protocol info. This should be done after initialization
Expand Down
3 changes: 1 addition & 2 deletions crates/nu-plugin/src/plugin/interface/plugin/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,7 @@ fn manager_consume_errors_on_sending_other_messages_before_hello() -> Result<(),

fn set_default_protocol_info(manager: &mut PluginInterfaceManager) -> Result<(), ShellError> {
manager
.state
.protocol_info
.protocol_info_mut
.set(Arc::new(ProtocolInfo::default()))
}

Expand Down
2 changes: 1 addition & 1 deletion crates/nu-plugin/src/util/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ mod waitable;
mod with_custom_values_in;

pub(crate) use mutable_cow::*;
pub use waitable::Waitable;
pub use waitable::*;
pub use with_custom_values_in::*;
167 changes: 124 additions & 43 deletions crates/nu-plugin/src/util/waitable.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,36 @@
use std::sync::{
atomic::{AtomicBool, Ordering},
Condvar, Mutex, MutexGuard, PoisonError,
Arc, Condvar, Mutex, MutexGuard, PoisonError,
};

use nu_protocol::ShellError;

/// A container that may be empty, and allows threads to block until it has a value.
#[derive(Debug)]
/// A shared container that may be empty, and allows threads to block until it has a value.
///
/// This side is read-only - use [`WaitableMut`] on threads that might write a value.
#[derive(Debug, Clone)]
pub struct Waitable<T: Clone + Send> {
shared: Arc<WaitableShared<T>>,
}

#[derive(Debug)]
pub struct WaitableMut<T: Clone + Send> {
shared: Arc<WaitableShared<T>>,
}

#[derive(Debug)]
struct WaitableShared<T: Clone + Send> {
is_set: AtomicBool,
mutex: Mutex<Option<T>>,
mutex: Mutex<SyncState<T>>,
condvar: Condvar,
}

#[derive(Debug)]
struct SyncState<T: Clone + Send> {
writers: usize,
value: Option<T>,
}

#[track_caller]
fn fail_if_poisoned<'a, T>(
result: Result<MutexGuard<'a, T>, PoisonError<MutexGuard<'a, T>>>,
Expand All @@ -26,75 +44,138 @@ fn fail_if_poisoned<'a, T>(
}
}

impl<T: Clone + Send> Waitable<T> {
/// Create a new empty `Waitable`.
pub fn new() -> Waitable<T> {
impl<T: Clone + Send> WaitableMut<T> {
/// Create a new empty `WaitableMut`. Call [`.reader()`] to get [`Waitable`].
pub fn new() -> WaitableMut<T> {
WaitableMut {
shared: Arc::new(WaitableShared {
is_set: AtomicBool::new(false),
mutex: Mutex::new(SyncState {
writers: 1,
value: None,
}),
condvar: Condvar::new(),
}),
}
}

pub fn reader(&self) -> Waitable<T> {
Waitable {
is_set: AtomicBool::new(false),
mutex: Mutex::new(None),
condvar: Condvar::new(),
shared: self.shared.clone(),
}
}

/// Set the value and let waiting threads know.
#[track_caller]
pub fn set(&self, value: T) -> Result<(), ShellError> {
let mut sync_state = fail_if_poisoned(self.shared.mutex.lock())?;
self.shared.is_set.store(true, Ordering::SeqCst);
sync_state.value = Some(value);
self.shared.condvar.notify_all();
Ok(())
}
}

impl<T: Clone + Send> Default for WaitableMut<T> {
fn default() -> Self {
Self::new()
}
}

impl<T: Clone + Send> Clone for WaitableMut<T> {
fn clone(&self) -> Self {
let shared = self.shared.clone();
shared
.mutex
.lock()
.expect("failed to lock mutex to increment writers")
.writers += 1;
WaitableMut { shared }
}
}

impl<T: Clone + Send> Drop for WaitableMut<T> {
fn drop(&mut self) {
// Decrement writers...
if let Ok(mut sync_state) = self.shared.mutex.lock() {
sync_state.writers = sync_state
.writers
.checked_sub(1)
.expect("would decrement writers below zero");
}
// and notify waiting threads so they have a chance to see it.
self.shared.condvar.notify_all();
}
}

impl<T: Clone + Send> Waitable<T> {
/// Wait for a value to be available and then clone it.
///
/// Returns `Ok(None)` if there are no writers left that could possibly place a value.
#[track_caller]
pub fn get(&self) -> Result<T, ShellError> {
let guard = fail_if_poisoned(self.mutex.lock())?;
if let Some(value) = (*guard).clone() {
Ok(value)
pub fn get(&self) -> Result<Option<T>, ShellError> {
let sync_state = fail_if_poisoned(self.shared.mutex.lock())?;
if let Some(value) = sync_state.value.clone() {
Ok(Some(value))
} else if sync_state.writers == 0 {
// There can't possibly be a value written, so no point in waiting.
Ok(None)
} else {
let guard = fail_if_poisoned(self.condvar.wait_while(guard, |g| g.is_none()))?;
Ok((*guard)
.clone()
.expect("checked already for Some but it was None"))
let sync_state = fail_if_poisoned(
self.shared
.condvar
.wait_while(sync_state, |g| g.writers > 0 && g.value.is_none()),
)?;
Ok(sync_state.value.clone())
}
}

/// Clone the value if one is available, but don't wait if not.
#[track_caller]
pub fn try_get(&self) -> Result<Option<T>, ShellError> {
let guard = fail_if_poisoned(self.mutex.lock())?;
Ok((*guard).clone())
let sync_state = fail_if_poisoned(self.shared.mutex.lock())?;
Ok(sync_state.value.clone())
}

/// Returns true if value is available.
#[track_caller]
pub fn is_set(&self) -> bool {
self.is_set.load(Ordering::SeqCst)
}

/// Set the value and let waiting threads know.
#[track_caller]
pub fn set(&self, value: T) -> Result<(), ShellError> {
let mut guard = fail_if_poisoned(self.mutex.lock())?;
self.is_set.store(true, Ordering::SeqCst);
*guard = Some(value);
self.condvar.notify_all();
Ok(())
}
}

impl<T: Clone + Send> Default for Waitable<T> {
fn default() -> Self {
Self::new()
self.shared.is_set.load(Ordering::SeqCst)
}
}

#[test]
fn set_from_other_thread() -> Result<(), ShellError> {
use std::sync::Arc;

let waitable = Arc::new(Waitable::new());
let waitable_clone = waitable.clone();
let waitable_mut = WaitableMut::new();
let waitable = waitable_mut.reader();

assert!(!waitable.is_set());

std::thread::spawn(move || {
waitable_clone.set(42).expect("error on set");
waitable_mut.set(42).expect("error on set");
});

assert_eq!(42, waitable.get()?);
assert_eq!(Some(42), waitable.get()?);
assert_eq!(Some(42), waitable.try_get()?);
assert!(waitable.is_set());
Ok(())
}

#[test]
fn dont_deadlock_if_waiting_without_writer() {
use std::time::Duration;

let (tx, rx) = std::sync::mpsc::channel();
let writer = WaitableMut::<()>::new();
let waitable = writer.reader();
// Ensure there are no writers
drop(writer);
std::thread::spawn(move || {
let _ = tx.send(waitable.get());
});
let result = rx
.recv_timeout(Duration::from_secs(10))
.expect("timed out")
.expect("error");
assert!(result.is_none());
}
7 changes: 7 additions & 0 deletions crates/nu_plugin_stress_internals/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use serde_json::{json, Value};
struct Options {
refuse_local_socket: bool,
advertise_local_socket: bool,
exit_before_hello: bool,
exit_early: bool,
wrong_version: bool,
local_socket_path: Option<String>,
Expand All @@ -28,6 +29,7 @@ pub fn main() -> Result<(), Box<dyn Error>> {
let mut opts = Options {
refuse_local_socket: has_env("STRESS_REFUSE_LOCAL_SOCKET"),
advertise_local_socket: has_env("STRESS_ADVERTISE_LOCAL_SOCKET"),
exit_before_hello: has_env("STRESS_EXIT_BEFORE_HELLO"),
exit_early: has_env("STRESS_EXIT_EARLY"),
wrong_version: has_env("STRESS_WRONG_VERSION"),
local_socket_path: None,
Expand Down Expand Up @@ -75,6 +77,11 @@ pub fn main() -> Result<(), Box<dyn Error>> {
output.flush()?;
}

// Test exiting without `Hello`
if opts.exit_before_hello {
std::process::exit(1)
}

// Send `Hello` message
write(
&mut output,
Expand Down
Loading

0 comments on commit c52884b

Please sign in to comment.