forked from g-benton/loss-surface-simplexes
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 882584b
Showing
74 changed files
with
9,413 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# Loss Surface Simplexes | ||
|
||
This repository contains the code accompanying our paper _Loss Surface Simplexes for Mode Connecting Volumes and Fast Ensembling_ by Greg Benton, Wesley Maddox, Sanae Lotfi, and Andrew Gordon Wilson. | ||
|
||
## Introduction | ||
|
||
The repository holds the implementation for Simplicial Pointwise Random Optimization (SPRO) as a method of finding simplicial complexes of low loss in the parameter space of neural networks. This work extends the approach of [Loss Surfaces, Mode Connectivity, and Fast Ensembling of DNNs](https://arxiv.org/abs/1802.10026) by Garipov et al. allowing us to find not just mode connecting paths but large intricate volumes of low loss that connect independently trained models in parameter space. | ||
|
||
<p float="center"> | ||
<img src="./plots/vggc10_mode_conn.jpg" width="400" /> | ||
<img src="./plots/extended_c10.jpg" width="300" /> | ||
</p> | ||
|
||
The left plot above shows loss surface projections along the edges and faces of a simplicial complex of low loss where <img src="https://render.githubusercontent.com/render/math?math=w_0"> and <img src="https://render.githubusercontent.com/render/math?math=w_1"> are independently trained VGG16 models on CIFAR-10, and <img src="https://render.githubusercontent.com/render/math?math=\theta_0"> and <img src="https://render.githubusercontent.com/render/math?math=\theta_1"> are connecting points trained with SPRO. The right plot shows a 3 dimensional projection of a simplicial complex containing 7 independently trained modes (orange) and 9 connecting points (blue) forming a total of 12 interconnected simplexes (blue shade) in the parameter space of a VGG16 network trained on CIFAR-100. | ||
|
||
Beyond providing the ability to find multidimensional simplexes of low loss in parameter space our method introduces a practical method for improving accuracy and calibration over standard deep ensembles. | ||
|
||
<p float="center"> | ||
<img src="./plots/cifar-simplex-acc.jpg" width="1000" /> | ||
</p> | ||
|
||
### Dependencies | ||
* [PyTorch](https://pytorch.org/) | ||
* [Torchvision](https://github.com/pytorch/vision/) | ||
* [Tabulate](https://pypi.python.org/pypi/tabulate/) | ||
* [GPyTorch](https://github.com/cornellius-gp/gpytorch) | ||
|
||
|
||
## Usage and General Repository Structure | ||
|
||
There are 2 main directories | ||
- `experiments` contains the main directories to reproduce the core experiments of the paper, | ||
- `simplex` contains the core code including mode definitions and utilities like training and evaluation functions. | ||
|
||
### Experiments | ||
|
||
#### `vgg-cifar10` and `vgg-cifar100` | ||
|
||
These directories operate very similarly. Each contains several scripts: `base_trainer.py` for training VGG style networks on the corresponding dataset, `simplex_trainer.py` to train a simplex starting from one of the pre-trained models, and `complex_iterative.py` which takes in a set of pre-trained models and computes a mode connecting simplicial complex containing them. Finally, these directories contain `ensemble_trainer.py` which can be used to train an ensemble of simplexes in a single command. | ||
|
||
The specific commands to reproduce the results in the paper can be found in the READMEs in the corresponding directories. | ||
|
||
### Simplex | ||
|
||
The core class definition here is `SimplexNet` contained in `simplex/models/simplex_models.py`, which is an adaptation of the `CurveNet` class used in [Loss Surfaces, Mode Connectivity, and Fast Ensembling of DNNs](https://arxiv.org/abs/1802.10026) with the accompanying codebase [here](https://github.com/timgaripov/dnn-mode-connectivity). The `SimplexNet` class serves as a wrapper for a network such as `VGG16Simplex` in `simplex/models/vgg_noBN.py`, which will be stored as the attribute `SimplexNet.net`. | ||
|
||
The networks that have simplex base components (such as `VGG16Simplex`) hold each of the vertices of the simplex as their parameters and keep a list of `fix_points` that determine which of these vertices should recieve gradients for training. | ||
|
||
Forward calls to a `SimplexNet` instance by default samples a set of parameters from within the simplex defined by the weight vectors in `SimplexNet.net`, and computes an output based on those parameters. Practically this means that if you call `SimplexNet(input)` repeatedly, you will get different results each time. | ||
|
||
### Relevant Papers | ||
|
||
[Loss Surfaces, Mode Connectivity, and Fast Ensembling of DNNs](https://arxiv.org/pdf/1802.10026.pdf) by Timur Garipov, Pavel Izmailov, Dmitrii Podoprikhin, Dmitry Vetrov, Andrew Gordon Wilson | ||
|
||
[Essentially No Barriers in Neural Network Energy Landscape](https://arxiv.org/pdf/1803.00885.pdf) by Felix Draxler, Kambis Veschgini, Manfred Salmhofer, Fred A. Hamprecht | ||
|
||
[Large Scale Structure of Neural Network Loss Landscapes](https://arxiv.org/pdf/1906.04724.pdf) by Stanislav Fort, Stanislaw Jastrzebski |
60 changes: 60 additions & 0 deletions
60
experiments/vgg-cifar10/.ipynb_checkpoints/README-checkpoint.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
# Training Classifiers and Simplexes | ||
|
||
#### Ensemble Trainer | ||
|
||
The core method provided in the paper is ESPRO, which functions as an ensemble of simplexes. | ||
|
||
To train an ESPRO ensemble as is done in the paper run the following | ||
|
||
```bash | ||
python3 ensemble_trainer.py --data_path=<YOUR DATA PATH> \ | ||
--base_epochs=300 \ | ||
--simplex_epochs=10 \ | ||
--base_lr=0.05 \ | ||
--simplex_lr=0.01 \ | ||
--wd=5e-4 \ | ||
--LMBD=1e-6\ | ||
--n_component=<NUMBER OF ENSEMBLE COMPONENTS> \ | ||
--n_verts=<NUMBER OF VERTICES PER SIMPLEX> \ | ||
``` | ||
|
||
|
||
|
||
#### Base Models | ||
|
||
To train the base models run | ||
```bash | ||
python3 base_trainer.py --data_path=<YOUR DATA PATH> \ | ||
--epochs=300 \ | ||
--lr_init=0.05 \ | ||
--wd=5e-4 | ||
``` | ||
each time you run this a new model will be trained and saved in the `saved-outputs` folder. | ||
|
||
|
||
#### Simplexes | ||
To use pretrained base models to train simplexes around the SGD-found solutions run | ||
|
||
```bash | ||
python3 simplex_trainer.py --data_path=<YOUR DATA PATH> \ | ||
--epochs=10 \ | ||
--lr_init=0.01 \ | ||
--wd=5e-4 \ | ||
--base_idx=<INDEX OF PRETRAINED MODEL> \ | ||
--n_verts=<TOTAL NUMBER OF VERTICES IN SIMPLEX> \ | ||
--n_sample=5 | ||
``` | ||
|
||
#### Mode Connecting Complexes | ||
|
||
To find simplicial complexes that connect modes in parameter space use | ||
```bash | ||
python3 simplex_trainer.py --data_path=<YOUR DATA PATH> \ | ||
--epochs=25 \ | ||
--lr_init=0.01 \ | ||
--wd=5e-4 \ | ||
--n_verts=<HOW MANY PRETRAINED MODELS TO CONNECT> \ | ||
--n_connector=<HOW MANY CONNECTING POINTS TO USE> \ | ||
--n_sample=5 | ||
``` | ||
Here `n_verts` is the number of independently trained models you want to connect through `n_connector` points. |
148 changes: 148 additions & 0 deletions
148
experiments/vgg-cifar10/.ipynb_checkpoints/base_trainer-checkpoint.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
import math | ||
import torch | ||
from torch import nn | ||
import numpy as np | ||
import pandas as pd | ||
import argparse | ||
|
||
from torch.utils.data import DataLoader | ||
import torchvision | ||
from torchvision import transforms | ||
import glob | ||
|
||
import tabulate | ||
|
||
import sys | ||
sys.path.append("../../simplex/") | ||
import utils | ||
from simplex_helpers import volume_loss | ||
import surfaces | ||
import time | ||
sys.path.append("../../simplex/models/") | ||
from vgg_noBN import VGG16 | ||
|
||
|
||
def main(args): | ||
trial_num = len(glob.glob("./saved-outputs/model_*")) | ||
savedir = "./saved-outputs/model_" + str(trial_num) + "/" | ||
os.makedirs(savedir, exist_ok=True) | ||
|
||
transform_train = transforms.Compose([ | ||
transforms.RandomHorizontalFlip(), | ||
transforms.RandomCrop(32, padding=4), | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) | ||
]) | ||
|
||
transform_test = transforms.Compose([ | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) | ||
]) | ||
|
||
dataset = torchvision.datasets.CIFAR10(args.data_path, | ||
train=True, download=False, | ||
transform=transform_train) | ||
trainloader = DataLoader(dataset, shuffle=True, batch_size=args.batch_size) | ||
|
||
testset = torchvision.datasets.CIFAR10(args.data_path, | ||
train=False, download=False, | ||
transform=transform_test) | ||
testloader = DataLoader(testset, shuffle=True, batch_size=args.batch_size) | ||
|
||
model = VGG16(10) | ||
model = model.cuda() | ||
|
||
## training setup ## | ||
optimizer = torch.optim.SGD( | ||
model.parameters(), | ||
lr=args.lr_init, | ||
momentum=0.9, | ||
weight_decay=args.wd | ||
) | ||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) | ||
criterion = torch.nn.CrossEntropyLoss() | ||
|
||
## train ## | ||
columns = ['ep', 'lr', 'tr_loss', 'tr_acc', 'te_loss', 'te_acc', 'time'] | ||
for epoch in range(args.epochs): | ||
time_ep = time.time() | ||
train_res = utils.train_epoch(trainloader, model, criterion, optimizer) | ||
|
||
if epoch == 0 or epoch % args.eval_freq == args.eval_freq - 1 or epoch == args.epochs - 1: | ||
test_res = utils.eval(testloader, model, criterion) | ||
else: | ||
test_res = {'loss': None, 'accuracy': None} | ||
|
||
time_ep = time.time() - time_ep | ||
|
||
lr = optimizer.param_groups[0]['lr'] | ||
scheduler.step() | ||
|
||
values = [epoch + 1, lr, train_res['loss'], train_res['accuracy'], | ||
test_res['loss'], test_res['accuracy'], time_ep] | ||
|
||
table = tabulate.tabulate([values], columns, tablefmt='simple', floatfmt='8.4f') | ||
if epoch % 40 == 0: | ||
table = table.split('\n') | ||
table = '\n'.join([table[1]] + table) | ||
else: | ||
table = table.split('\n')[2] | ||
print(table, flush=True) | ||
|
||
checkpoint = model.state_dict() | ||
trial_num = len(glob.glob("./saved-outputs/model_*")) | ||
savedir = "./saved-outputs/model_" +\ | ||
str(trial_num) + "/" | ||
os.makedirs(savedir, exist_ok=True) | ||
torch.save(checkpoint, savedir + "base_model.pt") | ||
|
||
|
||
if __name__ == '__main__': | ||
|
||
parser = argparse.ArgumentParser(description="cifar10 simplex") | ||
|
||
parser.add_argument( | ||
"--batch_size", | ||
type=int, | ||
default=128, | ||
metavar="N", | ||
help="input batch size (default: 50)", | ||
) | ||
|
||
parser.add_argument( | ||
"--lr_init", | ||
type=float, | ||
default=0.05, | ||
metavar="LR", | ||
help="initial learning rate (default: 0.1)", | ||
) | ||
parser.add_argument( | ||
"--data_path", | ||
default="/datasets/", | ||
help="directory where datasets are stored", | ||
) | ||
|
||
parser.add_argument( | ||
"--wd", | ||
type=float, | ||
default=5e-4, | ||
metavar="weight_decay", | ||
help="weight decay", | ||
) | ||
parser.add_argument( | ||
"--epochs", | ||
type=int, | ||
default=300, | ||
metavar="epochs", | ||
help="number of training epochs", | ||
) | ||
parser.add_argument( | ||
'--eval_freq', | ||
type=int, | ||
default=5, | ||
metavar='N', | ||
help='evaluation frequency (default: 5)' | ||
) | ||
args = parser.parse_args() | ||
|
||
main(args) |
Oops, something went wrong.