diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index 9da1058f2c8e..abbabb594579 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -502,6 +502,12 @@ def get_zero_allow_untested_optimizer(param_dict): ZERO_ALLOW_UNTESTED_OPTIMIZER_DEFAULT) +def get_zero_force_ds_cpu_optimizer(param_dict): + return get_scalar_param(param_dict, + ZERO_FORCE_DS_CPU_OPTIMIZER, + ZERO_FORCE_DS_CPU_OPTIMIZER_DEFAULT) + + def get_scheduler_name(param_dict): if SCHEDULER in param_dict.keys() and TYPE in param_dict[SCHEDULER].keys(): return param_dict[SCHEDULER][TYPE] @@ -859,6 +865,8 @@ def _initialize_params(self, param_dict): self.zero_allow_untested_optimizer = get_zero_allow_untested_optimizer( param_dict) + self.zero_force_ds_cpu_optimizer = get_zero_force_ds_cpu_optimizer(param_dict) + self.scheduler_name = get_scheduler_name(param_dict) self.scheduler_params = get_scheduler_params(param_dict) diff --git a/deepspeed/runtime/constants.py b/deepspeed/runtime/constants.py index 6925745a8e5a..450847126f24 100755 --- a/deepspeed/runtime/constants.py +++ b/deepspeed/runtime/constants.py @@ -73,6 +73,8 @@ ############################################# ZERO_ALLOW_UNTESTED_OPTIMIZER = "zero_allow_untested_optimizer" ZERO_ALLOW_UNTESTED_OPTIMIZER_DEFAULT = False +ZERO_FORCE_DS_CPU_OPTIMIZER = "zero_force_ds_cpu_optimizer" +ZERO_FORCE_DS_CPU_OPTIMIZER_DEFAULT = True # Steps STEPS_PER_PRINT = "steps_per_print" diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 08cb1fd7276a..13c669b6de90 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -719,6 +719,9 @@ def zero_optimization(self): def zero_allow_untested_optimizer(self): return self._config.zero_allow_untested_optimizer + def zero_force_ds_cpu_optimizer(self): + return self._config.zero_force_ds_cpu_optimizer + def zero_reduce_scatter(self): return self._config.zero_config.reduce_scatter @@ -1265,6 +1268,13 @@ def _configure_optimizer(self, client_optimizer, model_parameters): else: basic_optimizer = client_optimizer(model_parameters) log_dist('Using client callable to create basic optimizer', ranks=[0]) + + if self.zero_use_cpu_optimizer() and not isinstance( + basic_optimizer, + deepspeed.ops.adam.DeepSpeedCPUAdam): + if self.zero_force_ds_cpu_optimizer(): + msg = f'You are using ZeRO-Offload with a client provided optimizer ({type(basic_optimizer)}) which in most cases will yield poor performance. Please either use deepspeed.ops.adam.DeepSpeedCPUAdam or set an optimizer in your ds-config (https://www.deepspeed.ai/docs/config-json/#optimizer-parameters). If you really want to use a custom optimizer w. ZeRO-Offload and understand the performance impacts you can also set <"zero_force_ds_cpu_optimizer": false> in your configuration file.' + raise ZeRORuntimeException(msg) else: basic_optimizer = self._configure_basic_optimizer(model_parameters) log_dist( diff --git a/tests/unit/runtime/half_precision/test_fp16.py b/tests/unit/runtime/half_precision/test_fp16.py index c3c933fca144..b8b3e3d39db6 100644 --- a/tests/unit/runtime/half_precision/test_fp16.py +++ b/tests/unit/runtime/half_precision/test_fp16.py @@ -466,7 +466,8 @@ def test(self, zero_stage, use_cpu_offload): "stage": zero_stage, "cpu_offload": use_cpu_offload }, - "zero_allow_untested_optimizer": False + "zero_allow_untested_optimizer": False, + "zero_force_ds_cpu_optimizer": False } hidden_dim = 10 diff --git a/tests/unit/runtime/zero/test_zero.py b/tests/unit/runtime/zero/test_zero.py index 5de3ffca27df..f5d84060bd8e 100644 --- a/tests/unit/runtime/zero/test_zero.py +++ b/tests/unit/runtime/zero/test_zero.py @@ -18,6 +18,7 @@ from deepspeed.runtime.engine import DeepSpeedEngine from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint +from deepspeed.runtime.zero.utils import ZeRORuntimeException from deepspeed.accelerator import get_accelerator @@ -1384,3 +1385,40 @@ def forward(self, x, y): loss = loss[1] model.backward(loss) model.step() + + +@pytest.mark.parametrize('force_ds_optim', [True, False]) +class TestZeroOffloadOptim(DistributedTest): + world_size = 1 + + def test(self, force_ds_optim): + config_dict = { + "train_batch_size": 4, + "gradient_accumulation_steps": 2, + "steps_per_print": 1, + "fp16": { + "enabled": True + }, + "zero_optimization": { + "stage": 1, + "offload_optimizer": { + "device": "cpu" + } + }, + "zero_force_ds_cpu_optimizer": force_ds_optim, + } + hidden_dim = 10 + + model = SimpleModel(hidden_dim) + + optimizer = torch.optim.Adam(model.parameters()) + + if force_ds_optim: + with pytest.raises(ZeRORuntimeException): + model, _, _, _ = deepspeed.initialize(model=model, + optimizer=optimizer, + config=config_dict) + else: + model, _, _, _ = deepspeed.initialize(model=model, + optimizer=optimizer, + config=config_dict)