Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NVIDIA GPU] Remove control knobs for each individual async collective and use the global xla_gpu_disable_async_collectives #11422

Closed
wants to merge 3 commits into from

Conversation

Tixxx
Copy link
Contributor

@Tixxx Tixxx commented Apr 10, 2024

Currently we have 1 global flag(xla_gpu_enable_async_collectives) to control whether we want to asynchronize collectives or not. This flag overrides all other control knobs for each collective. The usage of it is confusing, instead we introduce a new flag xla_gpu_disable_async_collectives which will consolidate all the async flags we have now. We remove xla_gpu_enable_async_collectives and all other individual control knobs.
Sample usage:
xla_gpu_disable_async_collectives=allreduce,reducescatter
disables async allreduce and reducescatter
By default it's empty which indicates enabling async for all collectives.

xla/xla.proto Outdated Show resolved Hide resolved
Copy link
Member

@golechwierowicz golechwierowicz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the change, I wanted to do it myself for some time :)

@golechwierowicz golechwierowicz self-requested a review April 15, 2024 08:08
@Tixxx
Copy link
Contributor Author

Tixxx commented Apr 15, 2024

Thanks for the change, I wanted to do it myself for some time :)

Thanks for the review @golechwierowicz , I'm improving this pr with a better consolidation. Will address your comments and push a new commit today.

@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Apr 16, 2024
@Tixxx Tixxx changed the title [NVIDIA GPU] Remove control knobs for each individual async collective and use the global xla_gpu_enable_async_collectives [NVIDIA GPU] Remove control knobs for each individual async collective and use the global xla_gpu_disable_async_collectives Apr 16, 2024
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Apr 16, 2024
@Tixxx
Copy link
Contributor Author

Tixxx commented Apr 17, 2024

Thanks for the change, I wanted to do it myself for some time :)

Thanks for the review @golechwierowicz , I'm improving this pr with a better consolidation. Will address your comments and push a new commit today.

I have updated this with a better approach. Please take another look. Thanks.

Copy link
Member

@golechwierowicz golechwierowicz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please fix the flag description, then LGTM from my side.

@cheshire can you also take a look?

xla/debug_options_flags.cc Outdated Show resolved Hide resolved
xla/debug_options_flags.cc Outdated Show resolved Hide resolved
@github-actions github-actions bot added the kokoro:force-run Forces CI to rerun label Apr 17, 2024
@Tixxx Tixxx requested a review from frgossen April 17, 2024 16:14
@kokoro-team kokoro-team removed the kokoro:force-run Forces CI to rerun label Apr 17, 2024
@Tixxx
Copy link
Contributor Author

Tixxx commented Apr 17, 2024

adding @frgossen FYI

opts.set_xla_gpu_enable_async_collective_broadcast(true);
opts.set_xla_gpu_enable_async_collective_permute(false);
opts.set_xla_gpu_enable_async_all_to_all(false);
opts.set_xla_gpu_enable_async_reduce_scatter(false);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these default values are not preserved? IIRC, enabling all async collectives by default caused issues in the past.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current behavior is that xla_gpu_enable_async_collectives overrides all other control knobs, that has been true by default for a while. So this pr wont change the current default settings.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for clarifying

opts.set_xla_gpu_enable_async_collective_broadcast(true);
opts.set_xla_gpu_enable_async_collective_permute(false);
opts.set_xla_gpu_enable_async_all_to_all(false);
opts.set_xla_gpu_enable_async_reduce_scatter(false);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for clarifying

@Tixxx
Copy link
Contributor Author

Tixxx commented Apr 19, 2024

@golechwierowicz @frgossen could one of you help take a look the copybara failure? Thanks

@frgossen
Copy link
Member

Internal tmp import issue. Rerunning it

copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Apr 20, 2024
…c collective and use the global xla_gpu_disable_async_collectives

Imported from GitHub PR openxla/xla#11422

Currently we have 1 global flag(xla_gpu_enable_async_collectives) to control whether we want to asynchronize collectives or not. This flag overrides all other control knobs for each collective. The usage of it is confusing, instead we introduce a new flag xla_gpu_disable_async_collectives which will consolidate all the async flags we have now. We remove xla_gpu_enable_async_collectives and all other individual control knobs.
Sample usage:
xla_gpu_disable_async_collectives=allreduce,reducescatter
disables async allreduce and reducescatter
By default it's empty which indicates enabling async for all collectives.

Copybara import of the project:

--
afff139cb742662801c49052b70a8f234bb280e4 by TJ Xu <[email protected]>:

Remove control knobs for each individual async collective and use the
global xla_gpu_enable_async_collectives

--
e1385c5b1f82aee1d3ea98e5d453c3e60fc5fa8a by TJ Xu <[email protected]>:

Consolidate all flags into one

--
72020bef1368f514760cbfd6630c22d2348121c9 by TJ Xu <[email protected]>:

Change description of xla_gpu_disable_async_collectives

Merging this change closes #11422

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#11422 from Tixxx:tixxx/flag_clean_up 72020bef1368f514760cbfd6630c22d2348121c9
PiperOrigin-RevId: 624859309
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Apr 29, 2024
…c collective and use the global xla_gpu_disable_async_collectives

Imported from GitHub PR openxla/xla#11422

Currently we have 1 global flag(xla_gpu_enable_async_collectives) to control whether we want to asynchronize collectives or not. This flag overrides all other control knobs for each collective. The usage of it is confusing, instead we introduce a new flag xla_gpu_disable_async_collectives which will consolidate all the async flags we have now. We remove xla_gpu_enable_async_collectives and all other individual control knobs.
Sample usage:
xla_gpu_disable_async_collectives=allreduce,reducescatter
disables async allreduce and reducescatter
By default it's empty which indicates enabling async for all collectives.

Copybara import of the project:

--
afff139cb742662801c49052b70a8f234bb280e4 by TJ Xu <[email protected]>:

Remove control knobs for each individual async collective and use the
global xla_gpu_enable_async_collectives

--
e1385c5b1f82aee1d3ea98e5d453c3e60fc5fa8a by TJ Xu <[email protected]>:

Consolidate all flags into one

--
72020bef1368f514760cbfd6630c22d2348121c9 by TJ Xu <[email protected]>:

Change description of xla_gpu_disable_async_collectives

Merging this change closes #11422

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#11422 from Tixxx:tixxx/flag_clean_up 72020bef1368f514760cbfd6630c22d2348121c9
PiperOrigin-RevId: 627992711
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Apr 29, 2024
FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#11422 from Tixxx:tixxx/flag_clean_up 72020bef1368f514760cbfd6630c22d2348121c9
PiperOrigin-RevId: 628977605
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Apr 29, 2024
FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#11422 from Tixxx:tixxx/flag_clean_up 72020bef1368f514760cbfd6630c22d2348121c9
PiperOrigin-RevId: 628442535
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Apr 29, 2024
FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#11422 from Tixxx:tixxx/flag_clean_up 72020bef1368f514760cbfd6630c22d2348121c9
PiperOrigin-RevId: 628941758
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Apr 29, 2024
…c collective and use the global xla_gpu_disable_async_collectives

Imported from GitHub PR openxla/xla#11422

Currently we have 1 global flag(xla_gpu_enable_async_collectives) to control whether we want to asynchronize collectives or not. This flag overrides all other control knobs for each collective. The usage of it is confusing, instead we introduce a new flag xla_gpu_disable_async_collectives which will consolidate all the async flags we have now. We remove xla_gpu_enable_async_collectives and all other individual control knobs.
Sample usage:
xla_gpu_disable_async_collectives=allreduce,reducescatter
disables async allreduce and reducescatter
By default it's empty which indicates enabling async for all collectives.

Copybara import of the project:

--
afff139cb742662801c49052b70a8f234bb280e4 by TJ Xu <[email protected]>:

Remove control knobs for each individual async collective and use the
global xla_gpu_enable_async_collectives

--
e1385c5b1f82aee1d3ea98e5d453c3e60fc5fa8a by TJ Xu <[email protected]>:

Consolidate all flags into one

--
72020bef1368f514760cbfd6630c22d2348121c9 by TJ Xu <[email protected]>:

Change description of xla_gpu_disable_async_collectives

Merging this change closes #11422

PiperOrigin-RevId: 629003173
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Apr 29, 2024
After this CL, the output shape will not be a tuple if it contains only a single element. I broke this by accident in cl/626122610.

I also added tests that ensure the correct behavior.

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#11422 from Tixxx:tixxx/flag_clean_up 72020bef1368f514760cbfd6630c22d2348121c9
PiperOrigin-RevId: 629004592
@frgossen
Copy link
Member

This breaks internal tests that are still using the old flags. Wdyt about just adding the flag and deprecating the individual ones for now?

@Tixxx
Copy link
Contributor Author

Tixxx commented Apr 29, 2024

This breaks internal tests that are still using the old flags. Wdyt about just adding the flag and deprecating the individual ones for now?

hmm I didn't know those flags were still set explicitly. Yea in that case it makes sense to just deprecate them instead of removing them from xla.proto. Do you need me to make the changes?

@akuegel
Copy link
Member

akuegel commented Apr 30, 2024

This breaks internal tests that are still using the old flags. Wdyt about just adding the flag and deprecating the individual ones for now?

hmm I didn't know those flags were still set explicitly. Yea in that case it makes sense to just deprecate them instead of removing them from xla.proto. Do you need me to make the changes?

I think it is resolved already, the internal users of the flag were migrated to the new flags.

"This disables a certain set of async collectives and turn them into"
" synchornous ones. By default, this is empty which indicates enabling"
" async execution for all collectives. A sample usage is: "
" --xla_gpu_disable_async_collectives=ALLREDUCE,REDUCESCATTER"));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Tixxx
How to disable all collectives?
It would be good to have a list of all collective somewhere, or a special value like "ALL" to disable all of them.

reedwm added a commit to reedwm/maxtext that referenced this pull request May 10, 2024
XLA PR openxla/xla#11422 removed some XLA flags relating to async collectives. This caused the A3 configs to fail to run, so this change removes such flags from the A3 configs. The flags removed are:

--xla_gpu_enable_async_all_gather=true
--xla_gpu_enable_async_reduce_scatter=true
--xla_gpu_enable_async_all_reduce=true

Such flags had no impact before the XLA PR as the async collectives were already enabled by default.
shauryagup added a commit to AI-Hypercomputer/maxtext that referenced this pull request Jul 15, 2024
* Move tpu end-to-end test scripts to tpu folder

* unify WORKDIR to /deps

* Share GCS path between Gemma-7b tests

* Add README for llama2-7B

* adding script to fix the style and adding modified/fixed files with line length 125

* Move apt install from `rto_setup.sh` to `setup.sh`

* Update instructions for installing snap.

* Removes batch size from prefill attention calculation.

* Fixes for inf testing.

* Revert "Fixes for inf testing."

This reverts commit b15b1d5.

* Fixes

* Fix subset of hosts dataloading

* inference microbenchmark

  - allow run specified stages
  - allow run specific prefill length(s)
  - delete prefill result
  - printout prefill result

added funcs in max_utils

* Update Run_MaxText_via_xpk.md

Fixing typo.

* inference_microbenchmark:

  - time prefill only
  - benchmark prefill and insert

* Mark nvidia devtools repo as trusted

This is a stopgaps measure to circumvent the nvidia repo's gpg signature issue

* Explicitly set AQT Freezer mode in MaxText.

PiperOrigin-RevId: 627250589

* Move aqtp pin up

* Pre-commit config

* Update 128B config on v5e to use qkv_proj_offloaded remat_policy

* [MaxText] Rename llama2_7b_single_host_gpu.yml to make it clear that it can be used for any number of host.

PiperOrigin-RevId: 627804089

* Split Mixtral test into two scripts

* Update jax.tree_map to jax.tree_util.tree_map

* change norm sharding

fix lint

Revert "fix lint"

This reverts commit d8dc450.

fix lint

* Change l2norm to use jnp.sqrt

* Fix test_tokenize

* Streamlined setup.sh to have fewer apt install calls

* loosen tolerance in assert_params_sufficiently_sharded

* Enable entropy on multihost CPUs.

* Add tests to GPU runner

* Replace deprecated np.product with np.prod

* fix norm sharding

* Add Llama2-70b test

* Internal change only.

PiperOrigin-RevId: 630446330

* Add more tests for Mixtral

* Make some AQT dataclasses to use keyword-only fields (1/N)

This cl introduces an temporary decorator that will be temporarily used during this migration. The eventual goal is to enforce kw_only=True in all dataclasses unless it's not feasible, aiming to make AQT less error-prune and improve readability.

PiperOrigin-RevId: 631132072

* Reverts e8b53e5

PiperOrigin-RevId: 631465526

* Update tflops calculation

* fix sharding on generate cache in prefill results.

* Remove async XLA_FLAGS from A3 configs.

XLA PR openxla/xla#11422 removed some XLA flags relating to async collectives. This caused the A3 configs to fail to run, so this change removes such flags from the A3 configs. The flags removed are:

--xla_gpu_enable_async_all_gather=true
--xla_gpu_enable_async_reduce_scatter=true
--xla_gpu_enable_async_all_reduce=true

Such flags had no impact before the XLA PR as the async collectives were already enabled by default.

* Update llama2_7b_gpu.yml

PiperOrigin-RevId: 631752008

* Add forward pass logit check test for Llama2-7b

* Eval the command string from XPK for GPU script

* Remove cases where the deprecated --xla_gpu_simplify_all_fp_conversions is set to its default value.

PiperOrigin-RevId: 633645462

* streamline CI test structure

* fix pylint

fix pylint: Using variable 'p_eval_step' before assignment (#651)

* Remove async XLA_FLAGS from A3 configs

* Add llama-70b gpu config.

PiperOrigin-RevId: 634267313

* Support data input from HuggingFace

* Update the NCCL flags for A3+.

* add gemma logit test

* Integrate orbax logger in maxtext for structured logging.

Integrate orbax logger in maxtext for structured logging.

Integrate orbax logger in maxtext for structured logging.

Integrate orbax logger in maxtext for structured logging.

Integrate orbax logger in maxtext for structured logging.

Integrate orbax logger in maxtext for structured logging.

* fix hf input pipeline

* Fix prefill assertion

* Remove decode asserts from Gemma test files

* add single controller flag

* fix OOM issue running inference microbenchmark with llama13b on v5e4

* Add Llama2 13B Tests

* Don't clip fp8 stats

* Integrate nsys profiler

Remove 'enable_profiler' config and add 'profiler' config instead

* Add MoE matmul implementation

* fix OUTPUT_PATH in v5e/128b.sh

* squash

* Update flops calculation to active experts in moe

* Enable kv cache layout control

* Fix Gemma Readme link

* Internal change only.

Reverts a28f518

PiperOrigin-RevId: 639890999

* Upgrade Pinned Base Image for GPU

* Metrics bug: server_lib should be config_lib

* Fix MoE matmul scale issue

* Removed unused Pallas import from layers/attentions.py

PiperOrigin-RevId: 640481280

* Change norm sharding for llama2-7b to fsdp.

PiperOrigin-RevId: 640498890

* Copybara import of the project:

--
d7d694f by RissyRan <[email protected]>:

Fix forward test for Mixtral

COPYBARA_INTEGRATE_REVIEW=#679 from google:ranran_fix_forward_test d7d694f
PiperOrigin-RevId: 640537456

* Set additional flags for a3 and a3plus

* Use run_id instead of sha for docker tag

* refactor data input pipeline and add perf data

* Add gpt3 175b on v5e config

* Pipeline parallelism support (linear only)

* Turn on layer scanning for llama2-7b on GPU.

This better utilizes recent optimizations to collective approximation in the XLA latency hiding scheduler.

PiperOrigin-RevId: 642559284

* reshape q

* Add profiler flags to JetStream server

Add jetstream config

backward compatible

* fix tfds instruction

* Add vanilla megablox to MoE

* Add llama2 70b training config for v5e

* base.yml changes

circular changes to pipeline.py

pyconfig circ changes

pipeline parallel tests circular style

tree map, half passed tests

Total iterations circularized

improved iteration comment

run all tests

test both circular and non-circular

circ storage comment

circ storage pushing index comment

* Account for new mesh axes for llama2-7b, and llama2-70b on GPUs.

PiperOrigin-RevId: 643999933

* Sharding the llama2 70b on v5e-16 more efficiently.

https://arxiv.org/pdf/2211.05102
https://arxiv.org/pdf/1909.08053

* add compute_axis_order

* Add maxengine_server configs to base.yml

* Add FSDP + Megablox

* Llama3-8b model config

* MaxText package

* fix data loading from HF hub

* Fix llama2-{7,70}b sharding on GPU.

PiperOrigin-RevId: 645365795

* Move stage to second axis in mesh

Move stage to second axis in mesh

* Copybara import of the project:

--
1718b89 by RissyRan <[email protected]>:

Refactor permute and unpermute operations

COPYBARA_INTEGRATE_REVIEW=#714 from google:refactor_mega b101cbc
PiperOrigin-RevId: 645591567

* Fix Mesh setup for multiprocess CPUs.

* add kv_quant_axis

* Add a directory check for the . If it fails, attempt to check a path relative to the base config, similar to what is done for model configurations.

Minor update

Remove the raised exception

* Add mistral tokenizer to maxtext/assets

* Update the dependencies to prepare for integration of emergency checkpointing

Withhold some package versions

Update version of typing_extensions

* Make broadcasting from one replica to all more memory efficient

PiperOrigin-RevId: 646526020

* Inference Microbenchmark Sweep

* Fix mesh_axes and data_sharding for LLaMA 2 GPU configs.

PiperOrigin-RevId: 646795068

* Allow owners to have any approver

Fix AddLabel syntax

Fix punctuation

* Enable saving using Orbax's emergency checkpoint manager

fix data loading from HF hub

Add explanation to the emergency checkpoint feature

Fix pylint issues

Minor changes to the config file

resolve conflicts

Inference Microbenchmark Sweep

Fix mesh_axes and data_sharding for LLaMA 2 GPU configs.

PiperOrigin-RevId: 646795068

* Add Llama2 7B, 13B high performance training configs

* Load/Save Aqt quantized checkpoint.

* modify prefill to return first token

* Fix and protect simple_layer

Fix and protect simple_layer

Fix and protect simple_layer

Fix and protect simple_layer

* Adding option for int4 quantization to kvcache.

* support eval dataset and refactor

* Support partial overrides for logical_axis_rules.

* Fix simple test step count

* Clean up MoE brute force implementation

* Preliminary restore with lots of hardcoding and hacking

Refactor the code and remove the hardcoding

More refactoring

Cleanup for pull request

Address linting issues

Preliminary restore with lots of hardcoding and hacking

Refactor the code and remove the hardcoding

More refactoring

Cleanup for pull request

Address linting issues

Small formatting

Fix merging issues

* Add convergence tests on A3 GPU

* Update tile size

* Handle cases where memstats are not available for the device.

Memstats are not guaranteed to be available and can throw an error or return None. This change will handle both `jaxlib.xla_extension.XlaRuntimeError` if the device is not a PjRt addressable device or `KeyError` if the memstats returns None if they are not available.

* Fix validation error for other models

* Fix decode.py to also use first_token from prefill_call

* Add moe perf number

* move num_experts pyconfig assertion to fix tests

* Cast type for inputs before kernel call

* Move sharding overrides to models/ directory.

PiperOrigin-RevId: 650994392

* Enable quantization for MoE Matmul implementation

* Integrate and test Goodput Monitor with MaxText

* Adding Tokens/s/device to the log.

* Adding support for mixed precision quantization configs.

---------

Co-authored-by: maxtext authors <[email protected]>
Co-authored-by: Nina Cai <[email protected]>
Co-authored-by: NinaCai <[email protected]>
Co-authored-by: michelle-yooh <[email protected]>
Co-authored-by: In-Ho Yi <[email protected]>
Co-authored-by: A9isha <[email protected]>
Co-authored-by: In-Ho Yi <[email protected]>
Co-authored-by: ssusie <[email protected]>
Co-authored-by: tonyjohnchen <[email protected]>
Co-authored-by: Roshani Narasimhan <[email protected]>
Co-authored-by: Pate Motter <[email protected]>
Co-authored-by: khatwanimohit <[email protected]>
Co-authored-by: Morgan Du <[email protected]>
Co-authored-by: DongHyun Choi <[email protected]>
Co-authored-by: gobbleturk <[email protected]>
Co-authored-by: Raymond Zou <[email protected]>
Co-authored-by: Bixia Zheng <[email protected]>
Co-authored-by: Ran Ran <[email protected]>
Co-authored-by: Zhiyu Li <[email protected]>
Co-authored-by: Rafi Witten <[email protected]>
Co-authored-by: RissyRan <[email protected]>
Co-authored-by: Greg Olechwierowicz <[email protected]>
Co-authored-by: Junwei Yang <[email protected]>
Co-authored-by: Reed Wanderman-Milne <[email protected]>
Co-authored-by: Dimitar (Mitko) Asenov <[email protected]>
Co-authored-by: aireenmei <[email protected]>
Co-authored-by: yangyuwei <[email protected]>
Co-authored-by: Abhinav Singh <[email protected]>
Co-authored-by: Sadi Kneipp <[email protected]>
Co-authored-by: jwyang-google <[email protected]>
Co-authored-by: Anfal Siddiqui <[email protected]>
Co-authored-by: Brendan Slabe <[email protected]>
Co-authored-by: Sergei Lebedev <[email protected]>
Co-authored-by: Jon Bolin <[email protected]>
Co-authored-by: Zijun Zhou <[email protected]>
Co-authored-by: Zhihao Shan <[email protected]>
Co-authored-by: Adam O'Brien <[email protected]>
Co-authored-by: Vipan Nalla <[email protected]>
Co-authored-by: Vipan Nalla <[email protected]>
Co-authored-by: Xuefeng Gu <[email protected]>
Co-authored-by: Andy Ye <[email protected]>
Co-authored-by: Mitali Singh <[email protected]>
Co-authored-by: xuefgu <[email protected]>
Co-authored-by: Luke Baumann <[email protected]>
Co-authored-by: Dipannita Shaw <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants