Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor permute and unpermute operations #714

Closed
wants to merge 1 commit into from
Closed

Conversation

RissyRan
Copy link
Collaborator

Description

  • Refactor permute and unpermute operations to get a better perf.
  • Update rope_max_timescale to match HF config from Mistral AI for both Mistral & Mixtral (thanks @ZhiyuLi-goog for bring it up).

Test

Test locally: link

@RissyRan RissyRan marked this pull request as ready for review June 20, 2024 18:25
Copy link
Collaborator

@ZhiyuLi-goog ZhiyuLi-goog left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

copybara-service bot pushed a commit that referenced this pull request Jun 22, 2024
--
1718b89 by RissyRan <[email protected]>:

Refactor permute and unpermute operations

COPYBARA_INTEGRATE_REVIEW=#714 from google:refactor_mega b101cbc
PiperOrigin-RevId: 645591567
@RissyRan RissyRan closed this Jul 2, 2024
vivianrwu pushed a commit to vivianrwu/maxtext that referenced this pull request Jul 11, 2024
--
1718b89 by RissyRan <[email protected]>:

Refactor permute and unpermute operations

COPYBARA_INTEGRATE_REVIEW=google#714 from google:refactor_mega b101cbc
PiperOrigin-RevId: 645591567
shauryagup added a commit that referenced this pull request Jul 15, 2024
* Move tpu end-to-end test scripts to tpu folder

* unify WORKDIR to /deps

* Share GCS path between Gemma-7b tests

* Add README for llama2-7B

* adding script to fix the style and adding modified/fixed files with line length 125

* Move apt install from `rto_setup.sh` to `setup.sh`

* Update instructions for installing snap.

* Removes batch size from prefill attention calculation.

* Fixes for inf testing.

* Revert "Fixes for inf testing."

This reverts commit b15b1d5.

* Fixes

* Fix subset of hosts dataloading

* inference microbenchmark

  - allow run specified stages
  - allow run specific prefill length(s)
  - delete prefill result
  - printout prefill result

added funcs in max_utils

* Update Run_MaxText_via_xpk.md

Fixing typo.

* inference_microbenchmark:

  - time prefill only
  - benchmark prefill and insert

* Mark nvidia devtools repo as trusted

This is a stopgaps measure to circumvent the nvidia repo's gpg signature issue

* Explicitly set AQT Freezer mode in MaxText.

PiperOrigin-RevId: 627250589

* Move aqtp pin up

* Pre-commit config

* Update 128B config on v5e to use qkv_proj_offloaded remat_policy

* [MaxText] Rename llama2_7b_single_host_gpu.yml to make it clear that it can be used for any number of host.

PiperOrigin-RevId: 627804089

* Split Mixtral test into two scripts

* Update jax.tree_map to jax.tree_util.tree_map

* change norm sharding

fix lint

Revert "fix lint"

This reverts commit d8dc450.

fix lint

* Change l2norm to use jnp.sqrt

* Fix test_tokenize

* Streamlined setup.sh to have fewer apt install calls

* loosen tolerance in assert_params_sufficiently_sharded

* Enable entropy on multihost CPUs.

* Add tests to GPU runner

* Replace deprecated np.product with np.prod

* fix norm sharding

* Add Llama2-70b test

* Internal change only.

PiperOrigin-RevId: 630446330

* Add more tests for Mixtral

* Make some AQT dataclasses to use keyword-only fields (1/N)

This cl introduces an temporary decorator that will be temporarily used during this migration. The eventual goal is to enforce kw_only=True in all dataclasses unless it's not feasible, aiming to make AQT less error-prune and improve readability.

PiperOrigin-RevId: 631132072

* Reverts e8b53e5

PiperOrigin-RevId: 631465526

* Update tflops calculation

* fix sharding on generate cache in prefill results.

* Remove async XLA_FLAGS from A3 configs.

XLA PR openxla/xla#11422 removed some XLA flags relating to async collectives. This caused the A3 configs to fail to run, so this change removes such flags from the A3 configs. The flags removed are:

--xla_gpu_enable_async_all_gather=true
--xla_gpu_enable_async_reduce_scatter=true
--xla_gpu_enable_async_all_reduce=true

Such flags had no impact before the XLA PR as the async collectives were already enabled by default.

* Update llama2_7b_gpu.yml

PiperOrigin-RevId: 631752008

* Add forward pass logit check test for Llama2-7b

* Eval the command string from XPK for GPU script

* Remove cases where the deprecated --xla_gpu_simplify_all_fp_conversions is set to its default value.

PiperOrigin-RevId: 633645462

* streamline CI test structure

* fix pylint

fix pylint: Using variable 'p_eval_step' before assignment (#651)

* Remove async XLA_FLAGS from A3 configs

* Add llama-70b gpu config.

PiperOrigin-RevId: 634267313

* Support data input from HuggingFace

* Update the NCCL flags for A3+.

* add gemma logit test

* Integrate orbax logger in maxtext for structured logging.

Integrate orbax logger in maxtext for structured logging.

Integrate orbax logger in maxtext for structured logging.

Integrate orbax logger in maxtext for structured logging.

Integrate orbax logger in maxtext for structured logging.

Integrate orbax logger in maxtext for structured logging.

* fix hf input pipeline

* Fix prefill assertion

* Remove decode asserts from Gemma test files

* add single controller flag

* fix OOM issue running inference microbenchmark with llama13b on v5e4

* Add Llama2 13B Tests

* Don't clip fp8 stats

* Integrate nsys profiler

Remove 'enable_profiler' config and add 'profiler' config instead

* Add MoE matmul implementation

* fix OUTPUT_PATH in v5e/128b.sh

* squash

* Update flops calculation to active experts in moe

* Enable kv cache layout control

* Fix Gemma Readme link

* Internal change only.

Reverts a28f518

PiperOrigin-RevId: 639890999

* Upgrade Pinned Base Image for GPU

* Metrics bug: server_lib should be config_lib

* Fix MoE matmul scale issue

* Removed unused Pallas import from layers/attentions.py

PiperOrigin-RevId: 640481280

* Change norm sharding for llama2-7b to fsdp.

PiperOrigin-RevId: 640498890

* Copybara import of the project:

--
d7d694f by RissyRan <[email protected]>:

Fix forward test for Mixtral

COPYBARA_INTEGRATE_REVIEW=#679 from google:ranran_fix_forward_test d7d694f
PiperOrigin-RevId: 640537456

* Set additional flags for a3 and a3plus

* Use run_id instead of sha for docker tag

* refactor data input pipeline and add perf data

* Add gpt3 175b on v5e config

* Pipeline parallelism support (linear only)

* Turn on layer scanning for llama2-7b on GPU.

This better utilizes recent optimizations to collective approximation in the XLA latency hiding scheduler.

PiperOrigin-RevId: 642559284

* reshape q

* Add profiler flags to JetStream server

Add jetstream config

backward compatible

* fix tfds instruction

* Add vanilla megablox to MoE

* Add llama2 70b training config for v5e

* base.yml changes

circular changes to pipeline.py

pyconfig circ changes

pipeline parallel tests circular style

tree map, half passed tests

Total iterations circularized

improved iteration comment

run all tests

test both circular and non-circular

circ storage comment

circ storage pushing index comment

* Account for new mesh axes for llama2-7b, and llama2-70b on GPUs.

PiperOrigin-RevId: 643999933

* Sharding the llama2 70b on v5e-16 more efficiently.

https://arxiv.org/pdf/2211.05102
https://arxiv.org/pdf/1909.08053

* add compute_axis_order

* Add maxengine_server configs to base.yml

* Add FSDP + Megablox

* Llama3-8b model config

* MaxText package

* fix data loading from HF hub

* Fix llama2-{7,70}b sharding on GPU.

PiperOrigin-RevId: 645365795

* Move stage to second axis in mesh

Move stage to second axis in mesh

* Copybara import of the project:

--
1718b89 by RissyRan <[email protected]>:

Refactor permute and unpermute operations

COPYBARA_INTEGRATE_REVIEW=#714 from google:refactor_mega b101cbc
PiperOrigin-RevId: 645591567

* Fix Mesh setup for multiprocess CPUs.

* add kv_quant_axis

* Add a directory check for the . If it fails, attempt to check a path relative to the base config, similar to what is done for model configurations.

Minor update

Remove the raised exception

* Add mistral tokenizer to maxtext/assets

* Update the dependencies to prepare for integration of emergency checkpointing

Withhold some package versions

Update version of typing_extensions

* Make broadcasting from one replica to all more memory efficient

PiperOrigin-RevId: 646526020

* Inference Microbenchmark Sweep

* Fix mesh_axes and data_sharding for LLaMA 2 GPU configs.

PiperOrigin-RevId: 646795068

* Allow owners to have any approver

Fix AddLabel syntax

Fix punctuation

* Enable saving using Orbax's emergency checkpoint manager

fix data loading from HF hub

Add explanation to the emergency checkpoint feature

Fix pylint issues

Minor changes to the config file

resolve conflicts

Inference Microbenchmark Sweep

Fix mesh_axes and data_sharding for LLaMA 2 GPU configs.

PiperOrigin-RevId: 646795068

* Add Llama2 7B, 13B high performance training configs

* Load/Save Aqt quantized checkpoint.

* modify prefill to return first token

* Fix and protect simple_layer

Fix and protect simple_layer

Fix and protect simple_layer

Fix and protect simple_layer

* Adding option for int4 quantization to kvcache.

* support eval dataset and refactor

* Support partial overrides for logical_axis_rules.

* Fix simple test step count

* Clean up MoE brute force implementation

* Preliminary restore with lots of hardcoding and hacking

Refactor the code and remove the hardcoding

More refactoring

Cleanup for pull request

Address linting issues

Preliminary restore with lots of hardcoding and hacking

Refactor the code and remove the hardcoding

More refactoring

Cleanup for pull request

Address linting issues

Small formatting

Fix merging issues

* Add convergence tests on A3 GPU

* Update tile size

* Handle cases where memstats are not available for the device.

Memstats are not guaranteed to be available and can throw an error or return None. This change will handle both `jaxlib.xla_extension.XlaRuntimeError` if the device is not a PjRt addressable device or `KeyError` if the memstats returns None if they are not available.

* Fix validation error for other models

* Fix decode.py to also use first_token from prefill_call

* Add moe perf number

* move num_experts pyconfig assertion to fix tests

* Cast type for inputs before kernel call

* Move sharding overrides to models/ directory.

PiperOrigin-RevId: 650994392

* Enable quantization for MoE Matmul implementation

* Integrate and test Goodput Monitor with MaxText

* Adding Tokens/s/device to the log.

* Adding support for mixed precision quantization configs.

---------

Co-authored-by: maxtext authors <[email protected]>
Co-authored-by: Nina Cai <[email protected]>
Co-authored-by: NinaCai <[email protected]>
Co-authored-by: michelle-yooh <[email protected]>
Co-authored-by: In-Ho Yi <[email protected]>
Co-authored-by: A9isha <[email protected]>
Co-authored-by: In-Ho Yi <[email protected]>
Co-authored-by: ssusie <[email protected]>
Co-authored-by: tonyjohnchen <[email protected]>
Co-authored-by: Roshani Narasimhan <[email protected]>
Co-authored-by: Pate Motter <[email protected]>
Co-authored-by: khatwanimohit <[email protected]>
Co-authored-by: Morgan Du <[email protected]>
Co-authored-by: DongHyun Choi <[email protected]>
Co-authored-by: gobbleturk <[email protected]>
Co-authored-by: Raymond Zou <[email protected]>
Co-authored-by: Bixia Zheng <[email protected]>
Co-authored-by: Ran Ran <[email protected]>
Co-authored-by: Zhiyu Li <[email protected]>
Co-authored-by: Rafi Witten <[email protected]>
Co-authored-by: RissyRan <[email protected]>
Co-authored-by: Greg Olechwierowicz <[email protected]>
Co-authored-by: Junwei Yang <[email protected]>
Co-authored-by: Reed Wanderman-Milne <[email protected]>
Co-authored-by: Dimitar (Mitko) Asenov <[email protected]>
Co-authored-by: aireenmei <[email protected]>
Co-authored-by: yangyuwei <[email protected]>
Co-authored-by: Abhinav Singh <[email protected]>
Co-authored-by: Sadi Kneipp <[email protected]>
Co-authored-by: jwyang-google <[email protected]>
Co-authored-by: Anfal Siddiqui <[email protected]>
Co-authored-by: Brendan Slabe <[email protected]>
Co-authored-by: Sergei Lebedev <[email protected]>
Co-authored-by: Jon Bolin <[email protected]>
Co-authored-by: Zijun Zhou <[email protected]>
Co-authored-by: Zhihao Shan <[email protected]>
Co-authored-by: Adam O'Brien <[email protected]>
Co-authored-by: Vipan Nalla <[email protected]>
Co-authored-by: Vipan Nalla <[email protected]>
Co-authored-by: Xuefeng Gu <[email protected]>
Co-authored-by: Andy Ye <[email protected]>
Co-authored-by: Mitali Singh <[email protected]>
Co-authored-by: xuefgu <[email protected]>
Co-authored-by: Luke Baumann <[email protected]>
Co-authored-by: Dipannita Shaw <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants