Skip to content

Commit

Permalink
[RLlib] Fix Batch Norm Model and QMix VDN mixer flakeyness. (ray-proj…
Browse files Browse the repository at this point in the history
…ect#31371)

Signed-off-by: tmynn <[email protected]>
  • Loading branch information
ArturNiederfahrenhorst authored and tamohannes committed Jan 16, 2023
1 parent e7dea14 commit cb72895
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
10 changes: 5 additions & 5 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,7 @@ py_test(
name = "learning_tests_two_step_game_qmix_vdn_mixer",
main = "tests/run_regression_tests.py",
tags = ["team:rllib", "exclusive", "learning_tests", "learning_tests_discrete"],
size = "medium", # bazel may complain about it being too long sometimes - medium is on purpose as some frameworks take longer
size = "large",
srcs = ["tests/run_regression_tests.py"],
data = ["tuned_examples/qmix/two-step-game-qmix-vdn-mixer.yaml"],
args = ["--dir=tuned_examples/qmix", "--framework=torch"]
Expand Down Expand Up @@ -2683,18 +2683,18 @@ py_test(
name = "examples/batch_norm_model_dqn_tf",
main = "examples/batch_norm_model.py",
tags = ["team:rllib", "exclusive", "examples"],
size = "medium",
size = "large",
srcs = ["examples/batch_norm_model.py"],
args = ["--as-test", "--run=DQN", "--stop-reward=70"]
args = ["--as-test", "--run=DQN", "--stop-reward=70", "--stop-time=400"]
)

py_test(
name = "examples/batch_norm_model_dqn_torch",
main = "examples/batch_norm_model.py",
tags = ["team:rllib", "exclusive", "examples"],
size = "medium", # DQN learns much slower with BatchNorm.
size = "large", # DQN learns much slower with BatchNorm.
srcs = ["examples/batch_norm_model.py"],
args = ["--as-test", "--framework=torch", "--run=DQN", "--stop-reward=70"]
args = ["--as-test", "--framework=torch", "--run=DQN", "--stop-reward=70", "--stop-time=400"]
)

py_test(
Expand Down
9 changes: 8 additions & 1 deletion rllib/examples/batch_norm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@
parser.add_argument(
"--stop-reward", type=float, default=150.0, help="Reward at which we stop training."
)
parser.add_argument(
"--stop-time",
type=float,
default=60 * 60,
help="Time (in seconds) after which we stop training.",
)

if __name__ == "__main__":
args = parser.parse_args()
Expand All @@ -61,7 +67,7 @@
.get_default_config()
.environment("Pendulum-v1" if args.run in ["DDPG", "SAC"] else "CartPole-v1")
.framework(args.framework)
.rollouts(num_rollout_workers=0)
.rollouts(num_rollout_workers=3)
.training(model={"custom_model": "bn_model"}, lr=0.0003)
# Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
.resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
Expand All @@ -71,6 +77,7 @@
"training_iteration": args.stop_iters,
"timesteps_total": args.stop_timesteps,
"episode_reward_mean": args.stop_reward,
"time_total_s": args.stop_time,
}

tuner = tune.Tuner(
Expand Down

0 comments on commit cb72895

Please sign in to comment.