Skip to content

Commit

Permalink
store att for aakd
Browse files Browse the repository at this point in the history
  • Loading branch information
triomino committed Aug 5, 2020
1 parent ab84437 commit 85140ad
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 17 deletions.
2 changes: 1 addition & 1 deletion models/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def forward(self, x):

class SelfA(nn.Module):
"""Cross layer Self Attention"""
def __init__(self, s_len, t_len, input_channel, s_n, s_t, factor=2):
def __init__(self, s_len, t_len, input_channel, s_n, s_t, factor=4):
super(SelfA, self).__init__()

self.avgpool = nn.AdaptiveAvgPool2d((1,1))
Expand Down
8 changes: 7 additions & 1 deletion results/records.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,4 +90,10 @@ python train_student.py --path-t ./save/models/ResNet34_vanilla/resnet34_transfo
这个一开始爆显存了,所以 DALI 都用了 cpu 来减少显存用量。
## aakd
7.30 早上十点多开始,开始的 GPU 用量:
![vgg_shuffle_aakd_GPU](vgg_shuffle_aakd_GPU.png)
![vgg_shuffle_aakd_GPU](vgg_shuffle_aakd_GPU.png)

# 懒得写记录了 直接画表格
VGG13 -> ShuffleV2
|aakd(b=100)|aakd(b=400)|hint(re-run required)|irg(b=0.05)|sp|vid|
|-|-|-|-|-|-|
|63.228|62.144|63.134|63.82x|
2 changes: 2 additions & 0 deletions train_student.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,8 @@ def main_worker(gpu, ngpus_per_node, opt):
'model': model_s.state_dict(),
'best_acc': best_acc,
}
if opt.distill == 'aakd':
state['attention'] = trainable_list[-1].state_dict()
save_file = os.path.join(opt.save_folder, '{}_best.pth'.format(opt.model_s))

test_merics = {'test_loss': test_loss,
Expand Down
30 changes: 15 additions & 15 deletions vgg13_shufflev2.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,25 @@ EPOCH=90
LEARNING_RATE=0.1
DALI=cpu

# kd
# # kd
# python train_student.py --path-t ./save/models/vgg13_imagenet_vanilla/vgg13_transformed.pth \
# --batch_size $BATCH_SIZE --epochs $EPOCH --dataset imagenet --gpu_id $GPU --dist-url tcp:https://127.0.0.1:23333 \
# --print-freq 100 --num_workers $WORKER --model_s ShuffleV2_Imagenet --trial release \
# --multiprocessing-distributed --learning_rate $LEARNING_RATE --lr_decay_epochs 30,60 --weight_decay 1e-4 --dali $DALI \
# --distill kd -r 1 -a 1 -b 0
# fitnet
python train_student.py --path-t ./save/models/vgg13_imagenet_vanilla/vgg13_transformed.pth \
--batch_size $BATCH_SIZE --epochs $EPOCH --dataset imagenet --gpu_id $GPU --dist-url tcp:https://127.0.0.1:23333 \
--print-freq 100 --num_workers $WORKER --model_s ShuffleV2_Imagenet --trial release \
--multiprocessing-distributed --learning_rate $LEARNING_RATE --lr_decay_epochs 30,60 --weight_decay 1e-4 --dali $DALI \
--distill hint -r 1 -a 1 -b 100 --hint_layer 1
# sp
python train_student.py --path-t ./save/models/vgg13_imagenet_vanilla/vgg13_transformed.pth \
--batch_size $BATCH_SIZE --epochs $EPOCH --dataset imagenet --gpu_id $GPU --dist-url tcp:https://127.0.0.1:23333 \
--print-freq 100 --num_workers $WORKER --model_s ShuffleV2_Imagenet --trial release \
--multiprocessing-distributed --learning_rate $LEARNING_RATE --lr_decay_epochs 30,60 --weight_decay 1e-4 --dali $DALI \
--distill similarity -r 1 -a 1 -b 3000
# vid
# # fitnet
# python train_student.py --path-t ./save/models/vgg13_imagenet_vanilla/vgg13_transformed.pth \
# --batch_size $BATCH_SIZE --epochs $EPOCH --dataset imagenet --gpu_id $GPU --dist-url tcp:https://127.0.0.1:23333 \
# --print-freq 100 --num_workers $WORKER --model_s ShuffleV2_Imagenet --trial release \
# --multiprocessing-distributed --learning_rate $LEARNING_RATE --lr_decay_epochs 30,60 --weight_decay 1e-4 --dali $DALI \
# --distill hint -r 1 -a 1 -b 100 --hint_layer 1
# # sp
# python train_student.py --path-t ./save/models/vgg13_imagenet_vanilla/vgg13_transformed.pth \
# --batch_size $BATCH_SIZE --epochs $EPOCH --dataset imagenet --gpu_id $GPU --dist-url tcp:https://127.0.0.1:23333 \
# --print-freq 100 --num_workers $WORKER --model_s ShuffleV2_Imagenet --trial release \
# --multiprocessing-distributed --learning_rate $LEARNING_RATE --lr_decay_epochs 30,60 --weight_decay 1e-4 --dali $DALI \
# --distill similarity -r 1 -a 1 -b 3000
# # vid
# python train_student.py --path-t ./save/models/vgg13_imagenet_vanilla/vgg13_transformed.pth \
# --batch_size $BATCH_SIZE --epochs $EPOCH --dataset imagenet --gpu_id $GPU --dist-url tcp:https://127.0.0.1:23333 \
# --print-freq 100 --num_workers $WORKER --model_s ShuffleV2_Imagenet --trial release \
Expand All @@ -35,7 +35,7 @@ python train_student.py --path-t ./save/models/vgg13_imagenet_vanilla/vgg13_tran
--print-freq 100 --num_workers $WORKER --model_s ShuffleV2_Imagenet --trial release \
--multiprocessing-distributed --learning_rate $LEARNING_RATE --lr_decay_epochs 30,60 --weight_decay 1e-4 --dali $DALI \
--distill aakd -r 1 -a 1 -b 100
# irg
# # irg
# python train_student.py --path-t ./save/models/vgg13_imagenet_vanilla/vgg13_transformed.pth \
# --batch_size $BATCH_SIZE --epochs $EPOCH --dataset imagenet --gpu_id $GPU --dist-url tcp:https://127.0.0.1:23333 \
# --print-freq 100 --num_workers $WORKER --model_s ShuffleV2_Imagenet --trial release \
Expand Down

0 comments on commit 85140ad

Please sign in to comment.