Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Supported features #571

Open
peregilk opened this issue Mar 30, 2024 · 18 comments
Open

Supported features #571

peregilk opened this issue Mar 30, 2024 · 18 comments

Comments

@peregilk
Copy link

Mainly wanted to start with thanking you for making MaxText available. I have been using it for a few days, and the first impression is fantastic. Getting started was really easy, it seemed very stable, and the performance was fantastic. It seems to scale very nicely.

A few things that I have not been able to figure out yet, it might be because of lack of documentation, or simply because it is not implemented.

  • Are there any support for Flash attention, or any plans for implementing this? This has been a major area where GPUs have been ahead of TPUs. I have noticed that there now is at least an experimental implementation from the Jax-team: https://github.com/google/jax/blob/main/jax/experimental/pallas/ops/tpu/flash_attention.py.

  • Training directly from tfds seemed straight forward. However, I was a bit confused about how to implement more advanced data loader features, for instance probability sampling like explained here. This can be somewhat tricky to do efficiently on multiple tpus. What is the sensible approach here? Manually sampling into a tfds dataset does not seem very efficient. Are there external libraries here that are compatible with maxtext?

  • Are there plans for implementing DPO/RLHF?

I also shamelessly wanted to point you to my own repo: https://github.com/peregilk/ttconnect. It is a very simple bash script that ideally should be run on a VM in the same zone. It automatically opens up synchronised tmux windows to all the VMs in the pod, and allows you to type the same command into all the VMs. This makes it even easier to go from one tpu to pods.

@rwitten
Copy link
Collaborator

rwitten commented Mar 31, 2024

Thank you for the comments!

(1) Fused attention is on by default for training! We use "splash attention" which is a custom and faster version! (And we're working on accelerated inference attentions.)
(2) We don't implement more advanced data loaders though I think they can be implemented in TFDS. It is also easy to plug in your own data loader. Is there a specific data loading solution you'd like us to use?
(3) Yes, DPO is underway!

ttconnect is super cool, thanks for sending!

@peregilk
Copy link
Author

peregilk commented Apr 1, 2024

Thanks for the answer. Looking forward to the DPO support.

It would of course be fantastic if the HuggingFace datasets could natively be supported. I have never really been able to run large non-streaming datasets from HF on the TPUs (disk-size issues on the VMs), but we have been able to wrap the HF datasets in torch.split_dataset_by_node, to stream on multiple TPUs. Im not sure if I am able to implement something like this into MaxText though. Not really sure on what level it should be implemented.

Any chance you support HF datasets in the future?

But any way of preprocessing the data before it is split to the TPUs would be extremely useful for running experiments on dataset building. Thats both for sampling or filtering based on a field in the dataset.

@A9isha
Copy link
Collaborator

A9isha commented May 6, 2024

Yes support for HF datasets in MaxText is on the way
@aireenmei

@aireenmei
Copy link
Collaborator

Thank you for tagging me on this. Yes, supporting HuggingFace dataset is in our plan. We have some implementations and are undergoing some perf evaluations to understand it better. I will update here when we have it out.

@aireenmei
Copy link
Collaborator

Hi @peregilk , HuugingFace dataset is supported now. Please check out https://github.com/google/maxtext/blob/main/getting_started/Data_Input_Pipeline.md.

@peregilk
Copy link
Author

peregilk commented May 21, 2024

Really fantastic! Makes it a lot more convenient. Especially reading jsonlines from the buckets looks great. Do you support all native HF? Like jsonl.gz?

@aireenmei
Copy link
Collaborator

Yes, jsonl.gz is supported, as well as other formats supported by datasets.load_dataset (https://huggingface.co/docs/datasets/en/loading)

@peregilk
Copy link
Author

@aireenmei Is there a more detailed documentation here. I was for instance unable to figure out how to specify the validation set.

@aireenmei
Copy link
Collaborator

Hi @peregilk, a specific validation set is not supported yet. But this is in our list of items to be worked on.

@aireenmei
Copy link
Collaborator

Hi @peregilk , eval is supported now #738

@peregilk
Copy link
Author

peregilk commented Jul 3, 2024

@aireenmei Thanks a lot. Really looking forward to testing this.

Since this seems to be very related, I am reporting here. Can open an issue if you like:

I am training with:

hf_data_files='gs:https://mybucket/mydir/train*.jsonl'

There are 256 files in the directory. Close to the end of the first epoch one of the workers throws this error in maxtext/MaxText/input_pipeline/_input_pipeline_utils.py", line 95, in __getitem__:

The above exception was the direct cause of the following exception:

ValueError: Run out of shards, shard 259 is not available

@aireenmei
Copy link
Collaborator

Hi @peregilk , this should be the expected behavior. With the current implementation, you may not be able to use all the data in your train files. Say that you have 256 files and you are using v4-64 that has 8 hosts. Each host will read 256/8=32 shards. The i-th host will read the (8*x + i)-th shard (0<=x<32). For exp, host 0 reads shard 0, 8, 16, ..., 248; host 7 reads shard 7, 15, ..., 255 etc. When a host finish their current shard, they move to the next shard assigned to them. But since each shard has slightly different number of examples, the training will stop when the one of the hosts run out of data. For the above exp, if host 0 is the first one to finish it's last shard, 248, it will look for shard 248+8=256, which is not available, and it will results in the error you see.

@peregilk
Copy link
Author

peregilk commented Jul 3, 2024

Thank @aireenmei. Not sure I understand though. Why would not the logical behaviour here be simply to restart on the first shard that was given to the host when there are no more shards available? Alternatively you would have to duplicate your dataset for training more than one epoch, right?

@aireenmei
Copy link
Collaborator

I did not implement the auto restart because some users may not want their model to see repetitive data. I can add the multi-epoch support to our backlog. Meanwhile it should be straightfoward to change the shard update logic here: https://github.com/google/maxtext/blob/main/MaxText/input_pipeline/_input_pipeline_utils.py#L105

@peregilk
Copy link
Author

peregilk commented Jul 4, 2024

OK. Makes sense. Thanks.

@peregilk
Copy link
Author

@aireenmei I have tried using your validation support for hf-datasets. I am seeing the same issue here, setting hf_eval_files. Even if the number of shard are dividable by the number of the number of workers, it still crashes asking for the next shard. I cant see any way to limit the number of eval steps, so that it does not run out of shards. What am I missing?

@aireenmei
Copy link
Collaborator

Hi @peregilk indeed this is a bug. I will fix it. Meanwhile this flag (https://github.com/google/maxtext/blob/main/MaxText/configs/base.yml#L336) controls eval step that you can use for now, I'll rename it to eval_steps in my next pr for clarity.

@Mddct
Copy link

Mddct commented Jul 19, 2024

any update on dpo?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants