Skip to content

Commit

Permalink
initial commit!
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeff Wu and Adrien Ecoffet and Manas Joglekar and Jan Hendrik Kirchner and Pavel Izmailov authored and WuTheFWasThat committed Dec 14, 2023
0 parents commit 1dbfbd6
Show file tree
Hide file tree
Showing 24 changed files with 1,723 additions and 0 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
dump
*.pyc
*.swp
*.swo
7 changes: 7 additions & 0 deletions LICENSE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Copyright 2023 OpenAI

Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
60 changes: 60 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
**STATUS**: This codebase is not well tested and does not use the exact same settings we used in the paper, but in our experience gives qualitatively similar results when using large model size gaps and multiple seeds. Expected results can be found for two datasets below. We may update the code significantly in the coming week.

# Weak-to-strong generalization

![Our setup and how it relates to superhuman AI alignment](./weak-to-strong-setup.png)

This project contains code for implementing our [paper on weak-to-strong generalization](https://cdn.openai.com/papers/weak-to-strong-generalization.pdf).

The primary codebase contains a re-implementation of our weak-to-strong learning setup for binary classification tasks. The codebase contains code for fine-tuning pretrained language models, and also training against the labels from another language model. We support various losses described in the paper as well, such as the confidence auxiliary loss.

The `vision` directory contains stand-alone code for weak-to-strong in the vision models setting (AlexNet -> DINO on ImageNet).

### Getting Started

These instructions will get you a copy of the project up and running on your local machine for development and testing purposes.

#### Installation

You need to have Python installed on your machine. The project also has some dependencies, which can be installed with pip:

```
pip install -r requirements.txt
```

#### Running the Script

The main script of the project is train_weak_to_strong.py. It can be run from the command line using the following command:
```
python train_weak_to_strong.py
```

The script accepts several command-line arguments to customize the training process. Here are some examples:

```
python train_weak_to_strong.py --batch_size 32 --max_ctx 512 --ds_name "sciq" --loss "logconf" --n_docs 1000 --n_test_docs 100 --weak_model_size "gpt2-medium" --strong_model_size "gpt2-large" --seed 42
```

#### Expected results

<img src="notebooks/amazon_polarity_None.png" width="350">
<br>
<img src="notebooks/sciq_None.png" width="350">
<br>
<img src="notebooks/Anthropic-hh-rlhf_None.png" width="350">

### Authors

- Adrien Ecoffet
- Manas Joglekar
- Jeffrey Wu
- Jan Hendrik Kirchner
- Pavel Izmailov (vision)

### License

This project is licensed under the MIT License - see the LICENSE.md file for details.

### Acknowledgments

- Hugging Face for their open-source transformer models
Binary file added notebooks/Anthropic-hh-rlhf_None.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
159 changes: 159 additions & 0 deletions notebooks/Plotting.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "eb9a4b5a",
"metadata": {},
"source": [
"# Simple Plotting\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "88c7ff9f",
"metadata": {},
"outputs": [],
"source": [
"RESULTS_PATH = \"../../your_sweep_results_path\"\n",
"\n",
"PLOT_ALL_SEEDS = False\n",
"# Full sweep\n",
"MODELS_TO_PLOT = [\"gpt2\", \"gpt2-medium\", \"gpt2-large\", \"gpt2-xl\", \"Qwen/Qwen-1_8B\", \"Qwen/Qwen-7B\", \"Qwen/Qwen-14B\"]\n",
"# Minimal sweep\n",
"# MODELS_TO_PLOT = [\"gpt2\", \"gpt2-medium\", \"gpt2-large\"]\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "00ca073c",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"sns.set_style('whitegrid')\n",
"\n",
"from IPython.display import display\n",
"\n",
"import os\n",
"import glob\n",
"import json"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e5caa051",
"metadata": {},
"outputs": [],
"source": [
"records = []\n",
"all_results_folders = ['/'.join(e.split('/')[:-1]) for e in glob.glob(os.path.join(RESULTS_PATH, \"**/*.results_summary.json\"), recursive=True)]\n",
"for result_folder in set(all_results_folders):\n",
" config_file = os.path.join(result_folder, \"config.json\")\n",
" config = json.load(open(config_file, \"r\"))\n",
" if config[\"strong_model_size\"] not in MODELS_TO_PLOT:\n",
" continue\n",
" if 'seed' not in config:\n",
" config['seed'] = 0\n",
" result_filename = (config[\"weak_model_size\"].replace('.', '_') + \"_\" + config[\"strong_model_size\"].replace('.', '_') + \".results_summary.json\").replace('/', '_')\n",
" record = config.copy()\n",
" record.update(json.load(open(config_file.replace('config.json', result_filename))))\n",
" records.append(record)\n",
"\n",
"df = pd.DataFrame.from_records(records).sort_values(['ds_name', 'weak_model_size', 'strong_model_size'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2f628577",
"metadata": {},
"outputs": [],
"source": [
"datasets = df.ds_name.unique()\n",
"for dataset in datasets:\n",
" cur_df = df[(df.ds_name == dataset)]\n",
" base_df = pd.concat([\n",
" pd.DataFrame.from_dict({\"strong_model_size\": cur_df['weak_model_size'].to_list(), \"accuracy\": cur_df['weak_acc'].to_list(), \"seed\": cur_df['seed'].to_list()}),\n",
" pd.DataFrame.from_dict({\"strong_model_size\": cur_df['strong_model_size'].to_list(), \"accuracy\": cur_df['strong_acc'].to_list(), \"seed\": cur_df['seed'].to_list()})\n",
" ])\n",
" base_accuracies = base_df.groupby('strong_model_size').agg({'accuracy': 'mean', 'seed': 'count'}).sort_values('accuracy')\n",
" base_accuracy_lookup = base_accuracies['accuracy'].to_dict()\n",
" base_accuracies = base_accuracies.reset_index()\n",
" base_df.reset_index(inplace=True)\n",
" base_df['weak_model_size'] = 'ground truth'\n",
" base_df['loss'] = 'xent'\n",
" base_df['strong_model_accuracy'] = base_df['strong_model_size'].apply(lambda x: base_accuracy_lookup[x])\n",
"\n",
" weak_to_strong = cur_df[['weak_model_size', 'strong_model_size', 'seed'] + [e for e in cur_df.columns if e.startswith('transfer_acc')]]\n",
" weak_to_strong = weak_to_strong.melt(id_vars=['weak_model_size', 'strong_model_size', 'seed'], var_name='loss', value_name='accuracy')\n",
" weak_to_strong = weak_to_strong.dropna(subset=['accuracy'])\n",
" weak_to_strong.reset_index(inplace=True)\n",
" weak_to_strong['loss'] = weak_to_strong['loss'].str.replace('transfer_acc_', '')\n",
" weak_to_strong['strong_model_accuracy'] = weak_to_strong['strong_model_size'].apply(lambda x: base_accuracy_lookup[x])\n",
"\n",
" # Exclude cases where the weak model is better than the strong model from PGR calculation.\n",
" pgr_df = cur_df[(cur_df['weak_model_size'] != cur_df['strong_model_size']) & (cur_df['strong_acc'] > cur_df['weak_acc'])]\n",
" pgr_df = pgr_df.melt(id_vars=[e for e in cur_df.columns if not e.startswith('transfer_acc')], var_name='loss', value_name='transfer_acc')\n",
" pgr_df = pgr_df.dropna(subset=['transfer_acc'])\n",
" pgr_df['loss'] = pgr_df['loss'].str.replace('transfer_acc_', '')\n",
" pgr_df['pgr'] = (pgr_df['transfer_acc'] - pgr_df['weak_acc']) / (pgr_df['strong_acc'] - pgr_df['weak_acc'])\n",
"\n",
" for seed in [None] + (sorted(cur_df['seed'].unique().tolist()) if PLOT_ALL_SEEDS else []):\n",
" plot_df = pd.concat([base_df, weak_to_strong])\n",
" seed_pgr_df = pgr_df\n",
" if seed is not None:\n",
" plot_df = plot_df[plot_df['seed'] == seed]\n",
" # We mean across seeds, this is because sometimes the weak and strong models will have run on different hardware and therefore\n",
" # have slight differences. We want to average these out when filtering by seed.\n",
"\n",
" seed_pgr_df = pgr_df[pgr_df['seed'] == seed]\n",
"\n",
" if seed is not None or cur_df['seed'].nunique() == 1:\n",
" plot_df = plot_df[['strong_model_accuracy', 'weak_model_size', 'loss', 'accuracy']].groupby(['strong_model_accuracy', 'weak_model_size', 'loss']).mean().reset_index().sort_values(['loss', 'weak_model_size'], ascending=False)\n",
"\n",
" print(f\"Dataset: {dataset} (seed: {seed})\")\n",
"\n",
" pgr_results = seed_pgr_df.groupby(['loss']).aggregate({\"pgr\": \"median\"})\n",
" display(pgr_results)\n",
"\n",
" palette = sns.color_palette('colorblind', n_colors=len(plot_df['weak_model_size'].unique()) - 1)\n",
" color_dict = {model: (\"black\" if model == 'ground truth' else palette.pop()) for model in plot_df['weak_model_size'].unique()}\n",
"\n",
" sns.lineplot(data=plot_df, x='strong_model_accuracy', y='accuracy', hue='weak_model_size', style='loss', markers=True, palette=color_dict)\n",
" pd.plotting.table(plt.gca(), pgr_results.round(4), loc='lower right', colWidths=[0.1, 0.1], cellLoc='center', rowLoc='center')\n",
" plt.xticks(ticks=base_accuracies['accuracy'], labels=[f\"{e} ({base_accuracy_lookup[e]:.4f})\" for e in base_accuracies['strong_model_size']], rotation=90)\n",
" plt.title(f\"Dataset: {dataset} (seed: {seed})\")\n",
" plt.legend(loc='upper left')\n",
" plt.savefig(f\"{dataset.replace('/', '-')}_{seed}.png\", dpi=300, bbox_inches='tight')\n",
" plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "openai",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Binary file added notebooks/amazon_polarity_None.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added notebooks/sciq_None.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
18 changes: 18 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[project]
name = "weak_to_strong"
version = "0.0.1"
authors = [
{ name="OpenAI", email="[email protected]" },
]
description = "Weak-to-strong generalization"
readme = "README.md"
requires-python = ">=3.7"
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
]
8 changes: 8 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
torch~=2.1
numpy~=1.24
transformers~=4.36
datasets~=2.14
fire~=0.4
accelerate~=0.25
transformers-stream-generator~=0.0.4
torch_optimizer~=0.3
12 changes: 12 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import setuptools

setuptools.setup(
name="weak_to_strong",
version="0.1",
description="Weak-to-strong generalization",
url="#",
author="OpenAI",
author_email="[email protected]",
packages=setuptools.find_packages(),
zip_safe=False,
)
Loading

0 comments on commit 1dbfbd6

Please sign in to comment.