Skip to content

Commit

Permalink
Bug Fix: Test Metrics Overwrite Validation Metrics
Browse files Browse the repository at this point in the history
This commit fixes a bug where test set metrics overwrite validation set
metrics on TensorBoard and aren't logged at all by Weights and Biases.
It corrects the bug by writing test set metrics to their own charts,
prefixed by "test/". It also writes the metrics with the x-axis value
being the last iteration, rather than iteration 0. This correct the
Weights and Biases error that the iteration (aka step) must always be
increasing in subsequent log calls. For background, validation loss and
perplexity are written to the charts "validation/lm_loss" and
"validation/lm_loss_ppl". At the end of training, the test loss and
perplexity were also written to those two charts as iteration 0. This
resulted in TensorBoard overwriting the validation data and in Weights
& Biases throwing a warning such as "wandb: WARNING Step must only
increase in log calls.  Step 0 < 32000; dropping {'validation/lm_loss':
1.715476632118225}."

Tested manually to ensure that new charts were created for test metrics.
  • Loading branch information
pwstegman committed Sep 12, 2022
1 parent 87d01ad commit 1c8c2ab
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,10 @@ def pretrain(neox_args):
forward_step_func=forward_step,
data_iterator=test_data_iterator,
model=model,
iteration=0, # iteration 0 in order to always use full test data
iteration=iteration,
verbose=True,
timers=timers,
chart_name="test"
)


Expand Down Expand Up @@ -736,6 +737,7 @@ def evaluate_and_print_results(
iteration,
verbose=False,
timers=None,
chart_name="validation"
):
"""Helper function to evaluate and dump results on screen."""
total_loss_dict = evaluate(
Expand All @@ -746,14 +748,14 @@ def evaluate_and_print_results(
verbose=verbose,
timers=timers,
)
string = f" validation results at {prefix} | "
string = f" {chart_name} results at {prefix} | "
for k, v in total_loss_dict.items():
if isinstance(v, dict):
for k2, v2 in v.items():
k3 = "_".join([k, k2])
string += f"{k3} value: {v2:.6E} | "
tb_wandb_log(
f"validation/{k3}",
f"{chart_name}/{k3}",
v2,
iteration,
use_wandb=neox_args.use_wandb,
Expand All @@ -762,7 +764,7 @@ def evaluate_and_print_results(
else:
string += f"{k} value: {v:.6E} | "
tb_wandb_log(
f"validation/{k}",
f"{chart_name}/{k}",
v,
iteration,
use_wandb=neox_args.use_wandb,
Expand Down

0 comments on commit 1c8c2ab

Please sign in to comment.