Skip to content

Commit

Permalink
fix train module
Browse files Browse the repository at this point in the history
  • Loading branch information
PengNi committed May 5, 2022
1 parent e15b08c commit b1e622f
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 20 deletions.
2 changes: 2 additions & 0 deletions deepsignal_plant/call_mods_freq.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from .utils.process_utils import is_file_empty
import uuid

os.environ['MKL_THREADING_LAYER'] = 'GNU'

time_wait = 3


Expand Down
6 changes: 3 additions & 3 deletions deepsignal_plant/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ def parse_a_line2(line):

class SignalFeaData2(Dataset):
def __init__(self, filename, transform=None):
print(">>>using linecache to access '{}'<<<\n"
">>>after done using the file, "
"remember to use linecache.clearcache() to clear cache for safety<<<".format(filename))
# print(">>>using linecache to access '{}'<<<\n"
# ">>>after done using the file, "
# "remember to use linecache.clearcache() to clear cache for safety<<<".format(filename))
self._filename = os.path.abspath(filename)
self._total_data = 0
self._transform = transform
Expand Down
3 changes: 1 addition & 2 deletions deepsignal_plant/deepsignal_plant.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,8 @@ def main_call_mods(args):


def main_call_freq(args):
import os
os.environ['MKL_THREADING_LAYER'] = 'GNU'
from .call_mods_freq import call_mods_frequency_to_file

display_args(args)
call_mods_frequency_to_file(args)

Expand Down
27 changes: 12 additions & 15 deletions deepsignal_plant/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def train(args):
# train at most max_epoch_num epochs
for epoch in range(args.max_epoch_num):
curr_best_accuracy_epoch = 0
no_best_model = True
tlosses = []
start = time.time()
for i, sfeatures in enumerate(train_loader):
Expand Down Expand Up @@ -147,13 +148,6 @@ def train(args):
if use_cuda:
vlabels = vlabels.cpu()
vpredicted = vpredicted.cpu()
# i_accuracy = metrics.accuracy_score(vlabels.numpy(), vpredicted)
# i_precision = metrics.precision_score(vlabels.numpy(), vpredicted)
# i_recall = metrics.recall_score(vlabels.numpy(), vpredicted)

# vaccus.append(i_accuracy)
# vprecs.append(i_precision)
# vrecas.append(i_recall)
vlosses.append(vloss.item())
vlabels_total += vlabels.tolist()
vpredicted_total += vpredicted.tolist()
Expand All @@ -163,11 +157,14 @@ def train(args):
v_recall = metrics.recall_score(vlabels_total, vpredicted_total)
if v_accuracy > curr_best_accuracy_epoch:
curr_best_accuracy_epoch = v_accuracy
if curr_best_accuracy_epoch > curr_best_accuracy - 0.0005:
if curr_best_accuracy_epoch > curr_best_accuracy - 0.0002:
torch.save(model.state_dict(),
model_dir + args.model_type + '.b{}_s{}_epoch{}.ckpt'.format(args.seq_len,
args.signal_len,
epoch + 1))
if curr_best_accuracy_epoch > curr_best_accuracy:
curr_best_accuracy = curr_best_accuracy_epoch
no_best_model = False

time_cost = time.time() - start
print('Epoch [{}/{}], Step [{}/{}], TrainLoss: {:.4f}; '
Expand All @@ -182,16 +179,16 @@ def train(args):
sys.stdout.flush()
model.train()
scheduler.step()
if curr_best_accuracy_epoch > curr_best_accuracy:
curr_best_accuracy = curr_best_accuracy_epoch
else:
if epoch >= args.min_epoch_num - 1:
print("best accuracy: {}, early stop!".format(curr_best_accuracy))
break

if no_best_model and epoch >= args.min_epoch_num - 1:
print("early stop!")
break

endtime = time.time()
clear_linecache()
print("[train] training cost {} seconds".format(endtime - total_start))
print("[train]training cost {} seconds, "
"best accuracy: {}".format(endtime - total_start,
curr_best_accuracy))


def main():
Expand Down

0 comments on commit b1e622f

Please sign in to comment.