-
Notifications
You must be signed in to change notification settings - Fork 37
/
train_oneflow.sh
35 lines (28 loc) · 821 Bytes
/
train_oneflow.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
set -aux
OFRECORD_PATH="ofrecord"
if [ ! -d "$OFRECORD_PATH" ]; then
wget https://oneflow-public.oss-cn-beijing.aliyuncs.com/datasets/imagenette_ofrecord.tar.gz
tar zxf imagenette_ofrecord.tar.gz
fi
CHECKPOINT_PATH="checkpoints"
if [ ! -d "$CHECKPOINT_PATH" ]; then
mkdir $CHECKPOINT_PATH
fi
LEARNING_RATE=0.001
MOM=0.9
EPOCH=1000
TRAIN_BATCH_SIZE=16
VAL_BATCH_SIZE=16
MODEL=vgg16_bn
# LOAD PREVIOUS CHECKPOINT
# LOAD_CHECKPOINT=$CHECKPOINT_PATH/epoch_10_val_acc_0.230020
python3 train_oneflow.py \
--save_checkpoint_path $CHECKPOINT_PATH \
--ofrecord_path $OFRECORD_PATH \
--learning_rate $LEARNING_RATE \
--mom $MOM \
--epochs $EPOCH \
--train_batch_size $TRAIN_BATCH_SIZE \
--val_batch_size $VAL_BATCH_SIZE \
--model $MODEL \
#--load_checkpoint $LOAD_CHECKPOINT