A package to reduce the size of 🤗 Hugging Face models via vocabulary trimming.
The library currently supports the following models (and their pretrained versions available on the Hugging Face Models hub);
- BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation
- mBART: Multilingual Denoising Pre-training for Neural Machine Translation
- T5: Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer
- mT5: A Massively Multilingual Pre-trained Text-to-Text Transformer
To put it simply, vocabulary trimming is a way to reduce a language model's memory footprint while retaining most of its performance.
Read more here.
If you use this software, please cite it as given below;
@software{Srivastava_hf-trim,
author = {Srivastava, Aditya},
license = {MPL-2.0},
title = {{hf-trim}}
url = {https://github.com/IamAdiSri/hf-trim}
}
You can run the following command to install from PyPI (recommended);
$ pip install hf-trim
You can also install from source;
$ git clone https://github.com/IamAdiSri/hf-trim
$ cd hf-trim
$ pip install .
from transformers import MT5Config, MT5Tokenizer, MT5ForConditionalGeneration
from hftrim.TokenizerTrimmer import TokenizerTrimmer
from hftrim.ModelTrimmers import MT5Trimmer
data = [
" UN Chief Says There Is No Military Solution in Syria",
"Şeful ONU declară că nu există o soluţie militară în Siria"
]
# load pretrained config, tokenizer and model
config = MT5Config.from_pretrained("google/mt5-small")
tokenizer = MT5Tokenizer.from_pretrained("google/mt5-small")
model = MT5ForConditionalGeneration.from_pretrained("google/mt5-small")
# trim tokenizer
tt = TokenizerTrimmer(tokenizer)
tt.make_vocab(data)
tt.make_tokenizer()
# trim model
mt = MT5Trimmer(model, config, tt.trimmed_tokenizer)
mt.make_weights(tt.trimmed_vocab_ids)
mt.make_model()
You can directly use the trimmed model with mt.trimmed_model
and the trimmed tokenizer with tt.trimmed_tokenizer
.
# save with
tt.trimmed_tokenizer.save_pretrained('trimT5')
mt.trimmed_model.save_pretrained('trimT5')
# load with
config = MT5Config.from_pretrained("trimT5")
tokenizer = MT5Tokenizer.from_pretrained("trimT5")
model = MT5ForConditionalGeneration.from_pretrained("trimT5")
- Fast tokenizers are currently unsupported.
- Tensorflow and Flax models are currently unsupported.
- Add support for MarianMT models.
- Add sup