From 6ac76999ebc8d44330395f8cde801604de8b2deb Mon Sep 17 00:00:00 2001 From: Wong Yi Jie <55955482+yjwong1999@users.noreply.github.com> Date: Wed, 14 Jun 2023 15:16:05 +0800 Subject: [PATCH] Update td3.py --- td3.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/td3.py b/td3.py index db0994a..229553b 100644 --- a/td3.py +++ b/td3.py @@ -158,7 +158,10 @@ def save_checkpoint(self): def load_checkpoint(self,load_file = ''): print('... loading checkpoint ...') - self.load_state_dict(T.load(load_file)) + if T.cuda.is_available(): + self.load_state_dict(T.load(load_file)) + else: + self.load_state_dict(T.load(load_file, map_location=T.device('cpu'))) class ActorNetwork(nn.Module): def __init__(self, alpha, input_dims, fc1_dims, fc2_dims, fc3_dims, fc4_dims, n_actions, name, @@ -241,7 +244,10 @@ def save_checkpoint(self): def load_checkpoint(self, load_file=''): print('... loading checkpoint ...') - self.load_state_dict(T.load(load_file)) + if T.cuda.is_available(): + self.load_state_dict(T.load(load_file)) + else: + self.load_state_dict(T.load(load_file, map_location=T.device('cpu'))) class Agent(object): def __init__(self, alpha, beta, input_dims, tau, env, gamma=0.99,