-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Data V2 #3700
Data V2 #3700
Changes from 1 commit
0c42cb9
5ffedfc
80049f8
6f58c2a
1b3ad9a
effc445
9d44ad6
7e89ea6
883b6d7
56d022a
01e12f5
5aea291
f026946
5973b50
a23f47a
a76ea0a
ebf3854
57a67e5
7d21ed8
24a500c
fb13769
ef5187f
b1ea845
0231616
01d76bb
12b6efb
859d3ca
c22dee3
fe5b470
7533c91
ad45659
f944840
3b12a2f
be1f58c
ebdabe0
8693739
b9b0650
88314c7
14296a1
8a08899
3520280
0f1d8a4
93e1e89
40dd695
da3b1b4
801a8f5
007fd0c
d00e1a9
c066804
61c7b14
2b56b14
354010a
568291d
a016103
47db16a
04fdb70
d1d5c4a
5f0c8db
6f63a53
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -5,6 +5,7 @@ | |||||||||||||||||||||||||||||||||||
from allennlp.common.registrable import Registrable | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
from allennlp.common.util import add_noise_to_dict_values, lazy_groups_of | ||||||||||||||||||||||||||||||||||||
from allennlp.common.lazy import Lazy | ||||||||||||||||||||||||||||||||||||
from allennlp.data.batch import Batch as AllennlpBatch | ||||||||||||||||||||||||||||||||||||
from allennlp.data.instance import Instance | ||||||||||||||||||||||||||||||||||||
from allennlp.data.vocabulary import Vocabulary | ||||||||||||||||||||||||||||||||||||
|
@@ -30,7 +31,7 @@ def __iter__(self) -> Iterable[List[int]]: | |||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
@Sampler.register("sequential") | ||||||||||||||||||||||||||||||||||||
class SequentialSampler(Sampler, data.SequentialSampler): | ||||||||||||||||||||||||||||||||||||
def __init__(self, data_source: data.Dataset): | ||||||||||||||||||||||||||||||||||||
def __init__(self, data_source: data.Dataset, **kwargs): | ||||||||||||||||||||||||||||||||||||
super().__init__(data_source) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
|
@@ -47,7 +48,7 @@ class RandomSampler(Sampler, data.RandomSampler): | |||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
def __init__( | ||||||||||||||||||||||||||||||||||||
self, data_source: data.Dataset, replacement: bool = False, num_samples: int = None | ||||||||||||||||||||||||||||||||||||
self, data_source: data.Dataset, replacement: bool = False, num_samples: int = None, **kwargs | ||||||||||||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||||||||||||
super().__init__(data_source, replacement, num_samples) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
|
@@ -60,7 +61,7 @@ class SubsetRandomSampler(Sampler, data.SubsetRandomSampler): | |||||||||||||||||||||||||||||||||||
indices (sequence): a sequence of indices | ||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
def __init__(self, indices: List[int]): | ||||||||||||||||||||||||||||||||||||
def __init__(self, indices: List[int], **kwargs): | ||||||||||||||||||||||||||||||||||||
super().__init__(indices) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
|
@@ -82,7 +83,7 @@ class WeightedRandomSampler(Sampler, data.WeightedRandomSampler): | |||||||||||||||||||||||||||||||||||
[0, 1, 4, 3, 2] | ||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
def __init__(self, weights: List[float], num_samples: int, replacement: bool = True): | ||||||||||||||||||||||||||||||||||||
def __init__(self, weights: List[float], num_samples: int, replacement: bool = True, **kwargs): | ||||||||||||||||||||||||||||||||||||
super().__init__(weights, num_samples, replacement) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
|
@@ -103,15 +104,15 @@ class BasicBatchSampler(BatchSampler, data.BatchSampler): | |||||||||||||||||||||||||||||||||||
[[0, 1, 2], [3, 4, 5], [6, 7, 8]] | ||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
def __init__(self, sampler: Sampler, batch_size: int, drop_last: bool): | ||||||||||||||||||||||||||||||||||||
def __init__(self, sampler: Sampler, batch_size: int, drop_last: bool, **kwargs): | ||||||||||||||||||||||||||||||||||||
super().__init__(sampler, batch_size, drop_last) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
@BatchSampler.register("bucket") | ||||||||||||||||||||||||||||||||||||
class BatchInstanceSampler(BatchSampler): | ||||||||||||||||||||||||||||||||||||
def __init__( | ||||||||||||||||||||||||||||||||||||
self, | ||||||||||||||||||||||||||||||||||||
data: data.Dataset, | ||||||||||||||||||||||||||||||||||||
data_source: data.Dataset, | ||||||||||||||||||||||||||||||||||||
batch_size: int, | ||||||||||||||||||||||||||||||||||||
sorting_keys: List[Tuple[str, str]] = None, | ||||||||||||||||||||||||||||||||||||
padding_noise: float = 0.1, | ||||||||||||||||||||||||||||||||||||
|
@@ -121,7 +122,7 @@ def __init__( | |||||||||||||||||||||||||||||||||||
self._sorting_keys = sorting_keys | ||||||||||||||||||||||||||||||||||||
self._padding_noise = padding_noise | ||||||||||||||||||||||||||||||||||||
self._batch_size = batch_size | ||||||||||||||||||||||||||||||||||||
self.data = data | ||||||||||||||||||||||||||||||||||||
self.data_source = data_source | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
def _argsort_by_padding(self, instances: List[Instance]) -> List[int]: | ||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||
|
@@ -159,7 +160,7 @@ def _argsort_by_padding(self, instances: List[Instance]) -> List[int]: | |||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
def __iter__(self) -> Iterable[List[int]]: | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
indices = self._argsort_by_padding(self.data) | ||||||||||||||||||||||||||||||||||||
indices = self._argsort_by_padding(self.data_source) | ||||||||||||||||||||||||||||||||||||
for group in lazy_groups_of(indices, self._batch_size): | ||||||||||||||||||||||||||||||||||||
yield list(group) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
|
@@ -195,8 +196,8 @@ def __init__( | |||||||||||||||||||||||||||||||||||
dataset: data.Dataset, | ||||||||||||||||||||||||||||||||||||
batch_size: int = 1, | ||||||||||||||||||||||||||||||||||||
shuffle: bool = False, | ||||||||||||||||||||||||||||||||||||
sampler: Sampler = None, | ||||||||||||||||||||||||||||||||||||
batch_sampler: BatchSampler = None, | ||||||||||||||||||||||||||||||||||||
sampler: Lazy[Sampler] = None, | ||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Having this annotation here is not good. It makes it really hard for someone to just instantiate this themselves, because instead of just constructing the sampler and passing it to allennlp/allennlp/common/lazy.py Lines 15 to 18 in 0b68e8e
Instead, keep this constructor, but remove the allennlp/allennlp/training/trainer.py Lines 851 to 863 in 0b68e8e
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see what you mean - we should probably even make this class private, because you should never use it - you should just use the pytorch dataloader. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Making it private means changing type annotations in other code to a private object, which is also unfortunate. Probably better to just be really clear in the class docstring why this object exists (purely to get type annotations that will let us construct it using our |
||||||||||||||||||||||||||||||||||||
batch_sampler: Lazy[BatchSampler] = None, | ||||||||||||||||||||||||||||||||||||
num_workers: int = 0, | ||||||||||||||||||||||||||||||||||||
collate_fn=None, | ||||||||||||||||||||||||||||||||||||
pin_memory: bool = False, | ||||||||||||||||||||||||||||||||||||
|
@@ -207,12 +208,21 @@ def __init__( | |||||||||||||||||||||||||||||||||||
): | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
collate_fn = allennlp_collocate | ||||||||||||||||||||||||||||||||||||
if batch_sampler is not None: | ||||||||||||||||||||||||||||||||||||
batch_sampler_ = batch_sampler.construct(dataset=dataset) | ||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||
batch_sampler_ = None | ||||||||||||||||||||||||||||||||||||
if sampler is not None: | ||||||||||||||||||||||||||||||||||||
sampler_ = sampler.construct(dataset=dataset) | ||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||
sampler_ = None | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
super().__init__( | ||||||||||||||||||||||||||||||||||||
dataset=dataset, | ||||||||||||||||||||||||||||||||||||
batch_size=batch_size, | ||||||||||||||||||||||||||||||||||||
shuffle=shuffle, | ||||||||||||||||||||||||||||||||||||
sampler=sampler, | ||||||||||||||||||||||||||||||||||||
batch_sampler=batch_sampler, | ||||||||||||||||||||||||||||||||||||
sampler=sampler_, | ||||||||||||||||||||||||||||||||||||
batch_sampler=batch_sampler_, | ||||||||||||||||||||||||||||||||||||
num_workers=num_workers, | ||||||||||||||||||||||||||||||||||||
collate_fn=collate_fn, | ||||||||||||||||||||||||||||||||||||
pin_memory=pin_memory, | ||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does it return 1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah.... This is just moving previous behaviour from the iterator onto the dataset instead:
https://github.com/allenai/allennlp/blob/master/allennlp/data/iterators/data_iterator.py#L305
We rely in a couple of places that calling
len
on the iterator (or now, the dataloader) doesn't raise an error. In the case that you have anIterableDataset
and you calllen
, the pytorch dataloader actually spits out a warning - but we need actually calling it to not crash.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
allennlp/allennlp/data/iterators/data_iterator.py
Line 314 in 465224f
I see. Maybe you can add a comment there?