Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding the possibility of passing a label dataset #958

Merged
merged 10 commits into from
Jun 7, 2023

Conversation

honglu2875
Copy link
Contributor

As discussed with @haileyschoelkopf in DM, I'm cleaning up my personal fork and putting the major difference out here in case it is a useful feature. It allows to take an optional label_data_paths (which is assumed to be perfectly in sync with the training data).

But if there is anything that does not fit with the current design or anything, feel free to close it though.

Why do I need it

It was used when finetuning diff models, where a custom loss mask needs to be applied to the dataset. It is very controllable if I can do it on the dataset level and easy to double-check by poking into the generated masks.
I also used this in some experiments of finetuning pythia for text repairing. It works well.

How it works

  • Without the argument label_data_paths, nothing should change.
  • With the argument, it will
    • also load the labels into training dataset (see build_the_dataset function), so that the dataset yields an extra "label" item;
    • apply loss mask to negative labels, shift by 1 and use it as target labels during training (see _get_batch function).

Besides the label, there is also a minor fix in setup_for_inference_or_eval. Not sure if it's an issue with newer version of DeepSpeed, but the no_load_optim wasn't enough to prevent the optimizer from loaded when I do inference and use stuff like generate_samples_unconditional.

@honglu2875 honglu2875 requested a review from a team as a code owner May 29, 2023 07:15
@CLAassistant
Copy link

CLAassistant commented May 29, 2023

CLA assistant check
All committers have signed the CLA.

@StellaAthena
Copy link
Member

This seems reasonable. We can also use this to preserve metadata throughout the training pipeline which someone was asking about recently.

@haileyschoelkopf
Copy link
Contributor

This PR looks great and can be merged if need be now, thank you very much @honglu2875 !

I'll test this tomorrow before approving to doubly confirm though.

@haileyschoelkopf
Copy link
Contributor

LGTM!

@Quentin-Anthony
Copy link
Member

LGTM, but tools/preprocess_data_with_mask.py needs documented with the script behavior, example run commands, and some context. This can be done as a README or as a comment block within the script, but it needs to be detailed enough such that:

  1. You can hand the script off to another dev with minimal hand-holding
  2. We can maintain this code in the future without your explicit review

@haileyschoelkopf
Copy link
Contributor

Addressed via comment block! If this documentation seems deserving of being in the main README instead, happy to move it there.

@Quentin-Anthony Quentin-Anthony merged commit c00ce70 into EleutherAI:main Jun 7, 2023
0 of 3 checks passed
@jahatef jahatef mentioned this pull request Jun 22, 2023
@xealml
Copy link

xealml commented Jul 13, 2023

this is realy great ,but I still want to know how to edit yml file

in local_setup.yml

in this way?

"data_path": "./data_text_document",
"label_data_paths": ["./data_label_document"],

but I get this error
Traceback (most recent call last): File "train.py", line 27, in <module> pretrain(neox_args=neox_args) File "/data1/limenglin/gpt-neox/megatron/training.py", line 226, in pretrain iteration = train( File "/data1/limenglin/gpt-neox/megatron/training.py", line 794, in train loss_dict, skipped_iter = train_step( File "/data1/limenglin/gpt-neox/megatron/training.py", line 700, in train_step reduced_loss = train_step_pipe( File "/data1/limenglin/gpt-neox/megatron/training.py", line 750, in train_step_pipe loss = model.train_batch(data_iter=data_iterator) File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/pipe/engine.py", line 346, in train_batch self._exec_schedule(sched) File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/pipe/engine.py", line 1374, in _exec_schedule self._exec_instr(**cmd.kwargs) File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/pipe/engine.py", line 790, in _exec_load_micro_batch batch = self._next_batch() File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/pipe/engine.py", line 626, in _next_batch batch = self.batch_fn(batch) File "/data1/limenglin/gpt-neox/megatron/training.py", line 329, in get_batch_pipe tokens, labels, loss_mask, attention_mask, position_ids = _get_batch( File "/data1/limenglin/gpt-neox/megatron/training.py", line 276, in _get_batch data_b = mpu.broadcast_data(keys, data, datatype) File "/data1/limenglin/gpt-neox/megatron/mpu/data.py", line 91, in broadcast_data key_size, key_numel, total_numel = _build_key_size_numel_dictionaries(keys, data) File "/data1/limenglin/gpt-neox/megatron/mpu/data.py", line 44, in _build_key_size_numel_dictionaries assert data[key].dim() < max_dim, "you should increase MAX_DATA_DIM" KeyError: 'label'

@honglu2875
Copy link
Contributor Author

this is realy great ,but I still want to know how to edit yml file

in local_setup.yml

in this way?

"data_path": "./data_text_document", "label_data_paths": ["./data_label_document"],

but I get this error Traceback (most recent call last): File "train.py", line 27, in <module> pretrain(neox_args=neox_args) File "/data1/limenglin/gpt-neox/megatron/training.py", line 226, in pretrain iteration = train( File "/data1/limenglin/gpt-neox/megatron/training.py", line 794, in train loss_dict, skipped_iter = train_step( File "/data1/limenglin/gpt-neox/megatron/training.py", line 700, in train_step reduced_loss = train_step_pipe( File "/data1/limenglin/gpt-neox/megatron/training.py", line 750, in train_step_pipe loss = model.train_batch(data_iter=data_iterator) File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/pipe/engine.py", line 346, in train_batch self._exec_schedule(sched) File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/pipe/engine.py", line 1374, in _exec_schedule self._exec_instr(**cmd.kwargs) File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/pipe/engine.py", line 790, in _exec_load_micro_batch batch = self._next_batch() File "/usr/local/lib/python3.8/dist-packages/deepspeed/runtime/pipe/engine.py", line 626, in _next_batch batch = self.batch_fn(batch) File "/data1/limenglin/gpt-neox/megatron/training.py", line 329, in get_batch_pipe tokens, labels, loss_mask, attention_mask, position_ids = _get_batch( File "/data1/limenglin/gpt-neox/megatron/training.py", line 276, in _get_batch data_b = mpu.broadcast_data(keys, data, datatype) File "/data1/limenglin/gpt-neox/megatron/mpu/data.py", line 91, in broadcast_data key_size, key_numel, total_numel = _build_key_size_numel_dictionaries(keys, data) File "/data1/limenglin/gpt-neox/megatron/mpu/data.py", line 44, in _build_key_size_numel_dictionaries assert data[key].dim() < max_dim, "you should increase MAX_DATA_DIM" KeyError: 'label'

Would it work if you put down train_data_paths, valid_data_paths, label_data_paths, test_data_paths one-by-one instead of data_path? When I wrote the PR I wasn't even aware of the data_path argument. It's possible that there is still some problem with this automatic train-test split.

@TissueC
Copy link

TissueC commented Jul 26, 2023

I follow your instruction but a new problem occurs. The key is how do I pass the labels of "valid/test" set to the arguments?

@TissueC
Copy link

TissueC commented Jul 26, 2023

I follow your instruction but a new problem occurs. The key is how do I pass the labels of "valid/test" set to the arguments?

Oh, I know the solution. in megatron/data/data_utils.py: build_weighted_datasets function, pass the argument: label_prefix.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

7 participants