Skip to content

Commit

Permalink
Fix minor issues with config option defaults, and indentation error i…
Browse files Browse the repository at this point in the history
…n calculate_derived
  • Loading branch information
Quentin-Anthony committed May 2, 2023
1 parent 586f514 commit 507ad04
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 7 deletions.
12 changes: 6 additions & 6 deletions megatron/neox_arguments/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,7 +784,7 @@ def calculate_batch_parameters(
else:
assert (
False
), "Either train_batch_size or micro_batch_per_gpu needs to be provided"
), "Either train_batch_size or train_micro_batch_size_per_gpu needs to be provided"
return int(train_batch), int(micro_batch), int(grad_acc)

@staticmethod
Expand Down Expand Up @@ -907,11 +907,11 @@ def calculate_derived(self):
save_iters = list(save_iters)
save_iters.sort()

self.update_values(
{
"save_iters": save_iters,
}
)
self.update_values(
{
"save_iters": save_iters,
}
)

# derive precision
if (self.fp16 or {}).get("type", self.precision) == "bfloat16":
Expand Down
2 changes: 1 addition & 1 deletion megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ class NeoXArgsLogging(NeoXArgsTemplate):
Write TensorBoard logs to this directory.
"""

log_interval: int = None
log_interval: int = 100
"""
Interval between logging.
"""
Expand Down
4 changes: 4 additions & 0 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,10 @@ def get_optimizer(model, neox_args):
"""Set up the optimizer."""
if neox_args.no_load_optim:
return None, None

if neox_args.optimizer is None:
print_rank_0(f'ERROR: Optimizer is None. Either set the optimizer dict in your config (if training) or set no_load_optim in your config (if inference)')
exit()
# Build parameter groups (weight decay and non-decay).
param_groups = get_params_for_weight_decay_optimization(model, neox_args)
print_rank_0(
Expand Down

0 comments on commit 507ad04

Please sign in to comment.