Skip to content

ShiftAddLLM: Accelerating Pretrained LLMs via Post-Training Multiplication-Less Reparameterization

License

Notifications You must be signed in to change notification settings

GATECH-EIC/ShiftAddLLM

Repository files navigation

👉 ShiftAddLLM: Accelerating Pretrained LLMs via Post-Training Multiplication-Less Reparameterization


Your GPU-friendly multiplication-free LLMs without training or fine-tuning!

ShiftAddLLM: Accelerating Pretrained LLMs via Post-Training Multiplication-Less Reparameterization
Haoran You, Yipin Guo, Yichao Fu, Wei Zhou, Huihong Shi, Xiaofan Zhang,
Souvik Kundu, Amir Yazdanbakhsh, Yingyan (Celine) Lin
Georgia Institute of Technology, Intel Labs, Google, Google DeepMind


News 🔥🔥 !

  • [ To Do ] Update the kernel evaluation guideline.
  • [ ✅ New ] Jun. 13, 2024. 🤗 Released our model checkpoints to Huggingface!
  • [ ✅ New ] Jun. 10, 2024. 💥 ShiftAddLLM's PyTorch implementation codes are released!

Table of Content

Brief Introduction

Basic Usage

Reproduce ShiftAddLLM

Citation & Acknowledgement

Brief Introduction

Large language models (LLMs) excel in language tasks but struggle on resource-constrained devices due to high memory demands and latency from dense multiplications. Shift-and-add reparameterization replaces costly multiplications with hardware-friendly operations in LLMs' attention and MLP layers, but current methods need training from scratch or fine-tuning. We propose ShiftAddLLM, which accelerates pretrained LLMs via post-training shift-and-add reparameterization. We quantize weight matrices into binary matrices and scaling factors, reparameterizing multiplications into shifts, adds, and look-up table queries. Our multi-objective optimization minimizes reparameterization errors, and an automated bit allocation strategy reduces memory usage and latency. Experiments on five LLM families and eight tasks consistently validate the effectiveness of ShiftAddLLM, achieving average perplexity improvements of 5.6 and 22.7 points at comparable or lower latency compared to the most competitive quantized LLMs at 3 and 2 bits, respectively, and more than 80% memory and energy reductions over the original LLMs.

To avoid fine-tuning after reparameterization, our ShiftAddLLM mimics the original LLM multiplications using Binary-Coding Quantization (BCQ) with customized CUDA kernels, eliminating the need for dequantization. As shown in the above figure, ShiftAddLLM quantizes pretrained weights into binary matrices $\mathbf{b}$ and powers of two scaling factors $\alpha$. During optimization, scaling factors are further quantized to powers of two. We replace weight-activation multiplications with bitwise shifts and LUT-based lookups, efficiently implemented on GPUs. This approach simplifies hardware operations, reducing redundant computations and enabling post-training quantization of all pretrained weights in LLMs.

To reduce accuracy loss, we present a multi-objective optimization method to minimize both weight and output activation reparameterization errors. Additionally, considering the varying sensitivity across layers to reparameterization, we develop an automated bit allocation strategy to further reduce memory usage and latency. More technical details can be found in our paper.

Basic Usage

Environment Setup

conda env create -f environment.yml
conda activate shiftaddllm
export PYTHONPATH='YOUR-PATH-TO-SHIFTADDLLM-REPO'

Core Optimization Options

  • model: huggingface path of the model to quantize.
  • dataset: which dataset you want to use as calibration data.
  • wbits: number of bits to use for quantization; use 16 for evaluating base model.
  • groupsize: groupsize to use for quantization; default uses full row.
  • act-order: whether to apply the activation order GPTQ heuristic.
  • bcq: whether to quantize weights with binary coded quantization (bcq).
  • bcq_round: steps to iterate bcq quantization.
  • columnwise: whether to use columnwise - bcq - round to power of 2 - quantization to evaluate model.
  • block_quant & cust_group: whether to use blockwise (8 column by 1/8 rows for 1 quantize param) - bcq - round to power of 2 - quantization to evaluate model. Need to use with 'columnwise' set.
  • use_bst: whether to use binary search to get BinaryWeight.
  • apot_nums: set nums shift weight for quantization.
  • acc: whether to use Ours(acc.) to quantize the model.
  • lat: whether to use Ours(lat.) to quantize the model. Only one of acc and lat should be set.

