Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distributed TPU Training, training data stored in GCS #2690

Closed
tottenjordan opened this issue Dec 16, 2020 · 47 comments
Closed

Distributed TPU Training, training data stored in GCS #2690

tottenjordan opened this issue Dec 16, 2020 · 47 comments
Assignees
Labels
stale Has not had recent activity

Comments

@tottenjordan
Copy link

We have built a terraform script that spins up 4 VMs and uses a v3-32 TPU for Resnet50 training. We store the Imagenet training and validation data in a GCS bucket. Full code repo can be found here

  • we use the torch_xla.distributed.xla_dist
  • as well as the test_train_mp_imagenet.py (only altering it to use our GCS data loader)

For the questions below, I've attached a log file (with metrics_debug), and used the following configuration:

  • VM machine types = n2-custom (72 vCPUs, 512 GB memory)
  • NUM_EPOCHS=20
  • BATCH_SIZE=512
  • TEST_BATCH_SIZE=64
  • NUM_WORKERS=8
  • log_steps=200
  • --conda-env=torch-xla-1.7
  • --env XLA_USE_BF16=1
    • default learning rate and schedule

Questions

  • Not sure what baseline to compare with, but epoch training time seems to be around 5-6 minutes.
    • This is true for 8 workers on batch sizes of 128, 256, and 512. (batch size of 128 with 32 workers seems to be low 4 minutes per epoch).
    • Is there anything from a code or configuration perspective we could do to improve this? 32 workers seems like overkill, but we've seen better results with this?
  • Sometime we will get BrokenPipeError: [Errno 32] Broken pipe or unhealthy mesh errors and training will automatically restart (see line 20689 in log file for Broken Pipe Error during Epoch 13).
    • Is there anything we can do to overcome this?

imagenetraw_logfiles4-v3-32-512batch-8workers.txt

@zcain117
@shanemhansen

@tottenjordan
Copy link
Author

Here is a screenshot of the dashboard for the previously mentioned training job

Screenshot 2020-12-16 at 11 39 20 AM

@taylanbil
Copy link
Collaborator

how does the xla metrics report look? Does it differ from our regular nightly resnet50 runs too much in terms of ExecuteTime, CompileTime, etc ? @zcain117

@zcain117
Copy link
Collaborator

  • Compiles looks fine - no new compiles after the 1st epoch
  • no aten:: calls except for reasonable amount of aten::_local_scalar_dense

I don't have any recent logs with metrics so I will wrap up a few of my experiments running now and then try to kick off 2 runs on v3-32, one using regular PD and one with SSD. Then we can compare metrics

@zcain117 zcain117 self-assigned this Dec 16, 2020
@zcain117
Copy link
Collaborator

zcain117 commented Dec 17, 2020

Attached the metrics for my v3-32 run (filtered to just 1 of the 4 VM's stdout)
512_ssd_logs_grepped.txt

  • CompileTime is static for both after epoch 1
  • ExecuteTime yours: 04m56s -> 05m26s -> 07m35s -> 08m15s -> 09m45s -> 11m03s -> 12m23s -> 14m52s -> 16m13s -> 18m38s -> 19m49s -> 21m48s -> 22m01s -> 23m20s (epoch 14)
  • ExecuteTime mine: 02m06s -> 03m13s -> 04m18s -> 05m16s -> 06m21s -> 07m30s -> 09m34s -> 10m35s -> 11m34s -> 12m40s -> 13m45s -> 14m44s -> 15m15s (epoch 14)
  • No major difference in TransferToServerTime or TransferFromServerTime
  • DeviceLockWait total was ~11m on mine and ~20m on yours

Some of the XrtMetrics:

  • Metric: XrtAllocateFromTensor.c_tpu_worker.0 was ~1hr01m on mine and ~1hr26m on yours
  • The XrtExecute was substantially higher on your run. Here is from my run after 14 epochs:
2020-12-17 20:20:12 10.164.0.108 [0] Metric: XrtExecute.c_tpu_worker.0
2020-12-17 20:20:12 10.164.0.108 [0]   TotalSamples: 65744
2020-12-17 20:20:12 10.164.0.108 [0]   Accumulator: 05h10m41s479ms268.855us
2020-12-17 20:20:12 10.164.0.108 [0]   Mean: 490ms928.087us
2020-12-17 20:20:12 10.164.0.108 [0]   StdDev: 475ms630.074us
2020-12-17 20:20:12 10.164.0.108 [0]   Rate: 8.95582 / second
2020-12-17 20:20:12 10.164.0.108 [0]   Percentiles: 25%=030ms152.881us; 50%=532ms559.099us; 80%=839ms410.966us; 90%=941ms093.544us; 95%=01s107ms439.003us; 99%=02s541ms791.543us
2020-12-17 20:20:12 10.164.0.108 [0] Metric: XrtExecute.c_tpu_worker.1
2020-12-17 20:20:12 10.164.0.108 [0]   TotalSamples: 65744
2020-12-17 20:20:12 10.164.0.108 [0]   Accumulator: 05h07m12s957ms501.585us
2020-12-17 20:20:12 10.164.0.108 [0]   Mean: 497ms052.484us
2020-12-17 20:20:12 10.164.0.108 [0]   StdDev: 489ms446.525us
2020-12-17 20:20:12 10.164.0.108 [0]   Rate: 8.97735 / second
2020-12-17 20:20:12 10.164.0.108 [0]   Percentiles: 25%=030ms993.646us; 50%=533ms292.432us; 80%=834ms846.378us; 90%=01s008ms699.441us; 95%=01s223ms020.681us; 99%=02s710ms407.093us
2020-12-17 20:20:12 10.164.0.108 [0] Metric: XrtExecute.c_tpu_worker.2
2020-12-17 20:20:12 10.164.0.108 [0]   TotalSamples: 65744
2020-12-17 20:20:12 10.164.0.108 [0]   Accumulator: 05h02m44s926ms805.700us
2020-12-17 20:20:12 10.164.0.108 [0]   Mean: 478ms717.034us
2020-12-17 20:20:12 10.164.0.108 [0]   StdDev: 446ms347.398us
2020-12-17 20:20:12 10.164.0.108 [0]   Rate: 8.97536 / second
2020-12-17 20:20:12 10.164.0.108 [0]   Percentiles: 25%=030ms261.579us; 50%=550ms147.834us; 80%=803ms044.583us; 90%=886ms119.476us; 95%=01s028ms439.400us; 99%=02s965ms099.527us
2020-12-17 20:20:12 10.164.0.108 [0] Metric: XrtExecute.c_tpu_worker.3
2020-12-17 20:20:12 10.164.0.108 [0]   TotalSamples: 65744
2020-12-17 20:20:12 10.164.0.108 [0]   Accumulator: 05h60m37s069ms776.382us
2020-12-17 20:20:12 10.164.0.108 [0]   Mean: 492ms107.899us
2020-12-17 20:20:12 10.164.0.108 [0]   StdDev: 466ms978.928us
2020-12-17 20:20:12 10.164.0.108 [0]   Rate: 8.95665 / second
2020-12-17 20:20:12 10.164.0.108 [0]   Percentiles: 25%=030ms009.936us; 50%=572ms417.431us; 80%=779ms379.190us; 90%=893ms028.297us; 95%=01s192ms527.232us; 99%=02s026ms606.519us

And yours:

19309 2020-12-16 16:35:11 10.164.0.15 [0] Metric: XrtExecute.c_tpu_worker.0
19310 2020-12-16 16:35:11 10.164.0.15 [0]   TotalSamples: 36793
19311 2020-12-16 16:35:11 10.164.0.15 [0]   Accumulator: 13h49m38s272ms501.708us
19312 2020-12-16 16:35:11 10.164.0.15 [0]   Mean: 555ms735.382us
19313 2020-12-16 16:35:11 10.164.0.15 [0]   StdDev: 958ms486.022us
19314 2020-12-16 16:35:11 10.164.0.15 [0]   Rate: 2.88988 / second
19315 2020-12-16 16:35:11 10.164.0.15 [0]   Percentiles: 25%=029ms421.710us; 50%=496ms488.162us; 80%=558ms749.285us; 90%=693ms959.466us; 95%=02s444ms      354.965us; 99%=06s568ms853.819us
19316 2020-12-16 16:35:11 10.164.0.15 [0] Metric: XrtExecute.c_tpu_worker.1
19317 2020-12-16 16:35:11 10.164.0.15 [0]   TotalSamples: 36794
19318 2020-12-16 16:35:11 10.164.0.15 [0]   Accumulator: 13h31m54s695ms601.280us
19319 2020-12-16 16:35:11 10.164.0.15 [0]   Mean: 555ms451.089us
19320 2020-12-16 16:35:11 10.164.0.15 [0]   StdDev: 956ms374.188us
19321 2020-12-16 16:35:11 10.164.0.15 [0]   Rate: 2.88073 / second
19322 2020-12-16 16:35:11 10.164.0.15 [0]   Percentiles: 25%=030ms713.184us; 50%=496ms333.137us; 80%=555ms457.000us; 90%=687ms633.420us; 95%=03s646ms      049.735us; 99%=05s333ms692.363us
19323 2020-12-16 16:35:11 10.164.0.15 [0] Metric: XrtExecute.c_tpu_worker.2
19324 2020-12-16 16:35:11 10.164.0.15 [0]   TotalSamples: 36794
19325 2020-12-16 16:35:11 10.164.0.15 [0]   Accumulator: 13h40m56s431ms993.757us
19326 2020-12-16 16:35:11 10.164.0.15 [0]   Mean: 553ms892.010us
19327 2020-12-16 16:35:11 10.164.0.15 [0]   StdDev: 885ms607.958us
19328 2020-12-16 16:35:11 10.164.0.15 [0]   Rate: 2.88393 / second
19329 2020-12-16 16:35:11 10.164.0.15 [0]   Percentiles: 25%=030ms643.081us; 50%=497ms562.074us; 80%=561ms749.341us; 90%=679ms752.559us; 95%=03s819ms      256.936us; 99%=05s577ms093.633us
19330 2020-12-16 16:35:11 10.164.0.15 [0] Metric: XrtExecute.c_tpu_worker.3
19331 2020-12-16 16:35:11 10.164.0.15 [0]   TotalSamples: 36793
19332 2020-12-16 16:35:11 10.164.0.15 [0]   Accumulator: 13h45m05s942ms659.191us
19333 2020-12-16 16:35:11 10.164.0.15 [0]   Mean: 582ms085.747us
19334 2020-12-16 16:35:11 10.164.0.15 [0]   StdDev: 01s044ms268.753us
19335 2020-12-16 16:35:11 10.164.0.15 [0]   Rate: 2.87976 / second
19336 2020-12-16 16:35:11 10.164.0.15 [0]   Percentiles: 25%=030ms526.811us; 50%=497ms784.653us; 80%=559ms869.838us; 90%=682ms877.519us; 95%=03s790ms      202.928us; 99%=06s726ms888.469us

@JackCaoG do you think XrtExecute is usually an important metric?

@tottenjordan have you compared your speed if using a PD or SSD PD instead of the custom GCS reader? And regarding the custom GCS reader, if there any way you could increase the prefetching amount?

@taylanbil
Copy link
Collaborator

XrtExecute doesn't get reset b/w runs, you need to start tpu again for it to be reset. This looks like it didn't get restarted b/w runs?

ExecuteTime is very different. Are the models/batch sizes etc exactly the same?

@zcain117
Copy link
Collaborator

GCS was ~5m30s per epoch and SSD PD was ~1m30s per epoch so ~4min * 14 epochs = ~56min cumulative difference but ExecuteTime was only ~8min cumulative difference. Probably something else is causing the majority of the difference?

@zcain117
Copy link
Collaborator

Forgot to include my training command:
python -m torch_xla.distributed.xla_dist --tpu=$SSD_TPU_NAME --conda-env=torch-xla-1.7 --env XLA_USE_BF16=1 --env ANY_OTHER=ENV_VAR -- python /usr/share/torch-xla-1.7/pytorch/xla/test/test_train_mp_imagenet.py --model=resnet50 --num_workers=8 --batch_size=512 --log_steps=200 --num_epochs=14 --metrics_debug --datadir=/mnt/disks/dataset/imagenet 2>&1 | tee ssd_training_logs.txt

@zcain117
Copy link
Collaborator

@tottenjordan I think the next thing to try is increase the prefetch_factor when you initialize the DataLoader here. Maybe try a few different sizes (default value is 2)

I think this change should be in the 1.7 version you're using: https://github.com/pytorch/pytorch/releases/tag/v1.7.0

@taylanbil
Copy link
Collaborator

what are the percentiles for executetime?

@tottenjordan
Copy link
Author

@zcain117 , @shanemhansen ran training jobs with PD and will post log/metrics report when available.

Here is first successful run with prefetch_factor=4. Attached log/metrics report, but here are some highlights:

  • v2-32 (due to TPU capacity constraints)
  • reduced batch size from 512 to 256
    • (v2-32 couldnt handle 512 + prefetch of 4*8 workers)
  • average epoch training time is 6:24
    • Doesn't include epoch 1 of 6:38
    • epoch 5 training time = 9:21; epoch 10 = 10:24
export NUM_EPOCHS=14
export BATCH_SIZE=256
export TEST_BATCH_SIZE=64
export NUM_WORKERS=8

python -m torch_xla.distributed.xla_dist --tpu=$TPU_NAME --conda-env=torch-xla-1.7 \
    --env XLA_USE_BF16=1 \
    -- python /tmp/thepackage/test_train_mp_imagenet.py \
    --num_epochs=$NUM_EPOCHS \
    --batch_size=$BATCH_SIZE \
    --num_workers=$NUM_WORKERS \
    --log_steps=200 \
    --logdir=$LOGDIR \
    --datadir=$IMAGE_DIR \
    --test_set_batch_size=$TEST_BATCH_SIZE \
    --metrics_debug

imagenetr1aw_logfiles6-v2-32-256batch-8wrks-4prefetch.txt

@zcain117 @taylanbil Our data are stored as JPEG. Is it worth exploring tfrecords or other formats? Are the data in your PD tests stored as JPEG?

Will run prefetch_factor=8 next

@tottenjordan
Copy link
Author

prefetch_factor=8

  • v2-32
  • average epoch training time = 5:43
    • doesnt include epoch 1 of 6:07
    • more consistent and lower average than prefetch_factor=4 w/ 8 workers
export NUM_EPOCHS=14
export BATCH_SIZE=256
export TPU_NAME=mytpu
export TEST_BATCH_SIZE=64
export NUM_WORKERS=8

python -m torch_xla.distributed.xla_dist --tpu=$TPU_NAME --conda-env=torch-xla-1.7 \
    --env XLA_USE_BF16=1 \
    -- python /tmp/thepackage/test_train_mp_imagenet.py \
    --num_epochs=$NUM_EPOCHS \
    --batch_size=$BATCH_SIZE \
    --num_workers=$NUM_WORKERS \
    --log_steps=200 \
    --logdir=$LOGDIR \
    --datadir=$IMAGE_DIR \
    --test_set_batch_size=$TEST_BATCH_SIZE \
    --metrics_debug

imagenetr1aw_logfiles7-v2-32-256batch-8wrks-8prefetch.txt

@tottenjordan
Copy link
Author

This is interesting. Increased to num_workers=16 and kept prefetch_factor=8

  • v2-32
  • average epoch training is 3:53
    • consistent throughout job
export NUM_EPOCHS=10
export BATCH_SIZE=256
export TEST_BATCH_SIZE=64
export NUM_WORKERS=16

python -m torch_xla.distributed.xla_dist --tpu=$TPU_NAME --conda-env=torch-xla-1.7 \
    --env XLA_USE_BF16=1 \
    -- python /tmp/thepackage/test_train_mp_imagenet.py \
    --num_epochs=$NUM_EPOCHS \
    --batch_size=$BATCH_SIZE \
    --num_workers=$NUM_WORKERS \
    --log_steps=200 \
    --logdir=$LOGDIR \
    --datadir=$IMAGE_DIR \
    --test_set_batch_size=$TEST_BATCH_SIZE \
    --metrics_debug

imagenetr1aw_logfiles8-v2-32-256batch-16wrks-8prefetch.txt

@zcain117 how do i get metrics_debug to print out in sequence by worker, instead of every other line potentially being from another worker?

@zcain117
Copy link
Collaborator

I don't know of any way to order the outputs. I just filter to 1 of the VM's output at a time when reading.

E.g. you see lines like this where the VM's IP is mentioned: 2020-12-18 15:52:56 10.128.0.70 [2] Step 0

So I was using grep "10.128.0.70" logs.txt

Earlier you mentioned good results with num_workers=32, maybe it's worth trying that + prefetch_factor even though that many workers is higher than we normally recommend

@tottenjordan
Copy link
Author

Switched to v3-32 TPU pod.

Increasing the num_workers and prefetch_factor is showing better results:

Simply, moving from v2-32 to v3-32 showed a slight speed up of ~20s

On v3-32, increasing num_workers and keeping prefetch_factor constant, showed speed-up of ~23s
num_workers=32 and prefetch_factor=8

2020-12-21 16:29:15 10.164.0.29 [0]   TotalSamples: 448
2020-12-21 16:29:15 10.164.0.29 [0]   Accumulator: 276ms908.139us
2020-12-21 16:29:15 10.164.0.29 [0]   Mean: 615.866us
2020-12-21 16:29:15 10.164.0.29 [0]   StdDev: 001ms449.249us
2020-12-21 16:29:15 10.164.0.29 [0]   Rate: 0.0584148 / second
2020-12-21 16:29:15 10.164.0.29 [0]   Percentiles: 25%=350.820us; 50%=398.093us; 80%=497.008us; 90%=616.153us; 95%=937.838us; 99%=007ms552.195us
2020-12-21 16:29:15 10.164.0.29 [0] Metric: XrtReadLiteral.c_tpu_worker.1
2020-12-21 16:29:15 10.164.0.29 [0]   TotalSamples: 448
2020-12-21 16:29:15 10.164.0.29 [0]   Accumulator: 225ms417.282us
2020-12-21 16:29:15 10.164.0.29 [0]   Mean: 503.164us
2020-12-21 16:29:15 10.164.0.29 [0]   StdDev: 842.045us
2020-12-21 16:29:15 10.164.0.29 [0]   Rate: 0.0584197 / second
2020-12-21 16:29:15 10.164.0.29 [0]   Percentiles: 25%=352.787us; 50%=406.847us; 80%=508.414us; 90%=598.991us; 95%=718.260us; 99%=002ms539.905us
2020-12-21 16:29:15 10.164.0.29 [0] Metric: XrtReadLiteral.c_tpu_worker.2
2020-12-21 16:29:15 10.164.0.29 [0]   TotalSamples: 448
2020-12-21 16:29:15 10.164.0.29 [0]   Accumulator: 261ms891.467us
2020-12-21 16:29:15 10.164.0.29 [0]   Mean: 582.347us
2020-12-21 16:29:15 10.164.0.29 [0]   StdDev: 001ms256.640us
2020-12-21 16:29:15 10.164.0.29 [0]   Rate: 0.0584187 / second
2020-12-21 16:29:15 10.164.0.29 [0]   Percentiles: 25%=352.927us; 50%=411.384us; 80%=496.829us; 90%=583.585us; 95%=760.846us; 99%=006ms572.937us
2020-12-21 16:29:15 10.164.0.29 [0] Metric: XrtReadLiteral.c_tpu_worker.3
2020-12-21 16:29:15 10.164.0.29 [0]   TotalSamples: 448
2020-12-21 16:29:15 10.164.0.29 [0]   Accumulator: 257ms218.651us
2020-12-21 16:29:15 10.164.0.29 [0]   Mean: 574.149us
2020-12-21 16:29:15 10.164.0.29 [0]   StdDev: 001ms380.042us
2020-12-21 16:29:15 10.164.0.29 [0]   Rate: 0.0584215 / second
2020-12-21 16:29:15 10.164.0.29 [0]   Percentiles: 25%=355.008us; 50%=405.369us; 80%=516.042us; 90%=597.709us; 95%=727.514us; 99%=004ms020.203us```


@zcain117
Copy link
Collaborator

So a couple more results I would like to see:

  1. Same setup but using a PD and/or SSD PD instead of reading from GCS (omit the prefetch_factor and might also need to lower num_workers down from 32)
  2. GCS reading, 32 workers, higher prefetch_factor. I feel like we should go as high as possible until we see memory errors or performance degredation. It would be useful to understand the curve of performance vs prefetch_factor at various num_worker levels

@tottenjordan
Copy link
Author

Agreed. I'll work on these and follow-up with results..

I set up an experiment to test different configs of (1) TPU Pod version, (2) batch_size, (3) prefetch_factor, and (4) num_workers. And plan to do these for both GCS and PD/SSD PD. I've made progress on GCS trials, will continue to increase until get errors. then will try similar configs (where applicable) for PD

@shanemhansen

@tottenjordan
Copy link
Author

tottenjordan commented Dec 23, 2020

@zcain117 Using SSD PD, average epoch training time = ~1:32

  • export NUM_EPOCHS=10
  • export BATCH_SIZE=256
  • export NUM_WORKERS=8
  • prefetch_factor left a default
2020-12-23 17:54:01 10.164.0.51 [0] Metric: XrtReleaseAllocation.c_tpu_worker.0
2020-12-23 17:54:01 10.164.0.51 [0]   TotalSamples: 182400
2020-12-23 17:54:01 10.164.0.51 [0]   Accumulator: 54s489ms144.144us
2020-12-23 17:54:01 10.164.0.51 [0]   Mean: 279.874us
2020-12-23 17:54:01 10.164.0.51 [0]   StdDev: 810.276us
2020-12-23 17:54:01 10.164.0.51 [0]   Rate: 449.938 / second
2020-12-23 17:54:01 10.164.0.51 [0]   Percentiles: 25%=038.745us; 50%=170.667us; 80%=401.007us; 90%=513.775us; 95%=592.116us; 99%=903.837us
2020-12-23 17:54:01 10.164.0.51 [0] Metric: XrtReleaseAllocation.c_tpu_worker.1
2020-12-23 17:54:01 10.164.0.51 [0]   TotalSamples: 230948
2020-12-23 17:54:01 10.164.0.51 [0]   Accumulator: 55s416ms230.960us
2020-12-23 17:54:01 10.164.0.51 [0]   Mean: 205.796us
2020-12-23 17:54:01 10.164.0.51 [0]   StdDev: 203.502us
2020-12-23 17:54:01 10.164.0.51 [0]   Rate: 466.805 / second
2020-12-23 17:54:01 10.164.0.51 [0]   Percentiles: 25%=028.857us; 50%=144.102us; 80%=396.629us; 90%=517.070us; 95%=608.961us; 99%=762.827us
2020-12-23 17:54:01 10.164.0.51 [0] Metric: XrtReleaseAllocation.c_tpu_worker.2
2020-12-23 17:54:01 10.164.0.51 [0]   TotalSamples: 148753
2020-12-23 17:54:01 10.164.0.51 [0]   Accumulator: 56s508ms175.019us
2020-12-23 17:54:01 10.164.0.51 [0]   Mean: 220.988us
2020-12-23 17:54:01 10.164.0.51 [0]   StdDev: 273.390us
2020-12-23 17:54:01 10.164.0.51 [0]   Rate: 494.729 / second
2020-12-23 17:54:01 10.164.0.51 [0]   Percentiles: 25%=035.262us; 50%=136.353us; 80%=401.637us; 90%=537.228us; 95%=625.096us; 99%=808.244us
2020-12-23 17:54:01 10.164.0.51 [0] Metric: XrtReleaseAllocation.c_tpu_worker.3
2020-12-23 17:54:01 10.164.0.51 [0]   TotalSamples: 157583
2020-12-23 17:54:01 10.164.0.51 [0]   Accumulator: 59s824ms229.353us
2020-12-23 17:54:01 10.164.0.51 [0]   Mean: 216.674us
2020-12-23 17:54:01 10.164.0.51 [0]   StdDev: 209.225us
2020-12-23 17:54:01 10.164.0.51 [0]   Rate: 503.516 / second
2020-12-23 17:54:01 10.164.0.51 [0]   Percentiles: 25%=034.583us; 50%=154.641us; 80%=411.185us; 90%=523.400us; 95%=627.079us; 99%=755.876us
2020-12-23 17:54:01 10.164.0.51 [0] 
2020-12-23 17:54:01 10.164.0.51 [0] Max Accuracy: 41.56%

imagenetraw_logfiles11-SSDPD-v3-32-256batch-8workers-2prefetch.txt

TO DO

  1. For SSD PD reading, test increased num_workers, prefetch_factor, and batch_size
  2. For GCS reading, at 32 workers, increase prefetch_factor until memory error

@zcain117
Copy link
Collaborator

That speed is consistent with what I've gotten using v3-32 and SSD on imagenet. It seems like this is ~twice as fast as the best GCS run so far, so there's still room for improvement in the GCS version for the given dataset+model architecture.

It seems like your speed at 256 batch size is about the same as my best speed which was using 128 batch size. So if you run into a memory error when trying GCS with higher prefetch_factor, maybe consider lowering batch size since 256 vs 128 might not make much difference and might be a win if it allows more prefetch

@tottenjordan
Copy link
Author

I've tested different combinations of the following

  • batch_size= [128, 256, 512]
  • num_workers= [8, 16, 32]
  • prefetch_factor=[2,4,6,8,10,12,14,16,18,20,22]

Haven't had impressive results with batch_size=512

batch_size=128 is yielding the best results, so I'm testing num_workers=[16...30] for each prefetch_factor=[16,18...] until I receive error or diminishing returns

Best avg. training epoch time = 2:46

  • batch_size=128
  • prefetch_factor=18
  • num_workers=22
  • v3-32 pod

logfiles-metrics-GCS-128bs-22work-18prefetch-v332pod.txt

2021-01-05 22:36:22 10.164.15.197 [0]   TotalSamples: 689499
2021-01-05 22:36:22 10.164.15.197 [0]   Accumulator: 04m14s621ms200.411us
2021-01-05 22:36:22 10.164.15.197 [0]   Mean: 210.507us
2021-01-05 22:36:22 10.164.15.197 [0]   StdDev: 224.131us
2021-01-05 22:36:22 10.164.15.197 [0]   Rate: 258.551 / second
2021-01-05 22:36:22 10.164.15.197 [0]   Percentiles: 25%=032.002us; 50%=145.004us; 80%=385.428us; 90%=529.361us; 95%=607.630us; 99%=806.144us
2021-01-05 22:36:22 10.164.15.197 [0] Metric: XrtReleaseAllocation.c_tpu_worker.1
2021-01-05 22:36:22 10.164.15.197 [0]   TotalSamples: 716089
2021-01-05 22:36:22 10.164.15.197 [0]   Accumulator: 04m58s347ms245.929us
2021-01-05 22:36:22 10.164.15.197 [0]   Mean: 201.961us
2021-01-05 22:36:22 10.164.15.197 [0]   StdDev: 218.517us
2021-01-05 22:36:22 10.164.15.197 [0]   Rate: 273.607 / second
2021-01-05 22:36:22 10.164.15.197 [0]   Percentiles: 25%=036.074us; 50%=145.123us; 80%=347.944us; 90%=489.711us; 95%=598.754us; 99%=732.415us
2021-01-05 22:36:22 10.164.15.197 [0] Metric: XrtReleaseAllocation.c_tpu_worker.2
2021-01-05 22:36:22 10.164.15.197 [0]   TotalSamples: 698446
2021-01-05 22:36:22 10.164.15.197 [0]   Accumulator: 04m13s077ms532.441us
2021-01-05 22:36:22 10.164.15.197 [0]   Mean: 208.645us
2021-01-05 22:36:22 10.164.15.197 [0]   StdDev: 229.160us
2021-01-05 22:36:22 10.164.15.197 [0]   Rate: 250.631 / second
2021-01-05 22:36:22 10.164.15.197 [0]   Percentiles: 25%=035.645us; 50%=156.691us; 80%=366.583us; 90%=483.195us; 95%=570.105us; 99%=809.667us
2021-01-05 22:36:22 10.164.15.197 [0] Metric: XrtReleaseAllocation.c_tpu_worker.3
2021-01-05 22:36:22 10.164.15.197 [0]   TotalSamples: 661643
2021-01-05 22:36:22 10.164.15.197 [0]   Accumulator: 04m10s768ms803.569us
2021-01-05 22:36:22 10.164.15.197 [0]   Mean: 177.682us
2021-01-05 22:36:22 10.164.15.197 [0]   StdDev: 229.021us
2021-01-05 22:36:22 10.164.15.197 [0]   Rate: 263.604 / second
2021-01-05 22:36:22 10.164.15.197 [0]   Percentiles: 25%=024.501us; 50%=095.764us; 80%=326.172us; 90%=462.122us; 95%=570.292us; 99%=696.035us```

@tottenjordan
Copy link
Author

I've yet to get a "memory error", but I am running into unexpected training metrics. I discussed with @zcain117 that this may be random behavior on the TPU side, but it occurred during nearly 20% of the ~30 training jobs I've run in the last two days.

Metrics Reports

  • Here are two metrics reports for the same configuration.
  • First attempt was unsuccessful, second attempt was successful.
  • Didn't allocate new TPU resource for second attempt.
  • Even though this combination of num_workers and prefetch_factor seems high, same situation occurs with more reasonable configs (e.g., num_workers=8, prefetch_factor=14)
  • Using this command:
prefetch_factor=18

python -m torch_xla.distributed.xla_dist 
    --tpu=$TPU_NAME --conda-env=torch-xla-1.7 --env XLA_USE_BF16=1 
    -- python /tmp/thepackage/test_train_mp_imagenet.py \
    --num_epochs=5 \
    --batch_size=128 \
    --num_workers=26 \
    --log_steps=200 \
    --logdir=$LOGDIR --datadir=$IMAGE_DIR 
    --test_set_batch_size=64 \
    --metrics_debug

What happens:

  • During Epoch 1 at Step=0, Loss=~7.000 for each VM-worker
  • During Epoch 1 by Step=200, Loss=0.000 for each VM-worker. And this remains constant for the remaining Epochs
  • For all Epochs, Reduced Accuracy=0.00% and Replica Accuracy=0.00%
  • For most configurations, total training time and avg. epoch training time are consistent when experiencing and not experiencing this behavior
  • This behavior will occur for a certain config, and then I've experienced any of the following for the same config:
    • Successfully rerun the job without allocating new TPU resource
    • Experience same behavior without allocating new TPU resource
    • Successfully rerun the job after allocating new TPU resource
    • Experience same behavior after allocating new TPU resource
    • Sometimes, I can run a different config (successfully), and then rerun the original successfully. However, this doesn't always work

Questions

  • Loss=0 and Acc=0 doesn't make sense to me. What should I investigate? Not sure where to start
  • I haven't noticed any patterns with this yet, but should I consider closer look at (1) total prefetched images, (2) prefetched per VM/process, (3) all these compare to batch size e.g., % of batch that is prefetched?

@zcain117
Copy link
Collaborator

zcain117 commented Jan 6, 2021

The loss is on training data and the accuracy is on the eval data I think. It could be that you're getting some kind of data caching bug in your custom GCS reading implementation where the training loop is iterating over a small portion of data and overfitting and doing terribly on the eval data. PyTorch dataloader tends to cause the OS to cache data (as explored in https://b.corp.google.com/issues/175324667) and the OS caching might be interacting in an unexpected way with your implementation. Or this could be independent of OS caching.

You might look at the actual prediction vs. true label tensors to see if it's guessing poorly or if it's predicting NaN or something.

You could also try printing the training data to see if you're iterating over all of it or just a portion. Maybe you could keep a count of how many times you've seen each input filename or how many times you've seen each class

@harpone
Copy link

harpone commented Jan 8, 2021

Just a FYI: I used a similar GCSDataset solution which streams jpegs from GCS one jpeg at a time (on multiple processes of course, specified by Dataloader num_workers). It's a bit slow because of the overhead due to lots of small files. Also indeed need lots of workers => more CPUs.

Then I discovered webdataset: https://github.com/tmbdev/webdataset

You can pretty easily saturate the GCS download speed with webdataset (tested against gsutil cp from the bucket). Note that you could get IO bottlenecked even with an SSD. Webdataset works great and you don't even need many workers. I'm using 4 workers per CPU tops, but even 1 is OK.

One problem is that webdataset inherits from pytorch IterableDataset, and I've been having some issues with that in torch-xla... I can post an issue once I've pinpointed it further. torch-xla dataloader stuff can be a bit complicated :/

EDIT: I mean torch-xla for of the torch.multiprocessing can be a bit complicated and is the probable issue here. Actually webdataset works just fine!

@tottenjordan
Copy link
Author

thanks for the recommendation @harpone !

It looks like a good option to pursue... I'm still around ~2:50 for avg epoch time and trying to shave this down to something at least comparable to SSD PD (~1:30)

For your implementation, do you have any code you could share?

@harpone
Copy link

harpone commented Jan 13, 2021

Unfortunately can't share code explicitly at the moment... but I'm following the webdataset tutorial notebooks pretty closely, something like

dataset = (wds.Dataset(urls,
                               length=None,
                               tarhandler=warn_and_cont
                               )
                   .pipe(my_data_decoder)
                   .pipe(augment)
                   .pipe(batched(batchsize=args.batch_size, partial=True, collation_fn=collate_fn))
                   )

and then a pytorch dataloader with batch_size=None.

An update: I actually failed to get this working properly with torch-xla because the torch-xla dataloader fork requires a length for the dataset, which is a bit problematic with IterableDatasets... probably webdataset is not quite usable with torch-xla at the moment. Also having some more problems with torch-xla so I'm actually going to go with GPUs... sorry for getting your hopes up :/

But if you decide to give webdataset a try and run into issues, please @ me and I can try help!

@zcain117
Copy link
Collaborator

@harpone @tottenjordan I was also curious if either of you had tried gcsfs, maybe in conjunction with CachedDataset

@tottenjordan
Copy link
Author

@zcain117 the gcsdataset.py currently uses gcsfs. I have not looked into CachedDataset

@harpone
Copy link

harpone commented Jan 14, 2021

@zcain117 yeah I tested gcsfs earlier and indeed it suffers from the same problems due to overhead when having to access lots of small files (e.g. jpegs). Webdataset solves this problem by simply archiving all the data into "shards", which are just .tar archives conatining ~10k image/target pairs and streaming these tar files from (e.g.) GCS.

@tmbdev really explains it best in his intro videos: https://www.youtube.com/watch?v=kNuA2wflygM

Update: actually I got webdataset working with torch-xla by setting an explicit length... but it does seem to mess up the training, possibly getting identical minibatches per TPU core or something. I think I could work out a minimal torch-xla + webdataset example and publish a gist and an issue. I'll post a link here too when it's done.

@harpone
Copy link

harpone commented Jan 15, 2021

OK I have a webdataset working example here using torch-xla's pl.MpDeviceLoader.

No extra dependencies and uses the NVIDIA hosted OpenImages as in the webdataset examples.

It actually works fine (I was suspecting there were duplicate minibatches or something). Setting the dataset length works fine too.

@tottenjordan this should work as a minimal example how to implement your dataset as a webdataset.

@tmbdev
Copy link

tmbdev commented Jan 17, 2021

@harpone WebDataset yields each training sample from each shard exactly once during each epoch, and it uses each shard exactly once during each epoch.

For multinode training, you need to add a "nodesplitter=" argument to the WebDataset constructor to determine how datasets are split across nodes. The default is for each node to train on the entire dataset. Generally, something like "nodesplitter=lambda l: l[node_index::num_nodes]" is a reasonable choice.

If you have suggestions for better defaults or diagnostic messages, please let me know.

@harpone
Copy link

harpone commented Jan 18, 2021

Ah, the dev version seems to do things a bit differently... need to try that out. I was using the stable version (can't remember the version number, but the pypi one) and that seems to require setting the wds.Dataset().shard_selection and wds.Dataset().shard_shuffle (not really a shuffle but a split) explicitly. See lines 69 to 96 in the gist: https://gist.github.com/harpone/3b6003c22295a50cbd3d2cfc566dc115

I was checking if the different minibatches were in fact unique or not, and if I switch off either one of the shard_selection or shard_shuffle, I get duplicates. This is a very subtle bug which leads to convergence issues and should probably be tracked somehow...

Couple of things to note that may have an effect:

  1. This is running in the ddp setting, i.e. one process per accelerator, n dataloader worker processes per accelerator process
  2. Using xmp.spawn(..., start_method='fork') instead of the pytorch default 'spawn' start method (I think this is preferred by torch-xla)

@tottenjordan
Copy link
Author

Thank you @harpone. After exploring petastorm, I've started to focus on webdataset. I got simpler examples to work, but now I'm trying to apply to the distributed training setting. The webdataset API version I am using has a slightly different approach for shard_selection and shard_shuffle whereas I am using splitter and nodeplitter in the wds.WebDataset() function (example here).

I posted this issue here and the error suggests I'm not properly configuring the IterableDataset replica at each worker. When you mentioned the issue with duplicates, did you get a similar error?

Error message:
2021-03-15 16:16:21 10.164.0.29 [0] /anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch/utils/data/dataloader.py:447: UserWarning: Length of IterableDataset <webdataset.dataset.ResizedDataset object at 0x7f5f978c2438> was reported to be 78 (when accessing len(dataloader)), but 205 samples have been fetched. For multiprocessing data-loading, this could be caused by not properly configuring the IterableDataset replica at each worker. Please see https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset for examples.

@tottenjordan
Copy link
Author

I have a working implementation of Webdataset which is achieving epoch times of ~40s. Loss is dropping as expected and validation accuracy seems to be on par with other configurations.

The data is now stored in POSIX tar files on GCS and I'm using webdataset to retrieve each shard, shuffle shards and samples, and deliver to the PyTorch Dataloader as usual. Essentially replacing DistributedSampler() with wds.WebDataset() like this:

def make_train_loader(img_dim, shuffle=10000, batch_size=FLAGS.batch_size):
    
    num_dataset_instances = xm.xrt_world_size() * FLAGS.num_workers
    epoch_size = trainsize // num_dataset_instances
    num_batches = epoch_size // batch_size

    image_transform = transforms.Compose(
        [
            transforms.RandomResizedCrop(img_dim),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]
    )
    dataset = (
        wds.WebDataset("pipe:gsutil cat gs:https://$BUCKET/shards/imagenet-train-{000000..001281}.tar", 
        splitter=wds.split_by_worker, nodesplitter=my_node_splitter, shardshuffle=True, length=num_batches) 
        .shuffle(shuffle)
        .decode("pil") # handler=wds.warn_and_continue
        .to_tuple("ppm;jpg;jpeg;png", "cls")
        .map_tuple(image_transform, identity)
        .batched(batch_size)
        )

    loader = torch.utils.data.DataLoader(dataset, batch_size=None, shuffle=False, num_workers=FLAGS.num_workers)
    return loader

Because this is now an IterableDataset (not map-style), I needed to adjust the training and validation loops to look something like this:

def repeatedly(loader, nepochs=999999999, nbatches=999999999999):
    """Repeatedly returns batches from a DataLoader."""
    for epoch in range(nepochs):
        for sample in islice(loader, nbatches):
            yield sample 
...

def train_imagenet():
    print('==> Preparing data..')

...

    def train_loop_fn(loader, epoch):
        num_dataset_instances = xm.xrt_world_size() * FLAGS.num_workers
        epoch_size = trainsize // num_dataset_instances
        num_batches = epoch_size // FLAGS.batch_size
        tracker = xm.RateTracker()
        model.train()
        for step, (data, target) in enumerate(islice(repeatedly(loader), 0, num_batches)):
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()
            xm.optimizer_step(optimizer)
            tracker.add(FLAGS.batch_size)
            if lr_scheduler:
                lr_scheduler.step()
            if step % FLAGS.log_steps == 0:
                xm.add_step_closure(
                    _train_update, args=(device, step, loss, tracker, epoch, writer))

However, once training reaches the last of my specified epochs, I recieve a BrokenPipe error like this:

2021-03-16 13:49:42 10.164.0.61 [1] Exception ignored in: <_io.TextIOWrapper name='<stdout>' mode='w' encoding='UTF-8'>
2021-03-16 13:49:42 10.164.0.61 [1] BrokenPipeError: [Errno 32] Broken pipe

2021-03-16 13:49:42 10.164.0.61 [1] The above exception was the direct cause of the following exception:
2021-03-16 13:49:42 10.164.0.61 [1] 
2021-03-16 13:49:42 10.164.0.61 [1] Traceback (most recent call last):
2021-03-16 13:49:42 10.164.0.61 [1]   File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/threading.py", line 916, in _bootstrap_inner
2021-03-16 13:49:42 10.164.0.61 [1]     self.run()
2021-03-16 13:49:42 10.164.0.61 [1]   File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/threading.py", line 864, in run
2021-03-16 13:49:42 10.164.0.61 [1]     self._target(*self._args, **self._kwargs)
2021-03-16 13:49:42 10.164.0.61 [1]   File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch_xla/distributed/parallel_loader.py", line 141, in _loader_worker
2021-03-16 13:49:42 10.164.0.61 [1]     _, data = next(data_iter)
2021-03-16 13:49:42 10.164.0.61 [1]   File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 435, in __next__
2021-03-16 13:49:42 10.164.0.61 [1]     data = self._next_data()
2021-03-16 13:49:42 10.164.0.61 [1]   File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 1068, in _next_data
2021-03-16 13:49:42 10.164.0.61 [1]     idx, data = self._get_data()
2021-03-16 13:49:42 10.164.0.61 [1]   File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 1034, in _get_data
2021-03-16 13:49:42 10.164.0.61 [1]     success, data = self._try_get_data()
2021-03-16 13:49:42 10.164.0.61 [1]   File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 885, in _try_get_data
2021-03-16 13:49:42 10.164.0.61 [1]     raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str)) from e
2021-03-16 13:49:42 10.164.0.61 [1] RuntimeError: DataLoader worker (pid(s) 50847, 51213, 52225, 54052) exited unexpectedly

Reading the Multiprocessing Shutdown Logic in the docs here, it seems like I may not be exiting the iterator gracefully or once its depleted.

Any ideas I should pursue?

Full code script and metrics report are attached:

torchXLA-webdataset-trial2-metrics-debug-BrokenPipe.txt

test_train_mp_imagenet_wds.txt

@zcain117

@zcain117
Copy link
Collaborator

3 things:

  1. Does this setup work on CPUs / GPUs?
  2. I see .batched(batch_size) when creating the WebDataset but then loader = torch.utils.data.DataLoader(dataset, batch_size=None, shuffle=False, num_workers=FLAGS.num_workers). Have you tried batch_size=batch_size when making the DataLoader? Or is there a reason to leave batch_size=None for your case?
  3. Have you tried drop_last=True ?

@tmbdev
Copy link

tmbdev commented Mar 16, 2021

I'm updating the distributed training examples; I hope I can push those out this week.

Please also see my comments on webdataset/webdataset#47

@tottenjordan
Copy link
Author

  1. Have not tried CPU or GPU
  2. When batch_size=batch_size in the DataLoader, training hangs before epoch 1 and never begins. Also, webdataset recommends to do batching in the dataset and leave batch_size=None in the DataLoader
  3. Yes, tried drop_last=True, but error message said this couldnt be used while batch_size=None in the DataLoader. Trying both of these gave me answer to (2)

Does the RuntimeError: DataLoader worker (pid(s) 51472) exited unexpectedly error suggest there is "leftover" data in the DataLoader or its queue?

@harpone
Copy link

harpone commented Mar 17, 2021

I have a working implementation of Webdataset which is achieving epoch times of ~40s. Loss is dropping as expected and validation accuracy seems to be on par with other configurations.

The data is now stored in POSIX tar files on GCS and I'm using webdataset to retrieve each shard, shuffle shards and samples, and deliver to the PyTorch Dataloader as usual. Essentially replacing DistributedSampler() with wds.WebDataset() like this:

def make_train_loader(img_dim, shuffle=10000, batch_size=FLAGS.batch_size):
    
    num_dataset_instances = xm.xrt_world_size() * FLAGS.num_workers
    epoch_size = trainsize // num_dataset_instances
    num_batches = epoch_size // batch_size

    image_transform = transforms.Compose(
        [
            transforms.RandomResizedCrop(img_dim),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]
    )
    dataset = (
        wds.WebDataset("pipe:gsutil cat gs:https://$BUCKET/shards/imagenet-train-{000000..001281}.tar", 
        splitter=wds.split_by_worker, nodesplitter=my_node_splitter, shardshuffle=True, length=num_batches) 
        .shuffle(shuffle)
        .decode("pil") # handler=wds.warn_and_continue
        .to_tuple("ppm;jpg;jpeg;png", "cls")
        .map_tuple(image_transform, identity)
        .batched(batch_size)
        )

    loader = torch.utils.data.DataLoader(dataset, batch_size=None, shuffle=False, num_workers=FLAGS.num_workers)
    return loader

Because this is now an IterableDataset (not map-style), I needed to adjust the training and validation loops to look something like this:

def repeatedly(loader, nepochs=999999999, nbatches=999999999999):
    """Repeatedly returns batches from a DataLoader."""
    for epoch in range(nepochs):
        for sample in islice(loader, nbatches):
            yield sample 
...

def train_imagenet():
    print('==> Preparing data..')

...

    def train_loop_fn(loader, epoch):
        num_dataset_instances = xm.xrt_world_size() * FLAGS.num_workers
        epoch_size = trainsize // num_dataset_instances
        num_batches = epoch_size // FLAGS.batch_size
        tracker = xm.RateTracker()
        model.train()
        for step, (data, target) in enumerate(islice(repeatedly(loader), 0, num_batches)):
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()
            xm.optimizer_step(optimizer)
            tracker.add(FLAGS.batch_size)
            if lr_scheduler:
                lr_scheduler.step()
            if step % FLAGS.log_steps == 0:
                xm.add_step_closure(
                    _train_update, args=(device, step, loss, tracker, epoch, writer))

However, once training reaches the last of my specified epochs, I recieve a BrokenPipe error like this:

2021-03-16 13:49:42 10.164.0.61 [1] Exception ignored in: <_io.TextIOWrapper name='<stdout>' mode='w' encoding='UTF-8'>
2021-03-16 13:49:42 10.164.0.61 [1] BrokenPipeError: [Errno 32] Broken pipe

2021-03-16 13:49:42 10.164.0.61 [1] The above exception was the direct cause of the following exception:
2021-03-16 13:49:42 10.164.0.61 [1] 
2021-03-16 13:49:42 10.164.0.61 [1] Traceback (most recent call last):
2021-03-16 13:49:42 10.164.0.61 [1]   File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/threading.py", line 916, in _bootstrap_inner
2021-03-16 13:49:42 10.164.0.61 [1]     self.run()
2021-03-16 13:49:42 10.164.0.61 [1]   File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/threading.py", line 864, in run
2021-03-16 13:49:42 10.164.0.61 [1]     self._target(*self._args, **self._kwargs)
2021-03-16 13:49:42 10.164.0.61 [1]   File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch_xla/distributed/parallel_loader.py", line 141, in _loader_worker
2021-03-16 13:49:42 10.164.0.61 [1]     _, data = next(data_iter)
2021-03-16 13:49:42 10.164.0.61 [1]   File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 435, in __next__
2021-03-16 13:49:42 10.164.0.61 [1]     data = self._next_data()
2021-03-16 13:49:42 10.164.0.61 [1]   File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 1068, in _next_data
2021-03-16 13:49:42 10.164.0.61 [1]     idx, data = self._get_data()
2021-03-16 13:49:42 10.164.0.61 [1]   File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 1034, in _get_data
2021-03-16 13:49:42 10.164.0.61 [1]     success, data = self._try_get_data()
2021-03-16 13:49:42 10.164.0.61 [1]   File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 885, in _try_get_data
2021-03-16 13:49:42 10.164.0.61 [1]     raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str)) from e
2021-03-16 13:49:42 10.164.0.61 [1] RuntimeError: DataLoader worker (pid(s) 50847, 51213, 52225, 54052) exited unexpectedly

Reading the Multiprocessing Shutdown Logic in the docs here, it seems like I may not be exiting the iterator gracefully or once its depleted.

Any ideas I should pursue?

Full code script and metrics report are attached:

torchXLA-webdataset-trial2-metrics-debug-BrokenPipe.txt

test_train_mp_imagenet_wds.txt

@zcain117

Oops forgot to reply yesterday... anyway, seems you got the IterableDataset issue fixed. Never got exactly same warnings like you, but got similar ones when I was explicitly setting the length kwarg in wds.

I think I got a similar BrokenPipeError at the end of each epoch when one of the "pipes" ran out of images (I think... I'm using pytorch lightning and it was with CUDA). I'm now simply stopping training when about 90% of the data is used from the epoch, which is of course not a very elegant solution.

@tottenjordan
Copy link
Author

When trying to run 90 epochs, something is accumulating memory and by epoch ~16 i get the following:

RuntimeError: [enforce fail at CPUAllocator.cpp:65] . DefaultCPUAllocator: can't allocate memory: you tried to allocate 77070336 bytes. Error code 12 (Cannot allocate memory)

This might be related to my use of repeatedly(), which i originally chose to ensure device/worker has the same number of samples:

def repeatedly(loader, nepochs=999999999, nbatches=999999999999):
    """Repeatedly returns batches from a DataLoader."""
    for epoch in range(nepochs):
        for sample in islice(loader, nbatches):
            yield sample   

in the training and test loops:

def train_loop_fn(loader, epoch):
        train_steps = trainsize // (FLAGS.batch_size * xm.xrt_world_size())
        tracker = xm.RateTracker()
        model.train()
        for step, (data, target) in enumerate(islice(repeatedly(loader), 0, train_steps)): 
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()
            xm.optimizer_step(optimizer)
            tracker.add(FLAGS.batch_size)
            if lr_scheduler:
                lr_scheduler.step()
            if step % FLAGS.log_steps == 0:
                xm.add_step_closure(
                    _train_update, args=(device, step, loss, tracker, epoch, writer))

@zcain117 would anything in the metrics report help identify this problem?
metricsD16-wds-128bs-8wrk-90epoch-MemError.txt

@zcain117
Copy link
Collaborator

I don't see that error in the logs you gave. Is that TPU OOM or an OOM on the python/VM side?

@tottenjordan
Copy link
Author

Damn you're right. When training restarted, I must not have captured both reports. I'll try to recreate

@tottenjordan
Copy link
Author

Not sure why the memory errors are not getting copied in my grep "10.164.15.214" /tmp/out-wds-2.log but I attached the metrics-debug report (which didnt copy the memory errors for some reason) and metricsD17-MemErrors.txt where I manually copied them from shell. They start during Epoch 21 train.

Also attached the training code for this run. Notice I didnt use the repeatedley() and still got this error (5 epochs later)

Here is a screenshot of the memory allocation during training. Something is being accumulated during training;I suspect its training data from the loaders, but I'm not sure if anything from the metrics-debug report can validate this

image

2021-03-19 15:26:30 10.164.15.214 [0] ERROR: Unexpected bus error encountered in worker. This might be caused by insufficient shared memory (shm).
2021-03-19 15:26:30 10.164.15.214 [0] Exception in thread Thread-81:
2021-03-19 15:26:30 10.164.15.214 [0] Traceback (most recent call last):
2021-03-19 15:26:30 10.164.15.214 [0]   File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 872, in _try_get_data
2021-03-19 15:26:30 10.164.15.214 [0]     data = self._data_queue.get(timeout=timeout)
2021-03-19 15:26:30 10.164.15.214 [0]   File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/multiprocessing/queues.py", line 113, in get
2021-03-19 15:26:30 10.164.15.214 [0]     return _ForkingPickler.loads(res)
2021-03-19 15:26:30 10.164.15.214 [0]   File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch/multiprocessing/reductions.py", line 282, in rebuild_storage_fd
2021-03-19 15:26:30 10.164.15.214 [0]     fd = df.detach()
2021-03-19 15:26:30 10.164.15.214 [0]   File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/multiprocessing/resource_sharer.py", line 57, in detach
2021-03-19 15:26:30 10.164.15.214 [0]     with _resource_sharer.get_connection(self._id) as conn:
2021-03-19 15:26:30 10.164.15.214 [0]   File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/multiprocessing/resource_sharer.py", line 87, in get_connection
2021-03-19 15:26:30 10.164.15.214 [0]     c = Client(address, authkey=process.current_process().authkey)
2021-03-19 15:26:30 10.164.15.214 [0]   File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/multiprocessing/connection.py", line 487, in Client
2021-03-19 15:26:30 10.164.15.214 [0]     c = SocketClient(address)
2021-03-19 15:26:30 10.164.15.214 [0]   File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/multiprocessing/connection.py", line 614, in SocketClient
2021-03-19 15:26:30 10.164.15.214 [0]     s.connect(address)
2021-03-19 15:26:30 10.164.15.214 [0] ConnectionRefusedError: [Errno 111] Connection refused
2021-03-19 15:26:30 10.164.15.214 [0] 
2021-03-19 15:26:30 10.164.15.214 [0] The above exception was the direct cause of the following exception:
2021-03-19 15:26:30 10.164.15.214 [0] 
2021-03-19 15:26:30 10.164.15.214 [0] Traceback (most recent call last):
2021-03-19 15:26:30 10.164.15.214 [0]   File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/threading.py", line 916, in _bootstrap_inner
2021-03-19 15:26:30 10.164.15.214 [0]     self.run()
2021-03-19 15:26:30 10.164.15.214 [0]   File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/threading.py", line 864, in run
2021-03-19 15:26:30 10.164.15.214 [0]     self._target(*self._args, **self._kwargs)
2021-03-19 15:26:30 10.164.15.214 [0]   File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch_xla/distributed/parallel_loader.py", line 141, in _loader_worker
2021-03-19 15:26:30 10.164.15.214 [0]     _, data = next(data_iter)
2021-03-19 15:26:30 10.164.15.214 [0]   File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 435, in __next__
2021-03-19 15:26:30 10.164.15.214 [0]     data = self._next_data()
2021-03-19 15:26:30 10.164.15.214 [0]   File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 1068, in _next_data
2021-03-19 15:26:30 10.164.15.214 [0]     idx, data = self._get_data()
2021-03-19 15:26:30 10.164.15.214 [0]   File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 1034, in _get_data
2021-03-19 15:26:30 10.164.15.214 [0]     success, data = self._try_get_data()
2021-03-19 15:26:30 10.164.15.214 [0]   File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 885, in _try_get_data
2021-03-19 15:26:30 10.164.15.214 [0]     raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str)) from e
2021-03-19 15:26:30 10.164.15.214 [0] RuntimeError: DataLoader worker (pid(s) 16087) exited unexpectedly

...

2021-03-19 15:26:31 10.164.15.210 [3] Exception in thread Thread-81:
2021-03-19 15:26:31 10.164.15.210 [3] Traceback (most recent call last):
2021-03-19 15:26:31 10.164.15.210 [3]   File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/threading.py", line 916, in _bootstrap_inner
2021-03-19 15:26:31 10.164.15.210 [3]     self.run()
2021-03-19 15:26:31 10.164.15.210 [3]   File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/threading.py", line 864, in run
2021-03-19 15:26:31 10.164.15.210 [3]     self._target(*self._args, **self._kwargs)
2021-03-19 15:26:31 10.164.15.210 [3]   File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch_xla/distributed/parallel_loader.py", line 141, in _loader_worker
2021-03-19 15:26:31 10.164.15.214 [0] Exception ignored in: <_io.TextIOWrapper name='<stdout>' mode='w' encoding='UTF-8'>
2021-03-19 15:26:31 10.164.15.210 [3]     _, data = next(data_iter)
2021-03-19 15:26:31 10.164.15.214 [0] BrokenPipeError: [Errno 32] Broken pipe
2021-03-19 15:26:31 10.164.15.210 [3]   File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 435, in __next__
2021-03-19 15:26:31 10.164.15.210 [3]     data = self._next_data()
2021-03-19 15:26:31 10.164.15.210 [3]   File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 1085, in _next_data
2021-03-19 15:26:31 10.164.15.210 [3]     return self._process_data(data)
2021-03-19 15:26:31 10.164.15.210 [3]   File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 1111, in _process_data
2021-03-19 15:26:31 10.164.15.210 [3]     data.reraise()
2021-03-19 15:26:31 10.164.15.210 [3]   File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch/_utils.py", line 428, in reraise
2021-03-19 15:26:31 10.164.15.210 [3]     raise self.exc_type(msg)
2021-03-19 15:26:31 10.164.15.210 [3] RuntimeError: Caught RuntimeError in DataLoader worker process 1.
2021-03-19 15:26:31 10.164.15.210 [3] Original Traceback (most recent call last):
2021-03-19 15:26:31 10.164.15.210 [3]   File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch/utils/data/_utils/worker.py", line 198, in _worker_loop
2021-03-19 15:26:31 10.164.15.210 [3]     data = fetcher.fetch(index)
2021-03-19 15:26:31 10.164.15.210 [3]   File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 34, in fetch
2021-03-19 15:26:31 10.164.15.210 [3]     data = next(self.dataset_iter)
2021-03-19 15:26:31 10.164.15.210 [3]   File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/webdataset/iterators.py", line 358, in batched
2021-03-19 15:26:31 10.164.15.210 [3]     yield collation_fn(batch)
2021-03-19 15:26:31 10.164.15.210 [3]   File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/webdataset/iterators.py", line 332, in default_collation_fn
2021-03-19 15:26:31 10.164.15.210 [3]     b = torch.stack(list(b))
2021-03-19 15:26:31 10.164.15.210 [3] RuntimeError: [enforce fail at CPUAllocator.cpp:65] . DefaultCPUAllocator: can't allocate memory: you tried to allocate 77070336 bytes. Error code 12 (Cannot allocate memory)

