Skip to content

Commit

Permalink
[RLlib][RLModule] The flattening of 2D spaces is deprecated in the ne…
Browse files Browse the repository at this point in the history
…w stack, so we need to properly apply CNNs to PettingZoo pixel-based envs, even if they are grayscale (ray-project#33832)

Signed-off-by: Kourosh Hakhamaneshi <[email protected]>
  • Loading branch information
kouroshHakha committed Apr 4, 2023
1 parent 9fb9c74 commit 2ae4ea8
Showing 1 changed file with 24 additions and 1 deletion.
25 changes: 24 additions & 1 deletion rllib/tests/test_pettingzoo_env.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
from numpy import float32
from pettingzoo.butterfly import pistonball_v6
from pettingzoo.mpe import simple_spread_v2
from supersuit import normalize_obs_v0, dtype_v0, color_reduction_v0
from supersuit import (
color_reduction_v0,
dtype_v0,
normalize_obs_v0,
observation_lambda_v0,
resize_v1,
)
from supersuit.utils.convert_box import convert_box

import unittest

import ray
Expand All @@ -10,6 +18,16 @@
from ray.tune.registry import register_env


def change_observation(obs, obs_space):
# convert all images to a 3d array with 1 channel
obs = obs[..., None]
return obs


def change_obs_space(obs_space):
return convert_box(lambda obs: change_observation(obs, obs_space), obs_space)


# TODO(sven): Move into rllib/env/wrappers/tests/.
class TestPettingZooEnv(unittest.TestCase):
def setUp(self) -> None:
Expand All @@ -24,6 +42,11 @@ def env_creator(config):
env = dtype_v0(env, dtype=float32)
env = color_reduction_v0(env, mode="R")
env = normalize_obs_v0(env)
# add a wrapper to convert the observation space to a 3d array
env = observation_lambda_v0(env, change_observation, change_obs_space)
# resize the observation space to 84x84 so that RLlib defauls CNN can
# process it
env = resize_v1(env, x_size=84, y_size=84, linear_interp=True)
return env

# Register env
Expand Down

0 comments on commit 2ae4ea8

Please sign in to comment.