Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Draft] RWKV LM #207

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Remove custom approach and begin generalizing
  • Loading branch information
Kyle1668 committed May 6, 2023
commit a970761632bf1878aeb82a52770d23fb9fe47bca
78 changes: 51 additions & 27 deletions elk/rwkv_lm/rwkv_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,6 @@
)
from transformers.modeling_outputs import CausalLMOutput

# The rwkv.model is the official build
# from rwkv.model import RWKV
# rwkv_hiddens is a custom implementation that exposes all the hidden states as layer states - written by Nora
# from .rwkv_hiddens import RWKV
from rwkv.model import RWKV
from rwkv.utils import PIPELINE

Expand All @@ -21,35 +17,45 @@


class RWKVConfig(PretrainedConfig):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.hidden_size = 4096
self.num_hidden_layers = 160
def __init__(self, hidden_size, num_hidden_layers):
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.is_encoder_decoder = False
self.architectures = ["RWKV-LM"]


class RWKVModel(PreTrainedModel):
def __init__(self, device):
super().__init__(RWKVConfig())

# TODO: Add support for specifying the parameter count through the HF path provided in the CLI args

# 1.5b
# weights_path = hf_hub_download(repo_id="BlinkDL/rwkv-4-pile-1b5", filename="RWKV-4-Pile-1B5-20220903-8040.pth")

# 3b
# weights_path = hf_hub_download(repo_id="BlinkDL/rwkv-4-pile-3b", filename="RWKV-4-Pile-3B-20221008-8023.pth")

# 7b
weights_path = hf_hub_download(
repo_id="BlinkDL/rwkv-4-pile-7b",
filename="RWKV-4-Pile-7B-20221115-8047.pth",
@staticmethod
def from_pretrained(pretrained_model_name_or_path, **kwargs):
path_config_maps = {
"BlinkDL/rwkv-4-pile-1b5": {
"hidden_size": 2048,
"num_hidden_layers": 120,
},
"BlinkDL/rwkv-4-pile-3b": {
"hidden_size": 2560,
"num_hidden_layers": 160,
},
"BlinkDL/rwkv-4-pile-7b": {
"hidden_size": 4096,
"num_hidden_layers": 160,
},
"BlinkDL/rwkv-4-pile-14b": {
"hidden_size": 5120,
"num_hidden_layers": 200,
},
"BlinkDL/rwkv-4-raven": {
"hidden_size": 5120,
"num_hidden_layers": 200,
}
}
return RWKVConfig(
hidden_size=path_config_maps[pretrained_model_name_or_path]["hidden_size"],
num_hidden_layers=path_config_maps[pretrained_model_name_or_path]["num_hidden_layers"],
)

# 14b
# weights_path = hf_hub_download(repo_id="BlinkDL/rwkv-4-pile-14b", filename="RWKV-4-Pile-14B-20230213-8019.pth")

class RWKVModel(PreTrainedModel):
def __init__(self, config, weights_path, device):
super().__init__(config)
strategy = f"{device} fp16"
self.model = RWKV(model=weights_path, strategy=strategy)
self.device_object = torch.device(device)
Expand All @@ -75,6 +81,24 @@ def forward(

return response

@staticmethod
def from_pretrained(pretrained_model_name_or_path):
repo_weights_paths = {
"BlinkDL/rwkv-4-pile-1b5": "RWKV-4-Pile-1B5-20220903-8040.pth",
"BlinkDL/rwkv-4-pile-3b": "RWKV-4-Pile-3B-20221008-8023.pth",
"BlinkDL/rwkv-4-pile-7b": "RWKV-4-Pile-7B-20221115-8047.pth",
"BlinkDL/rwkv-4-pile-14b": "RWKV-4-Pile-14B-20230213-8019.pth",
"BlinkDL/rwkv-4-raven": "RWKV-4-Raven-14B-v10-Eng99%-Other1%-20230427-ctx8192.pth",
}

if pretrained_model_name_or_path not in repo_weights_paths:
raise ValueError(f"Unsupported RWKV model: {pretrained_model_name_or_path}")

weights_path = hf_hub_download(repo_id=pretrained_model_name_or_path, filename=repo_weights_paths[pretrained_model_name_or_path])
config = RWKVConfig.from_pretrained(pretrained_model_name_or_path)
model = RWKVModel(config, weights_path, device="cuda")
return model


class RWKVTokenizer(PreTrainedTokenizer):
model_max_length = 2048
Expand Down
Loading