-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Data V2 #3700
Data V2 #3700
Changes from 1 commit
0c42cb9
5ffedfc
80049f8
6f58c2a
1b3ad9a
effc445
9d44ad6
7e89ea6
883b6d7
56d022a
01e12f5
5aea291
f026946
5973b50
a23f47a
a76ea0a
ebf3854
57a67e5
7d21ed8
24a500c
fb13769
ef5187f
b1ea845
0231616
01d76bb
12b6efb
859d3ca
c22dee3
fe5b470
7533c91
ad45659
f944840
3b12a2f
be1f58c
ebdabe0
8693739
b9b0650
88314c7
14296a1
8a08899
3520280
0f1d8a4
93e1e89
40dd695
da3b1b4
801a8f5
007fd0c
d00e1a9
c066804
61c7b14
2b56b14
354010a
568291d
a016103
47db16a
04fdb70
d1d5c4a
5f0c8db
6f63a53
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,3 @@ | ||
|
||
from typing import List, Iterable, Tuple, Dict, cast | ||
import logging | ||
from torch.utils import data | ||
|
@@ -18,27 +17,23 @@ | |
|
||
|
||
class Sampler(Registrable): | ||
|
||
def __iter__(self) -> Iterable[int]: | ||
|
||
raise NotImplementedError | ||
|
||
|
||
class BatchSampler(Registrable): | ||
|
||
def __iter__(self) -> Iterable[List[int]]: | ||
|
||
raise NotImplementedError | ||
|
||
|
||
@Sampler.register("sequential") | ||
class SequentialSampler(Sampler, data.SequentialSampler): | ||
|
||
def __init__(self, data_source: data.Dataset): | ||
super().__init__(data_source) | ||
|
||
|
||
|
||
@Sampler.register("random") | ||
class RandomSampler(Sampler, data.RandomSampler): | ||
r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset. | ||
|
@@ -50,7 +45,10 @@ class RandomSampler(Sampler, data.RandomSampler): | |
num_samples (int): number of samples to draw, default=`len(dataset)`. This argument | ||
is supposed to be specified only when `replacement` is ``True``. | ||
""" | ||
def __init__(self, data_source: data.Dataset, replacement: bool = False, num_samples: int = None): | ||
|
||
def __init__( | ||
self, data_source: data.Dataset, replacement: bool = False, num_samples: int = None | ||
): | ||
super().__init__(data_source, replacement, num_samples) | ||
|
||
|
||
|
@@ -61,6 +59,7 @@ class SubsetRandomSampler(Sampler, data.SubsetRandomSampler): | |
Arguments: | ||
indices (sequence): a sequence of indices | ||
""" | ||
|
||
def __init__(self, indices: List[int]): | ||
super().__init__(indices) | ||
|
||
|
@@ -82,6 +81,7 @@ class WeightedRandomSampler(Sampler, data.WeightedRandomSampler): | |
>>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False)) | ||
[0, 1, 4, 3, 2] | ||
""" | ||
|
||
def __init__(self, weights: List[float], num_samples: int, replacement: bool = True): | ||
super().__init__(weights, num_samples, replacement) | ||
|
||
|
@@ -189,17 +189,35 @@ def allennlp_collocate(batch): | |
return batch.as_tensor_dict(batch.get_padding_lengths()) | ||
|
||
|
||
|
||
|
||
class DataLoader(Registrable, data.DataLoader): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why put this in |
||
|
||
def __init__(self, dataset: data.Dataset, batch_size: int = 1, shuffle: bool = False, sampler: Sampler = None, | ||
batch_sampler: BatchSampler = None, num_workers: int = 0, collate_fn=None, | ||
pin_memory: bool = False, drop_last: bool = False, timeout: bool = 0, | ||
worker_init_fn=None, multiprocessing_context: str = None): | ||
def __init__( | ||
self, | ||
dataset: data.Dataset, | ||
batch_size: int = 1, | ||
shuffle: bool = False, | ||
sampler: Sampler = None, | ||
batch_sampler: BatchSampler = None, | ||
num_workers: int = 0, | ||
collate_fn=None, | ||
pin_memory: bool = False, | ||
drop_last: bool = False, | ||
timeout: bool = 0, | ||
worker_init_fn=None, | ||
multiprocessing_context: str = None, | ||
): | ||
|
||
collate_fn = allennlp_collocate | ||
super().__init__(self, dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, | ||
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, | ||
pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, | ||
worker_init_fn=worker_init_fn, multiprocessing_context=multiprocessing_context) | ||
super().__init__( | ||
dataset=dataset, | ||
batch_size=batch_size, | ||
shuffle=shuffle, | ||
sampler=sampler, | ||
batch_sampler=batch_sampler, | ||
num_workers=num_workers, | ||
collate_fn=collate_fn, | ||
pin_memory=pin_memory, | ||
drop_last=drop_last, | ||
timeout=timeout, | ||
worker_init_fn=worker_init_fn, | ||
multiprocessing_context=multiprocessing_context, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -44,7 +44,9 @@ def setUp(self): | |
self.model = SimpleTagger.from_params(vocab=self.vocab, params=self.model_params) | ||
self.optimizer = torch.optim.SGD(self.model.parameters(), 0.01, momentum=0.9) | ||
self.data_loader = DataLoader(self.instances, batch_size=2, collate_fn=allennlp_collocate) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I got here and realized that I don't really understand how the data loader interacts with the samplers. Why do you have a batch size here, when you can have a batch sampler that sets a different batch size? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can either pass a batch size, in which case it will use a uniform random sampler, or you can pass a |
||
self.validation_data_loader = DataLoader(self.instances, batch_size=2, collate_fn=allennlp_collocate) | ||
self.validation_data_loader = DataLoader( | ||
self.instances, batch_size=2, collate_fn=allennlp_collocate | ||
) | ||
self.instances.index_with(vocab) | ||
|
||
def test_trainer_can_run(self): | ||
|
@@ -102,9 +104,7 @@ def test_trainer_can_run_exponential_moving_average(self): | |
@pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device registered.") | ||
def test_trainer_can_run_cuda(self): | ||
self.model.cuda() | ||
trainer = Trainer( | ||
self.model, self.optimizer, self.data_loader, num_epochs=2, cuda_device=0 | ||
) | ||
trainer = Trainer(self.model, self.optimizer, self.data_loader, num_epochs=2, cuda_device=0) | ||
metrics = trainer.train() | ||
assert "peak_cpu_memory_MB" in metrics | ||
assert isinstance(metrics["peak_cpu_memory_MB"], float) | ||
|
@@ -118,11 +118,7 @@ def test_passing_trainer_multiple_gpus_raises_error(self): | |
|
||
with pytest.raises(ConfigurationError): | ||
Trainer( | ||
self.model, | ||
self.optimizer, | ||
self.data_loader, | ||
num_epochs=2, | ||
cuda_device=[0, 1], | ||
self.model, self.optimizer, self.data_loader, num_epochs=2, cuda_device=[0, 1], | ||
) | ||
|
||
def test_trainer_can_resume_training(self): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,7 @@ | |
import re | ||
import time | ||
import traceback | ||
from typing import Dict, List, Optional, Tuple, Union, Iterable, Any | ||
from typing import Dict, List, Optional, Tuple, Union, Any | ||
|
||
import torch | ||
import torch.distributed as dist | ||
|
@@ -16,7 +16,6 @@ | |
from allennlp.common import Lazy, Tqdm | ||
from allennlp.common.checks import ConfigurationError, check_for_gpu | ||
from allennlp.common import util as common_util | ||
from allennlp.data.instance import Instance | ||
|
||
from allennlp.data.samplers import DataLoader | ||
|
||
|
@@ -36,7 +35,7 @@ | |
logger = logging.getLogger(__name__) | ||
|
||
|
||
@TrainerBase.register("trainer", constructor="from_partial_objects") | ||
@TrainerBase.register("default", constructor="from_partial_objects") | ||
class Trainer(TrainerBase): | ||
def __init__( | ||
self, | ||
|
@@ -512,9 +511,13 @@ def _validation_loss(self) -> Tuple[float, int]: | |
if self._validation_data_loader is not None: | ||
validation_data_loader = self._validation_data_loader | ||
else: | ||
raise ConfigurationError("Validation results cannot be calculated without a validation_data_loader") | ||
raise ConfigurationError( | ||
"Validation results cannot be calculated without a validation_data_loader" | ||
) | ||
|
||
val_generator_tqdm = Tqdm.tqdm(iter(validation_data_loader), total=len(validation_data_loader)) | ||
val_generator_tqdm = Tqdm.tqdm( | ||
iter(validation_data_loader), total=len(validation_data_loader) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it not work to just say There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh yep, nice 👍 |
||
) | ||
batches_this_epoch = 0 | ||
val_loss = 0 | ||
done_early = False | ||
|
@@ -814,8 +817,8 @@ def from_partial_objects( | |
cls, | ||
model: Model, | ||
serialization_dir: str, | ||
data_loader: Lazy[DataLoader], | ||
validation_data_loader: Lazy[DataLoader] = None, | ||
data_loader: DataLoader, | ||
validation_data_loader: DataLoader = None, | ||
local_rank: int = 0, | ||
patience: int = None, | ||
validation_metric: str = "-loss", | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this need
Registrable
, or justFromParams
? Do you imagine someone having to pick a subclass of these?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, given my comment below, yes, it needs to be
Registrable
. And you need to have a line like this after the class definition:DataLoader.register("default", constructor="from_partial_objects")(DataLoader)
. And you need adefault_implementation = "default"
line as a class variable.