Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Threaded collection and parallel envs #1559

Merged
merged 10 commits into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
tests
  • Loading branch information
vmoens committed Sep 22, 2023
commit 9021e190e5be7e8c9797319318da2d8c00f8304c
9 changes: 9 additions & 0 deletions test/_utils_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from tensordict import tensorclass, TensorDict
from torchrl._utils import implement_for, seed_generator
from torchrl.data.utils import CloudpickleWrapper

from torchrl.envs import MultiThreadedEnv, ObservationNorm
from torchrl.envs.batched_envs import ParallelEnv, SerialEnv
Expand Down Expand Up @@ -433,3 +434,11 @@ def check_rollout_consistency_multikey_env(td: TensorDict, max_steps: int):
== td["nested_2", "observation"][~action_is_count]
).all()
assert (td["next", "nested_2", "reward"][~action_is_count] == 0).all()


def decorate_thread_sub_func(func, num_threads):
def new_func(*args, **kwargs):
assert torch.get_num_threads() == num_threads
return func(*args, **kwargs)

return CloudpickleWrapper(new_func)
31 changes: 31 additions & 0 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
# LICENSE file in the root directory of this source tree.

import argparse

import sys

import numpy as np
import pytest
import torch
from _utils_internal import (
check_rollout_consistency_multikey_env,
decorate_thread_sub_func,
generate_seeds,
PENDULUM_VERSIONED,
PONG_VERSIONED,
Expand Down Expand Up @@ -1783,6 +1785,35 @@ def make_env():
collector.shutdown()


def test_num_threads():
from torchrl.collectors import collectors

_main_async_collector_saved = collectors._main_async_collector
collectors._main_async_collector = decorate_thread_sub_func(
collectors._main_async_collector, num_threads=3
)
num_threads = torch.get_num_threads()
try:
env = ContinuousActionVecMockEnv()
c = MultiSyncDataCollector(
[env],
policy=RandomPolicy(env.action_spec),
num_threads=7,
num_sub_threads=3,
total_frames=200,
frames_per_batch=200,
)
assert torch.get_num_threads() == 7
for _ in c:
pass
c.shutdown()
del c
finally:
# reset vals
collectors._main_async_collector = _main_async_collector_saved
torch.set_num_threads(num_threads)


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
24 changes: 24 additions & 0 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
_make_envs,
CARTPOLE_VERSIONED,
check_rollout_consistency_multikey_env,
decorate_thread_sub_func,
get_default_devices,
HALFCHEETAH_VERSIONED,
PENDULUM_VERSIONED,
Expand Down Expand Up @@ -2088,6 +2089,29 @@ def test_mocking_envs(envclass):
check_env_specs(env, seed=100, return_contiguous=False)


def test_num_threads():
from torchrl.envs import batched_envs

_run_worker_pipe_shared_mem_save = batched_envs._run_worker_pipe_shared_mem
batched_envs._run_worker_pipe_shared_mem = decorate_thread_sub_func(
batched_envs._run_worker_pipe_shared_mem, num_threads=3
)
num_threads = torch.get_num_threads()
try:
env = ParallelEnv(
2, ContinuousActionVecMockEnv, num_sub_threads=3, num_threads=7
)
# We could test that the number of threads isn't changed until we start the procs.
# Even though it's unlikely that we have 7 threads, we still disable this for safety
# assert torch.get_num_threads() != 7
env.rollout(3)
assert torch.get_num_threads() == 7
finally:
# reset vals
batched_envs._run_worker_pipe_shared_mem = _run_worker_pipe_shared_mem_save
torch.set_num_threads(num_threads)


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
3 changes: 1 addition & 2 deletions torchrl/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,9 +261,8 @@ def __setstate__(self, ob: bytes):
self.fn, self.kwargs = pickle.loads(ob)

def __call__(self, *args, **kwargs) -> Any:
kwargs = {k: item for k, item in kwargs.items()}
kwargs.update(self.kwargs)
return self.fn(**kwargs)
return self.fn(*args, **kwargs)


def _process_action_space_spec(action_space, spec):
Expand Down
Loading