Flyte Llama is a fine-tuned model based on Code Llama.
python -m venv ~/venvs/flyte-llama
source ~/venvs/flyte-llama/bin/activate
pip install -r requirements.txt
export PYTHONPATH=$(pwd):$PYTHONPATH
export FLYTECTL_CONFIG=~/.uctl/config-demo.yaml # replace with your flyte/union cloud config
export REGISTRY=ghcr.io/unionai-oss # replace this with your own registry
export FLYTE_PROJECT=llm-fine-tuning
python flyte_llama/dataset.py --output-path ~/datasets/flyte_llama
Local
pyflyte run flyte_llama/workflows.py train \
--dataset ~/datasets/flyte_llama \
--config config/local.json
Flyte Llama 7b Qlora
Train:
pyflyte -c $FLYTECTL_CONFIG run --remote \
--copy-all \
--project $FLYTE_PROJECT \
flyte_llama/workflows.py train_workflow \
--config config/flyte_llama_7b_qlora_v0.json
Publish:
pyflyte run --remote \
--copy-all \
--project $FLYTE_PROJECT \
flyte_llama/workflows.py publish_model \
--config config/flyte_llama_7b_qlora_v0.json \
--model_dir s3:https://path/to/model
Flyte Llama 7b Qlora from previous adapter checkpoint
Pass in the --pretrained_adapter
flag to continue training from a previous
adapter checkpoint. This is typically an s3 path produced by train_workflow
.
pyflyte run --remote \
--copy-all \
--project $FLYTE_PROJECT \
flyte_llama/workflows.py train_workflow \
--config config/flyte_llama_7b_qlora_v0.json \
--pretrained_adapter s3:https://path/to/checkpoint
Flyte Llama 7b Instruct Qlora
Train:
pyflyte run --remote \
--copy-all \
--project $FLYTE_PROJECT \
--image $IMAGE \
flyte_llama/workflows.py train_workflow \
--config config/flyte_llama_7b_instruct_qlora_v0.json
Publish:
pyflyte run --remote \
--copy-all \
--project $FLYTE_PROJECT \
--image $IMAGE \
flyte_llama/workflows.py publish_model \
--config config/flyte_llama_7b_instruct_qlora_v0.json \
--model_dir s3:https://path/to/model
Flyte Llama 13b Qlora
pyflyte run --remote \
--copy-all \
--project $FLYTE_PROJECT \
--image $IMAGE \
flyte_llama/workflows.py train_workflow \
--config config/flyte_llama_13b_qlora_v0.json
Flyte Llama 34b Qlora
pyflyte run --remote \
--copy-all \
--project $FLYTE_PROJECT \
--image $IMAGE \
flyte_llama/workflows.py train_workflow \
--config config/flyte_llama_34b_qlora_v0.json
This project uses ModelZ as the serving layer.
Create a secrets.txt
file to hold your sensitive credentials:
# do this once
echo MODELZ_USER_ID="<replace>" >> secrets.txt
echo MODELZ_API_KEY="<replace>" >> secrets.txt
echo HF_AUTH_TOKEN="<replace>" >> secrets.txt
Serving a POST Endpoint
Export env vars:
eval $(sed 's/^/export /g' secrets.txt)
export VERSION=$(git rev-parse --short=7 HEAD)
export SERVING_IMAGE=ghcr.io/unionai-oss/modelz-flyte-llama-serving:$VERSION
Build the serving image:
docker build . -f Dockerfile.server -t $SERVING_IMAGE
docker push $SERVING_IMAGE
Deploy:
python deploy.py \
--deployment-name flyte-llama-$VERSION \
--image $SERVING_IMAGE \
--server-resource "nvidia-ada-l4-2-24c-96g"
Get the deployment_key
from the output of the command above and use it to test
the model:
python client.py \
--prompt "The code snippet below shows a basic Flyte workflow" \
--output-file output.txt \
--deployment-key <deployment_key>
Serving a Server Streaming Events (SSE) Endpoint
Export env vars:
eval $(sed 's/^/export /g' secrets.txt)
export VERSION=$(git rev-parse --short=7 HEAD)
export SERVING_SSE_IMAGE=ghcr.io/unionai-oss/modelz-flyte-llama-serving-sse:$VERSION
Build the serving image:
docker build . -f Dockerfile.server_sse -t $SERVING_SSE_IMAGE
docker push $SERVING_SSE_IMAGE
Deploy:
python deploy.py \
--deployment-name flyte-llama-sse-$VERSION \
--image $SERVING_SSE_IMAGE \
--server-resource "nvidia-ada-l4-4-48c-192g" \
--stream
Get the deployment_key
from the output of the command above and use it to test
the model:
python client_sse.py \
--prompt "The code snippet below shows a basic Flyte workflow" \
--n-tokens 250 \
--output-file output.txt \
--deployment-key <deployment_key>
This system will be based on all of the Flyte codebases:
- flyte: Flyte's main repo
- flytekit: Python SDK
- flytepropeller Kubernetes-native operator for Flyte
- flyteplugins: Backend Flyte plugins
- flyteidl: Flyte language specification in protobuf
- flyteadmin: Flyte's control plane
- flyteconsole: UI console
- flytesnacks: Example repo
- flyte-conference-talks: Repo of conference talks
The dataset will consist of source files, tests, and documentation from all of these repositories.
This dataset could be enriched with open source repos that use Flyte in their codebase, which includes open source repos maintained by the Flyte core team and those maintained by the community. This would further train the model on how the community uses flytekit or configures their codebase in the wild.
We can build a supervised finetuning dataset using an LLM to generate a synthetic
instruction given a piece of Flyte code. For example, given a flytesnacks
example,
an LLM can be prompted to create an instruction associated with that example. Or,
given a flytekit plugin, an LLM can be prompted to create an instruction associated
with creating the flytekit plugin class that implements the plugin interface.
There are several possible training approaches to take:
- Causal language modeling (CLM)
- Masked language modeling (MLM)
- Fill in the middle (FIM)
We'll start with the simplest case using CLM to get a baseline, then experiment with FIM since we may want Flyte Llama to be able to both complete code and suggest code given some suffix and prefix (see Resources section below).
There are many data augmentation techniques we can leverage on top of the training approaches mentioned above:
- Add metadata to the context: This can include adding the repo name, file name, file extension to the beginnign of each training example to condition the token completion on the context of the code.
We'll use perplexity as a baseline metric for evaluating the model. This will capture how well the fine-tuned model fits the data.
It may be useful to keep hold-out data for evaluating the model's ability to generalize by excluding data from certain repos. For example, we can pretrain the model on pure Flyte source code and test it on example documentation repos, so you may have a train-test split as follows:
- Training set:
flyte
,flytekit
,flytepropeller
,flyteplugins
,flyteidl
,flyteadmin
,flyteconsole
- Test set:
flytesnacks
,flyte-conference-talks
Though there may be some data leakage, for the most part the code in the example repos should be different enough from the code in the core source code repos that the model will have to figure out how to use the basic building blocks in the source code to generate the examples (this is somewhat what a human does to generate code examples).
Local
Run:
pyflyte run flyte_llama/workflows.py tune_batch_size \
--config config/local.json \
--batch_sizes '[2, 4]'
Flyte Llama 7b Qlora
Run:
pyflyte run flyte_llama/workflows.py tune_batch_size \
--config config/flyte_llama_7b_qlora_v0.json \
--batch_sizes '[4, 8, 16, 32]'