Skip to content

Commit

Permalink
fix problems
Browse files Browse the repository at this point in the history
  • Loading branch information
triomino committed Jul 4, 2020
1 parent a51d792 commit 2088f54
Show file tree
Hide file tree
Showing 7 changed files with 5 additions and 230 deletions.
29 changes: 0 additions & 29 deletions distiller_zoo/AB.py

This file was deleted.

48 changes: 0 additions & 48 deletions distiller_zoo/FSP.py

This file was deleted.

31 changes: 0 additions & 31 deletions distiller_zoo/FT.py

This file was deleted.

75 changes: 0 additions & 75 deletions distiller_zoo/KDSVD.py

This file was deleted.

42 changes: 0 additions & 42 deletions distiller_zoo/NST.py

This file was deleted.

2 changes: 1 addition & 1 deletion kd3.sh
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1 +1 @@
python train_teacher.py --model ResNet18 --batch_size 64 --epochs 90 --learning_rate 0.1 --lr_decay_epochs 30,60 --weight_decay 1e-4 --dataset imagenet --gpu_id 0,1,2,3;
python3 train_teacher.py --model ResNet18 --batch_size 64 --epochs 90 --learning_rate 0.1 --lr_decay_epochs 30,60 --weight_decay 1e-4 --dataset imagenet --gpu_id 0,1,2,3
8 changes: 4 additions & 4 deletions train_student.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@

from helper.util import adjust_learning_rate, save_dict_to_json

from distiller_zoo import DistillKL, HintLoss, Attention, Similarity, Correlation, VIDLoss, RKDLoss, PKT, AAKDLoss, IRGLoss
from distiller_zoo import PKT, ABLoss, FactorTransfer, KDSVD, FSP, NSTLoss
from distiller_zoo import DistillKL, HintLoss, Attention, Similarity, Correlation, VIDLoss, RKDLoss, AAKDLoss, IRGLoss
from crd.criterion import CRDLoss

from helper.loops import train_distill as train, validate
Expand Down Expand Up @@ -60,6 +59,7 @@ def parse_option():
# distillation
parser.add_argument('--distill', type=str, default='kd', choices=['kd', 'hint', 'attention', 'similarity', 'vid',
'correlation', 'rkd', 'pkt', 'crd', 'aakd', 'irg'])
parser.add_argument('--trial', type=str, default='1', help='trial id')

parser.add_argument('-r', '--gamma', type=float, default=1.0, help='weight for classification')
parser.add_argument('-a', '--alpha', type=float, default=1.0, help='weight balance for KD')
Expand Down Expand Up @@ -265,11 +265,11 @@ def main():
if torch.cuda.is_available():
cudnn.benchmark = True
criterion_list.cuda()
if torch.cuda.device_count() > 1:
if torch.cuda.device_count() > 1 and len(opt.gpu_id.split(',')) > 1:
#model = nn.DataParallel(model, device_ids=opt.gpu_id).cuda()
module_list = nn.DataParallel(module_list).cuda()
else:
module_list = nn.DataParallel(module_list).cuda()
module_list.cuda()

# validate teacher accuracy
teacher_acc, _, _ = validate(val_loader, model_t, criterion_cls, opt)
Expand Down

0 comments on commit 2088f54

Please sign in to comment.