-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 238020d
Showing
7 changed files
with
100,882 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
# C extensions | ||
*.so | ||
|
||
# Distribution / packaging | ||
.Python | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
wheels/ | ||
share/python-wheels/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
MANIFEST | ||
|
||
# PyInstaller | ||
# Usually these files are written by a python script from a template | ||
# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
*.manifest | ||
*.spec | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# Unit test / coverage reports | ||
htmlcov/ | ||
.tox/ | ||
.nox/ | ||
.coverage | ||
.coverage.* | ||
.cache | ||
nosetests.xml | ||
coverage.xml | ||
*.cover | ||
*.py,cover | ||
.hypothesis/ | ||
.pytest_cache/ | ||
cover/ | ||
|
||
# Translations | ||
*.mo | ||
*.pot | ||
|
||
# Django stuff: | ||
*.log | ||
local_settings.py | ||
db.sqlite3 | ||
db.sqlite3-journal | ||
|
||
# Flask stuff: | ||
instance/ | ||
.webassets-cache | ||
|
||
# Scrapy stuff: | ||
.scrapy | ||
|
||
# Sphinx documentation | ||
docs/_build/ | ||
|
||
# PyBuilder | ||
.pybuilder/ | ||
target/ | ||
|
||
# Jupyter Notebook | ||
.ipynb_checkpoints | ||
|
||
# IPython | ||
profile_default/ | ||
ipython_config.py | ||
|
||
# pyenv | ||
# For a library or package, you might want to ignore these files since the code is | ||
# intended to run in multiple environments; otherwise, check them in: | ||
# .python-version | ||
|
||
# pipenv | ||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. | ||
# However, in case of collaboration, if having platform-specific dependencies or dependencies | ||
# having no cross-platform support, pipenv may install dependencies that don't work, or not | ||
# install all needed dependencies. | ||
#Pipfile.lock | ||
|
||
# poetry | ||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. | ||
# This is especially recommended for binary packages to ensure reproducibility, and is more | ||
# commonly ignored for libraries. | ||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control | ||
#poetry.lock | ||
|
||
# pdm | ||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. | ||
#pdm.lock | ||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it | ||
# in version control. | ||
# https://pdm.fming.dev/#use-with-ide | ||
.pdm.toml | ||
|
||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm | ||
__pypackages__/ | ||
|
||
# Celery stuff | ||
celerybeat-schedule | ||
celerybeat.pid | ||
|
||
# SageMath parsed files | ||
*.sage.py | ||
|
||
# Environments | ||
.env | ||
.venv | ||
env/ | ||
venv/ | ||
ENV/ | ||
env.bak/ | ||
venv.bak/ | ||
|
||
# Spyder project settings | ||
.spyderproject | ||
.spyproject | ||
|
||
# Rope project settings | ||
.ropeproject | ||
|
||
# mkdocs documentation | ||
/site | ||
|
||
# mypy | ||
.mypy_cache/ | ||
.dmypy.json | ||
dmypy.json | ||
|
||
# Pyre type checker | ||
.pyre/ | ||
|
||
# pytype static type analyzer | ||
.pytype/ | ||
|
||
# Cython debug symbols | ||
cython_debug/ | ||
|
||
# PyCharm | ||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can | ||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore | ||
# and can be added to the global gitignore or merged into this file. For a more nuclear | ||
# option (not recommended) you can uncomment the following to ignore the entire idea folder. | ||
#.idea/ | ||
*.pth |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# RWKV based reward model | ||
|
||
Work in progress, not ready for use. | ||
|
||
## Introduction | ||
|
||
This git implements a basic reward model based on RWKV models. The purposes of training such a reward model are: | ||
|
||
- Show that RWKV models can be used to learn reward functions | ||
- With the reward model trained, we can use it to train a policy using RL algorithms that encourages the base RWKV model to generate more diverse trajectories and high quality answers. | ||
|
||
## To run this experiment | ||
|
||
### 0. Environments | ||
|
||
You will need to install RWKV official pypi package using `pip install rwkv`. You will also need `datasets` and `pytorch` packages. | ||
|
||
``` | ||
pip install rwkv | ||
pip install datasets | ||
pip install torch | ||
``` | ||
|
||
### 2. Download the weights | ||
|
||
In this experiment I chose a smaller RWKV model with 430M parameters for enabling a quick training and testing on a single GPU. You can download the weight from https://huggingface.co/BlinkDL/rwkv-4-pile-430m | ||
|
||
I used this weight: https://huggingface.co/BlinkDL/rwkv-4-pile-430m/resolve/main/RWKV-4-Pile-430M-20220808-8066.pth | ||
|
||
### 3. Run the experiment | ||
|
||
``` | ||
python train.py | ||
``` | ||
|
||
In this experiment, we will use the reward dataset from https://huggingface.co/datasets/yitingxie/rlhf-reward-datasets. We sampled 100 data points from the train and 20 from the validation to show that the reward model can be trained to predict the reward values. | ||
|
||
## The detail of the reward model | ||
|
||
In short, the reward model in this setting will 1) read the prompt and output from another LLM model; 2) Rate the output from 1 to -1 ( in logits) for "accept" and "reject" respectively. In the dataset `yitingxie/rlhf-reward-datasets`, for each prompt, two answers were generated by an anonymous LLM, then human raters rated one of the answers as "accept" and the other as "reject". The reward model is trained to predict the human ratings. | ||
|
||
Once such a reward model is trained, we can use it to train a policy using RL algorithms that encourages the base RWKV model to get higher scores (i.e. more "accept" ratings). | ||
|
||
|
||
## Work in progress | ||
|
||
1. The reward model is not trained well yet. Only the last few layers are trained. | ||
2. I will implement qLora on top of this reward model to make it fully trainable. | ||
3. The dataset we used for illustration here is not big enough and diverse enough. We will need to collect more data to train a better reward model. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import types | ||
|
||
RWKV_CONFIGS = { | ||
"430M": { | ||
"n_layer": 24, | ||
"n_embd": 1024, | ||
} | ||
} | ||
|
||
def get_config(model_name): | ||
""" | ||
Locates the configuration for a given model name. | ||
""" | ||
for k, v in RWKV_CONFIGS.items(): | ||
if k in model_name: | ||
args = types.SimpleNamespace(**v) | ||
args.MODEL_NAME = model_name | ||
return args |
Oops, something went wrong.