-
Notifications
You must be signed in to change notification settings - Fork 7
/
base_dataset.py
66 lines (50 loc) · 2.01 KB
/
base_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch
from torch.utils.data.dataset import Dataset
class BaseDataset(Dataset):
def __init__(self, config, dataset_type, dataset_name):
"""Base class for a dataset.
Args:
- config: dict object with dataset config
- dataset_type: "train", "test" or "val"
- dataset_name: string of the dataset name
"""
super().__init__()
if config is None:
config = {}
self.config = config
self._dataset_name = dataset_name
self._dataset_type = dataset_type
self._data_dir = self.config.data_dir
def __len__(self):
raise NotImplementedError
def __getitem__(self, idx):
"""
__getitem__ of a torch dataset.
Args:
idx (int): Index of the sample to be loaded.
"""
raise NotImplementedError
def custom_collate_fn(data):
""" Custom collate function for data loader to create mini-batch tensors of the same shape.
Args:
data: list of tuple (audio, caption).
- audio: torch tensor of shape (?); variable length.
- caption: torch tensor of shape (?); variable length.
Returns:
padded_audio: torch tensor of shape (batch_size, padded_audio_length).
padded_captions: torch tensor of shape (batch_size, padded_cap_length).
"""
data.sort(key=lambda x: len(x[1]), reverse=True)
audio_tracks, captions = zip(*data)
audio_lengths = [len(audio) for audio in audio_tracks]
padded_audio = torch.zeros(
len(audio_tracks), max(audio_lengths)).float()
cap_lengths = [len(cap) for cap in captions]
padded_captions = torch.zeros(len(captions), max(cap_lengths)).long()
for i, cap in enumerate(captions):
caption_end = cap_lengths[i]
padded_captions[i, :caption_end] = cap
audio_end = audio_lengths[i]
padded_audio[i, :audio_end] = audio_tracks[i]
audio_lengths = torch.Tensor(audio_lengths).long()
return padded_audio, audio_lengths, padded_captions, cap_lengths