Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] step_and_maybe_reset in env #1611

Merged
merged 122 commits into from
Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
122 commits
Select commit Hold shift + click to select a range
1f539dd
init
vmoens Oct 6, 2023
d4c16e1
amend
vmoens Oct 6, 2023
d2321aa
amend
vmoens Oct 6, 2023
565115a
amend
vmoens Oct 6, 2023
3c46136
amend
vmoens Oct 6, 2023
78cfa41
amend
vmoens Oct 6, 2023
a6bd8eb
amend
vmoens Oct 6, 2023
04d4ae7
amend
vmoens Oct 6, 2023
2d0b4c6
init
vmoens Oct 10, 2023
528609a
amend
vmoens Oct 10, 2023
51cd5af
fix
vmoens Oct 10, 2023
3e31963
Merge remote-tracking branch 'origin/main' into step_maybe_reset
vmoens Oct 10, 2023
b11e73a
amend
vmoens Oct 10, 2023
1540407
remove `pip3 install -e .`
vmoens Oct 10, 2023
f1b0ea4
tensordict_
vmoens Oct 10, 2023
16b3538
amend rollout logic
vmoens Oct 10, 2023
7ad0864
amend
vmoens Oct 10, 2023
bcac398
amend
vmoens Oct 10, 2023
428f8ee
inference
vmoens Oct 10, 2023
6fbd0bd
cpu -> cuda
vmoens Oct 10, 2023
02db623
checks
vmoens Oct 10, 2023
08a8f47
using pipe instead of event
vmoens Oct 10, 2023
45e64f7
amend
vmoens Oct 10, 2023
7dd4821
amend
vmoens Oct 10, 2023
e1a2206
rm cuda event
vmoens Oct 10, 2023
dc2caab
amend
vmoens Oct 10, 2023
01ffbf9
amend
vmoens Oct 10, 2023
ac76ec3
amend
vmoens Oct 10, 2023
ceab010
amend
vmoens Oct 10, 2023
f0327c9
amend
vmoens Oct 10, 2023
518b3d1
amend
vmoens Oct 10, 2023
47dd93b
amend
vmoens Oct 10, 2023
354fb6f
amend
vmoens Oct 10, 2023
53d5f9a
amend
vmoens Oct 10, 2023
78c00e8
amend
vmoens Oct 10, 2023
f63480e
amend
vmoens Oct 10, 2023
9a3631f
amend
vmoens Oct 10, 2023
44336ed
Merge remote-tracking branch 'origin/main' into fix_ci
vmoens Oct 10, 2023
512f9f7
specs fix
vmoens Oct 10, 2023
2ceb438
amend
vmoens Oct 10, 2023
c666d5a
amend
vmoens Oct 10, 2023
6ecebda
amend
vmoens Oct 10, 2023
ce6e9bd
amend
vmoens Oct 10, 2023
5c613c3
amend
vmoens Oct 10, 2023
9f97e58
amend
vmoens Oct 10, 2023
9cbcbb0
amend
vmoens Oct 10, 2023
ae2748d
amend
vmoens Oct 10, 2023
72c4163
amend
vmoens Oct 10, 2023
6f4c374
amend
vmoens Oct 10, 2023
bf36bec
amend
vmoens Oct 10, 2023
4095766
amend
vmoens Oct 11, 2023
9f2a9ad
amend
vmoens Oct 11, 2023
5b34961
amend
vmoens Oct 11, 2023
deba78d
amend
vmoens Oct 11, 2023
b638461
amend
vmoens Oct 11, 2023
f8b2f60
amend
vmoens Oct 11, 2023
dfc868a
amend
vmoens Oct 11, 2023
1221c48
Merge branch 'fix_ci' into step_maybe_reset
vmoens Oct 11, 2023
22d3a27
amend
vmoens Oct 11, 2023
38e0e90
amend
vmoens Oct 11, 2023
5828a84
Merge branch 'fix_ci' into step_maybe_reset
vmoens Oct 11, 2023
004d381
amend
vmoens Oct 11, 2023
6ef720d
amend
vmoens Oct 11, 2023
d2167ea
Merge branch 'fix_ci' into step_maybe_reset
vmoens Oct 11, 2023
f66c6c9
amend
vmoens Oct 11, 2023
4fd2768
amend
vmoens Oct 11, 2023
ded1883
amend
vmoens Oct 11, 2023
cb1a83f
amend
vmoens Oct 11, 2023
6af6a45
amend
vmoens Oct 11, 2023
eaf5ebd
amend
vmoens Oct 12, 2023
f426322
Merge branch 'fix_ci' into step_maybe_reset
vmoens Oct 12, 2023
83672c6
amend
vmoens Oct 12, 2023
f0a6134
amend
vmoens Oct 12, 2023
c589a4a
amend
vmoens Oct 12, 2023
a85b321
amend
vmoens Oct 12, 2023
ed5e96f
amend
vmoens Oct 13, 2023
e705719
amend
vmoens Oct 16, 2023
4c3fa33
fix
vmoens Oct 17, 2023
e354bc8
fix
vmoens Oct 18, 2023
faa3b41
amend
vmoens Oct 18, 2023
80120b4
fix
vmoens Oct 18, 2023
09205e5
empty
vmoens Oct 18, 2023
3c41dab
amend
vmoens Oct 18, 2023
4db8110
amend
vmoens Oct 18, 2023
38cac91
empty
vmoens Oct 18, 2023
a336f0e
Merge remote-tracking branch 'origin/main' into step_maybe_reset
vmoens Oct 18, 2023
a9f4678
amend
vmoens Oct 18, 2023
9ee6348
amend
vmoens Oct 19, 2023
62c848b
amend
vmoens Oct 19, 2023
684d527
amend
vmoens Oct 19, 2023
89ddfd2
amend
vmoens Oct 19, 2023
a0376e4
amend
vmoens Oct 19, 2023
44dfe86
amend
vmoens Oct 19, 2023
d332597
amend
vmoens Oct 19, 2023
85cf664
amend
vmoens Oct 19, 2023
b55fae1
Merge remote-tracking branch 'origin/main' into step_maybe_reset
vmoens Oct 19, 2023
5dcbf15
amend
vmoens Oct 19, 2023
0236818
amend
vmoens Oct 20, 2023
36f151d
amend
vmoens Oct 20, 2023
afecc65
amend
vmoens Oct 20, 2023
f4a0beb
amend
vmoens Oct 20, 2023
5a331e5
amend
vmoens Oct 20, 2023
1ee4ba7
amend
vmoens Oct 20, 2023
9d14427
amend
vmoens Oct 20, 2023
8c15741
amend
vmoens Oct 20, 2023
1fe2cd3
amend
vmoens Oct 20, 2023
c378acc
amend
vmoens Oct 20, 2023
551c9eb
amend
vmoens Oct 20, 2023
e177c77
amend
vmoens Oct 20, 2023
28b0059
amend
vmoens Oct 20, 2023
5a21f2a
amend
vmoens Oct 20, 2023
b7b2081
lint
vmoens Oct 20, 2023
de1ecf2
amend
vmoens Oct 21, 2023
6fec99e
amend
vmoens Oct 21, 2023
6bfe517
amend
vmoens Oct 22, 2023
8f78ac0
cache keys
vmoens Oct 23, 2023
d2b734b
fix empty cache
vmoens Oct 23, 2023
a0c12d5
Merge remote-tracking branch 'origin/main' into step_maybe_reset
vmoens Oct 23, 2023
83eb52f
amend
vmoens Oct 24, 2023
5943c51
Merge remote-tracking branch 'origin/main' into step_maybe_reset
vmoens Oct 24, 2023
3e86847
addressing comments
vmoens Oct 24, 2023
ca1dd78
amend
vmoens Oct 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
amend
  • Loading branch information
vmoens committed Oct 21, 2023
commit de1ecf24b73e8c88310550d61111d387881f22e1
2 changes: 1 addition & 1 deletion examples/dqn/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ max_frames_per_traj: -1
weight_decay: 0.0
annealing_frames: 1000000
init_env_steps: 10000
record_frames: 50000
record_frames: 5000
loss_function: smooth_l1
batch_transform: 1
buffer_prefetch: 64
Expand Down
3 changes: 2 additions & 1 deletion test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -7511,7 +7511,8 @@ def test_dt_tensordict_keys(self):
loss_fn = DTLoss(actor)

default_keys = {
"action": "action",
"action_target": "action",
"action_pred": "action",
}

self.tensordict_keys_test(
Expand Down
4 changes: 2 additions & 2 deletions test/test_tensordictmodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -2191,10 +2191,10 @@ def test_dt_inference_wrapper(self, online):
)
with pytest.raises(
ValueError,
match="The action key action was not found in the policy out_keys",
match="The value of out_action_key",
):
result = inference_actor(td)
inference_actor.set_tensor_keys(action=action_key)
inference_actor.set_tensor_keys(action=action_key, out_action=action_key)
result = inference_actor(td)
# checks that the seq length has disappeared
assert result.get(action_key).shape == torch.Size([1, 2])
Expand Down
34 changes: 18 additions & 16 deletions torchrl/envs/libs/pettingzoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import copy
import importlib
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Tuple, Union

import torch
from tensordict.tensordict import TensorDictBase
Expand Down Expand Up @@ -154,11 +155,11 @@ def __init__(
"pettingzoo.utils.env.ParallelEnv", # noqa: F821
"pettingzoo.utils.env.AECEnv", # noqa: F821
] = None,
return_state: Optional[bool] = False,
group_map: Optional[Union[MarlGroupMapType, Dict[str, List[str]]]] = None,
return_state: bool = False,
group_map: MarlGroupMapType | Dict[str, List[str]] | None = None,
use_mask: bool = False,
categorical_actions: bool = True,
seed: Optional[int] = None,
seed: int | None = None,
**kwargs,
):
if env is not None:
Expand Down Expand Up @@ -401,7 +402,7 @@ def _check_kwargs(self, kwargs: Dict):
):
raise TypeError("env is not of type expected.")

def _init_env(self) -> Optional[int]:
def _init_env(self):
# Add info
if self.parallel:
_, info_dict = self._reset_parallel(seed=self.seed)
Expand Down Expand Up @@ -477,15 +478,16 @@ def _set_seed(self, seed: int):
self.reset(seed=self.seed)

def _reset(
self, tensordict: Optional[TensorDictBase] = None, **kwargs
self, tensordict: TensorDictBase | None = None, **kwargs
) -> TensorDictBase:

_reset = tensordict.get("_reset", None)
if _reset is not None and not _reset.all():
raise RuntimeError(
f"An attempt to call {type(self)}._reset was made when no reset signal could be found. "
f"Expected '_reset' entry to be `tensor(True)` or `None` but got `{_reset}`."
)
if tensordict is not None:
_reset = tensordict.get("_reset", None)
if _reset is not None and not _reset.all():
raise RuntimeError(
f"An attempt to call {type(self)}._reset was made when no "
f"reset signal could be found. Expected '_reset' entry to "
f"be `tensor(True)` or `None` but got `{_reset}`."
)
if self.parallel:
# This resets when any is done
observation_dict, info_dict = self._reset_parallel(**kwargs)
Expand Down Expand Up @@ -878,11 +880,11 @@ def __init__(
self,
task: str,
parallel: bool,
return_state: Optional[bool] = False,
group_map: Optional[Union[MarlGroupMapType, Dict[str, List[str]]]] = None,
return_state: bool = False,
group_map: MarlGroupMapType | Dict[str, List[str]] | None = None,
use_mask: bool = False,
categorical_actions: bool = True,
seed: Optional[int] = None,
seed: int | None = None,
**kwargs,
):
if not _has_pettingzoo:
Expand Down
20 changes: 16 additions & 4 deletions torchrl/record/recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,12 @@ def dump(self, suffix: Optional[str] = None) -> None:
self.count = 0
self.obs = []

def _reset(
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
) -> TensorDictBase:
self._call(tensordict_reset)
return tensordict_reset


class TensorDictRecorder(Transform):
"""TensorDict recorder.
Expand Down Expand Up @@ -171,14 +177,14 @@ def __init__(
self.skip = skip
self.count = 0

def _call(self, td: TensorDictBase) -> TensorDictBase:
def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
self.count += 1
if self.count % self.skip == 0:
_td = td
_td = tensordict
if self.in_keys:
_td = td.select(*self.in_keys).to_tensordict()
_td = tensordict.select(*self.in_keys).to_tensordict()
self.td.append(_td)
return td
return tensordict

def dump(self, suffix: Optional[str] = None) -> None:
if suffix is None:
Expand All @@ -197,3 +203,9 @@ def dump(self, suffix: Optional[str] = None) -> None:
self.count = 0
del self.td
self.td = []

def _reset(
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
) -> TensorDictBase:
self._call(tensordict_reset)
return tensordict_reset
6 changes: 6 additions & 0 deletions torchrl/trainers/helpers/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ def make_trainer(
)

if recorder is not None:
# create recorder object
recorder_obj = Recorder(
record_frames=cfg.record_frames,
frame_skip=cfg.frame_skip,
Expand All @@ -266,11 +267,14 @@ def make_trainer(
record_interval=cfg.record_interval,
log_keys=cfg.recorder_log_keys,
)
# register recorder
trainer.register_op(
"post_steps_log",
recorder_obj,
)
# call recorder - could be removed
recorder_obj(None)
# create explorative recorder - could be optional
recorder_obj_explore = Recorder(
record_frames=cfg.record_frames,
frame_skip=cfg.frame_skip,
Expand All @@ -281,10 +285,12 @@ def make_trainer(
suffix="exploration",
out_keys={("next", "reward"): "r_evaluation_exploration"},
)
# register recorder
trainer.register_op(
"post_steps_log",
recorder_obj_explore,
)
# call recorder - could be removed
recorder_obj_explore(None)

trainer.register_op(
Expand Down
Loading