-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Data V2 #3700
Data V2 #3700
Changes from 1 commit
0c42cb9
5ffedfc
80049f8
6f58c2a
1b3ad9a
effc445
9d44ad6
7e89ea6
883b6d7
56d022a
01e12f5
5aea291
f026946
5973b50
a23f47a
a76ea0a
ebf3854
57a67e5
7d21ed8
24a500c
fb13769
ef5187f
b1ea845
0231616
01d76bb
12b6efb
859d3ca
c22dee3
fe5b470
7533c91
ad45659
f944840
3b12a2f
be1f58c
ebdabe0
8693739
b9b0650
88314c7
14296a1
8a08899
3520280
0f1d8a4
93e1e89
40dd695
da3b1b4
801a8f5
007fd0c
d00e1a9
c066804
61c7b14
2b56b14
354010a
568291d
a016103
47db16a
04fdb70
d1d5c4a
5f0c8db
6f63a53
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
from typing import List, Iterable, Tuple, Dict, cast | ||
import logging | ||
from torch.utils import data | ||
|
||
from allennlp.common.util import add_noise_to_dict_values, lazy_groups_of | ||
from allennlp.data.instance import Instance | ||
from allennlp.data.samplers import BatchSampler | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@BatchSampler.register("bucket") | ||
class BucketBatchSampler(BatchSampler): | ||
""" | ||
An sampler which by default, argsorts batches with respect to the maximum input lengths `per | ||
batch`. You can provide a list of field names and padding keys (or pass none, in which case they | ||
will be inferred) which the dataset will be sorted by before doing this batching, causing inputs | ||
with similar length to be batched together, making computation more efficient (as less time is | ||
wasted on padded elements of the batch). | ||
|
||
# Parameters | ||
|
||
data_source: `data.Dataset`, required, | ||
The pytorch `Dataset` of allennlp Instances to bucket. | ||
sorting_keys : List[Tuple[str, str]], optional | ||
To bucket inputs into batches, we want to group the instances by padding length, so that we | ||
minimize the amount of padding necessary per batch. In order to do this, we need to know | ||
which fields need what type of padding, and in what order. | ||
|
||
Specifying the right keys for this is a bit cryptic, so if this is not given we try to | ||
auto-detect the right keys by iterating once through the data up front, reading all of the | ||
padding keys and seeing which one has the longest length. We use that one for padding. | ||
This should give reasonable results in most cases. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it worth giving some example cases where this isn't a reasonable default? "Some cases where it might not be the right thing to do are when you have a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can add that in if you say it's true, but I haven't thought about this deeply 😄 |
||
|
||
When you need to specify this yourself, you can create an instance from your dataset and | ||
call `Instance.get_padding_lengths()` to see a list of all keys used in your data. You | ||
should give one or more of those as the sorting keys here. | ||
batch_size : int, required. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Move this up one, so it's in order? |
||
The size of each batch of instances yielded when calling the dataloader. | ||
padding_noise : float, optional (default=.1) | ||
When sorting by padding length, we add a bit of noise to the lengths, so that the sorting | ||
isn't deterministic. This parameter determines how much noise we add, as a percentage of | ||
the actual padding value for each instance. | ||
drop_last : `bool` | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Give the default here. |
||
If `True`, the sampler will drop the last batch if | ||
its size would be less than batch_size`. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
data_source: data.Dataset, | ||
batch_size: int, | ||
sorting_keys: List[Tuple[str, str]] = None, | ||
padding_noise: float = 0.1, | ||
drop_last: bool = False, | ||
): | ||
|
||
self.vocab = data_source.vocab | ||
self.sorting_keys = sorting_keys | ||
self.padding_noise = padding_noise | ||
self.batch_size = batch_size | ||
self.data_source = data_source | ||
self.drop_last = drop_last | ||
|
||
def _argsort_by_padding(self, instances: Iterable[Instance]) -> List[int]: | ||
""" | ||
Argsorts the instances by their padding lengths, using the keys in | ||
`sorting_keys` (in the order in which they are provided). `sorting_keys` | ||
is a list of `(field_name, padding_key)` tuples. | ||
""" | ||
if not self.sorting_keys: | ||
logger.info("No sorting keys given; trying to guess a good one") | ||
self._guess_sorting_keys(instances) | ||
logger.info(f"Using {self.sorting_keys} as the sorting keys") | ||
instances_with_lengths = [] | ||
for instance in instances: | ||
# Make sure instance is indexed before calling .get_padding | ||
instance.index_fields(self.vocab) | ||
padding_lengths = cast(Dict[str, Dict[str, float]], instance.get_padding_lengths()) | ||
if self.padding_noise > 0.0: | ||
noisy_lengths = {} | ||
for field_name, field_lengths in padding_lengths.items(): | ||
noisy_lengths[field_name] = add_noise_to_dict_values( | ||
field_lengths, self.padding_noise | ||
) | ||
padding_lengths = noisy_lengths | ||
instance_with_lengths = ( | ||
[ | ||
padding_lengths[field_name][padding_key] | ||
for (field_name, padding_key) in self.sorting_keys | ||
], | ||
instance, | ||
) | ||
instances_with_lengths.append(instance_with_lengths) | ||
with_indices = [(x, i) for i, x in enumerate(instances_with_lengths)] | ||
with_indices.sort(key=lambda x: x[0][0]) | ||
return [instance_with_index[-1] for instance_with_index in with_indices] | ||
|
||
def __iter__(self) -> Iterable[List[int]]: | ||
|
||
indices = self._argsort_by_padding(self.data_source) | ||
for group in lazy_groups_of(indices, self.batch_size): | ||
batch_indices = list(group) | ||
if self.drop_last and len(batch_indices) < self.batch_size: | ||
continue | ||
yield batch_indices | ||
|
||
def _guess_sorting_keys(self, instances: Iterable[Instance], num_instances: int = 10) -> None: | ||
""" | ||
Use `num_instances` instances from the dataset to infer the keys used | ||
for sorting the dataset for bucketing. | ||
|
||
# Parameters | ||
|
||
instances : `Iterable[Instance]`, required. | ||
The dataset to guess sorting keys for. | ||
num_instances : `int`, optional (default = 10) | ||
The number of instances to use to guess sorting keys. Typically | ||
the default value is completely sufficient, but if your instances | ||
are not homogeneous, you might need more. | ||
""" | ||
max_length = 0.0 | ||
longest_padding_key: Tuple[str, str] = None | ||
for i, instance in enumerate(instances): | ||
instance.index_fields(self.vocab) | ||
padding_lengths = cast(Dict[str, Dict[str, float]], instance.get_padding_lengths()) | ||
for field_name, field_padding in padding_lengths.items(): | ||
for padding_key, length in field_padding.items(): | ||
if length > max_length: | ||
max_length = length | ||
longest_padding_key = (field_name, padding_key) | ||
if i > num_instances: | ||
# Only use num_instances instances to guess the sorting keys. | ||
break | ||
|
||
if not longest_padding_key: | ||
# This shouldn't ever happen (you basically have to have an empty instance list), but | ||
# just in case... | ||
raise AssertionError( | ||
"Found no field that needed padding; we are surprised you got this error, please " | ||
"open an issue on github" | ||
) | ||
self.sorting_keys = [longest_padding_key] | ||
|
||
def __len__(self): | ||
return len(self.data_source) // self.batch_size |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,8 @@ | ||
from typing import List, Iterable, Tuple, Dict, cast | ||
import logging | ||
from typing import List, Iterable | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Might be worth somewhere in here saying that you can just use the pytorch classes directly without issue if you aren't using |
||
from torch.utils import data | ||
|
||
from allennlp.common.registrable import Registrable | ||
|
||
from allennlp.common.util import add_noise_to_dict_values, lazy_groups_of | ||
from allennlp.data.instance import Instance | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class Sampler(Registrable): | ||
""" | ||
|
@@ -141,138 +135,3 @@ class BasicBatchSampler(BatchSampler, data.BatchSampler): | |
|
||
def __init__(self, sampler: Sampler, batch_size: int, drop_last: bool): | ||
super().__init__(sampler, batch_size, drop_last) | ||
|
||
|
||
@BatchSampler.register("bucket") | ||
class BucketBatchSampler(BatchSampler): | ||
""" | ||
An sampler which by default, argsorts batches with respect to the maximum input lengths `per | ||
batch`. You can provide a list of field names and padding keys (or pass none, in which case they | ||
will be inferred) which the dataset will be sorted by before doing this batching, causing inputs | ||
with similar length to be batched together, making computation more efficient (as less time is | ||
wasted on padded elements of the batch). | ||
|
||
# Parameters | ||
|
||
sorting_keys : List[Tuple[str, str]], optional | ||
To bucket inputs into batches, we want to group the instances by padding length, so that we | ||
minimize the amount of padding necessary per batch. In order to do this, we need to know | ||
which fields need what type of padding, and in what order. | ||
|
||
Specifying the right keys for this is a bit cryptic, so if this is not given we try to | ||
auto-detect the right keys by iterating once through the data up front, reading all of the | ||
padding keys and seeing which one has the longest length. We use that one for padding. | ||
This should give reasonable results in most cases. | ||
|
||
When you need to specify this yourself, you can create an instance from your dataset and | ||
call `Instance.get_padding_lengths()` to see a list of all keys used in your data. You | ||
should give one or more of those as the sorting keys here. | ||
batch_size : int, required. | ||
The size of each batch of instances yielded when calling the dataloader. | ||
padding_noise : float, optional (default=.1) | ||
When sorting by padding length, we add a bit of noise to the lengths, so that the sorting | ||
isn't deterministic. This parameter determines how much noise we add, as a percentage of | ||
the actual padding value for each instance. | ||
drop_last : `bool` | ||
If `True`, the sampler will drop the last batch if | ||
its size would be less than batch_size`. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
data_source: data.Dataset, | ||
batch_size: int, | ||
sorting_keys: List[Tuple[str, str]] = None, | ||
padding_noise: float = 0.1, | ||
drop_last: bool = False, | ||
): | ||
|
||
self.vocab = data_source.vocab | ||
self.sorting_keys = sorting_keys | ||
self.padding_noise = padding_noise | ||
self.batch_size = batch_size | ||
self.data_source = data_source | ||
self.drop_last = drop_last | ||
|
||
def _argsort_by_padding(self, instances: Iterable[Instance]) -> List[int]: | ||
""" | ||
Argsorts the instances by their padding lengths, using the keys in | ||
`sorting_keys` (in the order in which they are provided). `sorting_keys` | ||
is a list of `(field_name, padding_key)` tuples. | ||
""" | ||
if not self.sorting_keys: | ||
logger.info("No sorting keys given; trying to guess a good one") | ||
self._guess_sorting_keys(instances) | ||
logger.info(f"Using {self.sorting_keys} as the sorting keys") | ||
instances_with_lengths = [] | ||
for instance in instances: | ||
# Make sure instance is indexed before calling .get_padding | ||
instance.index_fields(self.vocab) | ||
padding_lengths = cast(Dict[str, Dict[str, float]], instance.get_padding_lengths()) | ||
if self.padding_noise > 0.0: | ||
noisy_lengths = {} | ||
for field_name, field_lengths in padding_lengths.items(): | ||
noisy_lengths[field_name] = add_noise_to_dict_values( | ||
field_lengths, self.padding_noise | ||
) | ||
padding_lengths = noisy_lengths | ||
instance_with_lengths = ( | ||
[ | ||
padding_lengths[field_name][padding_key] | ||
for (field_name, padding_key) in self.sorting_keys | ||
], | ||
instance, | ||
) | ||
instances_with_lengths.append(instance_with_lengths) | ||
with_indices = [(x, i) for i, x in enumerate(instances_with_lengths)] | ||
with_indices.sort(key=lambda x: x[0][0]) | ||
return [instance_with_index[-1] for instance_with_index in with_indices] | ||
|
||
def __iter__(self) -> Iterable[List[int]]: | ||
|
||
indices = self._argsort_by_padding(self.data_source) | ||
for group in lazy_groups_of(indices, self.batch_size): | ||
batch_indices = list(group) | ||
if self.drop_last and len(batch_indices) < self.batch_size: | ||
continue | ||
yield batch_indices | ||
|
||
def _guess_sorting_keys(self, instances: Iterable[Instance], num_instances: int = 10) -> None: | ||
""" | ||
Use `num_instances` instances from the dataset to infer the keys used | ||
for sorting the dataset for bucketing. | ||
|
||
# Parameters | ||
|
||
instances : `Iterable[Instance]`, required. | ||
The dataset to guess sorting keys for. | ||
num_instances : `int`, optional (default = 10) | ||
The number of instances to use to guess sorting keys. Typically | ||
the default value is completely sufficient, but if your instances | ||
are not homogeneous, you might need more. | ||
""" | ||
max_length = 0.0 | ||
longest_padding_key: Tuple[str, str] = None | ||
for i, instance in enumerate(instances): | ||
instance.index_fields(self.vocab) | ||
padding_lengths = cast(Dict[str, Dict[str, float]], instance.get_padding_lengths()) | ||
for field_name, field_padding in padding_lengths.items(): | ||
for padding_key, length in field_padding.items(): | ||
if length > max_length: | ||
max_length = length | ||
longest_padding_key = (field_name, padding_key) | ||
if i > num_instances: | ||
# Only use num_instances instances to guess the sorting keys. | ||
break | ||
|
||
if not longest_padding_key: | ||
# This shouldn't ever happen (you basically have to have an empty instance list), but | ||
# just in case... | ||
raise AssertionError( | ||
"Found no field that needed padding; we are surprised you got this error, please " | ||
"open an issue on github" | ||
) | ||
self.sorting_keys = [longest_padding_key] | ||
|
||
def __len__(self): | ||
return len(self.data_source) // self.batch_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"auto-detect the right keys by iterating through a few instances up front" ?