Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Environment, Docs] SMACv2 and docs on action masking #1466

Merged
merged 42 commits into from
Sep 15, 2023
Merged
Changes from 1 commit
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
7d291a7
init
vmoens Jul 27, 2023
32d7dda
amend
matteobettini Aug 18, 2023
58cbe43
amend
matteobettini Aug 18, 2023
d8913ec
amend
matteobettini Aug 21, 2023
97263eb
amend
matteobettini Aug 21, 2023
266a84d
amend
matteobettini Aug 21, 2023
7e6b9c9
amend
matteobettini Aug 21, 2023
646e233
add info
matteobettini Aug 21, 2023
e9f5257
add ci
matteobettini Aug 21, 2023
60a1563
add ci
matteobettini Aug 21, 2023
79ce182
amend
matteobettini Aug 21, 2023
a88e13b
amend
matteobettini Aug 21, 2023
28cd2b5
amend
matteobettini Aug 21, 2023
d0cf059
amend
matteobettini Aug 21, 2023
9014bc2
amend
matteobettini Aug 21, 2023
662dbb7
amend
matteobettini Aug 21, 2023
1b4325d
docs
matteobettini Aug 21, 2023
b6ea627
add tests
matteobettini Aug 22, 2023
9d9eb12
add group map
matteobettini Aug 29, 2023
0aec973
Merge branch 'main' into smacv2
matteobettini Aug 30, 2023
e5250d8
Merge branch 'main' into smacv2
matteobettini Sep 5, 2023
be12ac1
fixes
matteobettini Sep 5, 2023
74d7449
fix import
matteobettini Sep 5, 2023
111bbeb
collector test
matteobettini Sep 5, 2023
d6daa7c
review fixes
matteobettini Sep 5, 2023
613b4b6
change default categorical actions to true due to absence of one hot …
matteobettini Sep 5, 2023
09b0fc7
Merge branch 'main' into smacv2
matteobettini Sep 5, 2023
d6dd19b
add docs
matteobettini Sep 5, 2023
b399c40
amend
matteobettini Sep 5, 2023
f43ffed
amend
matteobettini Sep 5, 2023
64efb69
amend
matteobettini Sep 6, 2023
b1f0a05
Merge branch 'main' into smacv2
matteobettini Sep 6, 2023
d3b6d04
Merge branch 'main' into smacv2
matteobettini Sep 14, 2023
6a56db6
ci
matteobettini Sep 14, 2023
e2e90ca
add conditional ci
matteobettini Sep 14, 2023
efcd68e
Merge branch 'main' into smacv2
matteobettini Sep 14, 2023
ba10220
add conditional ci
matteobettini Sep 14, 2023
7778cca
import
matteobettini Sep 15, 2023
abd8281
test
matteobettini Sep 15, 2023
a3c2334
test
matteobettini Sep 15, 2023
b8aa329
test
matteobettini Sep 15, 2023
382d06b
empty
matteobettini Sep 15, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add info
Signed-off-by: Matteo Bettini <[email protected]>
  • Loading branch information
matteobettini committed Aug 21, 2023
commit 646e2339e582c3b4c8a65fd1da95782c02a23e4e
41 changes: 38 additions & 3 deletions torchrl/envs/libs/smacv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from tensordict import TensorDict, TensorDictBase

from torchrl.data import (
BoundedTensorSpec,
CompositeSpec,
DiscreteTensorSpec,
OneHotDiscreteTensorSpec,
Expand Down Expand Up @@ -153,6 +154,30 @@ def _make_observation_spec(self) -> CompositeSpec:
device=self.device,
dtype=torch.float32,
)
info_spec = CompositeSpec(
{
"battle_won": DiscreteTensorSpec(
2, dtype=torch.bool, device=self.device
),
"episode_limit": DiscreteTensorSpec(
2, dtype=torch.bool, device=self.device
),
"dead_allies": BoundedTensorSpec(
minimum=0,
maximum=self.n_agents,
dtype=torch.long,
device=self.device,
shape=(),
),
"dead_enemies": BoundedTensorSpec(
minimum=0,
maximum=self.n_enemies,
dtype=torch.long,
device=self.device,
shape=(),
),
}
)
mask_spec = DiscreteTensorSpec(
2,
torch.Size([self.n_agents, self.n_actions]),
Expand All @@ -170,6 +195,7 @@ def _make_observation_spec(self) -> CompositeSpec:
device=self.device,
dtype=torch.float32,
),
"info": info_spec,
}
)
return spec
Expand Down Expand Up @@ -200,6 +226,7 @@ def _reset(
# collect outputs
obs = self._to_tensor(obs)
state = self._to_tensor(state)
info = self.observation_spec["info"].zero()

mask = self.update_action_mask()

Expand All @@ -208,7 +235,7 @@ def _reset(
{"observation": obs, "mask": mask}, batch_size=(self.n_agents,)
)
tensordict_out = TensorDict(
source={"agents": agents_td, "state": state},
source={"agents": agents_td, "state": state, "info": info},
batch_size=(),
device=self.device,
)
Expand All @@ -226,9 +253,16 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
# collect outputs
obs = self.get_obs()
state = self.get_state()
info = self.observation_spec["info"].encode(info)
if "episode_limit" not in info.keys():
info["episode_limit"] = self.observation_spec["info"][
"episode_limit"
].zero()

reward = torch.tensor(reward, device=self.device, dtype=torch.float32)
done = torch.tensor(done, device=self.device, dtype=torch.bool)
reward = torch.tensor(
reward, device=self.device, dtype=torch.float32
).unsqueeze(-1)
done = torch.tensor(done, device=self.device, dtype=torch.bool).unsqueeze(-1)

mask = self.update_action_mask()

Expand All @@ -242,6 +276,7 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
"next": {
"agents": agents_td,
"state": state,
"info": info,
"reward": reward,
"done": done,
}
Expand Down
Loading