Skip to content

Implementation of the LongRoPE: Extending LLM Context Window Beyond 2 Million Tokens Paper

Notifications You must be signed in to change notification settings

jshuadvd/LongRoPE

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

tags datasets metrics
large-language-models
text-generation
context-extension
transformer-models
fine-tuning
long-contexts
natural-language-processing
context-window
context-length
nlp
llm
llm-context-window
llm-context-length
pg19
wikitext
custom-dataset
perplexity
accuracy

LongRoPE: Extending LLM Context Window Beyond 2 Million Tokens

Table of Contents

Introduction

The paper introduces LongRoPE, a method to extend the context window of large language models (LLMs) beyond 2 million tokens.

The key ideas are:

  • Identify and exploit two forms of non-uniformities in positional embeddings to minimize information loss during interpolation. This enables 8x context extension without fine-tuning.

  • Use an efficient progressive extension strategy with 256k fine-tuning to reach 2048k context, instead of directly fine-tuning an extremely large context.

  • Adjust embeddings for shorter contexts to recover performance within original window size.

The method is applied to LLaMA2 and Mistral. Experiments across various tasks demonstrate LongRoPE's effectiveness in maintaining performance from 4k to 2048k context lengths.

Description

The Transformer architecture struggles with the quadratic computational complexity of self-attention and its lack of generalization to token positions unseen at training time. To scale the self-attention computation to a large context, various methods have been proposed, such as the RoPE, AliBi, attention sinks, etc. Nonetheless, none of these solutions can effectively scale to context with millions of tokens while preserving the model's accuracy.

This paper presents a new technique, LongRoPE, expanding the context window of LLMs to over 2 million tokens.

LongRoPE utilizes a progressive extension strategy to attain a 2048k context window without necessitating direct fine-tuning on exceedingly lengthy texts, which are both rare and difficult to procure. This strategy initiates with a 256k extension on a pre-trained LLM, followed by fine-tuning at this length.

To address potential performance declines in the original (shorter) context window, LongRoPE further adjusts the RoPE rescale factors on the extended LLM, scaling down to 4k and 8k context windows on the 256k fine-tuned LLM using its search algorithm to minimize positional interpolation. During inference for sequences under 8k in length, RoPE is updated with these meticulously searched rescale factors.

LongRoPE

Testing across various LLMs and tasks requiring long contexts has validated LongRoPE's efficacy. The method significantly maintains low perplexity across evaluation lengths from 4k to 2048k tokens, achieves above 90% accuracy in passkey retrieval, and delivers accuracy comparable to standard benchmarks within a 4096 context window

LongRoPE

Potential implementations

  • Enable in-context learning with more examples to boost LLM reasoning
  • Build LLM agents that leverage longer context for tasks like dialog and question answering
  • Summarize very long documents by utilizing the full document context
  • Improve few-shot learning by providing more contextual examples to models
  • Enable long-term memory by utilizing the full context window

Model Architecture

An in-depth look at the structural modifications and their implications for model performance.

The LongRoPE model architecture is designed to extend the context window of large language models (LLMs) to over 2 million tokens, addressing the limitations of traditional Transformer architectures. The key innovation lies in the progressive extension strategy and the adjustment of positional embeddings.

Key components include:

  1. Rotary Position Encoding (RoPE):
class RoPEPositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=1000000, base=10000):
        super().__init__()
        self.d_model = d_model
        self.max_len = max_len
        self.base = base
        self.theta = torch.tensor([base ** (-2 * (i // 2) / d_model) for i in range(d_model)])

    def forward(self, positions):
        angles = positions.unsqueeze(-1) * self.theta
        sin_cos = torch.stack([angles.cos(), angles.sin()], dim=-1)
        return sin_cos.view(*sin_cos.shape[:-2], -1)
  1. Non-uniform Interpolation:
def non_uniform_interpolation(pos_embed, extension_ratio, lambda_factors, n_hat):
  d_model = pos_embed.shape[-1]
  interpolated_pos = pos_embed.clone()
  for i in range(d_model // 2):
      mask = torch.arange(pos_embed.shape[-2], device=pos_embed.device) < n_hat
      scale = torch.where(mask, torch.ones_like(pos_embed[..., 0], device=pos_embed.device),
                          1 / (lambda_factors[i] * extension_ratio))
      interpolated_pos[..., 2 * i] *= scale
      interpolated_pos[..., 2 * i + 1] *= scale
  return interpolated_pos
  1. Progressive Extension Strategy:
 def progressive_extension(model, data, base_length, target_length, population_size, num_mutations, num_crossovers, max_iterations):
     # Extend to 128k
     lambda_factors_128k, n_hat_128k = search_lambda_factors(model, data, 128000 / base_length, population_size, num_mutations, num_crossovers, max_iterations)
     model = fine_tune(model, data, 128000, lambda_factors_128k, n_hat_128k, steps=400)

     # Extend to 256k
     lambda_factors_256k, n_hat_256k = search_lambda_factors(model, data, 256000 / base_length, population_size, num_mutations, num_crossovers, max_iterations)
     model = fine_tune(model, data, 256000, lambda_factors_256k, n_hat_256k, steps=600)

     # Extend to target length
     if target_length > 256000:
         final_lambda_factors, final_n_hat = search_lambda_factors(model, data, target_length / base_length, population_size // 2, num_mutations // 2, num_crossovers // 2, max_iterations // 2)
         model.lambda_factors["2048k"] = final_lambda_factors
         model.n_hat["2048k"] = final_n_hat

     return model, final_lambda_factors, final_n_hat, lambda_factors_256k, n_hat_256k

Progressive Extension Strategy

The architecture begins with a pre-trained LLM and extends its context window incrementally. Initially, the model is fine-tuned to handle a context length of 256k tokens. This progressive approach avoids the need for direct fine-tuning on extremely long texts, which are rare and computationally expensive to process. By gradually increasing the context length, the model can adapt more effectively to longer sequences.

Positional Embeddings Adjustment

To maintain performance across varying context lengths, LongRoPE adjusts the Rotary Positional Embeddings (RoPE). The model identifies and exploits non-uniformities in positional embeddings to minimize information loss during interpolation. This allows for an 8x context extension without the need for fine-tuning. Additionally, the model employs a search algorithm to find optimal rescale factors for shorter contexts (e.g., 4k and 8k tokens) on the 256k fine-tuned LLM. These adjustments ensure that the model retains high performance even within the original context window size.

Structural Modifications

The architecture incorporates several structural modifications to handle the increased context length efficiently:

  • Layer Scaling: Adjustments are made to the scaling of layers to ensure stability and performance as the context window grows.

  • Memory Management: Efficient memory management techniques are employed to handle the large context sizes without overwhelming the system resources.

  • Attention Mechanisms: Enhanced attention mechanisms are integrated to ensure that the model can focus on relevant parts of the input sequence, even with the extended context.

  • Token-wise Attention: Token-wise attention mechanisms are introduced to capture the contextual relationships between tokens, allowing the model to better understand the semantic meaning of the input.

Performance and Applications

Experiments demonstrate that LongRoPE maintains low perplexity across evaluation lengths from 4k to 2048k tokens and achieves high accuracy in tasks requiring long contexts. This makes it suitable for various applications, including in-context learning, long document summarization, and few-shot learning.

For more detailed information, please refer to the full paper here.

Implementation Highlights

Insights into the coding and operational specifics that enable LongRoPE's functionality. This may include snippets or pseudocode illustrating key components.

For more detailed information, please refer to the paper.

Usage

Comprehensive examples demonstrating how to leverage LongRoPE for various applications, from text analysis to generating extensive documents.

# Example usage
data_path = "path/to/your/dataset"
d_model = 512
n_heads = 8
num_layers = 6
base_length = 4096
target_length = 2048 * 1024

data = load_data(data_path)
model = LongRoPEModel(d_model, n_heads, num_layers, base_length)
model = model.extend_context(data, target_length)

input_ids = torch.randn(2, target_length, d_model)
output = model(input_ids)
print(output.shape)  # Expected shape: (batch_size, target_length, d_model)

Advanced Usage

Custom Dataset Training

To train on a custom dataset:

  1. Prepare your dataset in a format compatible with the datasets library.
  2. Implement a custom preprocess_data function if needed.
  3. Use the extend_context method with your custom data.

Hyperparameter Tuning LongRoPE's performance can be sensitive to hyperparameters. Key parameters to tune include:

population_size, num_mutations, and num_crossovers in the lambda factor search Learning rate and scheduler parameters for fine-tuning gradient_accumulation_steps for training stability

Results

My implementation of LongRoPE achieves the following results:

  1. Perplexity:

    • 4k context: X.XX
    • 128k context: X.XX
    • 2048k context: X.XX
  2. Passkey Retrieval Accuracy:

    • 4k context: XX%
    • 128k context: XX%
    • 2048k context: XX%
  3. Accuracy:

    • 4k context: XX%
    • 128k context: XX%
    • 2048k context: XX%
  4. Comparison with baseline models:

Citation

@article{ding2024longrope,
  title={LongRoPE: Extending LLM Context Window Beyond 2 Million Tokens},
  author={Ding, Yiran and Zhang, Li Lyna and Zhang, Chengruidong and Xu, Yuanyuan and Shang, Ning and Xu, Jiahang and Yang, Fan and Yang, Mao},
  journal={arXiv preprint arXiv:2402.13753},
  year={2024}
}

Note: This repository is a work in progress and is not yet ready for production use. Please refer to the paper for more details.