Skip to content

Commit

Permalink
add images and fix formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
sid committed Apr 20, 2021
1 parent 6de2626 commit 914f512
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 52 deletions.
152 changes: 100 additions & 52 deletions content-blog/rotary-embeddings.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@ mathjax: True

by Stella Biderman, Sid Black, Charles Foster, Leo Gao, Eric Hallahan, Horace He, Ben Wang, and Phil Wang

<br>

## TL;DR:
Rotary Positional Embedding (RoPE) is a new type of position encoding that unifies absolute and relative approaches. Developed by Jianlin Su in a series of blog posts earlier this year [12, 13], it has already garnered widespread interest in some Chinese NLP circles. However this development is not widely known to the global community, in large part due to the lack of English-language resources. This post walks through the method as we understand it, with the goal of bringing it to the attention of the wider academic community. In general we have found that, across a large suite of setups including regular, linear, and local self-attention, it **either matches or surpasses all other methods currently available for injecting positional information into transformers.**

<br>

## What's the Problem?

Since Vaswani et al., 2017 [16] there have been many schemes introduced for encoding positional information in transformers. When applying self-attention to a given domain, the choice of position encoding typically involves tradeoffs between simplicity, flexibility, and efficiency. For example, learned absolute positional encoding is very simple, but may not generalize and are not always particularly meaningful due to the common practice [1, 3, 9, 15] of packing short sentences and phrases together in a single context and breaking up sentences across contexts. On the other hand, relative positional encodings
Expand All @@ -19,11 +23,14 @@ Another major limitation of existing methods is that they do not work with effic

A principled, easy to implement, and generally-applicable method for relative position encoding---one that works for both vanilla and “efficient” attention---is of great interest. Rotary Positional Embedding (RoPE) is designed to address this need.

<br>

## What's the Solution?

In this section we introduce and derive the rotary positional embedding. We begin with discussing the intuition, before presenting a full derivation.
<br>

#### Intuition
### Intuition

We would like to find a positional encoding function $f(\mathbf{x}, \ell)$ for an item $\mathbf{x}$ and its position $\ell$ such that, for two items $\mathbf{q}$ and $\mathbf{k}$ at positions $m$ and $n$, the inner product between $f(\mathbf{q}, m)$ and $f(\mathbf{k}, n)$ is sensitive only to the values of $\mathbf{q}$, $\mathbf{k}$, and their relative position $m-n$. This is related in spirit to the kernel trick: we are searching for a feature map such that its kernel has certain properties. A key piece of information is the geometric definition of the dot product between Euclidean vectors: $\mathbf{q} \cdot \mathbf{k} = \mathbf{q} \mathbf{k} \cos(\theta_{qk})$

Expand All @@ -40,13 +47,18 @@ The following is an example illustrating the core idea of RoPE —a more rigorou
&= \mathrm{RoPE}(q_j k_j, m - n)
\end{align}

#### Visual Intuision
<br>

### Visual Intuition
<figure>
<iframe id="waveplate-animation" src="/waveplate.html" class="auto" style="border-width:0; display: block;
<br>
<iframe id="waveplate-animation" src="/images/blog/rotary-embeddings/waveplate.html" class="auto" style="border-width:0; display: block;
margin-right: auto;
margin-left: auto;
height: 300px" loading="lazy" ></iframe>
height: 400px;
width: 800px" loading="lazy" ></iframe>
<figcaption>A quarter-waveplate can change the polarization of an electromagnetic wave.</figcaption>
<br>
</figure>

To see how relative position might be preserved in this transformation, we can look to an analogous situation in classical electrodynamics.
Expand All @@ -55,6 +67,8 @@ We imagine a linearly polarized electromagnetic wave that is sent through a quar

As the wave travels through the waveplate, we can see how the magnitude of the wave is preserved. We can also better see how the relative position may be encoded as the angle between subsequent timesteps: the angle between timesteps, and therefore distance along the axis of travel, is constant. This means the positional information must be orthogonal to the amplitude in the modulated wave.

<br>

### Derivation

We begin with absolute positional information: for each token, we know where it is in the sequence. However dot products (and therefore attention) do not preserve absolute positional information, so if we encode that positional information in the absolute position of the embeddings, we will lose a significant amount of information. Dot products do preserve relative position however, so if we can encode the absolute positional information into the token embeddings in a way that only leverages relative positional information, that will be preserved by the attention function.
Expand Down Expand Up @@ -111,6 +125,7 @@ where $M_j=\begin{pmatrix}\cos m\theta_j & -\sin m\theta_j \\\sin m\theta_j & \c
A response many of us at EleutherAI had when first coming across this was "how does this differ from sinusoidal embeddings," so we feel it is worth discussing this comparison. There are two ways that rotary embeddings are different from sinusoidal s:
1. Sinusoidal embeddings apply to each coordinate individually, while rotary embeddings mix pairs of coordinates
2. Sinusoidal embeddings add a $\cos(m\theta)$ or $\sin(m\theta)$ term, while rotary embeddings use a multiplicative factor.
<br><br>

## Okay, what About in Practice?

Expand Down Expand Up @@ -152,72 +167,105 @@ def apply_rotary_pos_emb(q, k, cos, sin):
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
```
**N.B:** The layout of the queries and keys in GPT-NeoX, following Megatron, is `[seq, batch, heads, hdim]`, in order to avoid memory-intensive transpose operations. The code will need to be modified to work with the conventional layout of `[batch, seq, heads, hdim]`.
<br>

### Experiments

We have found rotary embeddings to be effective for many varieties of attention.

**Comparison against other PEs for Global attention:** We conducted [comparisons](https://wandb.ai/eleutherai/neox/reports/Rotary-Test-3--Vmlldzo2MTIwMDM) of rotary embeddings with learned absolute positional embeddings, used in GPT-3 [1], and the learned relative positional embeddings (henceforth RPE) used in T5 [10] using our GPT-Neox codebase. Comparisons were done using 125M parameter models with the same hyperparameters as the equally-sized model from [1]. Models were trained on [OpenWebText2]({https://www.eleuther.ai/projects/open-web-text2/), a large and diverse dataset of online text. We see faster convergence of training and validation curves and a lower overall validation loss with a minimal decrease in throughput.

\begin{figure}[ht]
\centering
\includegraphics[width=0.9\textwidth]{Rotary /rope-learned-rpe.png}
\caption{OWT2 validation loss with 150M parameter models}
\label{fig:rope-learned}
\end{figure}

Final validation loss / ppl scores on OWT2 validation set at 55k steps (~30B tokens):

\begin{center}
\begin{tabular}{c c c}
\toprule
Type & OWT2 Loss & OWT2 Ppl. \\ [0.5ex]
\midrule
Learned Absolute & 2.809 & 16.59 \\
T5 RPE & 2.801 & 16.46 \\
Rotary & 2.759 & 15.78 \\
\bottomrule
\end{tabular}
\end{center}
<figure>
<img src="/images/blog/rotary-embeddings/rope-learned-rpe.png" alt="GPT-NeoX experiments" style="width:800px">
<figcaption>OWT2 validation loss with 150M parameter models in GPT-NeoX</figcaption>
</figure>


<figure>
<center>
<figcaption><b>Final validation loss / ppl scores on OWT2 validation set at 55k steps (~30B tokens):</b></figcaption>
<br>
<table style="width:50%">
<tr>
<th><b>Type</b></th>
<th>OWT2 Loss</th>
<th>OWT2 Ppl.</th>
</tr>
<tr>
<td><b>Learned Absolute</b></td>
<td>2.809</td>
<td>16.59</td>
</tr>
<tr>
<td><b>T5 RPE</b></td>
<td>2.801</td>
<td>16.46</td>
</tr>
<tr>
<td><b>Rotary</b></td>
<td>2.759</td>
<td>15.78</td>
</tr>
</table>
</center>
</figure>
<br>


**Billion+ parameter models:** We additionally conducted additional larger scale experiments with the [mesh-transformer-jax](https://github.com/kingoflolz/mesh-transformer-jax) codebase and 1.4B parameter models, against baselines of learned absolute position embeddings and T5 RPE. Hyperparameters similar to GPT3's 1.3B model were used, with the dataset being the Pile [3]. A similar increase in convergence speed was observed as seen over learned absolute (~30\%), and a smaller improvement (10-20\%) was still seen over the T5 relative position encoding, demonstrating scalability into the billion parameter regimen. For full details, see [here](https://wandb.ai/eleutherai/mesh-transformer-jax/reports/Position-encoding-shootout--Vmlldzo2MTg2MzY).

\begin{figure}[ht]
\centering
\includegraphics[width=0.9\textwidth]{Rotary /jax-experiments.png}
\caption{Pile validation loss with 1.5B parameter models}
\label{fig:rope-learned}
\end{figure}

Final validation loss / ppl scores on Pile validation set at 8k steps (~5B tokens):

\begin{center}
\begin{tabular}{c c c}
\toprule
Type & Pile Loss & Pile Ppl. \\ [0.5ex]
\midrule
Learned Absolute & 2.24 & 9.393 \\
T5 RPE & 2.223 & 9.234 \\
Rotary & 2.173 & 8.784 \\
\bottomrule
\end{tabular}
\end{center}
<figure>
<img src="/images/blog/rotary-embeddings/jax-experiments.png" alt="Jax experiments" style="width:800px">
<figcaption>Pile validation loss with 1.5B parameter models</figcaption>
</figure>



<figure>
<center>
<figcaption><b>Final validation loss / ppl scores on Pile validation set at 8k steps (~5B tokens):</b></figcaption>
<br>
<table style="width:50%">
<tr>
<th><b>Type</b></th>
<th>OWT2 Loss</th>
<th>OWT2 Ppl.</th>
</tr>
<tr>
<td><b>Learned Absolute</b></td>
<td>2.24</td>
<td>9.393</td>
</tr>
<tr>
<td><b>T5 RPE</b></td>
<td>2.223</td>
<td>9.234</td>
</tr>
<tr>
<td><b>Rotary</b></td>
<td>2.173</td>
<td>8.784</td>
</tr>
</table>
</center>
</figure>

**Comparison against learned absolute for Performer:** Performer [2] is an example of an alternative attention mechanism designed to avoid quadratic bottlenecks with respect to sequence lengths. We ran small scale tests of Performer on enwiki8, for 8 layer char-based transformers with 512 dimensions and 8 heads. [These tests indicated](https://wandb.ai/lucidrains/eleuther-blogpost/reports/performer-rotary--Vmlldzo2MTgyNDg) that substituting rotary embeddings into the Performer leads to stark decreases in validation loss and to rapid convergence. Though these improvements do not close the gap between efficient and quadratic attention mechanisms, such a significant improvement makes mechanisms like Performer more attractive.

In smaller scale tests, we have also put RoPE head to head against other alternatives including the relative position method of Shaw et al. [11], TUPE [5], and position-infused attention [8], seeing positive results across the board.
\begin{figure}[ht]
\centering
\includegraphics[width=0.9\textwidth]{Rotary /preformer.png}
\caption{a nice plot}
\label{fig:preformer}
\end{figure}

#### Runtime

<figure>
<img src="/images/blog/rotary-embeddings/performer.png" alt="x-transformers experiments" style="width:800px">
<figcaption>Enwik8 validation/train loss with performer</figcaption>
</figure>
<br>

### Runtime
In general, we find that the runtime cost of rotary embeddings is fairly negligible. With the above implementation, we find that applying the rotary embeddings is naively about 4-5x the cost of applying additive positional embeddings. With the addition of a fusing optimizer like Torchscript, the runtime can be reduced to about 2-2.5x the runtime of additive positional embeddings. Concretely, for query and key tensors of shape $[2048, 16, 12, 64]$, applying rotary embeddings take 5.3 milliseconds, while applying additive positional embeddings takes 2.1 milliseconds.

Unlike standard positional embeddings, however, rotary embeddings must be applied at every layer. As large transformer models are typically dominated by matrix multiplies, we find that the overall overhead remains negligible. With fusion, we find that rotary embeddings imposes a 1-3\% overhead across a range of transformer sizes.

<br> <br>

## Conclusion
Rotary embeddings make it possible to implement relative attention in a straightforward and efficient manner. We are excited to read the upcoming rotary positional embeddings paper from the original authors and the work it inspires. Simple improvements to the transformer architecture that carry over robustly between different types of self-attention are few and far between [6].
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
File renamed without changes.

0 comments on commit 914f512

Please sign in to comment.