-
Notifications
You must be signed in to change notification settings - Fork 428
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Distributed TPU Training, training data stored in GCS #2690
Comments
how does the xla metrics report look? Does it differ from our regular nightly resnet50 runs too much in terms of ExecuteTime, CompileTime, etc ? @zcain117 |
I don't have any recent logs with metrics so I will wrap up a few of my experiments running now and then try to kick off 2 runs on v3-32, one using regular PD and one with SSD. Then we can compare metrics |
Attached the metrics for my v3-32 run (filtered to just 1 of the 4 VM's stdout)
Some of the XrtMetrics:
And yours:
@JackCaoG do you think @tottenjordan have you compared your speed if using a PD or SSD PD instead of the custom GCS reader? And regarding the custom GCS reader, if there any way you could increase the prefetching amount? |
|
GCS was ~5m30s per epoch and SSD PD was ~1m30s per epoch so ~4min * 14 epochs = ~56min cumulative difference but |
Forgot to include my training command: |
@tottenjordan I think the next thing to try is increase the I think this change should be in the 1.7 version you're using: https://github.com/pytorch/pytorch/releases/tag/v1.7.0 |
what are the percentiles for executetime? |
@zcain117 , @shanemhansen ran training jobs with PD and will post log/metrics report when available. Here is first successful run with
imagenetr1aw_logfiles6-v2-32-256batch-8wrks-4prefetch.txt @zcain117 @taylanbil Our data are stored as JPEG. Is it worth exploring Will run |
|
This is interesting. Increased to
imagenetr1aw_logfiles8-v2-32-256batch-16wrks-8prefetch.txt @zcain117 how do i get metrics_debug to print out in sequence by worker, instead of every other line potentially being from another worker? |
I don't know of any way to order the outputs. I just filter to 1 of the VM's output at a time when reading. E.g. you see lines like this where the VM's IP is mentioned: So I was using Earlier you mentioned good results with |
Switched to v3-32 TPU pod. Increasing the Simply, moving from v2-32 to v3-32 showed a slight speed up of ~20s
On v3-32, increasing
|
So a couple more results I would like to see:
|
Agreed. I'll work on these and follow-up with results.. I set up an experiment to test different configs of (1) TPU Pod version, (2) batch_size, (3) prefetch_factor, and (4) num_workers. And plan to do these for both GCS and PD/SSD PD. I've made progress on GCS trials, will continue to increase until get errors. then will try similar configs (where applicable) for PD |
@zcain117 Using SSD PD, average epoch training time = ~1:32
imagenetraw_logfiles11-SSDPD-v3-32-256batch-8workers-2prefetch.txt TO DO
|
That speed is consistent with what I've gotten using v3-32 and SSD on imagenet. It seems like this is ~twice as fast as the best GCS run so far, so there's still room for improvement in the GCS version for the given dataset+model architecture. It seems like your speed at 256 batch size is about the same as my best speed which was using 128 batch size. So if you run into a memory error when trying GCS with higher prefetch_factor, maybe consider lowering batch size since 256 vs 128 might not make much difference and might be a win if it allows more prefetch |
I've tested different combinations of the following
Haven't had impressive results with
Best avg. training epoch time = 2:46
logfiles-metrics-GCS-128bs-22work-18prefetch-v332pod.txt
|
I've yet to get a "memory error", but I am running into unexpected training metrics. I discussed with @zcain117 that this may be random behavior on the TPU side, but it occurred during nearly 20% of the ~30 training jobs I've run in the last two days. Metrics Reports
What happens:
Questions
|
The loss is on training data and the accuracy is on the eval data I think. It could be that you're getting some kind of data caching bug in your custom GCS reading implementation where the training loop is iterating over a small portion of data and overfitting and doing terribly on the eval data. PyTorch dataloader tends to cause the OS to cache data (as explored in https://b.corp.google.com/issues/175324667) and the OS caching might be interacting in an unexpected way with your implementation. Or this could be independent of OS caching. You might look at the actual prediction vs. true label tensors to see if it's guessing poorly or if it's predicting NaN or something. You could also try printing the training data to see if you're iterating over all of it or just a portion. Maybe you could keep a count of how many times you've seen each input filename or how many times you've seen each class |
Just a FYI: I used a similar GCSDataset solution which streams jpegs from GCS one jpeg at a time (on multiple processes of course, specified by Dataloader num_workers). It's a bit slow because of the overhead due to lots of small files. Also indeed need lots of workers => more CPUs. Then I discovered webdataset: https://github.com/tmbdev/webdataset You can pretty easily saturate the GCS download speed with webdataset (tested against gsutil cp from the bucket). Note that you could get IO bottlenecked even with an SSD. Webdataset works great and you don't even need many workers. I'm using 4 workers per CPU tops, but even 1 is OK. One problem is that webdataset inherits from pytorch IterableDataset, and I've been having some issues with that in torch-xla... I can post an issue once I've pinpointed it further. EDIT: I mean torch-xla for of the torch.multiprocessing can be a bit complicated and is the probable issue here. Actually webdataset works just fine! |
thanks for the recommendation @harpone ! It looks like a good option to pursue... I'm still around ~2:50 for avg epoch time and trying to shave this down to something at least comparable to SSD PD (~1:30) For your implementation, do you have any code you could share? |
Unfortunately can't share code explicitly at the moment... but I'm following the webdataset tutorial notebooks pretty closely, something like
and then a pytorch dataloader with An update: I actually failed to get this working properly with torch-xla because the torch-xla dataloader fork requires a length for the dataset, which is a bit problematic with But if you decide to give webdataset a try and run into issues, please @ me and I can try help! |
@harpone @tottenjordan I was also curious if either of you had tried gcsfs, maybe in conjunction with CachedDataset |
@zcain117 the |
@zcain117 yeah I tested gcsfs earlier and indeed it suffers from the same problems due to overhead when having to access lots of small files (e.g. jpegs). Webdataset solves this problem by simply archiving all the data into "shards", which are just .tar archives conatining ~10k image/target pairs and streaming these tar files from (e.g.) GCS. @tmbdev really explains it best in his intro videos: https://www.youtube.com/watch?v=kNuA2wflygM Update: actually I got webdataset working with torch-xla by setting an explicit length... but it does seem to mess up the training, possibly getting identical minibatches per TPU core or something. I think I could work out a minimal torch-xla + webdataset example and publish a gist and an issue. I'll post a link here too when it's done. |
OK I have a webdataset working example here using torch-xla's pl.MpDeviceLoader. No extra dependencies and uses the NVIDIA hosted OpenImages as in the webdataset examples. It actually works fine (I was suspecting there were duplicate minibatches or something). Setting the dataset length works fine too. @tottenjordan this should work as a minimal example how to implement your dataset as a webdataset. |
@harpone WebDataset yields each training sample from each shard exactly once during each epoch, and it uses each shard exactly once during each epoch. For multinode training, you need to add a "nodesplitter=" argument to the WebDataset constructor to determine how datasets are split across nodes. The default is for each node to train on the entire dataset. Generally, something like "nodesplitter=lambda l: l[node_index::num_nodes]" is a reasonable choice. If you have suggestions for better defaults or diagnostic messages, please let me know. |
Ah, the dev version seems to do things a bit differently... need to try that out. I was using the stable version (can't remember the version number, but the pypi one) and that seems to require setting the I was checking if the different minibatches were in fact unique or not, and if I switch off either one of the shard_selection or shard_shuffle, I get duplicates. This is a very subtle bug which leads to convergence issues and should probably be tracked somehow... Couple of things to note that may have an effect:
|
Thank you @harpone. After exploring petastorm, I've started to focus on webdataset. I got simpler examples to work, but now I'm trying to apply to the distributed training setting. The webdataset API version I am using has a slightly different approach for I posted this issue here and the error suggests I'm not properly configuring the Error message: |
I have a working implementation of Webdataset which is achieving epoch times of ~40s. Loss is dropping as expected and validation accuracy seems to be on par with other configurations. The data is now stored in POSIX tar files on GCS and I'm using webdataset to retrieve each shard, shuffle shards and samples, and deliver to the PyTorch Dataloader as usual. Essentially replacing
Because this is now an IterableDataset (not map-style), I needed to adjust the training and validation loops to look something like this:
However, once training reaches the last of my specified epochs, I recieve a BrokenPipe error like this:
Reading the Multiprocessing Shutdown Logic in the docs here, it seems like I may not be exiting the iterator gracefully or once its depleted. Any ideas I should pursue? Full code script and metrics report are attached: torchXLA-webdataset-trial2-metrics-debug-BrokenPipe.txt |
3 things:
|
I'm updating the distributed training examples; I hope I can push those out this week. Please also see my comments on webdataset/webdataset#47 |
Does the |
Oops forgot to reply yesterday... anyway, seems you got the IterableDataset issue fixed. Never got exactly same warnings like you, but got similar ones when I was explicitly setting the length kwarg in wds. I think I got a similar |
When trying to run 90 epochs, something is accumulating memory and by epoch ~16 i get the following:
This might be related to my use of
in the training and test loops:
@zcain117 would anything in the metrics report help identify this problem? |
I don't see that error in the logs you gave. Is that TPU OOM or an OOM on the python/VM side? |
Damn you're right. When training restarted, I must not have captured both reports. I'll try to recreate |
Not sure why the memory errors are not getting copied in my Also attached the training code for this run. Notice I didnt use the Here is a screenshot of the memory allocation during training. Something is being accumulated during training;I suspect its training data from the loaders, but I'm not sure if anything from the
...
...
|
Looks like your VM is running out of memory and seems unrelated to TPUs, so I don't think anything in this TPU metric report will help much I guess the data is not being released after the epoch ends and the new data is loaded In your code, you are using MpDeviceLoader, which is not directly a PyTorch DataLoader. It is a class that maintains its own queues and loads data for every call to I am not sure what is happening but maybe the WebDataset itself is not releasing memory or maybe the MpDeviceLoader is creating references that prevent old data from being cleaned up. Some things to test:
|
Oh yeah damn, I had a similar issue... you could track the number of running python processes to see if they're increasing at the beginning of each epoch. Definitely happened also with GPUs. I don't remember how I resolved that :( |
Hi @tmbdev , can you take a look at the memory errors cited above? I was thinking that perhaps the data loader was accumulating samples over training, but I implemented the following to check the loader size during each train epoch:
Any ideas what this could be? |
I noticed that switching to |
We have been using WebDataset in workers for a long time and not seen any memory leaks. WebDataset doesn't operate any differently whether it runs standalone, inside a worker, or on another node. If you suspect a memory leak in the worker process, you can simply run it standalone and see what happens to the memory:
If that doesn't run out of memory, then WebDataset isn't leaking any memory. Running such a standalone loop should be pretty quick (since you're not doing any processing). (Note that process memory usage can grow and stay high even if there is no memory leak, due to the way the Python garbage collector works, but growth should usually stop before running out of memory.) Error handling and recovery in subprocesses of DataLoader is messy, though, as is the queue handling. That's one of the reason we developed github.com/nvlabs/tensorcom It's not only more testable, it also gives you dynamically scalable and distributed preprocessing and data augmentation. |
Ok - a stupid, careless mistake was leading to the memory leakage. In the validation dataset/loader I was using Seems this issue occurs when we do not correctly split dataset instances between both nodes/devices and workers - observed this issue when setting either and both the worker and node splitters to
Avg epoch training time is back to ~1:35... which is very close to the time when data is stored locally on PD/VM! |
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions. |
We have built a terraform script that spins up 4 VMs and uses a v3-32 TPU for Resnet50 training. We store the Imagenet training and validation data in a GCS bucket. Full code repo can be found here
torch_xla.distributed.xla_dist
test_train_mp_imagenet.py
(only altering it to use our GCS data loader)For the questions below, I've attached a log file (with metrics_debug), and used the following configuration:
n2-custom (72 vCPUs, 512 GB memory)
NUM_EPOCHS=20
BATCH_SIZE=512
TEST_BATCH_SIZE=64
NUM_WORKERS=8
log_steps=200
--conda-env=torch-xla-1.7
--env XLA_USE_BF16=1
Questions
BrokenPipeError: [Errno 32] Broken pipe
orunhealthy mesh
errors and training will automatically restart (see line 20689 in log file for Broken Pipe Error during Epoch 13).imagenetraw_logfiles4-v3-32-512batch-8workers.txt
@zcain117
@shanemhansen
The text was updated successfully, but these errors were encountered: