Skip to content

Commit

Permalink
Fix bug with number of evaluation steps (EleutherAI#384)
Browse files Browse the repository at this point in the history
we were running way to many evaluation steps if the model is pipe parallel + has g.a.s on because of this line

```python
            for _ in range(neox_args.gradient_accumulation_steps):
```

- fixing this to 1 if the model is pipe parallel fixes the issue, as .eval_batch() already takes gradient accumulation steps into account.
  • Loading branch information
sdtblck committed Jul 30, 2021
1 parent 7261bde commit 54e622b
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,8 @@ def evaluate(neox_args, forward_step_fn, data_iterator, model, verbose=False, ti

# although we're not accumulating gradients here, we count one iter as train_batch_size_per_gpu * g.a.s
# to be consistent with deepspeed's pipe parallel engine
for _ in range(neox_args.gradient_accumulation_steps):
# since pipe parallel already takes gas into account - default to 1 here if pipe parallel is true
for _ in range(1 if neox_args.is_pipe_parallel else neox_args.gradient_accumulation_steps):
# Forward evaluation
loss = forward_step_fn(model=model, data_iterator=data_iterator, neox_args=neox_args, timers=timers)
losses.append(loss)
Expand Down

0 comments on commit 54e622b

Please sign in to comment.