Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[s2t] DataLoader with BatchSampler or DistributeBatchSampler #1242

Merged
merged 2 commits into from
Dec 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions paddlespeech/s2t/exps/u2_st/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,8 @@ def setup_dataloader(self):
n_iter_processes=config.collator.num_workers,
subsampling_factor=1,
load_aux_output=load_transcript,
num_encs=1)
num_encs=1,
dist_sampler=True)

self.valid_loader = BatchDataLoader(
json_file=config.data.dev_manifest,
Expand All @@ -313,7 +314,8 @@ def setup_dataloader(self):
n_iter_processes=config.collator.num_workers,
subsampling_factor=1,
load_aux_output=load_transcript,
num_encs=1)
num_encs=1,
dist_sampler=True)
logger.info("Setup train/valid Dataloader!")
else:
# test dataset, return raw text
Expand All @@ -335,7 +337,8 @@ def setup_dataloader(self):
augmentation_config, # aug will be off when train_mode=False
n_iter_processes=config.collator.num_workers,
subsampling_factor=1,
num_encs=1)
num_encs=1,
dist_sampler=False)

logger.info("Setup test Dataloader!")

Expand Down Expand Up @@ -542,7 +545,8 @@ def test(self):
len_refs += metrics['len_refs']
num_ins += metrics['num_ins']
rtf = num_time / (num_frames * stride_ms)
logger.info("RTF: %f, instance (%d), batch BELU = %f" % (rtf, num_ins, bleu))
logger.info("RTF: %f, instance (%d), batch BELU = %f" %
(rtf, num_ins, bleu))

rtf = num_time / (num_frames * stride_ms)
msg = "Test: "
Expand Down
27 changes: 16 additions & 11 deletions paddlespeech/s2t/io/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,9 @@ def __call__(self, batch):
# text data (output): (text_len, )
ys_data.append(ud)

assert xs_data[0][0] is not None, "please check Reader and Augmentation impl."

assert xs_data[0][
0] is not None, "please check Reader and Augmentation impl."

xs_pad, ilens = [], []
for xs in xs_data:
# perform subsampling
Expand All @@ -79,22 +80,26 @@ def __call__(self, batch):
# perform padding and convert to tensor
# currently only support real number
xs_pad.append(pad_list(xs, 0).astype(self.dtype))

if not self.load_aux_input:
xs_pad, ilens = xs_pad[0], ilens[0]
break

# NOTE: this is for multi-output (e.g., speech translation)
ys_pad, olens = [], []

for ys in ys_data:
ys_pad.append(pad_list(
[np.array(y[0][:]) if isinstance(y, tuple) else y for y in ys],
self.ignore_id))
ys_pad.append(
pad_list([
np.array(y[0][:]) if isinstance(y, tuple) else y for y in ys
], self.ignore_id))

olens.append(
np.array([
y[0].shape[0] if isinstance(y, tuple) else y.shape[0]
for y in ys
]))

olens.append(np.array(
[y[0].shape[0] if isinstance(y, tuple) else y.shape[0] for y in ys]))

if not self.load_aux_output:
ys_pad, olens = ys_pad[0], olens[0]
break
Expand Down
25 changes: 19 additions & 6 deletions paddlespeech/s2t/io/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import jsonlines
import numpy as np
from paddle.io import BatchSampler
from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler

Expand Down Expand Up @@ -76,7 +77,8 @@ def __init__(self,
subsampling_factor: int=1,
load_aux_input: bool=False,
load_aux_output: bool=False,
num_encs: int=1):
num_encs: int=1,
dist_sampler: bool=False):
self.json_file = json_file
self.train_mode = train_mode
self.use_sortagrad = sortagrad == -1 or sortagrad > 0
Expand All @@ -94,6 +96,7 @@ def __init__(self,
self.n_iter_processes = n_iter_processes
self.load_aux_input = load_aux_input
self.load_aux_output = load_aux_output
self.dist_sampler = dist_sampler

# read json data
with jsonlines.open(json_file, 'r') as reader:
Expand Down Expand Up @@ -145,11 +148,18 @@ def __init__(self,
self.dataset = TransformDataset(self.minibaches, self.converter,
self.reader)

self.sampler = DistributedBatchSampler(
dataset=self.dataset,
batch_size=1,
shuffle=not self.use_sortagrad if self.train_mode else False,
)
if self.dist_sampler:
self.sampler = DistributedBatchSampler(
dataset=self.dataset,
batch_size=1,
shuffle=not self.use_sortagrad if self.train_mode else False,
drop_last=False, )
else:
self.sampler = BatchSampler(
dataset=self.dataset,
batch_size=1,
shuffle=not self.use_sortagrad if self.train_mode else False,
drop_last=False, )

self.dataloader = DataLoader(
dataset=self.dataset,
Expand Down Expand Up @@ -181,5 +191,8 @@ def __repr__(self):
echo += f"subsampling_factor: {self.subsampling_factor}, "
echo += f"num_encs: {self.num_encs}, "
echo += f"num_workers: {self.n_iter_processes}, "
echo += f"load_aux_input: {self.load_aux_input}, "
echo += f"load_aux_output: {self.load_aux_output}, "
echo += f"dist_sampler: {self.dist_sampler}, "
echo += f"file: {self.json_file}"
return echo
7 changes: 5 additions & 2 deletions paddlespeech/t2s/exps/synthesize_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,12 +203,15 @@ def evaluate(args):
get_tone_ids = True
if args.lang == 'zh':
input_ids = frontend.get_input_ids(
sentence, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids)
sentence,
merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids)
phone_ids = input_ids["phone_ids"]
if get_tone_ids:
tone_ids = input_ids["tone_ids"]
elif args.lang == 'en':
input_ids = frontend.get_input_ids(sentence, merge_sentences=merge_sentences)
input_ids = frontend.get_input_ids(
sentence, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"]
else:
print("lang should in {'zh', 'en'}!")
Expand Down