Skip to content

Commit

Permalink
format code
Browse files Browse the repository at this point in the history
  • Loading branch information
zh794390558 committed Nov 30, 2021
1 parent d395c2b commit 3922886
Show file tree
Hide file tree
Showing 20 changed files with 51 additions and 32 deletions.
3 changes: 0 additions & 3 deletions examples/aishell/asr1/READEME.md
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,3 @@ You need to prepare an audio file, please confirm the sample rate of the audio i
```bash
CUDA_VISIBLE_DEVICES= ./local/test_hub.sh conf/transformer.yaml exp/transformer/checkpoints/avg_20 data/test_audio.wav
```



10 changes: 6 additions & 4 deletions paddlespeech/s2t/exps/u2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ def train_batch(self, batch_index, batch_data, msg):
losses_np_v = losses_np.copy()
losses_np_v.update({"lr": self.lr_scheduler()})
for key, val in losses_np_v.items():
self.visualizer.add_scalar(tag='train/'+key, value=val, step=self.iteration-1)

self.visualizer.add_scalar(
tag='train/' + key, value=val, step=self.iteration - 1)

@paddle.no_grad()
def valid(self):
Expand Down Expand Up @@ -238,8 +238,10 @@ def do_train(self):
logger.info(
'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss))
if self.visualizer:
self.visualizer.add_scalar(tag='eval/cv_loss', value=cv_loss, step=self.epoch)
self.visualizer.add_scalar(tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)
self.visualizer.add_scalar(
tag='eval/cv_loss', value=cv_loss, step=self.epoch)
self.visualizer.add_scalar(
tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)

self.save(tag=self.epoch, infos={'val_loss': cv_loss})
self.new_epoch()
Expand Down
11 changes: 7 additions & 4 deletions paddlespeech/s2t/exps/u2_kaldi/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,8 @@ def train_batch(self, batch_index, batch_data, msg):
losses_np_v = losses_np.copy()
losses_np_v.update({"lr": self.lr_scheduler()})
for key, val in losses_np_v.items():
self.visualizer.add_scalar(tag="train/"+key, value=val, step=self.iteration - 1)
self.visualizer.add_scalar(
tag="train/" + key, value=val, step=self.iteration - 1)

@paddle.no_grad()
def valid(self):
Expand Down Expand Up @@ -222,9 +223,11 @@ def do_train(self):
logger.info(
'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss))
if self.visualizer:
self.visualizer.add_scalar(tag='eval/cv_loss', value=cv_loss, step=self.epoch)
self.visualizer.add_scalar(tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)

self.visualizer.add_scalar(
tag='eval/cv_loss', value=cv_loss, step=self.epoch)
self.visualizer.add_scalar(
tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)

self.save(tag=self.epoch, infos={'val_loss': cv_loss})
self.new_epoch()

Expand Down
11 changes: 7 additions & 4 deletions paddlespeech/s2t/exps/u2_st/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ def train_batch(self, batch_index, batch_data, msg):
losses_np_v = losses_np.copy()
losses_np_v.update({"lr": self.lr_scheduler()})
for key, val in losses_np_v.items():
self.visualizer.add_scalar(tag="train/"+key, value=val, step=self.iteration - 1)
self.visualizer.add_scalar(
tag="train/" + key, value=val, step=self.iteration - 1)

@paddle.no_grad()
def valid(self):
Expand Down Expand Up @@ -235,9 +236,11 @@ def do_train(self):
logger.info(
'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss))
if self.visualizer:
self.visualizer.add_scalar(tag='eval/cv_loss', value=cv_loss, step=self.epoch)
self.visualizer.add_scalar(tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)

self.visualizer.add_scalar(
tag='eval/cv_loss', value=cv_loss, step=self.epoch)
self.visualizer.add_scalar(
tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)

self.save(tag=self.epoch, infos={'val_loss': cv_loss})
self.new_epoch()

Expand Down
1 change: 1 addition & 0 deletions paddlespeech/s2t/frontend/augmentor/impulse_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
"""Contains the impulse response augmentation model."""
import jsonlines

from paddlespeech.s2t.frontend.audio import AudioSegment
from paddlespeech.s2t.frontend.augmentor.base import AugmentorBase

Expand Down
1 change: 1 addition & 0 deletions paddlespeech/s2t/frontend/augmentor/noise_perturb.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
"""Contains the noise perturb augmentation model."""
import jsonlines

from paddlespeech.s2t.frontend.audio import AudioSegment
from paddlespeech.s2t.frontend.augmentor.base import AugmentorBase

Expand Down
6 changes: 4 additions & 2 deletions paddlespeech/s2t/frontend/normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
"""Contains feature normalizers."""
import json

import jsonlines
import numpy as np
import paddle
Expand All @@ -26,7 +27,8 @@
__all__ = ["FeatureNormalizer"]

logger = Log(__name__).getlog()



# https://github.com/PaddlePaddle/Paddle/pull/31481
class CollateFunc(object):
def __init__(self, feature_func):
Expand Down Expand Up @@ -62,7 +64,7 @@ def __init__(self, manifest_path, num_samples=-1, rng=None, random_seed=0):

with jsonlines.open(manifest_path, 'r') as reader:
manifest = list(reader)

if num_samples == -1:
sampled_manifest = manifest
else:
Expand Down
2 changes: 1 addition & 1 deletion paddlespeech/s2t/frontend/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def load_dict(dict_path: Optional[Text], maskctc=False) -> Optional[List[Text]]:
char_list.append(MASKCTC)
return char_list


def read_manifest(
manifest_path,
max_input_len=float('inf'),
Expand Down
4 changes: 2 additions & 2 deletions paddlespeech/s2t/io/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from typing import Dict
from typing import List
from typing import Text
import jsonlines

import jsonlines
import numpy as np
from paddle.io import DataLoader

Expand Down Expand Up @@ -93,7 +93,7 @@ def __init__(self,
# read json data
with jsonlines.open(json_file, 'r') as reader:
self.data_json = list(reader)

self.feat_dim, self.vocab_size = feat_dim_and_vocab_size(
self.data_json, mode='asr')

Expand Down
1 change: 1 addition & 0 deletions paddlespeech/s2t/io/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# Modified from espnet(https://github.com/espnet/espnet)
# Modified from wenet(https://github.com/wenet-e2e/wenet)
from typing import Optional

import jsonlines
from paddle.io import Dataset
from yacs.config import CfgNode
Expand Down
2 changes: 1 addition & 1 deletion paddlespeech/s2t/io/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _batch_shuffle(indices, batch_size, epoch, clipped=False):
"""
rng = np.random.RandomState(epoch)
shift_len = rng.randint(0, batch_size - 1)
batch_indices = list(zip(*[iter(indices[shift_len:])] * batch_size))
batch_indices = list(zip(* [iter(indices[shift_len:])] * batch_size))
rng.shuffle(batch_indices)
batch_indices = [item for batch in batch_indices for item in batch]
assert clipped is False
Expand Down
6 changes: 4 additions & 2 deletions paddlespeech/s2t/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,10 @@ def do_train(self):
logger.info(
'Epoch {} Val info val_loss {}'.format(self.epoch, cv_loss))
if self.visualizer:
self.visualizer.add_scalar(tag='eval/cv_loss', value=cv_loss, step=self.epoch)
self.visualizer.add_scalar(tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)
self.visualizer.add_scalar(
tag='eval/cv_loss', value=cv_loss, step=self.epoch)
self.visualizer.add_scalar(
tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)

# after epoch
self.save(tag=self.epoch, infos={'val_loss': cv_loss})
Expand Down
1 change: 1 addition & 0 deletions paddlespeech/s2t/utils/socket_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import wave
from time import gmtime
from time import strftime

import jsonlines

__all__ = ["socket_send", "warm_up_test", "AsrTCPServer", "AsrRequestHandler"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,10 @@ def train(self):
self.logger.info("Epoch {} Val info val_loss {}, F1_score {}".
format(self.epoch, total_loss, F1_score))
if self.visualizer:
self.visualizer.add_scalar(tag='eval/cv_loss', value=cv_loss, step=self.epoch)
self.visualizer.add_scalar(tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)
self.visualizer.add_scalar(
tag='eval/cv_loss', value=cv_loss, step=self.epoch)
self.visualizer.add_scalar(
tag='eval/lr', value=self.lr_scheduler(), step=self.epoch)

self.save(
tag=self.epoch, infos={"val_loss": total_loss,
Expand Down
7 changes: 4 additions & 3 deletions utils/build_vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@
import functools
import os
import tempfile
import jsonlines
from collections import Counter

import jsonlines

from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.frontend.utility import BLANK
from paddlespeech.s2t.frontend.utility import SOS
Expand Down Expand Up @@ -63,7 +64,7 @@ def count_manifest(counter, text_feature, manifest_path):
with jsonlines.open(manifest_path, 'r') as reader:
for json_data in reader:
manifest_jsons.append(json_data)

for line_json in manifest_jsons:
line = text_feature.tokenize(line_json['text'], replace_space=False)
counter.update(line)
Expand All @@ -73,7 +74,7 @@ def dump_text_manifest(fileobj, manifest_path, key='text'):
with jsonlines.open(manifest_path, 'r') as reader:
for json_data in reader:
manifest_jsons.append(json_data)

for line_json in manifest_jsons:
fileobj.write(line_json[key] + "\n")

Expand Down
3 changes: 2 additions & 1 deletion utils/dump_manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import argparse
from pathlib import Path
from typing import Union

import jsonlines

key_whitelist = set(['feat', 'text', 'syllable', 'phone'])
Expand All @@ -34,7 +35,7 @@ def dump_manifest(manifest_path, output_dir: Union[str, Path]):

with jsonlines.open(str(manifest_path), 'r') as reader:
manifest_jsons = list(reader)

first_line = manifest_jsons[0]
file_map = {}

Expand Down
5 changes: 3 additions & 2 deletions utils/format_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
"""format manifest with more metadata."""
import argparse
import functools
import jsonlines
import json

import jsonlines

from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
from paddlespeech.s2t.frontend.utility import load_cmvn
from paddlespeech.s2t.io.utility import feat_type
Expand Down Expand Up @@ -73,7 +74,7 @@ def main():
for manifest_path in args.manifest_paths:
with jsonlines.open(str(manifest_path), 'r') as reader:
manifest_jsons = list(reader)

for line_json in manifest_jsons:
output_json = {
"input": [],
Expand Down
1 change: 1 addition & 0 deletions utils/format_triplet_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import argparse
import functools
import json

import jsonlines

from paddlespeech.s2t.frontend.featurizer.text_featurizer import TextFeaturizer
Expand Down
1 change: 1 addition & 0 deletions utils/manifest_key_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import argparse
import functools
from pathlib import Path

import jsonlines

from utils.utility import add_arguments
Expand Down
1 change: 0 additions & 1 deletion utils/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import hashlib
import json
import os
import sys
import tarfile
Expand Down

0 comments on commit 3922886

Please sign in to comment.