Skip to content

Commit

Permalink
fix batch sampler set_epoch when epcoh start
Browse files Browse the repository at this point in the history
  • Loading branch information
zh794390558 committed Jan 5, 2022
1 parent 680eac0 commit 6f651d7
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 15 deletions.
8 changes: 6 additions & 2 deletions paddlespeech/s2t/exps/u2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,9 @@ def setup_dataloader(self):
preprocess_conf=config.preprocess_config,
n_iter_processes=config.num_workers,
subsampling_factor=1,
num_encs=1)
num_encs=1,
dist_sampler=True,
shortest_first=False)

self.valid_loader = BatchDataLoader(
json_file=config.dev_manifest,
Expand All @@ -259,7 +261,9 @@ def setup_dataloader(self):
preprocess_conf=config.preprocess_config,
n_iter_processes=config.num_workers,
subsampling_factor=1,
num_encs=1)
num_encs=1,
dist_sampler=True,
shortest_first=False)
logger.info("Setup train/valid Dataloader!")
else:
decode_batch_size = config.get('decode', dict()).get(
Expand Down
13 changes: 8 additions & 5 deletions paddlespeech/s2t/io/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ def __init__(self,
load_aux_input: bool=False,
load_aux_output: bool=False,
num_encs: int=1,
dist_sampler: bool=False):
dist_sampler: bool=False,
shortest_first: bool=False):
self.json_file = json_file
self.train_mode = train_mode
self.use_sortagrad = sortagrad == -1 or sortagrad > 0
Expand All @@ -97,6 +98,7 @@ def __init__(self,
self.load_aux_input = load_aux_input
self.load_aux_output = load_aux_output
self.dist_sampler = dist_sampler
self.shortest_first = shortest_first

# read json data
with jsonlines.open(json_file, 'r') as reader:
Expand All @@ -113,7 +115,7 @@ def __init__(self,
maxlen_out,
minibatches, # for debug
min_batch_size=mini_batch_size,
shortest_first=self.use_sortagrad,
shortest_first=self.shortest_first or self.use_sortagrad,
count=batch_count,
batch_bins=batch_bins,
batch_frames_in=batch_frames_in,
Expand Down Expand Up @@ -149,21 +151,21 @@ def __init__(self,
self.reader)

if self.dist_sampler:
self.sampler = DistributedBatchSampler(
self.batch_sampler = DistributedBatchSampler(
dataset=self.dataset,
batch_size=1,
shuffle=not self.use_sortagrad if self.train_mode else False,
drop_last=False, )
else:
self.sampler = BatchSampler(
self.batch_sampler = BatchSampler(
dataset=self.dataset,
batch_size=1,
shuffle=not self.use_sortagrad if self.train_mode else False,
drop_last=False, )

self.dataloader = DataLoader(
dataset=self.dataset,
batch_sampler=self.sampler,
batch_sampler=self.batch_sampler,
collate_fn=batch_collate,
num_workers=self.n_iter_processes, )

Expand Down Expand Up @@ -194,5 +196,6 @@ def __repr__(self):
echo += f"load_aux_input: {self.load_aux_input}, "
echo += f"load_aux_output: {self.load_aux_output}, "
echo += f"dist_sampler: {self.dist_sampler}, "
echo += f"shortest_first: {self.shortest_first}, "
echo += f"file: {self.json_file}"
return echo
3 changes: 0 additions & 3 deletions paddlespeech/s2t/modules/ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,6 @@
except Exception as e:
logger.info("paddlespeech_ctcdecoders not installed!")

#try:
#except Exception as e:
# logger.info("ctcdecoder not installed!")

__all__ = ['CTCDecoder']

Expand Down
9 changes: 5 additions & 4 deletions paddlespeech/s2t/training/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,18 +67,19 @@ def __init__(self,
super().__init__(learning_rate, last_epoch, verbose)

def __repr__(self):
return f"{self.__class__.__name__}(warmup_steps={self.warmup_steps})"
return f"{self.__class__.__name__}(warmup_steps={self.warmup_steps}, lr={self.base_lr}, last_epoch={self.last_epoch})"

def get_lr(self):
# self.last_epoch start from zero
step_num = self.last_epoch + 1
return self.base_lr * self.warmup_steps**0.5 * min(
step_num**-0.5, step_num * self.warmup_steps**-1.5)

def set_step(self, step: int=None):
'''
It will update the learning rate in optimizer according to current ``epoch`` .
It will update the learning rate in optimizer according to current ``epoch`` .
The new learning rate will take effect on next ``optimizer.step`` .
Args:
step (int, None): specify current epoch. Default: None. Auto-increment from last_epoch=-1.
Returns:
Expand All @@ -94,7 +95,7 @@ class ConstantLR(LRScheduler):
learning_rate (float): The initial learning rate. It is a python float number.
last_epoch (int, optional): The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
Returns:
``ConstantLR`` instance to schedule learning rate.
"""
Expand Down
2 changes: 1 addition & 1 deletion paddlespeech/s2t/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def maybe_batch_sampler_step(self):
batch_sampler = self.train_loader.batch_sampler
if isinstance(batch_sampler, paddle.io.DistributedBatchSampler):
logger.debug(
f"train_loader.batch_sample set epoch: {self.epoch}")
f"train_loader.batch_sample.set_epoch: {self.epoch}")
batch_sampler.set_epoch(self.epoch)

def before_train(self):
Expand Down

0 comments on commit 6f651d7

Please sign in to comment.