Skip to content

Commit

Permalink
WIP.
Browse files Browse the repository at this point in the history
  • Loading branch information
sven1977 committed Dec 20, 2020
1 parent cc31641 commit 67cf4d3
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 20 deletions.
9 changes: 3 additions & 6 deletions rllib/agents/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,12 +233,9 @@ def update_prio(item):
if config.get("prioritized_replay"):
prio_dict = {}
for policy_id, info in info_dict.items():
# TODO(sven): This is currently structured differently for
# torch/tf. Clean up these results/info dicts across
# policies (note: fixing this in torch_policy.py will
# break e.g. DDPPO!).
td_error = info.get("td_error",
info[LEARNER_STATS_KEY].get("td_error"))
#TODO: check, whether correct now and resolve below disambiguation
td_error = info.get("td_error")
#info[LEARNER_STATS_KEY].get("td_error"))
prio_dict[policy_id] = (samples.policy_batches[policy_id]
.data.get("batch_indexes"), td_error)
local_replay_buffer.update_priorities(prio_dict)
Expand Down
12 changes: 7 additions & 5 deletions rllib/agents/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
StandardizeFields, SelectExperiences
from ray.rllib.execution.train_ops import TrainOneStep, TrainTFMultiGPU
from ray.rllib.execution.metric_ops import StandardMetricsReporting
from ray.rllib.policy.policy import Policy
from ray.rllib.policy.policy import LEARNER_STATS_KEY, Policy
from ray.rllib.utils.typing import TrainerConfigDict
from ray.util.iter import LocalIterator

Expand Down Expand Up @@ -174,12 +174,14 @@ def __init__(self, workers):

def __call__(self, fetches):
def update(pi, pi_id):
assert "kl" not in fetches, (
"kl should be nested under policy id key", fetches)
assert LEARNER_STATS_KEY not in fetches, \
("{} should be nested under policy id key".format(
LEARNER_STATS_KEY), fetches)
if pi_id in fetches:
assert "kl" in fetches[pi_id], (fetches, pi_id)
kl = fetches[pi_id].get(LEARNER_STATS_KEY, {}).get("kl")
assert kl is not None, (fetches, pi_id)
# Make the actual `Policy.update_kl()` call.
pi.update_kl(fetches[pi_id]["kl"])
pi.update_kl(kl)
else:
logger.warning("No data for {}, not updating kl".format(pi_id))

Expand Down
5 changes: 3 additions & 2 deletions rllib/execution/multi_gpu_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,9 @@ def optimize(self, sess, batch_index):
feed_dict.update(tower.loss_graph.extra_compute_grad_feed_dict())

fetches = {"train": self._train_op}
for tower in self._towers:
fetches.update(tower.loss_graph._get_grad_and_stats_fetches())
for tower_num, tower in enumerate(self._towers):
tower_fetch = tower.loss_graph._get_grad_and_stats_fetches()
fetches["tower_{}".format(tower_num)] = tower_fetch

return sess.run(fetches, feed_dict=feed_dict)

Expand Down
29 changes: 23 additions & 6 deletions rllib/execution/train_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import numpy as np
import math
import tree
from typing import List, Tuple, Any

import ray
Expand Down Expand Up @@ -206,14 +207,30 @@ def __call__(self,
batch_fetches = optimizer.optimize(
self.sess, permutation[batch_index] *
self.per_device_batch_size)
for k, v in batch_fetches[LEARNER_STATS_KEY].items():
TODO: multi-GPU optimizer here does not collect td_error (which is stored outside LEARNER_STATS_KEY)
that's why it doesn't show up in the returned fetches.
iter_extra_fetches[k].append(v)

#def mapping_fn(*s):
# s

iter_extra_fetches = tree.map_structure(
lambda *s: np.array(s),
*(batch_fetches["tower_{}".format(i)] for i in range(len(self.devices))))

#for k, v in batch_fetches.items():
# if k == LEARNER_STATS_KEY:
# for k, v in batch_fetches[LEARNER_STATS_KEY].items():
# #TODO: multi-GPU optimizer here does not collect td_error (which is stored outside LEARNER_STATS_KEY)
# #that's why it doesn't show up in the returned fetches.
# iter_extra_fetches[k].append(v)
# elif k == "train_op":
# continue
# else:
# iter_extra_fetches[k].append(v)
if logger.getEffectiveLevel() <= logging.DEBUG:
avg = averaged(iter_extra_fetches)
avg = tree.map_structure_with_path(lambda p, s: np.nanmean(s, axis=0) if p[0] != "td_error" else np.concatenate(s, axis=0), iter_extra_fetches)
logger.debug("{} {}".format(i, avg))
fetches[policy_id] = averaged(iter_extra_fetches, axis=0)
fetches[policy_id] = tree.map_structure_with_path(lambda p, s: np.nanmean(s, axis=0) if p[0] != "td_error" else np.concatenate(s, axis=0), iter_extra_fetches)
#fetches[policy_id] = tree.map_structure(lambda s: np.nanmean(s, axis=0), iter_extra_fetches)#averaged(iter_extra_fetches, axis=0)
#fetches[policy_id] = tree.unflatten_as()

load_timer.push_units_processed(samples.count)
learn_timer.push_units_processed(samples.count)
Expand Down
2 changes: 1 addition & 1 deletion rllib/policy/dynamic_tf_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ def copy(self,
self.action_space,
self.config,
existing_inputs=input_dict,
existing_model=[self.model, ("target_q_model", getattr(self, "target_q_model"))])
existing_model=[self.model, ("target_q_model", getattr(self, "target_q_model", None))])

instance._loss_input_dict = input_dict
loss = instance._do_loss_init(input_dict)
Expand Down

0 comments on commit 67cf4d3

Please sign in to comment.