Skip to content

Commit

Permalink
[RLlib-Contrib] Simple Q. (#36688)
Browse files Browse the repository at this point in the history
  • Loading branch information
avnishn committed Oct 4, 2023
1 parent 5f25de5 commit 4657552
Show file tree
Hide file tree
Showing 12 changed files with 1,189 additions and 0 deletions.
12 changes: 12 additions & 0 deletions .buildkite/pipeline.ml.yml
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,18 @@
- pytest rllib_contrib/r2d2/tests/
- python rllib_contrib/r2d2/examples/r2d2_stateless_cartpole.py --run-as-test

- label: ":exploding_death_star: RLlib Contrib: SimpleQ Tests"
conditions: ["NO_WHEELS_REQUIRED", "RAY_CI_RLLIB_CONTRIB_AFFECTED"]
commands:
- cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT
- conda deactivate
- conda create -n rllib_contrib python=3.8 -y
- conda activate rllib_contrib
- (cd rllib_contrib/simple_q && pip install -r requirements.txt && pip install -e ".[development"])
- ./ci/env/env_info.sh
- pytest rllib_contrib/simple_q/tests/
- python rllib_contrib/simple_q/examples/simple_q_cartpole_v1.py --run-as-test

- label: ":exploding_death_star: RLlib Contrib: TD3 Tests"
conditions: ["NO_WHEELS_REQUIRED", "RAY_CI_RLLIB_CONTRIB_AFFECTED"]
commands:
Expand Down
18 changes: 18 additions & 0 deletions rllib_contrib/simple_q/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Simple Q (DQN)

[Simple Q](https://arxiv.org/abs/1602.01783) is an implementation of the DQN algorithm without any
optimizations.


## Installation

```
conda create -n rllib-simpleq python=3.10
conda activate rllib-simpleq
pip install -r requirements.txt
pip install -e '.[development]'
```

## Usage

[SimpleQ Example]()
42 changes: 42 additions & 0 deletions rllib_contrib/simple_q/examples/simple_q_cartpole_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import argparse

from rllib_simple_q.simple_q import SimpleQ, SimpleQConfig

import ray
from ray import air, tune
from ray.rllib.utils.test_utils import check_learning_achieved


def get_cli_args():
"""Create CLI parser and return parsed arguments"""
parser = argparse.ArgumentParser()
parser.add_argument("--run-as-test", action="store_true", default=False)
args = parser.parse_args()
print(f"Running with following CLI args: {args}")
return args


if __name__ == "__main__":
args = get_cli_args()

ray.init()

config = SimpleQConfig().framework("torch").environment("CartPole-v1")

stop_reward = 150

tuner = tune.Tuner(
SimpleQ,
param_space=config.to_dict(),
run_config=air.RunConfig(
stop={
"sampler_results/episode_reward_mean": stop_reward,
"timesteps_total": 50000,
},
failure_config=air.FailureConfig(fail_fast="raise"),
),
)
results = tuner.fit()

if args.run_as_test:
check_learning_achieved(results, stop_reward)
18 changes: 18 additions & 0 deletions rllib_contrib/simple_q/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"

[tool.setuptools.packages.find]
where = ["src"]

[project]
name = "rllib-simpleq"
authors = [{name = "Anyscale Inc."}]
version = "0.1.0"
description = ""
readme = "README.md"
requires-python = ">=3.7, <3.11"
dependencies = ["gymnasium==0.26.3", "ray[rllib]==2.5.1"]

[project.optional-dependencies]
development = ["pytest>=7.2.2", "pre-commit==2.21.0", "tensorflow==2.11.0", "torch==1.12.0"]
2 changes: 2 additions & 0 deletions rllib_contrib/simple_q/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
tensorflow==2.11.0
torch==1.12.0
18 changes: 18 additions & 0 deletions rllib_contrib/simple_q/src/rllib_simple_q/simple_q/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from rllib_simple_q.simple_q.simple_q import SimpleQ, SimpleQConfig
from rllib_simple_q.simple_q.simple_q_tf_policy import (
SimpleQTF1Policy,
SimpleQTF2Policy,
)
from rllib_simple_q.simple_q.simple_q_torch_policy import SimpleQTorchPolicy

from ray.tune.registry import register_trainable

__all__ = [
"SimpleQ",
"SimpleQConfig",
"SimpleQTF1Policy",
"SimpleQTF2Policy",
"SimpleQTorchPolicy",
]

register_trainable("rllib-contrib-simple-q", SimpleQ)
Loading

0 comments on commit 4657552

Please sign in to comment.