diff --git a/tests/integration/jupyter_tests.rs b/tests/integration/jupyter_tests.rs index af8101ea74611..75b1da085469e 100644 --- a/tests/integration/jupyter_tests.rs +++ b/tests/integration/jupyter_tests.rs @@ -47,25 +47,43 @@ impl ConnectionSpec { } } -fn pick_unused_port() -> u16 { +/// Gets an unused port from the OS, and returns the port number and a +/// `TcpListener` bound to that port. You can keep the listener alive +/// to prevent another process from binding to the port. +fn pick_unused_port() -> (u16, std::net::TcpListener) { let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); - listener.local_addr().unwrap().port() + (listener.local_addr().unwrap().port(), listener) } -impl Default for ConnectionSpec { - fn default() -> Self { - Self { - key: "".into(), - signature_scheme: "hmac-sha256".into(), - transport: "tcp".into(), - ip: "127.0.0.1".into(), - hb_port: pick_unused_port(), - control_port: pick_unused_port(), - shell_port: pick_unused_port(), - stdin_port: pick_unused_port(), - iopub_port: pick_unused_port(), - kernel_name: "deno".into(), - } +impl ConnectionSpec { + fn new() -> (Self, Vec) { + let mut listeners = Vec::new(); + let (hb_port, listener) = pick_unused_port(); + listeners.push(listener); + let (control_port, listener) = pick_unused_port(); + listeners.push(listener); + let (shell_port, listener) = pick_unused_port(); + listeners.push(listener); + let (stdin_port, listener) = pick_unused_port(); + listeners.push(listener); + let (iopub_port, listener) = pick_unused_port(); + listeners.push(listener); + + ( + Self { + key: "".into(), + signature_scheme: "hmac-sha256".into(), + transport: "tcp".into(), + ip: "127.0.0.1".into(), + hb_port, + control_port, + shell_port, + stdin_port, + iopub_port, + kernel_name: "deno".into(), + }, + listeners, + ) } } @@ -191,25 +209,15 @@ async fn connect_socket( ) -> S { let addr = spec.endpoint(port); let mut socket = S::new(); - let mut connected = false; - for _ in 0..5 { - match timeout(Duration::from_secs(5), socket.connect(&addr)).await { - Ok(Ok(_)) => { - connected = true; - break; - } - Ok(Err(e)) => { - eprintln!("Failed to connect to {addr}: {e}"); - } - Err(e) => { - eprintln!("Timed out connecting to {addr}: {e}"); - } + match timeout(Duration::from_millis(5000), socket.connect(&addr)).await { + Ok(Ok(_)) => socket, + Ok(Err(e)) => { + panic!("Failed to connect to {addr}: {e}"); + } + Err(e) => { + panic!("Timed out connecting to {addr}: {e}"); } } - if !connected { - panic!("Failed to connect to {addr}"); - } - socket } #[derive(Clone)] @@ -236,7 +244,7 @@ use JupyterChannel::*; impl JupyterClient { async fn new(spec: &ConnectionSpec) -> Self { - Self::new_with_timeout(spec, Duration::from_secs(5)).await + Self::new_with_timeout(spec, Duration::from_secs(10)).await } async fn new_with_timeout(spec: &ConnectionSpec, timeout: Duration) -> Self { @@ -386,9 +394,36 @@ impl Drop for JupyterServerProcess { } } +async fn server_ready_on(addr: &str) -> bool { + matches!( + timeout( + Duration::from_millis(1000), + tokio::net::TcpStream::connect(addr.trim_start_matches("tcp://")), + ) + .await, + Ok(Ok(_)) + ) +} + +async fn server_ready(conn: &ConnectionSpec) -> bool { + let hb = conn.endpoint(conn.hb_port); + let control = conn.endpoint(conn.control_port); + let shell = conn.endpoint(conn.shell_port); + let stdin = conn.endpoint(conn.stdin_port); + let iopub = conn.endpoint(conn.iopub_port); + let (a, b, c, d, e) = tokio::join!( + server_ready_on(&hb), + server_ready_on(&control), + server_ready_on(&shell), + server_ready_on(&stdin), + server_ready_on(&iopub), + ); + a && b && c && d && e +} + async fn setup_server() -> (TestContext, ConnectionSpec, JupyterServerProcess) { let context = TestContextBuilder::new().use_temp_cwd().build(); - let mut conn = ConnectionSpec::default(); + let (mut conn, mut listeners) = ConnectionSpec::new(); let conn_file = context.temp_dir().path().join("connection.json"); conn_file.write_json(&conn); @@ -405,22 +440,38 @@ async fn setup_server() -> (TestContext, ConnectionSpec, JupyterServerProcess) { .unwrap() }; + // drop the listeners so the server can listen on the ports + drop(listeners); + // try to start the server, retrying up to 5 times // (this can happen due to TOCTOU errors with selecting unused TCP ports) let mut process = start_process(&conn_file); - tokio::time::sleep(Duration::from_millis(1000)).await; - - for _ in 0..5 { - if process.try_wait().unwrap().is_none() { - break; - } else { - conn = ConnectionSpec::default(); - conn_file.write_json(&conn); - process = start_process(&conn_file); - tokio::time::sleep(Duration::from_millis(1000)).await; + + 'outer: for i in 0..10 { + // try to see if the server is healthy + for _ in 0..10 { + // server still running? + if process.try_wait().unwrap().is_none() { + // listening on all ports? + if server_ready(&conn).await { + // server is ready to go + break 'outer; + } + } else { + // server exited, try again + break; + } + tokio::time::sleep(Duration::from_millis(500)).await; } + + // pick new ports and try again + (conn, listeners) = ConnectionSpec::new(); + conn_file.write_json(&conn); + drop(listeners); + process = start_process(&conn_file); + tokio::time::sleep(Duration::from_millis((i + 1) * 250)).await; } - if process.try_wait().unwrap().is_some() { + if process.try_wait().unwrap().is_some() || !server_ready(&conn).await { panic!("Failed to start Jupyter server"); } (context, conn, JupyterServerProcess(Some(process))) @@ -430,6 +481,9 @@ async fn setup() -> (TestContext, JupyterClient, JupyterServerProcess) { let (context, conn, process) = setup_server().await; let client = JupyterClient::new(&conn).await; client.io_subscribe("").await.unwrap(); + // make sure server is ready to receive messages + client.send_heartbeat(b"ping").await.unwrap(); + let _ = client.recv_heartbeat().await.unwrap(); (context, client, process) } @@ -530,7 +584,7 @@ async fn jupyter_execute_request() -> Result<()> { Err(e) => { if e.downcast_ref::().is_some() { // may timeout if we missed some messages - break; + eprintln!("Timed out waiting for messages"); } panic!("Error: {:#?}", e); }