-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Data V2 #3700
Data V2 #3700
Changes from 1 commit
0c42cb9
5ffedfc
80049f8
6f58c2a
1b3ad9a
effc445
9d44ad6
7e89ea6
883b6d7
56d022a
01e12f5
5aea291
f026946
5973b50
a23f47a
a76ea0a
ebf3854
57a67e5
7d21ed8
24a500c
fb13769
ef5187f
b1ea845
0231616
01d76bb
12b6efb
859d3ca
c22dee3
fe5b470
7533c91
ad45659
f944840
3b12a2f
be1f58c
ebdabe0
8693739
b9b0650
88314c7
14296a1
8a08899
3520280
0f1d8a4
93e1e89
40dd695
da3b1b4
801a8f5
007fd0c
d00e1a9
c066804
61c7b14
2b56b14
354010a
568291d
a016103
47db16a
04fdb70
d1d5c4a
5f0c8db
6f63a53
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,9 @@ | ||
from typing import List | ||
import logging | ||
from torch.utils import data | ||
|
||
from allennlp.common.registrable import Registrable | ||
from allennlp.data.instance import Instance | ||
|
||
from allennlp.common.lazy import Lazy | ||
from allennlp.data.batch import Batch | ||
|
@@ -19,19 +21,35 @@ | |
logger = logging.getLogger(__name__) | ||
|
||
|
||
def allennlp_collate(batch): | ||
batch = Batch(batch) | ||
def allennlp_collate(instances: List[Instance]): | ||
batch = Batch(instances) | ||
return batch.as_tensor_dict(batch.get_padding_lengths()) | ||
|
||
|
||
class DataLoader(Registrable, data.DataLoader): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why put this in |
||
""" | ||
A registrable version of the pytorch [DataLoader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader). | ||
The only reason this class exists is so that we can construct a DataLoader | ||
from a configuration file. Instead of using this class directly in python code, | ||
you should just use the pytorch dataloader with allennlp's custom collate function: | ||
|
||
``` | ||
from torch.utils.data import DataLoader | ||
|
||
from allennlp.data.samplers import allennlp_collate | ||
# Construct a dataloader directly for a dataset which contains allennlp | ||
# Instances which have _already_ been indexed. | ||
my_loader = DataLoader(dataset, batch_size=32, collate_fn=allennlp_collate) | ||
``` | ||
""" | ||
|
||
def __init__( | ||
self, | ||
dataset: data.Dataset, | ||
batch_size: int = 1, | ||
shuffle: bool = False, | ||
sampler: Lazy[Sampler] = None, | ||
batch_sampler: Lazy[BatchSampler] = None, | ||
sampler: Sampler = None, | ||
batch_sampler: BatchSampler = None, | ||
num_workers: int = 0, | ||
collate_fn=None, | ||
pin_memory: bool = False, | ||
|
@@ -40,8 +58,37 @@ def __init__( | |
worker_init_fn=None, | ||
multiprocessing_context: str = None, | ||
): | ||
super().__init__( | ||
dataset=dataset, | ||
batch_size=batch_size, | ||
shuffle=shuffle, | ||
sampler=sampler, | ||
batch_sampler=batch_sampler, | ||
num_workers=num_workers, | ||
collate_fn=collate_fn, | ||
pin_memory=pin_memory, | ||
drop_last=drop_last, | ||
timeout=timeout, | ||
worker_init_fn=worker_init_fn, | ||
multiprocessing_context=multiprocessing_context, | ||
) | ||
|
||
@classmethod | ||
def from_partial_objects( | ||
cls, | ||
dataset: data.Dataset, | ||
batch_size: int = 1, | ||
shuffle: bool = False, | ||
sampler: Lazy[Sampler] = None, | ||
batch_sampler: Lazy[BatchSampler] = None, | ||
num_workers: int = 0, | ||
pin_memory: bool = False, | ||
drop_last: bool = False, | ||
timeout: int = 0, | ||
worker_init_fn=None, | ||
multiprocessing_context: str = None, | ||
) -> "DataLoader": | ||
|
||
collate_fn = allennlp_collate | ||
if batch_sampler is not None: | ||
batch_sampler_ = batch_sampler.construct(data_source=dataset) | ||
else: | ||
|
@@ -51,17 +98,23 @@ def __init__( | |
else: | ||
sampler_ = None | ||
|
||
super().__init__( | ||
return cls( | ||
dataset=dataset, | ||
batch_size=batch_size, | ||
shuffle=shuffle, | ||
sampler=sampler_, | ||
batch_sampler=batch_sampler_, | ||
num_workers=num_workers, | ||
collate_fn=collate_fn, | ||
# NOTE: This default is different from the normal `None`. | ||
# We assume that if you are using this class you are using an | ||
# allennlp dataset of instances, which would require this. | ||
collate_fn=allennlp_collate, | ||
pin_memory=pin_memory, | ||
drop_last=drop_last, | ||
timeout=timeout, | ||
worker_init_fn=worker_init_fn, | ||
multiprocessing_context=multiprocessing_context, | ||
) | ||
|
||
|
||
DataLoader.register("default", "from_partial_objects")(DataLoader) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,13 +35,17 @@ def __iter__(self) -> Iterable[List[int]]: | |
|
||
@Sampler.register("sequential") | ||
class SequentialSampler(Sampler, data.SequentialSampler): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Might be useful to have a docstring pointing to the pytorch classes for all of these. |
||
""" | ||
A registerable version of pytorch's [SequentialSampler](https://pytorch.org/docs/stable/data.html#torch.utils.data.SequentialSampler). | ||
""" | ||
def __init__(self, data_source: data.Dataset): | ||
super().__init__(data_source) | ||
|
||
|
||
@Sampler.register("random") | ||
class RandomSampler(Sampler, data.RandomSampler): | ||
""" | ||
A registerable version of pytorch's [RandomSampler](https://pytorch.org/docs/stable/data.html#torch.utils.data.RandomSampler). | ||
Samples elements randomly. If without replacement, then sample from a shuffled dataset. | ||
If with replacement, then user can specify `num_samples` to draw. | ||
|
||
|
@@ -64,6 +68,7 @@ def __init__( | |
@Sampler.register("subset_random") | ||
class SubsetRandomSampler(Sampler, data.SubsetRandomSampler): | ||
""" | ||
A registerable version of pytorch's [SubsetRandomSampler](https://pytorch.org/docs/stable/data.html#torch.utils.data.SubsetRandomSampler). | ||
Samples elements randomly from a given list of indices, without replacement. | ||
|
||
# Parameters | ||
|
@@ -78,7 +83,8 @@ def __init__(self, indices: List[int]): | |
@Sampler.register("weighted_random") | ||
class WeightedRandomSampler(Sampler, data.WeightedRandomSampler): | ||
""" | ||
Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights). | ||
A registerable version of pytorch's [WeightedRandomSampler](https://pytorch.org/docs/stable/data.html#torch.utils.data.WeightedRandomSampler). | ||
Samples elements from `[0,...,len(weights)-1]` with given probabilities (weights). | ||
|
||
# Parameters: | ||
weights : `List[float]` | ||
|
@@ -106,6 +112,7 @@ def __init__(self, weights: List[float], num_samples: int, replacement: bool = T | |
@BatchSampler.register("basic") | ||
class BasicBatchSampler(BatchSampler, data.BatchSampler): | ||
""" | ||
A registerable version of pytorch's [BatchSampler](https://pytorch.org/docs/stable/data.html#torch.utils.data.BatchSampler). | ||
Wraps another sampler to yield a mini-batch of indices. | ||
|
||
# Parameters | ||
|
@@ -134,10 +141,10 @@ def __init__(self, sampler: Sampler, batch_size: int, drop_last: bool): | |
class BatchInstanceSampler(BatchSampler): | ||
""" | ||
An sampler which by default, argsorts batches with respect to the maximum input lengths `per | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. argsorts instances before batching? Not sure why the second sentence says "additionally", because you can't do an argsort without those padding keys. |
||
batch`. Additionally, you can provide a list of field names and padding keys which the dataset | ||
will be sorted by before doing this batching, causing inputs with similar length to be batched | ||
together, making computation more efficient (as less time is wasted on padded elements of the | ||
batch). | ||
batch`. You can provide a list of field names and padding keys (or pass none, in which case they | ||
will be inferred) which the dataset will be sorted by before doing this batching, causing inputs | ||
with similar length to be batched together, making computation more efficient (as less time is | ||
wasted on padded elements of the batch). | ||
|
||
# Parameters | ||
|
||
|
@@ -154,15 +161,15 @@ class BatchInstanceSampler(BatchSampler): | |
When you need to specify this yourself, you can create an instance from your dataset and | ||
call `Instance.get_padding_lengths()` to see a list of all keys used in your data. You | ||
should give one or more of those as the sorting keys here. | ||
batch_size : int, required. | ||
The size of each batch of instances yielded when calling the dataloader. | ||
padding_noise : float, optional (default=.1) | ||
When sorting by padding length, we add a bit of noise to the lengths, so that the sorting | ||
isn't deterministic. This parameter determines how much noise we add, as a percentage of | ||
the actual padding value for each instance. | ||
|
||
Note that if you specify `max_instances_in_memory`, the first batch will only be the | ||
biggest from among the first "max instances in memory" instances. | ||
batch_size : int, optional, (default = 32) | ||
The size of each batch of instances yielded when calling the iterator. | ||
|
||
""" | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this need
Registrable
, or justFromParams
? Do you imagine someone having to pick a subclass of these?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, given my comment below, yes, it needs to be
Registrable
. And you need to have a line like this after the class definition:DataLoader.register("default", constructor="from_partial_objects")(DataLoader)
. And you need adefault_implementation = "default"
line as a class variable.