Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
g-benton committed Feb 25, 2021
0 parents commit 882584b
Show file tree
Hide file tree
Showing 74 changed files with 9,413 additions and 0 deletions.
57 changes: 57 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Loss Surface Simplexes

This repository contains the code accompanying our paper _Loss Surface Simplexes for Mode Connecting Volumes and Fast Ensembling_ by Greg Benton, Wesley Maddox, Sanae Lotfi, and Andrew Gordon Wilson.

## Introduction

The repository holds the implementation for Simplicial Pointwise Random Optimization (SPRO) as a method of finding simplicial complexes of low loss in the parameter space of neural networks. This work extends the approach of [Loss Surfaces, Mode Connectivity, and Fast Ensembling of DNNs](https://arxiv.org/abs/1802.10026) by Garipov et al. allowing us to find not just mode connecting paths but large intricate volumes of low loss that connect independently trained models in parameter space.

<p float="center">
<img src="./plots/vggc10_mode_conn.jpg" width="400" />
<img src="./plots/extended_c10.jpg" width="300" />
</p>

The left plot above shows loss surface projections along the edges and faces of a simplicial complex of low loss where <img src="https://render.githubusercontent.com/render/math?math=w_0"> and <img src="https://render.githubusercontent.com/render/math?math=w_1"> are independently trained VGG16 models on CIFAR-10, and <img src="https://render.githubusercontent.com/render/math?math=\theta_0"> and <img src="https://render.githubusercontent.com/render/math?math=\theta_1"> are connecting points trained with SPRO. The right plot shows a 3 dimensional projection of a simplicial complex containing 7 independently trained modes (orange) and 9 connecting points (blue) forming a total of 12 interconnected simplexes (blue shade) in the parameter space of a VGG16 network trained on CIFAR-100.

Beyond providing the ability to find multidimensional simplexes of low loss in parameter space our method introduces a practical method for improving accuracy and calibration over standard deep ensembles.

<p float="center">
<img src="./plots/cifar-simplex-acc.jpg" width="1000" />
</p>

### Dependencies
* [PyTorch](https://pytorch.org/)
* [Torchvision](https://github.com/pytorch/vision/)
* [Tabulate](https://pypi.python.org/pypi/tabulate/)
* [GPyTorch](https://github.com/cornellius-gp/gpytorch)


## Usage and General Repository Structure

There are 2 main directories
- `experiments` contains the main directories to reproduce the core experiments of the paper,
- `simplex` contains the core code including mode definitions and utilities like training and evaluation functions.

### Experiments

#### `vgg-cifar10` and `vgg-cifar100`

These directories operate very similarly. Each contains several scripts: `base_trainer.py` for training VGG style networks on the corresponding dataset, `simplex_trainer.py` to train a simplex starting from one of the pre-trained models, and `complex_iterative.py` which takes in a set of pre-trained models and computes a mode connecting simplicial complex containing them. Finally, these directories contain `ensemble_trainer.py` which can be used to train an ensemble of simplexes in a single command.

The specific commands to reproduce the results in the paper can be found in the READMEs in the corresponding directories.

### Simplex

The core class definition here is `SimplexNet` contained in `simplex/models/simplex_models.py`, which is an adaptation of the `CurveNet` class used in [Loss Surfaces, Mode Connectivity, and Fast Ensembling of DNNs](https://arxiv.org/abs/1802.10026) with the accompanying codebase [here](https://github.com/timgaripov/dnn-mode-connectivity). The `SimplexNet` class serves as a wrapper for a network such as `VGG16Simplex` in `simplex/models/vgg_noBN.py`, which will be stored as the attribute `SimplexNet.net`.

The networks that have simplex base components (such as `VGG16Simplex`) hold each of the vertices of the simplex as their parameters and keep a list of `fix_points` that determine which of these vertices should recieve gradients for training.

Forward calls to a `SimplexNet` instance by default samples a set of parameters from within the simplex defined by the weight vectors in `SimplexNet.net`, and computes an output based on those parameters. Practically this means that if you call `SimplexNet(input)` repeatedly, you will get different results each time.

### Relevant Papers

[Loss Surfaces, Mode Connectivity, and Fast Ensembling of DNNs](https://arxiv.org/pdf/1802.10026.pdf) by Timur Garipov, Pavel Izmailov, Dmitrii Podoprikhin, Dmitry Vetrov, Andrew Gordon Wilson

[Essentially No Barriers in Neural Network Energy Landscape](https://arxiv.org/pdf/1803.00885.pdf) by Felix Draxler, Kambis Veschgini, Manfred Salmhofer, Fred A. Hamprecht

[Large Scale Structure of Neural Network Loss Landscapes](https://arxiv.org/pdf/1906.04724.pdf) by Stanislav Fort, Stanislaw Jastrzebski
60 changes: 60 additions & 0 deletions experiments/vgg-cifar10/.ipynb_checkpoints/README-checkpoint.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Training Classifiers and Simplexes

#### Ensemble Trainer

The core method provided in the paper is ESPRO, which functions as an ensemble of simplexes.

To train an ESPRO ensemble as is done in the paper run the following

```bash
python3 ensemble_trainer.py --data_path=<YOUR DATA PATH> \
--base_epochs=300 \
--simplex_epochs=10 \
--base_lr=0.05 \
--simplex_lr=0.01 \
--wd=5e-4 \
--LMBD=1e-6\
--n_component=<NUMBER OF ENSEMBLE COMPONENTS> \
--n_verts=<NUMBER OF VERTICES PER SIMPLEX> \
```



#### Base Models

To train the base models run
```bash
python3 base_trainer.py --data_path=<YOUR DATA PATH> \
--epochs=300 \
--lr_init=0.05 \
--wd=5e-4
```
each time you run this a new model will be trained and saved in the `saved-outputs` folder.


#### Simplexes
To use pretrained base models to train simplexes around the SGD-found solutions run

```bash
python3 simplex_trainer.py --data_path=<YOUR DATA PATH> \
--epochs=10 \
--lr_init=0.01 \
--wd=5e-4 \
--base_idx=<INDEX OF PRETRAINED MODEL> \
--n_verts=<TOTAL NUMBER OF VERTICES IN SIMPLEX> \
--n_sample=5
```

#### Mode Connecting Complexes

To find simplicial complexes that connect modes in parameter space use
```bash
python3 simplex_trainer.py --data_path=<YOUR DATA PATH> \
--epochs=25 \
--lr_init=0.01 \
--wd=5e-4 \
--n_verts=<HOW MANY PRETRAINED MODELS TO CONNECT> \
--n_connector=<HOW MANY CONNECTING POINTS TO USE> \
--n_sample=5
```
Here `n_verts` is the number of independently trained models you want to connect through `n_connector` points.
148 changes: 148 additions & 0 deletions experiments/vgg-cifar10/.ipynb_checkpoints/base_trainer-checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import math
import torch
from torch import nn
import numpy as np
import pandas as pd
import argparse

from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
import glob

import tabulate

import sys
sys.path.append("../../simplex/")
import utils
from simplex_helpers import volume_loss
import surfaces
import time
sys.path.append("../../simplex/models/")
from vgg_noBN import VGG16


def main(args):
trial_num = len(glob.glob("./saved-outputs/model_*"))
savedir = "./saved-outputs/model_" + str(trial_num) + "/"
os.makedirs(savedir, exist_ok=True)

transform_train = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

dataset = torchvision.datasets.CIFAR10(args.data_path,
train=True, download=False,
transform=transform_train)
trainloader = DataLoader(dataset, shuffle=True, batch_size=args.batch_size)

testset = torchvision.datasets.CIFAR10(args.data_path,
train=False, download=False,
transform=transform_test)
testloader = DataLoader(testset, shuffle=True, batch_size=args.batch_size)

model = VGG16(10)
model = model.cuda()

## training setup ##
optimizer = torch.optim.SGD(
model.parameters(),
lr=args.lr_init,
momentum=0.9,
weight_decay=args.wd
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
criterion = torch.nn.CrossEntropyLoss()

## train ##
columns = ['ep', 'lr', 'tr_loss', 'tr_acc', 'te_loss', 'te_acc', 'time']
for epoch in range(args.epochs):
time_ep = time.time()
train_res = utils.train_epoch(trainloader, model, criterion, optimizer)

if epoch == 0 or epoch % args.eval_freq == args.eval_freq - 1 or epoch == args.epochs - 1:
test_res = utils.eval(testloader, model, criterion)
else:
test_res = {'loss': None, 'accuracy': None}

time_ep = time.time() - time_ep

lr = optimizer.param_groups[0]['lr']
scheduler.step()

values = [epoch + 1, lr, train_res['loss'], train_res['accuracy'],
test_res['loss'], test_res['accuracy'], time_ep]

table = tabulate.tabulate([values], columns, tablefmt='simple', floatfmt='8.4f')
if epoch % 40 == 0:
table = table.split('\n')
table = '\n'.join([table[1]] + table)
else:
table = table.split('\n')[2]
print(table, flush=True)

checkpoint = model.state_dict()
trial_num = len(glob.glob("./saved-outputs/model_*"))
savedir = "./saved-outputs/model_" +\
str(trial_num) + "/"
os.makedirs(savedir, exist_ok=True)
torch.save(checkpoint, savedir + "base_model.pt")


if __name__ == '__main__':

parser = argparse.ArgumentParser(description="cifar10 simplex")

parser.add_argument(
"--batch_size",
type=int,
default=128,
metavar="N",
help="input batch size (default: 50)",
)

parser.add_argument(
"--lr_init",
type=float,
default=0.05,
metavar="LR",
help="initial learning rate (default: 0.1)",
)
parser.add_argument(
"--data_path",
default="/datasets/",
help="directory where datasets are stored",
)

parser.add_argument(
"--wd",
type=float,
default=5e-4,
metavar="weight_decay",
help="weight decay",
)
parser.add_argument(
"--epochs",
type=int,
default=300,
metavar="epochs",
help="number of training epochs",
)
parser.add_argument(
'--eval_freq',
type=int,
default=5,
metavar='N',
help='evaluation frequency (default: 5)'
)
args = parser.parse_args()

main(args)
Loading

0 comments on commit 882584b

Please sign in to comment.