Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New error using the new update. #359

Closed
jaoeded opened this issue Sep 13, 2021 · 18 comments · Fixed by #367
Closed

New error using the new update. #359

jaoeded opened this issue Sep 13, 2021 · 18 comments · Fixed by #367

Comments

@jaoeded
Copy link

jaoeded commented Sep 13, 2021

[2021-09-13 11:39:11,114] [INFO] [logging.py:68:log_dist] [Rank 0] DeepSpeed info: version=0.5.1, git-hash=unknown, git-branch=unknown
[2021-09-13 11:39:11,216] [INFO] [logging.py:68:log_dist] [Rank 0] initializing deepspeed groups
[2021-09-13 11:39:11,216] [INFO] [logging.py:68:log_dist] [Rank 0] initializing deepspeed model parallel group with size 1
[2021-09-13 11:39:11,216] [INFO] [logging.py:68:log_dist] [Rank 0] initializing deepspeed expert parallel group with size 1
[2021-09-13 11:39:11,217] [INFO] [logging.py:68:log_dist] [Rank 0] creating expert data parallel process group with ranks: [0]
[2021-09-13 11:39:11,217] [INFO] [logging.py:68:log_dist] [Rank 0] creating expert parallel process group with ranks: [0]
[2021-09-13 11:39:11,240] [INFO] [engine.py:198:init] DeepSpeed Flops Profiler Enabled: False
Traceback (most recent call last):
File "train_dalle.py", line 497, in
config_params=deepspeed_config,
File "/home/valterjordan/DALLE-pytorch/dalle_pytorch/distributed_backends/distributed_backend.py", line 152, in distribute
**kwargs,
File "/home/valterjordan/DALLE-pytorch/dalle_pytorch/distributed_backends/deepspeed_backend.py", line 162, in _distribute
**kwargs,
File "/home/valterjordan/miniconda3/envs/dalle_env/lib/python3.7/site-packages/deepspeed/init.py", line 141, in initialize
config_params=config_params)
File "/home/valterjordan/miniconda3/envs/dalle_env/lib/python3.7/site-packages/deepspeed/runtime/engine.py", line 204, in init
self.training_dataloader = self.deepspeed_io(training_data)
File "/home/valterjordan/miniconda3/envs/dalle_env/lib/python3.7/site-packages/deepspeed/runtime/engine.py", line 1188, in deepspeed_io
data_parallel_rank=data_parallel_rank)
File "/home/valterjordan/miniconda3/envs/dalle_env/lib/python3.7/site-packages/deepspeed/runtime/dataloader.py", line 52, in init
rank=data_parallel_rank)
File "/home/valterjordan/miniconda3/envs/dalle_env/lib/python3.7/site-packages/torch/utils/data/distributed.py", line 87, in init
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore
TypeError: object of type 'Processor' has no len()

@lucidrains
Copy link
Owner

should be fixed in the latest commit!

@jaoeded
Copy link
Author

jaoeded commented Sep 13, 2021

Issue still priests but disappears downgrading webdatasets really strange.

@afiaka87
Copy link
Contributor

Still getting the issue as well. @jaoeded what version of webdataset did you use to fix it?

@afiaka87
Copy link
Contributor

afiaka87 commented Sep 25, 2021

I guess we can just monkey patch the __len__ dunder method...
#366

@afiaka87
Copy link
Contributor

microsoft/DeepSpeed#1371

@rom1504
Copy link
Contributor

rom1504 commented Sep 25, 2021

just a note:
deepspeed supports IterableDataset https://github.com/microsoft/DeepSpeed/blob/86dd6a6484a4c3aa8a04fc7e7e6c67652b09dad5/deepspeed/runtime/engine.py#L1141
webdataset exposes an IterableDataset https://github.com/webdataset/webdataset

iterable dataset do not have a __len__ method, only __iter__

So I'm wondering if the problem could be coming from the use of webdataset and deepspeed here that would be incorrect in some way ?

@rom1504
Copy link
Contributor

rom1504 commented Sep 25, 2021

for example this call https://github.com/lucidrains/DALLE-pytorch/blob/main/train_dalle.py#L392 to torch.utils.data.distributed.DistributedSampler seems suspicious and unrelated with deepspeed

@afiaka87
Copy link
Contributor

@rom1504 All I know is that maintaining a DeepSpeed compatible codebase has been an utter nightmare since day one. Interop with deep speed breaks something fairly frequently. As such my motivation to fix these things "properly" is pretty much non-existent.

@afiaka87
Copy link
Contributor

afiaka87 commented Sep 25, 2021

I agree that it likely has something to do with the data sampler; but I didn't want to just remove that as it seems to be explicitly for handling the multi-GPU scenario with DeepSpeed I believe? @janEbert would love some background on this if you have the time.

@rom1504
Copy link
Contributor

rom1504 commented Sep 25, 2021

I guess @robvanvolt might be interested to have a look at this code since it's related with his deepspeed issue

@janEbert
Copy link
Contributor

janEbert commented Sep 25, 2021

Hey, the DistributedSampler is indeed unrelated to DeepSpeed, it's actually for Horovod. I should have documented this, sorry about that.

I have an intuition about the issue here (something about the dataset returned by DeepSpeed colliding with WebDatasets) but need to see whether I can fix it tomorrow (I'm also not very up-to-date with the code base). Otherwise after Wednesday is the earliest time.

Sorry for the brevity and thanks for the ping!

@janEbert
Copy link
Contributor

I had a quick look at this; I'm not yet familiar with WebDatasets, so maybe you can answer this more easily.
Why is it important to use a wds.WebLoader? Can't we pass the wds.WebDataset to distr_backend.distribute to let DeepSpeed handle data loading with its distr_dl (and removing this if-branch accordingly)?

Sorry, I know you probably answered this already during all the testing and implementing.

@janEbert
Copy link
Contributor

janEbert commented Sep 26, 2021

The problem is that PyTorch's sampling strategy does not work with IterableDatasets; see the open issue here: pytorch/pytorch#28743. DeepSpeed tries to apply a DistributedSampler when it is passed a dataset in initialization, and WebDatasets are IterableDatasets.
So to fix this, the only change we need is to pass None here when ENABLE_WEDATASET is True, foregoing the (anyway redundant) DeepSpeed sampling wrapper that causes the error.

I tried this on the supercomputer and cannot get a single iteration from the WebLoader, so please someone else test it as well.

janEbert added a commit to janEbert/DALLE-pytorch that referenced this issue Sep 26, 2021
Causes errors due to PyTorch's `torch.data.utils.DistributedSampler` not
being applicable to `torch.data.utils.IterableDataset`s (which
WebDatasets are implementing).

Fix lucidrains#359.
janEbert added a commit to janEbert/DALLE-pytorch that referenced this issue Sep 26, 2021
Wrapping causes errors due to PyTorch's
`torch.data.utils.DistributedSampler` not being applicable to
`torch.data.utils.IterableDataset`s (which WebDatasets are
implementing).

Fix lucidrains#359.
@jaoeded
Copy link
Author

jaoeded commented Sep 26, 2021

Is this issue safe to close now?

@jaoeded
Copy link
Author

jaoeded commented Sep 26, 2021

@janEbert @afiaka87 @rom1504 I'm closing this issue janEbert's pr seems to have fixed it. reopen if needed.

@jaoeded jaoeded closed this as completed Sep 26, 2021
@jaoeded
Copy link
Author

jaoeded commented Sep 26, 2021

One last note feel free to try the pr yourselves. if it does not work feel free to reopen it worked for me.

lucidrains pushed a commit that referenced this issue Sep 26, 2021
Wrapping causes errors due to PyTorch's
`torch.data.utils.DistributedSampler` not being applicable to
`torch.data.utils.IterableDataset`s (which WebDatasets are
implementing).

Fix #359.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants
@lucidrains @rom1504 @afiaka87 @janEbert @jaoeded and others