diff --git a/dlrover/trainer/tests/torch/deepspeed_ckpt_test.py b/dlrover/trainer/tests/torch/deepspeed_ckpt_test.py index 55c13acef..3168c0884 100644 --- a/dlrover/trainer/tests/torch/deepspeed_ckpt_test.py +++ b/dlrover/trainer/tests/torch/deepspeed_ckpt_test.py @@ -48,10 +48,10 @@ def zero_optimization_stage(self): def save_checkpoint(self, save_dir, tag, client_state, save_latest): model_sd = self.model.state_dict() - model_path = os.path.join(save_dir, tag, "model_states.pt") + model_path = os.path.join(save_dir, str(tag), "model_states.pt") torch.save(model_sd, model_path) optimizer_sd = self.optimizer.state_dict() - optim_path = os.path.join(save_dir, tag, "optim_states.pt") + optim_path = os.path.join(save_dir, str(tag), "optim_states.pt") torch.save(optimizer_sd, optim_path) def load_checkpoint( @@ -111,7 +111,7 @@ def test_save_load(self): engine = MockDeepSpeedEngine(model, optimizer) checkpointer = DeepSpeedCheckpointer(engine, tmpdirname) checkpointer.save_checkpoint( - tmpdirname, str(step), storage_type=StorageType.MEMORY + tmpdirname, step, storage_type=StorageType.MEMORY ) shm_handler = checkpointer._async_save_engine._shm_handler self.assertFalse(shm_handler.no_checkpint_state()) @@ -122,13 +122,13 @@ def test_save_load(self): checkpointer._async_save_engine._shm_handler.metadata.get() ) ds_ckpt_config = tensor_meta["_DLORVER_CKPT_CONFIG"] - self.assertEqual(ds_ckpt_config.step, str(step)) + self.assertEqual(ds_ckpt_config.step, step) self.assertIsNotNone(tensor_meta["model_states"]) tracer_file = os.path.join(tmpdirname, "latest") self.assertFalse(os.path.exists(tracer_file)) checkpointer.save_checkpoint( - tmpdirname, str(step), storage_type=StorageType.DISK + tmpdirname, step, storage_type=StorageType.DISK ) # Wait asynchronously saving. start = time.time()