Skip to content

Commit

Permalink
cleanup imports + change print string
Browse files Browse the repository at this point in the history
  • Loading branch information
sdtblck committed Jun 24, 2021
1 parent 63fc228 commit 356a155
Showing 1 changed file with 8 additions and 18 deletions.
26 changes: 8 additions & 18 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,30 +27,20 @@
import sys

import torch
import deepspeed
import numpy as np

from megatron.utils import Timers, init_wandb
from megatron import print_rank_0

from megatron import mpu

from megatron.model import GPT2ModelPipe
from megatron.utils import Timers, init_wandb, get_ltor_masks_and_position_ids, reduce_losses
from megatron import print_rank_0, mpu
from megatron.model import GPT2ModelPipe, get_params_for_weight_decay_optimization
from megatron.checkpointing import load_checkpoint, save_checkpoint
from megatron.data.data_utils import build_train_valid_test_data_iterators

from megatron.initialize import initialize_megatron
from megatron.learning_rates import AnnealingLR
from megatron.model import get_params_for_weight_decay_optimization
from megatron.logging import tb_wandb_log
from megatron.utils import OverflowMonitor, get_noise_scale_logger
from megatron.utils import get_total_params
from megatron.logging import training_log

from megatron.logging import tb_wandb_log, training_log
from megatron.utils import OverflowMonitor, get_noise_scale_logger, get_total_params
from megatron.model.gpt2_model import cross_entropy
from megatron.utils import get_ltor_masks_and_position_ids
from megatron.utils import reduce_losses
from eval_tasks import run_eval_harness
import deepspeed
import numpy as np


def pretrain(neox_args):
Expand Down Expand Up @@ -554,7 +544,7 @@ def evaluate_and_print_results(neox_args, prefix, forward_step_func, data_iterat
"""Helper function to evaluate and dump results on screen."""
total_loss_dict = evaluate(neox_args=neox_args, forward_step_fn=forward_step_func, data_iterator=data_iterator,
model=model, verbose=verbose, timers=timers)
string = f' validation loss at {prefix} | '
string = f' validation results at {prefix} | '
for k, v in total_loss_dict.items():
if isinstance(v, dict):
for k2, v2 in v.items():
Expand Down

0 comments on commit 356a155

Please sign in to comment.