-
Notifications
You must be signed in to change notification settings - Fork 5.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
1,189 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# Simple Q (DQN) | ||
|
||
[Simple Q](https://arxiv.org/abs/1602.01783) is an implementation of the DQN algorithm without any | ||
optimizations. | ||
|
||
|
||
## Installation | ||
|
||
``` | ||
conda create -n rllib-simpleq python=3.10 | ||
conda activate rllib-simpleq | ||
pip install -r requirements.txt | ||
pip install -e '.[development]' | ||
``` | ||
|
||
## Usage | ||
|
||
[SimpleQ Example]() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import argparse | ||
|
||
from rllib_simple_q.simple_q import SimpleQ, SimpleQConfig | ||
|
||
import ray | ||
from ray import air, tune | ||
from ray.rllib.utils.test_utils import check_learning_achieved | ||
|
||
|
||
def get_cli_args(): | ||
"""Create CLI parser and return parsed arguments""" | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--run-as-test", action="store_true", default=False) | ||
args = parser.parse_args() | ||
print(f"Running with following CLI args: {args}") | ||
return args | ||
|
||
|
||
if __name__ == "__main__": | ||
args = get_cli_args() | ||
|
||
ray.init() | ||
|
||
config = SimpleQConfig().framework("torch").environment("CartPole-v1") | ||
|
||
stop_reward = 150 | ||
|
||
tuner = tune.Tuner( | ||
SimpleQ, | ||
param_space=config.to_dict(), | ||
run_config=air.RunConfig( | ||
stop={ | ||
"sampler_results/episode_reward_mean": stop_reward, | ||
"timesteps_total": 50000, | ||
}, | ||
failure_config=air.FailureConfig(fail_fast="raise"), | ||
), | ||
) | ||
results = tuner.fit() | ||
|
||
if args.run_as_test: | ||
check_learning_achieved(results, stop_reward) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
[build-system] | ||
requires = ["setuptools>=61.0"] | ||
build-backend = "setuptools.build_meta" | ||
|
||
[tool.setuptools.packages.find] | ||
where = ["src"] | ||
|
||
[project] | ||
name = "rllib-simpleq" | ||
authors = [{name = "Anyscale Inc."}] | ||
version = "0.1.0" | ||
description = "" | ||
readme = "README.md" | ||
requires-python = ">=3.7, <3.11" | ||
dependencies = ["gymnasium==0.26.3", "ray[rllib]==2.5.1"] | ||
|
||
[project.optional-dependencies] | ||
development = ["pytest>=7.2.2", "pre-commit==2.21.0", "tensorflow==2.11.0", "torch==1.12.0"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
tensorflow==2.11.0 | ||
torch==1.12.0 |
18 changes: 18 additions & 0 deletions
18
rllib_contrib/simple_q/src/rllib_simple_q/simple_q/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from rllib_simple_q.simple_q.simple_q import SimpleQ, SimpleQConfig | ||
from rllib_simple_q.simple_q.simple_q_tf_policy import ( | ||
SimpleQTF1Policy, | ||
SimpleQTF2Policy, | ||
) | ||
from rllib_simple_q.simple_q.simple_q_torch_policy import SimpleQTorchPolicy | ||
|
||
from ray.tune.registry import register_trainable | ||
|
||
__all__ = [ | ||
"SimpleQ", | ||
"SimpleQConfig", | ||
"SimpleQTF1Policy", | ||
"SimpleQTF2Policy", | ||
"SimpleQTorchPolicy", | ||
] | ||
|
||
register_trainable("rllib-contrib-simple-q", SimpleQ) |
Oops, something went wrong.