Skip to content

Commit

Permalink
Merge pull request #287 from billh0420/dev
Browse files Browse the repository at this point in the history
Fix missing fields for checkpoint.
  • Loading branch information
daochenzha committed May 20, 2023
2 parents fe65713 + f294b82 commit 63ab6a4
Showing 1 changed file with 17 additions and 10 deletions.
27 changes: 17 additions & 10 deletions rlcard/agents/dqn_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ def __init__(self,
save_path (str): The path to save the model checkpoints
save_every (int): Save the model every X training steps
'''
self.use_raw = False
self.replay_memory_init_size = replay_memory_init_size
self.update_target_estimator_every = update_target_estimator_every
self.discount_factor = discount_factor
Expand Down Expand Up @@ -268,17 +267,20 @@ def checkpoint_attributes(self):
'memory': self.memory.checkpoint_attributes(),
'total_t': self.total_t,
'train_t': self.train_t,
'replay_memory_init_size': self.replay_memory_init_size,
'update_target_estimator_every': self.update_target_estimator_every,
'discount_factor': self.discount_factor,
'epsilon_start': self.epsilons.min(),
'epsilon_end': self.epsilons.max(),
'epsilon_decay_steps': self.epsilon_decay_steps,
'discount_factor': self.discount_factor,
'update_target_estimator_every': self.update_target_estimator_every,
'batch_size': self.batch_size,
'num_actions': self.num_actions,
'train_every': self.train_every,
'device': self.device
'device': self.device,
'save_path': self.save_path,
'save_every': self.save_every
}

@classmethod
def from_checkpoint(cls, checkpoint):
'''
Expand All @@ -291,17 +293,21 @@ def from_checkpoint(cls, checkpoint):
print("\nINFO - Restoring model from checkpoint...")
agent_instance = cls(
replay_memory_size=checkpoint['memory']['memory_size'],
replay_memory_init_size=checkpoint['replay_memory_init_size'],
update_target_estimator_every=checkpoint['update_target_estimator_every'],
discount_factor=checkpoint['discount_factor'],
epsilon_start=checkpoint['epsilon_start'],
epsilon_end=checkpoint['epsilon_end'],
epsilon_decay_steps=checkpoint['epsilon_decay_steps'],
batch_size=checkpoint['batch_size'],
num_actions=checkpoint['num_actions'],
device=checkpoint['device'],
state_shape=checkpoint['q_estimator']['state_shape'],
train_every=checkpoint['train_every'],
mlp_layers=checkpoint['q_estimator']['mlp_layers'],
train_every=checkpoint['train_every']
learning_rate=checkpoint['q_estimator']['learning_rate'],
device=checkpoint['device'],
save_path=checkpoint['save_path'],
save_every=checkpoint['save_every'],
)

agent_instance.total_t = checkpoint['total_t']
Expand All @@ -310,18 +316,19 @@ def from_checkpoint(cls, checkpoint):
agent_instance.q_estimator = Estimator.from_checkpoint(checkpoint['q_estimator'])
agent_instance.target_estimator = deepcopy(agent_instance.q_estimator)
agent_instance.memory = Memory.from_checkpoint(checkpoint['memory'])



return agent_instance

def save_checkpoint(self, path, filename='checkpoint_dqn.pt'):
''' Save the model checkpoint (all attributes)
Args:
path (str): the path to save the model
filename(str): the file name of checkpoint
'''
torch.save(self.checkpoint_attributes(), os.path.join(path, filename))



class Estimator(object):
'''
Approximate clone of rlcard.agents.dqn_agent.Estimator that
Expand Down

0 comments on commit 63ab6a4

Please sign in to comment.