diff --git a/MaxText/configs/a3/llama_2_7b/16vm.sh b/MaxText/configs/a3/llama_2_7b/16vm.sh index 76ee26b32..fa07a470e 100644 --- a/MaxText/configs/a3/llama_2_7b/16vm.sh +++ b/MaxText/configs/a3/llama_2_7b/16vm.sh @@ -37,6 +37,6 @@ export XLA_FLAGS="--xla_dump_to=$OUTPUT_PATH/$RUN_NAME/HLO_dumps/ # 16 nodes python MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME hardware=gpu \ - steps=30 dcn_data_parallelism=16 ici_fsdp_parallelism=8 per_device_batch_size=6 max_target_length=4096 model_name=llama2-7b \ + steps=30 dcn_data_parallelism=16 ici_fsdp_parallelism=8 per_device_batch_size=4 max_target_length=4096 model_name=llama2-7b \ enable_checkpointing=false attention=cudnn_flash_te remat_policy=minimal_flash use_iota_embed=true scan_layers=false \ dataset_type=synthetic async_checkpointing=false base_output_directory=gs://runner-maxtext-logs enable_profiler=true diff --git a/MaxText/configs/a3/llama_2_7b/README.md b/MaxText/configs/a3/llama_2_7b/README.md new file mode 100644 index 000000000..966049b8e --- /dev/null +++ b/MaxText/configs/a3/llama_2_7b/README.md @@ -0,0 +1,28 @@ + + +# High Performance Model Configs on A3 GPU +Expected performance results for Llama2-7B model running on A3 GPU: + + +### Llama2-7B +| Hardware | TFLOP/sec/chip | +| ---------------------- | ---------------- | +| 1x A3 (h100-80gb-8) | 492 | +| 2x A3 (h100-80gb-8) | 422 | +| 4x A3 (h100-80gb-8) | 407 | +| 8x A3 (h100-80gb-8) | 409 | +| 16x A3 (h100-80gb-8) | 375 |