-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Data] Add benchmark for Ray Data + Trainer #37624
Changes from 1 commit
b98e1a6
9635e22
f103a51
f6c2391
0e9c354
ec75d20
2baba01
3391671
2093c62
bd8f866
6b498ee
0a80e2f
e8df198
f749365
8b4697b
2da2175
2366d9c
82ce8c6
fbac655
a1416eb
25e6f76
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
Signed-off-by: Scott Lee <[email protected]>
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
import ray | ||
from ray.air import session | ||
from ray.train.torch import TorchTrainer | ||
from ray.air.config import ScalingConfig | ||
|
||
|
||
import time | ||
|
||
import torch | ||
import torchvision | ||
|
||
DEFAULT_IMAGE_SIZE = 224 | ||
|
||
|
||
def iterate(dataset, label, metrics): | ||
start = time.time() | ||
it = iter(dataset) | ||
num_rows = 0 | ||
for batch in it: | ||
num_rows += len(batch) | ||
end = time.time() | ||
print(label, end - start, "epoch", i) | ||
|
||
tput = num_rows / (end - start) | ||
metrics[label] = tput | ||
|
||
|
||
def get_transform(to_torch_tensor): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto. |
||
# Note(swang): This is a different order from tf.data. | ||
# torch: decode -> randCrop+resize -> randFlip | ||
# tf.data: decode -> randCrop -> randFlip -> resize | ||
transform = torchvision.transforms.Compose( | ||
[ | ||
torchvision.transforms.RandomResizedCrop( | ||
size=DEFAULT_IMAGE_SIZE, | ||
scale=(0.05, 1.0), | ||
ratio=(0.75, 1.33), | ||
), | ||
torchvision.transforms.RandomHorizontalFlip(), | ||
] | ||
+ [torchvision.transforms.ToTensor()] | ||
if to_torch_tensor | ||
else [] | ||
) | ||
return transform | ||
|
||
|
||
def crop_and_flip_image_batch(image_batch): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto |
||
transform = get_transform(False) | ||
batch_size, height, width, channels = image_batch["image"].shape | ||
tensor_shape = (batch_size, channels, height, width) | ||
image_batch["image"] = transform( | ||
torch.Tensor(image_batch["image"].reshape(tensor_shape)) | ||
) | ||
return image_batch | ||
|
||
|
||
if __name__ == "__main__": | ||
import argparse | ||
|
||
parser = argparse.ArgumentParser() | ||
|
||
parser.add_argument( | ||
"--data-root", | ||
default="s3:https://air-cuj-imagenet-1gb", | ||
type=str, | ||
help='Directory path with TFRecords. Filenames should start with "train".', | ||
) | ||
parser.add_argument( | ||
"--batch-size", | ||
default=32, | ||
type=int, | ||
help="Batch size to use.", | ||
) | ||
parser.add_argument( | ||
"--num-epochs", | ||
default=2, | ||
type=int, | ||
help="Number of epochs to run. The throughput for the last epoch will be kept.", | ||
) | ||
args = parser.parse_args() | ||
|
||
metrics = {} | ||
ray_dataset = ray.data.read_images(args.data_root).map_batches( | ||
crop_and_flip_image_batch | ||
) | ||
for i in range(args.num_epochs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is probably not useful. In this benchmark, I think we want to make the usage as close to the real training workloads as possible. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we included this to compare throughput between the data ingestion and the training phases. If we don't need benchmarking for this part, I can just remove this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think what we currently measure in the training loop function is already data ingestion throughput (because we don't apply a real model). |
||
iterate( | ||
ray_dataset.iter_torch_batches(batch_size=args.batch_size), | ||
"ray.data+transform", | ||
metrics, | ||
) | ||
|
||
def train_loop_per_worker(): | ||
# Get an iterator to the dataset we passed in below. | ||
it = session.get_dataset_shard("train") | ||
|
||
# Train for 10 epochs over the data. We'll use a shuffle buffer size | ||
# of 10k elements, and prefetch up to 10 batches of size 128 each. | ||
for _ in range(10): | ||
for batch in it.iter_batches( | ||
local_shuffle_buffer_size=10000, batch_size=128, prefetch_batches=10 | ||
): | ||
pass | ||
|
||
start_t = time.time() | ||
torch_trainer = TorchTrainer( | ||
train_loop_per_worker, | ||
scaling_config=ScalingConfig(num_workers=2), | ||
datasets={"train": ray_dataset}, | ||
) | ||
torch_trainer.fit() | ||
end_t = time.time() | ||
metrics["ray.torchtrainer.fit"] = end_t - start_t |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's reuse the same method in https://github.com/ray-project/ray/blob/master/release/nightly_tests/dataset/image_loader_microbenchmark.py#L17-L27, we can put the method into a util.py.