Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Eval Harness #471

Merged
merged 14 commits into from
Dec 20, 2021
Merged

Improve Eval Harness #471

merged 14 commits into from
Dec 20, 2021

Conversation

sdtblck
Copy link
Contributor

@sdtblck sdtblck commented Nov 26, 2021

Adds the following:

  • Data Parallelism to Eval Harness (need to test this a bit more thoroughly before merge)
  • Adds the ability to specify a checkpoint path to run evaluation on with a command line argument, e.g:
./deepy.py evaluate.py configs/tmp.yml --eval_tasks lambada pubmedqa --iteration 26000 --eval_results_prefix iteration_26000
  • Adds the ability to specify a results prefix to save results to with a command line arg (see above)

Apologies, autoformatter made a lot of cosmetic changes too

@StellaAthena
Copy link
Member

I used this to run the eval harness on the saved checkpoints and it didn’t have any problems. Does that address the additional testing that you wanted @sdtblck, or are there additional configurations you still want to try?

@sweinbach
Copy link
Contributor

I am testing a little on this branch and ran into this error. Is this related? Micro batch size is set to 8. Pipe and model parallel are 1 on 8 gpus.
image

@StellaAthena
Copy link
Member

I am testing a little on this branch and ran into this error. Is this related? Micro batch size is set to 8. Pipe and model parallel are 1 on 8 gpus. image

According to this print out you are incorrect about one of the numbers you are reporting. Can you start a run, log it on wandb, and link to the wandb? That'll make it easiest to see all the derived parameters and where precisely its failing.

@sweinbach
Copy link
Contributor

I am testing a little on this branch and ran into this error. Is this related? Micro batch size is set to 8. Pipe and model parallel are 1 on 8 gpus. image

According to this print out you are incorrect about one of the numbers you are reporting. Can you start a run, log it on wandb, and link to the wandb? That'll make it easiest to see all the derived parameters and where precisely its failing.

Tried to activate wandb. I don't see logs. Mabe wandb.init is not called in evaluate.py?

But have looked at the problem further. Got the following points:

  • Verified that batch size in EvalHarnessAdapter is set to 8 (equal to the number fo gpus)
  • Verified mp and pp set to 1
  • works for some tasks (e.g. mrpc, copa)
  • does not work for others (e.g. boolq)
  • when it does not work, loglikelihood requests run for all but the last batch

=> I assume that the total number of eval items is not divisible by the batch size. i.e. we need to treat a potentially smaller last batch differently

My proposal is to change this function in eval_adapter.py to

def _model_call(self, inps):
        ######## DATA PARALLEL STUFF ########
        world_size = mpu.get_data_parallel_world_size()

        batch_size = inps.shape[0]
        if batch_size % world_size != 0:
            # The last batch could potentially not fill the full batch size (if the dataset size is not divisible by batch size)
            # In this case we pad the batch
            padded_size = world_size - (batch_size % world_size)
            inps = torch.cat([inps] + [ inps[0:1, :] for _ in range(padded_size) ], dim=0) # pad with first inp item

        assert inps.shape[0] % world_size == 0, f"batch size ({inps.shape[0]}) must be divisible by world size ({world_size})"
        # get a chunk for each data parallel rank
        chunk_size = inps.shape[0] // world_size
        rank = mpu.get_data_parallel_rank()
        inps = inps[rank * chunk_size:(rank + 1) * chunk_size]
        #####################################

        # make a dummy dataloader / iterator to pass to model
        data_wrapped = iter([{'text': F.pad(inps, pad=(0, 1))}]) 
        if self.neox_args.is_pipe_parallel:
            # need these flags to stop deepspeed from hanging
            self.model.first_output_send = True
            self.model.pipe_recv_buf = None
        _, logits = self._forward_step_fn(model=self.model, data_iterator=data_wrapped)

        ######## DATA PARALLEL STUFF ########
        # gather logits from all ranks
        if logits is not None:
            tensor_list = [torch.zeros_like(logits) for _ in range(world_size)]
            torch.distributed.all_gather(tensor_list, logits, group=mpu.get_data_parallel_group())
            logits = torch.cat(tensor_list, dim=0)

        return logits[:batch_size, :, :]

If you agree I can push. Was a little hesitant becaues I really don't know the eval harness implementation.

@sweinbach
Copy link
Contributor

sweinbach commented Dec 10, 2021

Btw. greedy_until requests for alibi fail without this PR #452 . Would be nice to merge at some point.

@StellaAthena
Copy link
Member

Btw. greedy_until requests for alibi fail without this PR #452 . Would be nice to merge at some point.

Merged.

@sweinbach
Copy link
Contributor

Note not to forget. Squadv2 fails at this line with "list index out of range". cont is not always a list of length bigger than 0. Needs to be more robust.

s = cont[0]['text'] or ''

@sdtblck
Copy link
Contributor Author

sdtblck commented Dec 17, 2021

Hey @sweinbach spending the day today getting to all the past issues I've been too busy to look at 😆 thanks for looking into this one, I think you're probably on the money with the last batch not being equal.

However I suspect just naively padding might break things further down the line? I'll spend a couple of hours testing this out now, and also add doc strings to clarify the eval harness code as it's a bit of a mess right now.

@sdtblck
Copy link
Contributor Author

sdtblck commented Dec 18, 2021

Ok @sweinbach I think that this should be fixed now.

The problem with Squad2 is threefold, and I think even with this fixed, you won't get satisfactory results.

To explain:

  1. lm_eval_harness has 'stop tokens' for each generation task, that signal to the model when to stop generation. The stop tokens for many tasks (squad included) are \n. This is a problem for any models trained on the pile, as the stackoverflow componenent (the majority of QA pairs in pretraining) is formatted like Q:\n\n<question>. A:\n\n<answer>. So most of the time, after being prompted with a question, a pile trained model will produce a newline, and this newline will immediately stop the generation.

  2. if a stop token or eos is produced at the first token, the generation in neox will return an empty list. This is what caused the index out of range error, that's now fixed.

  3. Additionally, on ranks other than mp rank == 0, an empty list is returned. This is fixed too (you can pass a 'broadcast_generated_tokens' arg to generate_samples_from_prompt, and this is now set to True by default in eval harness.

So, despite 2 and 3 being fixed, 1 is still an issue, but more of an issue on lm_eval_harness' side. I suggest we add an issue there.

Anyway, if someone can approve, I think this is ready to merge.

Copy link
Member

@StellaAthena StellaAthena left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have tested this, we are currently using it to evaluate models, and Sid has reviewed it as well. There are some lingering issues, but those appear to be issues with the design issues with the eval harness rather than issues with this code.

@sdtblck sdtblck merged commit 3ad6195 into main Dec 20, 2021
@sdtblck sdtblck deleted the eval_harness_dp branch December 20, 2021 20:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants