Skip to content

The official code for Dropping Backward Propagation (DropBP)

License

Notifications You must be signed in to change notification settings

WooSunghyeon/dropbp

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

21 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

DropBP: Accelerating Fine-Tuning of Large Language Models

This is the official repository for DropBP: Accelerating Fine-Tuning of Large Language Models.

Abstract

Overview
Large language models (LLMs) have achieved significant success across various domains. However, training these LLMs typically involves substantial memory and computational costs during both forward and backward propagation. While parameter-efficient fine-tuning (PEFT) considerably reduces the training memory associated with parameters, it does not address the significant computational costs and activation memory. In this paper, we propose Dropping Backward Propagation (DropBP), a novel approach designed to reduce computational costs and activation memory while maintaining accuracy. DropBP randomly drops layers during backward propagation, which is essentially equivalent to training shallow submodules generated by undropped layers and residual connections. Additionally, DropBP calculates the sensitivity of each layer to assign an appropriate drop rate, thereby stabilizing the training process. DropBP is not only applicable to full fine-tuning but can also be orthogonally integrated with all types of PEFT by dropping layers during backward propagation. Specifically, by using DropBP when fine-tuning LLaMA2-70B through QLoRA, we can reduce training time by 44\%, accelerate convergence to the same perplexity by 1.5$\times$, and enable training with a sequence length 6.2$\times$ larger on a single NVIDIA-A100 80GiB GPU.

Install

  1. Install PyTorch before installing DropBP library.

  2. Build

pip install -v -e .

Usage

  1. Install DropBP layer to your model
  • Note that you have to insert the flops of layers in DropBP layers
  • For instance, in general transformer,
  • FLOPs of attention layers: $8bsh^2+4bhs^2$
  • FLOPs of mlp layers: $16bsh^2$
  • It's okay to input the ratio of FLOPs for each layer, rather than exact
import torch
..
from dropbp.layer import DropBP

...
class Block(nn.Modoule): # transformer block
    def __init__(self, ..):
        self.norm_1 = ...
        self.attn = ...
        self.norm_2 = ...
        self.mlp = ...
        
        # Define DropBP layers
        # The FLOPs below is about general transformer block per batch*seq
        # with intermediate_size = 4*hidden_size
        attn_flops = 8*config.hidden_size**2 + 4*config.hidden_size*self.sequence_length 
        mlp_flops = 16*config.hidden_size**2
        self.dropbp_attn = DropBP(flops=attn_flops)
        self.dropbp_mlp = DropBP(flops=mlp_flops)
        ...
    def forward(self, x, ..):
        h = self.attn(self.norm_1(x), ...)
        x = self.dropbp_attn(h)+x   # instead of 'x = h+x'  
        h = self.mlp(self.norm_2(x))
        x = self.dropbp_mlp(h)+x    # instead of 'x = h+x'    
        return x
  1. Integrate the DropBP API into your training code
import torch
from dropbp.handler import DropBPHandler

model = ... # user define model
optimizer = ... # user define optimizer

dropbp_handler = DropBPHandler(model) # define drop handler
dropbp_handler.set_initial_drop_rate(drop_rate) # set a drop rate

# training loop
for iter in ...
    def backprop: # define backprop
        output = model(data)
        loss = loss_func(output, target)
        optimizer.zero_grad() # this line must be present
        loss.backward()

    if iter == int(max_iter * 0.1) # adjust drop rates at 10% of training process 
        dropbp_handler.sensitivity_based_drop_bp(backprop, drop_rate) # it automatically adjusts drop rates
    
    out = model(data)
    loss = loss_func(output,target)
    non_grad = dropbp_handler.detact_non_grad() # detect when all layers are dropped
    if not(non_grad): # exclude the above situation for avoiding error
        loss.backward()
    optimizer.step()
    ...

Applications

Our DropBP library can be very easily integrated with existing training code as:

Lit-GPT

About

The official code for Dropping Backward Propagation (DropBP)

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages