diff --git a/rllib/core/learner/learner_group.py b/rllib/core/learner/learner_group.py index cbb6870a72fdd..0ab6fb47f5696 100644 --- a/rllib/core/learner/learner_group.py +++ b/rllib/core/learner/learner_group.py @@ -1,6 +1,5 @@ from collections import deque import pathlib -import socket from typing import ( Any, Callable, @@ -489,16 +488,15 @@ def load_state(self, path: str) -> None: self._learner.load_state(path) else: assert len(self._workers) == self._worker_manager.num_healthy_actors() - head_node_ip = socket.gethostbyname(socket.gethostname()) + head_node_ip = ray.util.get_node_ip_address() workers = self._worker_manager.healthy_actor_ids() def _load_state(w): # doing imports here since they might not be imported on the worker - import socket + import ray import tempfile - hostname = socket.gethostname() - worker_node_ip = socket.gethostbyname(hostname) + worker_node_ip = ray.util.get_node_ip_address() # if the worker is on the same node as the head, load the checkpoint # directly from the path otherwise sync the checkpoint from the head # to the worker and load it from there @@ -540,11 +538,9 @@ def _get_ip_address(_=None) -> str: The address of this process. """ - import socket + import ray - hostname = socket.gethostname() - - return socket.gethostbyname(hostname) + return ray.util.get_node_ip_address() def shutdown(self): """Shuts down the LearnerGroup."""