Skip to content

Commit

Permalink
Fix test cases.
Browse files Browse the repository at this point in the history
  • Loading branch information
workingloong committed Jun 3, 2024
1 parent 9a6c472 commit 65550ba
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions dlrover/trainer/tests/torch/deepspeed_ckpt_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ def zero_optimization_stage(self):

def save_checkpoint(self, save_dir, tag, client_state, save_latest):
model_sd = self.model.state_dict()
model_path = os.path.join(save_dir, tag, "model_states.pt")
model_path = os.path.join(save_dir, str(tag), "model_states.pt")
torch.save(model_sd, model_path)
optimizer_sd = self.optimizer.state_dict()
optim_path = os.path.join(save_dir, tag, "optim_states.pt")
optim_path = os.path.join(save_dir, str(tag), "optim_states.pt")
torch.save(optimizer_sd, optim_path)

def load_checkpoint(
Expand Down Expand Up @@ -111,7 +111,7 @@ def test_save_load(self):
engine = MockDeepSpeedEngine(model, optimizer)
checkpointer = DeepSpeedCheckpointer(engine, tmpdirname)
checkpointer.save_checkpoint(
tmpdirname, str(step), storage_type=StorageType.MEMORY
tmpdirname, step, storage_type=StorageType.MEMORY
)
shm_handler = checkpointer._async_save_engine._shm_handler
self.assertFalse(shm_handler.no_checkpint_state())
Expand All @@ -122,13 +122,13 @@ def test_save_load(self):
checkpointer._async_save_engine._shm_handler.metadata.get()
)
ds_ckpt_config = tensor_meta["_DLORVER_CKPT_CONFIG"]
self.assertEqual(ds_ckpt_config.step, str(step))
self.assertEqual(ds_ckpt_config.step, step)
self.assertIsNotNone(tensor_meta["model_states"])
tracer_file = os.path.join(tmpdirname, "latest")
self.assertFalse(os.path.exists(tracer_file))

checkpointer.save_checkpoint(
tmpdirname, str(step), storage_type=StorageType.DISK
tmpdirname, step, storage_type=StorageType.DISK
)
# Wait asynchronously saving.
start = time.time()
Expand Down

0 comments on commit 65550ba

Please sign in to comment.