Skip to content

Commit

Permalink
Update td3.py
Browse files Browse the repository at this point in the history
  • Loading branch information
yjwong1999 committed Jun 14, 2023
1 parent 4deb40b commit 6ac7699
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,10 @@ def save_checkpoint(self):

def load_checkpoint(self,load_file = ''):
print('... loading checkpoint ...')
self.load_state_dict(T.load(load_file))
if T.cuda.is_available():
self.load_state_dict(T.load(load_file))
else:
self.load_state_dict(T.load(load_file, map_location=T.device('cpu')))

class ActorNetwork(nn.Module):
def __init__(self, alpha, input_dims, fc1_dims, fc2_dims, fc3_dims, fc4_dims, n_actions, name,
Expand Down Expand Up @@ -241,7 +244,10 @@ def save_checkpoint(self):

def load_checkpoint(self, load_file=''):
print('... loading checkpoint ...')
self.load_state_dict(T.load(load_file))
if T.cuda.is_available():
self.load_state_dict(T.load(load_file))
else:
self.load_state_dict(T.load(load_file, map_location=T.device('cpu')))

class Agent(object):
def __init__(self, alpha, beta, input_dims, tau, env, gamma=0.99,
Expand Down

0 comments on commit 6ac7699

Please sign in to comment.