Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't pass envpool envs where vectorenvs are needed #1096

Open
MischaPanch opened this issue Apr 3, 2024 · 0 comments
Open

Don't pass envpool envs where vectorenvs are needed #1096

MischaPanch opened this issue Apr 3, 2024 · 0 comments
Labels
bug Something isn't working good first issue Good for newcomers refactoring No change to functionality

Comments

@MischaPanch
Copy link
Collaborator

MischaPanch commented Apr 3, 2024

See the block comments in test and in Collector method. Somewhere a pure envpool-env is passed instead of instances of BaseVectorEnv, thus the interface is not followed.

This means we rely on the two interfaces accidentally kind-of coinciding. They already don't fully coincide since envpool envs return an info as single dict with arrays, whereas tianshou's VectorEnvs return an array of dicts.

@Trinkle23897 this issue might be of interest to you

@pytest.mark.skipif(envpool is None, reason="EnvPool doesn't support this platform")
def test_venv_wrapper_envpool_gym_reset_return_info() -> None:
    num_envs = 4
    env = VectorEnvNormObs(
        envpool.make_gymnasium("Ant-v3", num_envs=num_envs, gym_reset_return_info=True),
    )
    obs, info = env.reset()
    assert obs.shape[0] == num_envs
    # This is not actually unreachable b/c envpool does not return info in the right format
    if isinstance(info, dict):  # type: ignore[unreachable]
        for _, v in info.items():  # type: ignore[unreachable]
            if not isinstance(v, dict):
                assert v.shape[0] == num_envs
    else:
        for _info in info:
            for _, v in _info.items():
                if not isinstance(v, dict):
                    assert v.shape[0] == num_envs
    def reset_env(
        self,
        gym_reset_kwargs: dict[str, Any] | None = None,
    ) -> None:
        """Reset the environments and the initial obs, info, and hidden state of the collector."""
        gym_reset_kwargs = gym_reset_kwargs or {}
        self._pre_collect_obs_RO, self._pre_collect_info_R = self.env.reset(**gym_reset_kwargs)
        # TODO: hack, wrap envpool envs such that they don't return a dict
        if isinstance(self._pre_collect_info_R, dict):  # type: ignore[unreachable]
            # this can happen if the env is an envpool env. Then the thing returned by reset is a dict
            # with array entries instead of an array of dicts
            # We use Batch to turn it into an array of dicts
            self._pre_collect_info_R = _dict_of_arr_to_arr_of_dicts(self._pre_collect_info_R)  # type: ignore[unreachable]

        self._pre_collect_hidden_state_RH = None
@MischaPanch MischaPanch added bug Something isn't working good first issue Good for newcomers refactoring No change to functionality labels Apr 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working good first issue Good for newcomers refactoring No change to functionality
Projects
Development

No branches or pull requests

1 participant