Skip to content

Reduce the size of pretrained Hugging Face models via vocabulary trimming.

License

Notifications You must be signed in to change notification settings

IamAdiSri/hf-trim

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

28 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

hf-trim

Python HuggingFace PyTorch

Downloads PyPI GitHub tag (latest by date) PyPI - License

A package to reduce the size of 🤗 Hugging Face models via vocabulary trimming.

The library currently supports the following models (and their pretrained versions available on the Hugging Face Models hub);

  1. BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation
  2. mBART: Multilingual Denoising Pre-training for Neural Machine Translation
  3. T5: Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer
  4. mT5: A Massively Multilingual Pre-trained Text-to-Text Transformer

"Why would I need to trim the vocabulary on a model?" 🤔

To put it simply, vocabulary trimming is a way to reduce a language model's memory footprint while retaining most of its performance.

Read more here.

Citation

If you use this software, please cite it as given below;

@software{Srivastava_hf-trim,
author = {Srivastava, Aditya},
license = {MPL-2.0},
title = {{hf-trim}}
url = {https://github.com/IamAdiSri/hf-trim}
}

Installation

You can run the following command to install from PyPI (recommended);

$ pip install hf-trim

You can also install from source;

$ git clone https://github.com/IamAdiSri/hf-trim
$ cd hf-trim
$ pip install .

Usage

Simple Example

from transformers import MT5Config, MT5Tokenizer, MT5ForConditionalGeneration
from hftrim.TokenizerTrimmer import TokenizerTrimmer
from hftrim.ModelTrimmers import MT5Trimmer

data = [
        " UN Chief Says There Is No Military Solution in Syria", 
        "Şeful ONU declară că nu există o soluţie militară în Siria"
]

# load pretrained config, tokenizer and model
config = MT5Config.from_pretrained("google/mt5-small")
tokenizer = MT5Tokenizer.from_pretrained("google/mt5-small")
model = MT5ForConditionalGeneration.from_pretrained("google/mt5-small")

# trim tokenizer
tt = TokenizerTrimmer(tokenizer)
tt.make_vocab(data)
tt.make_tokenizer()

# trim model
mt = MT5Trimmer(model, config, tt.trimmed_tokenizer)
mt.make_weights(tt.trimmed_vocab_ids)
mt.make_model()

You can directly use the trimmed model with mt.trimmed_model and the trimmed tokenizer with tt.trimmed_tokenizer.

Saving and Loading

# save with
tt.trimmed_tokenizer.save_pretrained('trimT5')
mt.trimmed_model.save_pretrained('trimT5')

# load with
config = MT5Config.from_pretrained("trimT5")
tokenizer = MT5Tokenizer.from_pretrained("trimT5")
model = MT5ForConditionalGeneration.from_pretrained("trimT5")

Limitations

  • Fast tokenizers are currently unsupported.
  • Tensorflow and Flax models are currently unsupported.

Roadmap

  • Add support for MarianMT models.
  • Add sup