Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pytorch gradient checkpointing is much better than deepspeed ! #63

Closed
agemagician opened this issue Feb 11, 2020 · 4 comments
Closed

pytorch gradient checkpointing is much better than deepspeed ! #63

agemagician opened this issue Feb 11, 2020 · 4 comments

Comments

@agemagician
Copy link

agemagician commented Feb 11, 2020

Hello,

I have a script that trains 12 layers transformer model (about 85 million) using gradient checkpoint. It was working with a local batch size of 32 per Nvidia Titan GPU.
I tried to use deepspeed instead and I am always getting OOM, even with a batch size 8.

minimal code:
Initialization:

model = models.TransformerModel(ntokens, args.emsize, args.nhead, args.nhid, args.nlayers, args.dropout)

parameters = filter(lambda p: p.requires_grad, model.parameters())

model_engine, optimizer, _, _ = deepspeed.initialize(args=args, model=model, model_parameters=parameters)

Training:

with tqdm(total=int(args.log_interval),
              desc='Train Step     #{}-{}'.format(step + 1,step+args.log_interval),
              disable=False) as t:
        for batch_idx, batch in enumerate(datasetGenerator):

            data, target,src_padding = batch['input'].to(model_engine.local_rank), batch['target'].to(model_engine.local_rank), batch['padding_mask'].to(model_engine.local_rank)
            

            output = model_engine(data, has_mask=False,src_key_padding_mask = src_padding.t())


            train_accuracy.update(accuracy(target, output))
            loss = criterion(output.view(-1, ntokens), target.view(-1))

            model_engine.backward(loss)
            model_engine.step()

            t.set_postfix({'loss': train_loss.avg.item(),
                           'accuracy': 100. * train_accuracy.avg.item()})
            t.update(1)

Original Transformer code with gradient checkpointing:

def forward(self, src, mask=None, src_key_padding_mask=None):
        r"""Pass the input through the encoder layers in turn.
        Args:
            src: the sequnce to the encoder (required).
            mask: the mask for the src sequence (optional).
            src_key_padding_mask: the mask for the src keys per batch (optional).
        Shape:
            see the docs in Transformer class.
        """
        output = src

        for i in range(self.num_layers):
            #output = self.layers[i](output, src_mask=mask,
            #                        src_key_padding_mask=src_key_padding_mask)
            output = checkpoint(self.layers[i], output, mask, src_key_padding_mask)


        if self.norm:
            output = self.norm(output)

        return output 

The working batch size for the dataloader is only 4.
Any idea how can I achieve the same batch size as gradient checkpoints with deepspeed?

@agemagician agemagician changed the title pytorch gradient checkpointing working but not deepspeed pytorch gradient checkpointing is much better than deepspeed ! Feb 11, 2020
@samyam
Copy link
Contributor

samyam commented Feb 11, 2020

You can train a larger batch size in two ways:

  1. Use gradient accumulation. If you want to train a batch size of 32 but can only fit batch size of 4, you can use a micro_batch_per_gpu of size 4 and gradient_accumulation_step of 8. Was the original code using gradient accumulation? Are the batch sizes consistent between deepscale config file and the data loader you are using?

  2. Use activation check-pointing/re-materialization/re-computation. I assume this is what you mean by gradient check-pointing. With this you should be able to run a significantly larger batch size for the model you are describing.

What is the sequence length you are using, and what are your hidden dimensions?

Please take a look at the line 390-398 in the following file for example of activation checkpointing.

https://github.com/microsoft/DeepSpeedExamples/blob/07d1ce0d26044602a7b8bb289e590a980f14aded/Megatron-LM/mpu/transformer.py

image

@agemagician
Copy link
Author

agemagician commented Feb 11, 2020

Thanks @samyam for the feedback.

In the original code, I was using both gradient accumulation and activation check-pointing/re-materialization/re-computation.
The total batch size was: 32 batch size * 8 gradient accumulation * 6 GPUs
Input length: 1024
embedding dimensions: 768
hidden dimensions: 3072
Vocab: 22

The gradient accumulation is not an issue between my code and deepspeed code.
The batch size is consistent between my custom data loader and deepspeed config file.

The problem is in the activation check-pointing/re-materialization/re-computation. I assumed if I disabled it and activate deepspeed, I can get the same results. Apparently, this is not the case.

I will re-integrate gradient checkpoint and see if there is a benefit of using deepspeed in this case or not.

@samyam
Copy link
Contributor

samyam commented Feb 11, 2020

Yes, DeepSpeed does not do activation check-pointing automatically. You have to add that in the model. Please keep us posted on if this solves your issue

@ShadenSmith
Copy link
Contributor

Closing the issue. Please let us know if there are further questions or issues!

rraminen pushed a commit to rraminen/DeepSpeed that referenced this issue Apr 28, 2021
delock pushed a commit to delock/DeepSpeedSYCLSupport that referenced this issue Sep 21, 2022
This reverts commit 59c91511a733db08c01cd35318b9d18e2b0d3894.
pengwa pushed a commit to pengwa/DeepSpeed that referenced this issue Oct 14, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants