-
Notifications
You must be signed in to change notification settings - Fork 5.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
PG unify/cleanup tf vs torch and PG functionality test cases (tf + to…
…rch). (#6650) * Unifying the code for PGTrainer/Policy wrt tf vs torch. Adding loss function test cases for the PGAgent (confirm equivalence of tf and torch). * Fix LINT line-len errors. * Fix LINT errors. * Fix `tf_pg_policy` imports (formerly: `pg_policy`). * Rename tf_pg_... into pg_tf_... following <alg>_<framework>_... convention, where ...=policy/loss/agent/trainer. Retire `PGAgent` class (use PGTrainer instead). * - Move PG test into agents/pg/tests directory. - All test cases will be located near the classes that are tested and then built into the Bazel/Travis test suite. * Moved post_process_advantages into pg.py (from pg_tf_policy.py), b/c the function is not a tf-specific one. * Fix remaining import errors for agents/pg/... * Fix circular dependency in pg imports. * Add pg tests to Jenkins test suite.
- Loading branch information
Showing
21 changed files
with
215 additions
and
102 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
from ray.rllib.agents.pg.pg import PGTrainer, DEFAULT_CONFIG | ||
from ray.rllib.utils import renamed_agent | ||
from ray.rllib.agents.pg.pg_tf_policy import pg_tf_loss, \ | ||
post_process_advantages | ||
from ray.rllib.agents.pg.pg_torch_policy import pg_torch_loss | ||
|
||
PGAgent = renamed_agent(PGTrainer) | ||
|
||
__all__ = ["PGAgent", "PGTrainer", "DEFAULT_CONFIG"] | ||
__all__ = ["PGTrainer", "pg_tf_loss", "pg_torch_loss", | ||
"post_process_advantages", "DEFAULT_CONFIG"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import ray | ||
from ray.rllib.evaluation.postprocessing import Postprocessing, \ | ||
compute_advantages | ||
from ray.rllib.policy.tf_policy_template import build_tf_policy | ||
from ray.rllib.policy.sample_batch import SampleBatch | ||
from ray.rllib.utils import try_import_tf | ||
|
||
tf = try_import_tf() | ||
|
||
|
||
def post_process_advantages(policy, sample_batch, other_agent_batches=None, | ||
episode=None): | ||
"""This adds the "advantages" column to the sample train_batch.""" | ||
return compute_advantages(sample_batch, 0.0, policy.config["gamma"], | ||
use_gae=False) | ||
|
||
|
||
def pg_tf_loss(policy, model, dist_class, train_batch): | ||
"""The basic policy gradients loss.""" | ||
logits, _ = model.from_batch(train_batch) | ||
action_dist = dist_class(logits, model) | ||
return -tf.reduce_mean(action_dist.logp(train_batch[SampleBatch.ACTIONS]) | ||
* train_batch[Postprocessing.ADVANTAGES]) | ||
|
||
|
||
PGTFPolicy = build_tf_policy( | ||
name="PGTFPolicy", | ||
get_default_config=lambda: ray.rllib.agents.pg.pg.DEFAULT_CONFIG, | ||
postprocess_fn=post_process_advantages, | ||
loss_fn=pg_tf_loss) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import ray | ||
from ray.rllib.agents.pg.pg_tf_policy import post_process_advantages | ||
from ray.rllib.evaluation.postprocessing import Postprocessing | ||
from ray.rllib.policy.sample_batch import SampleBatch | ||
from ray.rllib.policy.torch_policy_template import build_torch_policy | ||
from ray.rllib.utils.framework import try_import_torch | ||
|
||
torch, _ = try_import_torch() | ||
|
||
|
||
def pg_torch_loss(policy, model, dist_class, train_batch): | ||
"""The basic policy gradients loss.""" | ||
logits, _ = model.from_batch(train_batch) | ||
action_dist = dist_class(logits, model) | ||
log_probs = action_dist.logp(train_batch[SampleBatch.ACTIONS]) | ||
# Save the error in the policy object. | ||
# policy.pi_err = -train_batch[Postprocessing.ADVANTAGES].dot( | ||
# log_probs.reshape(-1)) / len(log_probs) | ||
policy.pi_err = -torch.mean( | ||
log_probs * train_batch[Postprocessing.ADVANTAGES] | ||
) | ||
return policy.pi_err | ||
|
||
|
||
def pg_loss_stats(policy, train_batch): | ||
""" The error is recorded when computing the loss.""" | ||
return {"policy_loss": policy.pi_err.item()} | ||
|
||
|
||
PGTorchPolicy = build_torch_policy( | ||
name="PGTorchPolicy", | ||
get_default_config=lambda: ray.rllib.agents.pg.pg.DEFAULT_CONFIG, | ||
loss_fn=pg_torch_loss, | ||
stats_fn=pg_loss_stats, | ||
postprocess_fn=post_process_advantages) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
import numpy as np | ||
import unittest | ||
|
||
import ray | ||
import ray.rllib.agents.pg as pg | ||
from ray.rllib.evaluation.postprocessing import Postprocessing | ||
from ray.rllib.models.tf.tf_action_dist import Categorical | ||
from ray.rllib.models.torch.torch_action_dist import TorchCategorical | ||
from ray.rllib.policy.sample_batch import SampleBatch | ||
from ray.rllib.utils import check, fc | ||
|
||
|
||
class TestPG(unittest.TestCase): | ||
|
||
ray.init() | ||
|
||
def test_pg_compilation(self): | ||
"""Test whether a PGTrainer can be built with both frameworks.""" | ||
config = pg.DEFAULT_CONFIG.copy() | ||
config["num_workers"] = 0 # Run locally. | ||
|
||
# tf. | ||
trainer = pg.PGTrainer(config=config, env="CartPole-v0") | ||
|
||
num_iterations = 2 | ||
for i in range(num_iterations): | ||
trainer.train() | ||
|
||
# Torch. | ||
config["use_pytorch"] = True | ||
trainer = pg.PGTrainer(config=config, env="CartPole-v0") | ||
for i in range(num_iterations): | ||
trainer.train() | ||
|
||
def test_pg_loss_functions(self): | ||
"""Tests the PG loss function math.""" | ||
config = pg.DEFAULT_CONFIG.copy() | ||
config["num_workers"] = 0 # Run locally. | ||
config["eager"] = True | ||
config["gamma"] = 0.99 | ||
config["model"]["fcnet_hiddens"] = [10] | ||
config["model"]["fcnet_activation"] = "linear" | ||
|
||
# Fake CartPole episode of n timesteps. | ||
train_batch = { | ||
SampleBatch.CUR_OBS: np.array([ | ||
[0.1, 0.2, 0.3, 0.4], | ||
[0.5, 0.6, 0.7, 0.8], | ||
[0.9, 1.0, 1.1, 1.2] | ||
]), | ||
SampleBatch.ACTIONS: np.array([0, 1, 1]), | ||
SampleBatch.REWARDS: np.array([1.0, 1.0, 1.0]), | ||
SampleBatch.DONES: np.array([False, False, True]) | ||
} | ||
|
||
# tf. | ||
trainer = pg.PGTrainer(config=config, env="CartPole-v0") | ||
policy = trainer.get_policy() | ||
vars = policy.model.trainable_variables() | ||
|
||
# Post-process (calculate simple (non-GAE) advantages) and attach to | ||
# train_batch dict. | ||
# A = [0.99^2 * 1.0 + 0.99 * 1.0 + 1.0, 0.99 * 1.0 + 1.0, 1.0] = | ||
# [2.9701, 1.99, 1.0] | ||
train_batch = pg.post_process_advantages(policy, train_batch) | ||
# Check Advantage values. | ||
check(train_batch[Postprocessing.ADVANTAGES], [2.9701, 1.99, 1.0]) | ||
|
||
# Actual loss results. | ||
results = pg.pg_tf_loss( | ||
policy, policy.model, dist_class=Categorical, | ||
train_batch=train_batch | ||
) | ||
|
||
# Calculate expected results. | ||
expected_logits = fc( | ||
fc( | ||
train_batch[SampleBatch.CUR_OBS], | ||
vars[0].numpy(), vars[1].numpy() | ||
), | ||
vars[2].numpy(), vars[3].numpy() | ||
) | ||
expected_logp = Categorical(expected_logits, policy.model).logp( | ||
train_batch[SampleBatch.ACTIONS] | ||
) | ||
expected_loss = -np.mean( | ||
expected_logp * train_batch[Postprocessing.ADVANTAGES] | ||
) | ||
check(results.numpy(), expected_loss, decimals=4) | ||
|
||
# Torch. | ||
config["use_pytorch"] = True | ||
trainer = pg.PGTrainer(config=config, env="CartPole-v0") | ||
policy = trainer.get_policy() | ||
train_batch = policy._lazy_tensor_dict(train_batch) | ||
results = pg.pg_torch_loss( | ||
policy, policy.model, dist_class=TorchCategorical, | ||
train_batch=train_batch | ||
) | ||
expected_logits = policy.model._last_output | ||
expected_logp = TorchCategorical(expected_logits, policy.model).logp( | ||
train_batch[SampleBatch.ACTIONS] | ||
) | ||
expected_loss = -np.mean( | ||
expected_logp.detach().numpy() * | ||
train_batch[Postprocessing.ADVANTAGES].numpy() | ||
) | ||
check(results.detach().numpy(), expected_loss, decimals=4) |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.