Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Draft] RWKV LM #207

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add HF hub download links within the model code
  • Loading branch information
kyobrien committed Apr 25, 2023
commit 53e3f0488f0f4ca89452fa267c480b5f678c2b92
29 changes: 21 additions & 8 deletions elk/rwkv_lm/rwkv_hf.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import os

from transformers import (
PretrainedConfig,
PreTrainedModel,
)
from huggingface_hub import hf_hub_download
from transformers.modeling_outputs import CausalLMOutput
from transformers import PretrainedConfig, PreTrainedModel

# The rwkv.model is the official build
# from rwkv.model import RWKV

# rwkv_hiddens is a custom implementation that exposes all the hidden states as layer states - written by Nora
from .rwkv_hiddens import RWKV

os.environ["RWKV_JIT_ON"] = "1"
Expand All @@ -25,9 +25,22 @@ def __init__(self, **kwargs):
class RWKVModel(PreTrainedModel):
def __init__(self):
super().__init__(RWKVConfig())
weights_path = "/home/kyle/HF-MODEL/rwkv-4-pile-1b5/models--BlinkDL--rwkv-4-pile-1b5/snapshots/6ea995eaa87a17af560c9b41ce1a3d92355c5a49/RWKV-4-Pile-1B5-20220903-8040.pth"
# weights_path = "/home/kyle/HF-MODEL/rwkv-4-pile-14b/models--BlinkDL--rwkv-4-pile-14b/snapshots/939b6851f96122b7b49bd00d446b3b49481214dd/RWKV-4-Pile-14B-20230213-8019.pth"
self.model = RWKV(model=weights_path, strategy="cuda fp16")

# TODO: Add support for specifying the parameter count through the HF path provided in the CLI args

# 1.5b
# weights_path = hf_hub_download(repo_id="BlinkDL/rwkv-4-pile-1b5", filename="RWKV-4-Pile-1B5-20220903-8040.pth")

# 3b
# weights_path = hf_hub_download(repo_id="BlinkDL/rwkv-4-pile-3b", filename="RWKV-4-Pile-3B-20221008-8023.pth")

# 7b
weights_path = hf_hub_download(repo_id="BlinkDL/rwkv-4-pile-7b", filename="RWKV-4-Pile-7B-20221115-8047.pth")

# 14b
# weights_path = hf_hub_download(repo_id="BlinkDL/rwkv-4-pile-14b", filename="RWKV-4-Pile-14B-20230213-8019.pth")

self.model = RWKV(model=weights_path, strategy="cuda bf16")

def forward(
self,
Expand Down
5 changes: 3 additions & 2 deletions elk/rwkv_lm/test_rwkv.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [
{
Expand All @@ -26,7 +26,8 @@
],
"source": [
"from huggingface_hub import hf_hub_download\n",
"hf_hub_download(repo_id=\"BlinkDL/rwkv-4-pile-14b\", filename=\"RWKV-4-Pile-14B-20230213-8019.pth\", cache_dir=\"/home/kyle/HF-MODEL/rwkv-4-pile-14b\")"
"download_path = hf_hub_download(repo_id=\"BlinkDL/rwkv-4-pile-1b5\", filename=\"RWKV-4-Pile-1B5-20220903-8040.pth\")\n",
"print(\"Download path is \" + download_path)"
]
},
{
Expand Down