Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update tflops calculation #644

Merged
merged 1 commit into from
May 10, 2024
Merged

Update tflops calculation #644

merged 1 commit into from
May 10, 2024

Conversation

RissyRan
Copy link
Collaborator

@RissyRan RissyRan commented May 9, 2024

Description

Update tflops calculation to match Megatron style with sources:

  • MaxText reference PR
  • Megatron-LM link

Test

List models before/after the change: total_ffn_flops value

Model Name TFLOPS/step before TFLOPS/step after
llama2-7b 1033.2 1053.4
llama2-13b 1981.1 2018.7
mistral-7b 1107.4 1127.6
llama2-70b 10368.9 10528.0
mixtral-8x7b 6926.1 6946.5
gemma-7b 1293.6 1328.2
gemma-2b 380.7 391.8
gpt3-6b 423.1 430.4
gpt3-52k 0.01256 0.01253

Test code snippet

class FlopsTest(unittest.TestCase):

  def setUp(self):
    import os
    os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
    pyconfig.initialize(
      [None, 'configs/base.yml'],
      run_name='test',
      model_name='mixtral-8x7b',
      flops_change=False,
    )
    self.cfg = pyconfig.config
    devices_array = max_utils.create_device_mesh(self.cfg)
    self.mesh = Mesh(devices_array, self.cfg.mesh_axes)
    quant = quantizations.configure_quantization(self.cfg)
    self.model = Transformer(self.cfg, mesh=self.mesh, quant=quant)

  def test_flops(self):
    rng = random.PRNGKey(0)

    tx = optax.adam(learning_rate=0.001)
    state, _, _ = max_utils.setup_training_state(self.model, None, tx, self.cfg, rng, self.mesh, None)
    num_model_parameters = max_utils.calculate_num_params_from_pytree(state.params)
    maxtext_utils.calculate_tflops_training_per_device(num_model_parameters, self.cfg)

@RissyRan RissyRan force-pushed the moe_param branch 5 times, most recently from 3955128 to 9d74f00 Compare May 9, 2024 07:45
@RissyRan RissyRan requested a review from ZhiyuLi-goog May 9, 2024 15:16
Copy link
Collaborator

@rwitten rwitten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make sure Abhinav Goel at NV is aligned

MaxText/maxtext_utils.py Outdated Show resolved Hide resolved
@abhinavgoel95
Copy link
Contributor

@rwitten @RissyRan this PR looks good to me

@RissyRan RissyRan marked this pull request as ready for review May 9, 2024 20:19
@RissyRan RissyRan requested a review from gobbleturk as a code owner May 9, 2024 20:19
@RissyRan RissyRan assigned rwitten and unassigned rwitten May 9, 2024
Copy link
Collaborator

@rwitten rwitten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! But please don't merge without Abhinav's approval.

Oh he already approved

@rwitten rwitten removed their assignment May 9, 2024
@copybara-service copybara-service bot merged commit f9c9dd8 into main May 10, 2024
11 checks passed
@copybara-service copybara-service bot deleted the moe_param branch May 10, 2024 19:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants