Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pipeline parallelism and gradient checkpointing (edit: and ZeRO 2!) don’t work together #62

Closed
StellaAthena opened this issue Jan 14, 2021 · 12 comments · Fixed by #77, #59 or #90
Closed
Assignees
Labels
bug Something isn't working

Comments

@StellaAthena
Copy link
Member

Pipeline parallelism and gradient checkpointing both work when you use them individually. However when you turn them both on you get a mysterious KeyError: 0 from somewhere deep in DeepSpeed.

@StellaAthena StellaAthena added the bug Something isn't working label Jan 14, 2021
@StellaAthena StellaAthena added this to To do in 1T or BUST via automation Jan 14, 2021
@StellaAthena StellaAthena moved this from To do to In progress in 1T or BUST Jan 14, 2021
This was referenced Jan 14, 2021
@StellaAthena StellaAthena linked a pull request Jan 14, 2021 that will close this issue
@StellaAthena StellaAthena removed a link to a pull request Jan 14, 2021
@StellaAthena StellaAthena removed this from In progress in 1T or BUST Jan 14, 2021
@StellaAthena
Copy link
Member Author

Update: we can now use any two of the following three options: ZeRO Stage 2, Parallel Pipelining, and Activation Checkpointing. If all three are enabled, it throws the following error:

/root/anaconda3/envs/ds/lib/python3.7/site-packages/deepspeed/runtime/pipe/engine.py:993: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the gradient for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations.

if you open the Stella branch you can replicate this. If you run sh scripts/train_enwik8_pipeline.sh it will have all three enabled and error.

To turn off activation checkpointing, set “number_checkpoints”: null in configs/deepspeed_zero2.json.

To turn off pipelining, run sh scripts/train_enwik8.sh.

To turn off ZeRO Stage 2, use configs/deepspeed_zero1.json as your config file.

@StellaAthena StellaAthena changed the title Pipeline parallelism and gradient checkpointing don’t work together Pipeline parallelism and gradient checkpointing (edit: and ZeRO 2!) don’t work together Jan 15, 2021
@leogao2
Copy link
Contributor

leogao2 commented Jan 16, 2021

I ran all pairs of 3 and the results are as follows.

Zero2+pipeline: does not work (contiguous gradients both on and off)
Checkpoint+pipeline: does work (contiguous gradients both on and off)
Zero2+checkpoint: does work (contiguous gradients on; didn't test off)

@leogao2
Copy link
Contributor

leogao2 commented Jan 16, 2021

All of the errors and warnings that occur for zero2+pipeline:

Traceback (most recent call last):
  File "train_enwik8_pipeline.py", line 109, in <module>
    loss = model_engine.train_batch()
  File "/root/anaconda3/envs/ds/lib/python3.7/site-packages/deepspeed/runtime/pipe/engine.py", line 273, in train_batch
    self._exec_schedule(sched)
  File "/root/anaconda3/envs/ds/lib/python3.7/site-packages/deepspeed/runtime/pipe/engine.py", line 1162, in _exec_schedule
    self._exec_instr(**cmd.kwargs)
  File "/root/anaconda3/envs/ds/lib/python3.7/site-packages/deepspeed/runtime/pipe/engine.py", line 952, in _exec_optimizer_step
    self._take_model_step(lr_kwargs)
  File "/root/anaconda3/envs/ds/lib/python3.7/site-packages/deepspeed/runtime/engine.py", line 916, in _take_model_step
    self.optimizer.step()
  File "/root/anaconda3/envs/ds/lib/python3.7/site-packages/deepspeed/runtime/zero/stage2.py", line 1341, in step
    self.check_overflow()
  File "/root/anaconda3/envs/ds/lib/python3.7/site-packages/deepspeed/runtime/zero/stage2.py", line 1612, in check_overflow
    self._check_overflow(partition_gradients)
  File "/root/anaconda3/envs/ds/lib/python3.7/site-packages/deepspeed/runtime/zero/stage2.py", line 1516, in _check_overflow
    self.overflow = self.has_overflow(partition_gradients)
  File "/root/anaconda3/envs/ds/lib/python3.7/site-packages/deepspeed/runtime/zero/stage2.py", line 1535, in has_overflow
    overflow = self.local_overflow if self.cpu_offload else self.has_overflow_partitioned_grads_serial(
  File "/root/anaconda3/envs/ds/lib/python3.7/site-packages/deepspeed/runtime/zero/stage2.py", line 1528, in has_overflow_partitioned_grads_serial
    for j, grad in enumerate(self.averaged_gradients[i]):
KeyError: 0

(this one shows up 4 times)

Traceback (most recent call last):
  File "train_enwik8_pipeline.py", line 109, in <module>
    self._exec_schedule(sched)
  File "/root/anaconda3/envs/ds/lib/python3.7/site-packages/deepspeed/runtime/pipe/engine.py", line 1162, in _exec_schedule
    loss = model_engine.train_batch()
  File "/root/anaconda3/envs/ds/lib/python3.7/site-packages/deepspeed/runtime/pipe/engine.py", line 273, in train_batch
    self._exec_schedule(sched)
  File "/root/anaconda3/envs/ds/lib/python3.7/site-packages/deepspeed/runtime/pipe/engine.py", line 1162, in _exec_schedule
    self._exec_instr(**cmd.kwargs)
  File "/root/anaconda3/envs/ds/lib/python3.7/site-packages/deepspeed/runtime/pipe/engine.py", line 602, in _exec_backward_pass
    torch.autograd.backward(tensors=(outputs, ), grad_tensors=(grad_tensors, ))
  File "/root/anaconda3/envs/ds/lib/python3.7/site-packages/torch/autograd/__init__.py", line 132, in backward
    self._exec_instr(**cmd.kwargs)
  File "/root/anaconda3/envs/ds/lib/python3.7/site-packages/deepspeed/runtime/pipe/engine.py", line 602, in _exec_backward_pass
    allow_unreachable=True)  # allow_unreachable flag
  File "/root/anaconda3/envs/ds/lib/python3.7/site-packages/deepspeed/runtime/zero/stage2.py", line 594, in reduce_partition_and_remove_grads
    torch.autograd.backward(tensors=(outputs, ), grad_tensors=(grad_tensors, ))
  File "/root/anaconda3/envs/ds/lib/python3.7/site-packages/torch/autograd/__init__.py", line 132, in backward
    self.reduce_ready_partitions_and_remove_grads(param, i)
  File "/root/anaconda3/envs/ds/lib/python3.7/site-packages/deepspeed/runtime/zero/stage2.py", line 984, in reduce_ready_partitions_and_remove_grads
    allow_unreachable=True)  # allow_unreachable flag
  File "/root/anaconda3/envs/ds/lib/python3.7/site-packages/deepspeed/runtime/zero/stage2.py", line 594, in reduce_partition_and_remove_grads
    self.reduce_independent_p_g_buckets_and_remove_grads(param, i)
  File "/root/anaconda3/envs/ds/lib/python3.7/site-packages/deepspeed/runtime/zero/stage2.py", line 633, in reduce_independent_p_g_buckets_and_remove_grads
    self.reduce_ready_partitions_and_remove_grads(param, i)
  File "/root/anaconda3/envs/ds/lib/python3.7/site-packages/deepspeed/runtime/zero/stage2.py", line 984, in reduce_ready_partitions_and_remove_grads
    new_grad_tensor = self.ipg_buffer[self.ipg_index].narrow(
AttributeError: 'FP16_DeepSpeedZeroOptimizer' object has no attribute 'ipg_index'

(this one pops up 4 times and I think 2 of them got mangled together here)

As well as the warning that stella mentioned.

@leogao2 leogao2 self-assigned this Jan 16, 2021
@leogao2
Copy link
Contributor

leogao2 commented Jan 16, 2021

With contiguous gradients off, the FP16_DeepSpeedZeroOptimizer error no longer happens and I get 8 KeyErrors.

@leogao2
Copy link
Contributor

leogao2 commented Jan 16, 2021

Checkpoint+pipeline works with both continuous gradients on and off. Therefore, I don't think it's a major factor for zero2 breaking, but I'll keep it off for the remainder of my tests.

@leogao2
Copy link
Contributor

leogao2 commented Jan 16, 2021

Focusing on the KeyError now.

The only place where self.averaged_gradients is written to within stage2.py is in the independent_gradient_partition_epilogue function (https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/zero/stage2.py#L485). So either this function just isn't being called, or it is being called but L485 is never being reached.

@leogao2
Copy link
Contributor

leogao2 commented Jan 16, 2021

@leogao2
Copy link
Contributor

leogao2 commented Jan 16, 2021

The reason the pipeline code is problematic is because it disables the backward_allreduce
https://github.com/microsoft/DeepSpeed/blob/81aeea361da3936b875a678b9cb44596800510b5/deepspeed/runtime/pipe/engine.py#L56
which means allreduce_gradients in the non-Pipelined engine never runs
https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/engine.py#L914
which means independent_gradient_partition_epilogue never gets called (see previous comment)

The pipeline code pushes a ReduceTiedGrads and then a ReduceGrads here: https://github.com/microsoft/DeepSpeed/blob/81aeea361da3936b875a678b9cb44596800510b5/deepspeed/runtime/pipe/schedule.py#L235

Execution of that ReduceTiedGrads op:
https://github.com/microsoft/DeepSpeed/blob/81aeea361da3936b875a678b9cb44596800510b5/deepspeed/runtime/pipe/engine.py#L1139
https://github.com/microsoft/DeepSpeed/blob/81aeea361da3936b875a678b9cb44596800510b5/deepspeed/runtime/pipe/engine.py#L208
https://github.com/microsoft/DeepSpeed/blob/81aeea361da3936b875a678b9cb44596800510b5/deepspeed/runtime/pipe/module.py#L405

Execution of ReduceGrads:
https://github.com/microsoft/DeepSpeed/blob/81aeea361da3936b875a678b9cb44596800510b5/deepspeed/runtime/pipe/engine.py#L211
calls buffered_allreduce_fallback, but only if dataparallel is enabled.
https://github.com/microsoft/DeepSpeed/blob/865104be85902ca398038045ad9cf94ec7d48745/deepspeed/runtime/engine.py#L1156

@leogao2
Copy link
Contributor

leogao2 commented Jan 16, 2021

With the patch applied,

Zero2+pipeline now works
Checkpoint+pipeline now works
Zero2+checkpoint now works

Zero2+checkpoint+pipeline now works

@leogao2
Copy link
Contributor

leogao2 commented Jan 16, 2021

Profiling results:

patched, zero2+checkpoint+pipeline: samples/sec: 1159.741, max vram: 3245MiB
patched: zero2+checkpoint: samples/sec: 1120.8568733324405, max vram: 1704MiB

@StellaAthena
Copy link
Member Author

With DeepSpeed's updates this seems to run just fine. The question of if it runs efficiently is still open though.

microsoft/DeepSpeed#677

@StellaAthena
Copy link
Member Author

Turns out we weren’t using gradient checkpointing at all! You can add checkpointing to the params without initializing the checkpointer and you can initialize the checkpointer without actually using it! #90 should actually implement gradient checkpointing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
3 participants