diff --git a/td3.py b/td3.py index db0994a..229553b 100644 --- a/td3.py +++ b/td3.py @@ -158,7 +158,10 @@ def save_checkpoint(self): def load_checkpoint(self,load_file = ''): print('... loading checkpoint ...') - self.load_state_dict(T.load(load_file)) + if T.cuda.is_available(): + self.load_state_dict(T.load(load_file)) + else: + self.load_state_dict(T.load(load_file, map_location=T.device('cpu'))) class ActorNetwork(nn.Module): def __init__(self, alpha, input_dims, fc1_dims, fc2_dims, fc3_dims, fc4_dims, n_actions, name, @@ -241,7 +244,10 @@ def save_checkpoint(self): def load_checkpoint(self, load_file=''): print('... loading checkpoint ...') - self.load_state_dict(T.load(load_file)) + if T.cuda.is_available(): + self.load_state_dict(T.load(load_file)) + else: + self.load_state_dict(T.load(load_file, map_location=T.device('cpu'))) class Agent(object): def __init__(self, alpha, beta, input_dims, tau, env, gamma=0.99,