Skip to content

Commit

Permalink
Add llama2 configs for GPU A3
Browse files Browse the repository at this point in the history
  • Loading branch information
michelle-yooh committed Apr 8, 2024
1 parent f04ba76 commit 1200dd7
Show file tree
Hide file tree
Showing 5 changed files with 215 additions and 0 deletions.
42 changes: 42 additions & 0 deletions MaxText/configs/a3/llama_2_7b/16vm.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
echo "Running 16vm.sh"
# Example command to invoke this script
# bash MaxText/configs/a3/llama_2_7b/16vm.sh
#
# Example command to invoke this script via XPK
# python3 xpk/xpk.py workload create --cluster ${CLUSTER_NAME} \
# --workload ${WORKLOAD_NAME} --docker-image=gcr.io/supercomputer-testing/${LOCAL_IMAGE_NAME} \
# --device-type ${DEVICE_TYPE} --num-slices 16 --priority=high \
# --command "bash MaxText/configs/a3/llama_2_7b/16vm.sh"

# Stop execution if any command exits with error
set -e

export OUTPUT_PATH="gs:https://maxtext-experiments-multipod"
export RUN_NAME="llama-2-16vm-$(date +%Y-%m-%d-%H-%M)"

# Set environment variables
for ARGUMENT in "$@"; do
IFS='=' read -r KEY VALUE <<< "$ARGUMENT"
export "$KEY"="$VALUE"
done

export XLA_PYTHON_CLIENT_MEM_FRACTION=0.85
export CUDA_DEVICE_MAX_CONNECTIONS=1
export NVTE_FUSED_ATTN=1
export NCCL_DEBUG=VERSION
export XLA_FLAGS="--xla_dump_to=$OUTPUT_PATH/$RUN_NAME/HLO_dumps/
--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_async_all_gather=true --xla_gpu_enable_async_reduce_scatter=true
--xla_gpu_enable_triton_gemm=false --xla_gpu_simplify_all_fp_conversions --xla_gpu_graph_level=0
--xla_gpu_enable_async_all_reduce=true --xla_gpu_enable_highest_priority_async_stream=true
--xla_gpu_all_reduce_combine_threshold_bytes=1073741824 --xla_gpu_all_gather_combine_threshold_bytes=134217728
--xla_gpu_reduce_scatter_combine_threshold_bytes=134217728 --xla_gpu_enable_pipelined_all_gather=true
--xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true
--xla_gpu_enable_while_loop_double_buffering=true --xla_gpu_enable_triton_softmax_fusion=false
--xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false
--xla_disable_hlo_passes=rematerialization"

# 16 nodes
python MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME hardware=gpu \
steps=30 dcn_data_parallelism=16 ici_fsdp_parallelism=8 per_device_batch_size=6 max_target_length=4096 model_name=llama2-7b \
enable_checkpointing=false attention=cudnn_flash_te remat_policy=minimal_flash use_iota_embed=true scan_layers=false \
dataset_type=synthetic async_checkpointing=false base_output_directory=gs:https://runner-maxtext-logs enable_profiler=true
43 changes: 43 additions & 0 deletions MaxText/configs/a3/llama_2_7b/1vm.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
echo "Running 1vm.sh"
# Example command to invoke this script
# bash MaxText/configs/a3/llama_2_7b/1vm.sh
#
# Example command to invoke this script via XPK
# python3 xpk/xpk.py workload create --cluster ${CLUSTER_NAME} \
# --workload ${WORKLOAD_NAME} --docker-image=gcr.io/supercomputer-testing/${LOCAL_IMAGE_NAME} \
# --device-type ${DEVICE_TYPE} --num-slices 1 \
# --command "bash MaxText/configs/a3/llama_2_7b/1vm.sh"

# Stop execution if any command exits with error
set -e

export OUTPUT_PATH="gs:https://maxtext-experiments-multipod"
export RUN_NAME="llama-2-1vm-$(date +%Y-%m-%d-%H-%M)"

# Set environment variables
for ARGUMENT in "$@"; do
IFS='=' read -r KEY VALUE <<< "$ARGUMENT"
export "$KEY"="$VALUE"
done

