Skip to content

Commit

Permalink
[Serve] Fix memory leak issue in serve inference (ray-project#27815)
Browse files Browse the repository at this point in the history
Signed-off-by: Stefan van der Kleij <[email protected]>
  • Loading branch information
sihanwang41 authored and Stefan van der Kleij committed Aug 18, 2022
1 parent 491f8d7 commit 85e0d36
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 2 deletions.
7 changes: 6 additions & 1 deletion python/ray/dag/py_obj_scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,5 +106,10 @@ def replace_nodes(self, table: Dict[SourceType, TransformedType]) -> Any:
def _replace_index(self, i: int) -> SourceType:
return self._replace_table[self._found[i]]

def clear(self):
"""Clear the scanner from the _instances"""
if id(self) in _instances:
del _instances[id(self)]

def __del__(self):
del _instances[id(self)]
self.clear()
35 changes: 34 additions & 1 deletion python/ray/dag/tests/test_py_obj_scanner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from ray.dag.py_obj_scanner import _PyObjScanner
from ray.dag.py_obj_scanner import _PyObjScanner, _instances
import pytest


class Source:
Expand Down Expand Up @@ -31,3 +32,35 @@ def test_not_serializing_objects():

replaced = scanner.replace_nodes({obj: 1 for obj in found})
assert replaced == [not_serializable, {"key": 1}]


def test_scanner_clear():
"""Test scanner clear to make the scanner GCable"""
prev_len = len(_instances)

def call_find_nodes():
scanner = _PyObjScanner(source_type=Source)
my_objs = [Source(), [Source(), {"key": Source()}]]
scanner.find_nodes(my_objs)
scanner.clear()
assert id(scanner) not in _instances

call_find_nodes()
assert prev_len == len(_instances)

def call_find_and_replace_nodes():
scanner = _PyObjScanner(source_type=Source)
my_objs = [Source(), [Source(), {"key": Source()}]]
found = scanner.find_nodes(my_objs)
scanner.replace_nodes({obj: 1 for obj in found})
scanner.clear()
assert id(scanner) not in _instances

call_find_and_replace_nodes()
assert prev_len == len(_instances)


if __name__ == "__main__":
import sys

sys.exit(pytest.main(["-v", __file__]))
3 changes: 3 additions & 0 deletions python/ray/serve/_private/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ async def resolve_async_tasks(self):
replacement_table = dict(zip(tasks, resolved))
self.args, self.kwargs = scanner.replace_nodes(replacement_table)

# Make the scanner GCable to avoid memory leak
scanner.clear()


class ReplicaSet:
"""Data structure representing a set of replica actor handles"""
Expand Down

0 comments on commit 85e0d36

Please sign in to comment.