Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
… into webdataset
  • Loading branch information
Jackwaterveg committed Jun 29, 2022
2 parents 1dd23a8 + 429221d commit 6ec6921
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 8 deletions.
2 changes: 1 addition & 1 deletion examples/wenetspeech/asr1/conf/conformer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ maxlen_out: 150 # if output length(number of tokens) > maxlen-out, data is auto
resample_rate: 16000
shuffle_size: 1500
sort_size: 1000
num_workers: 0
num_workers: 8
prefetch_factor: 10
dist_sampler: True
num_encs: 1
Expand Down
3 changes: 1 addition & 2 deletions examples/wenetspeech/asr1/local/train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ python3 -u ${BIN_DIR}/train.py \
--benchmark-batch-size ${benchmark_batch_size} \
--benchmark-max-step ${benchmark_max_step}
else
#NCCL_SOCKET_IFNAME=eth0
python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${ips_config} ${BIN_DIR}/train.py \
NCCL_SOCKET_IFNAME=eth0 python3 -m paddle.distributed.launch --gpus=${CUDA_VISIBLE_DEVICES} ${ips_config} ${BIN_DIR}/train.py \
--ngpu ${ngpu} \
--seed ${seed} \
--config ${config_path} \
Expand Down
2 changes: 2 additions & 0 deletions paddlespeech/audio/streamdata/shardlists.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __iter__(self):

def split_by_node(src, group=None):
rank, world_size, worker, num_workers = utils.paddle_worker_info(group=group)
logger.info(f"world_size:{world_size}, rank:{rank}")
if world_size > 1:
for s in islice(src, rank, None, world_size):
yield s
Expand All @@ -83,6 +84,7 @@ def single_node_only(src, group=None):

def split_by_worker(src):
rank, world_size, worker, num_workers = utils.paddle_worker_info()
logger.info(f"num_workers:{num_workers}, worker:{worker}")
if num_workers > 1:
for s in islice(src, worker, None, num_workers):
yield s
Expand Down
10 changes: 7 additions & 3 deletions paddlespeech/audio/streamdata/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
import sys
from typing import Any, Callable, Iterator, Optional, Union

from ..utils.log import Logger

logger = Logger(__name__)

def make_seed(*args):
seed = 0
Expand Down Expand Up @@ -112,13 +115,14 @@ def paddle_worker_info(group=None):
num_workers = int(os.environ["NUM_WORKERS"])
else:
try:
import paddle.io.get_worker_info
from paddle.io import get_worker_info
worker_info = paddle.io.get_worker_info()
if worker_info is not None:
worker = worker_info.id
num_workers = worker_info.num_workers
except ModuleNotFoundError:
pass
except ModuleNotFoundError as E:
logger.info(f"not found {E}")
exit(-1)

return rank, world_size, worker, num_workers

Expand Down
2 changes: 1 addition & 1 deletion paddlespeech/cli/asr/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from ..utils import CLI_TIMER
from ..utils import stats_wrapper
from ..utils import timer_register
from paddlespeech.s2t.audio.transformation import Transformation
from paddlespeech.audio.transform.transformation import Transformation
from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.utils.utility import UpdateConfig

Expand Down
2 changes: 1 addition & 1 deletion paddlespeech/s2t/io/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def __init__(self,
if self.dist_sampler:
base_dataset = streamdata.DataPipeline(
streamdata.SimpleShardList(shardlist),
streamdata.split_by_node,
streamdata.split_by_node if train_mode else streamdata.placeholder(),
streamdata.split_by_worker,
streamdata.tarfile_to_samples(streamdata.reraise_exception)
)
Expand Down

0 comments on commit 6ec6921

Please sign in to comment.