From 31855432290023a4885f2ee8ec75a56ac03cd3a7 Mon Sep 17 00:00:00 2001 From: Avnish Narayan <38871737+avnishn@users.noreply.github.com> Date: Fri, 12 May 2023 13:11:14 -0700 Subject: [PATCH] [RLlib] RLlib contrib (#35141) Signed-off-by: Avnish --- .buildkite/pipeline.ml.yml | 27 + ci/pipeline/determine_tests_to_run.py | 8 + rllib_contrib/README.md | 30 + rllib_contrib/a3c/README.rst | 21 + rllib_contrib/a3c/examples/a3c_cartpole_v1.py | 29 + rllib_contrib/a3c/pyproject.toml | 18 + rllib_contrib/a3c/requirements.txt | 2 + .../a3c/src/rllib_a3c/a3c/__init__.py | 7 + rllib_contrib/a3c/src/rllib_a3c/a3c/a3c.py | 261 +++++++++ .../a3c/src/rllib_a3c/a3c/a3c_tf_policy.py | 183 ++++++ .../a3c/src/rllib_a3c/a3c/a3c_torch_policy.py | 152 +++++ rllib_contrib/a3c/tests/test_a3c.py | 100 ++++ rllib_contrib/maml/README.rst | 27 + .../maml/examples/cartpole_mass_maml.py | 52 ++ rllib_contrib/maml/pyproject.toml | 18 + rllib_contrib/maml/requirements.txt | 2 + rllib_contrib/maml/src/rllib_maml/__init__.py | 0 .../maml/src/rllib_maml/envs/__init__.py | 11 + .../maml/src/rllib_maml/envs/ant_rand_goal.py | 86 +++ .../maml/src/rllib_maml/envs/cartpole_mass.py | 31 ++ .../maml/src/rllib_maml/envs/pendulum_mass.py | 33 ++ .../maml/src/rllib_maml/maml/__init__.py | 12 + .../maml/src/rllib_maml/maml/maml.py | 388 +++++++++++++ .../src/rllib_maml/maml/maml_tf_policy.py | 520 ++++++++++++++++++ .../src/rllib_maml/maml/maml_torch_policy.py | 449 +++++++++++++++ rllib_contrib/maml/tests/test_maml.py | 61 ++ 26 files changed, 2528 insertions(+) create mode 100644 rllib_contrib/README.md create mode 100644 rllib_contrib/a3c/README.rst create mode 100644 rllib_contrib/a3c/examples/a3c_cartpole_v1.py create mode 100644 rllib_contrib/a3c/pyproject.toml create mode 100644 rllib_contrib/a3c/requirements.txt create mode 100644 rllib_contrib/a3c/src/rllib_a3c/a3c/__init__.py create mode 100644 rllib_contrib/a3c/src/rllib_a3c/a3c/a3c.py create mode 100644 rllib_contrib/a3c/src/rllib_a3c/a3c/a3c_tf_policy.py create mode 100644 rllib_contrib/a3c/src/rllib_a3c/a3c/a3c_torch_policy.py create mode 100644 rllib_contrib/a3c/tests/test_a3c.py create mode 100644 rllib_contrib/maml/README.rst create mode 100644 rllib_contrib/maml/examples/cartpole_mass_maml.py create mode 100644 rllib_contrib/maml/pyproject.toml create mode 100644 rllib_contrib/maml/requirements.txt create mode 100644 rllib_contrib/maml/src/rllib_maml/__init__.py create mode 100644 rllib_contrib/maml/src/rllib_maml/envs/__init__.py create mode 100644 rllib_contrib/maml/src/rllib_maml/envs/ant_rand_goal.py create mode 100644 rllib_contrib/maml/src/rllib_maml/envs/cartpole_mass.py create mode 100644 rllib_contrib/maml/src/rllib_maml/envs/pendulum_mass.py create mode 100644 rllib_contrib/maml/src/rllib_maml/maml/__init__.py create mode 100644 rllib_contrib/maml/src/rllib_maml/maml/maml.py create mode 100644 rllib_contrib/maml/src/rllib_maml/maml/maml_tf_policy.py create mode 100644 rllib_contrib/maml/src/rllib_maml/maml/maml_torch_policy.py create mode 100644 rllib_contrib/maml/tests/test_maml.py diff --git a/.buildkite/pipeline.ml.yml b/.buildkite/pipeline.ml.yml index 6c1007d3cd507..ad474cf46e671 100644 --- a/.buildkite/pipeline.ml.yml +++ b/.buildkite/pipeline.ml.yml @@ -528,3 +528,30 @@ - ./ci/env/env_info.sh - python ./ci/env/setup_credentials.py wandb comet_ml - bazel test --config=ci $(./ci/run/bazel_export_options) --build_tests_only --test_tag_filters=needs_credentials,-timeseries_libs,-gpu,-py37,-post_wheel_build doc/... + + +- label: ":exploding_death_star: RLlib Contrib: A3C Tests" + conditions: ["NO_WHEELS_REQUIRED", "RAY_CI_RLLIB_CONTRIB_AFFECTED"] + commands: + - cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT + - (cd rllib_contrib/a3c && pip install -r requirements.txt && pip install -e .) + - ./ci/env/env_info.sh + - pytest rllib_contrib/a3c/tests/test_a3c.py + +- label: ":exploding_death_star: RLlib Contrib: MAML Tests" + conditions: ["NO_WHEELS_REQUIRED", "RAY_CI_RLLIB_CONTRIB_AFFECTED"] + commands: + - cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT + + # Install mujoco necessary for the testing environments + - sudo apt install libosmesa6-dev libgl1-mesa-glx libglfw3 patchelf -y + - wget https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz + - mkdir /root/.mujoco + - mv mujoco210-linux-x86_64.tar.gz /root/.mujoco/. + - (cd /root/.mujoco && tar -xf /root/.mujoco/mujoco210-linux-x86_64.tar.gz) + - echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/root/.mujoco/mujoco210/bin' >> /root/.bashrc + - source /root/.bashrc + + - (cd rllib_contrib/maml && pip install -r requirements.txt && pip install -e .) + - ./ci/env/env_info.sh + - pytest rllib_contrib/maml/tests/test_maml.py diff --git a/ci/pipeline/determine_tests_to_run.py b/ci/pipeline/determine_tests_to_run.py index bed9110be9384..7a3cd86d43200 100644 --- a/ci/pipeline/determine_tests_to_run.py +++ b/ci/pipeline/determine_tests_to_run.py @@ -88,6 +88,8 @@ def get_commit_range(): # Whether all RLlib tests should be run. # Set to 1 only when a source file in `ray/rllib` has been changed. RAY_CI_RLLIB_DIRECTLY_AFFECTED = 0 + # Whether to run all RLlib contrib tests + RAY_CI_RLLIB_CONTRIB_AFFECTED = 0 RAY_CI_SERVE_AFFECTED = 0 RAY_CI_CORE_CPP_AFFECTED = 0 RAY_CI_CPP_AFFECTED = 0 @@ -179,6 +181,9 @@ def get_commit_range(): RAY_CI_RLLIB_DIRECTLY_AFFECTED = 1 RAY_CI_LINUX_WHEELS_AFFECTED = 1 RAY_CI_MACOS_WHEELS_AFFECTED = 1 + elif re.match("rllib_contrib/", changed_file): + if not changed_file.endswith(".md"): + RAY_CI_RLLIB_CONTRIB_AFFECTED = 1 elif changed_file.startswith("python/ray/serve"): RAY_CI_DOC_AFFECTED = 1 RAY_CI_SERVE_AFFECTED = 1 @@ -307,6 +312,8 @@ def get_commit_range(): RAY_CI_TRAIN_AFFECTED = 1 RAY_CI_RLLIB_AFFECTED = 1 RAY_CI_RLLIB_DIRECTLY_AFFECTED = 1 + # the rllib contrib ci should only be run on pull requests + RAY_CI_RLLIB_CONTRIB_AFFECTED = 0 RAY_CI_SERVE_AFFECTED = 1 RAY_CI_CPP_AFFECTED = 1 RAY_CI_CORE_CPP_AFFECTED = 1 @@ -331,6 +338,7 @@ def get_commit_range(): "RAY_CI_TRAIN_AFFECTED={}".format(RAY_CI_TRAIN_AFFECTED), "RAY_CI_RLLIB_AFFECTED={}".format(RAY_CI_RLLIB_AFFECTED), "RAY_CI_RLLIB_DIRECTLY_AFFECTED={}".format(RAY_CI_RLLIB_DIRECTLY_AFFECTED), + "RAY_CI_RLLIB_CONTRIB_AFFECTED={}".format(RAY_CI_RLLIB_CONTRIB_AFFECTED), "RAY_CI_SERVE_AFFECTED={}".format(RAY_CI_SERVE_AFFECTED), "RAY_CI_DASHBOARD_AFFECTED={}".format(RAY_CI_DASHBOARD_AFFECTED), "RAY_CI_DOC_AFFECTED={}".format(RAY_CI_DOC_AFFECTED), diff --git a/rllib_contrib/README.md b/rllib_contrib/README.md new file mode 100644 index 0000000000000..1cc2e0e775eac --- /dev/null +++ b/rllib_contrib/README.md @@ -0,0 +1,30 @@ +# RLlib-Contrib + +RLlib-Contrib is a directory for more experimental community contributions to RLlib including contributed algorithms. **This directory has a more relaxed bar for contributions than Ray or RLlib.** If you are interested in contributing to RLlib-Contrib, please see the [contributing guide](CONTRIBUTING.md). + +## Getting Started and Installation +Navigate to the algorithm sub-directory you are interested in and see the README.md for installation instructions and example scripts to help you get started! + +## Maintenance + +**Any issues that are filed in `rllib_contrib` will be solved best-effort by the community and there is no expectation of maintenance by the RLlib team.** + +**The API surface between algorithms in `rllib_contrib` and current versions of Ray / RLlib is not guaranteed. This means that any APIs that are used in rllib_contrib could potentially become modified/removed in newer version of Ray/RLlib.** + +We will generally accept contributions to this directory that meet any of the following criteria: + +1. Updating dependencies. +2. Submitting community contributed algorithms that have been tested and are ready for use. +3. Enabling algorithms to be run in different environments (ex. adding support for a new type of gymnasium environment). +4. Updating algorithms for use with the newer RLlib APIs. +5. General bug fixes. + +We will not accept contributions that generally add a significant maintenance burden. In this case users should instead make their own repo with their contribution, using the same guidelines as this directory, and the RLlib team can help to market/promote it in the Ray docs. + +## Getting Involved + +| Platform | Purpose | Support Level | +| --- | --- | --- | +| [Discuss Forum](https://discuss.ray.io) | For discussions about development and questions about usage. | Community | +| [GitHub Issues](https://github.com/ray-project/rllib-contrib-maml/issues) | For reporting bugs and filing feature requests. | Community | +| [Slack](https://forms.gle/9TSdDYUgxYs8SA9e8) | For collaborating with other Ray users. | Community | diff --git a/rllib_contrib/a3c/README.rst b/rllib_contrib/a3c/README.rst new file mode 100644 index 0000000000000..df3665c1408e5 --- /dev/null +++ b/rllib_contrib/a3c/README.rst @@ -0,0 +1,21 @@ +A3C (Asynchronous Advantage Actor-Critic) +----------------------------------------- + +`A3C ` is the asynchronous version of A2C, where gradients are computed on the workers directly after trajectory rollouts, and only then shipped to a central learner to accumulate these gradients on the central model. After the central model update, parameters are broadcast back to all workers. Similar to A2C, A3C scales to 16-32+ worker processes depending on the environment. + + +Installation +------------ + +.. code-block:: bash + + conda create -n rllib-a3c python=3.10 + conda activate rllib-a3c + pip install -r requirements.txt + pip install -e '.[development]' + + +Usage +----- + +.. literalinclude:: examples/a3c_cartpole_v1.py \ No newline at end of file diff --git a/rllib_contrib/a3c/examples/a3c_cartpole_v1.py b/rllib_contrib/a3c/examples/a3c_cartpole_v1.py new file mode 100644 index 0000000000000..2f57ff71e1057 --- /dev/null +++ b/rllib_contrib/a3c/examples/a3c_cartpole_v1.py @@ -0,0 +1,29 @@ +from rllib_a3c.a3c import A3C, A3CConfig + +import ray +from ray import air, tune + +if __name__ == "__main__": + ray.init() + + config = ( + A3CConfig() + .rollouts(num_rollout_workers=1) + .framework("torch") + .environment("CartPole-v1") + .training( + gamma=0.95, + ) + ) + + num_iterations = 100 + + tuner = tune.Tuner( + A3C, + param_space=config.to_dict(), + run_config=air.RunConfig( + stop={"episode_reward_mean": 150, "timesteps_total": 200000}, + failure_config=air.FailureConfig(fail_fast="raise"), + ), + ) + results = tuner.fit() diff --git a/rllib_contrib/a3c/pyproject.toml b/rllib_contrib/a3c/pyproject.toml new file mode 100644 index 0000000000000..173999a039a85 --- /dev/null +++ b/rllib_contrib/a3c/pyproject.toml @@ -0,0 +1,18 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[tool.setuptools.packages.find] +where = ["src"] + +[project] +name = "rllib-a3c" +authors = [{name = "Anyscale Inc."}] +version = "0.1.0" +description = "" +readme = "README.md" +requires-python = ">=3.7, <3.11" +dependencies = ["gym[accept-rom-license]", "gymnasium[mujoco]==0.26.3", "higher", "ray[rllib]==2.3.1"] + +[project.optional-dependencies] +development = ["pytest>=7.2.2", "pre-commit==2.21.0", "tensorflow==2.11.0", "torch==1.12.0"] diff --git a/rllib_contrib/a3c/requirements.txt b/rllib_contrib/a3c/requirements.txt new file mode 100644 index 0000000000000..f1191ef524126 --- /dev/null +++ b/rllib_contrib/a3c/requirements.txt @@ -0,0 +1,2 @@ +tensorflow==2.11.0 +torch==1.12.0 diff --git a/rllib_contrib/a3c/src/rllib_a3c/a3c/__init__.py b/rllib_contrib/a3c/src/rllib_a3c/a3c/__init__.py new file mode 100644 index 0000000000000..3b050de0dca52 --- /dev/null +++ b/rllib_contrib/a3c/src/rllib_a3c/a3c/__init__.py @@ -0,0 +1,7 @@ +from rllib_a3c.a3c.a3c import A3C, A3CConfig + +from ray.tune.registry import register_trainable + +__all__ = ["A3CConfig", "A3C"] + +register_trainable("rllib-contrib-a3c", A3C) diff --git a/rllib_contrib/a3c/src/rllib_a3c/a3c/a3c.py b/rllib_contrib/a3c/src/rllib_a3c/a3c/a3c.py new file mode 100644 index 0000000000000..7f5a661cb94d2 --- /dev/null +++ b/rllib_contrib/a3c/src/rllib_a3c/a3c/a3c.py @@ -0,0 +1,261 @@ +import logging +from typing import Any, Dict, List, Optional, Type, Union + +from ray.rllib.algorithms.algorithm import Algorithm +from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided +from ray.rllib.evaluation.rollout_worker import RolloutWorker +from ray.rllib.policy.policy import Policy +from ray.rllib.utils.annotations import override +from ray.rllib.utils.metrics import ( + APPLY_GRADS_TIMER, + GRAD_WAIT_TIMER, + NUM_AGENT_STEPS_SAMPLED, + NUM_AGENT_STEPS_TRAINED, + NUM_ENV_STEPS_SAMPLED, + NUM_ENV_STEPS_TRAINED, + SYNCH_WORKER_WEIGHTS_TIMER, +) +from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder +from ray.rllib.utils.typing import ResultDict + +logger = logging.getLogger(__name__) + + +class A3CConfig(AlgorithmConfig): + """Defines a configuration class from which a A3C Algorithm can be built. + + Example: + >>> from ray import tune + >>> from ray.rllib.algorithms.a3c import A3CConfig + >>> config = A3CConfig() # doctest: +SKIP + >>> config = config.training(lr=0.01, grad_clip=30.0) # doctest: +SKIP + >>> config = config.resources(num_gpus=0) # doctest: +SKIP + >>> config = config.rollouts(num_rollout_workers=4) # doctest: +SKIP + >>> config = config.environment("CartPole-v1") # doctest: +SKIP + >>> print(config.to_dict()) # doctest: +SKIP + >>> # Build a Algorithm object from the config and run 1 training iteration. + >>> algo = config.build() # doctest: +SKIP + >>> algo.train() # doctest: +SKIP + + Example: + >>> from ray.rllib.algorithms.a3c import A3CConfig + >>> config = A3CConfig() + >>> # Print out some default values. + >>> print(config.sample_async) # doctest: +SKIP + >>> # Update the config object. + >>> config = config.training( # doctest: +SKIP + ... lr=tune.grid_search([0.001, 0.0001]), use_critic=False) + >>> # Set the config object's env. + >>> config = config.environment(env="CartPole-v1") # doctest: +SKIP + >>> # Use to_dict() to get the old-style python config dict + >>> # when running with tune. + >>> tune.Tuner( # doctest: +SKIP + ... "A3C", + ... stop={"episode_reward_mean": 200}, + ... param_space=config.to_dict(), + ... ).fit() + """ + + def __init__(self, algo_class=None): + """Initializes a A3CConfig instance.""" + super().__init__(algo_class=algo_class or A3C) + + # fmt: off + # __sphinx_doc_begin__ + # + # A3C specific settings. + self.use_critic = True + self.use_gae = True + self.lambda_ = 1.0 + self.grad_clip = 40.0 + self.lr_schedule = None + self.vf_loss_coeff = 0.5 + self.entropy_coeff = 0.01 + self.entropy_coeff_schedule = None + self.sample_async = True + + # Override some of AlgorithmConfig's default values with PPO-specific values. + self.num_rollout_workers = 2 + self.rollout_fragment_length = 10 + self.lr = 0.0001 + # Min time (in seconds) per reporting. + # This causes not every call to `training_iteration` to be reported, + # but to wait until n seconds have passed and then to summarize the + # thus far collected results. + self.min_time_s_per_iteration = 5 + self.exploration_config = { + # The Exploration class to use. In the simplest case, this is the name + # (str) of any class present in the `rllib.utils.exploration` package. + # You can also provide the python class directly or the full location + # of your class (e.g. "ray.rllib.utils.exploration.epsilon_greedy. + # EpsilonGreedy"). + "type": "StochasticSampling", + # Add constructor kwargs here (if any). + } + # __sphinx_doc_end__ + # fmt: on + + @override(AlgorithmConfig) + def training( + self, + *, + lr_schedule: Optional[List[List[Union[int, float]]]] = NotProvided, + use_critic: Optional[bool] = NotProvided, + use_gae: Optional[bool] = NotProvided, + lambda_: Optional[float] = NotProvided, + grad_clip: Optional[float] = NotProvided, + vf_loss_coeff: Optional[float] = NotProvided, + entropy_coeff: Optional[float] = NotProvided, + entropy_coeff_schedule: Optional[List[List[Union[int, float]]]] = NotProvided, + sample_async: Optional[bool] = NotProvided, + **kwargs, + ) -> "A3CConfig": + """Sets the training related configuration. + + Args: + lr_schedule: Learning rate schedule. In the format of + [[timestep, lr-value], [timestep, lr-value], ...] + Intermediary timesteps will be assigned to interpolated learning rate + values. A schedule should normally start from timestep 0. + use_critic: Should use a critic as a baseline (otherwise don't use value + baseline; required for using GAE). + use_gae: If true, use the Generalized Advantage Estimator (GAE) + with a value function, see https://arxiv.org/pdf/1506.02438.pdf. + lambda_: GAE(gamma) parameter. + grad_clip: Max global norm for each gradient calculated by worker. + vf_loss_coeff: Value Function Loss coefficient. + entropy_coeff: Coefficient of the entropy regularizer. + entropy_coeff_schedule: Decay schedule for the entropy regularizer. + sample_async: Whether workers should sample async. Note that this + increases the effective rollout_fragment_length by up to 5x due + to async buffering of batches. + + Returns: + This updated AlgorithmConfig object. + """ + # Pass kwargs onto super's `training()` method. + super().training(**kwargs) + + if lr_schedule is not NotProvided: + self.lr_schedule = lr_schedule + if use_critic is not NotProvided: + self.lr_schedule = use_critic + if use_gae is not NotProvided: + self.use_gae = use_gae + if lambda_ is not NotProvided: + self.lambda_ = lambda_ + if grad_clip is not NotProvided: + self.grad_clip = grad_clip + if vf_loss_coeff is not NotProvided: + self.vf_loss_coeff = vf_loss_coeff + if entropy_coeff is not NotProvided: + self.entropy_coeff = entropy_coeff + if entropy_coeff_schedule is not NotProvided: + self.entropy_coeff_schedule = entropy_coeff_schedule + if sample_async is not NotProvided: + self.sample_async = sample_async + + return self + + @override(AlgorithmConfig) + def validate(self) -> None: + # Call super's validation method. + super().validate() + + if self.entropy_coeff < 0: + raise ValueError("`entropy_coeff` must be >= 0.0!") + if self.num_rollout_workers <= 0 and self.sample_async: + raise ValueError("`num_workers` for A3C must be >= 1!") + + +class A3C(Algorithm): + @classmethod + @override(Algorithm) + def get_default_config(cls) -> AlgorithmConfig: + return A3CConfig() + + @classmethod + @override(Algorithm) + def get_default_policy_class( + cls, config: AlgorithmConfig + ) -> Optional[Type[Policy]]: + if config["framework"] == "torch": + from ray.rllib.algorithms.a3c.a3c_torch_policy import A3CTorchPolicy + + return A3CTorchPolicy + elif config["framework"] == "tf": + from ray.rllib.algorithms.a3c.a3c_tf_policy import A3CTF1Policy + + return A3CTF1Policy + else: + from ray.rllib.algorithms.a3c.a3c_tf_policy import A3CTF2Policy + + return A3CTF2Policy + + def training_step(self) -> ResultDict: + # Shortcut. + local_worker = self.workers.local_worker() + + # Define the function executed in parallel by all RolloutWorkers to collect + # samples + compute and return gradients (and other information). + + def sample_and_compute_grads(worker: RolloutWorker) -> Dict[str, Any]: + """Call sample() and compute_gradients() remotely on workers.""" + samples = worker.sample() + grads, infos = worker.compute_gradients(samples) + return { + "grads": grads, + "infos": infos, + "agent_steps": samples.agent_steps(), + "env_steps": samples.env_steps(), + } + + # Perform rollouts and gradient calculations asynchronously. + with self._timers[GRAD_WAIT_TIMER]: + # Results are a mapping from ActorHandle (RolloutWorker) to their + # returned gradient calculation results. + self.workers.foreach_worker_async( + func=sample_and_compute_grads, + healthy_only=True, + ) + async_results = self.workers.fetch_ready_async_reqs() + + # Loop through all fetched worker-computed gradients (if any) + # and apply them - one by one - to the local worker's model. + # After each apply step (one step per worker that returned some gradients), + # update that particular worker's weights. + global_vars = None + learner_info_builder = LearnerInfoBuilder(num_devices=1) + to_sync_workers = set() + for worker_id, result in async_results: + # Apply gradients to local worker. + with self._timers[APPLY_GRADS_TIMER]: + local_worker.apply_gradients(result["grads"]) + self._timers[APPLY_GRADS_TIMER].push_units_processed(result["agent_steps"]) + + # Update all step counters. + self._counters[NUM_AGENT_STEPS_SAMPLED] += result["agent_steps"] + self._counters[NUM_ENV_STEPS_SAMPLED] += result["env_steps"] + self._counters[NUM_AGENT_STEPS_TRAINED] += result["agent_steps"] + self._counters[NUM_ENV_STEPS_TRAINED] += result["env_steps"] + + learner_info_builder.add_learn_on_batch_results_multi_agent(result["infos"]) + + # Create current global vars. + global_vars = { + "timestep": self._counters[NUM_AGENT_STEPS_SAMPLED], + } + + # Add this worker to be synced. + to_sync_workers.add(worker_id) + + # Synch updated weights back to the particular worker + # (only those policies that are trainable). + with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: + self.workers.sync_weights( + policies=local_worker.get_policies_to_train(), + to_worker_indices=list(to_sync_workers), + global_vars=global_vars, + ) + + return learner_info_builder.finalize() diff --git a/rllib_contrib/a3c/src/rllib_a3c/a3c/a3c_tf_policy.py b/rllib_contrib/a3c/src/rllib_a3c/a3c/a3c_tf_policy.py new file mode 100644 index 0000000000000..bdc77f5790aeb --- /dev/null +++ b/rllib_contrib/a3c/src/rllib_a3c/a3c/a3c_tf_policy.py @@ -0,0 +1,183 @@ +"""Note: Keep in sync with changes to VTraceTFPolicy.""" +from typing import Dict, List, Optional, Type, Union + +from ray.rllib.evaluation.episode import Episode +from ray.rllib.evaluation.postprocessing import ( + Postprocessing, + compute_gae_for_sample_batch, +) +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.models.tf.tf_action_dist import TFActionDistribution +from ray.rllib.policy.dynamic_tf_policy_v2 import DynamicTFPolicyV2 +from ray.rllib.policy.eager_tf_policy_v2 import EagerTFPolicyV2 +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.tf_mixins import ( + EntropyCoeffSchedule, + LearningRateSchedule, + ValueNetworkMixin, + compute_gradients, +) +from ray.rllib.utils.annotations import override +from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.tf_utils import explained_variance +from ray.rllib.utils.typing import ( + AgentID, + LocalOptimizer, + ModelGradients, + TensorType, + TFPolicyV2Type, +) + +tf1, tf, tfv = try_import_tf() + + +# We need this builder function because we want to share the same +# custom logics between TF1 dynamic and TF2 eager policies. +def get_a3c_tf_policy(name: str, base: TFPolicyV2Type) -> TFPolicyV2Type: + """Construct a A3CTFPolicy inheriting either dynamic or eager base policies. + + Args: + base: Base class for this policy. DynamicTFPolicyV2 or EagerTFPolicyV2. + + Returns: + A TF Policy to be used with MAML. + """ + + class A3CTFPolicy( + ValueNetworkMixin, LearningRateSchedule, EntropyCoeffSchedule, base + ): + def __init__( + self, + observation_space, + action_space, + config, + existing_model=None, + existing_inputs=None, + ): + # First thing first, enable eager execution if necessary. + base.enable_eager_execution_if_necessary() + + # Initialize base class. + base.__init__( + self, + observation_space, + action_space, + config, + existing_inputs=existing_inputs, + existing_model=existing_model, + ) + + ValueNetworkMixin.__init__(self, self.config) + LearningRateSchedule.__init__( + self, self.config["lr"], self.config["lr_schedule"] + ) + EntropyCoeffSchedule.__init__( + self, config["entropy_coeff"], config["entropy_coeff_schedule"] + ) + + # Note: this is a bit ugly, but loss and optimizer initialization must + # happen after all the MixIns are initialized. + self.maybe_initialize_optimizer_and_loss() + + @override(base) + def loss( + self, + model: Union[ModelV2, "tf.keras.Model"], + dist_class: Type[TFActionDistribution], + train_batch: SampleBatch, + ) -> Union[TensorType, List[TensorType]]: + model_out, _ = model(train_batch) + action_dist = dist_class(model_out, model) + if self.is_recurrent(): + max_seq_len = tf.reduce_max(train_batch[SampleBatch.SEQ_LENS]) + valid_mask = tf.sequence_mask( + train_batch[SampleBatch.SEQ_LENS], max_seq_len + ) + valid_mask = tf.reshape(valid_mask, [-1]) + else: + valid_mask = tf.ones_like(train_batch[SampleBatch.REWARDS]) + + log_prob = action_dist.logp(train_batch[SampleBatch.ACTIONS]) + vf = model.value_function() + + # The "policy gradients" loss + self.pi_loss = -tf.reduce_sum( + tf.boolean_mask( + log_prob * train_batch[Postprocessing.ADVANTAGES], valid_mask + ) + ) + + delta = tf.boolean_mask( + vf - train_batch[Postprocessing.VALUE_TARGETS], valid_mask + ) + + # Compute a value function loss. + if self.config.get("use_critic", True): + self.vf_loss = 0.5 * tf.reduce_sum(tf.math.square(delta)) + # Ignore the value function. + else: + self.vf_loss = tf.constant(0.0) + + self.entropy_loss = tf.reduce_sum( + tf.boolean_mask(action_dist.entropy(), valid_mask) + ) + + self.total_loss = ( + self.pi_loss + + self.vf_loss * self.config["vf_loss_coeff"] + - self.entropy_loss * self.entropy_coeff + ) + + return self.total_loss + + @override(base) + def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]: + return { + "cur_lr": tf.cast(self.cur_lr, tf.float64), + "entropy_coeff": tf.cast(self.entropy_coeff, tf.float64), + "policy_loss": self.pi_loss, + "policy_entropy": self.entropy_loss, + "var_gnorm": tf.linalg.global_norm( + list(self.model.trainable_variables()) + ), + "vf_loss": self.vf_loss, + } + + @override(base) + def grad_stats_fn( + self, train_batch: SampleBatch, grads: ModelGradients + ) -> Dict[str, TensorType]: + return { + "grad_gnorm": tf.linalg.global_norm(grads), + "vf_explained_var": explained_variance( + train_batch[Postprocessing.VALUE_TARGETS], + self.model.value_function(), + ), + } + + @override(base) + def postprocess_trajectory( + self, + sample_batch: SampleBatch, + other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None, + episode: Optional[Episode] = None, + ): + sample_batch = super().postprocess_trajectory(sample_batch) + return compute_gae_for_sample_batch( + self, sample_batch, other_agent_batches, episode + ) + + @override(base) + def compute_gradients_fn( + self, optimizer: LocalOptimizer, loss: TensorType + ) -> ModelGradients: + return compute_gradients(self, optimizer, loss) + + A3CTFPolicy.__name__ = name + A3CTFPolicy.__qualname__ = name + + return A3CTFPolicy + + +A3CTF1Policy = get_a3c_tf_policy("A3CTF1Policy", DynamicTFPolicyV2) +A3CTF2Policy = get_a3c_tf_policy("A3CTF2Policy", EagerTFPolicyV2) diff --git a/rllib_contrib/a3c/src/rllib_a3c/a3c/a3c_torch_policy.py b/rllib_contrib/a3c/src/rllib_a3c/a3c/a3c_torch_policy.py new file mode 100644 index 0000000000000..e702254cd16c8 --- /dev/null +++ b/rllib_contrib/a3c/src/rllib_a3c/a3c/a3c_torch_policy.py @@ -0,0 +1,152 @@ +from typing import Dict, List, Optional, Type, Union + +from ray.rllib.evaluation.episode import Episode +from ray.rllib.evaluation.postprocessing import ( + Postprocessing, + compute_gae_for_sample_batch, +) +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.torch_mixins import ( + EntropyCoeffSchedule, + LearningRateSchedule, + ValueNetworkMixin, +) +from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2 +from ray.rllib.utils.annotations import override +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.numpy import convert_to_numpy +from ray.rllib.utils.torch_utils import apply_grad_clipping, sequence_mask +from ray.rllib.utils.typing import AgentID, TensorType + +torch, nn = try_import_torch() + + +class A3CTorchPolicy( + ValueNetworkMixin, LearningRateSchedule, EntropyCoeffSchedule, TorchPolicyV2 +): + """PyTorch Policy class used with A3C.""" + + def __init__(self, observation_space, action_space, config): + TorchPolicyV2.__init__( + self, + observation_space, + action_space, + config, + max_seq_len=config["model"]["max_seq_len"], + ) + ValueNetworkMixin.__init__(self, config) + LearningRateSchedule.__init__(self, config["lr"], config["lr_schedule"]) + EntropyCoeffSchedule.__init__( + self, config["entropy_coeff"], config["entropy_coeff_schedule"] + ) + + # TODO: Don't require users to call this manually. + self._initialize_loss_from_dummy_batch() + + @override(TorchPolicyV2) + def loss( + self, + model: ModelV2, + dist_class: Type[TorchDistributionWrapper], + train_batch: SampleBatch, + ) -> Union[TensorType, List[TensorType]]: + """Constructs the loss function. + + Args: + model: The Model to calculate the loss for. + dist_class: The action distr. class. + train_batch: The training data. + + Returns: + The A3C loss tensor given the input batch. + """ + logits, _ = model(train_batch) + values = model.value_function() + + if self.is_recurrent(): + B = len(train_batch[SampleBatch.SEQ_LENS]) + max_seq_len = logits.shape[0] // B + mask_orig = sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len) + valid_mask = torch.reshape(mask_orig, [-1]) + else: + valid_mask = torch.ones_like(values, dtype=torch.bool) + + dist = dist_class(logits, model) + log_probs = dist.logp(train_batch[SampleBatch.ACTIONS]).reshape(-1) + pi_err = -torch.sum( + torch.masked_select( + log_probs * train_batch[Postprocessing.ADVANTAGES], valid_mask + ) + ) + + # Compute a value function loss. + if self.config["use_critic"]: + value_err = 0.5 * torch.sum( + torch.pow( + torch.masked_select( + values.reshape(-1) - train_batch[Postprocessing.VALUE_TARGETS], + valid_mask, + ), + 2.0, + ) + ) + # Ignore the value function. + else: + value_err = 0.0 + + entropy = torch.sum(torch.masked_select(dist.entropy(), valid_mask)) + + total_loss = ( + pi_err + + value_err * self.config["vf_loss_coeff"] + - entropy * self.entropy_coeff + ) + + # Store values for stats function in model (tower), such that for + # multi-GPU, we do not override them during the parallel loss phase. + model.tower_stats["entropy"] = entropy + model.tower_stats["pi_err"] = pi_err + model.tower_stats["value_err"] = value_err + + return total_loss + + @override(TorchPolicyV2) + def optimizer( + self, + ) -> Union[List["torch.optim.Optimizer"], "torch.optim.Optimizer"]: + """Returns a torch optimizer (Adam) for A3C.""" + return torch.optim.Adam(self.model.parameters(), lr=self.config["lr"]) + + @override(TorchPolicyV2) + def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]: + return convert_to_numpy( + { + "cur_lr": self.cur_lr, + "entropy_coeff": self.entropy_coeff, + "policy_entropy": torch.mean( + torch.stack(self.get_tower_stats("entropy")) + ), + "policy_loss": torch.mean(torch.stack(self.get_tower_stats("pi_err"))), + "vf_loss": torch.mean(torch.stack(self.get_tower_stats("value_err"))), + } + ) + + @override(TorchPolicyV2) + def postprocess_trajectory( + self, + sample_batch: SampleBatch, + other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None, + episode: Optional[Episode] = None, + ): + sample_batch = super().postprocess_trajectory(sample_batch) + return compute_gae_for_sample_batch( + self, sample_batch, other_agent_batches, episode + ) + + @override(TorchPolicyV2) + def extra_grad_process( + self, optimizer: "torch.optim.Optimizer", loss: TensorType + ) -> Dict[str, TensorType]: + return apply_grad_clipping(self, optimizer, loss) diff --git a/rllib_contrib/a3c/tests/test_a3c.py b/rllib_contrib/a3c/tests/test_a3c.py new file mode 100644 index 0000000000000..66984eb1e4ae4 --- /dev/null +++ b/rllib_contrib/a3c/tests/test_a3c.py @@ -0,0 +1,100 @@ +import unittest + +from rllib_a3c.a3c import A3CConfig + +import ray +from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID +from ray.rllib.utils.metrics.learner_info import LEARNER_INFO, LEARNER_STATS_KEY +from ray.rllib.utils.test_utils import ( + check_compute_single_action, + check_train_results, + framework_iterator, +) + + +class TestA3C(unittest.TestCase): + """Sanity tests for A2C exec impl.""" + + def setUp(self): + ray.init(num_cpus=4) + + def tearDown(self): + ray.shutdown() + + def test_a3c_compilation(self): + """Test whether an A3C can be built with both frameworks.""" + config = A3CConfig().rollouts(num_rollout_workers=2, num_envs_per_worker=2) + + num_iterations = 2 + + # Test against all frameworks. + for _ in framework_iterator(config, with_eager_tracing=False): + for env in ["CartPole-v1", "Pendulum-v1"]: + print("env={}".format(env)) + config.model["use_lstm"] = env == "CartPole-v1" + algo = config.build(env=env) + for i in range(num_iterations): + results = algo.train() + check_train_results(results) + print(results) + check_compute_single_action( + algo, include_state=config.model["use_lstm"] + ) + algo.stop() + + def test_a3c_entropy_coeff_schedule(self): + """Test A3C entropy coeff schedule support.""" + config = A3CConfig().rollouts( + num_rollout_workers=1, + num_envs_per_worker=1, + batch_mode="truncate_episodes", + rollout_fragment_length=10, + ) + # Initial entropy coeff, doesn't really matter because of the schedule below. + config.training( + train_batch_size=20, + entropy_coeff=0.01, + entropy_coeff_schedule=[ + [0, 0.01], + [120, 0.0001], + ], + ) + # 0 metrics reporting delay, this makes sure timestep, + # which entropy coeff depends on, is updated after each worker rollout. + config.reporting( + min_time_s_per_iteration=0, min_sample_timesteps_per_iteration=20 + ) + + def _step_n_times(trainer, n: int): + """Step trainer n times. + + Returns: + learning rate at the end of the execution. + """ + for _ in range(n): + results = trainer.train() + return results["info"][LEARNER_INFO][DEFAULT_POLICY_ID][LEARNER_STATS_KEY][ + "entropy_coeff" + ] + + # Test against all frameworks. + for _ in framework_iterator(config): + algo = config.build(env="CartPole-v1") + + coeff = _step_n_times(algo, 1) # 20 timesteps + # Should be close to the starting coeff of 0.01 + self.assertGreaterEqual(coeff, 0.005) + + coeff = _step_n_times(algo, 10) # 200 timesteps + # Should have annealed to the final coeff of 0.0001. + self.assertLessEqual(coeff, 0.00011) + + algo.stop() + + +if __name__ == "__main__": + import sys + + import pytest + + sys.exit(pytest.main(["-v", __file__])) diff --git a/rllib_contrib/maml/README.rst b/rllib_contrib/maml/README.rst new file mode 100644 index 0000000000000..912fca39ed35c --- /dev/null +++ b/rllib_contrib/maml/README.rst @@ -0,0 +1,27 @@ +MAML (Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks) +------------------------------------------------------------------------ + +`MAML ` is an on-policy meta RL algorithm. Unlike standard RL algorithms, which aim to maximize the sum of rewards into the future for a single task (e.g. HalfCheetah), meta RL algorithms seek to maximize the sum of rewards for *a given distribution of tasks*. + +On a high level, MAML seeks to learn quick adaptation across different tasks (e.g. different velocities for HalfCheetah). Quick adaptation is defined by the number of gradient steps it takes to adapt. MAML aims to maximize the RL objective for each task after `X` gradient steps. Doing this requires partitioning the algorithm into two steps. The first step is data collection. This involves collecting data for each task for each step of adaptation (from `1, 2, ..., X`). The second step is the meta-update step. This second step takes all the aggregated ddata from the first step and computes the meta-gradient. + +Code here is adapted from `https://github.com/jonasrothfuss`, which outperforms vanilla MAML and avoids computation of the higher order gradients during the meta-update step. MAML is evaluated on custom environments that are described in greater detail here. + +MAML uses additional metrics to measure performance; episode_reward_mean measures the agent’s returns before adaptation, episode_reward_mean_adapt_N measures the agent’s returns after N gradient steps of inner adaptation, and adaptation_delta measures the difference in performance before and after adaptation. + + +Installation +------------ + +.. code-block:: bash + + conda create -n rllib-maml python=3.10 + conda activate rllib-maml + pip install -r requirements.txt + pip install -e '.[development]' + + +Usage +----- + +.. literalinclude:: examples/cartpole_mass_maml.py \ No newline at end of file diff --git a/rllib_contrib/maml/examples/cartpole_mass_maml.py b/rllib_contrib/maml/examples/cartpole_mass_maml.py new file mode 100644 index 0000000000000..72c27f83056c9 --- /dev/null +++ b/rllib_contrib/maml/examples/cartpole_mass_maml.py @@ -0,0 +1,52 @@ +from gymnasium.wrappers import TimeLimit +from rllib_maml.maml import MAML, MAMLConfig + +import ray +from ray import air, tune +from ray.rllib.examples.env.cartpole_mass import CartPoleMassEnv +from ray.tune.registry import register_env + +if __name__ == "__main__": + ray.init() + register_env( + "cartpole", + lambda env_cfg: TimeLimit(CartPoleMassEnv(), max_episode_steps=200), + ) + + rollout_fragment_length = 32 + + config = ( + MAMLConfig() + .rollouts( + num_rollout_workers=4, rollout_fragment_length=rollout_fragment_length + ) + .framework("torch") + .environment("cartpole", clip_actions=False) + .training( + inner_adaptation_steps=1, + maml_optimizer_steps=5, + gamma=0.99, + lambda_=1.0, + lr=0.001, + vf_loss_coeff=0.5, + inner_lr=0.03, + use_meta_env=False, + clip_param=0.3, + kl_target=0.01, + kl_coeff=0.001, + model=dict(fcnet_hiddens=[64, 64]), + train_batch_size=rollout_fragment_length, + ) + ) + + num_iterations = 100 + + tuner = tune.Tuner( + MAML, + param_space=config.to_dict(), + run_config=air.RunConfig( + stop={"training_iteration": num_iterations}, + failure_config=air.FailureConfig(fail_fast="raise"), + ), + ) + results = tuner.fit() diff --git a/rllib_contrib/maml/pyproject.toml b/rllib_contrib/maml/pyproject.toml new file mode 100644 index 0000000000000..bf6df70018fe6 --- /dev/null +++ b/rllib_contrib/maml/pyproject.toml @@ -0,0 +1,18 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[tool.setuptools.packages.find] +where = ["src"] + +[project] +name = "rllib-maml" +authors = [{name = "Anyscale Inc."}] +version = "0.1.0" +description = "" +readme = "README.md" +requires-python = ">=3.7, <3.11" +dependencies = ["gymnasium[mujoco]==0.26.3", "higher", "ray[rllib]==2.3.1"] + +[project.optional-dependencies] +development = ["pytest>=7.2.2", "pre-commit==2.21.0", "tensorflow==2.11.0", "torch==1.12.0"] diff --git a/rllib_contrib/maml/requirements.txt b/rllib_contrib/maml/requirements.txt new file mode 100644 index 0000000000000..f1191ef524126 --- /dev/null +++ b/rllib_contrib/maml/requirements.txt @@ -0,0 +1,2 @@ +tensorflow==2.11.0 +torch==1.12.0 diff --git a/rllib_contrib/maml/src/rllib_maml/__init__.py b/rllib_contrib/maml/src/rllib_maml/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/rllib_contrib/maml/src/rllib_maml/envs/__init__.py b/rllib_contrib/maml/src/rllib_maml/envs/__init__.py new file mode 100644 index 0000000000000..1796db67d13e0 --- /dev/null +++ b/rllib_contrib/maml/src/rllib_maml/envs/__init__.py @@ -0,0 +1,11 @@ +# Copyright 2023-onwards Anyscale, Inc. The use of this library is subject to the +# included LICENSE file. +from rllib_maml.envs.ant_rand_goal import AntRandGoalEnv +from rllib_maml.envs.cartpole_mass import CartPoleMassEnv +from rllib_maml.envs.pendulum_mass import PendulumMassEnv + +__all__ = [ + "AntRandGoalEnv", + "CartPoleMassEnv", + "PendulumMassEnv", +] diff --git a/rllib_contrib/maml/src/rllib_maml/envs/ant_rand_goal.py b/rllib_contrib/maml/src/rllib_maml/envs/ant_rand_goal.py new file mode 100644 index 0000000000000..5dd2f3c8e0265 --- /dev/null +++ b/rllib_contrib/maml/src/rllib_maml/envs/ant_rand_goal.py @@ -0,0 +1,86 @@ +import numpy as np +from gymnasium.envs.mujoco.mujoco_env import MujocoEnv +from gymnasium.utils import EzPickle + +from ray.rllib.env.apis.task_settable_env import TaskSettableEnv + + +class AntRandGoalEnv(EzPickle, MujocoEnv, TaskSettableEnv): + """Ant Environment that randomizes goals as tasks + + Goals are randomly sampled 2D positions + """ + + def __init__(self): + self.set_task(self.sample_tasks(1)[0]) + MujocoEnv.__init__(self, "ant.xml", 5) + EzPickle.__init__(self) + + def sample_tasks(self, n_tasks): + # Samples a goal position (2x1 position ector) + a = np.random.random(n_tasks) * 2 * np.pi + r = 3 * np.random.random(n_tasks) ** 0.5 + return np.stack((r * np.cos(a), r * np.sin(a)), axis=-1) + + def set_task(self, task): + """ + Args: + task: task of the meta-learning environment + """ + self.goal_pos = task + + def get_task(self): + """ + Returns: + task: task of the meta-learning environment + """ + return self.goal_pos + + def step(self, a): + self.do_simulation(a, self.frame_skip) + xposafter = self.get_body_com("torso") + goal_reward = -np.sum( + np.abs(xposafter[:2] - self.goal_pos) + ) # make it happy, not suicidal + ctrl_cost = 0.1 * np.square(a).sum() + contact_cost = ( + 0.5 * 1e-3 * np.sum(np.square(np.clip(self.sim.data.cfrc_ext, -1, 1))) + ) + # survive_reward = 1.0 + survive_reward = 0.0 + reward = goal_reward - ctrl_cost - contact_cost + survive_reward + # notdone = np.isfinite(state).all() and 1.0 >= state[2] >= 0. + # done = not notdone + done = False + ob = self._get_obs() + return ( + ob, + reward, + done, + dict( + reward_forward=goal_reward, + reward_ctrl=-ctrl_cost, + reward_contact=-contact_cost, + reward_survive=survive_reward, + ), + ) + + def _get_obs(self): + return np.concatenate( + [ + self.sim.data.qpos.flat, + self.sim.data.qvel.flat, + np.clip(self.sim.data.cfrc_ext, -1, 1).flat, + ] + ) + + def reset_model(self): + qpos = self.init_qpos + self.np_random.uniform( + size=self.model.nq, low=-0.1, high=0.1 + ) + qvel = self.init_qvel + self.np_random.randn(self.model.nv) * 0.1 + self.set_state(qpos, qvel) + return self._get_obs() + + def viewer_setup(self): + self.viewer.cam.distance = self.model.stat.extent * 0.5 diff --git a/rllib_contrib/maml/src/rllib_maml/envs/cartpole_mass.py b/rllib_contrib/maml/src/rllib_maml/envs/cartpole_mass.py new file mode 100644 index 0000000000000..bfd481402eb7c --- /dev/null +++ b/rllib_contrib/maml/src/rllib_maml/envs/cartpole_mass.py @@ -0,0 +1,31 @@ +import numpy as np +from gymnasium.envs.classic_control.cartpole import CartPoleEnv +from gymnasium.utils import EzPickle + +from ray.rllib.env.apis.task_settable_env import TaskSettableEnv + + +class CartPoleMassEnv(CartPoleEnv, EzPickle, TaskSettableEnv): + """CartPoleMassEnv varies the weights of the cart and the pole.""" + + def sample_tasks(self, n_tasks): + # Sample new cart- and pole masses (random floats between 0.5 and 2.0 + # (cart) and between 0.05 and 0.2 (pole)). + cart_masses = np.random.uniform(low=0.5, high=2.0, size=(n_tasks, 1)) + pole_masses = np.random.uniform(low=0.05, high=0.2, size=(n_tasks, 1)) + return np.concatenate([cart_masses, pole_masses], axis=-1) + + def set_task(self, task): + """ + Args: + task (Tuple[float]): Masses of the cart and the pole. + """ + self.masscart = task[0] + self.masspole = task[1] + + def get_task(self): + """ + Returns: + Tuple[float]: The current mass of the cart- and pole. + """ + return np.array([self.masscart, self.masspole]) diff --git a/rllib_contrib/maml/src/rllib_maml/envs/pendulum_mass.py b/rllib_contrib/maml/src/rllib_maml/envs/pendulum_mass.py new file mode 100644 index 0000000000000..2b4abdf20107e --- /dev/null +++ b/rllib_contrib/maml/src/rllib_maml/envs/pendulum_mass.py @@ -0,0 +1,33 @@ +import numpy as np +from gymnasium.envs.classic_control.pendulum import PendulumEnv +from gymnasium.utils import EzPickle + +from ray.rllib.env.apis.task_settable_env import TaskSettableEnv + + +class PendulumMassEnv(PendulumEnv, EzPickle, TaskSettableEnv): + """PendulumMassEnv varies the weight of the pendulum + + Tasks are defined to be weight uniformly sampled between [0.5,2] + """ + + def sample_tasks(self, n_tasks): + # Sample new pendulum masses (random floats between 0.5 and 2). + return np.random.uniform(low=0.5, high=2.0, size=(n_tasks,)) + + def set_task(self, task): + """ + Args: + task: Task of the meta-learning environment (here: mass of + the pendulum). + """ + # self.m is the mass property of the pendulum. + self.m = task + + def get_task(self): + """ + Returns: + float: The current mass of the pendulum (self.m in the PendulumEnv + object). + """ + return self.m diff --git a/rllib_contrib/maml/src/rllib_maml/maml/__init__.py b/rllib_contrib/maml/src/rllib_maml/maml/__init__.py new file mode 100644 index 0000000000000..1ec07956fabd3 --- /dev/null +++ b/rllib_contrib/maml/src/rllib_maml/maml/__init__.py @@ -0,0 +1,12 @@ +# Copyright 2023-onwards Anyscale, Inc. The use of this library is subject to the +# included LICENSE file. +from rllib_maml.maml.maml import MAML, MAMLConfig + +from ray.tune.registry import register_trainable + +__all__ = [ + "MAML", + "MAMLConfig", +] + +register_trainable("rllib-contrib-maml", MAML) diff --git a/rllib_contrib/maml/src/rllib_maml/maml/maml.py b/rllib_contrib/maml/src/rllib_maml/maml/maml.py new file mode 100644 index 0000000000000..e03a7ff3f6caf --- /dev/null +++ b/rllib_contrib/maml/src/rllib_maml/maml/maml.py @@ -0,0 +1,388 @@ +import logging +from typing import Optional, Type + +import numpy as np + +from ray.rllib.algorithms.algorithm import Algorithm +from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided +from ray.rllib.evaluation.metrics import collect_metrics, get_learner_stats +from ray.rllib.evaluation.worker_set import WorkerSet +from ray.rllib.execution.common import ( + STEPS_SAMPLED_COUNTER, + STEPS_TRAINED_COUNTER, + STEPS_TRAINED_THIS_ITER_COUNTER, + _get_shared_metrics, +) +from ray.rllib.execution.metric_ops import CollectMetrics +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.sample_batch import ( + concat_samples, + convert_ma_batch_to_sample_batch, +) +from ray.rllib.utils.annotations import override +from ray.rllib.utils.deprecation import DEPRECATED_VALUE +from ray.rllib.utils.metrics.learner_info import LEARNER_INFO +from ray.rllib.utils.sgd import standardized +from ray.util.iter import LocalIterator, from_actors + +logger = logging.getLogger(__name__) + + +class MAMLConfig(AlgorithmConfig): + """Defines a configuration class from which a MAML Algorithm can be built. + + Example: + >>> from ray.rllib.algorithms.maml import MAMLConfig + >>> config = MAMLConfig().training(use_gae=False).resources(num_gpus=1) + >>> print(config.to_dict()) # doctest: +SKIP + >>> # Build a Algorithm object from the config and run 1 training iteration. + >>> algo = config.build(env="CartPole-v1") # doctest: +SKIP + >>> algo.train() # doctest: +SKIP + + Example: + >>> from ray.rllib.algorithms.maml import MAMLConfig + >>> from ray import air + >>> from ray import tune + >>> config = MAMLConfig() + >>> # Print out some default values. + >>> print(config.lr) # doctest: +SKIP + >>> # Update the config object. + >>> config = config.training( # doctest: +SKIP + ... grad_clip=tune.grid_search([10.0, 40.0])) + >>> # Set the config object's env. + >>> config = config.environment(env="CartPole-v1") + >>> # Use to_dict() to get the old-style python config dict + >>> # when running with tune. + >>> tune.Tuner( # doctest: +SKIP + ... "MAML", + ... run_config=air.RunConfig(stop={"episode_reward_mean": 200}), + ... param_space=config.to_dict(), + ... ).fit() + """ + + def __init__(self, algo_class=None): + """Initializes a PGConfig instance.""" + super().__init__(algo_class=algo_class or MAML) + + # fmt: off + # __sphinx_doc_begin__ + # MAML-specific config settings. + self.use_gae = True + self.lambda_ = 1.0 + self.kl_coeff = 0.0005 + self.vf_loss_coeff = 0.5 + self.entropy_coeff = 0.0 + self.clip_param = 0.3 + self.vf_clip_param = 10.0 + self.grad_clip = None + self.kl_target = 0.01 + self.inner_adaptation_steps = 1 + self.maml_optimizer_steps = 5 + self.inner_lr = 0.1 + self.use_meta_env = True + + # Override some of AlgorithmConfig's default values with MAML-specific values. + self.num_rollout_workers = 2 + self.rollout_fragment_length = 200 + self.create_env_on_local_worker = True + self.lr = 1e-3 + + # Share layers for value function. + self.model.update({ + "vf_share_layers": False, + }) + + self.batch_mode = "complete_episodes" + self._disable_execution_plan_api = False + self.exploration_config = { + # The Exploration class to use. In the simplest case, this is the name + # (str) of any class present in the `rllib.utils.exploration` package. + # You can also provide the python class directly or the full location + # of your class (e.g. "ray.rllib.utils.exploration.epsilon_greedy. + # EpsilonGreedy"). + "type": "StochasticSampling", + # Add constructor kwargs here (if any). + } + # __sphinx_doc_end__ + # fmt: on + + # Deprecated keys: + self.vf_share_layers = DEPRECATED_VALUE + + def training( + self, + *, + use_gae: Optional[bool] = NotProvided, + lambda_: Optional[float] = NotProvided, + kl_coeff: Optional[float] = NotProvided, + vf_loss_coeff: Optional[float] = NotProvided, + entropy_coeff: Optional[float] = NotProvided, + clip_param: Optional[float] = NotProvided, + vf_clip_param: Optional[float] = NotProvided, + grad_clip: Optional[float] = NotProvided, + kl_target: Optional[float] = NotProvided, + inner_adaptation_steps: Optional[int] = NotProvided, + maml_optimizer_steps: Optional[int] = NotProvided, + inner_lr: Optional[float] = NotProvided, + use_meta_env: Optional[bool] = NotProvided, + **kwargs, + ) -> "MAMLConfig": + """Sets the training related configuration. + + Args: + use_gae: If true, use the Generalized Advantage Estimator (GAE) + with a value function, see https://arxiv.org/pdf/1506.02438.pdf. + lambda_: The GAE (lambda) parameter. + kl_coeff: Initial coefficient for KL divergence. + vf_loss_coeff: Coefficient of the value function loss. + entropy_coeff: Coefficient of the entropy regularizer. + clip_param: PPO clip parameter. + vf_clip_param: Clip param for the value function. Note that this is + sensitive to the scale of the rewards. If your expected V is large, + increase this. + grad_clip: If specified, clip the global norm of gradients by this amount. + kl_target: Target value for KL divergence. + inner_adaptation_steps: Number of Inner adaptation steps for the MAML + algorithm. + maml_optimizer_steps: Number of MAML steps per meta-update iteration + (PPO steps). + inner_lr: Inner Adaptation Step size. + use_meta_env: Use Meta Env Template. + + Returns: + This updated AlgorithmConfig object. + """ + # Pass kwargs onto super's `training()` method. + super().training(**kwargs) + + if use_gae is not NotProvided: + self.use_gae = use_gae + if lambda_ is not NotProvided: + self.lambda_ = lambda_ + if kl_coeff is not NotProvided: + self.kl_coeff = kl_coeff + if vf_loss_coeff is not NotProvided: + self.vf_loss_coeff = vf_loss_coeff + if entropy_coeff is not NotProvided: + self.entropy_coeff = entropy_coeff + if clip_param is not NotProvided: + self.clip_param = clip_param + if vf_clip_param is not NotProvided: + self.vf_clip_param = vf_clip_param + if grad_clip is not NotProvided: + self.grad_clip = grad_clip + if kl_target is not NotProvided: + self.kl_target = kl_target + if inner_adaptation_steps is not NotProvided: + self.inner_adaptation_steps = inner_adaptation_steps + if maml_optimizer_steps is not NotProvided: + self.maml_optimizer_steps = maml_optimizer_steps + if inner_lr is not NotProvided: + self.inner_lr = inner_lr + if use_meta_env is not NotProvided: + self.use_meta_env = use_meta_env + + return self + + @override(AlgorithmConfig) + def validate(self) -> None: + # Call super's validation method. + super().validate() + + if self.num_gpus > 1: + raise ValueError("`num_gpus` > 1 not yet supported for MAML!") + if self.inner_adaptation_steps <= 0: + raise ValueError("Inner Adaptation Steps must be >=1!") + if self.maml_optimizer_steps <= 0: + raise ValueError("PPO steps for meta-update needs to be >=0!") + if self.entropy_coeff < 0: + raise ValueError("`entropy_coeff` must be >=0.0!") + if self.batch_mode != "complete_episodes": + raise ValueError("`batch_mode`=truncate_episodes not supported!") + if self.num_rollout_workers <= 0: + raise ValueError("Must have at least 1 worker/task!") + if self.create_env_on_local_worker is False: + raise ValueError( + "Must have an actual Env created on the driver " + "(local) worker! Try setting `config.environment(" + "create_env_on_local_worker=True)`." + ) + + +# @mluo: TODO +def set_worker_tasks(workers, use_meta_env): + if use_meta_env: + n_tasks = len(workers.remote_workers()) + tasks = workers.local_worker().foreach_env(lambda x: x)[0].sample_tasks(n_tasks) + for i, worker in enumerate(workers.remote_workers()): + worker.foreach_env.remote(lambda env: env.set_task(tasks[i])) + + +class MetaUpdate: + def __init__(self, workers, maml_steps, metric_gen, use_meta_env): + self.workers = workers + self.maml_optimizer_steps = maml_steps + self.metric_gen = metric_gen + self.use_meta_env = use_meta_env + + def __call__(self, data_tuple): + # Metaupdate Step + samples = data_tuple[0] + adapt_metrics_dict = data_tuple[1] + + # Metric Updating + metrics = _get_shared_metrics() + metrics.counters[STEPS_SAMPLED_COUNTER] += samples.count + fetches = None + for i in range(self.maml_optimizer_steps): + fetches = self.workers.local_worker().learn_on_batch(samples) + learner_stats = get_learner_stats(fetches) + + # Sync workers with meta policy + self.workers.sync_weights() + + # Set worker tasks + set_worker_tasks(self.workers, self.use_meta_env) + + # Update KLS + def update(pi, pi_id): + assert "inner_kl" not in learner_stats, ( + "inner_kl should be nested under policy id key", + learner_stats, + ) + if pi_id in learner_stats: + assert "inner_kl" in learner_stats[pi_id], (learner_stats, pi_id) + pi.update_kls(learner_stats[pi_id]["inner_kl"]) + else: + logger.warning("No data for {}, not updating kl".format(pi_id)) + + self.workers.local_worker().foreach_policy_to_train(update) + + # Modify Reporting Metrics + metrics = _get_shared_metrics() + metrics.info[LEARNER_INFO] = fetches + metrics.counters[STEPS_TRAINED_THIS_ITER_COUNTER] = samples.count + metrics.counters[STEPS_TRAINED_COUNTER] += samples.count + + res = self.metric_gen.__call__(None) + res.update(adapt_metrics_dict) + + return res + + +def post_process_metrics(adapt_iter, workers, metrics): + # Obtain Current Dataset Metrics and filter out + name = "_adapt_" + str(adapt_iter) if adapt_iter > 0 else "" + + # Only workers are collecting data + res = collect_metrics(workers=workers) + + metrics["episode_reward_max" + str(name)] = res["episode_reward_max"] + metrics["episode_reward_mean" + str(name)] = res["episode_reward_mean"] + metrics["episode_reward_min" + str(name)] = res["episode_reward_min"] + + return metrics + + +def inner_adaptation(workers, samples): + # Each worker performs one gradient descent + for i, e in enumerate(workers.remote_workers()): + e.learn_on_batch.remote(samples[i]) + + +class MAML(Algorithm): + @classmethod + @override(Algorithm) + def get_default_config(cls) -> AlgorithmConfig: + return MAMLConfig() + + @classmethod + @override(Algorithm) + def get_default_policy_class( + cls, config: AlgorithmConfig + ) -> Optional[Type[Policy]]: + if config["framework"] == "torch": + from ray.rllib.algorithms.maml.maml_torch_policy import MAMLTorchPolicy + + return MAMLTorchPolicy + elif config["framework"] == "tf": + from ray.rllib.algorithms.maml.maml_tf_policy import MAMLTF1Policy + + return MAMLTF1Policy + else: + from ray.rllib.algorithms.maml.maml_tf_policy import MAMLTF2Policy + + return MAMLTF2Policy + + @staticmethod + @override(Algorithm) + def execution_plan( + workers: WorkerSet, config: AlgorithmConfig, **kwargs + ) -> LocalIterator[dict]: + assert ( + len(kwargs) == 0 + ), "MAML execution_plan does NOT take any additional parameters" + + # Sync workers with meta policy + workers.sync_weights() + + # Samples and sets worker tasks + use_meta_env = config.use_meta_env + set_worker_tasks(workers, use_meta_env) + + # Metric Collector + metric_collect = CollectMetrics( + workers, + min_history=config.metrics_num_episodes_for_smoothing, + timeout_seconds=config.metrics_episode_collection_timeout_s, + ) + + # Iterator for Inner Adaptation Data gathering (from pre->post + # adaptation) + inner_steps = config.inner_adaptation_steps + + def inner_adaptation_steps(itr): + buf = [] + split = [] + metrics = {} + for samples in itr: + # Processing Samples (Standardize Advantages) + split_lst = [] + for sample in samples: + sample = convert_ma_batch_to_sample_batch(sample) + sample["advantages"] = standardized(sample["advantages"]) + split_lst.append(sample.count) + buf.append(sample) + + split.append(split_lst) + + adapt_iter = len(split) - 1 + metrics = post_process_metrics(adapt_iter, workers, metrics) + if len(split) > inner_steps: + out = concat_samples(buf) + out["split"] = np.array(split) + buf = [] + split = [] + + # Reporting Adaptation Rew Diff + ep_rew_pre = metrics["episode_reward_mean"] + ep_rew_post = metrics[ + "episode_reward_mean_adapt_" + str(inner_steps) + ] + metrics["adaptation_delta"] = ep_rew_post - ep_rew_pre + yield out, metrics + metrics = {} + else: + inner_adaptation(workers, samples) + + rollouts = from_actors(workers.remote_workers()) + rollouts = rollouts.batch_across_shards() + rollouts = rollouts.transform(inner_adaptation_steps) + + # Metaupdate Step + train_op = rollouts.for_each( + MetaUpdate( + workers, config.maml_optimizer_steps, metric_collect, use_meta_env + ) + ) + return train_op diff --git a/rllib_contrib/maml/src/rllib_maml/maml/maml_tf_policy.py b/rllib_contrib/maml/src/rllib_maml/maml/maml_tf_policy.py new file mode 100644 index 0000000000000..d81bf8d834ecc --- /dev/null +++ b/rllib_contrib/maml/src/rllib_maml/maml/maml_tf_policy.py @@ -0,0 +1,520 @@ +import logging +from typing import Dict, List, Type, Union + +from ray.rllib.algorithms.ppo.ppo_tf_policy import validate_config +from ray.rllib.evaluation.postprocessing import ( + Postprocessing, + compute_gae_for_sample_batch, +) +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.models.tf.tf_action_dist import TFActionDistribution +from ray.rllib.models.utils import get_activation_fn +from ray.rllib.policy.dynamic_tf_policy_v2 import DynamicTFPolicyV2 +from ray.rllib.policy.eager_tf_policy_v2 import EagerTFPolicyV2 +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.tf_mixins import ( + LocalOptimizer, + ModelGradients, + ValueNetworkMixin, + compute_gradients, +) +from ray.rllib.utils import try_import_tf +from ray.rllib.utils.annotations import override +from ray.rllib.utils.typing import TensorType + +tf1, tf, tfv = try_import_tf() + +logger = logging.getLogger(__name__) + + +def PPOLoss( + dist_class, + actions, + curr_logits, + behaviour_logits, + advantages, + value_fn, + value_targets, + vf_preds, + cur_kl_coeff, + entropy_coeff, + clip_param, + vf_clip_param, + vf_loss_coeff, + clip_loss=False, +): + def surrogate_loss( + actions, curr_dist, prev_dist, advantages, clip_param, clip_loss + ): + pi_new_logp = curr_dist.logp(actions) + pi_old_logp = prev_dist.logp(actions) + + logp_ratio = tf.math.exp(pi_new_logp - pi_old_logp) + if clip_loss: + return tf.minimum( + advantages * logp_ratio, + advantages + * tf.clip_by_value(logp_ratio, 1 - clip_param, 1 + clip_param), + ) + return advantages * logp_ratio + + def kl_loss(curr_dist, prev_dist): + return prev_dist.kl(curr_dist) + + def entropy_loss(dist): + return dist.entropy() + + def vf_loss(value_fn, value_targets, vf_preds, vf_clip_param=0.1): + # GAE Value Function Loss + vf_loss1 = tf.math.square(value_fn - value_targets) + vf_clipped = vf_preds + tf.clip_by_value( + value_fn - vf_preds, -vf_clip_param, vf_clip_param + ) + vf_loss2 = tf.math.square(vf_clipped - value_targets) + vf_loss = tf.maximum(vf_loss1, vf_loss2) + return vf_loss + + pi_new_dist = dist_class(curr_logits, None) + pi_old_dist = dist_class(behaviour_logits, None) + + surr_loss = tf.reduce_mean( + surrogate_loss( + actions, pi_new_dist, pi_old_dist, advantages, clip_param, clip_loss + ) + ) + kl_loss = tf.reduce_mean(kl_loss(pi_new_dist, pi_old_dist)) + vf_loss = tf.reduce_mean(vf_loss(value_fn, value_targets, vf_preds, vf_clip_param)) + entropy_loss = tf.reduce_mean(entropy_loss(pi_new_dist)) + + total_loss = -surr_loss + cur_kl_coeff * kl_loss + total_loss += vf_loss_coeff * vf_loss - entropy_coeff * entropy_loss + return total_loss, surr_loss, kl_loss, vf_loss, entropy_loss + + +# This is the computation graph for workers (inner adaptation steps) +class WorkerLoss(object): + def __init__( + self, + dist_class, + actions, + curr_logits, + behaviour_logits, + advantages, + value_fn, + value_targets, + vf_preds, + cur_kl_coeff, + entropy_coeff, + clip_param, + vf_clip_param, + vf_loss_coeff, + clip_loss=False, + ): + self.loss, surr_loss, kl_loss, vf_loss, ent_loss = PPOLoss( + dist_class=dist_class, + actions=actions, + curr_logits=curr_logits, + behaviour_logits=behaviour_logits, + advantages=advantages, + value_fn=value_fn, + value_targets=value_targets, + vf_preds=vf_preds, + cur_kl_coeff=cur_kl_coeff, + entropy_coeff=entropy_coeff, + clip_param=clip_param, + vf_clip_param=vf_clip_param, + vf_loss_coeff=vf_loss_coeff, + clip_loss=clip_loss, + ) + self.loss = tf1.Print(self.loss, ["Worker Adapt Loss", self.loss]) + + +# This is the Meta-Update computation graph for main (meta-update step) +class MAMLLoss(object): + def __init__( + self, + model, + config, + dist_class, + value_targets, + advantages, + actions, + behaviour_logits, + vf_preds, + cur_kl_coeff, + policy_vars, + obs, + num_tasks, + split, + inner_adaptation_steps=1, + entropy_coeff=0, + clip_param=0.3, + vf_clip_param=0.1, + vf_loss_coeff=1.0, + use_gae=True, + ): + self.config = config + self.num_tasks = num_tasks + self.inner_adaptation_steps = inner_adaptation_steps + self.clip_param = clip_param + self.dist_class = dist_class + self.cur_kl_coeff = cur_kl_coeff + + # Split episode tensors into [inner_adaptation_steps+1, num_tasks, -1] + self.obs = self.split_placeholders(obs, split) + self.actions = self.split_placeholders(actions, split) + self.behaviour_logits = self.split_placeholders(behaviour_logits, split) + self.advantages = self.split_placeholders(advantages, split) + self.value_targets = self.split_placeholders(value_targets, split) + self.vf_preds = self.split_placeholders(vf_preds, split) + + # Construct name to tensor dictionary for easier indexing + self.policy_vars = {} + for var in policy_vars: + self.policy_vars[var.name] = var + + # Calculate pi_new for PPO + pi_new_logits, current_policy_vars, value_fns = [], [], [] + for i in range(self.num_tasks): + pi_new, value_fn = self.feed_forward( + self.obs[0][i], self.policy_vars, policy_config=config["model"] + ) + pi_new_logits.append(pi_new) + value_fns.append(value_fn) + current_policy_vars.append(self.policy_vars) + + inner_kls = [] + inner_ppo_loss = [] + + # Recompute weights for inner-adaptation (same weights as workers) + for step in range(self.inner_adaptation_steps): + kls = [] + for i in range(self.num_tasks): + # PPO Loss Function (only Surrogate) + ppo_loss, _, kl_loss, _, _ = PPOLoss( + dist_class=dist_class, + actions=self.actions[step][i], + curr_logits=pi_new_logits[i], + behaviour_logits=self.behaviour_logits[step][i], + advantages=self.advantages[step][i], + value_fn=value_fns[i], + value_targets=self.value_targets[step][i], + vf_preds=self.vf_preds[step][i], + cur_kl_coeff=0.0, + entropy_coeff=entropy_coeff, + clip_param=clip_param, + vf_clip_param=vf_clip_param, + vf_loss_coeff=vf_loss_coeff, + clip_loss=False, + ) + adapted_policy_vars = self.compute_updated_variables( + ppo_loss, current_policy_vars[i] + ) + pi_new_logits[i], value_fns[i] = self.feed_forward( + self.obs[step + 1][i], + adapted_policy_vars, + policy_config=config["model"], + ) + current_policy_vars[i] = adapted_policy_vars + kls.append(kl_loss) + inner_ppo_loss.append(ppo_loss) + + self.kls = kls + inner_kls.append(kls) + + mean_inner_kl = tf.stack( + [tf.reduce_mean(tf.stack(inner_kl)) for inner_kl in inner_kls] + ) + self.mean_inner_kl = mean_inner_kl + + ppo_obj = [] + for i in range(self.num_tasks): + ppo_loss, surr_loss, kl_loss, val_loss, entropy_loss = PPOLoss( + dist_class=dist_class, + actions=self.actions[self.inner_adaptation_steps][i], + curr_logits=pi_new_logits[i], + behaviour_logits=self.behaviour_logits[self.inner_adaptation_steps][i], + advantages=self.advantages[self.inner_adaptation_steps][i], + value_fn=value_fns[i], + value_targets=self.value_targets[self.inner_adaptation_steps][i], + vf_preds=self.vf_preds[self.inner_adaptation_steps][i], + cur_kl_coeff=0.0, + entropy_coeff=entropy_coeff, + clip_param=clip_param, + vf_clip_param=vf_clip_param, + vf_loss_coeff=vf_loss_coeff, + clip_loss=True, + ) + ppo_obj.append(ppo_loss) + self.mean_policy_loss = surr_loss + self.mean_kl = kl_loss + self.mean_vf_loss = val_loss + self.mean_entropy = entropy_loss + self.inner_kl_loss = tf.reduce_mean( + tf.multiply(self.cur_kl_coeff, mean_inner_kl) + ) + self.loss = tf.reduce_mean(tf.stack(ppo_obj, axis=0)) + self.inner_kl_loss + self.loss = tf1.Print( + self.loss, ["Meta-Loss", self.loss, "Inner KL", self.mean_inner_kl] + ) + + def feed_forward(self, obs, policy_vars, policy_config): + # Hacky for now, reconstruct FC network with adapted weights + # @mluo: TODO for any network + def fc_network( + inp, network_vars, hidden_nonlinearity, output_nonlinearity, policy_config + ): + bias_added = False + x = inp + for name, param in network_vars.items(): + if "kernel" in name: + x = tf.matmul(x, param) + elif "bias" in name: + x = tf.add(x, param) + bias_added = True + else: + raise NameError + + if bias_added: + if "out" not in name: + x = hidden_nonlinearity(x) + elif "out" in name: + x = output_nonlinearity(x) + else: + raise NameError + bias_added = False + return x + + policyn_vars = {} + valuen_vars = {} + log_std = None + for name, param in policy_vars.items(): + if "value" in name: + valuen_vars[name] = param + elif "log_std" in name: + log_std = param + else: + policyn_vars[name] = param + + output_nonlinearity = tf.identity + hidden_nonlinearity = get_activation_fn(policy_config["fcnet_activation"]) + + pi_new_logits = fc_network( + obs, policyn_vars, hidden_nonlinearity, output_nonlinearity, policy_config + ) + if log_std is not None: + pi_new_logits = tf.concat([pi_new_logits, 0.0 * pi_new_logits + log_std], 1) + value_fn = fc_network( + obs, valuen_vars, hidden_nonlinearity, output_nonlinearity, policy_config + ) + + return pi_new_logits, tf.reshape(value_fn, [-1]) + + def compute_updated_variables(self, loss, network_vars): + grad = tf.gradients(loss, list(network_vars.values())) + adapted_vars = {} + for i, tup in enumerate(network_vars.items()): + name, var = tup + if grad[i] is None: + adapted_vars[name] = var + else: + adapted_vars[name] = var - self.config["inner_lr"] * grad[i] + return adapted_vars + + def split_placeholders(self, placeholder, split): + inner_placeholder_list = tf.split( + placeholder, tf.math.reduce_sum(split, axis=1), axis=0 + ) + placeholder_list = [] + for index, split_placeholder in enumerate(inner_placeholder_list): + placeholder_list.append(tf.split(split_placeholder, split[index], axis=0)) + return placeholder_list + + +class KLCoeffMixin: + def __init__(self, config): + self.kl_coeff_val = [config["kl_coeff"]] * config["inner_adaptation_steps"] + self.kl_target = self.config["kl_target"] + self.kl_coeff = tf1.get_variable( + initializer=tf.keras.initializers.Constant(self.kl_coeff_val), + name="kl_coeff", + shape=(config["inner_adaptation_steps"]), + trainable=False, + dtype=tf.float32, + ) + + def update_kls(self, sampled_kls): + for i, kl in enumerate(sampled_kls): + if kl < self.kl_target / 1.5: + self.kl_coeff_val[i] *= 0.5 + elif kl > 1.5 * self.kl_target: + self.kl_coeff_val[i] *= 2.0 + print(self.kl_coeff_val) + self.kl_coeff.load(self.kl_coeff_val, session=self.get_session()) + return self.kl_coeff_val + + +# We need this builder function because we want to share the same +# custom logics between TF1 dynamic and TF2 eager policies. +def get_maml_tf_policy(name: str, base: type) -> type: + """Construct a MAMLTFPolicy inheriting either dynamic or eager base policies. + + Args: + base: Base class for this policy. DynamicTFPolicyV2 or EagerTFPolicyV2. + + Returns: + A TF Policy to be used with MAML. + """ + + class MAMLTFPolicy(KLCoeffMixin, ValueNetworkMixin, base): + def __init__( + self, + observation_space, + action_space, + config, + existing_model=None, + existing_inputs=None, + ): + # First thing first, enable eager execution if necessary. + base.enable_eager_execution_if_necessary() + + validate_config(config) + + # Initialize base class. + base.__init__( + self, + observation_space, + action_space, + config, + existing_inputs=existing_inputs, + existing_model=existing_model, + ) + + KLCoeffMixin.__init__(self, config) + ValueNetworkMixin.__init__(self, config) + + # Create the `split` placeholder before initialize loss. + if self.framework == "tf": + self._loss_input_dict["split"] = tf1.placeholder( + tf.int32, + name="Meta-Update-Splitting", + shape=( + self.config["inner_adaptation_steps"] + 1, + self.config["num_workers"], + ), + ) + + # Note: this is a bit ugly, but loss and optimizer initialization must + # happen after all the MixIns are initialized. + self.maybe_initialize_optimizer_and_loss() + + @override(base) + def loss( + self, + model: Union[ModelV2, "tf.keras.Model"], + dist_class: Type[TFActionDistribution], + train_batch: SampleBatch, + ) -> Union[TensorType, List[TensorType]]: + logits, state = model(train_batch) + self.cur_lr = self.config["lr"] + + if self.config["worker_index"]: + self.loss_obj = WorkerLoss( + dist_class=dist_class, + actions=train_batch[SampleBatch.ACTIONS], + curr_logits=logits, + behaviour_logits=train_batch[SampleBatch.ACTION_DIST_INPUTS], + advantages=train_batch[Postprocessing.ADVANTAGES], + value_fn=model.value_function(), + value_targets=train_batch[Postprocessing.VALUE_TARGETS], + vf_preds=train_batch[SampleBatch.VF_PREDS], + cur_kl_coeff=0.0, + entropy_coeff=self.config["entropy_coeff"], + clip_param=self.config["clip_param"], + vf_clip_param=self.config["vf_clip_param"], + vf_loss_coeff=self.config["vf_loss_coeff"], + clip_loss=False, + ) + else: + self.var_list = tf1.get_collection( + tf1.GraphKeys.TRAINABLE_VARIABLES, tf1.get_variable_scope().name + ) + self.loss_obj = MAMLLoss( + model=model, + dist_class=dist_class, + value_targets=train_batch[Postprocessing.VALUE_TARGETS], + advantages=train_batch[Postprocessing.ADVANTAGES], + actions=train_batch[SampleBatch.ACTIONS], + behaviour_logits=train_batch[SampleBatch.ACTION_DIST_INPUTS], + vf_preds=train_batch[SampleBatch.VF_PREDS], + cur_kl_coeff=self.kl_coeff, + policy_vars=self.var_list, + obs=train_batch[SampleBatch.CUR_OBS], + num_tasks=self.config["num_workers"], + split=train_batch["split"], + config=self.config, + inner_adaptation_steps=self.config["inner_adaptation_steps"], + entropy_coeff=self.config["entropy_coeff"], + clip_param=self.config["clip_param"], + vf_clip_param=self.config["vf_clip_param"], + vf_loss_coeff=self.config["vf_loss_coeff"], + use_gae=self.config["use_gae"], + ) + + return self.loss_obj.loss + + @override(base) + def optimizer( + self, + ) -> Union[ + "tf.keras.optimizers.Optimizer", List["tf.keras.optimizers.Optimizer"] + ]: + """ + Workers use simple SGD for inner adaptation + Meta-Policy uses Adam optimizer for meta-update + """ + if not self.config["worker_index"]: + return tf1.train.AdamOptimizer(learning_rate=self.config["lr"]) + return tf1.train.GradientDescentOptimizer( + learning_rate=self.config["inner_lr"] + ) + + @override(base) + def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]: + if self.config["worker_index"]: + return {"worker_loss": self.loss_obj.loss} + else: + return { + "cur_kl_coeff": tf.cast(self.kl_coeff, tf.float64), + "cur_lr": tf.cast(self.cur_lr, tf.float64), + "total_loss": self.loss_obj.loss, + "policy_loss": self.loss_obj.mean_policy_loss, + "vf_loss": self.loss_obj.mean_vf_loss, + "kl": self.loss_obj.mean_kl, + "inner_kl": self.loss_obj.mean_inner_kl, + "entropy": self.loss_obj.mean_entropy, + } + + @override(base) + def postprocess_trajectory( + self, sample_batch, other_agent_batches=None, episode=None + ): + sample_batch = super().postprocess_trajectory(sample_batch) + return compute_gae_for_sample_batch( + self, sample_batch, other_agent_batches, episode + ) + + @override(base) + def compute_gradients_fn( + self, optimizer: LocalOptimizer, loss: TensorType + ) -> ModelGradients: + return compute_gradients(self, optimizer, loss) + + MAMLTFPolicy.__name__ = name + MAMLTFPolicy.__qualname__ = name + + return MAMLTFPolicy + + +MAMLTF1Policy = get_maml_tf_policy("MAMLTF1Policy", DynamicTFPolicyV2) +MAMLTF2Policy = get_maml_tf_policy("MAMLTF2Policy", EagerTFPolicyV2) diff --git a/rllib_contrib/maml/src/rllib_maml/maml/maml_torch_policy.py b/rllib_contrib/maml/src/rllib_maml/maml/maml_torch_policy.py new file mode 100644 index 0000000000000..4a16f5eb950a1 --- /dev/null +++ b/rllib_contrib/maml/src/rllib_maml/maml/maml_torch_policy.py @@ -0,0 +1,449 @@ +import logging +from typing import Dict, List, Type, Union + +import ray +from ray.rllib.algorithms.ppo.ppo_tf_policy import validate_config +from ray.rllib.evaluation.postprocessing import ( + Postprocessing, + compute_gae_for_sample_batch, +) +from ray.rllib.models.modelv2 import ModelV2 +from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy.torch_mixins import ValueNetworkMixin +from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2 +from ray.rllib.utils.annotations import override +from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.numpy import convert_to_numpy +from ray.rllib.utils.torch_utils import apply_grad_clipping +from ray.rllib.utils.typing import TensorType + +torch, nn = try_import_torch() +logger = logging.getLogger(__name__) + +try: + import higher +except (ImportError, ModuleNotFoundError): + raise ImportError( + ( + "The MAML and MB-MPO algorithms require the `higher` module to be " + "installed! However, there was no installation found. You can install it " + "via `pip install higher`." + ) + ) + + +def PPOLoss( + dist_class, + actions, + curr_logits, + behaviour_logits, + advantages, + value_fn, + value_targets, + vf_preds, + cur_kl_coeff, + entropy_coeff, + clip_param, + vf_clip_param, + vf_loss_coeff, + clip_loss=False, +): + def surrogate_loss( + actions, curr_dist, prev_dist, advantages, clip_param, clip_loss + ): + pi_new_logp = curr_dist.logp(actions) + pi_old_logp = prev_dist.logp(actions) + + logp_ratio = torch.exp(pi_new_logp - pi_old_logp) + if clip_loss: + return torch.min( + advantages * logp_ratio, + advantages * torch.clamp(logp_ratio, 1 - clip_param, 1 + clip_param), + ) + return advantages * logp_ratio + + def kl_loss(curr_dist, prev_dist): + return prev_dist.kl(curr_dist) + + def entropy_loss(dist): + return dist.entropy() + + def vf_loss(value_fn, value_targets, vf_preds, vf_clip_param=0.1): + # GAE Value Function Loss + vf_loss1 = torch.pow(value_fn - value_targets, 2.0) + vf_clipped = vf_preds + torch.clamp( + value_fn - vf_preds, -vf_clip_param, vf_clip_param + ) + vf_loss2 = torch.pow(vf_clipped - value_targets, 2.0) + vf_loss = torch.max(vf_loss1, vf_loss2) + return vf_loss + + pi_new_dist = dist_class(curr_logits, None) + pi_old_dist = dist_class(behaviour_logits, None) + + surr_loss = torch.mean( + surrogate_loss( + actions, pi_new_dist, pi_old_dist, advantages, clip_param, clip_loss + ) + ) + kl_loss = torch.mean(kl_loss(pi_new_dist, pi_old_dist)) + vf_loss = torch.mean(vf_loss(value_fn, value_targets, vf_preds, vf_clip_param)) + entropy_loss = torch.mean(entropy_loss(pi_new_dist)) + + total_loss = -surr_loss + cur_kl_coeff * kl_loss + total_loss += vf_loss_coeff * vf_loss + total_loss -= entropy_coeff * entropy_loss + return total_loss, surr_loss, kl_loss, vf_loss, entropy_loss + + +# This is the computation graph for workers (inner adaptation steps) +class WorkerLoss(object): + def __init__( + self, + model, + dist_class, + actions, + curr_logits, + behaviour_logits, + advantages, + value_fn, + value_targets, + vf_preds, + cur_kl_coeff, + entropy_coeff, + clip_param, + vf_clip_param, + vf_loss_coeff, + clip_loss=False, + ): + self.loss, surr_loss, kl_loss, vf_loss, ent_loss = PPOLoss( + dist_class=dist_class, + actions=actions, + curr_logits=curr_logits, + behaviour_logits=behaviour_logits, + advantages=advantages, + value_fn=value_fn, + value_targets=value_targets, + vf_preds=vf_preds, + cur_kl_coeff=cur_kl_coeff, + entropy_coeff=entropy_coeff, + clip_param=clip_param, + vf_clip_param=vf_clip_param, + vf_loss_coeff=vf_loss_coeff, + clip_loss=clip_loss, + ) + + +# This is the Meta-Update computation graph for main (meta-update step) +class MAMLLoss(object): + def __init__( + self, + model, + config, + dist_class, + value_targets, + advantages, + actions, + behaviour_logits, + vf_preds, + cur_kl_coeff, + policy_vars, + obs, + num_tasks, + split, + meta_opt, + inner_adaptation_steps=1, + entropy_coeff=0, + clip_param=0.3, + vf_clip_param=0.1, + vf_loss_coeff=1.0, + use_gae=True, + ): + self.config = config + self.num_tasks = num_tasks + self.inner_adaptation_steps = inner_adaptation_steps + self.clip_param = clip_param + self.dist_class = dist_class + self.cur_kl_coeff = cur_kl_coeff + self.model = model + self.vf_clip_param = vf_clip_param + self.vf_loss_coeff = vf_loss_coeff + self.entropy_coeff = entropy_coeff + + # Split episode tensors into [inner_adaptation_steps+1, num_tasks, -1] + self.obs = self.split_placeholders(obs, split) + self.actions = self.split_placeholders(actions, split) + self.behaviour_logits = self.split_placeholders(behaviour_logits, split) + self.advantages = self.split_placeholders(advantages, split) + self.value_targets = self.split_placeholders(value_targets, split) + self.vf_preds = self.split_placeholders(vf_preds, split) + + inner_opt = torch.optim.SGD(model.parameters(), lr=config["inner_lr"]) + surr_losses = [] + val_losses = [] + kl_losses = [] + entropy_losses = [] + meta_losses = [] + kls = [] + + meta_opt.zero_grad() + for i in range(self.num_tasks): + with higher.innerloop_ctx(model, inner_opt, copy_initial_weights=False) as ( + fnet, + diffopt, + ): + inner_kls = [] + for step in range(self.inner_adaptation_steps): + ppo_loss, _, inner_kl_loss, _, _ = self.compute_losses( + fnet, step, i + ) + diffopt.step(ppo_loss) + inner_kls.append(inner_kl_loss) + kls.append(inner_kl_loss.detach()) + + # Meta Update + ppo_loss, s_loss, kl_loss, v_loss, ent = self.compute_losses( + fnet, self.inner_adaptation_steps - 1, i, clip_loss=True + ) + + inner_loss = torch.mean( + torch.stack( + [ + a * b + for a, b in zip( + self.cur_kl_coeff[ + i + * self.inner_adaptation_steps : (i + 1) + * self.inner_adaptation_steps + ], + inner_kls, + ) + ] + ) + ) + meta_loss = (ppo_loss + inner_loss) / self.num_tasks + meta_loss.backward() + + surr_losses.append(s_loss.detach()) + kl_losses.append(kl_loss.detach()) + val_losses.append(v_loss.detach()) + entropy_losses.append(ent.detach()) + meta_losses.append(meta_loss.detach()) + + meta_opt.step() + + # Stats Logging + self.mean_policy_loss = torch.mean(torch.stack(surr_losses)) + self.mean_kl_loss = torch.mean(torch.stack(kl_losses)) + self.mean_vf_loss = torch.mean(torch.stack(val_losses)) + self.mean_entropy = torch.mean(torch.stack(entropy_losses)) + self.mean_inner_kl = kls + self.loss = torch.sum(torch.stack(meta_losses)) + # Hacky, needed to bypass RLlib backend + self.loss.requires_grad = True + + def compute_losses(self, model, inner_adapt_iter, task_iter, clip_loss=False): + obs = self.obs[inner_adapt_iter][task_iter] + obs_dict = {"obs": obs, "obs_flat": obs} + curr_logits, _ = model.forward(obs_dict, None, None) + value_fns = model.value_function() + ppo_loss, surr_loss, kl_loss, val_loss, ent_loss = PPOLoss( + dist_class=self.dist_class, + actions=self.actions[inner_adapt_iter][task_iter], + curr_logits=curr_logits, + behaviour_logits=self.behaviour_logits[inner_adapt_iter][task_iter], + advantages=self.advantages[inner_adapt_iter][task_iter], + value_fn=value_fns, + value_targets=self.value_targets[inner_adapt_iter][task_iter], + vf_preds=self.vf_preds[inner_adapt_iter][task_iter], + cur_kl_coeff=0.0, + entropy_coeff=self.entropy_coeff, + clip_param=self.clip_param, + vf_clip_param=self.vf_clip_param, + vf_loss_coeff=self.vf_loss_coeff, + clip_loss=clip_loss, + ) + return ppo_loss, surr_loss, kl_loss, val_loss, ent_loss + + def split_placeholders(self, placeholder, split): + inner_placeholder_list = torch.split( + placeholder, torch.sum(split, dim=1).tolist(), dim=0 + ) + placeholder_list = [] + for index, split_placeholder in enumerate(inner_placeholder_list): + placeholder_list.append( + torch.split(split_placeholder, split[index].tolist(), dim=0) + ) + return placeholder_list + + +class KLCoeffMixin: + def __init__(self, config): + self.kl_coeff_val = ( + [config["kl_coeff"]] + * config["inner_adaptation_steps"] + * config["num_workers"] + ) + self.kl_target = self.config["kl_target"] + + def update_kls(self, sampled_kls): + for i, kl in enumerate(sampled_kls): + if kl < self.kl_target / 1.5: + self.kl_coeff_val[i] *= 0.5 + elif kl > 1.5 * self.kl_target: + self.kl_coeff_val[i] *= 2.0 + return self.kl_coeff_val + + +class MAMLTorchPolicy(ValueNetworkMixin, KLCoeffMixin, TorchPolicyV2): + """PyTorch policy class used with MAML.""" + + def __init__(self, observation_space, action_space, config): + config = dict(ray.rllib.algorithms.maml.maml.MAMLConfig(), **config) + validate_config(config) + + TorchPolicyV2.__init__( + self, + observation_space, + action_space, + config, + max_seq_len=config["model"]["max_seq_len"], + ) + + KLCoeffMixin.__init__(self, config) + ValueNetworkMixin.__init__(self, config) + + # TODO: Don't require users to call this manually. + self._initialize_loss_from_dummy_batch() + + @override(TorchPolicyV2) + def loss( + self, + model: ModelV2, + dist_class: Type[TorchDistributionWrapper], + train_batch: SampleBatch, + ) -> Union[TensorType, List[TensorType]]: + """Constructs the loss function. + + Args: + model: The Model to calculate the loss for. + dist_class: The action distr. class. + train_batch: The training data. + + Returns: + The PPO loss tensor given the input batch. + """ + logits, state = model(train_batch) + self.cur_lr = self.config["lr"] + + if self.config["worker_index"]: + self.loss_obj = WorkerLoss( + model=model, + dist_class=dist_class, + actions=train_batch[SampleBatch.ACTIONS], + curr_logits=logits, + behaviour_logits=train_batch[SampleBatch.ACTION_DIST_INPUTS], + advantages=train_batch[Postprocessing.ADVANTAGES], + value_fn=model.value_function(), + value_targets=train_batch[Postprocessing.VALUE_TARGETS], + vf_preds=train_batch[SampleBatch.VF_PREDS], + cur_kl_coeff=0.0, + entropy_coeff=self.config["entropy_coeff"], + clip_param=self.config["clip_param"], + vf_clip_param=self.config["vf_clip_param"], + vf_loss_coeff=self.config["vf_loss_coeff"], + clip_loss=False, + ) + else: + self.var_list = model.named_parameters() + + # `split` may not exist yet (during test-loss call), use a dummy value. + # Cannot use get here due to train_batch being a TrackingDict. + if "split" in train_batch: + split = train_batch["split"] + else: + split_shape = ( + self.config["inner_adaptation_steps"], + self.config["num_workers"], + ) + split_const = int( + train_batch["obs"].shape[0] // (split_shape[0] * split_shape[1]) + ) + split = torch.ones(split_shape, dtype=int) * split_const + self.loss_obj = MAMLLoss( + model=model, + dist_class=dist_class, + value_targets=train_batch[Postprocessing.VALUE_TARGETS], + advantages=train_batch[Postprocessing.ADVANTAGES], + actions=train_batch[SampleBatch.ACTIONS], + behaviour_logits=train_batch[SampleBatch.ACTION_DIST_INPUTS], + vf_preds=train_batch[SampleBatch.VF_PREDS], + cur_kl_coeff=self.kl_coeff_val, + policy_vars=self.var_list, + obs=train_batch[SampleBatch.CUR_OBS], + num_tasks=self.config["num_workers"], + split=split, + config=self.config, + inner_adaptation_steps=self.config["inner_adaptation_steps"], + entropy_coeff=self.config["entropy_coeff"], + clip_param=self.config["clip_param"], + vf_clip_param=self.config["vf_clip_param"], + vf_loss_coeff=self.config["vf_loss_coeff"], + use_gae=self.config["use_gae"], + meta_opt=self.meta_opt, + ) + + return self.loss_obj.loss + + @override(TorchPolicyV2) + def optimizer( + self, + ) -> Union[List["torch.optim.Optimizer"], "torch.optim.Optimizer"]: + """ + Workers use simple SGD for inner adaptation + Meta-Policy uses Adam optimizer for meta-update + """ + if not self.config["worker_index"]: + self.meta_opt = torch.optim.Adam( + self.model.parameters(), lr=self.config["lr"] + ) + return self.meta_opt + return torch.optim.SGD(self.model.parameters(), lr=self.config["inner_lr"]) + + @override(TorchPolicyV2) + def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]: + if self.config["worker_index"]: + return convert_to_numpy({"worker_loss": self.loss_obj.loss}) + else: + return convert_to_numpy( + { + "cur_kl_coeff": self.kl_coeff_val, + "cur_lr": self.cur_lr, + "total_loss": self.loss_obj.loss, + "policy_loss": self.loss_obj.mean_policy_loss, + "vf_loss": self.loss_obj.mean_vf_loss, + "kl_loss": self.loss_obj.mean_kl_loss, + "inner_kl": self.loss_obj.mean_inner_kl, + "entropy": self.loss_obj.mean_entropy, + } + ) + + @override(TorchPolicyV2) + def extra_grad_process( + self, optimizer: "torch.optim.Optimizer", loss: TensorType + ) -> Dict[str, TensorType]: + return apply_grad_clipping(self, optimizer, loss) + + @override(TorchPolicyV2) + def postprocess_trajectory( + self, sample_batch, other_agent_batches=None, episode=None + ): + # Do all post-processing always with no_grad(). + # Not using this here will introduce a memory leak + # in torch (issue #6962). + # TODO: no_grad still necessary? + with torch.no_grad(): + return compute_gae_for_sample_batch( + self, sample_batch, other_agent_batches, episode + ) diff --git a/rllib_contrib/maml/tests/test_maml.py b/rllib_contrib/maml/tests/test_maml.py new file mode 100644 index 0000000000000..774be4ecde413 --- /dev/null +++ b/rllib_contrib/maml/tests/test_maml.py @@ -0,0 +1,61 @@ +import unittest + +from gymnasium.wrappers import TimeLimit +from rllib_maml.envs.cartpole_mass import CartPoleMassEnv +from rllib_maml.envs.pendulum_mass import PendulumMassEnv +from rllib_maml.maml import MAMLConfig + +import ray +from ray.rllib.utils.test_utils import ( + check_compute_single_action, + check_train_results, + framework_iterator, +) +from ray.tune.registry import register_env + + +class TestMAML(unittest.TestCase): + @classmethod + def setUpClass(cls): + ray.init() + register_env( + "cartpole", + lambda env_cfg: TimeLimit(CartPoleMassEnv(), max_episode_steps=200), + ) + register_env( + "pendulum", + lambda env_cfg: TimeLimit(PendulumMassEnv(), max_episode_steps=200), + ) + + @classmethod + def tearDownClass(cls): + ray.shutdown() + + def test_maml_compilation(self): + """Test whether MAML can be built with all frameworks.""" + config = MAMLConfig().rollouts(num_rollout_workers=1) + + num_iterations = 1 + + # Test for tf framework (torch not implemented yet). + for fw in framework_iterator(config, frameworks=("tf", "torch")): + for env in ["cartpole", "pendulum"]: + if fw == "tf" and env.startswith("cartpole"): + continue + print("env={}".format(env)) + config.environment(env) + algo = config.build() + for i in range(num_iterations): + results = algo.train() + check_train_results(results) + print(results) + check_compute_single_action(algo, include_prev_action_reward=True) + algo.stop() + + +if __name__ == "__main__": + import sys + + import pytest + + sys.exit(pytest.main(["-v", __file__]))