From 09f096ddb19bf1bd70980429aa02efb6dd5af791 Mon Sep 17 00:00:00 2001 From: Hailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com> Date: Wed, 17 Jan 2024 19:43:59 -0500 Subject: [PATCH] Fixes for training models with bf16 + freshly initialized optimizer via `load_module_only` (#4141) This PR makes some fixes to the case where we want to resume training from a DeepSpeed ZeRO checkpoint and initialize a new optimizer, while not using the old optimizer in the checkpoint or relying on its existence at all. in this situation, despite passing `load_module_only=True` and `load_optimizer_states=False` to `load_checkpoint()`, the previous behavior was that: - `self._load_zero_checkpoint` would still be called, which attempts to load from the (in this case, nonexistent) checkpoint files. This PR stops this function from being called if using `load_module_only=True` and `load_optimizer_states=False`. Alternatively, calling this function may be alright if `"load_from_fp32_weights": true` is set in the DeepSpeed ZeRO config (reference: https://github.com/microsoft/DeepSpeed/blob/ff7d5275f2aa916cb5f320e0d817154e96f9cdb6/deepspeed/runtime/engine.py#L733) but this parameter does not seem to be documented in the docs for ZeRO config dicts. - in `_load_checkpoint`, the following codeblock: ``` if self.optimizer is not None and self.fp16_enabled(): self.optimizer.refresh_fp32_params() ``` results in `self.optimizer.refresh_fp32_params()` being called only if using FP16. As a result, the FP32 optimizer state is never initialized from the 16-bit model weights. This PR removes the fp16-specific condition. Previously reported in: https://github.com/EleutherAI/gpt-neox/issues/947 https://github.com/EleutherAI/gpt-neox/issues/843 Should also close: https://github.com/microsoft/DeepSpeed/issues/4017 Fixes: #4944 and #4017 This caused problems for a freshly-converted LLama checkpoint, which did not contain optimizer states, when trying to train with this model as initialization. I have confirmed the following fixes prevent this behavior. cc @Quentin-Anthony @zhangir-azerbayev --------- Co-authored-by: Olatunji Ruwase Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- deepspeed/runtime/engine.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index c4fb0778b1c8..f511d207fbfe 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -2791,10 +2791,12 @@ def load_checkpoint(self, load_module_only=load_module_only, custom_load_fn=custom_load_fn) - load_zero_checkpoint = load_optimizer_states and load_path is not None and (self.zero_optimization() - or self.bfloat16_enabled()) + load_zero_checkpoint = load_path is not None and (self.zero_optimization() or self.bfloat16_enabled()) if load_zero_checkpoint: - success = self._load_zero_checkpoint(load_dir, tag, load_optimizer_states=load_optimizer_states) + if load_optimizer_states and not load_module_only: + success = self._load_zero_checkpoint(load_dir, tag, load_optimizer_states=load_optimizer_states) + else: + success = False if not success: self.optimizer._restore_from_bit16_weights() @@ -2876,7 +2878,7 @@ def _load_checkpoint(self, optim_checkpoint = None if load_module_only: deepspeed_states = ['module'] - if self.optimizer is not None and self.fp16_enabled(): + if self.optimizer is not None: self.optimizer.refresh_fp32_params() else: has_zero_optimizer_state = self.zero_optimization() or self.bfloat16_enabled()