Skip to content

Commit

Permalink
[s2t] DataLoader with BatchSampler or DistributeBatchSampler (#1242)
Browse files Browse the repository at this point in the history
* batchsampler or distributebatchsampler

* format
  • Loading branch information
zh794390558 committed Dec 30, 2021
1 parent 6d93f3e commit c81a3f0
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 23 deletions.
12 changes: 8 additions & 4 deletions paddlespeech/s2t/exps/u2_st/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,8 @@ def setup_dataloader(self):
n_iter_processes=config.collator.num_workers,
subsampling_factor=1,
load_aux_output=load_transcript,
num_encs=1)
num_encs=1,
dist_sampler=True)

self.valid_loader = BatchDataLoader(
json_file=config.data.dev_manifest,
Expand All @@ -313,7 +314,8 @@ def setup_dataloader(self):
n_iter_processes=config.collator.num_workers,
subsampling_factor=1,
load_aux_output=load_transcript,
num_encs=1)
num_encs=1,
dist_sampler=True)
logger.info("Setup train/valid Dataloader!")
else:
# test dataset, return raw text
Expand All @@ -335,7 +337,8 @@ def setup_dataloader(self):
augmentation_config, # aug will be off when train_mode=False
n_iter_processes=config.collator.num_workers,
subsampling_factor=1,
num_encs=1)
num_encs=1,
dist_sampler=False)

logger.info("Setup test Dataloader!")

Expand Down Expand Up @@ -542,7 +545,8 @@ def test(self):
len_refs += metrics['len_refs']
num_ins += metrics['num_ins']
rtf = num_time / (num_frames * stride_ms)
logger.info("RTF: %f, instance (%d), batch BELU = %f" % (rtf, num_ins, bleu))
logger.info("RTF: %f, instance (%d), batch BELU = %f" %
(rtf, num_ins, bleu))

rtf = num_time / (num_frames * stride_ms)
msg = "Test: "
Expand Down
27 changes: 16 additions & 11 deletions paddlespeech/s2t/io/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,9 @@ def __call__(self, batch):
# text data (output): (text_len, )
ys_data.append(ud)

assert xs_data[0][0] is not None, "please check Reader and Augmentation impl."

assert xs_data[0][
0] is not None, "please check Reader and Augmentation impl."

xs_pad, ilens = [], []
for xs in xs_data:
# perform subsampling
Expand All @@ -79,22 +80,26 @@ def __call__(self, batch):
# perform padding and convert to tensor
# currently only support real number
xs_pad.append(pad_list(xs, 0).astype(self.dtype))

if not self.load_aux_input:
xs_pad, ilens = xs_pad[0], ilens[0]
break

# NOTE: this is for multi-output (e.g., speech translation)
ys_pad, olens = [], []

for ys in ys_data:
ys_pad.append(pad_list(
[np.array(y[0][:]) if isinstance(y, tuple) else y for y in ys],
self.ignore_id))
ys_pad.append(
pad_list([
np.array(y[0][:]) if isinstance(y, tuple) else y for y in ys
], self.ignore_id))

olens.append(
np.array([
y[0].shape[0] if isinstance(y, tuple) else y.shape[0]
for y in ys
]))

olens.append(np.array(
[y[0].shape[0] if isinstance(y, tuple) else y.shape[0] for y in ys]))

if not self.load_aux_output:
ys_pad, olens = ys_pad[0], olens[0]
break
Expand Down
25 changes: 19 additions & 6 deletions paddlespeech/s2t/io/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import jsonlines
import numpy as np
from paddle.io import BatchSampler
from paddle.io import DataLoader
from paddle.io import DistributedBatchSampler

Expand Down Expand Up @@ -76,7 +77,8 @@ def __init__(self,
subsampling_factor: int=1,
load_aux_input: bool=False,
load_aux_output: bool=False,
num_encs: int=1):
num_encs: int=1,
dist_sampler: bool=False):
self.json_file = json_file
self.train_mode = train_mode
self.use_sortagrad = sortagrad == -1 or sortagrad > 0
Expand All @@ -94,6 +96,7 @@ def __init__(self,
self.n_iter_processes = n_iter_processes
self.load_aux_input = load_aux_input
self.load_aux_output = load_aux_output
self.dist_sampler = dist_sampler

# read json data
with jsonlines.open(json_file, 'r') as reader:
Expand Down Expand Up @@ -145,11 +148,18 @@ def __init__(self,
self.dataset = TransformDataset(self.minibaches, self.converter,
self.reader)

self.sampler = DistributedBatchSampler(
dataset=self.dataset,
batch_size=1,
shuffle=not self.use_sortagrad if self.train_mode else False,
)
if self.dist_sampler:
self.sampler = DistributedBatchSampler(
dataset=self.dataset,
batch_size=1,
shuffle=not self.use_sortagrad if self.train_mode else False,
drop_last=False, )
else:
self.sampler = BatchSampler(
dataset=self.dataset,
batch_size=1,
shuffle=not self.use_sortagrad if self.train_mode else False,
drop_last=False, )

self.dataloader = DataLoader(
dataset=self.dataset,
Expand Down Expand Up @@ -181,5 +191,8 @@ def __repr__(self):
echo += f"subsampling_factor: {self.subsampling_factor}, "
echo += f"num_encs: {self.num_encs}, "
echo += f"num_workers: {self.n_iter_processes}, "
echo += f"load_aux_input: {self.load_aux_input}, "
echo += f"load_aux_output: {self.load_aux_output}, "
echo += f"dist_sampler: {self.dist_sampler}, "
echo += f"file: {self.json_file}"
return echo
7 changes: 5 additions & 2 deletions paddlespeech/t2s/exps/synthesize_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,12 +203,15 @@ def evaluate(args):
get_tone_ids = True
if args.lang == 'zh':
input_ids = frontend.get_input_ids(
sentence, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids)
sentence,
merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids)
phone_ids = input_ids["phone_ids"]
if get_tone_ids:
tone_ids = input_ids["tone_ids"]
elif args.lang == 'en':
input_ids = frontend.get_input_ids(sentence, merge_sentences=merge_sentences)
input_ids = frontend.get_input_ids(
sentence, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"]
else:
print("lang should in {'zh', 'en'}!")
Expand Down

0 comments on commit c81a3f0

Please sign in to comment.