Skip to content

Commit

Permalink
[RLlib] Fix bug in tf_learner.py; Learner.update method has accidenta…
Browse files Browse the repository at this point in the history
…l closure. (ray-project#35664)

Signed-off-by: sven1977 <[email protected]>
  • Loading branch information
sven1977 committed May 23, 2023
1 parent 5073be7 commit dad239a
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion rllib/core/learner/tf/tf_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ def helper(_batch):
# constraint on forward_train and compute_loss APIs. This seems to be
# in-efficient. However, for tf>=2.12, it works also w/o this conversion
# so remove this after we upgrade officially to tf==2.12.
_batch = NestedDict(batch)
_batch = NestedDict(_batch)
with tf.GradientTape() as tape:
fwd_out = self._module.forward_train(_batch)
loss_per_module = self.compute_loss(fwd_out=fwd_out, batch=_batch)
Expand Down

0 comments on commit dad239a

Please sign in to comment.