export XLA_PYTHON_CLIENT_MEM_FRACTION=0.85
export CUDA_DEVICE_MAX_CONNECTIONS=1
export NVTE_FUSED_ATTN=1
export NCCL_DEBUG=VERSION
export XLA_FLAGS="--xla_dump_to=$OUTPUT_PATH/$RUN_NAME/HLO_dumps/
--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_async_all_gather=true
--xla_gpu_enable_async_reduce_scatter=true --xla_gpu_enable_triton_gemm=false --xla_gpu_simplify_all_fp_conversions
--xla_gpu_graph_level=0 --xla_gpu_enable_async_all_reduce=true --xla_gpu_enable_highest_priority_async_stream=true
--xla_gpu_all_reduce_combine_threshold_bytes=134217728 --xla_gpu_all_gather_combine_threshold_bytes=134217728
--xla_gpu_reduce_scatter_combine_threshold_bytes=67108864 --xla_gpu_enable_pipelined_all_gather=true
--xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true
--xla_gpu_enable_while_loop_double_buffering=true --xla_gpu_enable_triton_softmax_fusion=false
--xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false
--xla_disable_hlo_passes=rematerialization"


# 1 node, DATA_DP=1, ICI_FSDP=8
python MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME hardware=gpu\
steps=30 dcn_data_parallelism=1 ici_fsdp_parallelism=8 per_device_batch_size=4 max_target_length=4096 model_name=llama2-7b \
enable_checkpointing=false attention=cudnn_flash_te remat_policy=minimal_flash use_iota_embed=true scan_layers=false \
dataset_type=synthetic async_checkpointing=false base_output_directory=$OUTPUT_PATH enable_profiler=false
44 changes: 44 additions & 0 deletions MaxText/configs/a3/llama_2_7b/2vm.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
echo "Running 2vm.sh"
# Example command to invoke this script
# bash MaxText/configs/a3/llama_2_7b/2vm.sh
#
# Example command to invoke this script via XPK
# python3 xpk/xpk.py workload create --cluster ${CLUSTER_NAME} \
# --workload ${WORKLOAD_NAME} --docker-image=gcr.io/supercomputer-testing/${LOCAL_IMAGE_NAME} \
# --device-type ${DEVICE_TYPE} --num-slices 2 \
# --command "bash MaxText/configs/a3/llama_2_7b/2vm.sh"

# Stop execution if any command exits with error
set -e

export OUTPUT_PATH="gs:https://maxtext-experiments-multipod"
export RUN_NAME="llama-2-2vm-$(date +%Y-%m-%d-%H-%M)"

# Set environment variables
for ARGUMENT in "$@"; do
IFS='=' read -r KEY VALUE <<< "$ARGUMENT"
export "$KEY"="$VALUE"
done

export XLA_PYTHON_CLIENT_MEM_FRACTION=0.85
export CUDA_DEVICE_MAX_CONNECTIONS=1
export NVTE_FUSED_ATTN=1
export NCCL_DEBUG=VERSION
export XLA_FLAGS="--xla_dump_to=$OUTPUT_PATH/$RUN_NAME/HLO_dumps/
--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_async_all_gather=true --xla_gpu_enable_async_reduce_scatter=true
--xla_gpu_enable_triton_gemm=false --xla_gpu_simplify_all_fp_conversions --xla_gpu_graph_level=0
--xla_gpu_enable_async_all_reduce=true --xla_gpu_enable_highest_priority_async_stream=true
--xla_gpu_all_reduce_combine_threshold_bytes=67108864 --xla_gpu_all_gather_combine_threshold_bytes=134217728
--xla_gpu_reduce_scatter_combine_threshold_bytes=67108864 --xla_gpu_enable_pipelined_all_gather=true
--xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true
--xla_gpu_enable_while_loop_double_buffering=true --xla_gpu_enable_triton_softmax_fusion=false
--xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false
--xla_disable_hlo_passes=rematerialization"


# 2 nodes
python MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME hardware=gpu \
steps=30 dcn_data_parallelism=2 ici_fsdp_parallelism=8 per_device_batch_size=4 max_target_length=4096 model_name=llama2-7b \
enable_checkpointing=false attention=cudnn_flash_te remat_policy=minimal_flash use_iota_embed=true scan_layers=false \
dataset_type=synthetic async_checkpointing=false base_output_directory=gs:https://runner-maxtext-logs enable_profiler=true

43 changes: 43 additions & 0 deletions MaxText/configs/a3/llama_2_7b/4vm.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
echo "Running 4vm.sh"
# Example command to invoke this script
# bash MaxText/configs/a3/llama_2_7b/4vm.sh
#
# Example command to invoke this script via XPK
# python3 xpk/xpk.py workload create --cluster ${CLUSTER_NAME} \
# --workload ${WORKLOAD_NAME} --docker-image=gcr.io/supercomputer-testing/${LOCAL_IMAGE_NAME} \
# --device-type ${DEVICE_TYPE} --num-slices 4 \
# --command "bash MaxText/configs/a3/llama_2_7b/4vm.sh"

# Stop execution if any command exits with error
set -e

export OUTPUT_PATH="gs:https://maxtext-experiments-multipod"
export RUN_NAME="llama-2-4vm-$(date +%Y-%m-%d-%H-%M)"

# Set environment variables
for ARGUMENT in "$@"; do
IFS='=' read -r KEY VALUE <<< "$ARGUMENT"
export "$KEY"="$VALUE"
done

export XLA_PYTHON_CLIENT_MEM_FRACTION=0.85
export CUDA_DEVICE_MAX_CONNECTIONS=1
export NVTE_FUSED_ATTN=1
export NCCL_DEBUG=VERSION
export XLA_FLAGS="--xla_dump_to=$OUTPUT_PATH/$RUN_NAME/HLO_dumps/
--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_async_all_gather=true
--xla_gpu_enable_async_reduce_scatter=true --xla_gpu_enable_triton_gemm=false --xla_gpu_simplify_all_fp_conversions
--xla_gpu_graph_level=0 --xla_gpu_enable_async_all_reduce=true --xla_gpu_enable_highest_priority_async_stream=true
--xla_gpu_all_reduce_combine_threshold_bytes=536870912 --xla_gpu_all_gather_combine_threshold_bytes=134217728
--xla_gpu_reduce_scatter_combine_threshold_bytes=67108864 --xla_gpu_enable_pipelined_all_gather=true
--xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true
--xla_gpu_enable_while_loop_double_buffering=true --xla_gpu_enable_triton_softmax_fusion=false
--xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false
--xla_disable_hlo_passes=rematerialization"

# 4 nodes
python MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME hardware=gpu \
steps=30 dcn_data_parallelism=4 ici_fsdp_parallelism=8 per_device_batch_size=4 max_target_length=4096 model_name=llama2-7b \
enable_checkpointing=false attention=cudnn_flash_te remat_policy=minimal_flash use_iota_embed=true scan_layers=false \
dataset_type=synthetic async_checkpointing=false base_output_directory=gs:https://runner-maxtext-logs enable_profiler=true

43 changes: 43 additions & 0 deletions MaxText/configs/a3/llama_2_7b/8vm.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
echo "Running 8vm.sh"
# Example command to invoke this script
# bash MaxText/configs/a3/llama_2_7b/8vm.sh
#
# Example command to invoke this script via XPK
# python3 xpk/xpk.py workload create --cluster ${CLUSTER_NAME} \
# --workload ${WORKLOAD_NAME} --docker-image=gcr.io/supercomputer-testing/${LOCAL_IMAGE_NAME} \
# --device-type ${DEVICE_TYPE} --num-slices 8 \
# --command "bash MaxText/configs/a3/llama_2_7b/8vm.sh"

# Stop execution if any command exits with error
set -e

export OUTPUT_PATH="gs:https://maxtext-experiments-multipod"
export RUN_NAME="llama-2-8vm-$(date +%Y-%m-%d-%H-%M)"

# Set environment variables
for ARGUMENT in "$@"; do
IFS='=' read -r KEY VALUE <<< "$ARGUMENT"
export "$KEY"="$VALUE"
done

export XLA_PYTHON_CLIENT_MEM_FRACTION=0.85
export CUDA_DEVICE_MAX_CONNECTIONS=1
export NVTE_FUSED_ATTN=1
export NCCL_DEBUG=VERSION
export XLA_FLAGS="--xla_dump_to=$OUTPUT_PATH/$RUN_NAME/HLO_dumps/
--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_async_all_gather=true --xla_gpu_enable_async_reduce_scatter=true
--xla_gpu_enable_triton_gemm=false --xla_gpu_simplify_all_fp_conversions --xla_gpu_graph_level=0
--xla_gpu_enable_async_all_reduce=true --xla_gpu_enable_highest_priority_async_stream=true
--xla_gpu_all_reduce_combine_threshold_bytes=1073741824 --xla_gpu_all_gather_combine_threshold_bytes=134217728
--xla_gpu_reduce_scatter_combine_threshold_bytes=67108864 --xla_gpu_enable_pipelined_all_gather=true
--xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true
--xla_gpu_enable_while_loop_double_buffering=true --xla_gpu_enable_triton_softmax_fusion=false
--xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false
--xla_disable_hlo_passes=rematerialization"

# 8 nodes
python MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME hardware=gpu \
steps=30 dcn_data_parallelism=8 ici_fsdp_parallelism=8 per_device_batch_size=4 max_target_length=4096 model_name=llama2-7b \
enable_checkpointing=false attention=cudnn_flash_te remat_policy=minimal_flash use_iota_embed=true scan_layers=false \
dataset_type=synthetic async_checkpointing=false base_output_directory=gs:https://runner-maxtext-logs enable_profiler=true

0 comments on commit 1200dd7

Please sign in to comment.