Skip to content

Commit

Permalink
Merge branch 'main' into hotfix/invalid_dtype_kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
WuTheFWasThat committed Jan 23, 2024
2 parents c01b150 + 7055122 commit 6c908db
Show file tree
Hide file tree
Showing 14 changed files with 614 additions and 54 deletions.
62 changes: 52 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
**STATUS**: This codebase is not well tested and does not use the exact same settings we used in the paper, but in our experience gives qualitatively similar results when using large model size gaps and multiple seeds. Expected results can be found for two datasets below. We may update the code significantly in the coming week.
**STATUS**: This codebase is not well tested and does not use the exact same settings we used in the paper, but in our experience gives qualitatively similar results when using large model size gaps and multiple seeds. Expected results can be found for two datasets below.

# Weak-to-strong generalization

Expand All @@ -24,24 +24,45 @@ pip install .

#### Running the Script

The main script of the project is train_weak_to_strong.py. It can be run from the command line using the following command:
The main script of the project is `sweep.py`. It can be run from the command line using the following command:
```
python train_weak_to_strong.py
python sweep.py --model_sizes=gpt2,gpt2-medium
```

The script accepts several command-line arguments to customize the training process. Here are some examples:
In addition to `--model_sizes`, `sweep.py` takes in almost all of the arguments that `train_simple.py` takes (e.g.
`--batch_size`, `--n_docs`, `--n_test_docs` etc., see `train_simple.py` for a full list). These arguments are simply
forwarded to `train_simple.py`.

```
python train_weak_to_strong.py --batch_size 32 --max_ctx 512 --ds_name "sciq" --loss "logconf" --n_docs 1000 --n_test_docs 100 --weak_model_size "gpt2-medium" --strong_model_size "gpt2-large" --seed 42
```
`sweep.py` calls `train_simple.py` in the following way:
1. First, it calls `train_simple.py` for each model size to train the ground truth models
2. Then, for each pair of weak and strong models in `model_sizes` (where a model can be the strong model in the pair
only if its index in the `model_sizes` list is >= the index of the weak model), it calls `train_simple.py` with a
`--weak_model_size` argument so that the strong model is trained with the labels of the weak model.

E.g. the example above will run gpt2 (ground truth), gpt2-medium (ground truth), gpt2 -> gpt2, gpt2 -> gpt2-medium, and
gpt2-medium -> gpt2-medium.

If needed, you can also run `train_simple.py` directly.

Note that `sweep.py` will not accept the arguments `--weak_model_size`, `--weak_labels_path` or `--model_size` (as opposed
to `--model_sizes`, with an "s") as choosing their values automatically is precisely the point of `sweep.py`.

An example of Jupyter notebook for plotting results is found in `notebooks/Plotting.ipynb`.

At the time of release, the main script was called `train_weak_to_strong.py`, but it was less usable than
`sweep.py` and `train_simple.py`. It is preserved here and the old instructions are given at the end of the document.

#### Expected results

<img src="notebooks/amazon_polarity_None.png" width="350">
<img src="notebooks/amazon_polarity.png" width="350">
<br>
<img src="notebooks/sciq_None.png" width="350">
<img src="notebooks/anthropic_hh.png" width="350">
<br>
<img src="notebooks/Anthropic-hh-rlhf_None.png" width="350">
<img src="notebooks/boolq.png" width="350">
<br>
<img src="notebooks/cosmos_qa.png" width="350">
<br>
<img src="notebooks/sciq.png" width="350">

### Authors

Expand All @@ -58,3 +79,24 @@ This project is licensed under the MIT License - see the LICENSE.md file for det
### Acknowledgments

- Hugging Face for their open-source transformer models

### Original single run script

You can run the original training script using:
```
python train_weak_to_strong.py
```

The script accepts several command-line arguments to customize the training process. Here are some examples:

```
python train_weak_to_strong.py --batch_size 32 --max_ctx 512 --ds_name "sciq" --loss "logconf" --n_docs 1000 --n_test_docs 100 --weak_model_size "gpt2-medium" --strong_model_size "gpt2-large" --seed 42
```

The notebook `notebooks/Plotting_old.ipynb` preserves the plotting notebook corresponding to old style training.

The key difference between this style and the new `sweep.py` style is that `train_weak_to_strong.py` will always
train three models: a weak model, a transfer model, and a strong model. `sweep.py` optimizes this by training
a series of ground truth models (which will serve as weak and strong models) as well as a series of transfer models
all in one go. This reduces training duplication and is arguably simpler. The files generated by `train_simple.py`
and `sweep.py` are also simpler to use.
Binary file removed notebooks/Anthropic-hh-rlhf_None.png
Binary file not shown.
73 changes: 31 additions & 42 deletions notebooks/Plotting.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"metadata": {},
"outputs": [],
"source": [
"RESULTS_PATH = \"../../your_sweep_results_path\"\n",
"RESULTS_PATH = \"../../your_sweep_path/default\"\n",
"\n",
"PLOT_ALL_SEEDS = False\n",
"# Full sweep\n",
Expand Down Expand Up @@ -52,20 +52,24 @@
"outputs": [],
"source": [
"records = []\n",
"all_results_folders = ['/'.join(e.split('/')[:-1]) for e in glob.glob(os.path.join(RESULTS_PATH, \"**/*.results_summary.json\"), recursive=True)]\n",
"for result_folder in set(all_results_folders):\n",
" config_file = os.path.join(result_folder, \"config.json\")\n",
"for result_filename in glob.glob(os.path.join(RESULTS_PATH, \"**/results_summary.json\"), recursive=True):\n",
" config_file = os.path.join(\"/\".join(result_filename.split(\"/\")[:-1]), \"config.json\")\n",
" config = json.load(open(config_file, \"r\"))\n",
" if config[\"strong_model_size\"] not in MODELS_TO_PLOT:\n",
" if config[\"model_size\"] not in MODELS_TO_PLOT:\n",
" continue\n",
" if 'seed' not in config:\n",
" config['seed'] = 0\n",
" result_filename = (config[\"weak_model_size\"].replace('.', '_') + \"_\" + config[\"strong_model_size\"].replace('.', '_') + \".results_summary.json\").replace('/', '_')\n",
" record = config.copy()\n",
" record.update(json.load(open(config_file.replace('config.json', result_filename))))\n",
" if 'weak_model' in config:\n",
" for k in record['weak_model']:\n",
" if k == 'model_size':\n",
" assert record['weak_model'][k] == record['weak_model_size']\n",
" record['weak_' + k] = record['weak_model'][k]\n",
" del record['weak_model']\n",
" record.update(json.load(open(result_filename)))\n",
" records.append(record)\n",
"\n",
"df = pd.DataFrame.from_records(records).sort_values(['ds_name', 'weak_model_size', 'strong_model_size'])"
"df = pd.DataFrame.from_records(records).sort_values(['ds_name', 'model_size'])"
]
},
{
Expand All @@ -77,60 +81,45 @@
"source": [
"datasets = df.ds_name.unique()\n",
"for dataset in datasets:\n",
" cur_df = df[(df.ds_name == dataset)]\n",
" base_df = pd.concat([\n",
" pd.DataFrame.from_dict({\"strong_model_size\": cur_df['weak_model_size'].to_list(), \"accuracy\": cur_df['weak_acc'].to_list(), \"seed\": cur_df['seed'].to_list()}),\n",
" pd.DataFrame.from_dict({\"strong_model_size\": cur_df['strong_model_size'].to_list(), \"accuracy\": cur_df['strong_acc'].to_list(), \"seed\": cur_df['seed'].to_list()})\n",
" ])\n",
" base_accuracies = base_df.groupby('strong_model_size').agg({'accuracy': 'mean', 'seed': 'count'}).sort_values('accuracy')\n",
" cur_df = df[(df.ds_name == dataset)].copy()\n",
" base_accuracies = cur_df[cur_df['weak_model_size'].isna()].groupby('model_size').agg({'accuracy': 'mean', 'seed': 'count'}).sort_values('accuracy')\n",
" base_accuracy_lookup = base_accuracies['accuracy'].to_dict()\n",
" base_accuracies = base_accuracies.reset_index()\n",
" base_df.reset_index(inplace=True)\n",
" base_df['weak_model_size'] = 'ground truth'\n",
" base_df['loss'] = 'xent'\n",
" base_df['strong_model_accuracy'] = base_df['strong_model_size'].apply(lambda x: base_accuracy_lookup[x])\n",
"\n",
" weak_to_strong = cur_df[['weak_model_size', 'strong_model_size', 'seed'] + [e for e in cur_df.columns if e.startswith('transfer_acc')]]\n",
" weak_to_strong = weak_to_strong.melt(id_vars=['weak_model_size', 'strong_model_size', 'seed'], var_name='loss', value_name='accuracy')\n",
" weak_to_strong = weak_to_strong.dropna(subset=['accuracy'])\n",
" weak_to_strong.reset_index(inplace=True)\n",
" weak_to_strong['loss'] = weak_to_strong['loss'].str.replace('transfer_acc_', '')\n",
" weak_to_strong['strong_model_accuracy'] = weak_to_strong['strong_model_size'].apply(lambda x: base_accuracy_lookup[x])\n",
" cur_df['strong_model_accuracy'] = cur_df['model_size'].apply(lambda x: base_accuracy_lookup[x])\n",
" cur_df.loc[~cur_df['weak_model_size'].isna(), 'weak_model_accuracy'] = cur_df.loc[~cur_df['weak_model_size'].isna(), 'weak_model_size'].apply(lambda x: base_accuracy_lookup[x])\n",
"\n",
" # Exclude cases where the weak model is better than the strong model from PGR calculation.\n",
" pgr_df = cur_df[(cur_df['weak_model_size'] != cur_df['strong_model_size']) & (cur_df['strong_acc'] > cur_df['weak_acc'])]\n",
" pgr_df = pgr_df.melt(id_vars=[e for e in cur_df.columns if not e.startswith('transfer_acc')], var_name='loss', value_name='transfer_acc')\n",
" pgr_df = pgr_df.dropna(subset=['transfer_acc'])\n",
" pgr_df['loss'] = pgr_df['loss'].str.replace('transfer_acc_', '')\n",
" pgr_df['pgr'] = (pgr_df['transfer_acc'] - pgr_df['weak_acc']) / (pgr_df['strong_acc'] - pgr_df['weak_acc'])\n",
" valid_pgr_index = (\n",
" (~cur_df['weak_model_size'].isna()) & \n",
" (cur_df['weak_model_size'] != cur_df['model_size']) & \n",
" (cur_df['strong_model_accuracy'] > cur_df['weak_model_accuracy'])\n",
" )\n",
" cur_df.loc[valid_pgr_index, 'pgr'] = (cur_df.loc[valid_pgr_index, 'accuracy'] - cur_df.loc[valid_pgr_index, 'weak_model_accuracy']) / (cur_df.loc[valid_pgr_index, 'strong_model_accuracy'] - cur_df.loc[valid_pgr_index, 'weak_model_accuracy'])\n",
"\n",
" cur_df.loc[cur_df['weak_model_size'].isna(), \"weak_model_size\"] = \"ground truth\"\n",
"\n",
" for seed in [None] + (sorted(cur_df['seed'].unique().tolist()) if PLOT_ALL_SEEDS else []):\n",
" plot_df = pd.concat([base_df, weak_to_strong])\n",
" seed_pgr_df = pgr_df\n",
" plot_df = cur_df.copy().sort_values(['strong_model_accuracy']).sort_values(['loss'], ascending=False)\n",
" if seed is not None:\n",
" plot_df = plot_df[plot_df['seed'] == seed]\n",
" # We mean across seeds, this is because sometimes the weak and strong models will have run on different hardware and therefore\n",
" # have slight differences. We want to average these out when filtering by seed.\n",
"\n",
" seed_pgr_df = pgr_df[pgr_df['seed'] == seed]\n",
"\n",
" if seed is not None or cur_df['seed'].nunique() == 1:\n",
" plot_df = plot_df[['strong_model_accuracy', 'weak_model_size', 'loss', 'accuracy']].groupby(['strong_model_accuracy', 'weak_model_size', 'loss']).mean().reset_index().sort_values(['loss', 'weak_model_size'], ascending=False)\n",
"\n",
" print(f\"Dataset: {dataset} (seed: {seed})\")\n",
"\n",
" pgr_results = seed_pgr_df.groupby(['loss']).aggregate({\"pgr\": \"median\"})\n",
" display(pgr_results)\n",
" pgr_results = plot_df[~plot_df['pgr'].isna()].groupby(['loss']).aggregate({\"pgr\": \"median\"})\n",
"\n",
" palette = sns.color_palette('colorblind', n_colors=len(plot_df['weak_model_size'].unique()) - 1)\n",
" color_dict = {model: (\"black\" if model == 'ground truth' else palette.pop()) for model in plot_df['weak_model_size'].unique()}\n",
"\n",
" sns.lineplot(data=plot_df, x='strong_model_accuracy', y='accuracy', hue='weak_model_size', style='loss', markers=True, palette=color_dict)\n",
" pd.plotting.table(plt.gca(), pgr_results.round(4), loc='lower right', colWidths=[0.1, 0.1], cellLoc='center', rowLoc='center')\n",
" plt.xticks(ticks=base_accuracies['accuracy'], labels=[f\"{e} ({base_accuracy_lookup[e]:.4f})\" for e in base_accuracies['strong_model_size']], rotation=90)\n",
" plt.xticks(ticks=base_accuracies['accuracy'], labels=[f\"{e} ({base_accuracy_lookup[e]:.4f})\" for e in base_accuracies['model_size']], rotation=90)\n",
" plt.title(f\"Dataset: {dataset} (seed: {seed})\")\n",
" plt.legend(loc='upper left')\n",
" plt.savefig(f\"{dataset.replace('/', '-')}_{seed}.png\", dpi=300, bbox_inches='tight')\n",
" suffix = \"\"\n",
" if seed is not None:\n",
" suffix = f\"_{seed}\"\n",
" plt.savefig(f\"{dataset.replace('/', '-')}{suffix}.png\", dpi=300, bbox_inches='tight')\n",
" plt.show()"
]
}
Expand Down
Loading

0 comments on commit 6c908db

Please sign in to comment.