diff --git a/python/ray/tests/test_mpi.py b/python/ray/tests/test_mpi.py index 50b76266ccc9d..ac01b1e24fb69 100644 --- a/python/ray/tests/test_mpi.py +++ b/python/ray/tests/test_mpi.py @@ -2,7 +2,6 @@ import ray import sys import os -from mpi4py import MPI import numpy @@ -22,6 +21,8 @@ def compute_pi(samples): def run(): + from mpi4py import MPI + comm = MPI.COMM_WORLD nprocs = comm.Get_size() myrank = comm.Get_rank()