Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
jiamingkong committed Jun 5, 2023
0 parents commit 238020d
Show file tree
Hide file tree
Showing 7 changed files with 100,882 additions and 0 deletions.
161 changes: 161 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
*.pth
49 changes: 49 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# RWKV based reward model

Work in progress, not ready for use.

## Introduction

This git implements a basic reward model based on RWKV models. The purposes of training such a reward model are:

- Show that RWKV models can be used to learn reward functions
- With the reward model trained, we can use it to train a policy using RL algorithms that encourages the base RWKV model to generate more diverse trajectories and high quality answers.

## To run this experiment

### 0. Environments

You will need to install RWKV official pypi package using `pip install rwkv`. You will also need `datasets` and `pytorch` packages.

```
pip install rwkv
pip install datasets
pip install torch
```

### 2. Download the weights

In this experiment I chose a smaller RWKV model with 430M parameters for enabling a quick training and testing on a single GPU. You can download the weight from https://huggingface.co/BlinkDL/rwkv-4-pile-430m

I used this weight: https://huggingface.co/BlinkDL/rwkv-4-pile-430m/resolve/main/RWKV-4-Pile-430M-20220808-8066.pth

### 3. Run the experiment

```
python train.py
```

In this experiment, we will use the reward dataset from https://huggingface.co/datasets/yitingxie/rlhf-reward-datasets. We sampled 100 data points from the train and 20 from the validation to show that the reward model can be trained to predict the reward values.

## The detail of the reward model

In short, the reward model in this setting will 1) read the prompt and output from another LLM model; 2) Rate the output from 1 to -1 ( in logits) for "accept" and "reject" respectively. In the dataset `yitingxie/rlhf-reward-datasets`, for each prompt, two answers were generated by an anonymous LLM, then human raters rated one of the answers as "accept" and the other as "reject". The reward model is trained to predict the human ratings.

Once such a reward model is trained, we can use it to train a policy using RL algorithms that encourages the base RWKV model to get higher scores (i.e. more "accept" ratings).


## Work in progress

1. The reward model is not trained well yet. Only the last few layers are trained.
2. I will implement qLora on top of this reward model to make it fully trainable.
3. The dataset we used for illustration here is not big enough and diverse enough. We will need to collect more data to train a better reward model.
18 changes: 18 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import types

RWKV_CONFIGS = {
"430M": {
"n_layer": 24,
"n_embd": 1024,
}
}

def get_config(model_name):
"""
Locates the configuration for a given model name.
"""
for k, v in RWKV_CONFIGS.items():
if k in model_name:
args = types.SimpleNamespace(**v)
args.MODEL_NAME = model_name
return args
Loading

0 comments on commit 238020d

Please sign in to comment.