Skip to content

Commit

Permalink
Merge pull request #188 from ranamihir/patch-1
Browse files Browse the repository at this point in the history
Fix data type for torch v0.4.1
  • Loading branch information
ikostrikov2 committed Apr 16, 2019
2 parents 4f04391 + 6a89459 commit b4133ec
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions a2c_ppo_acktr/algo/gail.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,11 @@ def __init__(self, file_name, num_trajectories=4, subsample_frequency=20):
idx = perm[:num_trajectories]

self.trajectories = {}


# See https://github.com/pytorch/pytorch/issues/14886
# .long() for fixing bug in torch v0.4.1
start_idx = torch.randint(
0, subsample_frequency, size=(num_trajectories, ))
0, subsample_frequency, size=(num_trajectories, )).long()

for k, v in all_trajectories.items():
data = v[idx]
Expand Down Expand Up @@ -162,4 +164,4 @@ def __getitem__(self, i):
traj_idx, i = self.get_idx[i]

return self.trajectories['states'][traj_idx][i], self.trajectories[
'actions'][traj_idx][i]
'actions'][traj_idx][i]

0 comments on commit b4133ec

Please sign in to comment.