-
Notifications
You must be signed in to change notification settings - Fork 5.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[RLlib-contrib] Alpha Zero. (#36736)
- Loading branch information
Showing
12 changed files
with
1,114 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# Alpha Zero | ||
|
||
[Alpha Zero](https://arxiv.org/abs/1712.01815) is a general reinforcement learning approach that achieved superhuman performance in the games of chess, shogi, and Go through tabula rasa learning from games of self-play, surpassing previous state-of-the-art programs that relied on handcrafted evaluation functions and domain-specific adaptations. | ||
|
||
## Installation | ||
|
||
``` | ||
conda create -n rllib-alpha-zero python=3.10 | ||
conda activate rllib-alpha-zero | ||
pip install -r requirements.txt | ||
pip install -e '.[development]' | ||
``` | ||
|
||
## Usage | ||
|
||
[AlphaZero Example]() |
73 changes: 73 additions & 0 deletions
73
rllib_contrib/alpha_zero/examples/alpha_zero_cartpole_sparse_rewards.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import argparse | ||
|
||
from rllib_alpha_zero.alpha_zero import AlphaZero, AlphaZeroConfig | ||
from rllib_alpha_zero.alpha_zero.custom_torch_models import DenseModel | ||
|
||
import ray | ||
from ray import air, tune | ||
from ray.rllib.examples.env.cartpole_sparse_rewards import CartPoleSparseRewards | ||
from ray.rllib.utils.test_utils import check_learning_achieved | ||
|
||
|
||
def get_cli_args(): | ||
"""Create CLI parser and return parsed arguments""" | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--run-as-test", action="store_true", default=False) | ||
args = parser.parse_args() | ||
print(f"Running with following CLI args: {args}") | ||
return args | ||
|
||
|
||
if __name__ == "__main__": | ||
args = get_cli_args() | ||
|
||
ray.init() | ||
|
||
config = ( | ||
AlphaZeroConfig() | ||
.rollouts( | ||
num_rollout_workers=6, | ||
rollout_fragment_length=50, | ||
) | ||
.framework("torch") | ||
.environment(CartPoleSparseRewards) | ||
.training( | ||
train_batch_size=500, | ||
sgd_minibatch_size=64, | ||
lr=1e-4, | ||
num_sgd_iter=1, | ||
mcts_config={ | ||
"puct_coefficient": 1.5, | ||
"num_simulations": 100, | ||
"temperature": 1.0, | ||
"dirichlet_epsilon": 0.20, | ||
"dirichlet_noise": 0.03, | ||
"argmax_tree_policy": False, | ||
"add_dirichlet_noise": True, | ||
}, | ||
ranked_rewards={ | ||
"enable": True, | ||
}, | ||
model={ | ||
"custom_model": DenseModel, | ||
}, | ||
) | ||
) | ||
|
||
stop_reward = 30.0 | ||
|
||
tuner = tune.Tuner( | ||
AlphaZero, | ||
param_space=config.to_dict(), | ||
run_config=air.RunConfig( | ||
stop={ | ||
"sampler_results/episode_reward_mean": stop_reward, | ||
"timesteps_total": 100000, | ||
}, | ||
failure_config=air.FailureConfig(fail_fast="raise"), | ||
), | ||
) | ||
results = tuner.fit() | ||
|
||
if args.run_as_test: | ||
check_learning_achieved(results, stop_reward) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
[build-system] | ||
requires = ["setuptools>=61.0"] | ||
build-backend = "setuptools.build_meta" | ||
|
||
[tool.setuptools.packages.find] | ||
where = ["src"] | ||
|
||
[project] | ||
name = "rllib-alpha-zero" | ||
authors = [{name = "Anyscale Inc."}] | ||
version = "0.1.0" | ||
description = "" | ||
readme = "README.md" | ||
requires-python = ">=3.7, <3.11" | ||
dependencies = ["gymnasium==0.26.3", "ray[rllib]==2.5.1"] | ||
|
||
[project.optional-dependencies] | ||
development = ["pytest>=7.2.2", "pre-commit==2.21.0", "torch==1.12.0"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
torch==1.12.0 |
17 changes: 17 additions & 0 deletions
17
rllib_contrib/alpha_zero/src/rllib_alpha_zero/alpha_zero/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from rllib_alpha_zero.alpha_zero.alpha_zero import ( | ||
AlphaZero, | ||
AlphaZeroConfig, | ||
AlphaZeroDefaultCallbacks, | ||
) | ||
from rllib_alpha_zero.alpha_zero.alpha_zero_policy import AlphaZeroPolicy | ||
|
||
from ray.tune.registry import register_trainable | ||
|
||
__all__ = [ | ||
"AlphaZeroConfig", | ||
"AlphaZero", | ||
"AlphaZeroDefaultCallbacks", | ||
"AlphaZeroPolicy", | ||
] | ||
|
||
register_trainable("rllib-contrib-alpha-zero", AlphaZero) |
Oops, something went wrong.