Gemma 2B with recurrent local attention with context length of up to 10M. Our implementation uses <32GB of memory!
Features:
- 10M sequence length on Gemma 2B.
- Runs on less than 32GB of memory.
- Native inference optimized for cuda.
- Recurrent local attention for O(N) memory.
Note: This is a very early checkpoint of the model. Only 200 steps. We plan on training for a lot more tokens!
Install the requirements:
pip install -r requirements.txt
Install the model from huggingface - Huggingface Model.
python main.py
Change the main.py
inference code to the specific prompt you desire.
model_path = "./models/gemma-2b-10m"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = GemmaForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16
)
prompt_text = "Summarize this harry potter book..."
with torch.no_grad():
generated_text = generate(
model, tokenizer, prompt_text, max_length=512, temperature=0.8
)
print(generated_text)
The largest bottleneck (in terms of memory) for LLMs is the KV cache. It grows quadratically in vanilla multi-head attention, thus limiting the size of your sequence length.
Our approach splits the attention in local attention blocks as outlined by InfiniAttention. We take those local attention blocks and apply recurrance to the local attention blocks for the final result of 10M context global atention.
A lot of the inspiration for our ideas comes from the Transformer-XL paper.
For more context about our motivations, implementation details, and the theory behind the work, check out our technical overview on medium.
This was built by: