Skip to content

younesbelkada/t5x

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

T5X

T5X is a modular, composable, research-friendly framework for high-performance, configurable, self-service training, evaluation, and inference of sequence models (starting with language) at many scales.

It is essentially a new and improved implementation of the T5 codebase (based on Mesh TensorFlow) in JAX and Flax.

Below is a quick start guide for training models with TPUs on Google Cloud. For additional tutorials and background, see the complete documentation.

Quickstart (Recommended)

T5X can be run with XManager on Vertex AI. Vertex AI is a platform for training that creates TPU instances and runs code on the TPUs. Vertex AI will also shut down the TPUs when the jobs terminate. This is signifcantly easier than managing GCE VMs and TPU VM instances.

  1. Follow the pre-requisites and directions to install XManager.

  2. Request TPU quota as required. GCP projects come with 8 cores by default, which is enough to run one training experiment on a single TPU host. If you want to run multi-host training or run multiple trials in parallel, you will need more quota. Navigate to Quotas.

The quota you want is:

  • Service: Vertex AI API
  • Dimensions (location): us-central1
  • If you want to run single-host experiments:
    • Custom model training TPU V2 cores per region
    • Custom model training TPU V3 cores per region
  • If you want to run multi-host experiments:
    • Custom model training TPU V2 pod cores per region
    • Custom model training TPU V3 pod cores per region

TIP: You won't be able to run single-host experiments with multi-host quota. (i.e. you can't run tpu_v2=8 using TPU V2 pod)

  1. Launch the xmanager script located at t5x/scripts/xm_launch.py.

As a running example, we use the WMT14 En-De translation which is described in more detail in the Examples section below.

export GOOGLE_CLOUD_BUCKET_NAME=...
export TFDS_DATA_DIR=gs:https://$GOOGLE_CLOUD_BUCKET_NAME/t5x/data
export MODEL_DIR=gs:https://$GOOGLE_CLOUD_BUCKET_NAME/t5x/$(date +%Y%m%d)

# Pre-download dataset in multi-host experiments.
tfds build wmt_t2t_translate --data_dir=$TFDS_DATA_DIR

git clone https://github.com/google-research/t5x
cd ./t5x/

python3 ./t5x/scripts/xm_launch.py \
  --gin_file=t5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin \
  --model_dir=$MODEL_DIR \
  --tfds_data_dir=$TFDS_DATA_DIR

Check gs:https://$GOOGLE_CLOUD_BUCKET_NAME/t5x/ for the output artifacts, which can be read by TensorBoard.

Installation

Note that all the commands in this document should be run in the commandline of the TPU VM instance unless otherwise stated.

  1. Follow the instructions to set up a Google Cloud Platform (GCP) account and enable the Cloud TPU API.

    Note: While T5X works with GPU as well, we haven't heavily tested the GPU usage.

  2. Create a Cloud TPU VM instance following this instruction. We recommend that you develop your workflow in a single v3-8 TPU (i.e., --accelerator-type=v3-8) and scale up to pod slices once the pipeline is ready. In this README, we focus on using a single v3-8 TPU. See here to learn more about TPU architectures.

  3. With Cloud TPU VMs, you ssh directly into the host machine of the TPU VM. You can install packages, run your code run, etc. in the host machine. Once the TPU instance is created, ssh into it with

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --zone=${ZONE}

    where TPU_NAME and ZONE are the name and the zone used in step 2.

  4. Install T5X and the dependencies.

    git clone --branch=main https://github.com/google-research/t5x
    cd t5x
    
    python3 -m pip install -e '.[tpu]' -f \
      https://storage.googleapis.com/jax-releases/libtpu_releases.html
    
  5. Create Google Cloud Storage (GCS) bucket to store the dataset and model checkpoints. To create a GCS bucket, see these instructions.

Example: English to German translation

As a running example, we use the WMT14 En-De translation. The raw dataset is available in TensorFlow Datasets as "wmt_t2t_translate".

T5 casts the translation task such as the following

{'en': 'That is good.', 'de': 'Das ist gut.'}

to the form called "text-to-text":

{'inputs': 'translate English to German: That is good.', 'targets': 'Das ist gut.'}

This formulation allows many different classes of language tasks to be expressed in a uniform manner and a single encoder-decoder architecture can handle them without any task-specific parameters. For more detail, refer to the T5 paper (Raffel et al. 2019).

For a scalable data pipeline and an evaluation framework