Skip to content

Commit

Permalink
report eval harness version and do bootstrapping (#498)
Browse files Browse the repository at this point in the history
  • Loading branch information
leogao2 committed Jan 31, 2022
1 parent eb5ceb3 commit e2e2302
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 7 deletions.
10 changes: 5 additions & 5 deletions eval_tasks/eval_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def _model_call(self, inps):
return logits

@torch.no_grad()
def run_eval(self, eval_tasks=None, num_fewshot=0):
def run_eval(self, eval_tasks=None, num_fewshot=0, bootstrap_iters=2):
was_training = self.model.training
self.model.eval()
in_micro_batches = (
Expand Down Expand Up @@ -336,8 +336,8 @@ def run_eval(self, eval_tasks=None, num_fewshot=0):
provide_description=False,
num_fewshot=num_fewshot,
limit=None,
bootstrap_iters=2,
).get("results")
bootstrap_iters=bootstrap_iters,
)

if was_training:
self.model.train()
Expand All @@ -346,8 +346,8 @@ def run_eval(self, eval_tasks=None, num_fewshot=0):


def run_eval_harness(
model, forward_step_fn, neox_args, batch_size=None, eval_tasks=None, num_fewshot=0,
model, forward_step_fn, neox_args, batch_size=None, eval_tasks=None, num_fewshot=0, bootstrap_iters=2
):
print_rank_0("Running evaluation harness...")
adapter = EvalHarnessAdapter(model, forward_step_fn, neox_args, batch_size)
return adapter.run_eval(eval_tasks=eval_tasks, num_fewshot=num_fewshot)
return adapter.run_eval(eval_tasks=eval_tasks, num_fewshot=num_fewshot, bootstrap_iters=bootstrap_iters)
2 changes: 1 addition & 1 deletion evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

def main():
model, neox_args = setup_for_inference_or_eval(inference=False, get_key_value=False)
results = run_eval_harness(model, forward_step, neox_args, eval_tasks=neox_args.eval_tasks)
results = run_eval_harness(model, forward_step, neox_args, eval_tasks=neox_args.eval_tasks, bootstrap_iters=10000)
if neox_args.rank == 0:
pprint(results)
results_path = f'eval_results_{datetime.now().strftime("%m-%d-%Y-%H-%M-%S")}.json'
Expand Down
2 changes: 1 addition & 1 deletion megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,7 @@ def evaluate(
eval_results.update(
run_eval_harness(
model, forward_step_fn, neox_args, eval_tasks=neox_args.eval_tasks
)
).get("results")
)
# Move model back to the train mode.
model.train()
Expand Down

0 comments on commit e2e2302

Please sign in to comment.