Use Reparameterized Weights Directly

You can download our reparameterized ShiftAddLLM model checkpoints from our Huggingface homepage!

Evaluate ShiftAddLLM (Acc.)

The weights in ShiftAddLLM (Acc.) mode are stored in an FP16 precision and are formatted to comply with the official Huggingface interface.

To use these weights, you can directly call the Huggingface API (note: the weight size here is the same as the original weight, aiming to verify the accuracy after reparameterization).

# Load model directly
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("ShiftAddLLM/Llama-2-70b-wbits2-acc")
model = AutoModelForCausalLM.from_pretrained("ShiftAddLLM/Llama-2-70b-wbits2-acc")

To verify this in our code, you can use:

CUDA_VISIBLE_DEVICES=0 python model/llama.py \
    ShiftAddLLM/Llama-2-70b-wbits2-acc

Evaluate ShiftAddLLM (Lat.)

The weights in ShiftAddLLM (Lat.) mode are packed and stored in an Int32 format. This significantly reduces the required storage compared to the original weights. However, the weights for Latency mode need to be loaded using the method specified in our code.

To use these weights, you need first to download the model weights repository locally. For example:

git clone https://huggingface.co/ShiftAddLLM/opt66b-2bit-lat

Then, specify the file path where the model weights are stored in the script. Ensure that the model name and wbits are matched with the downloaded weights.

CUDA_VISIBLE_DEVICES=0 python model/opt.py \
    facebook/opt-6.7b \
    --wbits 2 \
    --lat \
    --load_temp_storage <packed_weight_dir>

Reproduce ShiftAddLLM

ShiftAddLLM (Acc.)

To quantize LLMs using our ShiftAddLLM (Acc.) method with column-wise scaling factors and evaluate their performance, we provide scripts for five different LLM families.

ShiftAddLLM (Lat.)

To quantize LLMs using our ShiftAddLLM (Lat.) method with block-wise scaling factors and evaluate their performance, we provide scripts for five different LLM families.

Zero-Shot Downstream Task Evaluation

To evaluate quantized LLMs on seven downstream tasks for zero-shot task accuracy evaluation, run:

python3 main.py  <model_name> <calibration_dataset> --task <task_name> --num_fewshot <num_fewshot> 

We also provide example scripts for two LLM families.

Citation & Acknowledgement

@article{you2024shiftaddllm,
  title={ShiftAddLLM: Accelerating Pretrained LLMs via Post-Training Multiplication-Less Reparameterization},
  author={You, Haoran and Guo, Yipin and Fu, Yichao and Zhou, Wei and Shi, Huihong and Zhang, Xiaofan and Kundu, Souvik and Yazdanbakhsh, Amir and Lin, Yingyan},
  journal={arXiv preprint arXiv:2406.05981},
  year={2024}
}

Thanks to OPTQ, LUT-GEMM, and DeepShift for their wonderful work and codebase!

Disclaimer:

This “research quality code” is for Non-Commercial purposes and provided by the contributors “As Is” without any express or implied warranty of any kind. The organizations (Georgia Tech or Intel or Google or Google DeepMind) involved do not own the rights to the data sets used and do not confer any rights to it. The organizations (Georgia Tech or Intel or Google or Google DeepMind) do not warrant or assume responsibility for the accuracy or completeness of any information, text, graphics, links or other items within the code. A thorough security review has not been performed on this code. Additionally, this repository may contain components that are out of date or contain known security vulnerabilities.

About

ShiftAddLLM: Accelerating Pretrained LLMs via Post-Training Multiplication-Less Reparameterization

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages