Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tuple index out of range in _exec_send_grads p2p.send #884

Closed
drcege opened this issue Apr 14, 2023 · 7 comments
Closed

tuple index out of range in _exec_send_grads p2p.send #884

drcege opened this issue Apr 14, 2023 · 7 comments
Labels
bug Something isn't working

Comments

@drcege
Copy link

drcege commented Apr 14, 2023

Describe the bug

When setting both pipe-parallel-size and model-parallel-size to 2, the training crashes. However, setting each individually to 2 (and keep the other as 1) works fine.

The stack trace:

...
done with setups ...                                                                                                                                                                 
time (ms) | model and optimizer: 5943.63 | train/valid/test data iterators: 1088.62                                                                                                  
training ...                                                                                                                                                                         
[2023-04-14 09:05:18,824] [INFO] [checkpointing.py:553:forward] Activation Checkpointing Information                                                                                 
[2023-04-14 09:05:18,824] [INFO] [checkpointing.py:554:forward] ----Partition Activations True, CPU CHECKPOINTING False                                                               
[2023-04-14 09:05:18,824] [INFO] [checkpointing.py:557:forward] ----contiguous Memory Checkpointing False with 24 total layers                                                        
[2023-04-14 09:05:18,824] [INFO] [checkpointing.py:560:forward] ----Synchronization True                                                                                              
[2023-04-14 09:05:18,824] [INFO] [checkpointing.py:561:forward] ----Profiling time in checkpointing False 
 Traceback (most recent call last):                 
   File "train.py", line 27, in <module>                       
     pretrain(neox_args=neox_args)               
   File "/gpt-neox/megatron/training.py", line 226, in pretrain
     pretrain(neox_args=neox_args)  
   File "/gpt-neox/megatron/training.py", line 226, in pretrain
     iteration = train(
   File "/gpt-neox/megatron/training.py", line 778, in train
     iteration = train(
   File "/gpt-neox/megatron/training.py", line 778, in train
     loss_dict, skipped_iter = train_step(
   File "/gpt-neox/megatron/training.py", line 684, in train_step
     loss_dict, skipped_iter = train_step(
   File "/gpt-neox/megatron/training.py", line 684, in train_step
     reduced_loss = train_step_pipe(
   File "/gpt-neox/megatron/training.py", line 734, in train_step_pipe
     reduced_loss = train_step_pipe(
   File "/gpt-neox/megatron/training.py", line 734, in train_step_pipe
     loss = model.train_batch(data_iter=data_iterator)
   File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/pipe/engine.py", line 346, in train_batch
     self._exec_schedule(sched)
   File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/pipe/engine.py", line 1378, in _exec_schedule
     loss = model.train_batch(data_iter=data_iterator)
   File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/pipe/engine.py", line 346, in train_batch
     self._exec_schedule(sched)
   File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/pipe/engine.py", line 1378, in _exec_schedule
     self._exec_instr(**cmd.kwargs)
   File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/pipe/engine.py", line 1025, in _exec_send_grads
     p2p.send(inputs[1], self.prev_stage)
 IndexError: tuple index out of range
...

To Reproduce

Steps to reproduce the behavior:

  1. Use the latest docker images leogao2/gpt-neox:sha-61b5eee
  2. I start two containers each with four visible gpus to simulate two-node distributed training
  3. Properly set up passwordless SSH login between containers
  4. Write /job/hostfile as follows
container1-ip slots=4
container2-ip slots=4
  1. Execute python ./deepy.py train.py -d configs 1-3B.yml local_setup.yml after modifying data-path, vocab-file, and merge-file

Expected behavior

Should train without error.

Proposed solution

After debugging, I believe the error was triggered here:
https://github.com/EleutherAI/DeeperSpeed/blob/457850dc5ad72960f0e8a8f1597914d682a7792c/deepspeed/runtime/pipe/engine.py#L1023-L1025

It seems that the length of inputs is less than 2, so the indexing is out of range. Does this means the grad is not properly partitioned when model-parallel-size>1?

I know the code comes from inside DeepSpeed, but these lines were written several years ago and have been used by many tools, suggesting that the error may be caused by incorrect external passing of NeoX.

Screenshots
If applicable, add screenshots to help explain your problem.

Environment (please complete the following information):

  • GPUs: 8x V100
  • Configs: 1-3B.yml and local_setup.yml with above modifications

Additional context
Add any other context about the problem here.

@drcege drcege added the bug Something isn't working label Apr 14, 2023
@Zlzzzupup
Copy link

I also encountered the same problem as you, did you solve it now?

@Zlzzzupup
Copy link

Zlzzzupup commented Apr 19, 2023

I replaced the _exec_send_grads function that was reported as error in gpt-neox2.0 (deepspeed0.8.3) with the _exec_send_grads function in gpt-neox1.0 (deepspeed0.3.15). Now the program works fine.

deepspeed0.8.3

    # def _exec_send_grads(self, buffer_id):
    #     if self.wall_clock_breakdown():
    #         self.timers('pipe_send_grad').start()

    #     inputs = self.pipe_buffers['inputs'][buffer_id]

    #     # Partition the gradient
    #     if self.is_grad_partitioned:
    #         if isinstance(inputs, tuple):
    #             first_input = inputs[0]
    #             assert all([torch.is_tensor(elt) for elt in inputs[1:]])
    #             inputs_grad_tail = [
    #                 elt.grad for elt in inputs[1:] if elt.grad is not None
    #             ]
    #         elif torch.is_tensor(inputs):
    #             first_input = inputs
    #             inputs_grad_tail = []
    #         else:
    #             raise ValueError("expecting a tensor or a tuple of tensors")
    #         assert torch.is_tensor(first_input)
    #         part = PartitionedTensor(tensor=first_input.grad,
    #                                  group=self.grid.get_slice_parallel_group())

    #         inputs = (part.to_meta(), part.data(), *inputs_grad_tail)

    #     # XXX Terrible hack
    #     # Drop the attention mask from the input buffer here. It does not have
    #     # a grad that needs to be communicated. We free the buffer immediately
    #     # after, so no need to restore it. The receiver also has a hack that skips
    #     # the recv. This is because NCCL does not let us send torch.BoolTensor :-(.
    #     if self.has_attention_mask or self.has_bool_tensors:
    #         inputs = list(inputs)
    #         inputs.pop()
    #         inputs = tuple(inputs)

    #     if isinstance(inputs, torch.Tensor):
    #         assert inputs.grad is not None
    #         p2p.send(inputs.grad, self.prev_stage)
    #     else:
    #         # XXX terrible hacky branch
    #         if self.is_grad_partitioned:
    #             # First two sends are partitioned gradient
    #             p2p.send(inputs[0], self.prev_stage)
    #             p2p.send(inputs[1], self.prev_stage)
    #         else:
    #             for idx, buffer in enumerate(inputs):
    #                 # Skip tensors that will not produce a grad
    #                 if not buffer.is_floating_point():
    #                     assert buffer.grad is None
    #                     continue
    #                 assert buffer.grad is not None
    #                 p2p.send(buffer.grad, self.prev_stage)

    #     # We can free up the input buffer now
    #     self.pipe_buffers['inputs'][buffer_id] = None

    #     if self.wall_clock_breakdown():
    #         self.timers('pipe_send_grad').stop()

deepspeed 0.3.15

   def _exec_send_grads(self, buffer_id):
        if self.wall_clock_breakdown():
            self.timers('pipe_send_grad').start()

        inputs = self.pipe_buffers['inputs'][buffer_id]

        # Partition the gradient
        if self.is_grad_partitioned:
            part = PartitionedTensor(tensor=inputs[0].grad,
                                     group=self.grid.get_slice_parallel_group())
            # Clear the large output data, but save the computation graph
            # Inject the partitoned tensor into the output before sending

            # XXX Hack
            inputs = tuple([part.to_meta(), part.data(), inputs[1]])

        # XXX Terrible hack
        # Drop the attention mask from the input buffer here. It does not have
        # a grad that needs to be communicated. We free the buffer immediately
        # after, so no need to restore it. The receiver also has a hack that skips
        # the recv. This is because NCCL does not let us send torch.BoolTensor :-(.
        if self.module.__class__.__name__ == 'GPT2ModelPipe':
            inputs = list(inputs)
            inputs.pop()
            inputs = tuple(inputs)

        if isinstance(inputs, torch.Tensor):
            assert inputs.grad is not None
            p2p.send(inputs.grad, self.prev_stage)
        else:
            # XXX terrible hacky branch
            if self.is_grad_partitioned:
                # First two sends are partitioned gradient
                p2p.send(inputs[0], self.prev_stage)
                p2p.send(inputs[1], self.prev_stage)
                # XXX hack hack hack
                #p2p.send(inputs[2].grad, self.prev_stage)
            else:
                for idx, buffer in enumerate(inputs):
                    # Skip tensors that will not produce a grad
                    if not buffer.is_floating_point():
                        assert buffer.grad is None
                        continue
                    assert buffer.grad is not None
                    p2p.send(buffer.grad, self.prev_stage)

        # We can free up the input buffer now
        self.pipe_buffers['inputs'][buffer_id] = None

        if self.wall_clock_breakdown():
            self.timers('pipe_send_grad').stop()

@StellaAthena
Copy link
Member

@Zlzzzupup thanks for looking into this! I've created a Code Diff to help understand what's going on.

I'm having trouble spotting any functional differences. In my understanding there are three blocks of differences:

  1. In the block under # Partition the gradient there is some additional error handling and the code has been generalized from an input length of 2 to arbitrary input lengths.
  2. Then part and inputs are defined. part is defined the same way and input has been changed to remove the explicit casting as a tuple and instead use a tuple constructor.
  3. Instead of checking a class attribute, we directly check if the mask exists.

Can you try restoring the current version of the code, and then introduce each of these changes in isolation? I'm very curious to see which one(s) break the code.

@Zlzzzupup
Copy link

Zlzzzupup commented Apr 20, 2023

@Zlzzzupup thanks for looking into this! I've created a Code Diff to help understand what's going on.

I'm having trouble spotting any functional differences. In my understanding there are three blocks of differences:

  1. In the block under # Partition the gradient there is some additional error handling and the code has been generalized from an input length of 2 to arbitrary input lengths.
  2. Then part and inputs are defined. part is defined the same way and input has been changed to remove the explicit casting as a tuple and instead use a tuple constructor.
  3. Instead of checking a class attribute, we directly check if the mask exists.

Can you try restoring the current version of the code, and then introduce each of these changes in isolation? I'm very curious to see which one(s) break the code.

@StellaAthena Thank for your reply, I reproduced the three parts of the modified code and located the problem in:

# inputs = (part.to_meta(), part.data(), *inputs_grad_tail)
inputs = tuple([part.to_meta(), part.data(), inputs[1]])

Other parts of the modification do not affect the running of the program.

@StellaAthena
Copy link
Member

@Quentin-Anthony do you have any idea why this would cause an error? I can run some tests, but since inputs_grad_tail = [elt.grad for elt in inputs[1:] if elt.grad is not None] I'm not sure what could be causing an issue here.

@Quentin-Anthony
Copy link
Member

I'm still discussing this with the DeepSpeed team. Please either apply microsoft/DeepSpeed#2538 or install from the latest DeeperSpeed, which already has this patch applied.

In general, please use the latest DeeperSpeed for running gpt-neox. We use it as a staging ground for fixes like these before they get merged into upstream DeepSpeed.

@StellaAthena
Copy link
Member

@drcege I’ve corrected requirements/requirements.txt to correctly install the DeeperSpeed fork @Quentin-Anthony is talking about. Installing from source should solve your problem.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants