Skip to content

Commit

Permalink
[RLlib] Replace calls to socket in learner group for getting ip addre…
Browse files Browse the repository at this point in the history
…ss with ray (ray-project#35218)

Signed-off-by: Avnish <[email protected]>
  • Loading branch information
avnishn committed May 10, 2023
1 parent ce95fb9 commit 11d4cdb
Showing 1 changed file with 5 additions and 9 deletions.
14 changes: 5 additions & 9 deletions rllib/core/learner/learner_group.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from collections import deque
import pathlib
import socket
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -489,16 +488,15 @@ def load_state(self, path: str) -> None:
self._learner.load_state(path)
else:
assert len(self._workers) == self._worker_manager.num_healthy_actors()
head_node_ip = socket.gethostbyname(socket.gethostname())
head_node_ip = ray.util.get_node_ip_address()
workers = self._worker_manager.healthy_actor_ids()

def _load_state(w):
# doing imports here since they might not be imported on the worker
import socket
import ray
import tempfile

hostname = socket.gethostname()
worker_node_ip = socket.gethostbyname(hostname)
worker_node_ip = ray.util.get_node_ip_address()
# if the worker is on the same node as the head, load the checkpoint
# directly from the path otherwise sync the checkpoint from the head
# to the worker and load it from there
Expand Down Expand Up @@ -540,11 +538,9 @@ def _get_ip_address(_=None) -> str:
The address of this process.
"""
import socket
import ray

hostname = socket.gethostname()

return socket.gethostbyname(hostname)
return ray.util.get_node_ip_address()

def shutdown(self):
"""Shuts down the LearnerGroup."""
Expand Down

0 comments on commit 11d4cdb

Please sign in to comment.