Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: brits imputation test device mismatch #11

Merged
merged 3 commits into from
Aug 22, 2022
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
fix: brits imputation test device mismatch
  • Loading branch information
MaciejSkrabski committed Aug 2, 2022
commit 137b35dbd3fd9f247a08de165d494b519205df41
60 changes: 36 additions & 24 deletions pypots/imputation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,15 @@ def __init__(self, learning_rate, epochs, patience, batch_size, weight_decay, de
def assemble_input_data(self, data):
pass

def _train_model(self, training_loader, val_loader=None, val_X_intact=None, val_indicating_mask=None):
self.optimizer = torch.optim.Adam(self.model.parameters(),
lr=self.lr,
weight_decay=self.weight_decay)
def _train_model(
self, training_loader, val_loader=None, val_X_intact=None, val_indicating_mask=None
):
self.optimizer = torch.optim.Adam(
self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay
)

# each training starts from the very beginning, so reset the loss and model dict here
self.best_loss = float('inf')
self.best_loss = float("inf")
self.best_model_dict = None

try:
Expand All @@ -87,12 +89,14 @@ def _train_model(self, training_loader, val_loader=None, val_X_intact=None, val_
inputs = self.assemble_input_data(data)
self.optimizer.zero_grad()
results = self.model.forward(inputs)
results['loss'].backward()
results["loss"].backward()
self.optimizer.step()
epoch_train_loss_collector.append(results['loss'].item())
epoch_train_loss_collector.append(results["loss"].item())

mean_train_loss = np.mean(epoch_train_loss_collector) # mean training loss of the current epoch
self.logger['training_loss'].append(mean_train_loss)
mean_train_loss = np.mean(
epoch_train_loss_collector
) # mean training loss of the current epoch
self.logger["training_loss"].append(mean_train_loss)

if val_loader is not None:
self.model.eval()
Expand All @@ -101,17 +105,21 @@ def _train_model(self, training_loader, val_loader=None, val_X_intact=None, val_
for idx, data in enumerate(val_loader):
inputs = self.assemble_input_data(data)
results = self.model.forward(inputs)
imputation_collector.append(results['imputed_data'])
imputation_collector.append(results["imputed_data"])

imputation_collector = torch.cat(imputation_collector)
imputation_collector = imputation_collector

mean_val_loss = cal_mae(imputation_collector, val_X_intact, val_indicating_mask)
self.logger['validating_loss'].append(mean_val_loss)
print(f'epoch {epoch}: training loss {mean_train_loss:.4f}, validating loss {mean_val_loss:.4f}')
mean_val_loss = cal_mae(
imputation_collector, val_X_intact, val_indicating_mask
)
self.logger["validating_loss"].append(mean_val_loss)
print(
f"epoch {epoch}: training loss {mean_train_loss:.4f}, validating loss {mean_val_loss:.4f}"
)
mean_loss = mean_val_loss
else:
print(f'epoch {epoch}: training loss {mean_train_loss:.4f}')
print(f"epoch {epoch}: training loss {mean_train_loss:.4f}")
mean_loss = mean_train_loss

if mean_loss < self.best_loss:
Expand All @@ -121,25 +129,29 @@ def _train_model(self, training_loader, val_loader=None, val_X_intact=None, val_
else:
self.patience -= 1

if os.getenv('enable_nni', False):
if os.getenv("enable_nni", False):
nni.report_intermediate_result(mean_loss)
if epoch == self.epochs - 1 or self.patience == 0:
nni.report_final_result(self.best_loss)

if self.patience == 0:
print('Exceeded the training patience. Terminating the training procedure...')
print("Exceeded the training patience. Terminating the training procedure...")
break

except Exception as e:
print(f'Exception: {e}')
print(f"Exception: {e}")
if self.best_model_dict is None:
raise RuntimeError('Training got interrupted. Model was not get trained. Please try fit() again.')
raise RuntimeError(
"Training got interrupted. Model was not get trained. Please try fit() again."
)
else:
RuntimeWarning('Training got interrupted. '
'Model will load the best parameters so far for testing. '
"If you don't want it, please try fit() again.")
RuntimeWarning(
"Training got interrupted. "
"Model will load the best parameters so far for testing. "
"If you don't want it, please try fit() again."
)

if np.equal(self.best_loss, float('inf')):
raise ValueError('Something is wrong. best_loss is Nan after training.')
if np.equal(self.best_loss.item(), float("inf")):
raise ValueError("Something is wrong. best_loss is Nan after training.")

print('Finished training.')
print("Finished training.")