Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement kv-caching, add more variance of extrapolation (CFG) and interpolation methods #55

Merged
merged 10 commits into from
Nov 3, 2023

Conversation

honglu2875
Copy link
Contributor

@honglu2875 honglu2875 commented Oct 31, 2023

  • Implement kv-caching
  • Add a few CFG variance (varying CFG strength, negative prompt, ...)
  • Optimizations (added with torch.inference_mode(): around decoding)

super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
if device is None: # todo: maybe we don't need this...
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably best to not explicitly do this. Could instead perhaps do device = self.device? I'm worried that this will break the distributed code (with accelerate) but I'm unable to test it at the moment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem with device=self.device is that somehow device=None defaults to CPU and I will get mismatch device error.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@loubbrad I didn't realize these two lines can be safely removed now with hugging face codes. In fact device param itself is redundant!
In the original implementation of RotaryEmbedding (from neox repo) this device needs to be known at initialization-time because the cached cos_cached, sin_cached need to be explicitly set. But now they are registered as model parameters and moving across device is automatic.

self.rotary_emb = RotaryEmbedding(self.d_head)
if use_yarn:
# todo: need more testing on this
self.rotary_emb = DynamicYaRNScaledRotaryEmbedding(self.d_head,
Copy link
Contributor

@loubbrad loubbrad Nov 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This import is missing I think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes you are right. Will fix it in a moment.

@loubbrad loubbrad merged commit 358bd22 into EleutherAI:dev Nov 3, 2023
@honglu2875 honglu2875 deleted the honglu/dev branch November 6, 2023 23:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants