Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inv_freq seems not calculated right #24

Closed
dwzhu-pku opened this issue Sep 9, 2023 · 9 comments
Closed

inv_freq seems not calculated right #24

dwzhu-pku opened this issue Sep 9, 2023 · 9 comments

Comments

@dwzhu-pku
Copy link

Hello, I'm thrilled to see that linear and NTK interpolation have been elegantly combined to create a much stronger interpolation strategy—YARN. However, while going through the code in modeling_llama.py, I find myself a bit confused by the calculation of inv_freq, particularly at line398.

According to the YaRN paper, in equation 23, it is stated as follows:

$$ \lambda_d'=(1-\gamma_d)s\lambda_d+\gamma_d\lambda_d $$

Consequently, we can derive:

$$ h(\theta_d) = \frac{2\pi}{\lambda_d'} = \frac{2\pi}{(1-\gamma_d)s\lambda_d+\gamma_d\lambda_d} = \frac{\theta_d}{(1-\gamma_d)s+\gamma_d} $$

However, in the paper, the calculation of $h(\theta_d)$ in equation 25 is different:

$$ h(\theta_d) = \left(\frac{(1-\gamma_d)}{s}+\gamma_d\right)\theta_d \neq \frac{2\pi}{\lambda_d'} $$

Hence, I think there might be some problem with equation 25 and also with line398. Perhaps we can revise the yarn function as follows, since I've empirically found that this fix can further enhance performance:

def revised_yarn(self, device):
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))

        low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow, self.dim, self.base, self.original_max_position_embeddings)
        inv_freq_mask = (1 - _yarn_linear_ramp_mask(low, high, self.dim // 2).float().to(device)) * self.extrapolation_factor
        inv_freq = inv_freq / ((1-inv_freq_mask)*self.scale + inv_freq_mask)

        self.register_buffer("inv_freq", inv_freq, persistent=False)
        self.mscale = float(_yarn_get_mscale(self.scale) * self.attn_factor)
@cebtenzzre
Copy link
Contributor

So in the code (and equation 25), the scaling and mixing is applied to theta, but according to equation 23 (which is probably incorrect) the scaling is applied to the wavelength, lambda - so the division happens in a different place.

What data do you have on the performance of your YaRN implementation?

@bloc97
Copy link
Collaborator

bloc97 commented Sep 9, 2023

Hello! That's an interesting observation... Maybe my derivation was wrong but shouldn't a "stretching" of the wavelength be exactly equivalent to a "compression" of the frequency? Both $s$ and $\gamma$ are multiplicative so I assumed they would be equivalent the other way around, with only the $s$ value needing to be inverted...

Lets pick a dimension $d$ where we had to stretch the wavelength $\lambda$ by 2. No matter how the derivation is made, my intuition says that we should divide the frequency $\theta$ by 2, right?

Maybe I'm missing something here, please correct me if I'm wrong...

@cebtenzzre
Copy link
Contributor

cebtenzzre commented Sep 9, 2023

shouldn't a "stretching" of the wavelength be exactly equivalent to a "compression" of the frequency

That is certainly true, and for plain interpolation, multiplying wavelength by the scale is as same as dividing theta by the scale. But the difference is in the mixing. This is what the Python implementation does:

$$h(\theta_d) = \theta_d \times \left(\frac{(1 - \gamma_d)}{s} + \gamma_d\right)$$ $$\large{\lambda'_d = \frac{\lambda_d}{\frac{(1 - \gamma_d)}{s} + \gamma_d}}$$

And this is what equation 23 in the paper implies:

$$h(\theta_d) = \frac{\theta_d}{(1 - \gamma_d) \times s + \gamma_d}$$ $$\lambda'_d = \lambda_d \times ((1 - \gamma_d) \times s + \gamma_d)$$

@bloc97
Copy link
Collaborator

bloc97 commented Sep 9, 2023

Yes, but because $\gamma$ is symmetric when flipped (it is a ramp function that has a range [0,1]), shouldn't the two equations you wrote be equal? Do the two different python implementations give different results?

@cebtenzzre
Copy link
Contributor

cebtenzzre commented Sep 9, 2023

No, they are not equivalent. Substituting e.g. $\gamma_d = 0.2$ and $s = 2$, we get this for equation 1:

$\large{\lambda'_d = \Large{\frac{\lambda_d}{ (1 - \gamma_d)/s + \gamma_d }}}$
$\large{\lambda'_d = \Large{\frac{\lambda_d}{ (1 - 0.2)/2 + 0.2 }}}$
$\large{\lambda'_d = \Large{\frac{\lambda_d}{ 0.4 + 0.2 }}}$
$\large{\lambda'_d = \Large{\frac{\lambda_d}{0.6}}}$
$\large{\lambda'_d = \lambda_d \times 1.\overline{6}}$

But we get this for equation 2:
$\large{\lambda'_d = \lambda_d \times ((1 - \gamma_d) \times s + \gamma_d)}$
$\large{\lambda'_d = \lambda_d \times ((1 - 0.2) \times 2 + 0.2)}$
$\large{\lambda'_d = \lambda_d \times (1.6 + 0.2)}$
$\large{\lambda'_d = \lambda_d \times 1.8}$

To put it another way: The inverse of a sum is not equal to the sum of the inverses. This is essentially what is going on when $s = 1$. When $s = 1$ and $\gamma_d = 0.5$, I believe one is basically the harmonic mean and the other is the arithmetic mean.

@bloc97
Copy link
Collaborator

bloc97 commented Sep 9, 2023

Thanks, it has never occured to me that mixing frequencies vs mixing wavelengths with a ramp function would give such a difference. I will take some time and correct this mistake...

@bloc97
Copy link
Collaborator

bloc97 commented Sep 10, 2023

Here's two contour plots from wolfram, it is clear that the ramp function is not equal under the two scenarios:
imageimage

For mathematical correctness' sake, I will fix this mistake in the v2 of the preprint. However please do understand that this ramp is arbitrary, and did not have any robust hypothesis supporting its existence. We could have just used an heaviside step function instead of the linear ramp. (However since neural networks might not like that discontinuity, we chose the linear ramp.)

@dwzhu-pku
Copy link
Author

dwzhu-pku commented Sep 10, 2023

Thanks to Cebtenzzre and bloc97 for your insightful discussion! I conducted tests on two implementations mentioned earlier: the original YARN and the revised YARN, using the llama-7b model (version 1, not llama2) on GovReport. I configured the scaling factor to 48 (i.e., stretching to support 96k) and sampled 50 texts longer than 64k from GovReport for evaluation. Note that fine-tuning was not performed, so these results should be easily reproducible.

I varied the input length from 1k to 64k and employed the stretched model to calculate perplexity. Since it is totally within the supported context window, which is 96k, I didnot use sliding window for simplicity.The results are presented in the plot below. Interestingly, in this scenario, the original version performs relatively well when the input length is below 8k, whereas the revised YARN exhibits advantages as the input length increases. This phenomenon appears to be non-coincidental.

As previously discussed, in the revised version, the wavelength $\lambda_d'$ is defined as:

$$ \lambda_d'=\lambda_d((1-\gamma_d)s+\gamma_d) $$

On the other hand, in the original version, the wavelength $\lambda_d'$ is calculated as:

$$ \lambda_d'=\lambda_d\frac{1}{\frac{1-\gamma_d}{s}+\gamma_d} $$

Here are some observations:

  1. the original yarn stretchs wavelength not as much as the revised one. this may explain why the original one performs better when input sequence is short but worse when it gets longer:

$$ ((1-\gamma_d)s+\gamma_d) \ge \frac{1}{\frac{1-\gamma_d}{s}+\gamma_d} \\ \Leftrightarrow (1-s)^2\gamma_d \ge (1-s)^2\gamma_d^2 \Leftrightarrow 1 \ge \gamma_d $$

  1. Although these two formulations can be viewed as different ramp functions, the major problem with the original one is that, as the scaling factor $s$ becomes very large, it weakens the impact it imposes on the wavelength. The reason is that $\frac{1-\gamma_d}{s}\rightarrow 0$ as $s$ becomes very large, for instance, 48. Consequently, $\lambda_d' \rightarrow \frac{\lambda_d}{\gamma_d}$.

@bloc97
Copy link
Collaborator

bloc97 commented Sep 10, 2023

  1. Although these two formulations can be viewed as different ramp functions, the major problem with the original one is that, as the scaling factor s becomes very large, it weakens the impact it imposes on the wavelength. The reason is that 1−γds→0 as s becomes very large, for instance, 48. Consequently, λd′→λdγd.

Interestingly we are also currently investigating the special case where $s \to \infty$. In this special case, both ramps collapses to a
heaviside step function where its center $\eta$ is between $\alpha$ and $\beta$:

$$\alpha \le \eta \le \beta$$

Consequently, dimensions where $r<\eta$ would have infinite wavelength (equivalent to NoPE), and dimensions where $r \geq \eta$ would be unscaled (equivalent to standard RoPE).

This "Truncated" RoPE or "NoRoPE" embedding scheme applied during pretraining could potentially allow very long extrapolation capabilities, we are looking at testing this hypothesis in the near future.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants