forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[air] Dreambooth finetuning workspace template (ray-project#37851)
Signed-off-by: Justin Yu <[email protected]> Signed-off-by: NripeshN <[email protected]>
- Loading branch information
Showing
26 changed files
with
370 additions
and
118 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
# DreamBooth fine-tuning of Stable Diffusion with Ray Train | ||
|
||
| Template Specification | Description | | ||
| ---------------------- | ----------- | | ||
| Summary | This example shows how to do [DreamBooth fine-tuning](https://dreambooth.github.io/) of a Stable Diffusion model using Ray Train for data-parallel training with many workers and Ray Data for data ingestion. Use one of the provided datasets, or supply your own photos. By the end of this example, you'll be able to generate images of your subject in a variety of situations, just by feeding in a text prompt! | | ||
| Time to Run | ~10-15 minutes to generate a regularization dataset and fine-tune the model on photos of your subject. | | ||
| Minimum Compute Requirements | At least 2 GPUs, where each GPU has >= 24GB GRAM. The default is 1 node with an A10G GPU (AWS) or a A100G GPU 40GB (GCE). | | ||
| Cluster Environment | This template uses a docker image built on top of the latest Anyscale-provided Ray image using Python 3.9: [`anyscale/ray:latest-py39-cu118`](https://docs.anyscale.com/reference/base-images/overview). See the appendix below for more details. | | ||
|
||
![Dreambooth fine-tuning sample results](dreambooth/images/dreambooth_example.png) | ||
|
||
## Run the example | ||
|
||
This README will only contain minimal instructions on running this example on Anyscale. | ||
See [the guide on the Ray documentation](https://docs.ray.io/en/latest/ray-air/examples/dreambooth_finetuning.html) | ||
for a step-by-step walkthrough of the training code. | ||
|
||
You can get started fine-tuning on a sample dog dataset with default settings with the following commands: | ||
|
||
```bash | ||
chmod +x ./dreambooth_run.sh | ||
./dreambooth_run.sh | ||
``` | ||
|
||
## Customizing the example | ||
|
||
Here are a few modifications to the `dreambooth_run.sh` script that you may want to make: | ||
|
||
1. The image dataset of your subject. This example provides two sample datasets, but you can also supply your own directory of 4-5 images, as well as the general class your subject falls under. For example, the dog dataset contains images of one particular puppy, and the general class this subject falls under is `dog`. | ||
- Modify the `$CLASS_NAME` and `$INSTANCE_DIR` environment variables. | ||
2. The `$DATA_PREFIX` that the pre-trained model is downloaded to. This directory is also where the training dataset and the fine-tuned model checkpoint are written at the end of training. | ||
- If you add more worker nodes to the cluster, you should `$DATA_PREFIX` this to a shared NFS filesystem such as `/mnt/cluster_storage`. See [this page of the docs](https://docs.anyscale.com/develop/workspaces/storage#storage-shared-across-nodes) for all the options. | ||
- Note that each run of the script will overwrite the fine-tuned model checkpoint from the previous run, so consider changing the `$DATA_PREFIX` environment variable on each run if you don't want to lose the models/data of previous runs. | ||
3. The `$NUM_WORKERS` variable sets the number of data-parallel workers used during fine-tuning. The default is 2 workers (2 workers, each using 2 GPUs), and you should increase this number if you add more GPU worker nodes to the cluster. | ||
4. Setting `--num_epochs` and `--max_train_steps` determines the number of fine-tuning steps to take. | ||
- Depending on the batch size and number of data-parallel workers, one epoch will run for a certain number of steps. The run will terminate when one of these values (epoch vs. total number of steps) is reached. | ||
5. `generate.py` is used to generate stable diffusion images after loading the model from a checkpoint. You should modify the prompt at the end to be something more interesting, rather than just a photo of your subject. | ||
6. If you want to launch another fine-tuning run, you may want to run *only* the `python train.py ...` command. Running the bash script will start from the beginning (generating another regularization dataset). | ||
|
||
## Interact with the fine-tuned model | ||
|
||
### Generate images with a script | ||
|
||
Use the `generate.py` script to generate images with a prompt. | ||
Replace the variables with the values that you used in the fine-tuning script. | ||
See `run_model_flags` in `flags.py` for a full list of available command line arguments to pass to the script. | ||
|
||
```bash | ||
python generate.py \ | ||
--model_dir=$TUNED_MODEL_DIR \ | ||
--output_dir=$IMAGES_NEW_DIR \ | ||
--prompts="photo of a $UNIQUE_TOKEN $CLASS_NAME" \ | ||
--num_samples_per_prompt=5 | ||
``` | ||
|
||
### Generate images interactively in a notebook | ||
|
||
See the `playground.ipynb` notebook for a more interactive way to generate images with the fine-tuned model. | ||
Click on the Jupyter or VSCode icon on the workspace page and open the notebook. | ||
|
||
## Appendix | ||
|
||
### Advanced: Build off of this template's cluster environment | ||
|
||
#### Option 1: Build a new cluster environment on Anyscale | ||
|
||
The requirements are listed in `dreambooth/requirements.txt`. Feel free to modify this to include more requirements, then follow [this guide](https://docs.anyscale.com/configure/dependency-management/cluster-environments#creating-a-cluster-environment) to use the `anyscale` CLI to create a new cluster environment. The requirements should be pasted into the cluster environment yaml. | ||
|
||
Finally, update your workspace's cluster environment to this new one after it's done building. | ||
|
||
#### Option 2: Build a new docker image with your own infrastructure | ||
|
||
Use the following `docker pull` command if you want to manually build a new Docker image based off of this one. | ||
|
||
```bash | ||
docker pull us-docker.pkg.dev/anyscale-workspace-templates/workspace-templates/dreambooth-finetuning:latest | ||
``` |
9 changes: 9 additions & 0 deletions
9
doc/source/templates/05_dreambooth_finetuning/configs/Dockerfile
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# Run `docker build` with this from the 05_dreambooth_finetuning directory | ||
FROM anyscale/ray:latest-py39-cu118 | ||
|
||
COPY dreambooth/requirements.txt ./ | ||
|
||
RUN pip install --no-cache-dir -U -r requirements.txt | ||
|
||
RUN echo "Testing Ray Import..." && python -c "import ray" | ||
RUN ray --version |
5 changes: 5 additions & 0 deletions
5
doc/source/templates/05_dreambooth_finetuning/configs/compute/aws.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
head_node_type: | ||
name: head_node_type | ||
instance_type: g5.12xlarge | ||
|
||
max_workers: 0 |
7 changes: 7 additions & 0 deletions
7
doc/source/templates/05_dreambooth_finetuning/configs/compute/gce.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# 4 A100 GPUs with 40gb GRAM each | ||
# This is a bit overkill, but instances with L4/A10G GPUs are not yet available on GCE | ||
head_node_type: | ||
name: head_node_type | ||
instance_type: a2-highgpu-2g-nvidia-a100-40gb-4 | ||
|
||
max_workers: 0 |
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
File renamed without changes
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
97 changes: 97 additions & 0 deletions
97
doc/source/templates/05_dreambooth_finetuning/dreambooth_run.sh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
#!/bin/bash | ||
# shellcheck disable=SC2086 | ||
|
||
set -xe | ||
|
||
# Step 0 | ||
pushd dreambooth || true | ||
|
||
# Step 0 cont | ||
# TODO: If running on multiple nodes, change this path to a shared directory (ex: NFS) | ||
export DATA_PREFIX="/tmp" | ||
export ORIG_MODEL_NAME="CompVis/stable-diffusion-v1-4" | ||
export ORIG_MODEL_HASH="b95be7d6f134c3a9e62ee616f310733567f069ce" | ||
export ORIG_MODEL_DIR="$DATA_PREFIX/model-orig" | ||
export ORIG_MODEL_PATH="$ORIG_MODEL_DIR/models--${ORIG_MODEL_NAME/\//--}/snapshots/$ORIG_MODEL_HASH" | ||
export TUNED_MODEL_DIR="$DATA_PREFIX/model-tuned" | ||
export IMAGES_REG_DIR="$DATA_PREFIX/images-reg" | ||
export IMAGES_OWN_DIR="$DATA_PREFIX/images-own" | ||
export IMAGES_NEW_DIR="$DATA_PREFIX/images-new" | ||
# TODO: Add more worker nodes and increase NUM_WORKERS for more data-parallelism | ||
export NUM_WORKERS=2 | ||
|
||
mkdir -p $ORIG_MODEL_DIR $TUNED_MODEL_DIR $IMAGES_REG_DIR $IMAGES_OWN_DIR $IMAGES_NEW_DIR | ||
|
||
# Unique token to identify our subject (e.g., a random dog vs. our unqtkn dog) | ||
export UNIQUE_TOKEN="unqtkn" | ||
|
||
# Step 1 | ||
# Only uncomment one of the following: | ||
|
||
# Option 1: Use the dog dataset --------- | ||
export CLASS_NAME="dog" | ||
python download_example_dataset.py ./images/dog | ||
export INSTANCE_DIR=./images/dog | ||
# --------------------------------------- | ||
|
||
# Option 2: Use the lego car dataset ---- | ||
# export CLASS_NAME="car" | ||
# export INSTANCE_DIR=./images/lego-car | ||
# --------------------------------------- | ||
|
||
# Option 3: Use your own images --------- | ||
# export CLASS_NAME="<class-of-your-subject>" | ||
# export INSTANCE_DIR="/path/to/images/of/subject" | ||
# --------------------------------------- | ||
|
||
# Copy own images into IMAGES_OWN_DIR | ||
cp -rf $INSTANCE_DIR/* "$IMAGES_OWN_DIR/" | ||
|
||
# Step 2 | ||
python cache_model.py --model_dir=$ORIG_MODEL_DIR --model_name=$ORIG_MODEL_NAME --revision=$ORIG_MODEL_HASH | ||
|
||
# Clear reg dir | ||
rm -rf "$IMAGES_REG_DIR"/*.jpg | ||
|
||
# Step 3: START | ||
python generate.py \ | ||
--model_dir=$ORIG_MODEL_PATH \ | ||
--output_dir=$IMAGES_REG_DIR \ | ||
--prompts="photo of a $CLASS_NAME" \ | ||
--num_samples_per_prompt=200 \ | ||
--use_ray_data | ||
# Step 3: END | ||
|
||
# Step 4: START | ||
python train.py \ | ||
--model_dir=$ORIG_MODEL_PATH \ | ||
--output_dir=$TUNED_MODEL_DIR \ | ||
--instance_images_dir=$IMAGES_OWN_DIR \ | ||
--instance_prompt="photo of $UNIQUE_TOKEN $CLASS_NAME" \ | ||
--class_images_dir=$IMAGES_REG_DIR \ | ||
--class_prompt="photo of a $CLASS_NAME" \ | ||
--train_batch_size=2 \ | ||
--lr=5e-6 \ | ||
--num_epochs=10 \ | ||
--max_train_steps=400 \ | ||
--num_workers $NUM_WORKERS | ||
# Step 4: END | ||
|
||
# Clear new dir | ||
rm -rf "$IMAGES_NEW_DIR"/*.jpg | ||
|
||
# TODO: Change the prompt to something more interesting! | ||
# Step 5: START | ||
python generate.py \ | ||
--model_dir=$TUNED_MODEL_DIR \ | ||
--output_dir=$IMAGES_NEW_DIR \ | ||
--prompts="photo of a $UNIQUE_TOKEN $CLASS_NAME" \ | ||
--num_samples_per_prompt=5 | ||
# Step 5: END | ||
|
||
# Save artifact | ||
mkdir -p /tmp/artifacts | ||
cp -f "$IMAGES_NEW_DIR"/0-*.jpg /tmp/artifacts/example_out.jpg | ||
|
||
# Exit | ||
popd || true |
Oops, something went wrong.