Skip to content

Commit

Permalink
Merge branch 'main' into jax_nightly
Browse files Browse the repository at this point in the history
  • Loading branch information
SurbhiJainUSC committed Feb 29, 2024
2 parents 0c42227 + 2e7e967 commit a7f807d
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
4 changes: 3 additions & 1 deletion MaxText/convert_gpt3_ckpt_from_paxml.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from jax import random
from jax.sharding import Mesh
from layers.models import Transformer
from layers import quantizations
import checkpointing

import numpy as np
Expand Down Expand Up @@ -89,7 +90,8 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name
devices_array = max_utils.create_device_mesh(cfg)
mesh = Mesh(devices_array, cfg.mesh_axes)

model = Transformer(cfg, mesh)
quant = quantizations.configure_quantization(cfg)
model = Transformer(cfg, mesh, quant=quant)
learning_rate_schedule = max_utils.create_learning_rate_schedule(cfg)
tx = optimizers.get_optimizer(cfg, learning_rate_schedule)

Expand Down
7 changes: 4 additions & 3 deletions end_to_end/test_convergence_1b_params.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/bin/bash

echo "Running test_convergence_1b_params.sh"
# Run this on 64 chips to achieve a loss value of ~2.5 (v4-128)
# Run this on 64 chips to achieve a loss value of ~2.5 after 20400 steps, or ~2.7 after 10200 steps (v4-128)
#
# Command Flags:
# OUTPUT_PATH (Required, unless base_output_directory is already set in base.yml)
Expand All @@ -16,6 +16,7 @@ echo "Running test_convergence_1b_params.sh"
set -e

export LOSS_THRESHOLD=100.0 # Set to large value so test is guaranteed to pass.
export STEPS=20400 # Run for 20B tokens for a 1B sized mode for "chinchilla" scaling https://arxiv.org/abs/2203.15556

# Set environment variables
for ARGUMENT in "$@"; do
Expand All @@ -34,10 +35,10 @@ then
fi

TRAIN_CMD="python3 MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME\
steps=20400 per_device_batch_size=8.0 learning_rate=3e-4 enable_checkpointing=false \
steps=$STEPS per_device_batch_size=8.0 learning_rate=3e-4 enable_checkpointing=false \
max_target_length=2048 global_parameter_scale=1 \
enable_profiler=false metrics_file=metrics.txt base_output_directory=$OUTPUT_PATH\
dataset_path=$DATASET_PATH log_period=150 enable_data_shuffling=false"
dataset_path=$DATASET_PATH log_period=150 remat_policy=minimal enable_data_shuffling=false"
TRAIN_CMD+=$CMD_DATA

# Train
Expand Down

0 comments on commit a7f807d

Please sign in to comment.