Skip to content

Commit

Permalink
rename and update
Browse files Browse the repository at this point in the history
  • Loading branch information
HobbitLong committed Oct 23, 2019
1 parent 559adc7 commit b13aecc
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 11 deletions.
5 changes: 2 additions & 3 deletions models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .resnetv2 import ResNet50
from .wrn import wrn_16_1, wrn_16_2, wrn_40_1, wrn_40_2
from .vgg import vgg19_bn, vgg16_bn, vgg13_bn, vgg11_bn, vgg8_bn
from .mobilenetv2 import mobile_half, mobile_full
from .mobilenetv2 import mobile_half
from .ShuffleNetv1 import ShuffleV1
from .ShuffleNetv2 import ShuffleV2

Expand All @@ -26,8 +26,7 @@
'vgg13': vgg13_bn,
'vgg16': vgg16_bn,
'vgg19': vgg19_bn,
'mobile_half': mobile_half,
'mobile_full': mobile_full,
'MobileNetV2': mobile_half,
'ShuffleV1': ShuffleV1,
'ShuffleV2': ShuffleV2,
}
6 changes: 1 addition & 5 deletions models/mobilenetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch.nn as nn
import math

__all__ = ['mobilenetv2_T_w', 'mobile_half', 'mobile_full']
__all__ = ['mobilenetv2_T_w', 'mobile_half']

BN = None

Expand Down Expand Up @@ -184,10 +184,6 @@ def mobile_half(num_classes):
return mobilenetv2_T_w(6, 0.5, num_classes)


def mobile_full(num_classes):
return mobilenetv2_T_w(6, 1., num_classes)


if __name__ == '__main__':
x = torch.randn(2, 3, 32, 32)

Expand Down
4 changes: 2 additions & 2 deletions train_student.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def parse_option():
choices=['resnet8', 'resnet14', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110',
'resnet8x4', 'resnet32x4', 'wrn_16_1', 'wrn_16_2', 'wrn_40_1', 'wrn_40_2',
'vgg8', 'vgg11', 'vgg13', 'vgg16', 'vgg19', 'ResNet50',
'mobile_half', 'mobile_full', 'ShuffleV1', 'ShuffleV2'])
'MobileNetV2', 'ShuffleV1', 'ShuffleV2'])
parser.add_argument('--path_t', type=str, default=None, help='teacher model snapshot')

# distillation
Expand Down Expand Up @@ -90,7 +90,7 @@ def parse_option():
opt = parser.parse_args()

# set different learning rate from these 4 models
if opt.model_s in ['mobile_half', 'mobile_full', 'ShuffleV1', 'ShuffleV2']:
if opt.model_s in ['MobileNetV2', 'ShuffleV1', 'ShuffleV2']:
opt.learning_rate = 0.01

# set the path according to the environment
Expand Down
2 changes: 1 addition & 1 deletion train_teacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def parse_option():
choices=['resnet8', 'resnet14', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110',
'resnet8x4', 'resnet32x4', 'wrn_16_1', 'wrn_16_2', 'wrn_40_1', 'wrn_40_2',
'vgg8', 'vgg11', 'vgg13', 'vgg16', 'vgg19',
'mobile_half', 'mobile_full', 'ShuffleV1', 'ShuffleV2', ])
'MobileNetV2', 'ShuffleV1', 'ShuffleV2', ])
parser.add_argument('--dataset', type=str, default='cifar100', choices=['cifar100'], help='dataset')

parser.add_argument('-t', '--trial', type=int, default=0, help='the experiment id')
Expand Down

0 comments on commit b13aecc

Please sign in to comment.