-
Notifications
You must be signed in to change notification settings - Fork 226
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support data input from HuggingFace #646
Conversation
38982d0
to
7c07629
Compare
2c8958d
to
9647aad
Compare
9647aad
to
c921996
Compare
.github/workflows/UnitTests.yml
Outdated
- name: Test train.py with HF c4 | ||
run: | | ||
docker run -v /home/runner/actions-runner/_work/maxtext/maxtext:/app --rm --privileged maxtext_base_image bash -c \ | ||
'python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M)-${RANDOM} base_output_directory=gs:https://runner-maxtext-logs hf_data_files=gs:https://maxtext-dataset/hf/c4/c4-train-00000-of-01637.parquet hf_path=parquet dataset_type=hf steps=2 tokenizer_path=google-t5/t5-large enable_checkpointing=false' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why use the specific file for hf_data_files, why not /c4-train-*.parquet? The tests serve as useful documentation and I think the * format is more generally applicable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The concern here is that when giving a lot of files to HF's data loader, it will do a initialization to take a look at all the files. c4-train-*.parquet contains 1637 files, so it takes ~30s, which feels like an overhead for short unit tests, so I only use one file here.
dataset_type: c4 # must be c4 or synthetic | ||
dataset_type: c4 # must be c4, hf or synthetic | ||
# for HuggingFace input pipeline (dataset_type=hf) | ||
hf_path: '' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we add some hints here for what this should be set to in comments / maybe a pointer to the nice documentation you wrote in Data Input Pipeline
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added more comments and link to the doc
import grain.python as grain | ||
|
||
from input_pipeline import _hf_operations | ||
from input_pipeline import _grain_operations |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems odd that hf depends on _grain_operations? Are these common operations that apply to both? If so could we move (or rename existing file) them to a file called something like "input_pipeline_utils" or "_common_operations".
Edit: Please make changes in this PR if the _grain_operations make more sense to move to a common file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They are now in _input_pipeline_utils.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if any of these methods are hf specific or would also be better in a _common file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They are now in _input_pipeline_utils.py
@@ -37,6 +37,12 @@ then | |||
CMD_DATA=" dataset_type=c4-array_record dataset_name=array-record/c4/en/3.0.1 eval_dataset_name=array-record/c4/en/3.0.1" | |||
fi | |||
|
|||
if [ "$DATASET_TYPE" == "hf" ] | |||
then | |||
gsutil cp -r gs:https://maxtext-dataset/hf/llama2-tokenizer assets |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a difference between the hf llama2-tokenizer and the existing llama2-tokenizer? Is the HF tokenizer necessary for the HF dataset format? Generally different tokenizers can make a large difference in the trained model / loss profile, so this should be the same tokenizer as the other convergence test if possible.
Also we should avoid depending on bucket objects like this when possible. If this tokenizer is necessary could we instead add the hf/llama2-tokenizer into the assets of maxtext main (e.g. add the tokenizer in this PR?).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This tokenizer is from here: https://huggingface.co/meta-llama/Llama-2-7b-hf, which I believe is the same tokenizer as the existing llama2-tokenizer but in a different format (Rust based) that is more performant when used with HuggingFace dataset. Since it's from a gated model that requires user consent, I don't think we should add it to our repo. User can use it by providing "meta-llama/Llama-2-7b-hf" as tokenizer_path and set hf_access_token to theirs. I put this copy in our internal gcs for ease of testing so nobody's token is exposed. I added some comments to the file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the explanation and the added comments!
@@ -15,7 +15,32 @@ | |||
--> | |||
## Data Input Pipeline | |||
|
|||
Currently MaxText supports two data input pipelines: the tfds (tensorflow_datasets) based pipeline as default, and the Grain pipeline for determinism. | |||
Currently MaxText supports three types of data input for training: HuggingFace datasets, Tensorflow Datasets (TFRecord files) through the tf.data based pipeline, and ArrayRecord files through the Grain pipeline for determinism. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you give a little more flavor? Something like
Currently MaxText supports three types of data input for training: HuggingFace datasets, Tensorflow Datasets (TFRecord files) through the tf.data based pipeline, and ArrayRecord files through the Grain pipeline.
HuggingFace datasets tend to be convenient because many popular datasets are available, but tend to be the least performant/scalable and lack deterministic data replay. TensorFlow datasets are highly performant/scalable but similarly lack deterministic data replay. Finally, Grain tends to be adequately performant/scalable and provides deterministic data replay but is the least convenient
P.S. -- even after writing this I feel like we should actually delete the paragraph I just wrote and make a table of the three datasets on the axes of Performance/Scalability, Deterministic Replayability and Convenience where:
Huggingface = Low Perf/Scalability, Not Deterministically Replayable and High Convenience
TFDS = High Perf/Scalability, Not Deterministically Replayable and Medium Convenience
Grain = Medium Perf/Scalability, Deterministically Replayable and Low Convenience
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the suggestion! I added a table. I do plan to do more improvements on the doc in future PRs
ecfd8af
to
4ead64b
Compare
@@ -37,6 +37,12 @@ then | |||
CMD_DATA=" dataset_type=c4-array_record dataset_name=array-record/c4/en/3.0.1 eval_dataset_name=array-record/c4/en/3.0.1" | |||
fi | |||
|
|||
if [ "$DATASET_TYPE" == "hf" ] | |||
then | |||
gsutil cp -r gs:https://maxtext-dataset/hf/llama2-tokenizer assets |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the explanation and the added comments!
|
||
### Grain pipeline - for determinism | ||
|
||
<!-- TODO (aireenmei): add more details/examples on why determinism is important --> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One example where deterministic replay is crucial is sensitive convergence experiments such as testing quantization techniques. When comparing a quantized vs unquantized run we want everything else to be identical, in particular the data in each batch. Such experiments may be long and require saving/resuming, and we want the data order to be replyable even if the paired runs save/resume at different steps. Another use case is to debug training spikes/anomalies - if we can replay the exact data it can help debug a bad data batch versus a hardware/SDC issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the examples! I added a few more paragraphs there and incoporated your examples, with Gemimi's help :)
4a7117d
to
72a46b7
Compare
|
||
#### Cases where determinism is crucial | ||
* **Model sensitive to repetition.** When models are sensitive to the frequency with which they encounter specific examples, precise control over the order and repetition of data during training is essential. | ||
* **Convergence comparison.** In sensitive convergence experiments like testing quantization techniques, maintaining identical data batches between runs (e.g., quantized vs. unquantized) is essential for comparison. Determinism ensures consistency even the runs are long and undergo saving/resuming at different step. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: consistency even when the runs are long
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for catching, fixed now.
72a46b7
to
e2686ab
Compare
* **Convergence comparison.** In sensitive convergence experiments like testing quantization techniques, maintaining identical data batches between runs (e.g., quantized vs. unquantized) is essential for comparison. Determinism ensures consistency even the runs are long and undergo saving/resuming at different step. | ||
* **Debug training anomalies.** When troubleshooting training spikes or anomalies, the ability to replay the exact data sequence helps distinguish between bad data batches and underlying hardware or software issues. | ||
|
||
#### How does Grain achieve determinism |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you swap the order of the the "How does grain achieve determinism" section with the one above "Cases where determinism is crucial"? The "How does grain achieve determinism" follows naturally the end of the "Why do we need determinism" section, since that is left hanging on how determinism is even possible
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
e2686ab
to
3face14
Compare
An updated implementation of HF input pipeline, key designs:
Use HuggingFace API for loading dataset and tokenization. Then use grain for downstreaming processing to leverage grain's features such as packing, prefetch, use child process to fetch data.
Performance:
Standalone dataloader measured 2048 (seq_len) * 8 (per_device_batch) ~ 50s for 1000 batches. Increase host number from 1 to 64 does not change the speed.
No step time regression compared with TFDS/Grain, with small model (1B) on large slice (v5e-256). Step time the same (~1.008) for:
HF pipeline
TFDS
Grain
Convergence test pass
Clarifications: