Skip to content

Fully Sharded Data Parallel (FSDP) implementation of Transformer XL

License

Notifications You must be signed in to change notification settings

ridwan-salau/transformer-xl

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

52 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Fully-Sharded Data Parallel Transformer-XL

This repository extends the implementation of Transformer-XL to fully-sharded data parallel.

The original implementation can be found here:

Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context.

Our implementation focuses on the Pytorch implementation

PyTorch

  • The source code is in the pytorch/ folder, supporting single-node multi-gpu training via the modules nn.DataParallel (original implementation), and distributed.FullyShardedDataParallel.
  • Please refer to pytorch/README.md for details.

Results

We use 40GB NVIDIA A100 SXM GPUs. We tested with 1 device and 4 devices on a single node. We report the following performace improvements below:

Single Device FSDP

single device result

Key Findings:

  • FSDP with “no_shard” (DDP), “grad_op” (ZeRO Stage 2) and “full_shard” (ZeRO stage 3) have identical memory footprints. This is because sharding cannot happen on a single device.
  • “chkpt” (activation checkpointing) gives the highest boost to the memory footprint, allowing for upto 6x increase in batch.
  • “wrap” (wraping decoder layers) and “fp16” together gave a negligible boost to the maximum batch size
  • 9x increase in batch size from the baseline

Multi-device (4) FSDP

multi device result

Key Findings:

  • FSDP with “grad_op” (ZeRO Stage 2) and “full_shard” (ZeRO stage 3) have identical memory footprints. This is likely due to the size of the parameter being modest (277Mn), so sharding it does not give significant memory gain
  • “chkpt” (activation checkpointing) gives the highest boost to the memory footprint, allowing for upto 4x batch size.
  • “wrap” (wraping decoder layers) does not give a boost when activation checkpointing is not used.
  • 10x increase in batch size over baseline

About

Fully Sharded Data Parallel (FSDP) implementation of Transformer XL

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 88.8%
  • Shell 11.2%