Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
ohyeat committed Jan 19, 2020
0 parents commit 136491f
Show file tree
Hide file tree
Showing 279 changed files with 26,960 additions and 0 deletions.
21 changes: 21 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2020 Megvii Technology

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
117 changes: 117 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Moving Averge Batch Normalization

This reposity is the Pytorch implementation of **Moving Average Batch Normalization** on Imagenet classfication, COCO object detection and instance segmentation tasks. Notice the Imagenet classification simulate the small batch training settings by using small normalization batch size and regular SGD batch size.

The paper has been published as an ICLR2020 conference paper (https://openreview.net/forum?id=SkgGjRVKDS&noteId=BJeCWt3KiH).

## Results

### Overall comparation of MABN and its counterparts

Top 1 Error versus Batch Size:

<img src="https://uc23c04615403e17cfed78ccacc6.previews.dropboxusercontent.com/p/thumb/AApzTU_479fgUugJranC19juFKPjXUlOgbr2D6GpIje1c0f6aXL_VLQYmQ1ojgHTqSoPhXsDcyN7-dPXL3xgw9rXdckmGKgHR09Q5ihJAVeMUu5SwFcCmJHPidCPP9mo-ILrKKnInM6ohDdiV1541Ie9ozRy-PGHOlQ8zgu0z7JndBVBgw7Ave8gN1-ixsrPw-ANRBNZwmZTgfAha4BsCvJLU3E3pjaFm8BysoHn3UbHwOWaGdUuLDPDrfCNG0cWHBkqRJaimFhwG44y7-4W2wV_2_4V8s0Wj9vtkvWzIyGuOx6nWoqY0ID9iE9L8GhNZdJZ6qR6IE2hRVSenWmVVa43hbLj4QCWOOOLcjGDE0L4O_HwddjoFipHzSkLFsuZyD5NQiaWwasvNroMGm0G41jJ/p.png?fv_content=true&size_mode=5" width="500" height="350" />

Inference Speend

| Norm | Iterations/second |
|:-------:|:-------------:|
| BN/MABN | 237.88 |
| Instance Normalization | 105.60 |
| Group Normalization | 99.37 |
| Layer Normalization | 125.44 |

### Imagenet

| Model | Normalization Batch size | Norm | Top 1 Accuracy |
|:--------:|:-----------:|:----:|:------:|
| ResNet50 | 32 | BN | 23.41 |
| ResNet50 | 2 | BN | 35.22 |
| ResNet50 | 2 | BRN | 30.29 |
| ResNet50 | 2 | MABN | 23.67 |

### COCO
| Backbone | Method | Training Strategy | Norm | Batch Size | AP<sup>b</sup> | AP<sup>b</sup><sub>0.50</sub> | AP<sup>b</sup><sub>0.75</sub> | AP<sup>m</sup> | AP<sup>m</sup><sub>0.50</sub> | AP<sup>m</sup><sub>0.75</sub> |
|:-------------:|:------------:|:---------:|:----:|:------:|:----:|:----:|:----:|:----:|:----:|:----:|
| R50-FPN | Mask R-CNN | 2x from scratch | BN | 2 | 32.38 | 50.44 | 35.47 | 29.07 | 47.68 | 30.75 |
| R50-FPN | Mask R-CNN | 2x from scratch | BRN | 2 | 34.07 | 52.66 | 37.12 | 30.98 | 50.03 | 32.93 |
| R50-FPN | Mask R-CNN | 2x from scratch | SyncBN | 2x8 | 36.80 | 56.06 | 40.23 | 33.10 | 53.15 | 35.24 |
| R50-FPN | Mask R-CNN | 2x from scratch | MABN | 2 | 36.50 | 55.79 | 40.17 | 32.69 | 52.78 | 34.71 |
| R50-FPN | Mask R-CNN | 2x fine-tune | SyncBN | 2x8 | 38.25 | 57.81 | 42.01 | 34.22 | 54.97 | 36.34 |
| R50-FPN | Mask R-CNN | 2x fine-tune | MABN | 2 | 38.42 | 58.19 | 41.99 | 34.12 | 55.10 | 36.12 |


## Demo

One node with 8 GPUs.

### Imagenet
```bash
cd /your_path_to_repo/cls

# 8 GPUs Train and Test
python3 -m torch.distributed.launch --nproc_per_node=8 train.py --gpu_num=8 \
--save /your_path_to_logs \
--train_dir /your_imagenet_training_dataset_dir \
--val_dir /your_imagenet_eval_dataset_dir \
--gpu_num=8

# Only Test the trained model
python3 -m torch.distributed.launch --nproc_per_node=1 train.py --gpu_num=1 \
--save /your_path_to_logs \
--val_dir /your_imagenet_eval_dataset_dir \
--checkpoint_dir /your_path_to_checkpoint \
--test_only
```

### COCO
Please refer to [INSTALL.md](det/INSTALL.md) for installation and dataset preparation.

To use SyncBN, please do:
```bash
cd /your_path_to_repo/det/maskrcnn_benchmar/distributed_syncbn
bash compile.sh
```
You can download the pretrained model of ResNet-50 in [here](https://www.dropbox.com/sh/fbsi6935vmatbi9/AAA2jv0EBcSgySTgZnNZ3lmPa?dl=0). Notice R-50-2.pkl is the pretrained model for SyncBN while R50-wc.pth is for MABN.


```bash
cd /your_path_to_repo/det
# Train MABN from scratch
python3 -m torch.distributed.launch --nproc_per_node=8 tools/train_net.py \
--skip-test \
--config-file configs/e2e_mask_rcnn_R_50_FPN_mabn_2x_from_scratch.yaml \
DATALOADER.NUM_WORKERS 2 \
OUTPUT_DIR /your_path_to_logs

# Train MABN fine tuning (Download the pertrained model and set the path in configs at first)
python3 -m torch.distributed.launch --nproc_per_node=8 tools/train_net.py \
--skip-test \
--config-file configs/e2e_mask_rcnn_R_50_FPN_mabn_2x_fine_tune.yaml \
DATALOADER.NUM_WORKERS 2 \
OUTPUT_DIR /your_path_to_logs

# Test model
python3 -m torch.distributed.launch --nproc_per_node=8 tools/test_net.py \
--config-file configs/e2e_mask_rcnn_R_50_FPN_mabn_2x_from_scratch.yaml \
MODEL.WEIGHT /your_path_to_logs/model_0180000.pth \
TEST.IMS_PER_BATCH 8
```

## Thanks

This implementation of COCO is based on [maskrcnn-benchmark](https://github.com/facebookresearch/maskrcnn-benchmark). Ref to this link for more details about [maskrcnn-benchmark](https://github.com/facebookresearch/maskrcnn-benchmark).

## Citation

If you use Moving Average Batch Normalization in your research, please cite:
```bibtex
@inproceedings{
yan2020towards,
title={Towards Stablizing Batch Statistics in Backward Propagation of Batch Normalization},
author={Junjie Yan, Ruosi Wan, Xiangyu Zhang, Wei Zhang, Yichen Wei, Jian Sun},
booktitle={International Conference on Learning Representations},
year={2020},
url={https://openreview.net/forum?id=SkgGjRVKDS}
}
```
69 changes: 69 additions & 0 deletions cls/SGD.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import torch
from torch.optim.optimizer import Optimizer, required


class SGD(Optimizer):

def __init__(self, params, lr=required, momentum=0, dampening=0,
weight_decay=0, nesterov=False):
if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
if momentum < 0.0:
raise ValueError("Invalid momentum value: {}".format(momentum))
if weight_decay < 0.0:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))

defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
weight_decay=weight_decay, nesterov=nesterov)
if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError("Nesterov momentum requires a momentum and zero dampening")

super(SGD, self).__init__(params, defaults)

def __setstate__(self, state):
super(SGD, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('nesterov', False)

def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()

for group in self.param_groups:
weight_decay = group['weight_decay']
momentum = group['momentum']
dampening = group['dampening']
nesterov = group['nesterov']

for p in group['params']:
if p.grad is None:
continue
d_p = p.grad.data

if momentum != 0:
param_state = self.state[p]
if 'momentum_buffer' not in param_state:
buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
else:
buf = param_state['momentum_buffer']
buf.mul_(momentum).add_(1 - dampening, d_p)

if nesterov:
d_p = d_p.add(momentum, buf)
else:
d_p = buf
'''
weight decay is still included in momentum
'''
if weight_decay != 0:
d_p.add_(weight_decay, p.data)
p.data.add_(-group['lr'], d_p)

return loss
109 changes: 109 additions & 0 deletions cls/networks/MABN.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

class MABNFunction(torch.autograd.Function):

@staticmethod
def forward(ctx, x, weight, bias,
running_var, eps, momentum,
sta_matrix, pre_x2, pre_gz, iters
):
ctx.eps = eps
current_iter = iters.item()
ctx.iter = current_iter
N, C, H, W = x.size()

x = x.view(N//2, 2, C, H, W)
x2 = (x * x).mean(dim=4).mean(dim=3).mean(dim=1)
var = torch.cat([pre_x2, x2], dim=0)

var = torch.mm(sta_matrix, var)
var = var.view(N//2, 1, C, 1, 1)

if current_iter == 1:
var = x2.view(N//2, 1, C, 1, 1)

z = x /(var + eps).sqrt()
r = (var + eps).sqrt() / (running_var.view(1, 1, C, 1, 1) + eps).sqrt()
if current_iter < 100:
r = torch.clamp(r, 1, 1)
else:
r = torch.clamp(r, 1/5, 5)
y = r * z
ctx.save_for_backward(z, var, weight, sta_matrix, pre_gz, r)

if current_iter == 1:
running_var.copy_(var.mean(dim=0).view(-1,))
running_var.copy_(momentum*running_var + (1-momentum)*var.mean(dim=0).view(-1,))
pre_x2.copy_(x2)
y = weight.view(1,C,1,1) * y.view(N, C, H, W) + bias.view(1,C,1,1)

return y

@staticmethod
def backward(ctx, grad_output):
eps = ctx.eps
current_iter = ctx.iter
N, C, H, W = grad_output.size()
z, var, weight, sta_matrix, pre_gz, r = ctx.saved_variables
y = r * z
g = grad_output * weight.view(1, C, 1, 1)
g = g.view(N//2, 2, C, H, W) * r
gz = (g * z).mean(dim=4).mean(dim=3).mean(dim=1)

mean_gz = torch.cat([pre_gz, gz], dim=0)
mean_gz = torch.mm(sta_matrix, mean_gz)
mean_gz = mean_gz.view(N//2, 1, C, 1, 1)

if current_iter == 1:
mean_gz = gz.view(N//2, 1, C, 1, 1)
gx = 1. / torch.sqrt(var + eps) * (g - z * mean_gz)
gx = gx.view(N, C, H, W)
pre_gz.copy_(gz)

return gx, (grad_output * y.view(N, C, H, W)).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum(dim=0), None, None, None, None, None, None, None

class MABN2d(nn.Module):

def __init__(self, channels, eps=1e-5, momentum=0.98, buffer_size=16):
"""
buffer_size: Moving Average Batch Size / Normalization Batch Size
running_var: EMA statistics of x^2
buffer_x2: batch statistics of x^2 from last several iters
buffer_gz: batch statistics of phi from last several iters
iters: current iter
"""
super(MABN2d, self).__init__()
self.B = buffer_size
self.register_parameter('weight', nn.Parameter(torch.ones(channels)))
self.register_parameter('bias', nn.Parameter(torch.zeros(channels)))
self.register_buffer('running_var', torch.ones(channels))
self.register_buffer('sta_matrix', torch.ones(self.B, 2 *self.B)/self.B)
self.register_buffer('pre_x2', torch.ones(self.B, channels))
self.register_buffer('pre_gz', torch.zeros(self.B, channels))
self.register_buffer('iters', torch.zeros(1,))
self.eps = eps
self.momentum = momentum
self.init()

def init(self):
for i in range(self.sta_matrix.size(0)):
self.sta_matrix[i][:i+1] = 0
self.sta_matrix[i][self.B+i+1:] = 0

def forward(self, x):
if self.training:
self.iters.copy_(self.iters + 1)
x = MABNFunction.apply(x, self.weight, self.bias,
self.running_var, self.eps,
self.momentum, self.sta_matrix,
self.pre_x2, self.pre_gz,
self.iters)
return x
else:
N, C, H, W = x.size()
var = self.running_var.view(1, C, 1, 1)
x = x / (var + self.eps).sqrt()

return self.weight.view(1,C,1,1) * x + self.bias.view(1,C,1,1)
2 changes: 2 additions & 0 deletions cls/networks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from . import MABN
from . import resnet
Loading

0 comments on commit 136491f

Please sign in to comment.