diff --git a/MaxText/configs/llama2_70b_gpu.yml b/MaxText/configs/llama2_70b_gpu.yml new file mode 100644 index 000000000..7fab4727d --- /dev/null +++ b/MaxText/configs/llama2_70b_gpu.yml @@ -0,0 +1,18 @@ +base_config: "base.yml" + +run_name: "gpu_train_test" +# Args coming from the NVIDIA spreadsheet http://shortn/_AhULYn1mX4. +hardware: "gpu" +steps: 30 +model_name: "llama2-70b" +enable_checkpointing: False +attention: "cudnn_flash_te" +remat_policy: "full" +use_iota_embed: True +scan_layers: True +dataset_type: "synthetic" +async_checkpointing: False +logits_dot_in_fp32: False + +per_device_batch_size: 6 +max_target_length: 4096 \ No newline at end of file