diff --git a/MaxText/configs/16b.sh b/MaxText/configs/16b.sh new file mode 100644 index 000000000..4a8b08fb0 --- /dev/null +++ b/MaxText/configs/16b.sh @@ -0,0 +1,10 @@ +echo "Running 16b.sh" + +RUN_NAME=${1} +OUTPUT_PATH=${2} +DATASET_PATH=${3} + +bash rto_setup.sh + +TFLOP_THRESHOLD=0 # set to 0 since we are not actually running as a test. +bash end_to_end/test_tflops_16b_params.sh ${RUN_NAME} ${TFLOP_THRESHOLD} ${OUTPUT_PATH} ${DATASET_PATH} diff --git a/MaxText/configs/32b.sh b/MaxText/configs/32b.sh new file mode 100644 index 000000000..a274de167 --- /dev/null +++ b/MaxText/configs/32b.sh @@ -0,0 +1,10 @@ +echo "Running 32b.sh" + +RUN_NAME=${1} +OUTPUT_PATH=${2} +DATASET_PATH=${3} + +bash rto_setup.sh + +TFLOP_THRESHOLD=0 # set to 0 since we are not actually running as a test. +bash end_to_end/test_tflops_32b_params.sh ${RUN_NAME} ${TFLOP_THRESHOLD} ${OUTPUT_PATH} ${DATASET_PATH} diff --git a/end_to_end/test_tflops_16b_params.sh b/end_to_end/test_tflops_16b_params.sh index c8a691b30..d2d3d06f8 100644 --- a/end_to_end/test_tflops_16b_params.sh +++ b/end_to_end/test_tflops_16b_params.sh @@ -9,16 +9,16 @@ DATASET_PATH=${4} if [ -z ${5} ] then - RUN_NAME=${USER}_$(date +%Y-%m-%d-%H-%M-%S) + RUN_NAME=${USER} else - RUN_NAME=${5}_$(date +%Y-%m-%d-%H) + RUN_NAME=${5} fi # Train export LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" python3 MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME\ - steps=150 per_device_batch_size=2 enable_checkpointing=false\ - enable_profiler=false remat_policy=proj base_emb_dim=6144 base_mlp_dim=24576\ + steps=150 per_device_batch_size=6 enable_checkpointing=false\ + enable_profiler=false remat_policy=full base_emb_dim=6144 base_mlp_dim=24576\ base_num_heads=24 base_num_decoder_layers=36 head_dim=256\ max_target_length=2048 metrics_file='metrics.txt' base_output_directory=$OUTPUT_PATH\ dataset_path=$DATASET_PATH log_period=150 diff --git a/end_to_end/test_tflops_32b_params.sh b/end_to_end/test_tflops_32b_params.sh index 62aa76e53..a5e014359 100644 --- a/end_to_end/test_tflops_32b_params.sh +++ b/end_to_end/test_tflops_32b_params.sh @@ -9,16 +9,16 @@ DATASET_PATH=${4} if [ -z ${5} ] then - RUN_NAME=${USER}_$(date +%Y-%m-%d-%H-%M-%S) + RUN_NAME=${USER} else - RUN_NAME=${5}_$(date +%Y-%m-%d-%H) + RUN_NAME=${5} fi # Train export LIBTPU_INIT_ARGS="--xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true" python3 MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME\ - steps=150 per_device_batch_size=1 enable_checkpointing=false\ - enable_profiler=false remat_policy=proj base_emb_dim=8192 base_mlp_dim=32768\ + steps=150 per_device_batch_size=4 enable_checkpointing=false\ + enable_profiler=false remat_policy=full base_emb_dim=8192 base_mlp_dim=32768\ base_num_heads=32 base_num_decoder_layers=40 head_dim=256\ max_target_length=2048 metrics_file='metrics.txt' base_output_directory=$OUTPUT_PATH\ dataset_path=$DATASET_PATH log_period=150 diff --git a/rto_setup.sh b/rto_setup.sh new file mode 100644 index 000000000..4c13a4df9 --- /dev/null +++ b/rto_setup.sh @@ -0,0 +1,5 @@ +echo "Running rto_setup.sh..." +first_line_res=$(ip route show | head -n 1) +sudo ip route change ${first_line_res} rto_min 5ms +sudo ethtool -K ens9 tx-nocache-copy on +echo "rto_setup finished" \ No newline at end of file