Skip to content

Commit

Permalink
修复Bug
Browse files Browse the repository at this point in the history
  • Loading branch information
moon-hotel committed Oct 1, 2021
1 parent 05929ab commit e7d0211
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def train_model(config):
for p in classification_model.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
model_save_path = os.path.join(config.model_save_dir, 'model.pkl')
model_save_path = os.path.join(config.model_save_dir, 'model.pt')
if os.path.exists(model_save_path):
loaded_paras = torch.load(model_save_path)
classification_model.load_state_dict(loaded_paras)
Expand All @@ -55,6 +55,7 @@ def train_model(config):
betas=(config.beta1, config.beta2),
eps=config.epsilon)
classification_model.train()
max_test_acc = 0
for epoch in range(config.epochs):
losses = 0
start_time = time.time()
Expand All @@ -75,25 +76,29 @@ def train_model(config):

acc = (logits.argmax(1) == label).float().mean()
if idx % 10 == 0:
print(
f"Epoch: {epoch}, Batch[{idx}/{len(train_iter)}], Train loss :{loss.item():.3f}, Train acc: {acc:.3f}")
print(f"Epoch: {epoch}, Batch[{idx}/{len(train_iter)}], "
f"Train loss :{loss.item():.3f}, Train acc: {acc:.3f}")
end_time = time.time()
train_loss = losses / len(train_iter)
print(f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Epoch time = {(end_time - start_time):.3f}s")
if (epoch + 1) % config.model_save_per_epoch == 0:
acc = evaluate(test_iter, classification_model, config.device)
print(f"Accuracy on test {acc:.3f}")
torch.save(classification_model.state_dict(), model_save_path)
print(f"Accuracy on test {acc:.3f}, max acc on test {max_test_acc:.3f}")
if acc > max_test_acc:
max_test_acc = acc
torch.save(classification_model.state_dict(), model_save_path)


def evaluate(data_iter, model, device):
model.eval()
with torch.no_grad():
acc_sum, n = 0.0, 0
for x, y in data_iter:
x, y = x.to(device), y.to(device)
logits = model(x)
acc_sum += (logits.argmax(1) == y).float().sum().item()
n += len(y)
model.train()
return acc_sum / n


Expand Down

0 comments on commit e7d0211

Please sign in to comment.