Skip to content

Commit

Permalink
fix gpt3 conversion script
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhiyuLi-goog committed Feb 29, 2024
1 parent c3cf02a commit 988d8c9
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion MaxText/convert_gpt3_ckpt_from_paxml.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from jax import random
from jax.sharding import Mesh
from layers.models import Transformer
from layers import quantizations
import checkpointing

import numpy as np
Expand Down Expand Up @@ -89,7 +90,8 @@ def convert(paxml_ckpt_path, maxtext_model_name, base_output_directory, run_name
devices_array = max_utils.create_device_mesh(cfg)
mesh = Mesh(devices_array, cfg.mesh_axes)

model = Transformer(cfg, mesh)
quant = quantizations.configure_quantization(cfg)
model = Transformer(cfg, mesh, quant=quant)
learning_rate_schedule = max_utils.create_learning_rate_schedule(cfg)
tx = optimizers.get_optimizer(cfg, learning_rate_schedule)

Expand Down

0 comments on commit 988d8c9

Please sign in to comment.