diff --git a/.cargo/config b/.cargo/config new file mode 100644 index 00000000000..df8858986f3 --- /dev/null +++ b/.cargo/config @@ -0,0 +1,2 @@ +# [build] +# rustflags = ["--cfg", "tokio_unstable"] \ No newline at end of file diff --git a/.circleci/config.yml b/.circleci/config.yml index 7e5b83fbdb9..4ec482174b6 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -5,8 +5,8 @@ jobs: image: ubuntu-2004:202101-01 resource_class: arm.medium environment: - # Change to pin rust versino - RUST_STABLE: stable + # Change to pin rust version + RUST_STABLE: 1.60.0 steps: - checkout - run: @@ -22,4 +22,4 @@ jobs: workflows: ci: jobs: - - test-arm \ No newline at end of file + - test-arm diff --git a/.cirrus.yml b/.cirrus.yml index 5916acd96c4..6dcd8b19226 100644 --- a/.cirrus.yml +++ b/.cirrus.yml @@ -1,8 +1,8 @@ freebsd_instance: - image: freebsd-12-2-release-amd64 + image: freebsd-12-3-release-amd64 env: - RUST_STABLE: stable - RUST_NIGHTLY: nightly-2022-01-12 + RUST_STABLE: 1.60.0 + RUST_NIGHTLY: nightly-2022-03-21 RUSTFLAGS: -D warnings # Test FreeBSD in a full VM on cirrus-ci.com. Test the i686 target too, in the @@ -26,8 +26,8 @@ task: task: name: FreeBSD docs env: - RUSTFLAGS: --cfg docsrs - RUSTDOCFLAGS: --cfg docsrs -Dwarnings + RUSTFLAGS: --cfg docsrs --cfg tokio_unstable + RUSTDOCFLAGS: --cfg docsrs --cfg tokio_unstable -Dwarnings setup_script: - pkg install -y bash curl - curl https://sh.rustup.rs -sSf --output rustup.sh diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1bb8b52bb1d..ae99b17b7e0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -10,9 +10,9 @@ env: RUSTFLAGS: -Dwarnings RUST_BACKTRACE: 1 # Change to specific Rust release to pin - rust_stable: stable - rust_nightly: nightly-2022-01-12 - rust_clippy: 1.52.0 + rust_stable: 1.60.0 + rust_nightly: nightly-2022-03-21 + rust_clippy: 1.56.0 rust_min: 1.49.0 defaults: @@ -33,7 +33,6 @@ jobs: - features - minrust - fmt - - clippy - docs - valgrind - loom-compile @@ -174,6 +173,9 @@ jobs: working-directory: tokio env: RUSTFLAGS: --cfg tokio_unstable -Dwarnings + # in order to run doctests for unstable features, we must also pass + # the unstable cfg to RustDoc + RUSTDOCFLAGS: --cfg tokio_unstable miri: name: miri @@ -225,6 +227,7 @@ jobs: - powerpc64-unknown-linux-gnu - mips-unknown-linux-gnu - arm-linux-androideabi + - mipsel-unknown-linux-musl steps: - uses: actions/checkout@v2 - name: Install Rust ${{ env.rust_stable }} @@ -238,7 +241,14 @@ jobs: with: use-cross: true command: check - args: --workspace --target ${{ matrix.target }} + args: --workspace --all-features --target ${{ matrix.target }} + - uses: actions-rs/cargo@v1 + with: + use-cross: true + command: check + args: --workspace --all-features --target ${{ matrix.target }} + env: + RUSTFLAGS: --cfg tokio_unstable -Dwarnings features: name: features @@ -273,8 +283,9 @@ jobs: toolchain: ${{ env.rust_min }} override: true - uses: Swatinem/rust-cache@v1 - - name: "test --workspace --all-features" - run: cargo check --workspace --all-features + - name: "test --all-features" + run: cargo check --all-features + working-directory: tokio minimal-versions: name: minimal-versions @@ -329,22 +340,6 @@ jobs: exit 1 fi - clippy: - name: clippy - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v2 - - name: Install Rust ${{ env.rust_clippy }} - uses: actions-rs/toolchain@v1 - with: - toolchain: ${{ env.rust_clippy }} - override: true - components: clippy - - uses: Swatinem/rust-cache@v1 - # Run clippy - - name: "clippy --all" - run: cargo clippy --all --tests --all-features - docs: name: docs runs-on: ubuntu-latest @@ -359,8 +354,8 @@ jobs: - name: "doc --lib --all-features" run: cargo doc --lib --no-deps --all-features --document-private-items env: - RUSTFLAGS: --cfg docsrs - RUSTDOCFLAGS: --cfg docsrs -Dwarnings + RUSTFLAGS: --cfg docsrs --cfg tokio_unstable + RUSTDOCFLAGS: --cfg docsrs --cfg tokio_unstable -Dwarnings loom-compile: name: build loom tests diff --git a/.github/workflows/loom.yml b/.github/workflows/loom.yml index 83a6743cb17..29ed8c17994 100644 --- a/.github/workflows/loom.yml +++ b/.github/workflows/loom.yml @@ -11,7 +11,7 @@ env: RUSTFLAGS: -Dwarnings RUST_BACKTRACE: 1 # Change to specific Rust release to pin - rust_stable: stable + rust_stable: 1.60.0 jobs: loom: diff --git a/.github/workflows/stress-test.yml b/.github/workflows/stress-test.yml index a6a24fdea14..f48c7a01996 100644 --- a/.github/workflows/stress-test.yml +++ b/.github/workflows/stress-test.yml @@ -9,7 +9,7 @@ env: RUSTFLAGS: -Dwarnings RUST_BACKTRACE: 1 # Change to specific Rust release to pin - rust_stable: stable + rust_stable: 1.60.0 jobs: stess-test: diff --git a/.gitignore b/.gitignore index a9d37c560c6..1e656726650 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,4 @@ target Cargo.lock + +.cargo/config.toml diff --git a/Cross.toml b/Cross.toml new file mode 100644 index 00000000000..050f2bdbd75 --- /dev/null +++ b/Cross.toml @@ -0,0 +1,4 @@ +[build.env] +passthrough = [ + "RUSTFLAGS", +] diff --git a/README.md b/README.md index 1cce34aeeff..46b1e089cfd 100644 --- a/README.md +++ b/README.md @@ -56,7 +56,7 @@ Make sure you activated the full features of the tokio crate on Cargo.toml: ```toml [dependencies] -tokio = { version = "1.17.0", features = ["full"] } +tokio = { version = "1.18.5", features = ["full"] } ``` Then, on your main.rs: diff --git a/examples/Cargo.toml b/examples/Cargo.toml index d2aca69d84a..a6f6eda3267 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -21,7 +21,6 @@ serde_derive = "1.0" serde_json = "1.0" httparse = "1.0" httpdate = "1.0" -once_cell = "1.5.2" rand = "0.8.3" [target.'cfg(windows)'.dev-dependencies.winapi] @@ -71,11 +70,6 @@ path = "udp-codec.rs" name = "tinyhttp" path = "tinyhttp.rs" -[[example]] -name = "custom-executor" -path = "custom-executor.rs" - - [[example]] name = "custom-executor-tokio-context" path = "custom-executor-tokio-context.rs" diff --git a/tests-build/tests/fail/macros_core_no_default.stderr b/tests-build/tests/fail/macros_core_no_default.stderr index 676acc8dbe3..c1a35af3c6e 100644 --- a/tests-build/tests/fail/macros_core_no_default.stderr +++ b/tests-build/tests/fail/macros_core_no_default.stderr @@ -1,5 +1,5 @@ error: The default runtime flavor is `multi_thread`, but the `rt-multi-thread` feature is disabled. - --> $DIR/macros_core_no_default.rs:3:1 + --> tests/fail/macros_core_no_default.rs:3:1 | 3 | #[tokio::main] | ^^^^^^^^^^^^^^ diff --git a/tests-build/tests/fail/macros_dead_code.stderr b/tests-build/tests/fail/macros_dead_code.stderr index 816c294bd31..19ea06e8f71 100644 --- a/tests-build/tests/fail/macros_dead_code.stderr +++ b/tests-build/tests/fail/macros_dead_code.stderr @@ -1,11 +1,11 @@ error: function is never used: `f` - --> $DIR/macros_dead_code.rs:6:10 + --> tests/fail/macros_dead_code.rs:6:10 | 6 | async fn f() {} | ^ | note: the lint level is defined here - --> $DIR/macros_dead_code.rs:1:9 + --> tests/fail/macros_dead_code.rs:1:9 | 1 | #![deny(dead_code)] | ^^^^^^^^^ diff --git a/tests-build/tests/fail/macros_invalid_input.rs b/tests-build/tests/fail/macros_invalid_input.rs index eb04eca76b6..3272757b9f0 100644 --- a/tests-build/tests/fail/macros_invalid_input.rs +++ b/tests-build/tests/fail/macros_invalid_input.rs @@ -1,3 +1,5 @@ +#![deny(duplicate_macro_attributes)] + use tests_build::tokio; #[tokio::main] @@ -33,6 +35,15 @@ async fn test_worker_threads_not_int() {} #[tokio::test(flavor = "current_thread", worker_threads = 4)] async fn test_worker_threads_and_current_thread() {} +#[tokio::test(crate = 456)] +async fn test_crate_not_ident_int() {} + +#[tokio::test(crate = "456")] +async fn test_crate_not_ident_invalid() {} + +#[tokio::test(crate = "abc::edf")] +async fn test_crate_not_ident_path() {} + #[tokio::test] #[test] async fn test_has_second_test_attr() {} diff --git a/tests-build/tests/fail/macros_invalid_input.stderr b/tests-build/tests/fail/macros_invalid_input.stderr index 11337a94fe5..cd294dddf36 100644 --- a/tests-build/tests/fail/macros_invalid_input.stderr +++ b/tests-build/tests/fail/macros_invalid_input.stderr @@ -1,71 +1,101 @@ error: the `async` keyword is missing from the function declaration - --> $DIR/macros_invalid_input.rs:4:1 + --> tests/fail/macros_invalid_input.rs:6:1 | -4 | fn main_is_not_async() {} +6 | fn main_is_not_async() {} | ^^ -error: Unknown attribute foo is specified; expected one of: `flavor`, `worker_threads`, `start_paused` - --> $DIR/macros_invalid_input.rs:6:15 +error: Unknown attribute foo is specified; expected one of: `flavor`, `worker_threads`, `start_paused`, `crate` + --> tests/fail/macros_invalid_input.rs:8:15 | -6 | #[tokio::main(foo)] +8 | #[tokio::main(foo)] | ^^^ error: Must have specified ident - --> $DIR/macros_invalid_input.rs:9:15 - | -9 | #[tokio::main(threadpool::bar)] - | ^^^^^^^^^^^^^^^ + --> tests/fail/macros_invalid_input.rs:11:15 + | +11 | #[tokio::main(threadpool::bar)] + | ^^^^^^^^^^^^^^^ error: the `async` keyword is missing from the function declaration - --> $DIR/macros_invalid_input.rs:13:1 + --> tests/fail/macros_invalid_input.rs:15:1 | -13 | fn test_is_not_async() {} +15 | fn test_is_not_async() {} | ^^ -error: Unknown attribute foo is specified; expected one of: `flavor`, `worker_threads`, `start_paused` - --> $DIR/macros_invalid_input.rs:15:15 +error: Unknown attribute foo is specified; expected one of: `flavor`, `worker_threads`, `start_paused`, `crate` + --> tests/fail/macros_invalid_input.rs:17:15 | -15 | #[tokio::test(foo)] +17 | #[tokio::test(foo)] | ^^^ -error: Unknown attribute foo is specified; expected one of: `flavor`, `worker_threads`, `start_paused` - --> $DIR/macros_invalid_input.rs:18:15 +error: Unknown attribute foo is specified; expected one of: `flavor`, `worker_threads`, `start_paused`, `crate` + --> tests/fail/macros_invalid_input.rs:20:15 | -18 | #[tokio::test(foo = 123)] +20 | #[tokio::test(foo = 123)] | ^^^^^^^^^ error: Failed to parse value of `flavor` as string. - --> $DIR/macros_invalid_input.rs:21:24 + --> tests/fail/macros_invalid_input.rs:23:24 | -21 | #[tokio::test(flavor = 123)] +23 | #[tokio::test(flavor = 123)] | ^^^ error: No such runtime flavor `foo`. The runtime flavors are `current_thread` and `multi_thread`. - --> $DIR/macros_invalid_input.rs:24:24 + --> tests/fail/macros_invalid_input.rs:26:24 | -24 | #[tokio::test(flavor = "foo")] +26 | #[tokio::test(flavor = "foo")] | ^^^^^ error: The `start_paused` option requires the `current_thread` runtime flavor. Use `#[tokio::test(flavor = "current_thread")]` - --> $DIR/macros_invalid_input.rs:27:55 + --> tests/fail/macros_invalid_input.rs:29:55 | -27 | #[tokio::test(flavor = "multi_thread", start_paused = false)] +29 | #[tokio::test(flavor = "multi_thread", start_paused = false)] | ^^^^^ error: Failed to parse value of `worker_threads` as integer. - --> $DIR/macros_invalid_input.rs:30:57 + --> tests/fail/macros_invalid_input.rs:32:57 | -30 | #[tokio::test(flavor = "multi_thread", worker_threads = "foo")] +32 | #[tokio::test(flavor = "multi_thread", worker_threads = "foo")] | ^^^^^ error: The `worker_threads` option requires the `multi_thread` runtime flavor. Use `#[tokio::test(flavor = "multi_thread")]` - --> $DIR/macros_invalid_input.rs:33:59 + --> tests/fail/macros_invalid_input.rs:35:59 | -33 | #[tokio::test(flavor = "current_thread", worker_threads = 4)] +35 | #[tokio::test(flavor = "current_thread", worker_threads = 4)] | ^ +error: Failed to parse value of `crate` as ident. + --> tests/fail/macros_invalid_input.rs:38:23 + | +38 | #[tokio::test(crate = 456)] + | ^^^ + +error: Failed to parse value of `crate` as ident: "456" + --> tests/fail/macros_invalid_input.rs:41:23 + | +41 | #[tokio::test(crate = "456")] + | ^^^^^ + +error: Failed to parse value of `crate` as ident: "abc::edf" + --> tests/fail/macros_invalid_input.rs:44:23 + | +44 | #[tokio::test(crate = "abc::edf")] + | ^^^^^^^^^^ + error: second test attribute is supplied - --> $DIR/macros_invalid_input.rs:37:1 + --> tests/fail/macros_invalid_input.rs:48:1 + | +48 | #[test] + | ^^^^^^^ + +error: duplicated attribute + --> tests/fail/macros_invalid_input.rs:48:1 | -37 | #[test] +48 | #[test] | ^^^^^^^ + | +note: the lint level is defined here + --> tests/fail/macros_invalid_input.rs:1:9 + | +1 | #![deny(duplicate_macro_attributes)] + | ^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/tests-build/tests/fail/macros_type_mismatch.rs b/tests-build/tests/fail/macros_type_mismatch.rs index 0a5b9c4c727..c292ee68f66 100644 --- a/tests-build/tests/fail/macros_type_mismatch.rs +++ b/tests-build/tests/fail/macros_type_mismatch.rs @@ -23,4 +23,13 @@ async fn extra_semicolon() -> Result<(), ()> { Ok(()); } +// https://github.com/tokio-rs/tokio/issues/4635 +#[allow(redundant_semicolons)] +#[rustfmt::skip] +#[tokio::main] +async fn issue_4635() { + return 1; + ; +} + fn main() {} diff --git a/tests-build/tests/fail/macros_type_mismatch.stderr b/tests-build/tests/fail/macros_type_mismatch.stderr index f98031514ff..9e1bd5ca0c3 100644 --- a/tests-build/tests/fail/macros_type_mismatch.stderr +++ b/tests-build/tests/fail/macros_type_mismatch.stderr @@ -1,25 +1,19 @@ error[E0308]: mismatched types --> tests/fail/macros_type_mismatch.rs:5:5 | +4 | async fn missing_semicolon_or_return_type() { + | - possibly return type missing here? 5 | Ok(()) | ^^^^^^ expected `()`, found enum `Result` | = note: expected unit type `()` found enum `Result<(), _>` -help: consider using a semicolon here - | -5 | Ok(()); - | + -help: try adding a return type - | -4 | async fn missing_semicolon_or_return_type() -> Result<(), _> { - | ++++++++++++++++ error[E0308]: mismatched types --> tests/fail/macros_type_mismatch.rs:10:5 | 9 | async fn missing_return_type() { - | - help: try adding a return type: `-> Result<(), _>` + | - possibly return type missing here? 10 | return Ok(()); | ^^^^^^^^^^^^^^ expected `()`, found enum `Result` | @@ -37,8 +31,18 @@ error[E0308]: mismatched types | = note: expected enum `Result<(), ()>` found unit type `()` -help: try adding an expression at the end of the block +help: try wrapping the expression in a variant of `Result` | -23 ~ Ok(());; -24 + Ok(()) +23 | Ok(Ok(());) + | +++ + +23 | Err(Ok(());) + | ++++ + + +error[E0308]: mismatched types + --> tests/fail/macros_type_mismatch.rs:32:5 | +30 | async fn issue_4635() { + | - possibly return type missing here? +31 | return 1; +32 | ; + | ^ expected `()`, found integer diff --git a/tests-integration/Cargo.toml b/tests-integration/Cargo.toml index 5ab8c15db3d..a45c4deac0d 100644 --- a/tests-integration/Cargo.toml +++ b/tests-integration/Cargo.toml @@ -20,7 +20,7 @@ required-features = ["rt-process-signal"] # For mem check rt-net = ["tokio/rt", "tokio/rt-multi-thread", "tokio/net"] # For test-process-signal -rt-process-signal = ["rt", "tokio/process", "tokio/signal"] +rt-process-signal = ["rt-net", "tokio/process", "tokio/signal"] full = [ "macros", diff --git a/tokio-macros/src/entry.rs b/tokio-macros/src/entry.rs index 5cb4a49b430..68eb829176b 100644 --- a/tokio-macros/src/entry.rs +++ b/tokio-macros/src/entry.rs @@ -1,5 +1,5 @@ use proc_macro::TokenStream; -use proc_macro2::Span; +use proc_macro2::{Ident, Span}; use quote::{quote, quote_spanned, ToTokens}; use syn::parse::Parser; @@ -29,6 +29,7 @@ struct FinalConfig { flavor: RuntimeFlavor, worker_threads: Option, start_paused: Option, + crate_name: Option, } /// Config used in case of the attribute not being able to build a valid config @@ -36,6 +37,7 @@ const DEFAULT_ERROR_CONFIG: FinalConfig = FinalConfig { flavor: RuntimeFlavor::CurrentThread, worker_threads: None, start_paused: None, + crate_name: None, }; struct Configuration { @@ -45,6 +47,7 @@ struct Configuration { worker_threads: Option<(usize, Span)>, start_paused: Option<(bool, Span)>, is_test: bool, + crate_name: Option, } impl Configuration { @@ -59,6 +62,7 @@ impl Configuration { worker_threads: None, start_paused: None, is_test, + crate_name: None, } } @@ -104,6 +108,15 @@ impl Configuration { Ok(()) } + fn set_crate_name(&mut self, name: syn::Lit, span: Span) -> Result<(), syn::Error> { + if self.crate_name.is_some() { + return Err(syn::Error::new(span, "`crate` set multiple times.")); + } + let name_ident = parse_ident(name, span, "crate")?; + self.crate_name = Some(name_ident.to_string()); + Ok(()) + } + fn macro_name(&self) -> &'static str { if self.is_test { "tokio::test" @@ -151,6 +164,7 @@ impl Configuration { }; Ok(FinalConfig { + crate_name: self.crate_name.clone(), flavor, worker_threads, start_paused, @@ -185,6 +199,27 @@ fn parse_string(int: syn::Lit, span: Span, field: &str) -> Result Result { + match lit { + syn::Lit::Str(s) => { + let err = syn::Error::new( + span, + format!( + "Failed to parse value of `{}` as ident: \"{}\"", + field, + s.value() + ), + ); + let path = s.parse::().map_err(|_| err.clone())?; + path.get_ident().cloned().ok_or(err) + } + _ => Err(syn::Error::new( + span, + format!("Failed to parse value of `{}` as ident.", field), + )), + } +} + fn parse_bool(bool: syn::Lit, span: Span, field: &str) -> Result { match bool { syn::Lit::Bool(b) => Ok(b.value), @@ -243,9 +278,15 @@ fn build_config( let msg = "Attribute `core_threads` is renamed to `worker_threads`"; return Err(syn::Error::new_spanned(namevalue, msg)); } + "crate" => { + config.set_crate_name( + namevalue.lit.clone(), + syn::spanned::Spanned::span(&namevalue.lit), + )?; + } name => { let msg = format!( - "Unknown attribute {} is specified; expected one of: `flavor`, `worker_threads`, `start_paused`", + "Unknown attribute {} is specified; expected one of: `flavor`, `worker_threads`, `start_paused`, `crate`", name, ); return Err(syn::Error::new_spanned(namevalue, msg)); @@ -275,7 +316,7 @@ fn build_config( format!("The `{}` attribute requires an argument.", name) } name => { - format!("Unknown attribute {} is specified; expected one of: `flavor`, `worker_threads`, `start_paused`", name) + format!("Unknown attribute {} is specified; expected one of: `flavor`, `worker_threads`, `start_paused`, `crate`", name) } }; return Err(syn::Error::new_spanned(path, msg)); @@ -313,12 +354,16 @@ fn parse_knobs(mut input: syn::ItemFn, is_test: bool, config: FinalConfig) -> To (start, end) }; + let crate_name = config.crate_name.as_deref().unwrap_or("tokio"); + + let crate_ident = Ident::new(crate_name, last_stmt_start_span); + let mut rt = match config.flavor { RuntimeFlavor::CurrentThread => quote_spanned! {last_stmt_start_span=> - tokio::runtime::Builder::new_current_thread() + #crate_ident::runtime::Builder::new_current_thread() }, RuntimeFlavor::Threaded => quote_spanned! {last_stmt_start_span=> - tokio::runtime::Builder::new_multi_thread() + #crate_ident::runtime::Builder::new_multi_thread() }, }; if let Some(v) = config.worker_threads { @@ -338,29 +383,17 @@ fn parse_knobs(mut input: syn::ItemFn, is_test: bool, config: FinalConfig) -> To let body = &input.block; let brace_token = input.block.brace_token; - let (tail_return, tail_semicolon) = match body.stmts.last() { - Some(syn::Stmt::Semi(syn::Expr::Return(_), _)) => (quote! { return }, quote! { ; }), - Some(syn::Stmt::Semi(..)) | Some(syn::Stmt::Local(..)) | None => { - match &input.sig.output { - syn::ReturnType::Type(_, ty) if matches!(&**ty, syn::Type::Tuple(ty) if ty.elems.is_empty()) => - { - (quote! {}, quote! { ; }) // unit - } - syn::ReturnType::Default => (quote! {}, quote! { ; }), // unit - syn::ReturnType::Type(..) => (quote! {}, quote! {}), // ! or another - } - } - _ => (quote! {}, quote! {}), - }; input.block = syn::parse2(quote_spanned! {last_stmt_end_span=> { let body = async #body; - #[allow(clippy::expect_used)] - #tail_return #rt - .enable_all() - .build() - .expect("Failed building the Runtime") - .block_on(body)#tail_semicolon + #[allow(clippy::expect_used, clippy::diverging_sub_expression)] + { + return #rt + .enable_all() + .build() + .expect("Failed building the Runtime") + .block_on(body); + } } }) .expect("Parsing failure"); diff --git a/tokio-macros/src/lib.rs b/tokio-macros/src/lib.rs index 38638a1df8a..b15fd3b7017 100644 --- a/tokio-macros/src/lib.rs +++ b/tokio-macros/src/lib.rs @@ -168,12 +168,32 @@ use proc_macro::TokenStream; /// /// Note that `start_paused` requires the `test-util` feature to be enabled. /// -/// ### NOTE: +/// ### Rename package /// -/// If you rename the Tokio crate in your dependencies this macro will not work. -/// If you must rename the current version of Tokio because you're also using an -/// older version of Tokio, you _must_ make the current version of Tokio -/// available as `tokio` in the module where this macro is expanded. +/// ```rust +/// use tokio as tokio1; +/// +/// #[tokio1::main(crate = "tokio1")] +/// async fn main() { +/// println!("Hello world"); +/// } +/// ``` +/// +/// Equivalent code not using `#[tokio::main]` +/// +/// ```rust +/// use tokio as tokio1; +/// +/// fn main() { +/// tokio1::runtime::Builder::new_multi_thread() +/// .enable_all() +/// .build() +/// .unwrap() +/// .block_on(async { +/// println!("Hello world"); +/// }) +/// } +/// ``` #[proc_macro_attribute] #[cfg(not(test))] // Work around for rust-lang/rust#62127 pub fn main(args: TokenStream, item: TokenStream) -> TokenStream { @@ -213,12 +233,32 @@ pub fn main(args: TokenStream, item: TokenStream) -> TokenStream { /// } /// ``` /// -/// ### NOTE: +/// ### Rename package +/// +/// ```rust +/// use tokio as tokio1; +/// +/// #[tokio1::main(crate = "tokio1")] +/// async fn main() { +/// println!("Hello world"); +/// } +/// ``` +/// +/// Equivalent code not using `#[tokio::main]` +/// +/// ```rust +/// use tokio as tokio1; /// -/// If you rename the Tokio crate in your dependencies this macro will not work. -/// If you must rename the current version of Tokio because you're also using an -/// older version of Tokio, you _must_ make the current version of Tokio -/// available as `tokio` in the module where this macro is expanded. +/// fn main() { +/// tokio1::runtime::Builder::new_multi_thread() +/// .enable_all() +/// .build() +/// .unwrap() +/// .block_on(async { +/// println!("Hello world"); +/// }) +/// } +/// ``` #[proc_macro_attribute] #[cfg(not(test))] // Work around for rust-lang/rust#62127 pub fn main_rt(args: TokenStream, item: TokenStream) -> TokenStream { @@ -260,12 +300,16 @@ pub fn main_rt(args: TokenStream, item: TokenStream) -> TokenStream { /// /// Note that `start_paused` requires the `test-util` feature to be enabled. /// -/// ### NOTE: +/// ### Rename package +/// +/// ```rust +/// use tokio as tokio1; /// -/// If you rename the Tokio crate in your dependencies this macro will not work. -/// If you must rename the current version of Tokio because you're also using an -/// older version of Tokio, you _must_ make the current version of Tokio -/// available as `tokio` in the module where this macro is expanded. +/// #[tokio1::test(crate = "tokio1")] +/// async fn my_test() { +/// println!("Hello world"); +/// } +/// ``` #[proc_macro_attribute] pub fn test(args: TokenStream, item: TokenStream) -> TokenStream { entry::test(args, item, true) @@ -281,13 +325,6 @@ pub fn test(args: TokenStream, item: TokenStream) -> TokenStream { /// assert!(true); /// } /// ``` -/// -/// ### NOTE: -/// -/// If you rename the Tokio crate in your dependencies this macro will not work. -/// If you must rename the current version of Tokio because you're also using an -/// older version of Tokio, you _must_ make the current version of Tokio -/// available as `tokio` in the module where this macro is expanded. #[proc_macro_attribute] pub fn test_rt(args: TokenStream, item: TokenStream) -> TokenStream { entry::test(args, item, false) diff --git a/tokio-stream/src/lib.rs b/tokio-stream/src/lib.rs index f600ccb8d36..bbd4cef03ed 100644 --- a/tokio-stream/src/lib.rs +++ b/tokio-stream/src/lib.rs @@ -77,6 +77,9 @@ pub mod wrappers; mod stream_ext; pub use stream_ext::{collect::FromStream, StreamExt}; +cfg_time! { + pub use stream_ext::timeout::{Elapsed, Timeout}; +} mod empty; pub use empty::{empty, Empty}; diff --git a/tokio-stream/src/stream_ext.rs b/tokio-stream/src/stream_ext.rs index b79883bd6e8..a7bc2044f95 100644 --- a/tokio-stream/src/stream_ext.rs +++ b/tokio-stream/src/stream_ext.rs @@ -56,7 +56,7 @@ mod try_next; use try_next::TryNext; cfg_time! { - mod timeout; + pub(crate) mod timeout; use timeout::Timeout; use tokio::time::Duration; mod throttle; diff --git a/tokio-stream/src/wrappers/broadcast.rs b/tokio-stream/src/wrappers/broadcast.rs index 2064973d733..10184bf9410 100644 --- a/tokio-stream/src/wrappers/broadcast.rs +++ b/tokio-stream/src/wrappers/broadcast.rs @@ -18,7 +18,7 @@ pub struct BroadcastStream { } /// An error returned from the inner stream of a [`BroadcastStream`]. -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Clone)] pub enum BroadcastStreamRecvError { /// The receiver lagged too far behind. Attempting to receive again will /// return the oldest message still retained by the channel. diff --git a/tokio-util/CHANGELOG.md b/tokio-util/CHANGELOG.md index 96a45374384..80722ddd74a 100644 --- a/tokio-util/CHANGELOG.md +++ b/tokio-util/CHANGELOG.md @@ -1,3 +1,25 @@ +# 0.7.1 (February 21, 2022) + +### Added + +- codec: add `length_field_type` to `LengthDelimitedCodec` builder ([#4508]) +- io: add `StreamReader::into_inner_with_chunk()` ([#4559]) + +### Changed + +- switch from log to tracing ([#4539]) + +### Fixed + +- sync: fix waker update condition in `CancellationToken` ([#4497]) +- bumped tokio dependency to 1.6 to satisfy minimum requirements ([#4490]) + +[#4490]: https://github.com/tokio-rs/tokio/pull/4490 +[#4497]: https://github.com/tokio-rs/tokio/pull/4497 +[#4508]: https://github.com/tokio-rs/tokio/pull/4508 +[#4539]: https://github.com/tokio-rs/tokio/pull/4539 +[#4559]: https://github.com/tokio-rs/tokio/pull/4559 + # 0.7.0 (February 9, 2022) ### Added diff --git a/tokio-util/Cargo.toml b/tokio-util/Cargo.toml index b4782d5d9a3..992ff234a8b 100644 --- a/tokio-util/Cargo.toml +++ b/tokio-util/Cargo.toml @@ -4,7 +4,7 @@ name = "tokio-util" # - Remove path dependencies # - Update CHANGELOG.md. # - Create "tokio-util-0.7.x" git tag. -version = "0.7.0" +version = "0.7.1" edition = "2018" rust-version = "1.49" authors = ["Tokio Contributors "] @@ -25,25 +25,27 @@ full = ["codec", "compat", "io-util", "time", "net", "rt"] net = ["tokio/net"] compat = ["futures-io",] -codec = [] +codec = ["tracing"] time = ["tokio/time","slab"] io = [] io-util = ["io", "tokio/rt", "tokio/io-util"] -rt = ["tokio/rt", "tokio/sync", "futures-util"] +rt = ["tokio/rt", "tokio/sync", "futures-util", "hashbrown"] __docs_rs = ["futures-util"] [dependencies] -tokio = { version = "1.6.0", path = "../tokio", features = ["sync"] } - +tokio = { version = "1.18.0", path = "../tokio", features = ["sync"] } bytes = "1.0.0" futures-core = "0.3.0" futures-sink = "0.3.0" futures-io = { version = "0.3.0", optional = true } futures-util = { version = "0.3.0", optional = true } -log = "0.4" pin-project-lite = "0.2.0" slab = { version = "0.4.4", optional = true } # Backs `DelayQueue` +tracing = { version = "0.1.25", default-features = false, features = ["std"], optional = true } + +[target.'cfg(tokio_unstable)'.dependencies] +hashbrown = { version = "0.12.0", optional = true } [dev-dependencies] tokio = { version = "1.0.0", path = "../tokio", features = ["full"] } diff --git a/tokio-util/src/codec/framed_impl.rs b/tokio-util/src/codec/framed_impl.rs index f932414e1e0..ce1a6db8739 100644 --- a/tokio-util/src/codec/framed_impl.rs +++ b/tokio-util/src/codec/framed_impl.rs @@ -7,12 +7,12 @@ use tokio::io::{AsyncRead, AsyncWrite}; use bytes::BytesMut; use futures_core::ready; use futures_sink::Sink; -use log::trace; use pin_project_lite::pin_project; use std::borrow::{Borrow, BorrowMut}; use std::io; use std::pin::Pin; use std::task::{Context, Poll}; +use tracing::trace; pin_project! { #[derive(Debug)] @@ -278,7 +278,7 @@ where while !pinned.state.borrow_mut().buffer.is_empty() { let WriteFrame { buffer } = pinned.state.borrow_mut(); - trace!("writing; remaining={}", buffer.len()); + trace!(remaining = buffer.len(), "writing;"); let n = ready!(poll_write_buf(pinned.inner.as_mut(), cx, buffer))?; diff --git a/tokio-util/src/codec/length_delimited.rs b/tokio-util/src/codec/length_delimited.rs index de0eb4e9201..93d2f180d0f 100644 --- a/tokio-util/src/codec/length_delimited.rs +++ b/tokio-util/src/codec/length_delimited.rs @@ -84,7 +84,7 @@ //! # fn bind_read(io: T) { //! LengthDelimitedCodec::builder() //! .length_field_offset(0) // default value -//! .length_field_length(2) +//! .length_field_type::() //! .length_adjustment(0) // default value //! .num_skip(0) // Do not strip frame header //! .new_read(io); @@ -118,7 +118,7 @@ //! # fn bind_read(io: T) { //! LengthDelimitedCodec::builder() //! .length_field_offset(0) // default value -//! .length_field_length(2) +//! .length_field_type::() //! .length_adjustment(0) // default value //! // `num_skip` is not needed, the default is to skip //! .new_read(io); @@ -150,7 +150,7 @@ //! # fn bind_read(io: T) { //! LengthDelimitedCodec::builder() //! .length_field_offset(0) // default value -//! .length_field_length(2) +//! .length_field_type::() //! .length_adjustment(-2) // size of head //! .num_skip(0) //! .new_read(io); @@ -228,7 +228,7 @@ //! # fn bind_read(io: T) { //! LengthDelimitedCodec::builder() //! .length_field_offset(1) // length of hdr1 -//! .length_field_length(2) +//! .length_field_type::() //! .length_adjustment(1) // length of hdr2 //! .num_skip(3) // length of hdr1 + LEN //! .new_read(io); @@ -274,7 +274,7 @@ //! # fn bind_read(io: T) { //! LengthDelimitedCodec::builder() //! .length_field_offset(1) // length of hdr1 -//! .length_field_length(2) +//! .length_field_type::() //! .length_adjustment(-3) // length of hdr1 + LEN, negative //! .num_skip(3) //! .new_read(io); @@ -350,7 +350,7 @@ //! # fn write_frame(io: T) { //! # let _ = //! LengthDelimitedCodec::builder() -//! .length_field_length(2) +//! .length_field_type::() //! .new_write(io); //! # } //! # pub fn main() {} @@ -379,7 +379,7 @@ use tokio::io::{AsyncRead, AsyncWrite}; use bytes::{Buf, BufMut, Bytes, BytesMut}; use std::error::Error as StdError; use std::io::{self, Cursor}; -use std::{cmp, fmt}; +use std::{cmp, fmt, mem}; /// Configure length delimited `LengthDelimitedCodec`s. /// @@ -629,6 +629,24 @@ impl Default for LengthDelimitedCodec { // ===== impl Builder ===== +mod builder { + /// Types that can be used with `Builder::length_field_type`. + pub trait LengthFieldType {} + + impl LengthFieldType for u8 {} + impl LengthFieldType for u16 {} + impl LengthFieldType for u32 {} + impl LengthFieldType for u64 {} + + #[cfg(any( + target_pointer_width = "8", + target_pointer_width = "16", + target_pointer_width = "32", + target_pointer_width = "64", + ))] + impl LengthFieldType for usize {} +} + impl Builder { /// Creates a new length delimited codec builder with default configuration /// values. @@ -642,7 +660,7 @@ impl Builder { /// # fn bind_read(io: T) { /// LengthDelimitedCodec::builder() /// .length_field_offset(0) - /// .length_field_length(2) + /// .length_field_type::() /// .length_adjustment(0) /// .num_skip(0) /// .new_read(io); @@ -777,6 +795,42 @@ impl Builder { self } + /// Sets the unsigned integer type used to represent the length field. + /// + /// The default type is [`u32`]. The max type is [`u64`] (or [`usize`] on + /// 64-bit targets). + /// + /// # Examples + /// + /// ``` + /// # use tokio::io::AsyncRead; + /// use tokio_util::codec::LengthDelimitedCodec; + /// + /// # fn bind_read(io: T) { + /// LengthDelimitedCodec::builder() + /// .length_field_type::() + /// .new_read(io); + /// # } + /// # pub fn main() {} + /// ``` + /// + /// Unlike [`Builder::length_field_length`], this does not fail at runtime + /// and instead produces a compile error: + /// + /// ```compile_fail + /// # use tokio::io::AsyncRead; + /// # use tokio_util::codec::LengthDelimitedCodec; + /// # fn bind_read(io: T) { + /// LengthDelimitedCodec::builder() + /// .length_field_type::() + /// .new_read(io); + /// # } + /// # pub fn main() {} + /// ``` + pub fn length_field_type(&mut self) -> &mut Self { + self.length_field_length(mem::size_of::()) + } + /// Sets the number of bytes used to represent the length field /// /// The default value is `4`. The max value is `8`. @@ -878,7 +932,7 @@ impl Builder { /// # pub fn main() { /// LengthDelimitedCodec::builder() /// .length_field_offset(0) - /// .length_field_length(2) + /// .length_field_type::() /// .length_adjustment(0) /// .num_skip(0) /// .new_codec(); @@ -902,7 +956,7 @@ impl Builder { /// # fn bind_read(io: T) { /// LengthDelimitedCodec::builder() /// .length_field_offset(0) - /// .length_field_length(2) + /// .length_field_type::() /// .length_adjustment(0) /// .num_skip(0) /// .new_read(io); @@ -925,7 +979,7 @@ impl Builder { /// # use tokio_util::codec::LengthDelimitedCodec; /// # fn write_frame(io: T) { /// LengthDelimitedCodec::builder() - /// .length_field_length(2) + /// .length_field_type::() /// .new_write(io); /// # } /// # pub fn main() {} @@ -947,7 +1001,7 @@ impl Builder { /// # fn write_frame(io: T) { /// # let _ = /// LengthDelimitedCodec::builder() - /// .length_field_length(2) + /// .length_field_type::() /// .new_framed(io); /// # } /// # pub fn main() {} diff --git a/tokio-util/src/io/stream_reader.rs b/tokio-util/src/io/stream_reader.rs index 8fe09b9b655..05ae8865573 100644 --- a/tokio-util/src/io/stream_reader.rs +++ b/tokio-util/src/io/stream_reader.rs @@ -84,13 +84,24 @@ where } /// Do we have a chunk and is it non-empty? - fn has_chunk(self: Pin<&mut Self>) -> bool { - if let Some(chunk) = self.project().chunk { + fn has_chunk(&self) -> bool { + if let Some(ref chunk) = self.chunk { chunk.remaining() > 0 } else { false } } + + /// Consumes this `StreamReader`, returning a Tuple consisting + /// of the underlying stream and an Option of the interal buffer, + /// which is Some in case the buffer contains elements. + pub fn into_inner_with_chunk(self) -> (S, Option) { + if self.has_chunk() { + (self.inner, self.chunk) + } else { + (self.inner, None) + } + } } impl StreamReader { @@ -118,6 +129,10 @@ impl StreamReader { /// Consumes this `BufWriter`, returning the underlying stream. /// /// Note that any leftover data in the internal buffer is lost. + /// If you additionally want access to the internal buffer use + /// [`into_inner_with_chunk`]. + /// + /// [`into_inner_with_chunk`]: crate::io::StreamReader::into_inner_with_chunk pub fn into_inner(self) -> S { self.inner } diff --git a/tokio-util/src/sync/cancellation_token.rs b/tokio-util/src/sync/cancellation_token.rs index 8a89698c4fd..b856ab368a8 100644 --- a/tokio-util/src/sync/cancellation_token.rs +++ b/tokio-util/src/sync/cancellation_token.rs @@ -24,9 +24,9 @@ use guard::DropGuard; /// /// # Examples /// -/// ```ignore +/// ```no_run /// use tokio::select; -/// use tokio::scope::CancellationToken; +/// use tokio_util::sync::CancellationToken; /// /// #[tokio::main] /// async fn main() { @@ -172,9 +172,9 @@ impl CancellationToken { /// /// # Examples /// - /// ```ignore + /// ```no_run /// use tokio::select; - /// use tokio::scope::CancellationToken; + /// use tokio_util::sync::CancellationToken; /// /// #[tokio::main] /// async fn main() { diff --git a/tokio-util/src/task/join_map.rs b/tokio-util/src/task/join_map.rs new file mode 100644 index 00000000000..41c82f448ab --- /dev/null +++ b/tokio-util/src/task/join_map.rs @@ -0,0 +1,808 @@ +use hashbrown::hash_map::RawEntryMut; +use hashbrown::HashMap; +use std::borrow::Borrow; +use std::collections::hash_map::RandomState; +use std::fmt; +use std::future::Future; +use std::hash::{BuildHasher, Hash, Hasher}; +use tokio::runtime::Handle; +use tokio::task::{AbortHandle, Id, JoinError, JoinSet, LocalSet}; + +/// A collection of tasks spawned on a Tokio runtime, associated with hash map +/// keys. +/// +/// This type is very similar to the [`JoinSet`] type in `tokio::task`, with the +/// addition of a set of keys associated with each task. These keys allow +/// [cancelling a task][abort] or [multiple tasks][abort_matching] in the +/// `JoinMap` based on their keys, or [test whether a task corresponding to a +/// given key exists][contains] in the `JoinMap`. +/// +/// In addition, when tasks in the `JoinMap` complete, they will return the +/// associated key along with the value returned by the task, if any. +/// +/// A `JoinMap` can be used to await the completion of some or all of the tasks +/// in the map. The map is not ordered, and the tasks will be returned in the +/// order they complete. +/// +/// All of the tasks must have the same return type `V`. +/// +/// When the `JoinMap` is dropped, all tasks in the `JoinMap` are immediately aborted. +/// +/// **Note**: This type depends on Tokio's [unstable API][unstable]. See [the +/// documentation on unstable features][unstable] for details on how to enable +/// Tokio's unstable features. +/// +/// # Examples +/// +/// Spawn multiple tasks and wait for them: +/// +/// ``` +/// use tokio_util::task::JoinMap; +/// +/// #[tokio::main] +/// async fn main() { +/// let mut map = JoinMap::new(); +/// +/// for i in 0..10 { +/// // Spawn a task on the `JoinMap` with `i` as its key. +/// map.spawn(i, async move { /* ... */ }); +/// } +/// +/// let mut seen = [false; 10]; +/// +/// // When a task completes, `join_one` returns the task's key along +/// // with its output. +/// while let Some((key, res)) = map.join_one().await { +/// seen[key] = true; +/// assert!(res.is_ok(), "task {} completed successfully!", key); +/// } +/// +/// for i in 0..10 { +/// assert!(seen[i]); +/// } +/// } +/// ``` +/// +/// Cancel tasks based on their keys: +/// +/// ``` +/// use tokio_util::task::JoinMap; +/// +/// #[tokio::main] +/// async fn main() { +/// let mut map = JoinMap::new(); +/// +/// map.spawn("hello world", async move { /* ... */ }); +/// map.spawn("goodbye world", async move { /* ... */}); +/// +/// // Look up the "goodbye world" task in the map and abort it. +/// let aborted = map.abort("goodbye world"); +/// +/// // `JoinMap::abort` returns `true` if a task existed for the +/// // provided key. +/// assert!(aborted); +/// +/// while let Some((key, res)) = map.join_one().await { +/// if key == "goodbye world" { +/// // The aborted task should complete with a cancelled `JoinError`. +/// assert!(res.unwrap_err().is_cancelled()); +/// } else { +/// // Other tasks should complete normally. +/// assert!(res.is_ok()); +/// } +/// } +/// } +/// ``` +/// +/// [`JoinSet`]: tokio::task::JoinSet +/// [unstable]: tokio#unstable-features +/// [abort]: fn@Self::abort +/// [abort_matching]: fn@Self::abort_matching +/// [contains]: fn@Self::contains_key +#[cfg_attr(docsrs, doc(cfg(all(feature = "rt", tokio_unstable))))] +pub struct JoinMap { + /// A map of the [`AbortHandle`]s of the tasks spawned on this `JoinMap`, + /// indexed by their keys and task IDs. + /// + /// The [`Key`] type contains both the task's `K`-typed key provided when + /// spawning tasks, and the task's IDs. The IDs are stored here to resolve + /// hash collisions when looking up tasks based on their pre-computed hash + /// (as stored in the `hashes_by_task` map). + tasks_by_key: HashMap, AbortHandle, S>, + + /// A map from task IDs to the hash of the key associated with that task. + /// + /// This map is used to perform reverse lookups of tasks in the + /// `tasks_by_key` map based on their task IDs. When a task terminates, the + /// ID is provided to us by the `JoinSet`, so we can look up the hash value + /// of that task's key, and then remove it from the `tasks_by_key` map using + /// the raw hash code, resolving collisions by comparing task IDs. + hashes_by_task: HashMap, + + /// The [`JoinSet`] that awaits the completion of tasks spawned on this + /// `JoinMap`. + tasks: JoinSet, +} + +/// A [`JoinMap`] key. +/// +/// This holds both a `K`-typed key (the actual key as seen by the user), _and_ +/// a task ID, so that hash collisions between `K`-typed keys can be resolved +/// using either `K`'s `Eq` impl *or* by checking the task IDs. +/// +/// This allows looking up a task using either an actual key (such as when the +/// user queries the map with a key), *or* using a task ID and a hash (such as +/// when removing completed tasks from the map). +#[derive(Debug)] +struct Key { + key: K, + id: Id, +} + +impl JoinMap { + /// Creates a new empty `JoinMap`. + /// + /// The `JoinMap` is initially created with a capacity of 0, so it will not + /// allocate until a task is first spawned on it. + /// + /// # Examples + /// + /// ``` + /// use tokio_util::task::JoinMap; + /// let map: JoinMap<&str, i32> = JoinMap::new(); + /// ``` + #[inline] + #[must_use] + pub fn new() -> Self { + Self::with_hasher(RandomState::new()) + } + + /// Creates an empty `JoinMap` with the specified capacity. + /// + /// The `JoinMap` will be able to hold at least `capacity` tasks without + /// reallocating. + /// + /// # Examples + /// + /// ``` + /// use tokio_util::task::JoinMap; + /// let map: JoinMap<&str, i32> = JoinMap::with_capacity(10); + /// ``` + #[inline] + #[must_use] + pub fn with_capacity(capacity: usize) -> Self { + JoinMap::with_capacity_and_hasher(capacity, Default::default()) + } +} + +impl JoinMap { + /// Creates an empty `JoinMap` which will use the given hash builder to hash + /// keys. + /// + /// The created map has the default initial capacity. + /// + /// Warning: `hash_builder` is normally randomly generated, and + /// is designed to allow `JoinMap` to be resistant to attacks that + /// cause many collisions and very poor performance. Setting it + /// manually using this function can expose a DoS attack vector. + /// + /// The `hash_builder` passed should implement the [`BuildHasher`] trait for + /// the `JoinMap` to be useful, see its documentation for details. + #[inline] + #[must_use] + pub fn with_hasher(hash_builder: S) -> Self { + Self::with_capacity_and_hasher(0, hash_builder) + } + + /// Creates an empty `JoinMap` with the specified capacity, using `hash_builder` + /// to hash the keys. + /// + /// The `JoinMap` will be able to hold at least `capacity` elements without + /// reallocating. If `capacity` is 0, the `JoinMap` will not allocate. + /// + /// Warning: `hash_builder` is normally randomly generated, and + /// is designed to allow HashMaps to be resistant to attacks that + /// cause many collisions and very poor performance. Setting it + /// manually using this function can expose a DoS attack vector. + /// + /// The `hash_builder` passed should implement the [`BuildHasher`] trait for + /// the `JoinMap`to be useful, see its documentation for details. + /// + /// # Examples + /// + /// ``` + /// # #[tokio::main] + /// # async fn main() { + /// use tokio_util::task::JoinMap; + /// use std::collections::hash_map::RandomState; + /// + /// let s = RandomState::new(); + /// let mut map = JoinMap::with_capacity_and_hasher(10, s); + /// map.spawn(1, async move { "hello world!" }); + /// # } + /// ``` + #[inline] + #[must_use] + pub fn with_capacity_and_hasher(capacity: usize, hash_builder: S) -> Self { + Self { + tasks_by_key: HashMap::with_capacity_and_hasher(capacity, hash_builder.clone()), + hashes_by_task: HashMap::with_capacity_and_hasher(capacity, hash_builder), + tasks: JoinSet::new(), + } + } + + /// Returns the number of tasks currently in the `JoinMap`. + pub fn len(&self) -> usize { + let len = self.tasks_by_key.len(); + debug_assert_eq!(len, self.hashes_by_task.len()); + len + } + + /// Returns whether the `JoinMap` is empty. + pub fn is_empty(&self) -> bool { + let empty = self.tasks_by_key.is_empty(); + debug_assert_eq!(empty, self.hashes_by_task.is_empty()); + empty + } + + /// Returns the number of tasks the map can hold without reallocating. + /// + /// This number is a lower bound; the `JoinMap` might be able to hold + /// more, but is guaranteed to be able to hold at least this many. + /// + /// # Examples + /// + /// ``` + /// use tokio_util::task::JoinMap; + /// + /// let map: JoinMap = JoinMap::with_capacity(100); + /// assert!(map.capacity() >= 100); + /// ``` + #[inline] + pub fn capacity(&self) -> usize { + let capacity = self.tasks_by_key.capacity(); + debug_assert_eq!(capacity, self.hashes_by_task.capacity()); + capacity + } +} + +impl JoinMap +where + K: Hash + Eq, + V: 'static, + S: BuildHasher, +{ + /// Spawn the provided task and store it in this `JoinMap` with the provided + /// key. + /// + /// If a task previously existed in the `JoinMap` for this key, that task + /// will be cancelled and replaced with the new one. The previous task will + /// be removed from the `JoinMap`; a subsequent call to [`join_one`] will + /// *not* return a cancelled [`JoinError`] for that task. + /// + /// # Panics + /// + /// This method panics if called outside of a Tokio runtime. + /// + /// [`join_one`]: Self::join_one + pub fn spawn(&mut self, key: K, task: F) + where + F: Future, + F: Send + 'static, + V: Send, + { + let task = self.tasks.spawn(task); + self.insert(key, task) + } + + /// Spawn the provided task on the provided runtime and store it in this + /// `JoinMap` with the provided key. + /// + /// If a task previously existed in the `JoinMap` for this key, that task + /// will be cancelled and replaced with the new one. The previous task will + /// be removed from the `JoinMap`; a subsequent call to [`join_one`] will + /// *not* return a cancelled [`JoinError`] for that task. + /// + /// [`join_one`]: Self::join_one + pub fn spawn_on(&mut self, key: K, task: F, handle: &Handle) + where + F: Future, + F: Send + 'static, + V: Send, + { + let task = self.tasks.spawn_on(task, handle); + self.insert(key, task); + } + + /// Spawn the provided task on the current [`LocalSet`] and store it in this + /// `JoinMap` with the provided key. + /// + /// If a task previously existed in the `JoinMap` for this key, that task + /// will be cancelled and replaced with the new one. The previous task will + /// be removed from the `JoinMap`; a subsequent call to [`join_one`] will + /// *not* return a cancelled [`JoinError`] for that task. + /// + /// # Panics + /// + /// This method panics if it is called outside of a `LocalSet`. + /// + /// [`LocalSet`]: tokio::task::LocalSet + /// [`join_one`]: Self::join_one + pub fn spawn_local(&mut self, key: K, task: F) + where + F: Future, + F: 'static, + { + let task = self.tasks.spawn_local(task); + self.insert(key, task); + } + + /// Spawn the provided task on the provided [`LocalSet`] and store it in + /// this `JoinMap` with the provided key. + /// + /// If a task previously existed in the `JoinMap` for this key, that task + /// will be cancelled and replaced with the new one. The previous task will + /// be removed from the `JoinMap`; a subsequent call to [`join_one`] will + /// *not* return a cancelled [`JoinError`] for that task. + /// + /// [`LocalSet`]: tokio::task::LocalSet + /// [`join_one`]: Self::join_one + pub fn spawn_local_on(&mut self, key: K, task: F, local_set: &LocalSet) + where + F: Future, + F: 'static, + { + let task = self.tasks.spawn_local_on(task, local_set); + self.insert(key, task) + } + + fn insert(&mut self, key: K, abort: AbortHandle) { + let hash = self.hash(&key); + let id = abort.id(); + let map_key = Key { + id: id.clone(), + key, + }; + + // Insert the new key into the map of tasks by keys. + let entry = self + .tasks_by_key + .raw_entry_mut() + .from_hash(hash, |k| k.key == map_key.key); + match entry { + RawEntryMut::Occupied(mut occ) => { + // There was a previous task spawned with the same key! Cancel + // that task, and remove its ID from the map of hashes by task IDs. + let Key { id: prev_id, .. } = occ.insert_key(map_key); + occ.insert(abort).abort(); + let _prev_hash = self.hashes_by_task.remove(&prev_id); + debug_assert_eq!(Some(hash), _prev_hash); + } + RawEntryMut::Vacant(vac) => { + vac.insert(map_key, abort); + } + }; + + // Associate the key's hash with this task's ID, for looking up tasks by ID. + let _prev = self.hashes_by_task.insert(id, hash); + debug_assert!(_prev.is_none(), "no prior task should have had the same ID"); + } + + /// Waits until one of the tasks in the map completes and returns its + /// output, along with the key corresponding to that task. + /// + /// Returns `None` if the map is empty. + /// + /// # Cancel Safety + /// + /// This method is cancel safe. If `join_one` is used as the event in a [`tokio::select!`] + /// statement and some other branch completes first, it is guaranteed that no tasks were + /// removed from this `JoinMap`. + /// + /// # Returns + /// + /// This function returns: + /// + /// * `Some((key, Ok(value)))` if one of the tasks in this `JoinMap` has + /// completed. The `value` is the return value of that ask, and `key` is + /// the key associated with the task. + /// * `Some((key, Err(err))` if one of the tasks in this JoinMap` has + /// panicked or been aborted. `key` is the key associated with the task + /// that panicked or was aborted. + /// * `None` if the `JoinMap` is empty. + /// + /// [`tokio::select!`]: tokio::select + pub async fn join_one(&mut self) -> Option<(K, Result)> { + let (res, id) = match self.tasks.join_one_with_id().await { + Ok(task) => { + let (id, output) = task?; + (Ok(output), id) + } + Err(e) => { + let id = e.id(); + (Err(e), id) + } + }; + let key = self.remove_by_id(id)?; + Some((key, res)) + } + + /// Aborts all tasks and waits for them to finish shutting down. + /// + /// Calling this method is equivalent to calling [`abort_all`] and then calling [`join_one`] in + /// a loop until it returns `None`. + /// + /// This method ignores any panics in the tasks shutting down. When this call returns, the + /// `JoinMap` will be empty. + /// + /// [`abort_all`]: fn@Self::abort_all + /// [`join_one`]: fn@Self::join_one + pub async fn shutdown(&mut self) { + self.abort_all(); + while self.join_one().await.is_some() {} + } + + /// Abort the task corresponding to the provided `key`. + /// + /// If this `JoinMap` contains a task corresponding to `key`, this method + /// will abort that task and return `true`. Otherwise, if no task exists for + /// `key`, this method returns `false`. + /// + /// # Examples + /// + /// Aborting a task by key: + /// + /// ``` + /// use tokio_util::task::JoinMap; + /// + /// # #[tokio::main] + /// # async fn main() { + /// let mut map = JoinMap::new(); + /// + /// map.spawn("hello world", async move { /* ... */ }); + /// map.spawn("goodbye world", async move { /* ... */}); + /// + /// // Look up the "goodbye world" task in the map and abort it. + /// map.abort("goodbye world"); + /// + /// while let Some((key, res)) = map.join_one().await { + /// if key == "goodbye world" { + /// // The aborted task should complete with a cancelled `JoinError`. + /// assert!(res.unwrap_err().is_cancelled()); + /// } else { + /// // Other tasks should complete normally. + /// assert!(res.is_ok()); + /// } + /// } + /// # } + /// ``` + /// + /// `abort` returns `true` if a task was aborted: + /// ``` + /// use tokio_util::task::JoinMap; + /// + /// # #[tokio::main] + /// # async fn main() { + /// let mut map = JoinMap::new(); + /// + /// map.spawn("hello world", async move { /* ... */ }); + /// map.spawn("goodbye world", async move { /* ... */}); + /// + /// // A task for the key "goodbye world" should exist in the map: + /// assert!(map.abort("goodbye world")); + /// + /// // Aborting a key that does not exist will return `false`: + /// assert!(!map.abort("goodbye universe")); + /// # } + /// ``` + pub fn abort(&mut self, key: &Q) -> bool + where + Q: Hash + Eq, + K: Borrow, + { + match self.get_by_key(key) { + Some((_, handle)) => { + handle.abort(); + true + } + None => false, + } + } + + /// Aborts all tasks with keys matching `predicate`. + /// + /// `predicate` is a function called with a reference to each key in the + /// map. If it returns `true` for a given key, the corresponding task will + /// be cancelled. + /// + /// # Examples + /// ``` + /// use tokio_util::task::JoinMap; + /// + /// # // use the current thread rt so that spawned tasks don't + /// # // complete in the background before they can be aborted. + /// # #[tokio::main(flavor = "current_thread")] + /// # async fn main() { + /// let mut map = JoinMap::new(); + /// + /// map.spawn("hello world", async move { + /// // ... + /// # tokio::task::yield_now().await; // don't complete immediately, get aborted! + /// }); + /// map.spawn("goodbye world", async move { + /// // ... + /// # tokio::task::yield_now().await; // don't complete immediately, get aborted! + /// }); + /// map.spawn("hello san francisco", async move { + /// // ... + /// # tokio::task::yield_now().await; // don't complete immediately, get aborted! + /// }); + /// map.spawn("goodbye universe", async move { + /// // ... + /// # tokio::task::yield_now().await; // don't complete immediately, get aborted! + /// }); + /// + /// // Abort all tasks whose keys begin with "goodbye" + /// map.abort_matching(|key| key.starts_with("goodbye")); + /// + /// let mut seen = 0; + /// while let Some((key, res)) = map.join_one().await { + /// seen += 1; + /// if key.starts_with("goodbye") { + /// // The aborted task should complete with a cancelled `JoinError`. + /// assert!(res.unwrap_err().is_cancelled()); + /// } else { + /// // Other tasks should complete normally. + /// assert!(key.starts_with("hello")); + /// assert!(res.is_ok()); + /// } + /// } + /// + /// // All spawned tasks should have completed. + /// assert_eq!(seen, 4); + /// # } + /// ``` + pub fn abort_matching(&mut self, mut predicate: impl FnMut(&K) -> bool) { + // Note: this method iterates over the tasks and keys *without* removing + // any entries, so that the keys from aborted tasks can still be + // returned when calling `join_one` in the future. + for (Key { ref key, .. }, task) in &self.tasks_by_key { + if predicate(key) { + task.abort(); + } + } + } + + /// Returns `true` if this `JoinMap` contains a task for the provided key. + /// + /// If the task has completed, but its output hasn't yet been consumed by a + /// call to [`join_one`], this method will still return `true`. + /// + /// [`join_one`]: fn@Self::join_one + pub fn contains_key(&self, key: &Q) -> bool + where + Q: Hash + Eq, + K: Borrow, + { + self.get_by_key(key).is_some() + } + + /// Returns `true` if this `JoinMap` contains a task with the provided + /// [task ID]. + /// + /// If the task has completed, but its output hasn't yet been consumed by a + /// call to [`join_one`], this method will still return `true`. + /// + /// [`join_one`]: fn@Self::join_one + /// [task ID]: tokio::task::Id + pub fn contains_task(&self, task: &Id) -> bool { + self.get_by_id(task).is_some() + } + + /// Reserves capacity for at least `additional` more tasks to be spawned + /// on this `JoinMap` without reallocating for the map of task keys. The + /// collection may reserve more space to avoid frequent reallocations. + /// + /// Note that spawning a task will still cause an allocation for the task + /// itself. + /// + /// # Panics + /// + /// Panics if the new allocation size overflows [`usize`]. + /// + /// # Examples + /// + /// ``` + /// use tokio_util::task::JoinMap; + /// + /// let mut map: JoinMap<&str, i32> = JoinMap::new(); + /// map.reserve(10); + /// ``` + #[inline] + pub fn reserve(&mut self, additional: usize) { + self.tasks_by_key.reserve(additional); + self.hashes_by_task.reserve(additional); + } + + /// Shrinks the capacity of the `JoinMap` as much as possible. It will drop + /// down as much as possible while maintaining the internal rules + /// and possibly leaving some space in accordance with the resize policy. + /// + /// # Examples + /// + /// ``` + /// # #[tokio::main] + /// # async fn main() { + /// use tokio_util::task::JoinMap; + /// + /// let mut map: JoinMap = JoinMap::with_capacity(100); + /// map.spawn(1, async move { 2 }); + /// map.spawn(3, async move { 4 }); + /// assert!(map.capacity() >= 100); + /// map.shrink_to_fit(); + /// assert!(map.capacity() >= 2); + /// # } + /// ``` + #[inline] + pub fn shrink_to_fit(&mut self) { + self.hashes_by_task.shrink_to_fit(); + self.tasks_by_key.shrink_to_fit(); + } + + /// Shrinks the capacity of the map with a lower limit. It will drop + /// down no lower than the supplied limit while maintaining the internal rules + /// and possibly leaving some space in accordance with the resize policy. + /// + /// If the current capacity is less than the lower limit, this is a no-op. + /// + /// # Examples + /// + /// ``` + /// # #[tokio::main] + /// # async fn main() { + /// use tokio_util::task::JoinMap; + /// + /// let mut map: JoinMap = JoinMap::with_capacity(100); + /// map.spawn(1, async move { 2 }); + /// map.spawn(3, async move { 4 }); + /// assert!(map.capacity() >= 100); + /// map.shrink_to(10); + /// assert!(map.capacity() >= 10); + /// map.shrink_to(0); + /// assert!(map.capacity() >= 2); + /// # } + /// ``` + #[inline] + pub fn shrink_to(&mut self, min_capacity: usize) { + self.hashes_by_task.shrink_to(min_capacity); + self.tasks_by_key.shrink_to(min_capacity) + } + + /// Look up a task in the map by its key, returning the key and abort handle. + fn get_by_key<'map, Q: ?Sized>(&'map self, key: &Q) -> Option<(&'map Key, &'map AbortHandle)> + where + Q: Hash + Eq, + K: Borrow, + { + let hash = self.hash(key); + self.tasks_by_key + .raw_entry() + .from_hash(hash, |k| k.key.borrow() == key) + } + + /// Look up a task in the map by its task ID, returning the key and abort handle. + fn get_by_id<'map>(&'map self, id: &Id) -> Option<(&'map Key, &'map AbortHandle)> { + let hash = self.hashes_by_task.get(id)?; + self.tasks_by_key + .raw_entry() + .from_hash(*hash, |k| &k.id == id) + } + + /// Remove a task from the map by ID, returning the key for that task. + fn remove_by_id(&mut self, id: Id) -> Option { + // Get the hash for the given ID. + let hash = self.hashes_by_task.remove(&id)?; + + // Remove the entry for that hash. + let entry = self + .tasks_by_key + .raw_entry_mut() + .from_hash(hash, |k| k.id == id); + let (Key { id: _key_id, key }, handle) = match entry { + RawEntryMut::Occupied(entry) => entry.remove_entry(), + _ => return None, + }; + debug_assert_eq!(_key_id, id); + debug_assert_eq!(id, handle.id()); + self.hashes_by_task.remove(&id); + Some(key) + } + + /// Returns the hash for a given key. + #[inline] + fn hash(&self, key: &Q) -> u64 + where + Q: Hash, + { + let mut hasher = self.tasks_by_key.hasher().build_hasher(); + key.hash(&mut hasher); + hasher.finish() + } +} + +impl JoinMap +where + V: 'static, +{ + /// Aborts all tasks on this `JoinMap`. + /// + /// This does not remove the tasks from the `JoinMap`. To wait for the tasks to complete + /// cancellation, you should call `join_one` in a loop until the `JoinMap` is empty. + pub fn abort_all(&mut self) { + self.tasks.abort_all() + } + + /// Removes all tasks from this `JoinMap` without aborting them. + /// + /// The tasks removed by this call will continue to run in the background even if the `JoinMap` + /// is dropped. They may still be aborted by key. + pub fn detach_all(&mut self) { + self.tasks.detach_all(); + self.tasks_by_key.clear(); + self.hashes_by_task.clear(); + } +} + +// Hand-written `fmt::Debug` implementation in order to avoid requiring `V: +// Debug`, since no value is ever actually stored in the map. +impl fmt::Debug for JoinMap { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + // format the task keys and abort handles a little nicer by just + // printing the key and task ID pairs, without format the `Key` struct + // itself or the `AbortHandle`, which would just format the task's ID + // again. + struct KeySet<'a, K: fmt::Debug, S>(&'a HashMap, AbortHandle, S>); + impl fmt::Debug for KeySet<'_, K, S> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_map() + .entries(self.0.keys().map(|Key { key, id }| (key, id))) + .finish() + } + } + + f.debug_struct("JoinMap") + // The `tasks_by_key` map is the only one that contains information + // that's really worth formatting for the user, since it contains + // the tasks' keys and IDs. The other fields are basically + // implementation details. + .field("tasks", &KeySet(&self.tasks_by_key)) + .finish() + } +} + +impl Default for JoinMap { + fn default() -> Self { + Self::new() + } +} + +// === impl Key === + +impl Hash for Key { + // Don't include the task ID in the hash. + #[inline] + fn hash(&self, hasher: &mut H) { + self.key.hash(hasher); + } +} + +// Because we override `Hash` for this type, we must also override the +// `PartialEq` impl, so that all instances with the same hash are equal. +impl PartialEq for Key { + #[inline] + fn eq(&self, other: &Self) -> bool { + self.key == other.key + } +} + +impl Eq for Key {} diff --git a/tokio-util/src/task/mod.rs b/tokio-util/src/task/mod.rs index 5aa33df2dc0..7ba8ad9a218 100644 --- a/tokio-util/src/task/mod.rs +++ b/tokio-util/src/task/mod.rs @@ -1,4 +1,10 @@ //! Extra utilities for spawning tasks +#[cfg(tokio_unstable)] +mod join_map; mod spawn_pinned; pub use spawn_pinned::LocalPoolHandle; + +#[cfg(tokio_unstable)] +#[cfg_attr(docsrs, doc(cfg(all(tokio_unstable, feature = "rt"))))] +pub use join_map::JoinMap; diff --git a/tokio-util/tests/task_join_map.rs b/tokio-util/tests/task_join_map.rs new file mode 100644 index 00000000000..d5f87bfb185 --- /dev/null +++ b/tokio-util/tests/task_join_map.rs @@ -0,0 +1,275 @@ +#![warn(rust_2018_idioms)] +#![cfg(all(feature = "rt", tokio_unstable))] + +use tokio::sync::oneshot; +use tokio::time::Duration; +use tokio_util::task::JoinMap; + +use futures::future::FutureExt; + +fn rt() -> tokio::runtime::Runtime { + tokio::runtime::Builder::new_current_thread() + .build() + .unwrap() +} + +#[tokio::test(start_paused = true)] +async fn test_with_sleep() { + let mut map = JoinMap::new(); + + for i in 0..10 { + map.spawn(i, async move { i }); + assert_eq!(map.len(), 1 + i); + } + map.detach_all(); + assert_eq!(map.len(), 0); + + assert!(matches!(map.join_one().await, None)); + + for i in 0..10 { + map.spawn(i, async move { + tokio::time::sleep(Duration::from_secs(i as u64)).await; + i + }); + assert_eq!(map.len(), 1 + i); + } + + let mut seen = [false; 10]; + while let Some((k, res)) = map.join_one().await { + seen[k] = true; + assert_eq!(res.expect("task should have completed successfully"), k); + } + + for was_seen in &seen { + assert!(was_seen); + } + assert!(matches!(map.join_one().await, None)); + + // Do it again. + for i in 0..10 { + map.spawn(i, async move { + tokio::time::sleep(Duration::from_secs(i as u64)).await; + i + }); + } + + let mut seen = [false; 10]; + while let Some((k, res)) = map.join_one().await { + seen[k] = true; + assert_eq!(res.expect("task should have completed successfully"), k); + } + + for was_seen in &seen { + assert!(was_seen); + } + assert!(matches!(map.join_one().await, None)); +} + +#[tokio::test] +async fn test_abort_on_drop() { + let mut map = JoinMap::new(); + + let mut recvs = Vec::new(); + + for i in 0..16 { + let (send, recv) = oneshot::channel::<()>(); + recvs.push(recv); + + map.spawn(i, async { + // This task will never complete on its own. + futures::future::pending::<()>().await; + drop(send); + }); + } + + drop(map); + + for recv in recvs { + // The task is aborted soon and we will receive an error. + assert!(recv.await.is_err()); + } +} + +#[tokio::test] +async fn alternating() { + let mut map = JoinMap::new(); + + assert_eq!(map.len(), 0); + map.spawn(1, async {}); + assert_eq!(map.len(), 1); + map.spawn(2, async {}); + assert_eq!(map.len(), 2); + + for i in 0..16 { + let (_, res) = map.join_one().await.unwrap(); + assert!(res.is_ok()); + assert_eq!(map.len(), 1); + map.spawn(i, async {}); + assert_eq!(map.len(), 2); + } +} + +#[tokio::test(start_paused = true)] +async fn abort_by_key() { + let mut map = JoinMap::new(); + let mut num_canceled = 0; + let mut num_completed = 0; + for i in 0..16 { + map.spawn(i, async move { + tokio::time::sleep(Duration::from_secs(i as u64)).await; + }); + } + + for i in 0..16 { + if i % 2 != 0 { + // abort odd-numbered tasks. + map.abort(&i); + } + } + + while let Some((key, res)) = map.join_one().await { + match res { + Ok(()) => { + num_completed += 1; + assert_eq!(key % 2, 0); + assert!(!map.contains_key(&key)); + } + Err(e) => { + num_canceled += 1; + assert!(e.is_cancelled()); + assert_ne!(key % 2, 0); + assert!(!map.contains_key(&key)); + } + } + } + + assert_eq!(num_canceled, 8); + assert_eq!(num_completed, 8); +} + +#[tokio::test(start_paused = true)] +async fn abort_by_predicate() { + let mut map = JoinMap::new(); + let mut num_canceled = 0; + let mut num_completed = 0; + for i in 0..16 { + map.spawn(i, async move { + tokio::time::sleep(Duration::from_secs(i as u64)).await; + }); + } + + // abort odd-numbered tasks. + map.abort_matching(|key| key % 2 != 0); + + while let Some((key, res)) = map.join_one().await { + match res { + Ok(()) => { + num_completed += 1; + assert_eq!(key % 2, 0); + assert!(!map.contains_key(&key)); + } + Err(e) => { + num_canceled += 1; + assert!(e.is_cancelled()); + assert_ne!(key % 2, 0); + assert!(!map.contains_key(&key)); + } + } + } + + assert_eq!(num_canceled, 8); + assert_eq!(num_completed, 8); +} + +#[test] +fn runtime_gone() { + let mut map = JoinMap::new(); + { + let rt = rt(); + map.spawn_on("key", async { 1 }, rt.handle()); + drop(rt); + } + + let (key, res) = rt().block_on(map.join_one()).unwrap(); + assert_eq!(key, "key"); + assert!(res.unwrap_err().is_cancelled()); +} + +// This ensures that `join_one` works correctly when the coop budget is +// exhausted. +#[tokio::test(flavor = "current_thread")] +async fn join_map_coop() { + // Large enough to trigger coop. + const TASK_NUM: u32 = 1000; + + static SEM: tokio::sync::Semaphore = tokio::sync::Semaphore::const_new(0); + + let mut map = JoinMap::new(); + + for i in 0..TASK_NUM { + map.spawn(i, async move { + SEM.add_permits(1); + i + }); + } + + // Wait for all tasks to complete. + // + // Since this is a `current_thread` runtime, there's no race condition + // between the last permit being added and the task completing. + let _ = SEM.acquire_many(TASK_NUM).await.unwrap(); + + let mut count = 0; + let mut coop_count = 0; + loop { + match map.join_one().now_or_never() { + Some(Some((key, Ok(i)))) => assert_eq!(key, i), + Some(Some((key, Err(err)))) => panic!("failed[{}]: {}", key, err), + None => { + coop_count += 1; + tokio::task::yield_now().await; + continue; + } + Some(None) => break, + } + + count += 1; + } + assert!(coop_count >= 1); + assert_eq!(count, TASK_NUM); +} + +#[tokio::test(start_paused = true)] +async fn abort_all() { + let mut map: JoinMap = JoinMap::new(); + + for i in 0..5 { + map.spawn(i, futures::future::pending()); + } + for i in 5..10 { + map.spawn(i, async { + tokio::time::sleep(Duration::from_secs(1)).await; + }); + } + + // The join map will now have 5 pending tasks and 5 ready tasks. + tokio::time::sleep(Duration::from_secs(2)).await; + + map.abort_all(); + assert_eq!(map.len(), 10); + + let mut count = 0; + let mut seen = [false; 10]; + while let Some((k, res)) = map.join_one().await { + seen[k] = true; + if let Err(err) = res { + assert!(err.is_cancelled()); + } + count += 1; + } + assert_eq!(count, 10); + assert_eq!(map.len(), 0); + for was_seen in &seen { + assert!(was_seen); + } +} diff --git a/tokio/CHANGELOG.md b/tokio/CHANGELOG.md index 3f69f09364f..05f4d8152e0 100644 --- a/tokio/CHANGELOG.md +++ b/tokio/CHANGELOG.md @@ -1,3 +1,124 @@ +# 1.18.5 (January 17, 2023) + +### Fixed + +- io: fix unsoundness in `ReadHalf::unsplit` ([#5375]) + +[#5375]: https://github.com/tokio-rs/tokio/pull/5375 + +# 1.18.4 (January 3, 2022) + +### Fixed + +- net: fix Windows named pipe server builder to maintain option when toggling + pipe mode ([#5336]). + +[#5336]: https://github.com/tokio-rs/tokio/pull/5336 + +# 1.18.3 (September 27, 2022) + +This release removes the dependency on the `once_cell` crate to restore the MSRV +of the 1.18.x LTS release. ([#5048]) + +[#5048]: https://github.com/tokio-rs/tokio/pull/5048 + +# 1.18.2 (May 5, 2022) + +Add missing features for the `winapi` dependency. ([#4663]) + +[#4663]: https://github.com/tokio-rs/tokio/pull/4663 + +# 1.18.1 (May 2, 2022) + +The 1.18.0 release broke the build for targets without 64-bit atomics when +building with `tokio_unstable`. This release fixes that. ([#4649]) + +[#4649]: https://github.com/tokio-rs/tokio/pull/4649 + +# 1.18.0 (April 27, 2022) + +This release adds a number of new APIs in `tokio::net`, `tokio::signal`, and +`tokio::sync`. In addition, it adds new unstable APIs to `tokio::task` (`Id`s +for uniquely identifying a task, and `AbortHandle` for remotely cancelling a +task), as well as a number of bugfixes. + +### Fixed + +- blocking: add missing `#[track_caller]` for `spawn_blocking` ([#4616]) +- macros: fix `select` macro to process 64 branches ([#4519]) +- net: fix `try_io` methods not calling Mio's `try_io` internally ([#4582]) +- runtime: recover when OS fails to spawn a new thread ([#4485]) + +### Added + +- macros: support setting a custom crate name for `#[tokio::main]` and + `#[tokio::test]` ([#4613]) +- net: add `UdpSocket::peer_addr` ([#4611]) +- net: add `try_read_buf` method for named pipes ([#4626]) +- signal: add `SignalKind` `Hash`/`Eq` impls and `c_int` conversion ([#4540]) +- signal: add support for signals up to `SIGRTMAX` ([#4555]) +- sync: add `watch::Sender::send_modify` method ([#4310]) +- sync: add `broadcast::Receiver::len` method ([#4542]) +- sync: add `watch::Receiver::same_channel` method ([#4581]) +- sync: implement `Clone` for `RecvError` types ([#4560]) + +### Changed + +- update `mio` to 0.8.1 ([#4582]) +- macros: rename `tokio::select!`'s internal `util` module ([#4543]) +- runtime: use `Vec::with_capacity` when building runtime ([#4553]) + +### Documented + +- improve docs for `tokio_unstable` ([#4524]) +- runtime: include more documentation for thread_pool/worker ([#4511]) +- runtime: update `Handle::current`'s docs to mention `EnterGuard` ([#4567]) +- time: clarify platform specific timer resolution ([#4474]) +- signal: document that `Signal::recv` is cancel-safe ([#4634]) +- sync: `UnboundedReceiver` close docs ([#4548]) + +### Unstable + +The following changes only apply when building with `--cfg tokio_unstable`: + +- task: add `task::Id` type ([#4630]) +- task: add `AbortHandle` type for cancelling tasks in a `JoinSet` ([#4530], + [#4640]) +- task: fix missing `doc(cfg(...))` attributes for `JoinSet` ([#4531]) +- task: fix broken link in `AbortHandle` RustDoc ([#4545]) +- metrics: add initial IO driver metrics ([#4507]) + + +[#4616]: https://github.com/tokio-rs/tokio/pull/4616 +[#4519]: https://github.com/tokio-rs/tokio/pull/4519 +[#4582]: https://github.com/tokio-rs/tokio/pull/4582 +[#4485]: https://github.com/tokio-rs/tokio/pull/4485 +[#4613]: https://github.com/tokio-rs/tokio/pull/4613 +[#4611]: https://github.com/tokio-rs/tokio/pull/4611 +[#4626]: https://github.com/tokio-rs/tokio/pull/4626 +[#4540]: https://github.com/tokio-rs/tokio/pull/4540 +[#4555]: https://github.com/tokio-rs/tokio/pull/4555 +[#4310]: https://github.com/tokio-rs/tokio/pull/4310 +[#4542]: https://github.com/tokio-rs/tokio/pull/4542 +[#4581]: https://github.com/tokio-rs/tokio/pull/4581 +[#4560]: https://github.com/tokio-rs/tokio/pull/4560 +[#4631]: https://github.com/tokio-rs/tokio/pull/4631 +[#4582]: https://github.com/tokio-rs/tokio/pull/4582 +[#4543]: https://github.com/tokio-rs/tokio/pull/4543 +[#4553]: https://github.com/tokio-rs/tokio/pull/4553 +[#4524]: https://github.com/tokio-rs/tokio/pull/4524 +[#4511]: https://github.com/tokio-rs/tokio/pull/4511 +[#4567]: https://github.com/tokio-rs/tokio/pull/4567 +[#4474]: https://github.com/tokio-rs/tokio/pull/4474 +[#4634]: https://github.com/tokio-rs/tokio/pull/4634 +[#4548]: https://github.com/tokio-rs/tokio/pull/4548 +[#4630]: https://github.com/tokio-rs/tokio/pull/4630 +[#4530]: https://github.com/tokio-rs/tokio/pull/4530 +[#4640]: https://github.com/tokio-rs/tokio/pull/4640 +[#4531]: https://github.com/tokio-rs/tokio/pull/4531 +[#4545]: https://github.com/tokio-rs/tokio/pull/4545 +[#4507]: https://github.com/tokio-rs/tokio/pull/4507 + # 1.17.0 (February 16, 2022) This release updates the minimum supported Rust version (MSRV) to 1.49, the @@ -23,7 +144,7 @@ performance improvements. - time: use bit manipulation instead of modulo to improve performance ([#4480]) - net: use `std::future::Ready` instead of our own `Ready` future ([#4271]) - replace deprecated `atomic::spin_loop_hint` with `hint::spin_loop` ([#4491]) -- fix miri failures in intrusive linked lists ([#4397]) +- fix miri failures in intrusive linked lists ([#4397]) ### Documented diff --git a/tokio/Cargo.toml b/tokio/Cargo.toml index ba165d2556c..3e6b95784b6 100644 --- a/tokio/Cargo.toml +++ b/tokio/Cargo.toml @@ -6,7 +6,7 @@ name = "tokio" # - README.md # - Update CHANGELOG.md. # - Create "v1.0.x" git tag. -version = "1.17.0" +version = "1.18.5" edition = "2018" rust-version = "1.49" authors = ["Tokio Contributors "] @@ -56,7 +56,6 @@ net = [ ] process = [ "bytes", - "once_cell", "libc", "mio/os-poll", "mio/os-ext", @@ -71,7 +70,6 @@ rt-multi-thread = [ "rt", ] signal = [ - "once_cell", "libc", "mio/os-poll", "mio/net", @@ -95,9 +93,8 @@ pin-project-lite = "0.2.0" # Everything else is optional... bytes = { version = "1.0.0", optional = true } -once_cell = { version = "1.5.2", optional = true } memchr = { version = "2.2", optional = true } -mio = { version = "0.8.0", optional = true } +mio = { version = "0.8.1", optional = true } socket2 = { version = "0.4.4", optional = true, features = [ "all" ] } num_cpus = { version = "1.8.0", optional = true } parking_lot = { version = "0.12.0", optional = true } @@ -113,11 +110,12 @@ signal-hook-registry = { version = "1.1.1", optional = true } [target.'cfg(unix)'.dev-dependencies] libc = { version = "0.2.42" } -nix = { version = "0.23" } +nix = { version = "0.24", default-features = false, features = ["fs", "socket"] } [target.'cfg(windows)'.dependencies.winapi] version = "0.3.8" default-features = false +features = ["std", "winsock2", "mswsock", "handleapi", "ws2ipdef", "ws2tcpip"] optional = true [target.'cfg(windows)'.dev-dependencies.ntapi] diff --git a/tokio/README.md b/tokio/README.md index 1cce34aeeff..46b1e089cfd 100644 --- a/tokio/README.md +++ b/tokio/README.md @@ -56,7 +56,7 @@ Make sure you activated the full features of the tokio crate on Cargo.toml: ```toml [dependencies] -tokio = { version = "1.17.0", features = ["full"] } +tokio = { version = "1.18.5", features = ["full"] } ``` Then, on your main.rs: diff --git a/tokio/src/io/driver/metrics.rs b/tokio/src/io/driver/metrics.rs new file mode 100644 index 00000000000..410732ce7dd --- /dev/null +++ b/tokio/src/io/driver/metrics.rs @@ -0,0 +1,22 @@ +//! This file contains mocks of the metrics types used in the I/O driver. +//! +//! The reason these mocks don't live in `src/runtime/mock.rs` is because +//! these need to be available in the case when `net` is enabled but +//! `rt` is not. + +cfg_not_rt_and_metrics! { + #[derive(Default)] + pub(crate) struct IoDriverMetrics {} + + impl IoDriverMetrics { + pub(crate) fn incr_fd_count(&self) {} + pub(crate) fn dec_fd_count(&self) {} + pub(crate) fn incr_ready_count_by(&self, _amt: u64) {} + } +} + +cfg_rt! { + cfg_metrics! { + pub(crate) use crate::runtime::IoDriverMetrics; + } +} diff --git a/tokio/src/io/driver/mod.rs b/tokio/src/io/driver/mod.rs index 19f67a24e7f..24939ca0e5c 100644 --- a/tokio/src/io/driver/mod.rs +++ b/tokio/src/io/driver/mod.rs @@ -14,10 +14,14 @@ pub(crate) use registration::Registration; mod scheduled_io; use scheduled_io::ScheduledIo; +mod metrics; + use crate::park::{Park, Unpark}; use crate::util::slab::{self, Slab}; use crate::{loom::sync::Mutex, util::bit}; +use metrics::IoDriverMetrics; + use std::fmt; use std::io; use std::sync::{Arc, Weak}; @@ -74,6 +78,8 @@ pub(super) struct Inner { /// Used to wake up the reactor from a call to `turn`. waker: mio::Waker, + + metrics: IoDriverMetrics, } #[derive(Debug, Eq, PartialEq, Clone, Copy)] @@ -130,6 +136,7 @@ impl Driver { registry, io_dispatch: allocator, waker, + metrics: IoDriverMetrics::default(), }), }) } @@ -167,14 +174,18 @@ impl Driver { } // Process all the events that came in, dispatching appropriately + let mut ready_count = 0; for event in events.iter() { let token = event.token(); if token != TOKEN_WAKEUP { self.dispatch(token, Ready::from_mio(event)); + ready_count += 1; } } + self.inner.metrics.incr_ready_count_by(ready_count); + self.events = Some(events); Ok(()) @@ -279,6 +290,21 @@ cfg_not_rt! { } } +cfg_metrics! { + impl Handle { + // TODO: Remove this when handle contains `Arc` so that we can return + // &IoDriverMetrics instead of using a closure. + // + // Related issue: https://github.com/tokio-rs/tokio/issues/4509 + pub(crate) fn with_io_driver_metrics(&self, f: F) -> Option + where + F: Fn(&IoDriverMetrics) -> R, + { + self.inner().map(|inner| f(&inner.metrics)) + } + } +} + impl Handle { /// Forces a reactor blocked in a call to `turn` to wakeup, or otherwise /// makes the next call to `turn` return immediately. @@ -335,12 +361,18 @@ impl Inner { self.registry .register(source, mio::Token(token), interest.to_mio())?; + self.metrics.incr_fd_count(); + Ok(shared) } /// Deregisters an I/O resource from the reactor. pub(super) fn deregister_source(&self, source: &mut impl mio::event::Source) -> io::Result<()> { - self.registry.deregister(source) + self.registry.deregister(source)?; + + self.metrics.dec_fd_count(); + + Ok(()) } } diff --git a/tokio/src/io/split.rs b/tokio/src/io/split.rs index 8258a0f7a08..a3aa9d60c07 100644 --- a/tokio/src/io/split.rs +++ b/tokio/src/io/split.rs @@ -74,7 +74,10 @@ impl ReadHalf { /// same `split` operation this method will panic. /// This can be checked ahead of time by comparing the stream ID /// of the two halves. - pub fn unsplit(self, wr: WriteHalf) -> T { + pub fn unsplit(self, wr: WriteHalf) -> T + where + T: Unpin, + { if self.is_pair_of(&wr) { drop(wr); diff --git a/tokio/src/io/stdio_common.rs b/tokio/src/io/stdio_common.rs index 7e4a198a82b..2715ba7923a 100644 --- a/tokio/src/io/stdio_common.rs +++ b/tokio/src/io/stdio_common.rs @@ -42,7 +42,7 @@ where // for further code. Since `AsyncWrite` can always shrink // buffer at its discretion, excessive (i.e. in tests) shrinking // does not break correctness. - // 2. If buffer is small, it will not be shrinked. + // 2. If buffer is small, it will not be shrunk. // That's why, it's "textness" will not change, so we don't have // to fixup it. if cfg!(not(any(target_os = "windows", test))) || buf.len() <= crate::io::blocking::MAX_BUF @@ -193,7 +193,7 @@ mod tests { fn test_pseudo_text() { // In this test we write a piece of binary data, whose beginning is // text though. We then validate that even in this corner case buffer - // was not shrinked too much. + // was not shrunk too much. let checked_count = super::MAGIC_CONST * super::MAX_BYTES_PER_CHAR; let mut data: Vec = str::repeat("a", checked_count).into(); data.extend(std::iter::repeat(0b1010_1010).take(MAX_BUF - checked_count + 1)); @@ -212,7 +212,7 @@ mod tests { writer.write_history.iter().copied().sum::(), data.len() ); - // Check that at most MAX_BYTES_PER_CHAR + 1 (i.e. 5) bytes were shrinked + // Check that at most MAX_BYTES_PER_CHAR + 1 (i.e. 5) bytes were shrunk // from the buffer: one because it was outside of MAX_BUF boundary, and // up to one "utf8 code point". assert!(data.len() - writer.write_history[0] <= super::MAX_BYTES_PER_CHAR + 1); diff --git a/tokio/src/lib.rs b/tokio/src/lib.rs index 35295d837a6..27d4dc83855 100644 --- a/tokio/src/lib.rs +++ b/tokio/src/lib.rs @@ -340,13 +340,43 @@ //! //! ### Unstable features //! -//! These feature flags enable **unstable** features. The public API may break in 1.x -//! releases. To enable these features, the `--cfg tokio_unstable` must be passed to -//! `rustc` when compiling. This is easiest done using the `RUSTFLAGS` env variable: -//! `RUSTFLAGS="--cfg tokio_unstable"`. +//! Some feature flags are only available when specifying the `tokio_unstable` flag: //! //! - `tracing`: Enables tracing events. //! +//! Likewise, some parts of the API are only available with the same flag: +//! +//! - [`task::JoinSet`] +//! - [`task::Builder`] +//! +//! This flag enables **unstable** features. The public API of these features +//! may break in 1.x releases. To enable these features, the `--cfg +//! tokio_unstable` argument must be passed to `rustc` when compiling. This +//! serves to explicitly opt-in to features which may break semver conventions, +//! since Cargo [does not yet directly support such opt-ins][unstable features]. +//! +//! You can specify it in your project's `.cargo/config.toml` file: +//! +//! ```toml +//! [build] +//! rustflags = ["--cfg", "tokio_unstable"] +//! ``` +//! +//! Alternatively, you can specify it with an environment variable: +//! +//! ```sh +//! ## Many *nix shells: +//! export RUSTFLAGS="--cfg tokio_unstable" +//! cargo build +//! ``` +//! +//! ```powershell +//! ## Windows PowerShell: +//! $Env:RUSTFLAGS="--cfg tokio_unstable" +//! cargo build +//! ``` +//! +//! [unstable features]: https://internals.rust-lang.org/t/feature-request-unstable-opt-in-non-transitive-crate-features/16193#why-not-a-crate-feature-2 //! [feature flags]: https://doc.rust-lang.org/cargo/reference/manifest.html#the-features-section // Test that pointer width is compatible. This asserts that e.g. usize is at @@ -483,7 +513,7 @@ pub(crate) use self::doc::winapi; #[cfg(all(not(docsrs), windows, feature = "net"))] #[allow(unused)] -pub(crate) use ::winapi; +pub(crate) use winapi; cfg_macros! { /// Implementation detail of the `select!` macro. This macro is **not** diff --git a/tokio/src/loom/std/atomic_u64.rs b/tokio/src/loom/std/atomic_u64.rs index 113992d9775..ac20f352943 100644 --- a/tokio/src/loom/std/atomic_u64.rs +++ b/tokio/src/loom/std/atomic_u64.rs @@ -75,4 +75,12 @@ cfg_not_has_atomic_u64! { self.compare_exchange(current, new, success, failure) } } + + impl Default for AtomicU64 { + fn default() -> AtomicU64 { + Self { + inner: Mutex::new(0), + } + } + } } diff --git a/tokio/src/loom/std/parking_lot.rs b/tokio/src/loom/std/parking_lot.rs index 034a0ce69a5..e3af258d116 100644 --- a/tokio/src/loom/std/parking_lot.rs +++ b/tokio/src/loom/std/parking_lot.rs @@ -52,7 +52,7 @@ impl Mutex { } #[inline] - #[cfg(all(feature = "parking_lot", not(all(loom, test)),))] + #[cfg(all(feature = "parking_lot", not(all(loom, test))))] #[cfg_attr(docsrs, doc(cfg(all(feature = "parking_lot",))))] pub(crate) const fn const_new(t: T) -> Mutex { Mutex(PhantomData, parking_lot::const_mutex(t)) diff --git a/tokio/src/macros/cfg.rs b/tokio/src/macros/cfg.rs index b6beb3d6952..608eef08cea 100644 --- a/tokio/src/macros/cfg.rs +++ b/tokio/src/macros/cfg.rs @@ -195,6 +195,12 @@ macro_rules! cfg_not_metrics { } } +macro_rules! cfg_not_rt_and_metrics { + ($($item:item)*) => { + $( #[cfg(not(all(feature = "rt", all(tokio_unstable, not(loom)))))] $item )* + } +} + macro_rules! cfg_net { ($($item:item)*) => { $( diff --git a/tokio/src/macros/select.rs b/tokio/src/macros/select.rs index 051f8cb72a8..f38aee0f20b 100644 --- a/tokio/src/macros/select.rs +++ b/tokio/src/macros/select.rs @@ -101,6 +101,7 @@ /// * [`tokio::sync::watch::Receiver::changed`](crate::sync::watch::Receiver::changed) /// * [`tokio::net::TcpListener::accept`](crate::net::TcpListener::accept) /// * [`tokio::net::UnixListener::accept`](crate::net::UnixListener::accept) +/// * [`tokio::signal::unix::Signal::recv`](crate::signal::unix::Signal::recv) /// * [`tokio::io::AsyncReadExt::read`](crate::io::AsyncReadExt::read) on any `AsyncRead` /// * [`tokio::io::AsyncReadExt::read_buf`](crate::io::AsyncReadExt::read_buf) on any `AsyncRead` /// * [`tokio::io::AsyncWriteExt::write`](crate::io::AsyncWriteExt::write) on any `AsyncWrite` @@ -429,7 +430,8 @@ macro_rules! select { // // This module is defined within a scope and should not leak out of this // macro. - mod util { + #[doc(hidden)] + mod __tokio_select_util { // Generate an enum with one variant per select branch $crate::select_priv_declare_output_enum!( ( $($count)* ) ); } @@ -442,13 +444,13 @@ macro_rules! select { const BRANCHES: u32 = $crate::count!( $($count)* ); - let mut disabled: util::Mask = Default::default(); + let mut disabled: __tokio_select_util::Mask = Default::default(); // First, invoke all the pre-conditions. For any that return true, // set the appropriate bit in `disabled`. $( if !$c { - let mask: util::Mask = 1 << $crate::count!( $($skip)* ); + let mask: __tokio_select_util::Mask = 1 << $crate::count!( $($skip)* ); disabled |= mask; } )* @@ -525,7 +527,7 @@ macro_rules! select { } // The select is complete, return the value - return Ready($crate::select_variant!(util::Out, ($($skip)*))(out)); + return Ready($crate::select_variant!(__tokio_select_util::Out, ($($skip)*))(out)); } )* _ => unreachable!("reaching this means there probably is an off by one bug"), @@ -536,16 +538,16 @@ macro_rules! select { Pending } else { // All branches have been disabled. - Ready(util::Out::Disabled) + Ready(__tokio_select_util::Out::Disabled) } }).await }; match output { $( - $crate::select_variant!(util::Out, ($($skip)*) ($bind)) => $handle, + $crate::select_variant!(__tokio_select_util::Out, ($($skip)*) ($bind)) => $handle, )* - util::Out::Disabled => $else, + __tokio_select_util::Out::Disabled => $else, _ => unreachable!("failed to match bind"), } }}; @@ -801,6 +803,9 @@ macro_rules! count { (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { 63 }; + (_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _) => { + 64 + }; } #[macro_export] diff --git a/tokio/src/net/addr.rs b/tokio/src/net/addr.rs index 13f743c9628..36da860194d 100644 --- a/tokio/src/net/addr.rs +++ b/tokio/src/net/addr.rs @@ -136,6 +136,15 @@ impl sealed::ToSocketAddrsPriv for &[SocketAddr] { type Future = ReadyFuture; fn to_socket_addrs(&self, _: sealed::Internal) -> Self::Future { + // Clippy doesn't like the `to_vec()` call here (as it will allocate, + // while `self.iter().copied()` would not), but it's actually necessary + // in order to ensure that the returned iterator is valid for the + // `'static` lifetime, which the borrowed `slice::Iter` iterator would + // not be. + // Note that we can't actually add an `allow` attribute for + // `clippy::unnecessary_to_owned` here, as Tokio's CI runs clippy lints + // on Rust 1.52 to avoid breaking LTS releases of Tokio. Users of newer + // Rust versions who see this lint should just ignore it. let iter = self.to_vec().into_iter(); future::ready(Ok(iter)) } diff --git a/tokio/src/net/tcp/stream.rs b/tokio/src/net/tcp/stream.rs index ebb67b84d16..5228e1bc242 100644 --- a/tokio/src/net/tcp/stream.rs +++ b/tokio/src/net/tcp/stream.rs @@ -968,7 +968,9 @@ impl TcpStream { interest: Interest, f: impl FnOnce() -> io::Result, ) -> io::Result { - self.io.registration().try_io(interest, f) + self.io + .registration() + .try_io(interest, || self.io.try_io(f)) } /// Receives data on the socket from the remote address to which it is diff --git a/tokio/src/net/udp.rs b/tokio/src/net/udp.rs index 12af5152c28..bd905e91a5f 100644 --- a/tokio/src/net/udp.rs +++ b/tokio/src/net/udp.rs @@ -278,6 +278,28 @@ impl UdpSocket { self.io.local_addr() } + /// Returns the socket address of the remote peer this socket was connected to. + /// + /// # Example + /// + /// ``` + /// use tokio::net::UdpSocket; + /// + /// # use std::{io, net::SocketAddr}; + /// # #[tokio::main] + /// # async fn main() -> io::Result<()> { + /// let addr = "0.0.0.0:8080".parse::().unwrap(); + /// let peer = "127.0.0.1:11100".parse::().unwrap(); + /// let sock = UdpSocket::bind(addr).await?; + /// sock.connect(peer).await?; + /// assert_eq!(peer, sock.peer_addr()?); + /// # Ok(()) + /// # } + /// ``` + pub fn peer_addr(&self) -> io::Result { + self.io.peer_addr() + } + /// Connects the UDP socket setting the default destination for send() and /// limiting packets that are read via recv from the address specified in /// `addr`. @@ -1272,7 +1294,9 @@ impl UdpSocket { interest: Interest, f: impl FnOnce() -> io::Result, ) -> io::Result { - self.io.registration().try_io(interest, f) + self.io + .registration() + .try_io(interest, || self.io.try_io(f)) } /// Receives data from the socket, without removing it from the input queue. diff --git a/tokio/src/net/unix/datagram/socket.rs b/tokio/src/net/unix/datagram/socket.rs index d5b618663dc..def006c4761 100644 --- a/tokio/src/net/unix/datagram/socket.rs +++ b/tokio/src/net/unix/datagram/socket.rs @@ -1241,7 +1241,9 @@ impl UnixDatagram { interest: Interest, f: impl FnOnce() -> io::Result, ) -> io::Result { - self.io.registration().try_io(interest, f) + self.io + .registration() + .try_io(interest, || self.io.try_io(f)) } /// Returns the local address that this socket is bound to. diff --git a/tokio/src/net/unix/stream.rs b/tokio/src/net/unix/stream.rs index 4e7ef87b416..fe2d825bf98 100644 --- a/tokio/src/net/unix/stream.rs +++ b/tokio/src/net/unix/stream.rs @@ -685,7 +685,9 @@ impl UnixStream { interest: Interest, f: impl FnOnce() -> io::Result, ) -> io::Result { - self.io.registration().try_io(interest, f) + self.io + .registration() + .try_io(interest, || self.io.try_io(f)) } /// Creates new `UnixStream` from a `std::os::unix::net::UnixStream`. diff --git a/tokio/src/net/windows/named_pipe.rs b/tokio/src/net/windows/named_pipe.rs index 550fd4df2bc..51c625e8db6 100644 --- a/tokio/src/net/windows/named_pipe.rs +++ b/tokio/src/net/windows/named_pipe.rs @@ -12,6 +12,10 @@ use std::task::{Context, Poll}; use crate::io::{AsyncRead, AsyncWrite, Interest, PollEvented, ReadBuf, Ready}; use crate::os::windows::io::{AsRawHandle, FromRawHandle, RawHandle}; +cfg_io_util! { + use bytes::BufMut; +} + // Hide imports which are not used when generating documentation. #[cfg(not(docsrs))] mod doc { @@ -528,6 +532,86 @@ impl NamedPipeServer { .try_io(Interest::READABLE, || (&*self.io).read_vectored(bufs)) } + cfg_io_util! { + /// Tries to read data from the stream into the provided buffer, advancing the + /// buffer's internal cursor, returning how many bytes were read. + /// + /// Receives any pending data from the socket but does not wait for new data + /// to arrive. On success, returns the number of bytes read. Because + /// `try_read_buf()` is non-blocking, the buffer does not have to be stored by + /// the async task and can exist entirely on the stack. + /// + /// Usually, [`readable()`] or [`ready()`] is used with this function. + /// + /// [`readable()`]: NamedPipeServer::readable() + /// [`ready()`]: NamedPipeServer::ready() + /// + /// # Return + /// + /// If data is successfully read, `Ok(n)` is returned, where `n` is the + /// number of bytes read. `Ok(0)` indicates the stream's read half is closed + /// and will no longer yield data. If the stream is not ready to read data + /// `Err(io::ErrorKind::WouldBlock)` is returned. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::windows::named_pipe; + /// use std::error::Error; + /// use std::io; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-client-readable"; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box> { + /// let server = named_pipe::ServerOptions::new().create(PIPE_NAME)?; + /// + /// loop { + /// // Wait for the socket to be readable + /// server.readable().await?; + /// + /// let mut buf = Vec::with_capacity(4096); + /// + /// // Try to read data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match server.try_read_buf(&mut buf) { + /// Ok(0) => break, + /// Ok(n) => { + /// println!("read {} bytes", n); + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_read_buf(&self, buf: &mut B) -> io::Result { + self.io.registration().try_io(Interest::READABLE, || { + use std::io::Read; + + let dst = buf.chunk_mut(); + let dst = + unsafe { &mut *(dst as *mut _ as *mut [std::mem::MaybeUninit] as *mut [u8]) }; + + // Safety: We trust `NamedPipeServer::read` to have filled up `n` bytes in the + // buffer. + let n = (&*self.io).read(dst)?; + + unsafe { + buf.advance_mut(n); + } + + Ok(n) + }) + } + } + /// Waits for the pipe to become writable. /// /// This function is equivalent to `ready(Interest::WRITABLE)` and is usually @@ -1186,6 +1270,86 @@ impl NamedPipeClient { .try_io(Interest::READABLE, || (&*self.io).read_vectored(bufs)) } + cfg_io_util! { + /// Tries to read data from the stream into the provided buffer, advancing the + /// buffer's internal cursor, returning how many bytes were read. + /// + /// Receives any pending data from the socket but does not wait for new data + /// to arrive. On success, returns the number of bytes read. Because + /// `try_read_buf()` is non-blocking, the buffer does not have to be stored by + /// the async task and can exist entirely on the stack. + /// + /// Usually, [`readable()`] or [`ready()`] is used with this function. + /// + /// [`readable()`]: NamedPipeClient::readable() + /// [`ready()`]: NamedPipeClient::ready() + /// + /// # Return + /// + /// If data is successfully read, `Ok(n)` is returned, where `n` is the + /// number of bytes read. `Ok(0)` indicates the stream's read half is closed + /// and will no longer yield data. If the stream is not ready to read data + /// `Err(io::ErrorKind::WouldBlock)` is returned. + /// + /// # Examples + /// + /// ```no_run + /// use tokio::net::windows::named_pipe; + /// use std::error::Error; + /// use std::io; + /// + /// const PIPE_NAME: &str = r"\\.\pipe\tokio-named-pipe-client-readable"; + /// + /// #[tokio::main] + /// async fn main() -> Result<(), Box> { + /// let client = named_pipe::ClientOptions::new().open(PIPE_NAME)?; + /// + /// loop { + /// // Wait for the socket to be readable + /// client.readable().await?; + /// + /// let mut buf = Vec::with_capacity(4096); + /// + /// // Try to read data, this may still fail with `WouldBlock` + /// // if the readiness event is a false positive. + /// match client.try_read_buf(&mut buf) { + /// Ok(0) => break, + /// Ok(n) => { + /// println!("read {} bytes", n); + /// } + /// Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + /// continue; + /// } + /// Err(e) => { + /// return Err(e.into()); + /// } + /// } + /// } + /// + /// Ok(()) + /// } + /// ``` + pub fn try_read_buf(&self, buf: &mut B) -> io::Result { + self.io.registration().try_io(Interest::READABLE, || { + use std::io::Read; + + let dst = buf.chunk_mut(); + let dst = + unsafe { &mut *(dst as *mut _ as *mut [std::mem::MaybeUninit] as *mut [u8]) }; + + // Safety: We trust `NamedPipeClient::read` to have filled up `n` bytes in the + // buffer. + let n = (&*self.io).read(dst)?; + + unsafe { + buf.advance_mut(n); + } + + Ok(n) + }) + } + } + /// Waits for the pipe to become writable. /// /// This function is equivalent to `ready(Interest::WRITABLE)` and is usually @@ -1517,11 +1681,10 @@ impl ServerOptions { /// /// [`dwPipeMode`]: https://docs.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-createnamedpipea pub fn pipe_mode(&mut self, pipe_mode: PipeMode) -> &mut Self { - self.pipe_mode = match pipe_mode { - PipeMode::Byte => winbase::PIPE_TYPE_BYTE, - PipeMode::Message => winbase::PIPE_TYPE_MESSAGE, - }; - + let is_msg = matches!(pipe_mode, PipeMode::Message); + // Pipe mode is implemented as a bit flag 0x4. Set is message and unset + // is byte. + bool_flag!(self.pipe_mode, is_msg, winbase::PIPE_TYPE_MESSAGE); self } @@ -2248,3 +2411,48 @@ unsafe fn named_pipe_info(handle: RawHandle) -> io::Result { max_instances, }) } + +#[cfg(test)] +mod test { + use self::winbase::{PIPE_REJECT_REMOTE_CLIENTS, PIPE_TYPE_BYTE, PIPE_TYPE_MESSAGE}; + use super::*; + + #[test] + fn opts_default_pipe_mode() { + let opts = ServerOptions::new(); + assert_eq!(opts.pipe_mode, PIPE_TYPE_BYTE | PIPE_REJECT_REMOTE_CLIENTS); + } + + #[test] + fn opts_unset_reject_remote() { + let mut opts = ServerOptions::new(); + opts.reject_remote_clients(false); + assert_eq!(opts.pipe_mode & PIPE_REJECT_REMOTE_CLIENTS, 0); + } + + #[test] + fn opts_set_pipe_mode_maintains_reject_remote_clients() { + let mut opts = ServerOptions::new(); + opts.pipe_mode(PipeMode::Byte); + assert_eq!(opts.pipe_mode, PIPE_TYPE_BYTE | PIPE_REJECT_REMOTE_CLIENTS); + + opts.reject_remote_clients(false); + opts.pipe_mode(PipeMode::Byte); + assert_eq!(opts.pipe_mode, PIPE_TYPE_BYTE); + + opts.reject_remote_clients(true); + opts.pipe_mode(PipeMode::Byte); + assert_eq!(opts.pipe_mode, PIPE_TYPE_BYTE | PIPE_REJECT_REMOTE_CLIENTS); + + opts.reject_remote_clients(false); + opts.pipe_mode(PipeMode::Message); + assert_eq!(opts.pipe_mode, PIPE_TYPE_MESSAGE); + + opts.reject_remote_clients(true); + opts.pipe_mode(PipeMode::Message); + assert_eq!( + opts.pipe_mode, + PIPE_TYPE_MESSAGE | PIPE_REJECT_REMOTE_CLIENTS + ); + } +} diff --git a/tokio/src/process/mod.rs b/tokio/src/process/mod.rs index 4e1a21dd449..719fdeee6a1 100644 --- a/tokio/src/process/mod.rs +++ b/tokio/src/process/mod.rs @@ -111,7 +111,7 @@ //! let mut cmd = Command::new("sort"); //! //! // Specifying that we want pipe both the output and the input. -//! // Similarily to capturing the output, by configuring the pipe +//! // Similarly to capturing the output, by configuring the pipe //! // to stdin it can now be used as an asynchronous writer. //! cmd.stdout(Stdio::piped()); //! cmd.stdin(Stdio::piped()); diff --git a/tokio/src/process/unix/mod.rs b/tokio/src/process/unix/mod.rs index 576fe6cb47f..0bc4e78503f 100644 --- a/tokio/src/process/unix/mod.rs +++ b/tokio/src/process/unix/mod.rs @@ -34,10 +34,10 @@ use crate::process::kill::Kill; use crate::process::SpawnedChild; use crate::signal::unix::driver::Handle as SignalHandle; use crate::signal::unix::{signal, Signal, SignalKind}; +use crate::util::once_cell::OnceCell; use mio::event::Source; use mio::unix::SourceFd; -use once_cell::sync::Lazy; use std::fmt; use std::fs::File; use std::future::Future; @@ -64,25 +64,29 @@ impl Kill for StdChild { } } -static ORPHAN_QUEUE: Lazy> = Lazy::new(OrphanQueueImpl::new); +fn get_orphan_queue() -> &'static OrphanQueueImpl { + static ORPHAN_QUEUE: OnceCell> = OnceCell::new(); + + ORPHAN_QUEUE.get(OrphanQueueImpl::new) +} pub(crate) struct GlobalOrphanQueue; impl fmt::Debug for GlobalOrphanQueue { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - ORPHAN_QUEUE.fmt(fmt) + get_orphan_queue().fmt(fmt) } } impl GlobalOrphanQueue { fn reap_orphans(handle: &SignalHandle) { - ORPHAN_QUEUE.reap_orphans(handle) + get_orphan_queue().reap_orphans(handle) } } impl OrphanQueue for GlobalOrphanQueue { fn push_orphan(&self, orphan: StdChild) { - ORPHAN_QUEUE.push_orphan(orphan) + get_orphan_queue().push_orphan(orphan) } } diff --git a/tokio/src/runtime/basic_scheduler.rs b/tokio/src/runtime/basic_scheduler.rs index 401f55b3f2f..cb48e98ef22 100644 --- a/tokio/src/runtime/basic_scheduler.rs +++ b/tokio/src/runtime/basic_scheduler.rs @@ -5,7 +5,7 @@ use crate::park::{Park, Unpark}; use crate::runtime::context::EnterGuard; use crate::runtime::driver::Driver; use crate::runtime::task::{self, JoinHandle, OwnedTasks, Schedule, Task}; -use crate::runtime::Callback; +use crate::runtime::{Callback, HandleInner}; use crate::runtime::{MetricsBatch, SchedulerMetrics, WorkerMetrics}; use crate::sync::notify::Notify; use crate::util::atomic_cell::AtomicCell; @@ -78,6 +78,9 @@ struct Shared { /// Indicates whether the blocked on thread was woken. woken: AtomicBool, + /// Handle to I/O driver, timer, blocking pool, ... + handle_inner: HandleInner, + /// Callback for a worker parking itself before_park: Option, @@ -119,6 +122,7 @@ scoped_thread_local!(static CURRENT: Context); impl BasicScheduler { pub(crate) fn new( driver: Driver, + handle_inner: HandleInner, before_park: Option, after_unpark: Option, ) -> BasicScheduler { @@ -130,6 +134,7 @@ impl BasicScheduler { owned: OwnedTasks::new(), unpark, woken: AtomicBool::new(false), + handle_inner, before_park, after_unpark, scheduler_metrics: SchedulerMetrics::new(), @@ -365,12 +370,12 @@ impl Context { impl Spawner { /// Spawns a future onto the basic scheduler - pub(crate) fn spawn(&self, future: F) -> JoinHandle + pub(crate) fn spawn(&self, future: F, id: super::task::Id) -> JoinHandle where F: crate::future::Future + Send + 'static, F::Output: Send + 'static, { - let (handle, notified) = self.shared.owned.bind(future, self.shared.clone()); + let (handle, notified) = self.shared.owned.bind(future, self.shared.clone(), id); if let Some(notified) = notified { self.shared.schedule(notified); @@ -397,6 +402,10 @@ impl Spawner { pub(crate) fn reset_woken(&self) -> bool { self.shared.woken.swap(false, AcqRel) } + + pub(crate) fn as_handle_inner(&self) -> &HandleInner { + &self.shared.handle_inner + } } cfg_metrics! { diff --git a/tokio/src/runtime/blocking/mod.rs b/tokio/src/runtime/blocking/mod.rs index 15fe05c9ade..88d5e6b6a99 100644 --- a/tokio/src/runtime/blocking/mod.rs +++ b/tokio/src/runtime/blocking/mod.rs @@ -21,28 +21,3 @@ use crate::runtime::Builder; pub(crate) fn create_blocking_pool(builder: &Builder, thread_cap: usize) -> BlockingPool { BlockingPool::new(builder, thread_cap) } - -/* -cfg_not_blocking_impl! { - use crate::runtime::Builder; - use std::time::Duration; - - #[derive(Debug, Clone)] - pub(crate) struct BlockingPool {} - - pub(crate) use BlockingPool as Spawner; - - pub(crate) fn create_blocking_pool(_builder: &Builder, _thread_cap: usize) -> BlockingPool { - BlockingPool {} - } - - impl BlockingPool { - pub(crate) fn spawner(&self) -> &BlockingPool { - self - } - - pub(crate) fn shutdown(&mut self, _duration: Option) { - } - } -} -*/ diff --git a/tokio/src/runtime/blocking/pool.rs b/tokio/src/runtime/blocking/pool.rs index daf1f63fac3..f73868ee9e7 100644 --- a/tokio/src/runtime/blocking/pool.rs +++ b/tokio/src/runtime/blocking/pool.rs @@ -7,7 +7,7 @@ use crate::runtime::blocking::shutdown; use crate::runtime::builder::ThreadNameFn; use crate::runtime::context; use crate::runtime::task::{self, JoinHandle}; -use crate::runtime::{Builder, Callback, Handle}; +use crate::runtime::{Builder, Callback, ToHandle}; use std::collections::{HashMap, VecDeque}; use std::fmt; @@ -104,6 +104,7 @@ const KEEP_ALIVE: Duration = Duration::from_secs(10); /// Runs the provided function on an executor dedicated to blocking operations. /// Tasks will be scheduled as non-mandatory, meaning they may not get executed /// in case of runtime shutdown. +#[track_caller] pub(crate) fn spawn_blocking(func: F) -> JoinHandle where F: FnOnce() -> R + Send + 'static, @@ -128,7 +129,7 @@ cfg_fs! { R: Send + 'static, { let rt = context::current(); - rt.spawn_mandatory_blocking(func) + rt.as_inner().spawn_mandatory_blocking(&rt, func) } } @@ -219,7 +220,7 @@ impl fmt::Debug for BlockingPool { // ===== impl Spawner ===== impl Spawner { - pub(crate) fn spawn(&self, task: Task, rt: &Handle) -> Result<(), ()> { + pub(crate) fn spawn(&self, task: Task, rt: &dyn ToHandle) -> Result<(), ()> { let mut shared = self.inner.shared.lock(); if shared.shutdown { @@ -240,17 +241,29 @@ impl Spawner { if shared.num_th == self.inner.thread_cap { // At max number of threads } else { - shared.num_th += 1; assert!(shared.shutdown_tx.is_some()); let shutdown_tx = shared.shutdown_tx.clone(); if let Some(shutdown_tx) = shutdown_tx { let id = shared.worker_thread_index; - shared.worker_thread_index += 1; - let handle = self.spawn_thread(shutdown_tx, rt, id); - - shared.worker_threads.insert(id, handle); + match self.spawn_thread(shutdown_tx, rt, id) { + Ok(handle) => { + shared.num_th += 1; + shared.worker_thread_index += 1; + shared.worker_threads.insert(id, handle); + } + Err(ref e) if is_temporary_os_thread_error(e) && shared.num_th > 0 => { + // OS temporarily failed to spawn a new thread. + // The task will be picked up eventually by a currently + // busy thread. + } + Err(e) => { + // The OS refused to spawn the thread and there is no thread + // to pick up the task that has just been pushed to the queue. + panic!("OS can't spawn worker thread: {}", e) + } + } } } } else { @@ -270,28 +283,32 @@ impl Spawner { fn spawn_thread( &self, shutdown_tx: shutdown::Sender, - rt: &Handle, + rt: &dyn ToHandle, id: usize, - ) -> thread::JoinHandle<()> { + ) -> std::io::Result> { let mut builder = thread::Builder::new().name((self.inner.thread_name)()); if let Some(stack_size) = self.inner.stack_size { builder = builder.stack_size(stack_size); } - let rt = rt.clone(); + let rt = rt.to_handle(); - builder - .spawn(move || { - // Only the reference should be moved into the closure - let _enter = crate::runtime::context::enter(rt.clone()); - rt.blocking_spawner.inner.run(id); - drop(shutdown_tx); - }) - .expect("OS can't spawn a new worker thread") + builder.spawn(move || { + // Only the reference should be moved into the closure + let _enter = crate::runtime::context::enter(rt.clone()); + rt.as_inner().blocking_spawner.inner.run(id); + drop(shutdown_tx); + }) } } +// Tells whether the error when spawning a thread is temporary. +#[inline] +fn is_temporary_os_thread_error(error: &std::io::Error) -> bool { + matches!(error.kind(), std::io::ErrorKind::WouldBlock) +} + impl Inner { fn run(&self, worker_thread_id: usize) { if let Some(f) = &self.after_start { diff --git a/tokio/src/runtime/builder.rs b/tokio/src/runtime/builder.rs index 91c365fd516..618474c05ce 100644 --- a/tokio/src/runtime/builder.rs +++ b/tokio/src/runtime/builder.rs @@ -555,32 +555,37 @@ impl Builder { } fn build_basic_runtime(&mut self) -> io::Result { - use crate::runtime::{BasicScheduler, Kind}; + use crate::runtime::{BasicScheduler, HandleInner, Kind}; let (driver, resources) = driver::Driver::new(self.get_cfg())?; + // Blocking pool + let blocking_pool = blocking::create_blocking_pool(self, self.max_blocking_threads); + let blocking_spawner = blocking_pool.spawner().clone(); + + let handle_inner = HandleInner { + io_handle: resources.io_handle, + time_handle: resources.time_handle, + signal_handle: resources.signal_handle, + clock: resources.clock, + blocking_spawner, + }; + // And now put a single-threaded scheduler on top of the timer. When // there are no futures ready to do something, it'll let the timer or // the reactor to generate some new stimuli for the futures to continue // in their life. - let scheduler = - BasicScheduler::new(driver, self.before_park.clone(), self.after_unpark.clone()); + let scheduler = BasicScheduler::new( + driver, + handle_inner, + self.before_park.clone(), + self.after_unpark.clone(), + ); let spawner = Spawner::Basic(scheduler.spawner().clone()); - // Blocking pool - let blocking_pool = blocking::create_blocking_pool(self, self.max_blocking_threads); - let blocking_spawner = blocking_pool.spawner().clone(); - Ok(Runtime { kind: Kind::CurrentThread(scheduler), - handle: Handle { - spawner, - io_handle: resources.io_handle, - time_handle: resources.time_handle, - signal_handle: resources.signal_handle, - clock: resources.clock, - blocking_spawner, - }, + handle: Handle { spawner }, blocking_pool, }) } @@ -662,23 +667,17 @@ cfg_rt_multi_thread! { impl Builder { fn build_threaded_runtime(&mut self) -> io::Result { use crate::loom::sys::num_cpus; - use crate::runtime::{Kind, ThreadPool}; - use crate::runtime::park::Parker; + use crate::runtime::{Kind, HandleInner, ThreadPool}; let core_threads = self.worker_threads.unwrap_or_else(num_cpus); let (driver, resources) = driver::Driver::new(self.get_cfg())?; - let (scheduler, launch) = ThreadPool::new(core_threads, Parker::new(driver), self.before_park.clone(), self.after_unpark.clone()); - let spawner = Spawner::ThreadPool(scheduler.spawner().clone()); - // Create the blocking pool let blocking_pool = blocking::create_blocking_pool(self, self.max_blocking_threads + core_threads); let blocking_spawner = blocking_pool.spawner().clone(); - // Create the runtime handle - let handle = Handle { - spawner, + let handle_inner = HandleInner { io_handle: resources.io_handle, time_handle: resources.time_handle, signal_handle: resources.signal_handle, @@ -686,6 +685,14 @@ cfg_rt_multi_thread! { blocking_spawner, }; + let (scheduler, launch) = ThreadPool::new(core_threads, driver, handle_inner, self.before_park.clone(), self.after_unpark.clone()); + let spawner = Spawner::ThreadPool(scheduler.spawner().clone()); + + // Create the runtime handle + let handle = Handle { + spawner, + }; + // Spawn the thread pool workers let _enter = crate::runtime::context::enter(handle.clone()); launch.launch(); diff --git a/tokio/src/runtime/context.rs b/tokio/src/runtime/context.rs index 1f44a534026..aebbe18755a 100644 --- a/tokio/src/runtime/context.rs +++ b/tokio/src/runtime/context.rs @@ -26,7 +26,7 @@ cfg_io_driver! { pub(crate) fn io_handle() -> crate::runtime::driver::IoHandle { match CONTEXT.try_with(|ctx| { let ctx = ctx.borrow(); - ctx.as_ref().expect(crate::util::error::CONTEXT_MISSING_ERROR).io_handle.clone() + ctx.as_ref().expect(crate::util::error::CONTEXT_MISSING_ERROR).as_inner().io_handle.clone() }) { Ok(io_handle) => io_handle, Err(_) => panic!("{}", crate::util::error::THREAD_LOCAL_DESTROYED_ERROR), @@ -39,7 +39,7 @@ cfg_signal_internal! { pub(crate) fn signal_handle() -> crate::runtime::driver::SignalHandle { match CONTEXT.try_with(|ctx| { let ctx = ctx.borrow(); - ctx.as_ref().expect(crate::util::error::CONTEXT_MISSING_ERROR).signal_handle.clone() + ctx.as_ref().expect(crate::util::error::CONTEXT_MISSING_ERROR).as_inner().signal_handle.clone() }) { Ok(signal_handle) => signal_handle, Err(_) => panic!("{}", crate::util::error::THREAD_LOCAL_DESTROYED_ERROR), @@ -51,7 +51,7 @@ cfg_time! { pub(crate) fn time_handle() -> crate::runtime::driver::TimeHandle { match CONTEXT.try_with(|ctx| { let ctx = ctx.borrow(); - ctx.as_ref().expect(crate::util::error::CONTEXT_MISSING_ERROR).time_handle.clone() + ctx.as_ref().expect(crate::util::error::CONTEXT_MISSING_ERROR).as_inner().time_handle.clone() }) { Ok(time_handle) => time_handle, Err(_) => panic!("{}", crate::util::error::THREAD_LOCAL_DESTROYED_ERROR), @@ -60,7 +60,7 @@ cfg_time! { cfg_test_util! { pub(crate) fn clock() -> Option { - match CONTEXT.try_with(|ctx| (*ctx.borrow()).as_ref().map(|ctx| ctx.clock.clone())) { + match CONTEXT.try_with(|ctx| (*ctx.borrow()).as_ref().map(|ctx| ctx.as_inner().clock.clone())) { Ok(clock) => clock, Err(_) => panic!("{}", crate::util::error::THREAD_LOCAL_DESTROYED_ERROR), } diff --git a/tokio/src/runtime/handle.rs b/tokio/src/runtime/handle.rs index 9dbe6774dd0..9d4a35e5e48 100644 --- a/tokio/src/runtime/handle.rs +++ b/tokio/src/runtime/handle.rs @@ -16,7 +16,11 @@ use std::{error, fmt}; #[derive(Debug, Clone)] pub struct Handle { pub(super) spawner: Spawner, +} +/// All internal handles that are *not* the scheduler's spawner. +#[derive(Debug)] +pub(crate) struct HandleInner { /// Handles to the I/O drivers #[cfg_attr( not(any(feature = "net", feature = "process", all(unix, feature = "signal"))), @@ -47,6 +51,11 @@ pub struct Handle { pub(super) blocking_spawner: blocking::Spawner, } +/// Create a new runtime handle. +pub(crate) trait ToHandle { + fn to_handle(&self) -> Handle; +} + /// Runtime context guard. /// /// Returned by [`Runtime::enter`] and [`Handle::enter`], the context guard exits @@ -63,7 +72,8 @@ pub struct EnterGuard<'a> { impl Handle { /// Enters the runtime context. This allows you to construct types that must /// have an executor available on creation such as [`Sleep`] or [`TcpStream`]. - /// It will also allow you to call methods such as [`tokio::spawn`]. + /// It will also allow you to call methods such as [`tokio::spawn`] and [`Handle::current`] + /// without panicking. /// /// [`Sleep`]: struct@crate::time::Sleep /// [`TcpStream`]: struct@crate::net::TcpStream @@ -80,8 +90,9 @@ impl Handle { /// # Panic /// /// This will panic if called outside the context of a Tokio runtime. That means that you must - /// call this on one of the threads **being run by the runtime**. Calling this from within a - /// thread created by `std::thread::spawn` (for example) will cause a panic. + /// call this on one of the threads **being run by the runtime**, or from a thread with an active + /// `EnterGuard`. Calling this from within a thread created by `std::thread::spawn` (for example) + /// will cause a panic unless that thread has an active `EnterGuard`. /// /// # Examples /// @@ -105,9 +116,14 @@ impl Handle { /// # let handle = /// thread::spawn(move || { /// // Notice that the handle is created outside of this thread and then moved in - /// handle.spawn(async { /* ... */ }) - /// // This next line would cause a panic - /// // let handle2 = Handle::current(); + /// handle.spawn(async { /* ... */ }); + /// // This next line would cause a panic because we haven't entered the runtime + /// // and created an EnterGuard + /// // let handle2 = Handle::current(); // panic + /// // So we create a guard here with Handle::enter(); + /// let _guard = handle.enter(); + /// // Now we can call Handle::current(); + /// let handle2 = Handle::current(); /// }); /// # handle.join().unwrap(); /// # }); @@ -159,9 +175,10 @@ impl Handle { F: Future + Send + 'static, F::Output: Send + 'static, { + let id = crate::runtime::task::Id::next(); #[cfg(all(tokio_unstable, feature = "tracing"))] - let future = crate::util::trace::task(future, "task", None); - self.spawner.spawn(future) + let future = crate::util::trace::task(future, "task", None, id.as_u64()); + self.spawner.spawn(future, id) } /// Runs the provided function on an executor dedicated to blocking. @@ -189,85 +206,11 @@ impl Handle { F: FnOnce() -> R + Send + 'static, R: Send + 'static, { - let (join_handle, _was_spawned) = - if cfg!(debug_assertions) && std::mem::size_of::() > 2048 { - self.spawn_blocking_inner(Box::new(func), blocking::Mandatory::NonMandatory, None) - } else { - self.spawn_blocking_inner(func, blocking::Mandatory::NonMandatory, None) - }; - - join_handle - } - - cfg_fs! { - #[track_caller] - #[cfg_attr(any( - all(loom, not(test)), // the function is covered by loom tests - test - ), allow(dead_code))] - pub(crate) fn spawn_mandatory_blocking(&self, func: F) -> Option> - where - F: FnOnce() -> R + Send + 'static, - R: Send + 'static, - { - let (join_handle, was_spawned) = if cfg!(debug_assertions) && std::mem::size_of::() > 2048 { - self.spawn_blocking_inner( - Box::new(func), - blocking::Mandatory::Mandatory, - None - ) - } else { - self.spawn_blocking_inner( - func, - blocking::Mandatory::Mandatory, - None - ) - }; - - if was_spawned { - Some(join_handle) - } else { - None - } - } + self.as_inner().spawn_blocking(self, func) } - #[track_caller] - pub(crate) fn spawn_blocking_inner( - &self, - func: F, - is_mandatory: blocking::Mandatory, - name: Option<&str>, - ) -> (JoinHandle, bool) - where - F: FnOnce() -> R + Send + 'static, - R: Send + 'static, - { - let fut = BlockingTask::new(func); - - #[cfg(all(tokio_unstable, feature = "tracing"))] - let fut = { - use tracing::Instrument; - let location = std::panic::Location::caller(); - let span = tracing::trace_span!( - target: "tokio::task::blocking", - "runtime.spawn", - kind = %"blocking", - task.name = %name.unwrap_or_default(), - "fn" = %std::any::type_name::(), - spawn.location = %format_args!("{}:{}:{}", location.file(), location.line(), location.column()), - ); - fut.instrument(span) - }; - - #[cfg(not(all(tokio_unstable, feature = "tracing")))] - let _ = name; - - let (task, handle) = task::unowned(fut, NoopSchedule); - let spawned = self - .blocking_spawner - .spawn(blocking::Task::new(task, is_mandatory), self); - (handle, spawned.is_ok()) + pub(crate) fn as_inner(&self) -> &HandleInner { + self.spawner.as_handle_inner() } /// Runs a future to completion on this `Handle`'s associated `Runtime`. @@ -343,7 +286,8 @@ impl Handle { #[track_caller] pub fn block_on(&self, future: F) -> F::Output { #[cfg(all(tokio_unstable, feature = "tracing"))] - let future = crate::util::trace::task(future, "block_on", None); + let future = + crate::util::trace::task(future, "block_on", None, super::task::Id::next().as_u64()); // Enter the **runtime** context. This configures spawning, the current I/O driver, ... let _rt_enter = self.enter(); @@ -362,6 +306,12 @@ impl Handle { } } +impl ToHandle for Handle { + fn to_handle(&self) -> Handle { + self.clone() + } +} + cfg_metrics! { use crate::runtime::RuntimeMetrics; @@ -374,6 +324,100 @@ cfg_metrics! { } } +impl HandleInner { + #[track_caller] + pub(crate) fn spawn_blocking(&self, rt: &dyn ToHandle, func: F) -> JoinHandle + where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, + { + let (join_handle, _was_spawned) = if cfg!(debug_assertions) + && std::mem::size_of::() > 2048 + { + self.spawn_blocking_inner(Box::new(func), blocking::Mandatory::NonMandatory, None, rt) + } else { + self.spawn_blocking_inner(func, blocking::Mandatory::NonMandatory, None, rt) + }; + + join_handle + } + + cfg_fs! { + #[track_caller] + #[cfg_attr(any( + all(loom, not(test)), // the function is covered by loom tests + test + ), allow(dead_code))] + pub(crate) fn spawn_mandatory_blocking(&self, rt: &dyn ToHandle, func: F) -> Option> + where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, + { + let (join_handle, was_spawned) = if cfg!(debug_assertions) && std::mem::size_of::() > 2048 { + self.spawn_blocking_inner( + Box::new(func), + blocking::Mandatory::Mandatory, + None, + rt, + ) + } else { + self.spawn_blocking_inner( + func, + blocking::Mandatory::Mandatory, + None, + rt, + ) + }; + + if was_spawned { + Some(join_handle) + } else { + None + } + } + } + + #[track_caller] + pub(crate) fn spawn_blocking_inner( + &self, + func: F, + is_mandatory: blocking::Mandatory, + name: Option<&str>, + rt: &dyn ToHandle, + ) -> (JoinHandle, bool) + where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, + { + let fut = BlockingTask::new(func); + let id = super::task::Id::next(); + #[cfg(all(tokio_unstable, feature = "tracing"))] + let fut = { + use tracing::Instrument; + let location = std::panic::Location::caller(); + let span = tracing::trace_span!( + target: "tokio::task::blocking", + "runtime.spawn", + kind = %"blocking", + task.name = %name.unwrap_or_default(), + task.id = id.as_u64(), + "fn" = %std::any::type_name::(), + spawn.location = %format_args!("{}:{}:{}", location.file(), location.line(), location.column()), + ); + fut.instrument(span) + }; + + #[cfg(not(all(tokio_unstable, feature = "tracing")))] + let _ = name; + + let (task, handle) = task::unowned(fut, NoopSchedule, id); + let spawned = self + .blocking_spawner + .spawn(blocking::Task::new(task, is_mandatory), rt); + (handle, spawned.is_ok()) + } +} + /// Error returned by `try_current` when no Runtime has been started #[derive(Debug)] pub struct TryCurrentError { diff --git a/tokio/src/runtime/metrics/io.rs b/tokio/src/runtime/metrics/io.rs new file mode 100644 index 00000000000..9706bfc9bc2 --- /dev/null +++ b/tokio/src/runtime/metrics/io.rs @@ -0,0 +1,30 @@ +#![cfg_attr(not(feature = "net"), allow(dead_code))] + +use crate::loom::sync::atomic::{AtomicU64, Ordering::Relaxed}; + +#[derive(Default)] +pub(crate) struct IoDriverMetrics { + pub(super) fd_registered_count: AtomicU64, + pub(super) fd_deregistered_count: AtomicU64, + pub(super) ready_count: AtomicU64, +} + +impl IoDriverMetrics { + pub(crate) fn incr_fd_count(&self) { + let prev = self.fd_registered_count.load(Relaxed); + let new = prev.wrapping_add(1); + self.fd_registered_count.store(new, Relaxed); + } + + pub(crate) fn dec_fd_count(&self) { + let prev = self.fd_deregistered_count.load(Relaxed); + let new = prev.wrapping_add(1); + self.fd_deregistered_count.store(new, Relaxed); + } + + pub(crate) fn incr_ready_count_by(&self, amt: u64) { + let prev = self.ready_count.load(Relaxed); + let new = prev.wrapping_add(amt); + self.ready_count.store(new, Relaxed); + } +} diff --git a/tokio/src/runtime/metrics/mod.rs b/tokio/src/runtime/metrics/mod.rs index ca643a59047..4b96f1b7115 100644 --- a/tokio/src/runtime/metrics/mod.rs +++ b/tokio/src/runtime/metrics/mod.rs @@ -21,6 +21,11 @@ cfg_metrics! { mod worker; pub(crate) use worker::WorkerMetrics; + + cfg_net! { + mod io; + pub(crate) use io::IoDriverMetrics; + } } cfg_not_metrics! { diff --git a/tokio/src/runtime/metrics/runtime.rs b/tokio/src/runtime/metrics/runtime.rs index 0f8055907f5..26a0118a475 100644 --- a/tokio/src/runtime/metrics/runtime.rs +++ b/tokio/src/runtime/metrics/runtime.rs @@ -386,7 +386,7 @@ impl RuntimeMetrics { /// Returns the number of tasks currently scheduled in the runtime's /// injection queue. /// - /// Tasks that are spanwed or notified from a non-runtime thread are + /// Tasks that are spawned or notified from a non-runtime thread are /// scheduled using the runtime's injection queue. This metric returns the /// **current** number of tasks pending in the injection queue. As such, the /// returned value may increase or decrease as new tasks are scheduled and @@ -447,3 +447,90 @@ impl RuntimeMetrics { self.handle.spawner.worker_local_queue_depth(worker) } } + +cfg_net! { + impl RuntimeMetrics { + /// Returns the number of file descriptors that have been registered with the + /// runtime's I/O driver. + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::Handle; + /// + /// #[tokio::main] + /// async fn main() { + /// let metrics = Handle::current().metrics(); + /// + /// let registered_fds = metrics.io_driver_fd_registered_count(); + /// println!("{} fds have been registered with the runtime's I/O driver.", registered_fds); + /// + /// let deregistered_fds = metrics.io_driver_fd_deregistered_count(); + /// + /// let current_fd_count = registered_fds - deregistered_fds; + /// println!("{} fds are currently registered by the runtime's I/O driver.", current_fd_count); + /// } + /// ``` + pub fn io_driver_fd_registered_count(&self) -> u64 { + self.with_io_driver_metrics(|m| { + m.fd_registered_count.load(Relaxed) + }) + } + + /// Returns the number of file descriptors that have been deregistered by the + /// runtime's I/O driver. + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::Handle; + /// + /// #[tokio::main] + /// async fn main() { + /// let metrics = Handle::current().metrics(); + /// + /// let n = metrics.io_driver_fd_deregistered_count(); + /// println!("{} fds have been deregistered by the runtime's I/O driver.", n); + /// } + /// ``` + pub fn io_driver_fd_deregistered_count(&self) -> u64 { + self.with_io_driver_metrics(|m| { + m.fd_deregistered_count.load(Relaxed) + }) + } + + /// Returns the number of ready events processed by the runtime's + /// I/O driver. + /// + /// # Examples + /// + /// ``` + /// use tokio::runtime::Handle; + /// + /// #[tokio::main] + /// async fn main() { + /// let metrics = Handle::current().metrics(); + /// + /// let n = metrics.io_driver_ready_count(); + /// println!("{} ready events procssed by the runtime's I/O driver.", n); + /// } + /// ``` + pub fn io_driver_ready_count(&self) -> u64 { + self.with_io_driver_metrics(|m| m.ready_count.load(Relaxed)) + } + + fn with_io_driver_metrics(&self, f: F) -> u64 + where + F: Fn(&super::IoDriverMetrics) -> u64, + { + // TODO: Investigate if this should return 0, most of our metrics always increase + // thus this breaks that guarantee. + self.handle + .as_inner() + .io_handle + .as_ref() + .and_then(|h| h.with_io_driver_metrics(f)) + .unwrap_or(0) + } + } +} diff --git a/tokio/src/runtime/metrics/worker.rs b/tokio/src/runtime/metrics/worker.rs index c9b85e48e4c..ec58de6b3a0 100644 --- a/tokio/src/runtime/metrics/worker.rs +++ b/tokio/src/runtime/metrics/worker.rs @@ -1,7 +1,7 @@ use crate::loom::sync::atomic::Ordering::Relaxed; use crate::loom::sync::atomic::{AtomicU64, AtomicUsize}; -/// Retreive runtime worker metrics. +/// Retrieve runtime worker metrics. /// /// **Note**: This is an [unstable API][unstable]. The public API of this type /// may break in 1.x releases. See [the documentation on unstable diff --git a/tokio/src/runtime/mod.rs b/tokio/src/runtime/mod.rs index 7c381b0bbd0..bd428525d00 100644 --- a/tokio/src/runtime/mod.rs +++ b/tokio/src/runtime/mod.rs @@ -187,6 +187,10 @@ cfg_metrics! { pub use metrics::RuntimeMetrics; pub(crate) use metrics::{MetricsBatch, SchedulerMetrics, WorkerMetrics}; + + cfg_net! { + pub(crate) use metrics::IoDriverMetrics; + } } cfg_not_metrics! { @@ -214,24 +218,20 @@ cfg_rt! { pub use self::builder::Builder; pub(crate) mod context; - pub(crate) mod driver; + mod driver; use self::enter::enter; mod handle; pub use handle::{EnterGuard, Handle, TryCurrentError}; + pub(crate) use handle::{HandleInner, ToHandle}; mod spawner; use self::spawner::Spawner; } cfg_rt_multi_thread! { - mod park; - use park::Parker; -} - -cfg_rt_multi_thread! { - mod queue; + use driver::Driver; pub(crate) mod thread_pool; use self::thread_pool::ThreadPool; @@ -467,7 +467,7 @@ cfg_rt! { #[track_caller] pub fn block_on(&self, future: F) -> F::Output { #[cfg(all(tokio_unstable, feature = "tracing"))] - let future = crate::util::trace::task(future, "block_on", None); + let future = crate::util::trace::task(future, "block_on", None, task::Id::next().as_u64()); let _enter = self.enter(); diff --git a/tokio/src/runtime/spawner.rs b/tokio/src/runtime/spawner.rs index d81a806cb59..fb4d7f91845 100644 --- a/tokio/src/runtime/spawner.rs +++ b/tokio/src/runtime/spawner.rs @@ -1,5 +1,6 @@ use crate::future::Future; -use crate::runtime::basic_scheduler; +use crate::runtime::task::Id; +use crate::runtime::{basic_scheduler, HandleInner}; use crate::task::JoinHandle; cfg_rt_multi_thread! { @@ -23,15 +24,23 @@ impl Spawner { } } - pub(crate) fn spawn(&self, future: F) -> JoinHandle + pub(crate) fn spawn(&self, future: F, id: Id) -> JoinHandle where F: Future + Send + 'static, F::Output: Send + 'static, { match self { - Spawner::Basic(spawner) => spawner.spawn(future), + Spawner::Basic(spawner) => spawner.spawn(future, id), #[cfg(feature = "rt-multi-thread")] - Spawner::ThreadPool(spawner) => spawner.spawn(future), + Spawner::ThreadPool(spawner) => spawner.spawn(future, id), + } + } + + pub(crate) fn as_handle_inner(&self) -> &HandleInner { + match self { + Spawner::Basic(spawner) => spawner.as_handle_inner(), + #[cfg(feature = "rt-multi-thread")] + Spawner::ThreadPool(spawner) => spawner.as_handle_inner(), } } } diff --git a/tokio/src/runtime/task/abort.rs b/tokio/src/runtime/task/abort.rs new file mode 100644 index 00000000000..4977377880d --- /dev/null +++ b/tokio/src/runtime/task/abort.rs @@ -0,0 +1,88 @@ +use crate::runtime::task::{Id, RawTask}; +use std::fmt; +use std::panic::{RefUnwindSafe, UnwindSafe}; + +/// An owned permission to abort a spawned task, without awaiting its completion. +/// +/// Unlike a [`JoinHandle`], an `AbortHandle` does *not* represent the +/// permission to await the task's completion, only to terminate it. +/// +/// The task may be aborted by calling the [`AbortHandle::abort`] method. +/// Dropping an `AbortHandle` releases the permission to terminate the task +/// --- it does *not* abort the task. +/// +/// **Note**: This is an [unstable API][unstable]. The public API of this type +/// may break in 1.x releases. See [the documentation on unstable +/// features][unstable] for details. +/// +/// [unstable]: crate#unstable-features +/// [`JoinHandle`]: crate::task::JoinHandle +#[cfg_attr(docsrs, doc(cfg(all(feature = "rt", tokio_unstable))))] +#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))] +pub struct AbortHandle { + raw: Option, + id: Id, +} + +impl AbortHandle { + pub(super) fn new(raw: Option, id: Id) -> Self { + Self { raw, id } + } + + /// Abort the task associated with the handle. + /// + /// Awaiting a cancelled task might complete as usual if the task was + /// already completed at the time it was cancelled, but most likely it + /// will fail with a [cancelled] `JoinError`. + /// + /// If the task was already cancelled, such as by [`JoinHandle::abort`], + /// this method will do nothing. + /// + /// [cancelled]: method@super::error::JoinError::is_cancelled + /// [`JoinHandle::abort`]: method@super::JoinHandle::abort + // the `AbortHandle` type is only publicly exposed when `tokio_unstable` is + // enabled, but it is still defined for testing purposes. + #[cfg_attr(not(tokio_unstable), allow(unreachable_pub))] + pub fn abort(&self) { + if let Some(ref raw) = self.raw { + raw.remote_abort(); + } + } + + /// Returns a [task ID] that uniquely identifies this task relative to other + /// currently spawned tasks. + /// + /// **Note**: This is an [unstable API][unstable]. The public API of this type + /// may break in 1.x releases. See [the documentation on unstable + /// features][unstable] for details. + /// + /// [task ID]: crate::task::Id + /// [unstable]: crate#unstable-features + #[cfg(tokio_unstable)] + #[cfg_attr(docsrs, doc(cfg(tokio_unstable)))] + pub fn id(&self) -> super::Id { + self.id.clone() + } +} + +unsafe impl Send for AbortHandle {} +unsafe impl Sync for AbortHandle {} + +impl UnwindSafe for AbortHandle {} +impl RefUnwindSafe for AbortHandle {} + +impl fmt::Debug for AbortHandle { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.debug_struct("AbortHandle") + .field("id", &self.id) + .finish() + } +} + +impl Drop for AbortHandle { + fn drop(&mut self) { + if let Some(raw) = self.raw.take() { + raw.drop_abort_handle(); + } + } +} diff --git a/tokio/src/runtime/task/core.rs b/tokio/src/runtime/task/core.rs index 776e8341f37..548c56da3d4 100644 --- a/tokio/src/runtime/task/core.rs +++ b/tokio/src/runtime/task/core.rs @@ -13,7 +13,7 @@ use crate::future::Future; use crate::loom::cell::UnsafeCell; use crate::runtime::task::raw::{self, Vtable}; use crate::runtime::task::state::State; -use crate::runtime::task::Schedule; +use crate::runtime::task::{Id, Schedule}; use crate::util::linked_list; use std::pin::Pin; @@ -49,6 +49,9 @@ pub(super) struct Core { /// Either the future or the output. pub(super) stage: CoreStage, + + /// The task's ID, used for populating `JoinError`s. + pub(super) task_id: Id, } /// Crate public as this is also needed by the pool. @@ -102,7 +105,7 @@ pub(super) enum Stage { impl Cell { /// Allocates a new task cell, containing the header, trailer, and core /// structures. - pub(super) fn new(future: T, scheduler: S, state: State) -> Box> { + pub(super) fn new(future: T, scheduler: S, state: State, task_id: Id) -> Box> { #[cfg(all(tokio_unstable, feature = "tracing"))] let id = future.id(); Box::new(Cell { @@ -120,6 +123,7 @@ impl Cell { stage: CoreStage { stage: UnsafeCell::new(Stage::Running(future)), }, + task_id, }, trailer: Trailer { waker: UnsafeCell::new(None), diff --git a/tokio/src/runtime/task/error.rs b/tokio/src/runtime/task/error.rs index 1a8129b2b6f..22b688aa221 100644 --- a/tokio/src/runtime/task/error.rs +++ b/tokio/src/runtime/task/error.rs @@ -2,12 +2,13 @@ use std::any::Any; use std::fmt; use std::io; +use super::Id; use crate::util::SyncWrapper; - cfg_rt! { /// Task failed to execute to completion. pub struct JoinError { repr: Repr, + id: Id, } } @@ -17,15 +18,17 @@ enum Repr { } impl JoinError { - pub(crate) fn cancelled() -> JoinError { + pub(crate) fn cancelled(id: Id) -> JoinError { JoinError { repr: Repr::Cancelled, + id, } } - pub(crate) fn panic(err: Box) -> JoinError { + pub(crate) fn panic(id: Id, err: Box) -> JoinError { JoinError { repr: Repr::Panic(SyncWrapper::new(err)), + id, } } @@ -111,13 +114,28 @@ impl JoinError { _ => Err(self), } } + + /// Returns a [task ID] that identifies the task which errored relative to + /// other currently spawned tasks. + /// + /// **Note**: This is an [unstable API][unstable]. The public API of this type + /// may break in 1.x releases. See [the documentation on unstable + /// features][unstable] for details. + /// + /// [task ID]: crate::task::Id + /// [unstable]: crate#unstable-features + #[cfg(tokio_unstable)] + #[cfg_attr(docsrs, doc(cfg(tokio_unstable)))] + pub fn id(&self) -> Id { + self.id.clone() + } } impl fmt::Display for JoinError { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { match &self.repr { - Repr::Cancelled => write!(fmt, "cancelled"), - Repr::Panic(_) => write!(fmt, "panic"), + Repr::Cancelled => write!(fmt, "task {} was cancelled", self.id), + Repr::Panic(_) => write!(fmt, "task {} panicked", self.id), } } } @@ -125,8 +143,8 @@ impl fmt::Display for JoinError { impl fmt::Debug for JoinError { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { match &self.repr { - Repr::Cancelled => write!(fmt, "JoinError::Cancelled"), - Repr::Panic(_) => write!(fmt, "JoinError::Panic(...)"), + Repr::Cancelled => write!(fmt, "JoinError::Cancelled({:?})", self.id), + Repr::Panic(_) => write!(fmt, "JoinError::Panic({:?}, ...)", self.id), } } } diff --git a/tokio/src/runtime/task/harness.rs b/tokio/src/runtime/task/harness.rs index 261dccea415..1d3ababfb17 100644 --- a/tokio/src/runtime/task/harness.rs +++ b/tokio/src/runtime/task/harness.rs @@ -100,7 +100,8 @@ where let header_ptr = self.header_ptr(); let waker_ref = waker_ref::(&header_ptr); let cx = Context::from_waker(&*waker_ref); - let res = poll_future(&self.core().stage, cx); + let core = self.core(); + let res = poll_future(&core.stage, core.task_id.clone(), cx); if res == Poll::Ready(()) { // The future completed. Move on to complete the task. @@ -114,14 +115,15 @@ where TransitionToIdle::Cancelled => { // The transition to idle failed because the task was // cancelled during the poll. - - cancel_task(&self.core().stage); + let core = self.core(); + cancel_task(&core.stage, core.task_id.clone()); PollFuture::Complete } } } TransitionToRunning::Cancelled => { - cancel_task(&self.core().stage); + let core = self.core(); + cancel_task(&core.stage, core.task_id.clone()); PollFuture::Complete } TransitionToRunning::Failed => PollFuture::Done, @@ -144,7 +146,8 @@ where // By transitioning the lifecycle to `Running`, we have permission to // drop the future. - cancel_task(&self.core().stage); + let core = self.core(); + cancel_task(&core.stage, core.task_id.clone()); self.complete(); } @@ -432,7 +435,7 @@ enum PollFuture { } /// Cancels the task and store the appropriate error in the stage field. -fn cancel_task(stage: &CoreStage) { +fn cancel_task(stage: &CoreStage, id: super::Id) { // Drop the future from a panic guard. let res = panic::catch_unwind(panic::AssertUnwindSafe(|| { stage.drop_future_or_output(); @@ -440,17 +443,17 @@ fn cancel_task(stage: &CoreStage) { match res { Ok(()) => { - stage.store_output(Err(JoinError::cancelled())); + stage.store_output(Err(JoinError::cancelled(id))); } Err(panic) => { - stage.store_output(Err(JoinError::panic(panic))); + stage.store_output(Err(JoinError::panic(id, panic))); } } } /// Polls the future. If the future completes, the output is written to the /// stage field. -fn poll_future(core: &CoreStage, cx: Context<'_>) -> Poll<()> { +fn poll_future(core: &CoreStage, id: super::Id, cx: Context<'_>) -> Poll<()> { // Poll the future. let output = panic::catch_unwind(panic::AssertUnwindSafe(|| { struct Guard<'a, T: Future> { @@ -473,7 +476,7 @@ fn poll_future(core: &CoreStage, cx: Context<'_>) -> Poll<()> { let output = match output { Ok(Poll::Pending) => return Poll::Pending, Ok(Poll::Ready(output)) => Ok(output), - Err(panic) => Err(JoinError::panic(panic)), + Err(panic) => Err(JoinError::panic(id, panic)), }; // Catch and ignore panics if the future panics on drop. diff --git a/tokio/src/runtime/task/join.rs b/tokio/src/runtime/task/join.rs index 8beed2eaacb..86580c84b59 100644 --- a/tokio/src/runtime/task/join.rs +++ b/tokio/src/runtime/task/join.rs @@ -1,4 +1,4 @@ -use crate::runtime::task::RawTask; +use crate::runtime::task::{Id, RawTask}; use std::fmt; use std::future::Future; @@ -144,6 +144,7 @@ cfg_rt! { /// [`JoinError`]: crate::task::JoinError pub struct JoinHandle { raw: Option, + id: Id, _p: PhantomData, } } @@ -155,9 +156,10 @@ impl UnwindSafe for JoinHandle {} impl RefUnwindSafe for JoinHandle {} impl JoinHandle { - pub(super) fn new(raw: RawTask) -> JoinHandle { + pub(super) fn new(raw: RawTask, id: Id) -> JoinHandle { JoinHandle { raw: Some(raw), + id, _p: PhantomData, } } @@ -210,6 +212,31 @@ impl JoinHandle { } } } + + /// Returns a new `AbortHandle` that can be used to remotely abort this task. + #[cfg(any(tokio_unstable, test))] + pub(crate) fn abort_handle(&self) -> super::AbortHandle { + let raw = self.raw.map(|raw| { + raw.ref_inc(); + raw + }); + super::AbortHandle::new(raw, self.id.clone()) + } + + /// Returns a [task ID] that uniquely identifies this task relative to other + /// currently spawned tasks. + /// + /// **Note**: This is an [unstable API][unstable]. The public API of this type + /// may break in 1.x releases. See [the documentation on unstable + /// features][unstable] for details. + /// + /// [task ID]: crate::task::Id + /// [unstable]: crate#unstable-features + #[cfg(tokio_unstable)] + #[cfg_attr(docsrs, doc(cfg(tokio_unstable)))] + pub fn id(&self) -> super::Id { + self.id.clone() + } } impl Unpin for JoinHandle {} @@ -270,6 +297,8 @@ where T: fmt::Debug, { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt.debug_struct("JoinHandle").finish() + fmt.debug_struct("JoinHandle") + .field("id", &self.id) + .finish() } } diff --git a/tokio/src/runtime/task/list.rs b/tokio/src/runtime/task/list.rs index 7758f8db7aa..7a1dff0bbfc 100644 --- a/tokio/src/runtime/task/list.rs +++ b/tokio/src/runtime/task/list.rs @@ -84,13 +84,14 @@ impl OwnedTasks { &self, task: T, scheduler: S, + id: super::Id, ) -> (JoinHandle, Option>) where S: Schedule, T: Future + Send + 'static, T::Output: Send + 'static, { - let (task, notified, join) = super::new_task(task, scheduler); + let (task, notified, join) = super::new_task(task, scheduler, id); unsafe { // safety: We just created the task, so we have exclusive access @@ -187,13 +188,14 @@ impl LocalOwnedTasks { &self, task: T, scheduler: S, + id: super::Id, ) -> (JoinHandle, Option>) where S: Schedule, T: Future + 'static, T::Output: 'static, { - let (task, notified, join) = super::new_task(task, scheduler); + let (task, notified, join) = super::new_task(task, scheduler, id); unsafe { // safety: We just created the task, so we have exclusive access diff --git a/tokio/src/runtime/task/mod.rs b/tokio/src/runtime/task/mod.rs index 2a492dc985d..37909b75c6d 100644 --- a/tokio/src/runtime/task/mod.rs +++ b/tokio/src/runtime/task/mod.rs @@ -155,7 +155,14 @@ cfg_rt_multi_thread! { pub(super) use self::inject::Inject; } +#[cfg(all(feature = "rt", any(tokio_unstable, test)))] +mod abort; mod join; + +#[cfg(all(feature = "rt", any(tokio_unstable, test)))] +#[allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/57411 +pub use self::abort::AbortHandle; + #[allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/57411 pub use self::join::JoinHandle; @@ -177,6 +184,27 @@ use std::marker::PhantomData; use std::ptr::NonNull; use std::{fmt, mem}; +/// An opaque ID that uniquely identifies a task relative to all other currently +/// running tasks. +/// +/// # Notes +/// +/// - Task IDs are unique relative to other *currently running* tasks. When a +/// task completes, the same ID may be used for another task. +/// - Task IDs are *not* sequential, and do not indicate the order in which +/// tasks are spawned, what runtime a task is spawned on, or any other data. +/// +/// **Note**: This is an [unstable API][unstable]. The public API of this type +/// may break in 1.x releases. See [the documentation on unstable +/// features][unstable] for details. +/// +/// [unstable]: crate#unstable-features +#[cfg_attr(docsrs, doc(cfg(all(feature = "rt", tokio_unstable))))] +#[cfg_attr(not(tokio_unstable), allow(unreachable_pub))] +// TODO(eliza): there's almost certainly no reason not to make this `Copy` as well... +#[derive(Clone, Debug, Hash, Eq, PartialEq)] +pub struct Id(u64); + /// An owned handle to the task, tracked by ref count. #[repr(transparent)] pub(crate) struct Task { @@ -243,14 +271,15 @@ cfg_rt! { /// notification. fn new_task( task: T, - scheduler: S + scheduler: S, + id: Id, ) -> (Task, Notified, JoinHandle) where S: Schedule, T: Future + 'static, T::Output: 'static, { - let raw = RawTask::new::(task, scheduler); + let raw = RawTask::new::(task, scheduler, id.clone()); let task = Task { raw, _p: PhantomData, @@ -259,7 +288,7 @@ cfg_rt! { raw, _p: PhantomData, }); - let join = JoinHandle::new(raw); + let join = JoinHandle::new(raw, id); (task, notified, join) } @@ -268,13 +297,13 @@ cfg_rt! { /// only when the task is not going to be stored in an `OwnedTasks` list. /// /// Currently only blocking tasks use this method. - pub(crate) fn unowned(task: T, scheduler: S) -> (UnownedTask, JoinHandle) + pub(crate) fn unowned(task: T, scheduler: S, id: Id) -> (UnownedTask, JoinHandle) where S: Schedule, T: Send + Future + 'static, T::Output: Send + 'static, { - let (task, notified, join) = new_task(task, scheduler); + let (task, notified, join) = new_task(task, scheduler, id); // This transfers the ref-count of task and notified into an UnownedTask. // This is valid because an UnownedTask holds two ref-counts. @@ -443,3 +472,52 @@ unsafe impl linked_list::Link for Task { NonNull::from(target.as_ref().owned.with_mut(|ptr| &mut *ptr)) } } + +impl fmt::Display for Id { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} + +impl Id { + // When 64-bit atomics are available, use a static `AtomicU64` counter to + // generate task IDs. + // + // Note(eliza): we _could_ just use `crate::loom::AtomicU64`, which switches + // between an atomic and mutex-based implementation here, rather than having + // two separate functions for targets with and without 64-bit atomics. + // However, because we can't use the mutex-based implementation in a static + // initializer directly, the 32-bit impl also has to use a `OnceCell`, and I + // thought it was nicer to avoid the `OnceCell` overhead on 64-bit + // platforms... + cfg_has_atomic_u64! { + pub(crate) fn next() -> Self { + use std::sync::atomic::{AtomicU64, Ordering::Relaxed}; + static NEXT_ID: AtomicU64 = AtomicU64::new(1); + Self(NEXT_ID.fetch_add(1, Relaxed)) + } + } + + cfg_not_has_atomic_u64! { + pub(crate) fn next() -> Self { + use crate::util::once_cell::OnceCell; + use crate::loom::sync::Mutex; + + fn init_next_id() -> Mutex { + Mutex::new(1) + } + + static NEXT_ID: OnceCell> = OnceCell::new(); + + let next_id = NEXT_ID.get(init_next_id); + let mut lock = next_id.lock(); + let id = *lock; + *lock += 1; + Self(id) + } + } + + pub(crate) fn as_u64(&self) -> u64 { + self.0 + } +} diff --git a/tokio/src/runtime/task/raw.rs b/tokio/src/runtime/task/raw.rs index 2e4420b5c13..5555298a4d4 100644 --- a/tokio/src/runtime/task/raw.rs +++ b/tokio/src/runtime/task/raw.rs @@ -1,5 +1,5 @@ use crate::future::Future; -use crate::runtime::task::{Cell, Harness, Header, Schedule, State}; +use crate::runtime::task::{Cell, Harness, Header, Id, Schedule, State}; use std::ptr::NonNull; use std::task::{Poll, Waker}; @@ -27,6 +27,9 @@ pub(super) struct Vtable { /// The join handle has been dropped. pub(super) drop_join_handle_slow: unsafe fn(NonNull
), + /// An abort handle has been dropped. + pub(super) drop_abort_handle: unsafe fn(NonNull
), + /// The task is remotely aborted. pub(super) remote_abort: unsafe fn(NonNull
), @@ -42,18 +45,19 @@ pub(super) fn vtable() -> &'static Vtable { try_read_output: try_read_output::, try_set_join_waker: try_set_join_waker::, drop_join_handle_slow: drop_join_handle_slow::, + drop_abort_handle: drop_abort_handle::, remote_abort: remote_abort::, shutdown: shutdown::, } } impl RawTask { - pub(super) fn new(task: T, scheduler: S) -> RawTask + pub(super) fn new(task: T, scheduler: S, id: Id) -> RawTask where T: Future, S: Schedule, { - let ptr = Box::into_raw(Cell::<_, S>::new(task, scheduler, State::new())); + let ptr = Box::into_raw(Cell::<_, S>::new(task, scheduler, State::new(), id)); let ptr = unsafe { NonNull::new_unchecked(ptr as *mut Header) }; RawTask { ptr } @@ -104,6 +108,11 @@ impl RawTask { unsafe { (vtable.drop_join_handle_slow)(self.ptr) } } + pub(super) fn drop_abort_handle(self) { + let vtable = self.header().vtable; + unsafe { (vtable.drop_abort_handle)(self.ptr) } + } + pub(super) fn shutdown(self) { let vtable = self.header().vtable; unsafe { (vtable.shutdown)(self.ptr) } @@ -113,6 +122,13 @@ impl RawTask { let vtable = self.header().vtable; unsafe { (vtable.remote_abort)(self.ptr) } } + + /// Increment the task's reference count. + /// + /// Currently, this is used only when creating an `AbortHandle`. + pub(super) fn ref_inc(self) { + self.header().state.ref_inc(); + } } impl Clone for RawTask { @@ -154,6 +170,11 @@ unsafe fn drop_join_handle_slow(ptr: NonNull
) { harness.drop_join_handle_slow() } +unsafe fn drop_abort_handle(ptr: NonNull
) { + let harness = Harness::::from_raw(ptr); + harness.drop_reference(); +} + unsafe fn remote_abort(ptr: NonNull
) { let harness = Harness::::from_raw(ptr); harness.remote_abort() diff --git a/tokio/src/runtime/tests/loom_queue.rs b/tokio/src/runtime/tests/loom_queue.rs index b5f78d7ebe7..d0ebf5d4350 100644 --- a/tokio/src/runtime/tests/loom_queue.rs +++ b/tokio/src/runtime/tests/loom_queue.rs @@ -1,6 +1,7 @@ use crate::runtime::blocking::NoopSchedule; use crate::runtime::task::Inject; -use crate::runtime::{queue, MetricsBatch}; +use crate::runtime::thread_pool::queue; +use crate::runtime::MetricsBatch; use loom::thread; diff --git a/tokio/src/runtime/tests/mod.rs b/tokio/src/runtime/tests/mod.rs index 4b49698a86a..08724d43ee4 100644 --- a/tokio/src/runtime/tests/mod.rs +++ b/tokio/src/runtime/tests/mod.rs @@ -2,7 +2,7 @@ use self::unowned_wrapper::unowned; mod unowned_wrapper { use crate::runtime::blocking::NoopSchedule; - use crate::runtime::task::{JoinHandle, Notified}; + use crate::runtime::task::{Id, JoinHandle, Notified}; #[cfg(all(tokio_unstable, feature = "tracing"))] pub(crate) fn unowned(task: T) -> (Notified, JoinHandle) @@ -13,7 +13,7 @@ mod unowned_wrapper { use tracing::Instrument; let span = tracing::trace_span!("test_span"); let task = task.instrument(span); - let (task, handle) = crate::runtime::task::unowned(task, NoopSchedule); + let (task, handle) = crate::runtime::task::unowned(task, NoopSchedule, Id::next()); (task.into_notified(), handle) } @@ -23,7 +23,7 @@ mod unowned_wrapper { T: std::future::Future + Send + 'static, T::Output: Send + 'static, { - let (task, handle) = crate::runtime::task::unowned(task, NoopSchedule); + let (task, handle) = crate::runtime::task::unowned(task, NoopSchedule, Id::next()); (task.into_notified(), handle) } } diff --git a/tokio/src/runtime/tests/queue.rs b/tokio/src/runtime/tests/queue.rs index 0fd1e0c6d9e..2bdaecf9f7c 100644 --- a/tokio/src/runtime/tests/queue.rs +++ b/tokio/src/runtime/tests/queue.rs @@ -1,5 +1,5 @@ -use crate::runtime::queue; use crate::runtime::task::{self, Inject, Schedule, Task}; +use crate::runtime::thread_pool::queue; use crate::runtime::MetricsBatch; use std::thread; diff --git a/tokio/src/runtime/tests/task.rs b/tokio/src/runtime/tests/task.rs index 04e1b56e777..173e5b0b23f 100644 --- a/tokio/src/runtime/tests/task.rs +++ b/tokio/src/runtime/tests/task.rs @@ -1,5 +1,5 @@ use crate::runtime::blocking::NoopSchedule; -use crate::runtime::task::{self, unowned, JoinHandle, OwnedTasks, Schedule, Task}; +use crate::runtime::task::{self, unowned, Id, JoinHandle, OwnedTasks, Schedule, Task}; use crate::util::TryLock; use std::collections::VecDeque; @@ -55,6 +55,7 @@ fn create_drop1() { unreachable!() }, NoopSchedule, + Id::next(), ); drop(notified); handle.assert_not_dropped(); @@ -71,6 +72,7 @@ fn create_drop2() { unreachable!() }, NoopSchedule, + Id::next(), ); drop(join); handle.assert_not_dropped(); @@ -78,6 +80,46 @@ fn create_drop2() { handle.assert_dropped(); } +#[test] +fn drop_abort_handle1() { + let (ad, handle) = AssertDrop::new(); + let (notified, join) = unowned( + async { + drop(ad); + unreachable!() + }, + NoopSchedule, + Id::next(), + ); + let abort = join.abort_handle(); + drop(join); + handle.assert_not_dropped(); + drop(notified); + handle.assert_not_dropped(); + drop(abort); + handle.assert_dropped(); +} + +#[test] +fn drop_abort_handle2() { + let (ad, handle) = AssertDrop::new(); + let (notified, join) = unowned( + async { + drop(ad); + unreachable!() + }, + NoopSchedule, + Id::next(), + ); + let abort = join.abort_handle(); + drop(notified); + handle.assert_not_dropped(); + drop(abort); + handle.assert_not_dropped(); + drop(join); + handle.assert_dropped(); +} + // Shutting down through Notified works #[test] fn create_shutdown1() { @@ -88,6 +130,7 @@ fn create_shutdown1() { unreachable!() }, NoopSchedule, + Id::next(), ); drop(join); handle.assert_not_dropped(); @@ -104,6 +147,7 @@ fn create_shutdown2() { unreachable!() }, NoopSchedule, + Id::next(), ); handle.assert_not_dropped(); notified.shutdown(); @@ -113,7 +157,7 @@ fn create_shutdown2() { #[test] fn unowned_poll() { - let (task, _) = unowned(async {}, NoopSchedule); + let (task, _) = unowned(async {}, NoopSchedule, Id::next()); task.run(); } @@ -228,7 +272,7 @@ impl Runtime { T: 'static + Send + Future, T::Output: 'static + Send, { - let (handle, notified) = self.0.owned.bind(future, self.clone()); + let (handle, notified) = self.0.owned.bind(future, self.clone(), Id::next()); if let Some(notified) = notified { self.schedule(notified); diff --git a/tokio/src/runtime/tests/task_combinations.rs b/tokio/src/runtime/tests/task_combinations.rs index 76ce2330c2c..5c7a0b0109b 100644 --- a/tokio/src/runtime/tests/task_combinations.rs +++ b/tokio/src/runtime/tests/task_combinations.rs @@ -3,6 +3,7 @@ use std::panic; use std::pin::Pin; use std::task::{Context, Poll}; +use crate::runtime::task::AbortHandle; use crate::runtime::Builder; use crate::sync::oneshot; use crate::task::JoinHandle; @@ -56,6 +57,12 @@ enum CombiAbort { AbortedAfterConsumeOutput = 4, } +#[derive(Copy, Clone, Debug, PartialEq)] +enum CombiAbortSource { + JoinHandle, + AbortHandle, +} + #[test] fn test_combinations() { let mut rt = &[ @@ -90,6 +97,13 @@ fn test_combinations() { CombiAbort::AbortedAfterFinish, CombiAbort::AbortedAfterConsumeOutput, ]; + let ah = [ + None, + Some(CombiJoinHandle::DropImmediately), + Some(CombiJoinHandle::DropFirstPoll), + Some(CombiJoinHandle::DropAfterNoConsume), + Some(CombiJoinHandle::DropAfterConsume), + ]; for rt in rt.iter().copied() { for ls in ls.iter().copied() { @@ -98,7 +112,34 @@ fn test_combinations() { for ji in ji.iter().copied() { for jh in jh.iter().copied() { for abort in abort.iter().copied() { - test_combination(rt, ls, task, output, ji, jh, abort); + // abort via join handle --- abort handles + // may be dropped at any point + for ah in ah.iter().copied() { + test_combination( + rt, + ls, + task, + output, + ji, + jh, + ah, + abort, + CombiAbortSource::JoinHandle, + ); + } + // if aborting via AbortHandle, it will + // never be dropped. + test_combination( + rt, + ls, + task, + output, + ji, + jh, + None, + abort, + CombiAbortSource::AbortHandle, + ); } } } @@ -108,6 +149,7 @@ fn test_combinations() { } } +#[allow(clippy::too_many_arguments)] fn test_combination( rt: CombiRuntime, ls: CombiLocalSet, @@ -115,12 +157,24 @@ fn test_combination( output: CombiOutput, ji: CombiJoinInterest, jh: CombiJoinHandle, + ah: Option, abort: CombiAbort, + abort_src: CombiAbortSource, ) { - if (jh as usize) < (abort as usize) { - // drop before abort not possible - return; + match (abort_src, ah) { + (CombiAbortSource::JoinHandle, _) if (jh as usize) < (abort as usize) => { + // join handle dropped prior to abort + return; + } + (CombiAbortSource::AbortHandle, Some(_)) => { + // abort handle dropped, we can't abort through the + // abort handle + return; + } + + _ => {} } + if (task == CombiTask::PanicOnDrop) && (output == CombiOutput::PanicOnDrop) { // this causes double panic return; @@ -130,7 +184,7 @@ fn test_combination( return; } - println!("Runtime {:?}, LocalSet {:?}, Task {:?}, Output {:?}, JoinInterest {:?}, JoinHandle {:?}, Abort {:?}", rt, ls, task, output, ji, jh, abort); + println!("Runtime {:?}, LocalSet {:?}, Task {:?}, Output {:?}, JoinInterest {:?}, JoinHandle {:?}, AbortHandle {:?}, Abort {:?} ({:?})", rt, ls, task, output, ji, jh, ah, abort, abort_src); // A runtime optionally with a LocalSet struct Rt { @@ -282,8 +336,24 @@ fn test_combination( ); } + // If we are either aborting the task via an abort handle, or dropping via + // an abort handle, do that now. + let mut abort_handle = if ah.is_some() || abort_src == CombiAbortSource::AbortHandle { + handle.as_ref().map(JoinHandle::abort_handle) + } else { + None + }; + + let do_abort = |abort_handle: &mut Option, + join_handle: Option<&mut JoinHandle<_>>| { + match abort_src { + CombiAbortSource::AbortHandle => abort_handle.take().unwrap().abort(), + CombiAbortSource::JoinHandle => join_handle.unwrap().abort(), + } + }; + if abort == CombiAbort::AbortedImmediately { - handle.as_mut().unwrap().abort(); + do_abort(&mut abort_handle, handle.as_mut()); aborted = true; } if jh == CombiJoinHandle::DropImmediately { @@ -301,12 +371,15 @@ fn test_combination( } if abort == CombiAbort::AbortedFirstPoll { - handle.as_mut().unwrap().abort(); + do_abort(&mut abort_handle, handle.as_mut()); aborted = true; } if jh == CombiJoinHandle::DropFirstPoll { drop(handle.take().unwrap()); } + if ah == Some(CombiJoinHandle::DropFirstPoll) { + drop(abort_handle.take().unwrap()); + } // Signal the future that it can return now let _ = on_complete.send(()); @@ -318,23 +391,42 @@ fn test_combination( if abort == CombiAbort::AbortedAfterFinish { // Don't set aborted to true here as the task already finished - handle.as_mut().unwrap().abort(); + do_abort(&mut abort_handle, handle.as_mut()); } if jh == CombiJoinHandle::DropAfterNoConsume { - // The runtime will usually have dropped every ref-count at this point, - // in which case dropping the JoinHandle drops the output. - // - // (But it might race and still hold a ref-count) - let panic = panic::catch_unwind(panic::AssertUnwindSafe(|| { + if ah == Some(CombiJoinHandle::DropAfterNoConsume) { drop(handle.take().unwrap()); - })); - if panic.is_err() { - assert!( - (output == CombiOutput::PanicOnDrop) - && (!matches!(task, CombiTask::PanicOnRun | CombiTask::PanicOnRunAndDrop)) - && !aborted, - "Dropping JoinHandle shouldn't panic here" - ); + // The runtime will usually have dropped every ref-count at this point, + // in which case dropping the AbortHandle drops the output. + // + // (But it might race and still hold a ref-count) + let panic = panic::catch_unwind(panic::AssertUnwindSafe(|| { + drop(abort_handle.take().unwrap()); + })); + if panic.is_err() { + assert!( + (output == CombiOutput::PanicOnDrop) + && (!matches!(task, CombiTask::PanicOnRun | CombiTask::PanicOnRunAndDrop)) + && !aborted, + "Dropping AbortHandle shouldn't panic here" + ); + } + } else { + // The runtime will usually have dropped every ref-count at this point, + // in which case dropping the JoinHandle drops the output. + // + // (But it might race and still hold a ref-count) + let panic = panic::catch_unwind(panic::AssertUnwindSafe(|| { + drop(handle.take().unwrap()); + })); + if panic.is_err() { + assert!( + (output == CombiOutput::PanicOnDrop) + && (!matches!(task, CombiTask::PanicOnRun | CombiTask::PanicOnRunAndDrop)) + && !aborted, + "Dropping JoinHandle shouldn't panic here" + ); + } } } @@ -362,11 +454,15 @@ fn test_combination( _ => unreachable!(), } - let handle = handle.take().unwrap(); + let mut handle = handle.take().unwrap(); if abort == CombiAbort::AbortedAfterConsumeOutput { - handle.abort(); + do_abort(&mut abort_handle, Some(&mut handle)); } drop(handle); + + if ah == Some(CombiJoinHandle::DropAfterConsume) { + drop(abort_handle.take()); + } } // The output should have been dropped now. Check whether the output diff --git a/tokio/src/runtime/thread_pool/mod.rs b/tokio/src/runtime/thread_pool/mod.rs index d3f46517cb0..ef6b5775ca2 100644 --- a/tokio/src/runtime/thread_pool/mod.rs +++ b/tokio/src/runtime/thread_pool/mod.rs @@ -3,14 +3,19 @@ mod idle; use self::idle::Idle; +mod park; +pub(crate) use park::{Parker, Unparker}; + +pub(super) mod queue; + mod worker; pub(crate) use worker::Launch; pub(crate) use worker::block_in_place; use crate::loom::sync::Arc; -use crate::runtime::task::JoinHandle; -use crate::runtime::{Callback, Parker}; +use crate::runtime::task::{self, JoinHandle}; +use crate::runtime::{Callback, Driver, HandleInner}; use std::fmt; use std::future::Future; @@ -42,11 +47,14 @@ pub(crate) struct Spawner { impl ThreadPool { pub(crate) fn new( size: usize, - parker: Parker, + driver: Driver, + handle_inner: HandleInner, before_park: Option, after_unpark: Option, ) -> (ThreadPool, Launch) { - let (shared, launch) = worker::create(size, parker, before_park, after_unpark); + let parker = Parker::new(driver); + let (shared, launch) = + worker::create(size, parker, handle_inner, before_park, after_unpark); let spawner = Spawner { shared }; let thread_pool = ThreadPool { spawner }; @@ -90,17 +98,21 @@ impl Drop for ThreadPool { impl Spawner { /// Spawns a future onto the thread pool - pub(crate) fn spawn(&self, future: F) -> JoinHandle + pub(crate) fn spawn(&self, future: F, id: task::Id) -> JoinHandle where F: crate::future::Future + Send + 'static, F::Output: Send + 'static, { - worker::Shared::bind_new_task(&self.shared, future) + worker::Shared::bind_new_task(&self.shared, future, id) } pub(crate) fn shutdown(&mut self) { self.shared.close(); } + + pub(crate) fn as_handle_inner(&self) -> &HandleInner { + self.shared.as_handle_inner() + } } cfg_metrics! { diff --git a/tokio/src/runtime/park.rs b/tokio/src/runtime/thread_pool/park.rs similarity index 100% rename from tokio/src/runtime/park.rs rename to tokio/src/runtime/thread_pool/park.rs diff --git a/tokio/src/runtime/queue.rs b/tokio/src/runtime/thread_pool/queue.rs similarity index 97% rename from tokio/src/runtime/queue.rs rename to tokio/src/runtime/thread_pool/queue.rs index ad9085a6545..1f5841d6dda 100644 --- a/tokio/src/runtime/queue.rs +++ b/tokio/src/runtime/thread_pool/queue.rs @@ -11,14 +11,14 @@ use std::ptr; use std::sync::atomic::Ordering::{AcqRel, Acquire, Relaxed, Release}; /// Producer handle. May only be used from a single thread. -pub(super) struct Local { +pub(crate) struct Local { inner: Arc>, } /// Consumer handle. May be used from many threads. -pub(super) struct Steal(Arc>); +pub(crate) struct Steal(Arc>); -pub(super) struct Inner { +pub(crate) struct Inner { /// Concurrently updated by many threads. /// /// Contains two `u16` values. The LSB byte is the "real" head of the queue. @@ -65,7 +65,7 @@ fn make_fixed_size(buffer: Box<[T]>) -> Box<[T; LOCAL_QUEUE_CAPACITY]> { } /// Create a new local run-queue -pub(super) fn local() -> (Steal, Local) { +pub(crate) fn local() -> (Steal, Local) { let mut buffer = Vec::with_capacity(LOCAL_QUEUE_CAPACITY); for _ in 0..LOCAL_QUEUE_CAPACITY { @@ -89,7 +89,7 @@ pub(super) fn local() -> (Steal, Local) { impl Local { /// Returns true if the queue has entries that can be stealed. - pub(super) fn is_stealable(&self) -> bool { + pub(crate) fn is_stealable(&self) -> bool { !self.inner.is_empty() } @@ -97,12 +97,12 @@ impl Local { /// /// Separate to is_stealable so that refactors of is_stealable to "protect" /// some tasks from stealing won't affect this - pub(super) fn has_tasks(&self) -> bool { + pub(crate) fn has_tasks(&self) -> bool { !self.inner.is_empty() } /// Pushes a task to the back of the local queue, skipping the LIFO slot. - pub(super) fn push_back( + pub(crate) fn push_back( &mut self, mut task: task::Notified, inject: &Inject, @@ -259,7 +259,7 @@ impl Local { } /// Pops a task from the local queue. - pub(super) fn pop(&mut self) -> Option> { + pub(crate) fn pop(&mut self) -> Option> { let mut head = self.inner.head.load(Acquire); let idx = loop { @@ -301,12 +301,12 @@ impl Local { } impl Steal { - pub(super) fn is_empty(&self) -> bool { + pub(crate) fn is_empty(&self) -> bool { self.0.is_empty() } /// Steals half the tasks from self and place them into `dst`. - pub(super) fn steal_into( + pub(crate) fn steal_into( &self, dst: &mut Local, dst_metrics: &mut MetricsBatch, diff --git a/tokio/src/runtime/thread_pool/worker.rs b/tokio/src/runtime/thread_pool/worker.rs index 7e4989701e5..3d58767f308 100644 --- a/tokio/src/runtime/thread_pool/worker.rs +++ b/tokio/src/runtime/thread_pool/worker.rs @@ -63,10 +63,9 @@ use crate::loom::sync::{Arc, Mutex}; use crate::park::{Park, Unpark}; use crate::runtime; use crate::runtime::enter::EnterContext; -use crate::runtime::park::{Parker, Unparker}; use crate::runtime::task::{Inject, JoinHandle, OwnedTasks}; -use crate::runtime::thread_pool::Idle; -use crate::runtime::{queue, task, Callback, MetricsBatch, SchedulerMetrics, WorkerMetrics}; +use crate::runtime::thread_pool::{queue, Idle, Parker, Unparker}; +use crate::runtime::{task, Callback, HandleInner, MetricsBatch, SchedulerMetrics, WorkerMetrics}; use crate::util::atomic_cell::AtomicCell; use crate::util::FastRand; @@ -122,11 +121,16 @@ struct Core { /// State shared across all workers pub(super) struct Shared { + /// Handle to the I/O driver, timer, blocking spawner, ... + handle_inner: HandleInner, + /// Per-worker remote state. All other workers have access to this and is /// how they communicate between each other. remotes: Box<[Remote]>, - /// Submits work to the scheduler while **not** currently on a worker thread. + /// Global task queue used for: + /// 1. Submit work to the scheduler while **not** currently on a worker thread. + /// 2. Submit work to the scheduler when a worker run queue is saturated inject: Inject>, /// Coordinates idle workers @@ -191,12 +195,13 @@ scoped_thread_local!(static CURRENT: Context); pub(super) fn create( size: usize, park: Parker, + handle_inner: HandleInner, before_park: Option, after_unpark: Option, ) -> (Arc, Launch) { - let mut cores = vec![]; - let mut remotes = vec![]; - let mut worker_metrics = vec![]; + let mut cores = Vec::with_capacity(size); + let mut remotes = Vec::with_capacity(size); + let mut worker_metrics = Vec::with_capacity(size); // Create the local queues for _ in 0..size { @@ -221,6 +226,7 @@ pub(super) fn create( } let shared = Arc::new(Shared { + handle_inner, remotes: remotes.into_boxed_slice(), inject: Inject::new(), idle: Idle::new(size), @@ -470,6 +476,17 @@ impl Context { core } + /// Parks the worker thread while waiting for tasks to execute. + /// + /// This function checks if indeed there's no more work left to be done before parking. + /// Also important to notice that, before parking, the worker thread will try to take + /// ownership of the Driver (IO/Time) and dispatch any events that might have fired. + /// Whenever a worker thread executes the Driver loop, all waken tasks are scheduled + /// in its own local queue until the queue saturates (ntasks > LOCAL_QUEUE_CAPACITY). + /// When the local queue is saturated, the overflow tasks are added to the injection queue + /// from where other workers can pick them up. + /// Also, we rely on the workstealing algorithm to spread the tasks amongst workers + /// after all the IOs get dispatched fn park(&self, mut core: Box) -> Box { if let Some(f) = &self.worker.shared.before_park { f(); @@ -545,6 +562,11 @@ impl Core { self.lifo_slot.take().or_else(|| self.run_queue.pop()) } + /// Function responsible for stealing tasks from another worker + /// + /// Note: Only if less than half the workers are searching for tasks to steal + /// a new worker will actually try to steal. The idea is to make sure not all + /// workers will be trying to steal at the same time. fn steal_work(&mut self, worker: &Worker) -> Option { if !self.transition_to_searching(worker) { return None; @@ -594,7 +616,7 @@ impl Core { /// Prepares the worker state for parking. /// - /// Returns true if the transition happend, false if there is work to do first. + /// Returns true if the transition happened, false if there is work to do first. fn transition_to_parked(&mut self, worker: &Worker) -> bool { // Workers should not park if they have work to do if self.lifo_slot.is_some() || self.run_queue.has_tasks() { @@ -697,12 +719,20 @@ impl task::Schedule for Arc { } impl Shared { - pub(super) fn bind_new_task(me: &Arc, future: T) -> JoinHandle + pub(crate) fn as_handle_inner(&self) -> &HandleInner { + &self.handle_inner + } + + pub(super) fn bind_new_task( + me: &Arc, + future: T, + id: crate::runtime::task::Id, + ) -> JoinHandle where T: Future + Send + 'static, T::Output: Send + 'static, { - let (handle, notified) = me.owned.bind(future, me.clone()); + let (handle, notified) = me.owned.bind(future, me.clone(), id); if let Some(notified) = notified { me.schedule(notified, false); @@ -835,6 +865,19 @@ impl Shared { } } +impl crate::runtime::ToHandle for Arc { + fn to_handle(&self) -> crate::runtime::Handle { + use crate::runtime::thread_pool::Spawner; + use crate::runtime::{self, Handle}; + + Handle { + spawner: runtime::Spawner::ThreadPool(Spawner { + shared: self.clone(), + }), + } + } +} + cfg_metrics! { impl Shared { pub(super) fn injection_queue_depth(&self) -> usize { diff --git a/tokio/src/signal/registry.rs b/tokio/src/signal/registry.rs index 6d8eb9e7487..e1b3d108767 100644 --- a/tokio/src/signal/registry.rs +++ b/tokio/src/signal/registry.rs @@ -1,10 +1,9 @@ #![allow(clippy::unit_arg)] use crate::signal::os::{OsExtraData, OsStorage}; - use crate::sync::watch; +use crate::util::once_cell::OnceCell; -use once_cell::sync::Lazy; use std::ops; use std::pin::Pin; use std::sync::atomic::{AtomicBool, Ordering}; @@ -152,19 +151,25 @@ impl Globals { } } +fn globals_init() -> Globals +where + OsExtraData: 'static + Send + Sync + Init, + OsStorage: 'static + Send + Sync + Init, +{ + Globals { + extra: OsExtraData::init(), + registry: Registry::new(OsStorage::init()), + } +} + pub(crate) fn globals() -> Pin<&'static Globals> where OsExtraData: 'static + Send + Sync + Init, OsStorage: 'static + Send + Sync + Init, { - static GLOBALS: Lazy>> = Lazy::new(|| { - Box::pin(Globals { - extra: OsExtraData::init(), - registry: Registry::new(OsStorage::init()), - }) - }); - - GLOBALS.as_ref() + static GLOBALS: OnceCell = OnceCell::new(); + + Pin::new(GLOBALS.get(globals_init)) } #[cfg(all(test, not(loom)))] @@ -237,7 +242,7 @@ mod tests { #[test] fn record_invalid_event_does_nothing() { let registry = Registry::new(vec![EventInfo::default()]); - registry.record_event(42); + registry.record_event(1302); } #[test] diff --git a/tokio/src/signal/reusable_box.rs b/tokio/src/signal/reusable_box.rs index 02f32474b16..796fa210b03 100644 --- a/tokio/src/signal/reusable_box.rs +++ b/tokio/src/signal/reusable_box.rs @@ -151,7 +151,6 @@ impl fmt::Debug for ReusableBoxFuture { } #[cfg(test)] -#[cfg(not(miri))] // Miri breaks when you use Pin<&mut dyn Future> mod test { use super::ReusableBoxFuture; use futures::future::FutureExt; diff --git a/tokio/src/signal/unix.rs b/tokio/src/signal/unix.rs index 86ea9a93ee6..11f848b5a99 100644 --- a/tokio/src/signal/unix.rs +++ b/tokio/src/signal/unix.rs @@ -22,13 +22,17 @@ use self::driver::Handle; pub(crate) type OsStorage = Vec; -// Number of different unix signals -// (FreeBSD has 33) -const SIGNUM: usize = 33; - impl Init for OsStorage { fn init() -> Self { - (0..SIGNUM).map(|_| SignalInfo::default()).collect() + // There are reliable signals ranging from 1 to 33 available on every Unix platform. + #[cfg(not(target_os = "linux"))] + let possible = 0..=33; + + // On Linux, there are additional real-time signals available. + #[cfg(target_os = "linux")] + let possible = 0..=libc::SIGRTMAX(); + + possible.map(|_| SignalInfo::default()).collect() } } @@ -60,7 +64,7 @@ impl Init for OsExtraData { } /// Represents the specific kind of signal to listen for. -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] pub struct SignalKind(libc::c_int); impl SignalKind { @@ -84,6 +88,17 @@ impl SignalKind { Self(signum as libc::c_int) } + /// Get the signal's numeric value. + /// + /// ```rust + /// # use tokio::signal::unix::SignalKind; + /// let kind = SignalKind::interrupt(); + /// assert_eq!(kind.as_raw_value(), libc::SIGINT); + /// ``` + pub fn as_raw_value(&self) -> std::os::raw::c_int { + self.0 + } + /// Represents the SIGALRM signal. /// /// On Unix systems this signal is sent when a real-time timer has expired. @@ -190,6 +205,18 @@ impl SignalKind { } } +impl From for SignalKind { + fn from(signum: std::os::raw::c_int) -> Self { + Self::from_raw(signum as libc::c_int) + } +} + +impl From for std::os::raw::c_int { + fn from(kind: SignalKind) -> Self { + kind.as_raw_value() + } +} + pub(crate) struct SignalInfo { event_info: EventInfo, init: Once, @@ -380,6 +407,12 @@ impl Signal { /// /// `None` is returned if no more events can be received by this stream. /// + /// # Cancel safety + /// + /// This method is cancel safe. If you use it as the event in a + /// [`tokio::select!`](crate::select) statement and some other branch + /// completes first, then it is guaranteed that no signal is lost. + /// /// # Examples /// /// Wait for SIGHUP @@ -474,4 +507,15 @@ mod tests { ) .unwrap_err(); } + + #[test] + fn from_c_int() { + assert_eq!(SignalKind::from(2), SignalKind::interrupt()); + } + + #[test] + fn into_c_int() { + let value: std::os::raw::c_int = SignalKind::interrupt().into(); + assert_eq!(value, libc::SIGINT as _); + } } diff --git a/tokio/src/sync/barrier.rs b/tokio/src/sync/barrier.rs index dfc76a40ebf..b2f24bbadda 100644 --- a/tokio/src/sync/barrier.rs +++ b/tokio/src/sync/barrier.rs @@ -105,7 +105,7 @@ impl Barrier { n, wait, #[cfg(all(tokio_unstable, feature = "tracing"))] - resource_span: resource_span, + resource_span, } } diff --git a/tokio/src/sync/batch_semaphore.rs b/tokio/src/sync/batch_semaphore.rs index 4f5effff319..4db88351d0c 100644 --- a/tokio/src/sync/batch_semaphore.rs +++ b/tokio/src/sync/batch_semaphore.rs @@ -582,7 +582,7 @@ impl<'a> Acquire<'a> { tracing::trace!( target: "runtime::resource::async_op::state_update", - permits_obtained = 0 as usize, + permits_obtained = 0usize, permits.op = "override", ); diff --git a/tokio/src/sync/broadcast.rs b/tokio/src/sync/broadcast.rs index 0d9cd3bc176..846d6c027a6 100644 --- a/tokio/src/sync/broadcast.rs +++ b/tokio/src/sync/broadcast.rs @@ -230,7 +230,7 @@ pub mod error { /// /// [`recv`]: crate::sync::broadcast::Receiver::recv /// [`Receiver`]: crate::sync::broadcast::Receiver - #[derive(Debug, PartialEq)] + #[derive(Debug, PartialEq, Clone)] pub enum RecvError { /// There are no more active senders implying no further messages will ever /// be sent. @@ -258,7 +258,7 @@ pub mod error { /// /// [`try_recv`]: crate::sync::broadcast::Receiver::try_recv /// [`Receiver`]: crate::sync::broadcast::Receiver - #[derive(Debug, PartialEq)] + #[derive(Debug, PartialEq, Clone)] pub enum TryRecvError { /// The channel is currently empty. There are still active /// [`Sender`] handles, so data may yet become available. @@ -425,6 +425,11 @@ const MAX_RECEIVERS: usize = usize::MAX >> 2; /// tx.send(20).unwrap(); /// } /// ``` +/// +/// # Panics +/// +/// This will panic if `capacity` is equal to `0` or larger +/// than `usize::MAX / 2`. pub fn channel(mut capacity: usize) -> (Sender, Receiver) { assert!(capacity > 0, "capacity is empty"); assert!(capacity <= usize::MAX >> 1, "requested capacity too large"); @@ -691,6 +696,73 @@ impl Drop for Sender { } impl Receiver { + /// Returns the number of messages that were sent into the channel and that + /// this [`Receiver`] has yet to receive. + /// + /// If the returned value from `len` is larger than the next largest power of 2 + /// of the capacity of the channel any call to [`recv`] will return an + /// `Err(RecvError::Lagged)` and any call to [`try_recv`] will return an + /// `Err(TryRecvError::Lagged)`, e.g. if the capacity of the channel is 10, + /// [`recv`] will start to return `Err(RecvError::Lagged)` once `len` returns + /// values larger than 16. + /// + /// [`Receiver`]: crate::sync::broadcast::Receiver + /// [`recv`]: crate::sync::broadcast::Receiver::recv + /// [`try_recv`]: crate::sync::broadcast::Receiver::try_recv + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::broadcast; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx1) = broadcast::channel(16); + /// + /// tx.send(10).unwrap(); + /// tx.send(20).unwrap(); + /// + /// assert_eq!(rx1.len(), 2); + /// assert_eq!(rx1.recv().await.unwrap(), 10); + /// assert_eq!(rx1.len(), 1); + /// assert_eq!(rx1.recv().await.unwrap(), 20); + /// assert_eq!(rx1.len(), 0); + /// } + /// ``` + pub fn len(&self) -> usize { + let next_send_pos = self.shared.tail.lock().pos; + (next_send_pos - self.next) as usize + } + + /// Returns true if there aren't any messages in the channel that the [`Receiver`] + /// has yet to receive. + /// + /// [`Receiver]: create::sync::broadcast::Receiver + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::broadcast; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx1) = broadcast::channel(16); + /// + /// assert!(rx1.is_empty()); + /// + /// tx.send(10).unwrap(); + /// tx.send(20).unwrap(); + /// + /// assert!(!rx1.is_empty()); + /// assert_eq!(rx1.recv().await.unwrap(), 10); + /// assert_eq!(rx1.recv().await.unwrap(), 20); + /// assert!(rx1.is_empty()); + /// } + /// ``` + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + /// Locks the next value if there is one. fn recv_ref( &mut self, diff --git a/tokio/src/sync/mpsc/error.rs b/tokio/src/sync/mpsc/error.rs index 3fe6bac5e17..1c789da0cdd 100644 --- a/tokio/src/sync/mpsc/error.rs +++ b/tokio/src/sync/mpsc/error.rs @@ -78,7 +78,7 @@ impl Error for TryRecvError {} // ===== RecvError ===== /// Error returned by `Receiver`. -#[derive(Debug)] +#[derive(Debug, Clone)] #[doc(hidden)] #[deprecated(note = "This type is unused because recv returns an Option.")] pub struct RecvError(()); diff --git a/tokio/src/sync/mpsc/unbounded.rs b/tokio/src/sync/mpsc/unbounded.rs index b133f9f35e3..f8338fb0885 100644 --- a/tokio/src/sync/mpsc/unbounded.rs +++ b/tokio/src/sync/mpsc/unbounded.rs @@ -79,8 +79,14 @@ impl UnboundedReceiver { /// Receives the next value for this receiver. /// - /// `None` is returned when all `Sender` halves have dropped, indicating - /// that no further values can be sent on the channel. + /// This method returns `None` if the channel has been closed and there are + /// no remaining messages in the channel's buffer. This indicates that no + /// further values can ever be received from this `Receiver`. The channel is + /// closed when all senders have been dropped, or when [`close`] is called. + /// + /// If there are no messages in the channel's buffer, but the channel has + /// not yet been closed, this method will sleep until a message is sent or + /// the channel is closed. /// /// # Cancel safety /// @@ -89,6 +95,8 @@ impl UnboundedReceiver { /// completes first, it is guaranteed that no messages were received on this /// channel. /// + /// [`close`]: Self::close + /// /// # Examples /// /// ``` @@ -207,6 +215,9 @@ impl UnboundedReceiver { /// /// This prevents any further messages from being sent on the channel while /// still enabling the receiver to drain messages that are buffered. + /// + /// To guarantee that no messages are dropped, after calling `close()`, + /// `recv()` must be called until `None` is returned. pub fn close(&mut self) { self.chan.close(); } diff --git a/tokio/src/sync/oneshot.rs b/tokio/src/sync/oneshot.rs index 2240074e733..d5fc811da2a 100644 --- a/tokio/src/sync/oneshot.rs +++ b/tokio/src/sync/oneshot.rs @@ -323,11 +323,11 @@ pub mod error { use std::fmt; /// Error returned by the `Future` implementation for `Receiver`. - #[derive(Debug, Eq, PartialEq)] + #[derive(Debug, Eq, PartialEq, Clone)] pub struct RecvError(pub(super) ()); /// Error returned by the `try_recv` function on `Receiver`. - #[derive(Debug, Eq, PartialEq)] + #[derive(Debug, Eq, PartialEq, Clone)] pub enum TryRecvError { /// The send half of the channel has not yet sent a value. Empty, @@ -526,7 +526,7 @@ pub fn channel() -> (Sender, Receiver) { let rx = Receiver { inner: Some(inner), #[cfg(all(tokio_unstable, feature = "tracing"))] - resource_span: resource_span, + resource_span, #[cfg(all(tokio_unstable, feature = "tracing"))] async_op_span, #[cfg(all(tokio_unstable, feature = "tracing"))] diff --git a/tokio/src/sync/watch.rs b/tokio/src/sync/watch.rs index 5673e0fca78..afeb7c2c9f7 100644 --- a/tokio/src/sync/watch.rs +++ b/tokio/src/sync/watch.rs @@ -60,6 +60,7 @@ use crate::loom::sync::atomic::Ordering::Relaxed; use crate::loom::sync::{Arc, RwLock, RwLockReadGuard}; use std::mem; use std::ops; +use std::panic; /// Receives values from the associated [`Sender`](struct@Sender). /// @@ -154,7 +155,7 @@ pub mod error { impl std::error::Error for SendError {} /// Error produced when receiving a change notification. - #[derive(Debug)] + #[derive(Debug, Clone)] pub struct RecvError(pub(super) ()); // ===== impl RecvError ===== @@ -466,6 +467,22 @@ impl Receiver { } } + /// Returns `true` if receivers belong to the same channel. + /// + /// # Examples + /// + /// ``` + /// let (tx, rx) = tokio::sync::watch::channel(true); + /// let rx2 = rx.clone(); + /// assert!(rx.same_channel(&rx2)); + /// + /// let (tx3, rx3) = tokio::sync::watch::channel(true); + /// assert!(!rx3.same_channel(&rx2)); + /// ``` + pub fn same_channel(&self, other: &Self) -> bool { + Arc::ptr_eq(&self.shared, &other.shared) + } + cfg_process_driver! { pub(crate) fn try_has_changed(&mut self) -> Option> { maybe_changed(&self.shared, &mut self.version) @@ -530,27 +547,47 @@ impl Sender { Ok(()) } - /// Sends a new value via the channel, notifying all receivers and returning - /// the previous value in the channel. + /// Modifies watched value, notifying all receivers. /// - /// This can be useful for reusing the buffers inside a watched value. - /// Additionally, this method permits sending values even when there are no - /// receivers. + /// This can useful for modifying the watched value, without + /// having to allocate a new instance. Additionally, this + /// method permits sending values even when there are no receivers. + /// + /// # Panics + /// + /// This function panics if calling `func` results in a panic. + /// No receivers are notified if panic occurred, but if the closure has modified + /// the value, that change is still visible to future calls to `borrow`. /// /// # Examples /// /// ``` /// use tokio::sync::watch; /// - /// let (tx, _rx) = watch::channel(1); - /// assert_eq!(tx.send_replace(2), 1); - /// assert_eq!(tx.send_replace(3), 2); + /// struct State { + /// counter: usize, + /// } + /// let (state_tx, state_rx) = watch::channel(State { counter: 0 }); + /// state_tx.send_modify(|state| state.counter += 1); + /// assert_eq!(state_rx.borrow().counter, 1); /// ``` - pub fn send_replace(&self, value: T) -> T { - let old = { + pub fn send_modify(&self, func: F) + where + F: FnOnce(&mut T), + { + { // Acquire the write lock and update the value. let mut lock = self.shared.value.write().unwrap(); - let old = mem::replace(&mut *lock, value); + // Update the value and catch possible panic inside func. + let result = panic::catch_unwind(panic::AssertUnwindSafe(|| { + func(&mut lock); + })); + // If the func panicked return the panic to the caller. + if let Err(error) = result { + // Drop the lock to avoid poisoning it. + drop(lock); + panic::resume_unwind(error); + } self.shared.state.increment_version(); @@ -560,14 +597,32 @@ impl Sender { // that receivers are able to figure out the version number of the // value they are currently looking at. drop(lock); + } - old - }; - - // Notify all watchers self.shared.notify_rx.notify_waiters(); + } + + /// Sends a new value via the channel, notifying all receivers and returning + /// the previous value in the channel. + /// + /// This can be useful for reusing the buffers inside a watched value. + /// Additionally, this method permits sending values even when there are no + /// receivers. + /// + /// # Examples + /// + /// ``` + /// use tokio::sync::watch; + /// + /// let (tx, _rx) = watch::channel(1); + /// assert_eq!(tx.send_replace(2), 1); + /// assert_eq!(tx.send_replace(3), 2); + /// ``` + pub fn send_replace(&self, mut value: T) -> T { + // swap old watched value with the new one + self.send_modify(|old| mem::swap(old, &mut value)); - old + value } /// Returns a reference to the most recently sent value diff --git a/tokio/src/task/builder.rs b/tokio/src/task/builder.rs index 2086302fb92..976ecc3c4b0 100644 --- a/tokio/src/task/builder.rs +++ b/tokio/src/task/builder.rs @@ -108,8 +108,13 @@ impl<'a> Builder<'a> { Output: Send + 'static, { use crate::runtime::Mandatory; - let (join_handle, _was_spawned) = - context::current().spawn_blocking_inner(function, Mandatory::NonMandatory, self.name); + let handle = context::current(); + let (join_handle, _was_spawned) = handle.as_inner().spawn_blocking_inner( + function, + Mandatory::NonMandatory, + self.name, + &handle, + ); join_handle } } diff --git a/tokio/src/task/join_set.rs b/tokio/src/task/join_set.rs index 8e8f74f66d1..d036fcc3cda 100644 --- a/tokio/src/task/join_set.rs +++ b/tokio/src/task/join_set.rs @@ -4,7 +4,7 @@ use std::pin::Pin; use std::task::{Context, Poll}; use crate::runtime::Handle; -use crate::task::{JoinError, JoinHandle, LocalSet}; +use crate::task::{AbortHandle, Id, JoinError, JoinHandle, LocalSet}; use crate::util::IdleNotifiedSet; /// A collection of tasks spawned on a Tokio runtime. @@ -48,6 +48,7 @@ use crate::util::IdleNotifiedSet; /// ``` /// /// [unstable]: crate#unstable-features +#[cfg_attr(docsrs, doc(cfg(all(feature = "rt", tokio_unstable))))] pub struct JoinSet { inner: IdleNotifiedSet>, } @@ -72,61 +73,76 @@ impl JoinSet { } impl JoinSet { - /// Spawn the provided task on the `JoinSet`. + /// Spawn the provided task on the `JoinSet`, returning an [`AbortHandle`] + /// that can be used to remotely cancel the task. /// /// # Panics /// /// This method panics if called outside of a Tokio runtime. - pub fn spawn(&mut self, task: F) + /// + /// [`AbortHandle`]: crate::task::AbortHandle + pub fn spawn(&mut self, task: F) -> AbortHandle where F: Future, F: Send + 'static, T: Send, { - self.insert(crate::spawn(task)); + self.insert(crate::spawn(task)) } - /// Spawn the provided task on the provided runtime and store it in this `JoinSet`. - pub fn spawn_on(&mut self, task: F, handle: &Handle) + /// Spawn the provided task on the provided runtime and store it in this + /// `JoinSet` returning an [`AbortHandle`] that can be used to remotely + /// cancel the task. + /// + /// [`AbortHandle`]: crate::task::AbortHandle + pub fn spawn_on(&mut self, task: F, handle: &Handle) -> AbortHandle where F: Future, F: Send + 'static, T: Send, { - self.insert(handle.spawn(task)); + self.insert(handle.spawn(task)) } - /// Spawn the provided task on the current [`LocalSet`] and store it in this `JoinSet`. + /// Spawn the provided task on the current [`LocalSet`] and store it in this + /// `JoinSet`, returning an [`AbortHandle`] that can be used to remotely + /// cancel the task. /// /// # Panics /// /// This method panics if it is called outside of a `LocalSet`. /// /// [`LocalSet`]: crate::task::LocalSet - pub fn spawn_local(&mut self, task: F) + /// [`AbortHandle`]: crate::task::AbortHandle + pub fn spawn_local(&mut self, task: F) -> AbortHandle where F: Future, F: 'static, { - self.insert(crate::task::spawn_local(task)); + self.insert(crate::task::spawn_local(task)) } - /// Spawn the provided task on the provided [`LocalSet`] and store it in this `JoinSet`. + /// Spawn the provided task on the provided [`LocalSet`] and store it in + /// this `JoinSet`, returning an [`AbortHandle`] that can be used to + /// remotely cancel the task. /// /// [`LocalSet`]: crate::task::LocalSet - pub fn spawn_local_on(&mut self, task: F, local_set: &LocalSet) + /// [`AbortHandle`]: crate::task::AbortHandle + pub fn spawn_local_on(&mut self, task: F, local_set: &LocalSet) -> AbortHandle where F: Future, F: 'static, { - self.insert(local_set.spawn_local(task)); + self.insert(local_set.spawn_local(task)) } - fn insert(&mut self, jh: JoinHandle) { + fn insert(&mut self, jh: JoinHandle) -> AbortHandle { + let abort = jh.abort_handle(); let mut entry = self.inner.insert_idle(jh); // Set the waker that is notified when the task completes. entry.with_value_and_context(|jh, ctx| jh.set_join_waker(ctx.waker())); + abort } /// Waits until one of the tasks in the set completes and returns its output. @@ -139,6 +155,24 @@ impl JoinSet { /// statement and some other branch completes first, it is guaranteed that no tasks were /// removed from this `JoinSet`. pub async fn join_one(&mut self) -> Result, JoinError> { + crate::future::poll_fn(|cx| self.poll_join_one(cx)) + .await + .map(|opt| opt.map(|(_, res)| res)) + } + + /// Waits until one of the tasks in the set completes and returns its + /// output, along with the [task ID] of the completed task. + /// + /// Returns `None` if the set is empty. + /// + /// # Cancel Safety + /// + /// This method is cancel safe. If `join_one_with_id` is used as the event in a `tokio::select!` + /// statement and some other branch completes first, it is guaranteed that no tasks were + /// removed from this `JoinSet`. + /// + /// [task ID]: crate::task::Id + pub async fn join_one_with_id(&mut self) -> Result, JoinError> { crate::future::poll_fn(|cx| self.poll_join_one(cx)).await } @@ -175,8 +209,8 @@ impl JoinSet { /// Polls for one of the tasks in the set to complete. /// - /// If this returns `Poll::Ready(Ok(Some(_)))` or `Poll::Ready(Err(_))`, then the task that - /// completed is removed from the set. + /// If this returns `Poll::Ready(Some((_, Ok(_))))` or `Poll::Ready(Some((_, + /// Err(_)))`, then the task that completed is removed from the set. /// /// When the method returns `Poll::Pending`, the `Waker` in the provided `Context` is scheduled /// to receive a wakeup when a task in the `JoinSet` completes. Note that on multiple calls to @@ -189,17 +223,19 @@ impl JoinSet { /// /// * `Poll::Pending` if the `JoinSet` is not empty but there is no task whose output is /// available right now. - /// * `Poll::Ready(Ok(Some(value)))` if one of the tasks in this `JoinSet` has completed. The - /// `value` is the return value of one of the tasks that completed. + /// * `Poll::Ready(Ok(Some((id, value)))` if one of the tasks in this `JoinSet` has completed. The + /// `value` is the return value of one of the tasks that completed, while + /// `id` is the [task ID] of that task. /// * `Poll::Ready(Err(err))` if one of the tasks in this `JoinSet` has panicked or been - /// aborted. + /// aborted. The `err` is the `JoinError` from the panicked/aborted task. /// * `Poll::Ready(Ok(None))` if the `JoinSet` is empty. /// /// Note that this method may return `Poll::Pending` even if one of the tasks has completed. /// This can happen if the [coop budget] is reached. /// /// [coop budget]: crate::task#cooperative-scheduling - fn poll_join_one(&mut self, cx: &mut Context<'_>) -> Poll, JoinError>> { + /// [task ID]: crate::task::Id + fn poll_join_one(&mut self, cx: &mut Context<'_>) -> Poll, JoinError>> { // The call to `pop_notified` moves the entry to the `idle` list. It is moved back to // the `notified` list if the waker is notified in the `poll` call below. let mut entry = match self.inner.pop_notified(cx.waker()) { @@ -217,7 +253,10 @@ impl JoinSet { let res = entry.with_value_and_context(|jh, ctx| Pin::new(jh).poll(ctx)); if let Poll::Ready(res) = res { - entry.remove(); + let entry = entry.remove(); + // If the task succeeded, add the task ID to the output. Otherwise, the + // `JoinError` will already have the task's ID. + let res = res.map(|output| (entry.id(), output)); Poll::Ready(Some(res).transpose()) } else { // A JoinHandle generally won't emit a wakeup without being ready unless diff --git a/tokio/src/task/local.rs b/tokio/src/task/local.rs index 2dbd9706047..32e376872f4 100644 --- a/tokio/src/task/local.rs +++ b/tokio/src/task/local.rs @@ -301,12 +301,13 @@ cfg_rt! { where F: Future + 'static, F::Output: 'static { - let future = crate::util::trace::task(future, "local", name); + let id = crate::runtime::task::Id::next(); + let future = crate::util::trace::task(future, "local", name, id.as_u64()); CURRENT.with(|maybe_cx| { let cx = maybe_cx .expect("`spawn_local` called from outside of a `task::LocalSet`"); - let (handle, notified) = cx.owned.bind(future, cx.shared.clone()); + let (handle, notified) = cx.owned.bind(future, cx.shared.clone(), id); if let Some(notified) = notified { cx.shared.schedule(notified); @@ -385,9 +386,13 @@ impl LocalSet { F: Future + 'static, F::Output: 'static, { - let future = crate::util::trace::task(future, "local", None); + let id = crate::runtime::task::Id::next(); + let future = crate::util::trace::task(future, "local", None, id.as_u64()); - let (handle, notified) = self.context.owned.bind(future, self.context.shared.clone()); + let (handle, notified) = self + .context + .owned + .bind(future, self.context.shared.clone(), id); if let Some(notified) = notified { self.context.shared.schedule(notified); diff --git a/tokio/src/task/mod.rs b/tokio/src/task/mod.rs index d532155a1fe..cebc269bb40 100644 --- a/tokio/src/task/mod.rs +++ b/tokio/src/task/mod.rs @@ -303,6 +303,7 @@ cfg_rt! { cfg_unstable! { mod join_set; pub use join_set::JoinSet; + pub use crate::runtime::task::{Id, AbortHandle}; } cfg_trace! { diff --git a/tokio/src/task/spawn.rs b/tokio/src/task/spawn.rs index a9d736674c0..5a60f9d66e6 100644 --- a/tokio/src/task/spawn.rs +++ b/tokio/src/task/spawn.rs @@ -142,8 +142,10 @@ cfg_rt! { T: Future + Send + 'static, T::Output: Send + 'static, { - let spawn_handle = crate::runtime::context::spawn_handle().expect(CONTEXT_MISSING_ERROR); - let task = crate::util::trace::task(future, "task", name); - spawn_handle.spawn(task) + use crate::runtime::{task, context}; + let id = task::Id::next(); + let spawn_handle = context::spawn_handle().expect(CONTEXT_MISSING_ERROR); + let task = crate::util::trace::task(future, "task", name, id.as_u64()); + spawn_handle.spawn(task, id) } } diff --git a/tokio/src/time/driver/sleep.rs b/tokio/src/time/driver/sleep.rs index 7f27ef201f7..a629cb81ee9 100644 --- a/tokio/src/time/driver/sleep.rs +++ b/tokio/src/time/driver/sleep.rs @@ -72,7 +72,9 @@ pub fn sleep_until(deadline: Instant) -> Sleep { /// /// No work is performed while awaiting on the sleep future to complete. `Sleep` /// operates at millisecond granularity and should not be used for tasks that -/// require high-resolution timers. +/// require high-resolution timers. The implementation is platform specific, +/// and some platforms (specifically Windows) will provide timers with a +/// larger resolution than 1 ms. /// /// To run something regularly on a schedule, see [`interval`]. /// @@ -261,7 +263,7 @@ impl Sleep { let inner = { let time_source = handle.time_source().clone(); let deadline_tick = time_source.deadline_to_tick(deadline); - let duration = deadline_tick.checked_sub(time_source.now()).unwrap_or(0); + let duration = deadline_tick.saturating_sub(time_source.now()); let location = location.expect("should have location if tracing"); let resource_span = tracing::trace_span!( @@ -373,7 +375,7 @@ impl Sleep { let duration = { let now = me.inner.time_source.now(); let deadline_tick = me.inner.time_source.deadline_to_tick(deadline); - deadline_tick.checked_sub(now).unwrap_or(0) + deadline_tick.saturating_sub(now) }; tracing::trace!( diff --git a/tokio/src/util/mod.rs b/tokio/src/util/mod.rs index 618f5543802..ef0fdce6261 100644 --- a/tokio/src/util/mod.rs +++ b/tokio/src/util/mod.rs @@ -6,6 +6,15 @@ cfg_io_driver! { #[cfg(feature = "rt")] pub(crate) mod atomic_cell; +cfg_has_atomic_u64! { + #[cfg(any(feature = "signal", feature = "process"))] + pub(crate) mod once_cell; +} +cfg_not_has_atomic_u64! { + #[cfg(any(feature = "rt", feature = "signal", feature = "process"))] + pub(crate) mod once_cell; +} + #[cfg(any( // io driver uses `WakeList` directly feature = "net", diff --git a/tokio/src/util/once_cell.rs b/tokio/src/util/once_cell.rs new file mode 100644 index 00000000000..15639e6307f --- /dev/null +++ b/tokio/src/util/once_cell.rs @@ -0,0 +1,70 @@ +#![cfg_attr(loom, allow(dead_code))] +use std::cell::UnsafeCell; +use std::mem::MaybeUninit; +use std::sync::Once; + +pub(crate) struct OnceCell { + once: Once, + value: UnsafeCell>, +} + +unsafe impl Send for OnceCell {} +unsafe impl Sync for OnceCell {} + +impl OnceCell { + pub(crate) const fn new() -> Self { + Self { + once: Once::new(), + value: UnsafeCell::new(MaybeUninit::uninit()), + } + } + + /// Get the value inside this cell, intiailizing it using the provided + /// function if necessary. + /// + /// If the `init` closure panics, then the `OnceCell` is poisoned and all + /// future calls to `get` will panic. + #[inline] + pub(crate) fn get(&self, init: fn() -> T) -> &T { + if !self.once.is_completed() { + self.do_init(init); + } + + // Safety: The `std::sync::Once` guarantees that we can only reach this + // line if a `call_once` closure has been run exactly once and without + // panicking. Thus, the value is not uninitialized. + // + // There is also no race because the only `&self` method that modifies + // `value` is `do_init`, but if the `call_once` closure is still + // running, then no thread has gotten past the `call_once`. + unsafe { &*(self.value.get() as *const T) } + } + + #[cold] + fn do_init(&self, init: fn() -> T) { + let value_ptr = self.value.get() as *mut T; + + self.once.call_once(|| { + let set_to = init(); + + // Safety: The `std::sync::Once` guarantees that this initialization + // will run at most once, and that no thread can get past the + // `call_once` until it has run exactly once. Thus, we have + // exclusive access to `value`. + unsafe { + std::ptr::write(value_ptr, set_to); + } + }); + } +} + +impl Drop for OnceCell { + fn drop(&mut self) { + if self.once.is_completed() { + let value_ptr = self.value.get() as *mut T; + unsafe { + std::ptr::drop_in_place(value_ptr); + } + } + } +} diff --git a/tokio/src/util/trace.rs b/tokio/src/util/trace.rs index 6080e2358ae..76e8a6cbf55 100644 --- a/tokio/src/util/trace.rs +++ b/tokio/src/util/trace.rs @@ -10,7 +10,7 @@ cfg_trace! { #[inline] #[track_caller] - pub(crate) fn task(task: F, kind: &'static str, name: Option<&str>) -> Instrumented { + pub(crate) fn task(task: F, kind: &'static str, name: Option<&str>, id: u64) -> Instrumented { use tracing::instrument::Instrument; let location = std::panic::Location::caller(); let span = tracing::trace_span!( @@ -18,6 +18,7 @@ cfg_trace! { "runtime.spawn", %kind, task.name = %name.unwrap_or_default(), + task.id = id, loc.file = location.file(), loc.line = location.line(), loc.col = location.column(), @@ -91,7 +92,7 @@ cfg_time! { cfg_not_trace! { cfg_rt! { #[inline] - pub(crate) fn task(task: F, _: &'static str, _name: Option<&str>) -> F { + pub(crate) fn task(task: F, _: &'static str, _name: Option<&str>, _: u64) -> F { // nop task } diff --git a/tokio/tests/macros_rename_test.rs b/tokio/tests/macros_rename_test.rs new file mode 100644 index 00000000000..fd5554ced1f --- /dev/null +++ b/tokio/tests/macros_rename_test.rs @@ -0,0 +1,26 @@ +#![cfg(feature = "full")] + +#[allow(unused_imports)] +use std as tokio; + +use ::tokio as tokio1; + +async fn compute() -> usize { + let join = tokio1::spawn(async { 1 }); + join.await.unwrap() +} + +#[tokio1::main(crate = "tokio1")] +async fn compute_main() -> usize { + compute().await +} + +#[test] +fn crate_rename_main() { + assert_eq!(1, compute_main()); +} + +#[tokio1::test(crate = "tokio1")] +async fn crate_rename_test() { + assert_eq!(1, compute().await); +} diff --git a/tokio/tests/macros_select.rs b/tokio/tests/macros_select.rs index 755365affbc..c60a4a9506f 100644 --- a/tokio/tests/macros_select.rs +++ b/tokio/tests/macros_select.rs @@ -461,6 +461,7 @@ async fn many_branches() { x = async { 1 } => x, x = async { 1 } => x, x = async { 1 } => x, + x = async { 1 } => x, }; assert_eq!(1, num); diff --git a/tokio/tests/rt_metrics.rs b/tokio/tests/rt_metrics.rs index 0a26b80285d..1521cd26074 100644 --- a/tokio/tests/rt_metrics.rs +++ b/tokio/tests/rt_metrics.rs @@ -369,6 +369,40 @@ fn worker_local_queue_depth() { }); } +#[cfg(any(target_os = "linux", target_os = "macos"))] +#[test] +fn io_driver_fd_count() { + let rt = basic(); + let metrics = rt.metrics(); + + // Since this is enabled w/ the process driver we always + // have 1 fd registered. + assert_eq!(metrics.io_driver_fd_registered_count(), 1); + + let stream = tokio::net::TcpStream::connect("google.com:80"); + let stream = rt.block_on(async move { stream.await.unwrap() }); + + assert_eq!(metrics.io_driver_fd_registered_count(), 2); + assert_eq!(metrics.io_driver_fd_deregistered_count(), 0); + + drop(stream); + + assert_eq!(metrics.io_driver_fd_deregistered_count(), 1); + assert_eq!(metrics.io_driver_fd_registered_count(), 2); +} + +#[cfg(any(target_os = "linux", target_os = "macos"))] +#[test] +fn io_driver_ready_count() { + let rt = basic(); + let metrics = rt.metrics(); + + let stream = tokio::net::TcpStream::connect("google.com:80"); + let _stream = rt.block_on(async move { stream.await.unwrap() }); + + assert_eq!(metrics.io_driver_ready_count(), 2); +} + fn basic() -> Runtime { tokio::runtime::Builder::new_current_thread() .enable_all() diff --git a/tokio/tests/sync_broadcast.rs b/tokio/tests/sync_broadcast.rs index 1b68eb7edbd..ca8b4d7f4ce 100644 --- a/tokio/tests/sync_broadcast.rs +++ b/tokio/tests/sync_broadcast.rs @@ -457,6 +457,25 @@ fn lagging_receiver_recovers_after_wrap_open() { assert_empty!(rx); } +#[test] +fn receiver_len_with_lagged() { + let (tx, mut rx) = broadcast::channel(3); + + tx.send(10).unwrap(); + tx.send(20).unwrap(); + tx.send(30).unwrap(); + tx.send(40).unwrap(); + + assert_eq!(rx.len(), 4); + assert_eq!(assert_recv!(rx), 10); + + tx.send(50).unwrap(); + tx.send(60).unwrap(); + + assert_eq!(rx.len(), 5); + assert_lagged!(rx.try_recv(), 1); +} + fn is_closed(err: broadcast::error::RecvError) -> bool { matches!(err, broadcast::error::RecvError::Closed) } diff --git a/tokio/tests/sync_mutex.rs b/tokio/tests/sync_mutex.rs index 51dbe03dc73..bcd9b1ebac6 100644 --- a/tokio/tests/sync_mutex.rs +++ b/tokio/tests/sync_mutex.rs @@ -155,7 +155,7 @@ fn try_lock() { let g1 = m.try_lock(); assert!(g1.is_ok()); let g2 = m.try_lock(); - assert!(!g2.is_ok()); + assert!(g2.is_err()); } let g3 = m.try_lock(); assert!(g3.is_ok()); diff --git a/tokio/tests/sync_mutex_owned.rs b/tokio/tests/sync_mutex_owned.rs index 2ce15de5b9d..98ced158390 100644 --- a/tokio/tests/sync_mutex_owned.rs +++ b/tokio/tests/sync_mutex_owned.rs @@ -122,7 +122,7 @@ fn try_lock_owned() { let g1 = m.clone().try_lock_owned(); assert!(g1.is_ok()); let g2 = m.clone().try_lock_owned(); - assert!(!g2.is_ok()); + assert!(g2.is_err()); } let g3 = m.try_lock_owned(); assert!(g3.is_ok()); diff --git a/tokio/tests/sync_watch.rs b/tokio/tests/sync_watch.rs index 8b9ea81bb89..d47f0df7326 100644 --- a/tokio/tests/sync_watch.rs +++ b/tokio/tests/sync_watch.rs @@ -211,3 +211,32 @@ fn reopened_after_subscribe() { drop(rx); assert!(tx.is_closed()); } + +#[test] +#[cfg(not(target_arch = "wasm32"))] // wasm currently doesn't support unwinding +fn send_modify_panic() { + let (tx, mut rx) = watch::channel("one"); + + tx.send_modify(|old| *old = "two"); + assert_eq!(*rx.borrow_and_update(), "two"); + + let mut rx2 = rx.clone(); + assert_eq!(*rx2.borrow_and_update(), "two"); + + let mut task = spawn(rx2.changed()); + + let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + tx.send_modify(|old| { + *old = "panicked"; + panic!(); + }) + })); + assert!(result.is_err()); + + assert_pending!(task.poll()); + assert_eq!(*rx.borrow(), "panicked"); + + tx.send_modify(|old| *old = "three"); + assert_ready_ok!(task.poll()); + assert_eq!(*rx.borrow_and_update(), "three"); +} diff --git a/tokio/tests/task_join_set.rs b/tokio/tests/task_join_set.rs index 66a2fbb021d..470f861fe9b 100644 --- a/tokio/tests/task_join_set.rs +++ b/tokio/tests/task_join_set.rs @@ -106,6 +106,40 @@ async fn alternating() { } } +#[tokio::test(start_paused = true)] +async fn abort_tasks() { + let mut set = JoinSet::new(); + let mut num_canceled = 0; + let mut num_completed = 0; + for i in 0..16 { + let abort = set.spawn(async move { + tokio::time::sleep(Duration::from_secs(i as u64)).await; + i + }); + + if i % 2 != 0 { + // abort odd-numbered tasks. + abort.abort(); + } + } + loop { + match set.join_one().await { + Ok(Some(res)) => { + num_completed += 1; + assert_eq!(res % 2, 0); + } + Err(e) => { + assert!(e.is_cancelled()); + num_canceled += 1; + } + Ok(None) => break, + } + } + + assert_eq!(num_canceled, 8); + assert_eq!(num_completed, 8); +} + #[test] fn runtime_gone() { let mut set = JoinSet::new(); diff --git a/tokio/tests/uds_stream.rs b/tokio/tests/uds_stream.rs index 5f1b4cffbcc..b8c4e6a8eed 100644 --- a/tokio/tests/uds_stream.rs +++ b/tokio/tests/uds_stream.rs @@ -25,13 +25,13 @@ async fn accept_read_write() -> std::io::Result<()> { let connect = UnixStream::connect(&sock_path); let ((mut server, _), mut client) = try_join(accept, connect).await?; - // Write to the client. TODO: Switch to write_all. - let write_len = client.write(b"hello").await?; - assert_eq!(write_len, 5); + // Write to the client. + client.write_all(b"hello").await?; drop(client); - // Read from the server. TODO: Switch to read_to_end. - let mut buf = [0u8; 5]; - server.read_exact(&mut buf).await?; + + // Read from the server. + let mut buf = vec![]; + server.read_to_end(&mut buf).await?; assert_eq!(&buf, b"hello"); let len = server.read(&mut buf).await?; assert_eq!(len, 0);