forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Data] Add stable diffusion benchmark (ray-project#39524)
This PR adds a nightly test that benchmarks stable diffusion batch inference. --------- Signed-off-by: Balaji Veeramani <[email protected]>
- Loading branch information
1 parent
baa861d
commit 5dba924
Showing
6 changed files
with
169 additions
and
3 deletions.
There are no files selected for viewing
94 changes: 94 additions & 0 deletions
94
release/nightly_tests/dataset/stable_diffusion_benchmark.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
import argparse | ||
import json | ||
import os | ||
from timeit import default_timer as timer | ||
from typing import Dict | ||
|
||
import numpy as np | ||
import torch | ||
from diffusers import StableDiffusionImg2ImgPipeline | ||
|
||
import ray | ||
|
||
DATA_URI = "s3:https://air-example-data-2/10G-image-data-synthetic-raw-parquet/" | ||
# This isn't the largest batch size that fits in memory, but it achieves virtually 100% | ||
# GPU utilization, and throughput declines at higher batch sizes. | ||
BATCH_SIZE = 32 | ||
PROMPT = "ghibli style" | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser(description="Stable diffusion benchmark") | ||
parser.add_argument("--smoke-test", action="store_true") | ||
return parser.parse_args() | ||
|
||
|
||
def main(args): | ||
ray.init() | ||
ray.data.DataContext.get_current().execution_options.verbose_progress = True | ||
|
||
start_time = timer() | ||
|
||
dataset = ray.data.read_parquet(DATA_URI) | ||
|
||
if args.smoke_test: | ||
dataset = dataset.limit(1) | ||
|
||
actor_pool_size = int(ray.cluster_resources().get("GPU")) | ||
dataset = dataset.map_batches( | ||
GenerateImage, | ||
compute=ray.data.ActorPoolStrategy(size=actor_pool_size), | ||
batch_size=BATCH_SIZE, | ||
num_gpus=1, | ||
) | ||
|
||
num_images = 0 | ||
for batch in dataset.iter_batches(batch_format="pyarrow", batch_size=None): | ||
num_images += len(batch) | ||
|
||
end_time = timer() | ||
|
||
total_time = end_time - start_time | ||
throughput = num_images / total_time | ||
|
||
# For structured output integration with internal tooling | ||
results = { | ||
"data_uri": DATA_URI, | ||
"perf_metrics": { | ||
"total_time_s": total_time, | ||
"throughput_images_s": throughput, | ||
"num_images": num_images, | ||
}, | ||
} | ||
|
||
test_output_json = os.environ.get("TEST_OUTPUT_JSON", "release_test_out.json") | ||
with open(test_output_json, "wt") as f: | ||
json.dump(results, f) | ||
|
||
print(results) | ||
|
||
|
||
class GenerateImage: | ||
def __init__(self): | ||
device = "cuda" if torch.cuda.is_available() else "cpu" | ||
self.pipeline = StableDiffusionImg2ImgPipeline.from_pretrained( | ||
"nitrosocke/Ghibli-Diffusion", | ||
torch_dtype=torch.float16, | ||
use_safetensors=True, | ||
requires_safety_checker=False, | ||
safety_checker=None, | ||
).to(device) | ||
self.pipeline.set_progress_bar_config(disable=True) | ||
|
||
def __call__(self, batch: Dict[str, np.ndarray]): | ||
output = self.pipeline( | ||
prompt=[PROMPT] * len(batch["image"]), | ||
image=batch["image"], | ||
output_type="np", | ||
) | ||
return {"image": output.images} | ||
|
||
|
||
if __name__ == "__main__": | ||
args = parse_args() | ||
main(args) |
13 changes: 13 additions & 0 deletions
13
release/nightly_tests/dataset/stable_diffusion_benchmark_compute.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
cloud_id: {{env["ANYSCALE_CLOUD_ID"]}} | ||
region: us-west-2 | ||
|
||
head_node_type: | ||
name: head_node | ||
instance_type: m5.4xlarge | ||
|
||
worker_node_types: | ||
- name: worker_node | ||
instance_type: g4dn.4xlarge | ||
max_workers: 16 | ||
min_workers: 16 | ||
use_spot: false |
15 changes: 15 additions & 0 deletions
15
release/nightly_tests/dataset/stable_diffusion_benchmark_compute_gce.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
cloud_id: {{env["ANYSCALE_CLOUD_ID"]}} | ||
region: us-west1 | ||
allowed_azs: | ||
- us-west1-b | ||
|
||
head_node_type: | ||
name: head_node | ||
instance_type: n2-standard-16 # m5.4xlarge | ||
|
||
worker_node_types: | ||
- name: worker_node | ||
instance_type: n1-standard-16-nvidia-tesla-t4-1 # g4dn.4xlarge | ||
min_workers: 16 | ||
max_workers: 16 | ||
use_spot: false |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ crc32c | |
cupy-cuda113 | ||
datasets | ||
deepspeed | ||
diffusers | ||
evaluate | ||
fastapi | ||
filelock | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters