-
Notifications
You must be signed in to change notification settings - Fork 234
/
test_tflops_32b_params.sh
27 lines (22 loc) · 1.06 KB
/
test_tflops_32b_params.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
#!/bin/bash
set -e
USER=${1}
TFLOP_THRESHOLD=${2}
OUTPUT_PATH=${3}
DATASET_PATH=${4}
if [ -z ${5} ]
then
RUN_NAME=${USER}
else
RUN_NAME=${5}
fi
# Train
export LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"
python3 MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME\
steps=150 per_device_batch_size=4 enable_checkpointing=false\
enable_profiler=false remat_policy=full base_emb_dim=8192 base_mlp_dim=32768\
base_num_heads=32 base_num_decoder_layers=40 head_dim=256\
max_target_length=2048 metrics_file='metrics.txt' base_output_directory=$OUTPUT_PATH\
dataset_path=$DATASET_PATH log_period=150
# Assert TFLOP/s
python3 end_to_end/eval_assert.py metrics_average metrics.txt $TFLOP_THRESHOLD perf/per_device_tflops_per_sec