Skip to content

Commit

Permalink
[zero] prevent poor configs from running w. zero-offload (microsoft#2971
Browse files Browse the repository at this point in the history
)
  • Loading branch information
jeffra committed Mar 9, 2023
1 parent 58a4a4d commit 457850d
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 1 deletion.
8 changes: 8 additions & 0 deletions deepspeed/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,12 @@ def get_zero_allow_untested_optimizer(param_dict):
ZERO_ALLOW_UNTESTED_OPTIMIZER_DEFAULT)


def get_zero_force_ds_cpu_optimizer(param_dict):
return get_scalar_param(param_dict,
ZERO_FORCE_DS_CPU_OPTIMIZER,
ZERO_FORCE_DS_CPU_OPTIMIZER_DEFAULT)


def get_scheduler_name(param_dict):
if SCHEDULER in param_dict.keys() and TYPE in param_dict[SCHEDULER].keys():
return param_dict[SCHEDULER][TYPE]
Expand Down Expand Up @@ -859,6 +865,8 @@ def _initialize_params(self, param_dict):
self.zero_allow_untested_optimizer = get_zero_allow_untested_optimizer(
param_dict)

self.zero_force_ds_cpu_optimizer = get_zero_force_ds_cpu_optimizer(param_dict)

self.scheduler_name = get_scheduler_name(param_dict)
self.scheduler_params = get_scheduler_params(param_dict)

Expand Down
2 changes: 2 additions & 0 deletions deepspeed/runtime/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@
#############################################
ZERO_ALLOW_UNTESTED_OPTIMIZER = "zero_allow_untested_optimizer"
ZERO_ALLOW_UNTESTED_OPTIMIZER_DEFAULT = False
ZERO_FORCE_DS_CPU_OPTIMIZER = "zero_force_ds_cpu_optimizer"
ZERO_FORCE_DS_CPU_OPTIMIZER_DEFAULT = True

# Steps
STEPS_PER_PRINT = "steps_per_print"
Expand Down
10 changes: 10 additions & 0 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,9 @@ def zero_optimization(self):
def zero_allow_untested_optimizer(self):
return self._config.zero_allow_untested_optimizer

def zero_force_ds_cpu_optimizer(self):
return self._config.zero_force_ds_cpu_optimizer

def zero_reduce_scatter(self):
return self._config.zero_config.reduce_scatter

Expand Down Expand Up @@ -1265,6 +1268,13 @@ def _configure_optimizer(self, client_optimizer, model_parameters):
else:
basic_optimizer = client_optimizer(model_parameters)
log_dist('Using client callable to create basic optimizer', ranks=[0])

if self.zero_use_cpu_optimizer() and not isinstance(
basic_optimizer,
deepspeed.ops.adam.DeepSpeedCPUAdam):
if self.zero_force_ds_cpu_optimizer():
msg = f'You are using ZeRO-Offload with a client provided optimizer ({type(basic_optimizer)}) which in most cases will yield poor performance. Please either use deepspeed.ops.adam.DeepSpeedCPUAdam or set an optimizer in your ds-config (https://www.deepspeed.ai/docs/config-json/#optimizer-parameters). If you really want to use a custom optimizer w. ZeRO-Offload and understand the performance impacts you can also set <"zero_force_ds_cpu_optimizer": false> in your configuration file.'
raise ZeRORuntimeException(msg)
else:
basic_optimizer = self._configure_basic_optimizer(model_parameters)
log_dist(
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/runtime/half_precision/test_fp16.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,8 @@ def test(self, zero_stage, use_cpu_offload):
"stage": zero_stage,
"cpu_offload": use_cpu_offload
},
"zero_allow_untested_optimizer": False
"zero_allow_untested_optimizer": False,
"zero_force_ds_cpu_optimizer": False
}
hidden_dim = 10

Expand Down
38 changes: 38 additions & 0 deletions tests/unit/runtime/zero/test_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from deepspeed.runtime.engine import DeepSpeedEngine
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
from deepspeed.runtime.zero.utils import ZeRORuntimeException
from deepspeed.accelerator import get_accelerator


Expand Down Expand Up @@ -1384,3 +1385,40 @@ def forward(self, x, y):
loss = loss[1]
model.backward(loss)
model.step()


@pytest.mark.parametrize('force_ds_optim', [True, False])
class TestZeroOffloadOptim(DistributedTest):
world_size = 1

def test(self, force_ds_optim):
config_dict = {
"train_batch_size": 4,
"gradient_accumulation_steps": 2,
"steps_per_print": 1,
"fp16": {
"enabled": True
},
"zero_optimization": {
"stage": 1,
"offload_optimizer": {
"device": "cpu"
}
},
"zero_force_ds_cpu_optimizer": force_ds_optim,
}
hidden_dim = 10

model = SimpleModel(hidden_dim)

optimizer = torch.optim.Adam(model.parameters())

if force_ds_optim:
with pytest.raises(ZeRORuntimeException):
model, _, _, _ = deepspeed.initialize(model=model,
optimizer=optimizer,
config=config_dict)
else:
model, _, _, _ = deepspeed.initialize(model=model,
optimizer=optimizer,
config=config_dict)

0 comments on commit 457850d

Please sign in to comment.