Skip to content

Commit

Permalink
Replace references to _DataLoaderIter with _BaseDataLoaderIter (pytor…
Browse files Browse the repository at this point in the history
…ch#27105)

Summary:
Back in April, malmaud added type annotations for `dataloader.py`. However, at about the same time, SsnL in pytorch#19228 replaced `_DataLoaderIter` with `_BaseDataLoaderIter` and two subclasses, `_SingleProcessDataLoaderIter`, and `_MultiProcessingDataLoaderIter`. However - probably because these changes happened in parallel at roughly the same time, the type stubs and several other references in the codebase were never updated to match this refactoring.

I've gone ahead and done the updates to reflect the refactoring in pytorch#19228, which fixes the specific type stub/impelementation mismatch pointed out in pytorch#26673, although not the broader problem that pytorch doesn't have a test to make sure that the `.pyi` type stub files match the real API defined in `.py` files.
Pull Request resolved: pytorch#27105

Differential Revision: D17813641

Pulled By: ezyang

fbshipit-source-id: ed7ac025c8d6ad3f298dd073347ec83bb4b6600c
  • Loading branch information
ngoldbaum authored and Thiago Crepaldi committed Feb 4, 2020
1 parent e1fc304 commit 3f81566
Show file tree
Hide file tree
Showing 8 changed files with 16 additions and 14 deletions.
4 changes: 2 additions & 2 deletions torch/csrc/DataLoader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ static PyObject *THPModule_setWorkerPIDs(PyObject *module, PyObject *args) {
}
int64_t key = THPUtils_unpackLong(PyTuple_GET_ITEM(args, 0));
if (worker_pids.find(key) != worker_pids.end()) {
throw ValueError("_set_worker_pids should be called only once for each _DataLoaderIter.");
throw ValueError("_set_worker_pids should be called only once for each _BaseDataLoaderIter.");
}
PyObject *child_pids = PyTuple_GET_ITEM(args, 1);
if (!PyTuple_Check(child_pids)) {
Expand All @@ -182,7 +182,7 @@ static PyObject *THPModule_removeWorkerPIDs(PyObject *module, PyObject *loader_i
int64_t key = THPUtils_unpackLong(loader_id);
auto it = worker_pids.find(key);
if (it == worker_pids.end()) {
throw ValueError("Cannot find worker information for _DataLoaderIter with id %ld.", key);
throw ValueError("Cannot find worker information for _BaseDataLoaderIter with id %ld.", key);
}
worker_pids.erase(it);

Expand Down
2 changes: 1 addition & 1 deletion torch/utils/data/_utils/collate.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
r""""Contains definitions of the methods used by the _DataLoaderIter workers to
r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers to
collate samples fetched from dataset into Tensor(s).
These **needs** to be in global scope since Py2 doesn't support serializing
Expand Down
2 changes: 1 addition & 1 deletion torch/utils/data/_utils/fetch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
r""""Contains definitions of the methods used by the _DataLoaderIter to fetch
r""""Contains definitions of the methods used by the _BaseDataLoaderIter to fetch
data from an iterable-style or map-style dataset. This logic is shared in both
single- and multi-processing data loading.
"""
Expand Down
2 changes: 1 addition & 1 deletion torch/utils/data/_utils/pin_memory.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
r""""Contains definitions of the methods used by the _DataLoaderIter to put
r""""Contains definitions of the methods used by the _BaseDataLoaderIter to put
fetched tensors into pinned memory.
These **needs** to be in global scope since Py2 doesn't support serializing
Expand Down
4 changes: 2 additions & 2 deletions torch/utils/data/_utils/signal_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
our best effort to provide some error message to users when such unfortunate
events happen.
When a _DataLoaderIter starts worker processes, their pids are registered in a
defined in `DataLoader.cpp`: id(_DataLoaderIter) => Collection[ Worker pids ]
When a _BaseDataLoaderIter starts worker processes, their pids are registered in a
defined in `DataLoader.cpp`: id(_BaseDataLoaderIter) => Collection[ Worker pids ]
via `_set_worker_pids`.
When an error happens in a worker process, the main process received a SIGCHLD,
Expand Down
2 changes: 1 addition & 1 deletion torch/utils/data/_utils/worker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
r""""Contains definitions of the methods used by the _DataLoaderIter workers.
r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers.
These **needs** to be in global scope since Py2 doesn't support serializing
static methods.
Expand Down
2 changes: 1 addition & 1 deletion torch/utils/data/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
r"""Definition of the DataLoader and it's iterator _DataLoaderIter classes.
r"""Definition of the DataLoader and associated iterators that subclass _BaseDataLoaderIter
To support these two classes, in `./_utils` we define many utility methods and
functions to be run in multiprocessing. E.g., the data loading worker loop is
Expand Down
12 changes: 7 additions & 5 deletions torch/utils/data/dataloader.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@ class DataLoader(Generic[T_co]):
worker_init_fn: _worker_init_fn_t=...) -> None: ...

def __len__(self) -> int: ...
# We quote '_DataLoaderIter' since it isn't defined yet and the definition can't be moved up since
# '_DataLoaderIter' references 'DataLoader'. Pending updates of PEP 484 will fix this.
def __iter__(self) -> '_DataLoaderIter':...
# We quote '_BaseDataLoaderIter' since it isn't defined yet and the definition can't be moved up
# since '_BaseDataLoaderIter' references 'DataLoader'. In mypy 0.720 and newer a new semantic
# analyzer is used that obviates the need for this but we leave the quoting in to support older
# versions of mypy
def __iter__(self) -> '_BaseDataLoaderIter':...

class _DataLoaderIter:
class _BaseDataLoaderIter:
def __init__(self, loader: DataLoader) -> None:...
def __len__(self) -> int: ...
def __iter__(self) -> _DataLoaderIter: ...
def __iter__(self) -> _BaseDataLoaderIter: ...
def __next__(self) -> Any: ...

0 comments on commit 3f81566

Please sign in to comment.