diff --git a/tokio/src/io/util/mem.rs b/tokio/src/io/util/mem.rs index 5b404c21940..96676e64cff 100644 --- a/tokio/src/io/util/mem.rs +++ b/tokio/src/io/util/mem.rs @@ -124,6 +124,18 @@ impl AsyncWrite for DuplexStream { Pin::new(&mut *self.write.lock()).poll_write(cx, buf) } + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + Pin::new(&mut *self.write.lock()).poll_write_vectored(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + true + } + #[allow(unused_mut)] fn poll_flush( mut self: Pin<&mut Self>, @@ -224,6 +236,37 @@ impl Pipe { } Poll::Ready(Ok(len)) } + + fn poll_write_vectored_internal( + mut self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + if self.is_closed { + return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())); + } + let avail = self.max_buf_size - self.buffer.len(); + if avail == 0 { + self.write_waker = Some(cx.waker().clone()); + return Poll::Pending; + } + + let mut rem = avail; + for buf in bufs { + if rem == 0 { + break; + } + + let len = buf.len().min(rem); + self.buffer.extend_from_slice(&buf[..len]); + rem -= len; + } + + if let Some(waker) = self.read_waker.take() { + waker.wake(); + } + Poll::Ready(Ok(avail - rem)) + } } impl AsyncRead for Pipe { @@ -285,6 +328,38 @@ impl AsyncWrite for Pipe { } } + cfg_coop! { + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + ready!(crate::trace::trace_leaf(cx)); + let coop = ready!(crate::runtime::coop::poll_proceed(cx)); + + let ret = self.poll_write_vectored_internal(cx, bufs); + if ret.is_ready() { + coop.made_progress(); + } + ret + } + } + + cfg_not_coop! { + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut task::Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + ready!(crate::trace::trace_leaf(cx)); + self.poll_write_vectored_internal(cx, bufs) + } + } + + fn is_write_vectored(&self) -> bool { + true + } + fn poll_flush(self: Pin<&mut Self>, _: &mut task::Context<'_>) -> Poll> { Poll::Ready(Ok(())) } diff --git a/tokio/tests/duplex_stream.rs b/tokio/tests/duplex_stream.rs new file mode 100644 index 00000000000..64111fb960d --- /dev/null +++ b/tokio/tests/duplex_stream.rs @@ -0,0 +1,47 @@ +#![warn(rust_2018_idioms)] +#![cfg(feature = "full")] + +use std::io::IoSlice; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; + +const HELLO: &[u8] = b"hello world..."; + +#[tokio::test] +async fn write_vectored() { + let (mut client, mut server) = tokio::io::duplex(64); + + let ret = client + .write_vectored(&[IoSlice::new(HELLO), IoSlice::new(HELLO)]) + .await + .unwrap(); + assert_eq!(ret, HELLO.len() * 2); + + client.flush().await.unwrap(); + drop(client); + + let mut buf = Vec::with_capacity(HELLO.len() * 2); + let bytes_read = server.read_to_end(&mut buf).await.unwrap(); + + assert_eq!(bytes_read, HELLO.len() * 2); + assert_eq!(buf, [HELLO, HELLO].concat()); +} + +#[tokio::test] +async fn write_vectored_and_shutdown() { + let (mut client, mut server) = tokio::io::duplex(64); + + let ret = client + .write_vectored(&[IoSlice::new(HELLO), IoSlice::new(HELLO)]) + .await + .unwrap(); + assert_eq!(ret, HELLO.len() * 2); + + client.shutdown().await.unwrap(); + drop(client); + + let mut buf = Vec::with_capacity(HELLO.len() * 2); + let bytes_read = server.read_to_end(&mut buf).await.unwrap(); + + assert_eq!(bytes_read, HELLO.len() * 2); + assert_eq!(buf, [HELLO, HELLO].concat()); +}