Skip to content

Commit

Permalink
Add maxtext sweep config (#109)
Browse files Browse the repository at this point in the history
* Add maxtext sweep config
  • Loading branch information
gobbleturk committed Aug 10, 2023
1 parent 12bc1ba commit d726b7a
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 0 deletions.
19 changes: 19 additions & 0 deletions MaxText/aqt/jax/v2/google/maxtext_sweeps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""MaxText AQT sweeps configurations."""

# pylint: skip-file

import aqt.jax.v2.config as aqt_config


def sweep1(fwd_int8: bool, bwd_int8: bool) -> aqt_config.DotGeneral:
fqt_config = aqt_config.fully_quantized(
fwd_bits=8 if fwd_int8 else None,
bwd_bits=8 if bwd_int8 else None,
use_fwd_quant=False,
use_stochastic_rounding=None,
vjp_lhs_stochastic_rounding=True,
vjp_rhs_stochastic_rounding=False,
fwd_save_accumulator_memory=False,
bwd_save_accumulator_memory=False,
)
return fqt_config
1 change: 1 addition & 0 deletions end_to_end/eval_assert.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# pylint: skip-file
"""Reads and asserts over target values"""
from absl import app
from typing import Sequence
Expand Down

0 comments on commit d726b7a

Please sign in to comment.