Skip to content

Latest commit

 

History

History

flyte_llama

Flyte Llama

Flyte Llama is a fine-tuned model based on Code Llama.

Env Setup

python -m venv ~/venvs/flyte-llama
source ~/venvs/flyte-llama/bin/activate
pip install -r requirements.txt

Train model

Export Environment Variables

export PYTHONPATH=$(pwd):$PYTHONPATH
export FLYTECTL_CONFIG=~/.uctl/config-demo.yaml  # replace with your flyte/union cloud config
export REGISTRY=ghcr.io/unionai-oss  # replace this with your own registry
export FLYTE_PROJECT=llm-fine-tuning

Create dataset

python flyte_llama/dataset.py --output-path ~/datasets/flyte_llama

Train Model

Local

pyflyte run flyte_llama/workflows.py train \
    --dataset ~/datasets/flyte_llama \
    --config config/local.json

Flyte Llama 7b Qlora

Train:

pyflyte -c $FLYTECTL_CONFIG run --remote \
    --copy-all \
    --project $FLYTE_PROJECT \
    flyte_llama/workflows.py train_workflow \
    --config config/flyte_llama_7b_qlora_v0.json

Publish:

pyflyte run --remote \
    --copy-all \
    --project $FLYTE_PROJECT \
    flyte_llama/workflows.py publish_model \
    --config config/flyte_llama_7b_qlora_v0.json \
    --model_dir s3:https://path/to/model

Flyte Llama 7b Qlora from previous adapter checkpoint

Pass in the --pretrained_adapter flag to continue training from a previous adapter checkpoint. This is typically an s3 path produced by train_workflow.

pyflyte run --remote \
    --copy-all \
    --project $FLYTE_PROJECT \
    flyte_llama/workflows.py train_workflow \
    --config config/flyte_llama_7b_qlora_v0.json \
    --pretrained_adapter s3:https://path/to/checkpoint

Flyte Llama 7b Instruct Qlora

Train:

pyflyte run --remote \
    --copy-all \
    --project $FLYTE_PROJECT \
    --image $IMAGE \
    flyte_llama/workflows.py train_workflow \
    --config config/flyte_llama_7b_instruct_qlora_v0.json

Publish:

pyflyte run --remote \
    --copy-all \
    --project $FLYTE_PROJECT \
    --image $IMAGE \
    flyte_llama/workflows.py publish_model \
    --config config/flyte_llama_7b_instruct_qlora_v0.json \
    --model_dir s3:https://path/to/model

Flyte Llama 13b Qlora

pyflyte run --remote \
    --copy-all \
    --project $FLYTE_PROJECT \
    --image $IMAGE \
    flyte_llama/workflows.py train_workflow \
    --config config/flyte_llama_13b_qlora_v0.json

Flyte Llama 34b Qlora

pyflyte run --remote \
    --copy-all \
    --project $FLYTE_PROJECT \
    --image $IMAGE \
    flyte_llama/workflows.py train_workflow \
    --config config/flyte_llama_34b_qlora_v0.json

Serve model

This project uses ModelZ as the serving layer.

Create a secrets.txt file to hold your sensitive credentials:

# do this once
echo MODELZ_USER_ID="<replace>" >> secrets.txt
echo MODELZ_API_KEY="<replace>" >> secrets.txt
echo HF_AUTH_TOKEN="<replace>" >> secrets.txt
Serving a POST Endpoint

Export env vars:

eval $(sed 's/^/export /g' secrets.txt)
export VERSION=$(git rev-parse --short=7 HEAD)
export SERVING_IMAGE=ghcr.io/unionai-oss/modelz-flyte-llama-serving:$VERSION

Build the serving image:

docker build . -f Dockerfile.server -t $SERVING_IMAGE
docker push $SERVING_IMAGE

Deploy:

python deploy.py \
    --deployment-name flyte-llama-$VERSION \
    --image $SERVING_IMAGE \
    --server-resource "nvidia-ada-l4-2-24c-96g"

Get the deployment_key from the output of the command above and use it to test the model:

python client.py \
    --prompt "The code snippet below shows a basic Flyte workflow" \
    --output-file output.txt \
    --deployment-key <deployment_key>

Serving a Server Streaming Events (SSE) Endpoint

Export env vars:

eval $(sed 's/^/export /g' secrets.txt)
export VERSION=$(git rev-parse --short=7 HEAD)
export SERVING_SSE_IMAGE=ghcr.io/unionai-oss/modelz-flyte-llama-serving-sse:$VERSION

Build the serving image:

docker build . -f Dockerfile.server_sse -t $SERVING_SSE_IMAGE
docker push $SERVING_SSE_IMAGE

Deploy:

python deploy.py \
    --deployment-name flyte-llama-sse-$VERSION \
    --image $SERVING_SSE_IMAGE \
    --server-resource "nvidia-ada-l4-4-48c-192g" \
    --stream

Get the deployment_key from the output of the command above and use it to test the model:

python client_sse.py \
    --prompt "The code snippet below shows a basic Flyte workflow" \
    --n-tokens 250 \
    --output-file output.txt \
    --deployment-key <deployment_key>

🔖 Model Card

Dataset

This system will be based on all of the Flyte codebases:

The dataset will consist of source files, tests, and documentation from all of these repositories.

Data Source Extensions

This dataset could be enriched with open source repos that use Flyte in their codebase, which includes open source repos maintained by the Flyte core team and those maintained by the community. This would further train the model on how the community uses flytekit or configures their codebase in the wild.

LLM-augmented supervised finetuning

We can build a supervised finetuning dataset using an LLM to generate a synthetic instruction given a piece of Flyte code. For example, given a flytesnacks example, an LLM can be prompted to create an instruction associated with that example. Or, given a flytekit plugin, an LLM can be prompted to create an instruction associated with creating the flytekit plugin class that implements the plugin interface.

Training

There are several possible training approaches to take:

  • Causal language modeling (CLM)
  • Masked language modeling (MLM)
  • Fill in the middle (FIM)

We'll start with the simplest case using CLM to get a baseline, then experiment with FIM since we may want Flyte Llama to be able to both complete code and suggest code given some suffix and prefix (see Resources section below).

Data Augmentation

There are many data augmentation techniques we can leverage on top of the training approaches mentioned above:

  • Add metadata to the context: This can include adding the repo name, file name, file extension to the beginnign of each training example to condition the token completion on the context of the code.

Evaluation

We'll use perplexity as a baseline metric for evaluating the model. This will capture how well the fine-tuned model fits the data.

It may be useful to keep hold-out data for evaluating the model's ability to generalize by excluding data from certain repos. For example, we can pretrain the model on pure Flyte source code and test it on example documentation repos, so you may have a train-test split as follows:

  • Training set: flyte, flytekit, flytepropeller, flyteplugins, flyteidl, flyteadmin, flyteconsole
  • Test set: flytesnacks, flyte-conference-talks

Though there may be some data leakage, for the most part the code in the example repos should be different enough from the code in the core source code repos that the model will have to figure out how to use the basic building blocks in the source code to generate the examples (this is somewhat what a human does to generate code examples).

Resources

🔧 Resource Tuning

Local

Run:

pyflyte run flyte_llama/workflows.py tune_batch_size \
    --config config/local.json \
    --batch_sizes '[2, 4]'

Flyte Llama 7b Qlora

Run:

pyflyte run flyte_llama/workflows.py tune_batch_size \
    --config config/flyte_llama_7b_qlora_v0.json \
    --batch_sizes '[4, 8, 16, 32]'