Skip to content

Commit

Permalink
Added Early Stopping and RMAE Dict saving
Browse files Browse the repository at this point in the history
  • Loading branch information
atharva-tendle committed May 22, 2021
1 parent bea3562 commit 8f96e0a
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 3 deletions.
2 changes: 2 additions & 0 deletions configs.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
"lr": 0.001,
"batch_size": 512,
"target_val_acc": 94.0,
"save_path": "./",
"tolerance": 5,
"arr_save_path": "/home/shared/results/",
"num_workers": 4,
"data_path": "/home/shared/imagenet"
Expand Down
22 changes: 19 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import torch
from os.path import join

from utils.delta import compute_delta
from utils.helpers import accuracy
Expand All @@ -12,6 +13,7 @@ def training(epochs, loaders, model, optimizer, criterion, prev_list,
"""

min_test_loss = np.Inf
early_stopping_counter = 0

train_acc_arr, test_acc_arr = [], []

Expand Down Expand Up @@ -75,6 +77,9 @@ def training(epochs, loaders, model, optimizer, criterion, prev_list,
rmae_delta_dict, prev_list = compute_delta(
model, prev_list, rmae_delta_dict)

# save rmae dict
save_path = join(configs.arr_save_path, configs.exp_name)
np.save(save_path + f"rmae_dict_{epoch}.npy", rmae_delta_dict)

test_top1, test_top5 = [], []

Expand Down Expand Up @@ -104,9 +109,6 @@ def training(epochs, loaders, model, optimizer, criterion, prev_list,
# accumulate total number of examples
test_total += data.size(0)

if test_loss < min_test_loss:
print(f"Saving model at Epoch: {epoch}")
# torch.save(model.state_dict(), 'drive/My Drive/cifar10-resnet18-gradual-adam')

test_loss = round(test_loss/len(loaders['test'].dataset), 4)
test_acc = round(((test_correct/test_total) * 100), 4)
Expand All @@ -127,6 +129,20 @@ def training(epochs, loaders, model, optimizer, criterion, prev_list,
f"Epoch: {epoch} \tTrain Loss: {train_loss} \tTrain Top-1: {epoch_train_t1} \tTrain Top-5: {epoch_train_t5}%
\tTest Loss: {test_loss} \tTest Top-1: {epoch_test_t1} \tTest Top-5: {epoch_test_t5}%")

# early stopping
# accumulate consecutive counters
if test_loss < min_test_loss and early_stopping_counter:
early_stopping_counter += 1
min_test_loss = test_loss
else:
min_test_loss = test_loss
early_stopping_counter = 0

# check if we passed tolerance
if early_stopping_counter >= configs.tolerance:
print(f"Saving model at Epoch: {epoch}")
torch.save(model.state_dict(), configs.save_path)

if float(test_acc) >= configs.target_val_acc:
break

Expand Down

0 comments on commit 8f96e0a

Please sign in to comment.