-
Notifications
You must be signed in to change notification settings - Fork 405
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NVIDIA GPU] Remove control knobs for each individual async collective and use the global xla_gpu_disable_async_collectives #11422
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the change, I wanted to do it myself for some time :)
Thanks for the review @golechwierowicz , I'm improving this pr with a better consolidation. Will address your comments and push a new commit today. |
global xla_gpu_enable_async_collectives
7fabaa2
to
e1385c5
Compare
I have updated this with a better approach. Please take another look. Thanks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please fix the flag description, then LGTM from my side.
@cheshire can you also take a look?
adding @frgossen FYI |
opts.set_xla_gpu_enable_async_collective_broadcast(true); | ||
opts.set_xla_gpu_enable_async_collective_permute(false); | ||
opts.set_xla_gpu_enable_async_all_to_all(false); | ||
opts.set_xla_gpu_enable_async_reduce_scatter(false); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these default values are not preserved? IIRC, enabling all async collectives by default caused issues in the past.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current behavior is that xla_gpu_enable_async_collectives
overrides all other control knobs, that has been true by default for a while. So this pr wont change the current default settings.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for clarifying
opts.set_xla_gpu_enable_async_collective_broadcast(true); | ||
opts.set_xla_gpu_enable_async_collective_permute(false); | ||
opts.set_xla_gpu_enable_async_all_to_all(false); | ||
opts.set_xla_gpu_enable_async_reduce_scatter(false); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for clarifying
@golechwierowicz @frgossen could one of you help take a look the copybara failure? Thanks |
Internal tmp import issue. Rerunning it |
…c collective and use the global xla_gpu_disable_async_collectives Imported from GitHub PR openxla/xla#11422 Currently we have 1 global flag(xla_gpu_enable_async_collectives) to control whether we want to asynchronize collectives or not. This flag overrides all other control knobs for each collective. The usage of it is confusing, instead we introduce a new flag xla_gpu_disable_async_collectives which will consolidate all the async flags we have now. We remove xla_gpu_enable_async_collectives and all other individual control knobs. Sample usage: xla_gpu_disable_async_collectives=allreduce,reducescatter disables async allreduce and reducescatter By default it's empty which indicates enabling async for all collectives. Copybara import of the project: -- afff139cb742662801c49052b70a8f234bb280e4 by TJ Xu <[email protected]>: Remove control knobs for each individual async collective and use the global xla_gpu_enable_async_collectives -- e1385c5b1f82aee1d3ea98e5d453c3e60fc5fa8a by TJ Xu <[email protected]>: Consolidate all flags into one -- 72020bef1368f514760cbfd6630c22d2348121c9 by TJ Xu <[email protected]>: Change description of xla_gpu_disable_async_collectives Merging this change closes #11422 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#11422 from Tixxx:tixxx/flag_clean_up 72020bef1368f514760cbfd6630c22d2348121c9 PiperOrigin-RevId: 624859309
…c collective and use the global xla_gpu_disable_async_collectives Imported from GitHub PR openxla/xla#11422 Currently we have 1 global flag(xla_gpu_enable_async_collectives) to control whether we want to asynchronize collectives or not. This flag overrides all other control knobs for each collective. The usage of it is confusing, instead we introduce a new flag xla_gpu_disable_async_collectives which will consolidate all the async flags we have now. We remove xla_gpu_enable_async_collectives and all other individual control knobs. Sample usage: xla_gpu_disable_async_collectives=allreduce,reducescatter disables async allreduce and reducescatter By default it's empty which indicates enabling async for all collectives. Copybara import of the project: -- afff139cb742662801c49052b70a8f234bb280e4 by TJ Xu <[email protected]>: Remove control knobs for each individual async collective and use the global xla_gpu_enable_async_collectives -- e1385c5b1f82aee1d3ea98e5d453c3e60fc5fa8a by TJ Xu <[email protected]>: Consolidate all flags into one -- 72020bef1368f514760cbfd6630c22d2348121c9 by TJ Xu <[email protected]>: Change description of xla_gpu_disable_async_collectives Merging this change closes #11422 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#11422 from Tixxx:tixxx/flag_clean_up 72020bef1368f514760cbfd6630c22d2348121c9 PiperOrigin-RevId: 627992711
FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#11422 from Tixxx:tixxx/flag_clean_up 72020bef1368f514760cbfd6630c22d2348121c9 PiperOrigin-RevId: 628977605
FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#11422 from Tixxx:tixxx/flag_clean_up 72020bef1368f514760cbfd6630c22d2348121c9 PiperOrigin-RevId: 628442535
FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#11422 from Tixxx:tixxx/flag_clean_up 72020bef1368f514760cbfd6630c22d2348121c9 PiperOrigin-RevId: 628941758
…c collective and use the global xla_gpu_disable_async_collectives Imported from GitHub PR openxla/xla#11422 Currently we have 1 global flag(xla_gpu_enable_async_collectives) to control whether we want to asynchronize collectives or not. This flag overrides all other control knobs for each collective. The usage of it is confusing, instead we introduce a new flag xla_gpu_disable_async_collectives which will consolidate all the async flags we have now. We remove xla_gpu_enable_async_collectives and all other individual control knobs. Sample usage: xla_gpu_disable_async_collectives=allreduce,reducescatter disables async allreduce and reducescatter By default it's empty which indicates enabling async for all collectives. Copybara import of the project: -- afff139cb742662801c49052b70a8f234bb280e4 by TJ Xu <[email protected]>: Remove control knobs for each individual async collective and use the global xla_gpu_enable_async_collectives -- e1385c5b1f82aee1d3ea98e5d453c3e60fc5fa8a by TJ Xu <[email protected]>: Consolidate all flags into one -- 72020bef1368f514760cbfd6630c22d2348121c9 by TJ Xu <[email protected]>: Change description of xla_gpu_disable_async_collectives Merging this change closes #11422 PiperOrigin-RevId: 629003173
After this CL, the output shape will not be a tuple if it contains only a single element. I broke this by accident in cl/626122610. I also added tests that ensure the correct behavior. FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#11422 from Tixxx:tixxx/flag_clean_up 72020bef1368f514760cbfd6630c22d2348121c9 PiperOrigin-RevId: 629004592
This breaks internal tests that are still using the old flags. Wdyt about just adding the flag and deprecating the individual ones for now? |
hmm I didn't know those flags were still set explicitly. Yea in that case it makes sense to just deprecate them instead of removing them from xla.proto. Do you need me to make the changes? |
I think it is resolved already, the internal users of the flag were migrated to the new flags. |
"This disables a certain set of async collectives and turn them into" | ||
" synchornous ones. By default, this is empty which indicates enabling" | ||
" async execution for all collectives. A sample usage is: " | ||
" --xla_gpu_disable_async_collectives=ALLREDUCE,REDUCESCATTER")); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Tixxx
How to disable all collectives?
It would be good to have a list of all collective somewhere, or a special value like "ALL" to disable all of them.
XLA PR openxla/xla#11422 removed some XLA flags relating to async collectives. This caused the A3 configs to fail to run, so this change removes such flags from the A3 configs. The flags removed are: --xla_gpu_enable_async_all_gather=true --xla_gpu_enable_async_reduce_scatter=true --xla_gpu_enable_async_all_reduce=true Such flags had no impact before the XLA PR as the async collectives were already enabled by default.
* Move tpu end-to-end test scripts to tpu folder * unify WORKDIR to /deps * Share GCS path between Gemma-7b tests * Add README for llama2-7B * adding script to fix the style and adding modified/fixed files with line length 125 * Move apt install from `rto_setup.sh` to `setup.sh` * Update instructions for installing snap. * Removes batch size from prefill attention calculation. * Fixes for inf testing. * Revert "Fixes for inf testing." This reverts commit b15b1d5. * Fixes * Fix subset of hosts dataloading * inference microbenchmark - allow run specified stages - allow run specific prefill length(s) - delete prefill result - printout prefill result added funcs in max_utils * Update Run_MaxText_via_xpk.md Fixing typo. * inference_microbenchmark: - time prefill only - benchmark prefill and insert * Mark nvidia devtools repo as trusted This is a stopgaps measure to circumvent the nvidia repo's gpg signature issue * Explicitly set AQT Freezer mode in MaxText. PiperOrigin-RevId: 627250589 * Move aqtp pin up * Pre-commit config * Update 128B config on v5e to use qkv_proj_offloaded remat_policy * [MaxText] Rename llama2_7b_single_host_gpu.yml to make it clear that it can be used for any number of host. PiperOrigin-RevId: 627804089 * Split Mixtral test into two scripts * Update jax.tree_map to jax.tree_util.tree_map * change norm sharding fix lint Revert "fix lint" This reverts commit d8dc450. fix lint * Change l2norm to use jnp.sqrt * Fix test_tokenize * Streamlined setup.sh to have fewer apt install calls * loosen tolerance in assert_params_sufficiently_sharded * Enable entropy on multihost CPUs. * Add tests to GPU runner * Replace deprecated np.product with np.prod * fix norm sharding * Add Llama2-70b test * Internal change only. PiperOrigin-RevId: 630446330 * Add more tests for Mixtral * Make some AQT dataclasses to use keyword-only fields (1/N) This cl introduces an temporary decorator that will be temporarily used during this migration. The eventual goal is to enforce kw_only=True in all dataclasses unless it's not feasible, aiming to make AQT less error-prune and improve readability. PiperOrigin-RevId: 631132072 * Reverts e8b53e5 PiperOrigin-RevId: 631465526 * Update tflops calculation * fix sharding on generate cache in prefill results. * Remove async XLA_FLAGS from A3 configs. XLA PR openxla/xla#11422 removed some XLA flags relating to async collectives. This caused the A3 configs to fail to run, so this change removes such flags from the A3 configs. The flags removed are: --xla_gpu_enable_async_all_gather=true --xla_gpu_enable_async_reduce_scatter=true --xla_gpu_enable_async_all_reduce=true Such flags had no impact before the XLA PR as the async collectives were already enabled by default. * Update llama2_7b_gpu.yml PiperOrigin-RevId: 631752008 * Add forward pass logit check test for Llama2-7b * Eval the command string from XPK for GPU script * Remove cases where the deprecated --xla_gpu_simplify_all_fp_conversions is set to its default value. PiperOrigin-RevId: 633645462 * streamline CI test structure * fix pylint fix pylint: Using variable 'p_eval_step' before assignment (#651) * Remove async XLA_FLAGS from A3 configs * Add llama-70b gpu config. PiperOrigin-RevId: 634267313 * Support data input from HuggingFace * Update the NCCL flags for A3+. * add gemma logit test * Integrate orbax logger in maxtext for structured logging. Integrate orbax logger in maxtext for structured logging. Integrate orbax logger in maxtext for structured logging. Integrate orbax logger in maxtext for structured logging. Integrate orbax logger in maxtext for structured logging. Integrate orbax logger in maxtext for structured logging. * fix hf input pipeline * Fix prefill assertion * Remove decode asserts from Gemma test files * add single controller flag * fix OOM issue running inference microbenchmark with llama13b on v5e4 * Add Llama2 13B Tests * Don't clip fp8 stats * Integrate nsys profiler Remove 'enable_profiler' config and add 'profiler' config instead * Add MoE matmul implementation * fix OUTPUT_PATH in v5e/128b.sh * squash * Update flops calculation to active experts in moe * Enable kv cache layout control * Fix Gemma Readme link * Internal change only. Reverts a28f518 PiperOrigin-RevId: 639890999 * Upgrade Pinned Base Image for GPU * Metrics bug: server_lib should be config_lib * Fix MoE matmul scale issue * Removed unused Pallas import from layers/attentions.py PiperOrigin-RevId: 640481280 * Change norm sharding for llama2-7b to fsdp. PiperOrigin-RevId: 640498890 * Copybara import of the project: -- d7d694f by RissyRan <[email protected]>: Fix forward test for Mixtral COPYBARA_INTEGRATE_REVIEW=#679 from google:ranran_fix_forward_test d7d694f PiperOrigin-RevId: 640537456 * Set additional flags for a3 and a3plus * Use run_id instead of sha for docker tag * refactor data input pipeline and add perf data * Add gpt3 175b on v5e config * Pipeline parallelism support (linear only) * Turn on layer scanning for llama2-7b on GPU. This better utilizes recent optimizations to collective approximation in the XLA latency hiding scheduler. PiperOrigin-RevId: 642559284 * reshape q * Add profiler flags to JetStream server Add jetstream config backward compatible * fix tfds instruction * Add vanilla megablox to MoE * Add llama2 70b training config for v5e * base.yml changes circular changes to pipeline.py pyconfig circ changes pipeline parallel tests circular style tree map, half passed tests Total iterations circularized improved iteration comment run all tests test both circular and non-circular circ storage comment circ storage pushing index comment * Account for new mesh axes for llama2-7b, and llama2-70b on GPUs. PiperOrigin-RevId: 643999933 * Sharding the llama2 70b on v5e-16 more efficiently. https://arxiv.org/pdf/2211.05102 https://arxiv.org/pdf/1909.08053 * add compute_axis_order * Add maxengine_server configs to base.yml * Add FSDP + Megablox * Llama3-8b model config * MaxText package * fix data loading from HF hub * Fix llama2-{7,70}b sharding on GPU. PiperOrigin-RevId: 645365795 * Move stage to second axis in mesh Move stage to second axis in mesh * Copybara import of the project: -- 1718b89 by RissyRan <[email protected]>: Refactor permute and unpermute operations COPYBARA_INTEGRATE_REVIEW=#714 from google:refactor_mega b101cbc PiperOrigin-RevId: 645591567 * Fix Mesh setup for multiprocess CPUs. * add kv_quant_axis * Add a directory check for the . If it fails, attempt to check a path relative to the base config, similar to what is done for model configurations. Minor update Remove the raised exception * Add mistral tokenizer to maxtext/assets * Update the dependencies to prepare for integration of emergency checkpointing Withhold some package versions Update version of typing_extensions * Make broadcasting from one replica to all more memory efficient PiperOrigin-RevId: 646526020 * Inference Microbenchmark Sweep * Fix mesh_axes and data_sharding for LLaMA 2 GPU configs. PiperOrigin-RevId: 646795068 * Allow owners to have any approver Fix AddLabel syntax Fix punctuation * Enable saving using Orbax's emergency checkpoint manager fix data loading from HF hub Add explanation to the emergency checkpoint feature Fix pylint issues Minor changes to the config file resolve conflicts Inference Microbenchmark Sweep Fix mesh_axes and data_sharding for LLaMA 2 GPU configs. PiperOrigin-RevId: 646795068 * Add Llama2 7B, 13B high performance training configs * Load/Save Aqt quantized checkpoint. * modify prefill to return first token * Fix and protect simple_layer Fix and protect simple_layer Fix and protect simple_layer Fix and protect simple_layer * Adding option for int4 quantization to kvcache. * support eval dataset and refactor * Support partial overrides for logical_axis_rules. * Fix simple test step count * Clean up MoE brute force implementation * Preliminary restore with lots of hardcoding and hacking Refactor the code and remove the hardcoding More refactoring Cleanup for pull request Address linting issues Preliminary restore with lots of hardcoding and hacking Refactor the code and remove the hardcoding More refactoring Cleanup for pull request Address linting issues Small formatting Fix merging issues * Add convergence tests on A3 GPU * Update tile size * Handle cases where memstats are not available for the device. Memstats are not guaranteed to be available and can throw an error or return None. This change will handle both `jaxlib.xla_extension.XlaRuntimeError` if the device is not a PjRt addressable device or `KeyError` if the memstats returns None if they are not available. * Fix validation error for other models * Fix decode.py to also use first_token from prefill_call * Add moe perf number * move num_experts pyconfig assertion to fix tests * Cast type for inputs before kernel call * Move sharding overrides to models/ directory. PiperOrigin-RevId: 650994392 * Enable quantization for MoE Matmul implementation * Integrate and test Goodput Monitor with MaxText * Adding Tokens/s/device to the log. * Adding support for mixed precision quantization configs. --------- Co-authored-by: maxtext authors <[email protected]> Co-authored-by: Nina Cai <[email protected]> Co-authored-by: NinaCai <[email protected]> Co-authored-by: michelle-yooh <[email protected]> Co-authored-by: In-Ho Yi <[email protected]> Co-authored-by: A9isha <[email protected]> Co-authored-by: In-Ho Yi <[email protected]> Co-authored-by: ssusie <[email protected]> Co-authored-by: tonyjohnchen <[email protected]> Co-authored-by: Roshani Narasimhan <[email protected]> Co-authored-by: Pate Motter <[email protected]> Co-authored-by: khatwanimohit <[email protected]> Co-authored-by: Morgan Du <[email protected]> Co-authored-by: DongHyun Choi <[email protected]> Co-authored-by: gobbleturk <[email protected]> Co-authored-by: Raymond Zou <[email protected]> Co-authored-by: Bixia Zheng <[email protected]> Co-authored-by: Ran Ran <[email protected]> Co-authored-by: Zhiyu Li <[email protected]> Co-authored-by: Rafi Witten <[email protected]> Co-authored-by: RissyRan <[email protected]> Co-authored-by: Greg Olechwierowicz <[email protected]> Co-authored-by: Junwei Yang <[email protected]> Co-authored-by: Reed Wanderman-Milne <[email protected]> Co-authored-by: Dimitar (Mitko) Asenov <[email protected]> Co-authored-by: aireenmei <[email protected]> Co-authored-by: yangyuwei <[email protected]> Co-authored-by: Abhinav Singh <[email protected]> Co-authored-by: Sadi Kneipp <[email protected]> Co-authored-by: jwyang-google <[email protected]> Co-authored-by: Anfal Siddiqui <[email protected]> Co-authored-by: Brendan Slabe <[email protected]> Co-authored-by: Sergei Lebedev <[email protected]> Co-authored-by: Jon Bolin <[email protected]> Co-authored-by: Zijun Zhou <[email protected]> Co-authored-by: Zhihao Shan <[email protected]> Co-authored-by: Adam O'Brien <[email protected]> Co-authored-by: Vipan Nalla <[email protected]> Co-authored-by: Vipan Nalla <[email protected]> Co-authored-by: Xuefeng Gu <[email protected]> Co-authored-by: Andy Ye <[email protected]> Co-authored-by: Mitali Singh <[email protected]> Co-authored-by: xuefgu <[email protected]> Co-authored-by: Luke Baumann <[email protected]> Co-authored-by: Dipannita Shaw <[email protected]>
Currently we have 1 global flag(xla_gpu_enable_async_collectives) to control whether we want to asynchronize collectives or not. This flag overrides all other control knobs for each collective. The usage of it is confusing, instead we introduce a new flag xla_gpu_disable_async_collectives which will consolidate all the async flags we have now. We remove xla_gpu_enable_async_collectives and all other individual control knobs.
Sample usage:
xla_gpu_disable_async_collectives=allreduce,reducescatter
disables async allreduce and reducescatter
By default it's empty which indicates enabling async for all collectives.