...

2021-03-19 15:26:37 10.164.15.214 [0] Traceback (most recent call last):
2021-03-19 15:26:37 10.164.15.214 [0]   File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/multiprocessing/queues.py", line 234, in _feed
2021-03-19 15:26:37 10.164.15.214 [0]     obj = _ForkingPickler.dumps(obj)
2021-03-19 15:26:37 10.164.15.214 [0]   File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/multiprocessing/reduction.py", line 51, in dumps
2021-03-19 15:26:37 10.164.15.214 [0]     cls(buf, protocol).dump(obj)
2021-03-19 15:26:37 10.164.15.214 [0]   File "/anaconda3/envs/torch-xla-1.7/lib/python3.6/site-packages/torch/multiprocessing/reductions.py", line 321, in reduce_storage
2021-03-19 15:26:37 10.164.15.214 [0]     fd, size = storage._share_fd_()
2021-03-19 15:26:37 10.164.15.214 [0] RuntimeError: unable to write to file </torch_16114_3281789289>

metricsD17-MemErrors.txt

metricsD17-wds-128bs-8wrk-90epoch-MemError.txt

test-train-mp-wds-metricsD17.txt

@zcain117
Copy link
Collaborator

Looks like your VM is running out of memory and seems unrelated to TPUs, so I don't think anything in this TPU metric report will help much

I guess the data is not being released after the epoch ends and the new data is loaded

In your code, you are using MpDeviceLoader, which is not directly a PyTorch DataLoader. It is a class that maintains its own queues and loads data for every call to __iter__. Here is where it loads data into the queue: https://github.com/pytorch/xla/blob/master/torch_xla/distributed/parallel_loader.py#L134

I am not sure what is happening but maybe the WebDataset itself is not releasing memory or maybe the MpDeviceLoader is creating references that prevent old data from being cleaned up.

Some things to test:

  • is the islice necessary in the train/test loops (i.e. for step, (data, target) in enumerate(islice(loader, 0, train_steps)):) ? Maybe you could just use enumerate(loader) in train loop and test loop and you can break the loop if step == train_steps. I'm worried islice might have some weird interactions with the MpDeviceLoader
  • In the training loop, could you try something like sys.getsizeof(loader) to see if the loader object is growing each epoch?
  • Is there any way you could try this workflow with just a regular PyTorch WebDataset and avoid using MpDeviceLoader? You'd probably need to use GPUs or let it run a long time on CPUs to see the shape of memory usage graph

@harpone
Copy link

harpone commented Mar 20, 2021

Oh yeah damn, I had a similar issue... you could track the number of running python processes to see if they're increasing at the beginning of each epoch. Definitely happened also with GPUs. I don't remember how I resolved that :(

@tottenjordan
Copy link
Author

Hi @tmbdev , can you take a look at the memory errors cited above? I was thinking that perhaps the data loader was accumulating samples over training, but I implemented the following to check the loader size during each train epoch:

    def train_loop_fn(loader, epoch):
        train_steps = trainsize // (FLAGS.batch_size * xm.xrt_world_size())
        tracker = xm.RateTracker()
        total_samples = 0
        model.train()
        loader_size = sys.getsizeof(loader)
        for step, (data, target) in enumerate(loader): # repeatedly(loader) | enumerate(islice(loader, 0, train_steps))
            optimizer.zero_grad()
            output = model(data)
            loss = loss_fn(output, target)
            loss.backward()
            xm.optimizer_step(optimizer)
            tracker.add(FLAGS.batch_size)
            total_samples += data.size()[0]
            if lr_scheduler:
                lr_scheduler.step()
            if step % FLAGS.log_steps == 0:
                xm.add_step_closure(
                    _train_update, args=(device, step, loss, tracker, epoch, writer))
                print("Size of TRAIN loader:", loader_size)
            if step == train_steps:
                break                    

        return total_samples 
  • I removed the islice(loader...)
  • The memory of the loader is staying constant between epochs at 56
  • The number of samples cycle through the train and test loops is also consistent.
    • train samples: 39,944
    • test samples: 1,600
  • I also tried all of:
    • (1) .batched(batch_size) on the dataset, DataLoader(batch_size=None...)
    • (2) removed .batched(batch_size) from dataset, DataLoader(batch_size=batch_size...)
    • (3) .batched(batch_size) on the dataset, DataLoader(batch_size=batch_size...)

Any ideas what this could be?

@tottenjordan
Copy link
Author

I noticed that switching to spawn in xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=FLAGS.num_cores, start_method='spawn') reduced the rate at which memory is utilized over training. Was able to get to 31 epochs vs 20 epochs with fork

image

metricsD19-wds-128bs-8wrks-90epoch-MemError-31epochs.txt

@tmbdev
Copy link

tmbdev commented Mar 24, 2021

We have been using WebDataset in workers for a long time and not seen any memory leaks.

WebDataset doesn't operate any differently whether it runs standalone, inside a worker, or on another node. If you suspect a memory leak in the worker process, you can simply run it standalone and see what happens to the memory:

# testing the WebDataset pipeline for memory leaks
import ...
dataset = wds.WebDataset(...)
count = 0
for epoch in range(100):
    for batch in dataset:
        count += 1

If that doesn't run out of memory, then WebDataset isn't leaking any memory. Running such a standalone loop should be pretty quick (since you're not doing any processing).

(Note that process memory usage can grow and stay high even if there is no memory leak, due to the way the Python garbage collector works, but growth should usually stop before running out of memory.)

Error handling and recovery in subprocesses of DataLoader is messy, though, as is the queue handling. That's one of the reason we developed github.com/nvlabs/tensorcom It's not only more testable, it also gives you dynamically scalable and distributed preprocessing and data augmentation.

@tottenjordan
Copy link
Author

tottenjordan commented Mar 25, 2021

Ok - a stupid, careless mistake was leading to the memory leakage. In the validation dataset/loader I was using splitter=None instead of splitter=my_worker_splitter.

Seems this issue occurs when we do not correctly split dataset instances between both nodes/devices and workers - observed this issue when setting either and both the worker and node splitters to None in the validation dataset/loader

def my_node_splitter(urls):
    """Split urls_ correctly per accelerator node
    :param urls:
    :return: slice of urls_
    """
    rank=xm.get_ordinal()
    num_replicas=xm.xrt_world_size()

    urls_this = urls[rank::num_replicas]
    
    return urls_this

def my_worker_splitter(urls):
    """Split urls per worker
    Selects a subset of urls based on Torch get_worker_info.
    Used as a shard selection function in Dataset."""

    urls = [url for url in urls]

    assert isinstance(urls, list)

    worker_info = torch.utils.data.get_worker_info()
    if worker_info is not None:
        wid = worker_info.id
        num_workers = worker_info.num_workers
        return urls[wid::num_workers]
    else:
        return urls

val_dataset = (
        wds.WebDataset("pipe:gsutil cat gs:https://$BUCKET/imagenet-val-{000000..000049}.tar", 
        splitter=my_worker_splitter, nodesplitter=my_node_splitter, shardshuffle=False, length=epoch_test_size) 

image

Avg epoch training time is back to ~1:35... which is very close to the time when data is stored locally on PD/VM!

@stale
Copy link

stale bot commented Jun 16, 2021

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the stale Has not had recent activity label Jun 16, 2021
@stale stale bot closed this as completed Jun 26, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stale Has not had recent activity
Projects
None yet
Development

No branches or pull requests

5 participants