Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] RWKV4Neo the RNN and GPT Hybrid Model #20809

Closed
wants to merge 6 commits into from

Conversation

ArEnSc
Copy link
Contributor

@ArEnSc ArEnSc commented Dec 17, 2022

What does this PR do?

Adds the model from issue
Fixes # (#20737)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@younesbelkada
@ArthurZucker

@ArEnSc ArEnSc marked this pull request as draft December 17, 2022 20:50
@younesbelkada
Copy link
Contributor

younesbelkada commented Dec 19, 2022

Hi @ArEnSc !
Thanks for starting over the PR 💪
Let us know whenever you need help with @ArthurZucker !

@ArEnSc
Copy link
Contributor Author

ArEnSc commented Dec 19, 2022

Hi @ArEnSc ! Thanks for starting over the PR 💪 Let us know whenever you need help with @ArthurZucker !

Will do still doing some research, just figured out how the training notebook works, model executes in notebook so that's a positive

@ArEnSc
Copy link
Contributor Author

ArEnSc commented Jan 5, 2023

Update: tracing the model and came up with a state based api for the RNN inference mode on my own code base to experiment with

@younesbelkada
Copy link
Contributor

Thanks a lot for the status update! Feel free to ping whenever you need help

@xloem
Copy link
Contributor

xloem commented Jan 16, 2023

Sometimes I look at working on this a little. Here are my notes and possible tasks, started 2023-01-16.

  • The template appears to be from a T5 style model. The RWKV state could be the encoder hidden state (a little intuitive) and/or the past key values (normative generation). It will take some algebra and tests to add input state to the GPT training form from the RNN inference form.

  • The tensorflow loading code appears complicating to me. I might move it out to another file for now.

  • The embeddings can likely be adjusted to reflect parts "i" and "ii" of the high level outline below

  • It could be helpful to organize the file to retain layout similarity with blinkdl’s files.

  • For below outline, next step is reviewing timemix.
    Draft of architecture (maybe leave out optional parts to start).

    High level:

    1. word embeddings emb
    2. layernorm ln0
      - optional 2-axis trained position embeddings seen in training code for image modeling pos_emb_x pos_emb_y. this is converted to 1-axis pos_emb and used prior to ln0 in inference.
    3. layers of blocks
      1. layernorm ln1
      2. timemix self attention time_mix_k, time_mix_v, time_mix_r, time_first, time_decay, key, value, receptance, output. time_first and time_decay are kept as float32 in inference.
      3. layernorm ln2
      4. feedforward channelmix time_mix_k, time_mix_r, key, value, receptance (see channelmix section below)
      - timemix self attention optionally replaced with feedforward channelmix for block 0 in training code
      - for one optional block, tiny attention tiny_ln, tiny_q, tiny_k, tiny_v, tiny_mask seen in training code, inference code in development
      - optionally inference code uses what looks like a numeric stability trick to extract a factor of 2 from the weights every 6 layere
    4. layernorm ln_out
      - optional "copy" attention head_q, head_k, copy_mask then summed to head in training code, inference code in development
    5. linear language modeling head
      - for training loss, blink presently has a function after cross entropy called L2Wrap to reduce magnitudes

    GPT(training) and RNN (inference) equivalence:

    • i think special training initialization values may be used in timemix, channelmix
    • for inference time_decay = -exp(time_decay) is factored out when loaded, but for training this is done in the forward pass.
    • 5 state elements per layer:
      • 0 = ChannelMix/FF xx
      • 1 = TimeMix/SA xx
      • 2 = aa
      • 3 = bb
      • 4 = pp in inference, o in training

    TimeMix:

    1. the previous state is shifted into the x vector to make xx. in training this is done by "time shifting" with nn.ZeroPad2d((0, 0, 1, -1)); in single token inference it is passed as state element 1, which is then replaced by x.
    2. linear interpolation between the old state xx and the new state x, weighting x by a ratio of time_mix_k, time_mix_v, and time_mix_r to make xk, xv, and xr respectivly.
    3. k = key @ xk
    4. v = value @ xv
    5. sr = sigmoid(receptance @ xr) # called simply r in inference code
    • the GPT training form of this is now handed off to a hand-written cuda kernel, compiled on first run, from cuda/wkv_cuda.cu
      • kernel parameters: B = batchsize; T = sequence length; C = channel count; _w = time_decay; _u = time_first; _k = k; _v = v; _y = wkv.
      • i think this used to be a convolution; i'm not sure whether it still is
      • o and no appear to be running values for magnitude management in exponential space, initialized to -1e38; p and q are initialized to 0
      • k and v are indexed by thread so the token offset may represent different subregions. i'm not quite clear on that and should test or ask.
      1. no = max(o, time_first[channel] + k[token])
      2. A = exp(o - no) # this is e1 in the RNN form
      3. B = exp(time_first[channel] + k[token] - no) # this is e2 in RNN
      4. wkv[token] = (A * p + B * v[token]) / (A * q + B)
      5. no = max(time_decay[channel] + o, k[token])
      6. A = exp(time_decay[channel] + o - no)
      7. B = exp(k[token] - no)
      8. p = A * p + B * v[token]
      9. q = A * q + B
      10. o = no; token += 1
    • ... here would be the remaining core algebra and code inspection
    • WIP unified summary of wkv kernel between inference and training:
      1. ww = time_first + k[token]
      2. next_pp = max(pp, ww)
      3. A = exp(pp - next_pp ...
    • rwkv = sr * wkv
    • return output @ rwkv

    ChannelMix:

    1. the previous state is shifted into the x vector to make xx. in training this is done by "time shifting" with nn.ZeroPad2d((0, 0, 1, -1)); in single token inference it is passed as state element 0, which is then replaced by x.
    2. linear interpolation between the old state xx and the new state x, weighting x by a ratio of time_mix_k and time_mix_r to make xk and xr respectivly.
    3. r = sigmoid(receptance @ xr)
    4. k = square(relu(key @ xk))
    5. kv = value @ k
    6. rkv = r * kv
    7. return rkv
  • review or improve model file further

@Lundez
Copy link

Lundez commented Jan 17, 2023

@ArEnSc do you need any help?

@ArEnSc
Copy link
Contributor Author

ArEnSc commented Jan 17, 2023

@ArEnSc do you need any help?

if you want to help pm me! on discord, otherwise I should have something end of week minor update

@younesbelkada
Copy link
Contributor

younesbelkada commented Jan 23, 2023

Hi @ArEnSc,
Can you share with us your discord handle? Thanks!

@ArEnSc
Copy link
Contributor Author

ArEnSc commented Jan 23, 2023

Hi @ArEnSc, Can you share with us your discord handle? Thanks!

ARENSC#5905
yeah still working on it haha it will be a while

@ArEnSc
Copy link
Contributor Author

ArEnSc commented Jan 30, 2023

Working on having GPT Encoder to generate the context and RNN mode inference and sharing weights

@ArEnSc
Copy link
Contributor Author

ArEnSc commented Jan 30, 2023

Deleted a bunch of not needed stuff

@huggingface huggingface deleted a comment from github-actions bot Mar 15, 2023
@ArthurZucker ArthurZucker changed the title RWKV4Neo the RNN and GPT Hybrid Model [WIP] RWKV4Neo the RNN and GPT Hybrid Model Mar 15, 2023
@ArthurZucker
Copy link
Collaborator

Added the [WIP] Label to prevent the bot from coming back 😉

@huggingface huggingface deleted a comment from github-actions bot Apr 11, 2023
@fblgit fblgit mentioned this pull request Apr 11, 2023
2 tasks
@sgugger
Copy link
Collaborator

sgugger commented Apr 11, 2023

@ArEnSc Please let us know if you won't have time to finish this PR. The model is heavily requested as you may see from the linked issue, do you want us to take over this PR and finish this?

@ArEnSc
Copy link
Contributor Author

ArEnSc commented Apr 12, 2023

@ArEnSc Please let us know if you won't have time to finish this PR. The model is heavily requested as you may see from the linked issue, do you want us to take over this PR and finish this?

Sure yes, sorry been busy at the hospital these days! I think it's probably important that you guys take this on =)

@github-actions
Copy link

github-actions bot commented May 6, 2023

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this May 14, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants