Skip to content

Commit

Permalink
add style_melgan and hifigan in tts cli, test=tts (#1241)
Browse files Browse the repository at this point in the history
  • Loading branch information
yt605155624 committed Dec 30, 2021
1 parent a232cd8 commit fbe3c05
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 50 deletions.
82 changes: 63 additions & 19 deletions paddlespeech/cli/tts/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,32 @@
'speech_stats':
'feats_stats.npy',
},
# style_melgan
"style_melgan_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/style_melgan/style_melgan_csmsc_ckpt_0.1.1.zip',
'md5':
'5de2d5348f396de0c966926b8c462755',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_1500000.pdz',
'speech_stats':
'feats_stats.npy',
},
# hifigan
"hifigan_csmsc-zh": {
'url':
'https://paddlespeech.bj.bcebos.com/Parakeet/released_models/hifigan/hifigan_csmsc_ckpt_0.1.1.zip',
'md5':
'dd40a3d88dfcf64513fba2f0f961ada6',
'config':
'default.yaml',
'ckpt':
'snapshot_iter_2500000.pdz',
'speech_stats':
'feats_stats.npy',
},
}

model_alias = {
Expand All @@ -199,6 +225,14 @@
"paddlespeech.t2s.models.melgan:MelGANGenerator",
"mb_melgan_inference":
"paddlespeech.t2s.models.melgan:MelGANInference",
"style_melgan":
"paddlespeech.t2s.models.melgan:StyleMelGANGenerator",
"style_melgan_inference":
"paddlespeech.t2s.models.melgan:StyleMelGANInference",
"hifigan":
"paddlespeech.t2s.models.hifigan:HiFiGANGenerator",
"hifigan_inference":
"paddlespeech.t2s.models.hifigan:HiFiGANInference",
}


Expand Down Expand Up @@ -266,7 +300,7 @@ def __init__(self):
default='pwgan_csmsc',
choices=[
'pwgan_csmsc', 'pwgan_ljspeech', 'pwgan_aishell3', 'pwgan_vctk',
'mb_melgan_csmsc'
'mb_melgan_csmsc', 'style_melgan_csmsc', 'hifigan_csmsc'
],
help='Choose vocoder type of tts task.')

Expand Down Expand Up @@ -504,37 +538,47 @@ def infer(self,
am_name = am[:am.rindex('_')]
am_dataset = am[am.rindex('_') + 1:]
get_tone_ids = False
merge_sentences = False
if am_name == 'speedyspeech':
get_tone_ids = True
if lang == 'zh':
input_ids = self.frontend.get_input_ids(
text, merge_sentences=True, get_tone_ids=get_tone_ids)
text,
merge_sentences=merge_sentences,
get_tone_ids=get_tone_ids)
phone_ids = input_ids["phone_ids"]
phone_ids = phone_ids[0]
if get_tone_ids:
tone_ids = input_ids["tone_ids"]
tone_ids = tone_ids[0]
elif lang == 'en':
input_ids = self.frontend.get_input_ids(text)
input_ids = self.frontend.get_input_ids(
text, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"]
else:
print("lang should in {'zh', 'en'}!")

# am
if am_name == 'speedyspeech':
mel = self.am_inference(phone_ids, tone_ids)
# fastspeech2
else:
# multi speaker
if am_dataset in {"aishell3", "vctk"}:
mel = self.am_inference(
phone_ids, spk_id=paddle.to_tensor(spk_id))
flags = 0
for i in range(len(phone_ids)):
part_phone_ids = phone_ids[i]
# am
if am_name == 'speedyspeech':
part_tone_ids = tone_ids[i]
mel = self.am_inference(part_phone_ids, part_tone_ids)
# fastspeech2
else:
mel = self.am_inference(phone_ids)

# voc
wav = self.voc_inference(mel)
self._outputs['wav'] = wav
# multi speaker
if am_dataset in {"aishell3", "vctk"}:
mel = self.am_inference(
part_phone_ids, spk_id=paddle.to_tensor(spk_id))
else:
mel = self.am_inference(part_phone_ids)
# voc
wav = self.voc_inference(mel)
if flags == 0:
wav_all = wav
flags = 1
else:
wav_all = paddle.concat([wav_all, wav])
self._outputs['wav'] = wav_all

def postprocess(self, output: str='output.wav') -> Union[str, os.PathLike]:
"""
Expand Down
42 changes: 24 additions & 18 deletions paddlespeech/t2s/exps/synthesize_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,41 +196,47 @@ def evaluate(args):

output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)

merge_sentences = False
for utt_id, sentence in sentences:
get_tone_ids = False
if am_name == 'speedyspeech':
get_tone_ids = True
if args.lang == 'zh':
input_ids = frontend.get_input_ids(
sentence, merge_sentences=True, get_tone_ids=get_tone_ids)
sentence, merge_sentences=merge_sentences, get_tone_ids=get_tone_ids)
phone_ids = input_ids["phone_ids"]
phone_ids = phone_ids[0]
if get_tone_ids:
tone_ids = input_ids["tone_ids"]
tone_ids = tone_ids[0]
elif args.lang == 'en':
input_ids = frontend.get_input_ids(sentence)
input_ids = frontend.get_input_ids(sentence, merge_sentences=merge_sentences)
phone_ids = input_ids["phone_ids"]
else:
print("lang should in {'zh', 'en'}!")

with paddle.no_grad():
# acoustic model
if am_name == 'fastspeech2':
# multi speaker
if am_dataset in {"aishell3", "vctk"}:
spk_id = paddle.to_tensor(args.spk_id)
mel = am_inference(phone_ids, spk_id)
flags = 0
for i in range(len(phone_ids)):
part_phone_ids = phone_ids[i]
# acoustic model
if am_name == 'fastspeech2':
# multi speaker
if am_dataset in {"aishell3", "vctk"}:
spk_id = paddle.to_tensor(args.spk_id)
mel = am_inference(part_phone_ids, spk_id)
else:
mel = am_inference(part_phone_ids)
elif am_name == 'speedyspeech':
part_tone_ids = tone_ids[i]
mel = am_inference(part_phone_ids, part_tone_ids)
# vocoder
wav = voc_inference(mel)
if flags == 0:
wav_all = wav
flags = 1
else:
mel = am_inference(phone_ids)
elif am_name == 'speedyspeech':
mel = am_inference(phone_ids, tone_ids)
# vocoder
wav = voc_inference(mel)
wav_all = paddle.concat([wav_all, wav])
sf.write(
str(output_dir / (utt_id + ".wav")),
wav.numpy(),
wav_all.numpy(),
samplerate=am_config.fs)
print(f"{utt_id} done!")

Expand Down
49 changes: 38 additions & 11 deletions paddlespeech/t2s/frontend/phonectic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,17 @@
# limitations under the License.
from abc import ABC
from abc import abstractmethod
from typing import List

import numpy as np
import paddle
from g2p_en import G2p
from g2pM import G2pM

from paddlespeech.t2s.frontend.normalizer.normalizer import normalize
from paddlespeech.t2s.frontend.punctuation import get_punctuations
from paddlespeech.t2s.frontend.vocab import Vocab
from paddlespeech.t2s.frontend.zh_normalization.text_normlization import TextNormalizer

# discard opencc untill we find an easy solution to install it on windows
# from opencc import OpenCC
Expand Down Expand Up @@ -53,6 +56,7 @@ def __init__(self, phone_vocab_path=None):
self.vocab = Vocab(self.phonemes + self.punctuations)
self.vocab_phones = {}
self.punc = ":,;。?!“”‘’':,;.?!"
self.text_normalizer = TextNormalizer()
if phone_vocab_path:
with open(phone_vocab_path, 'rt') as f:
phn_id = [line.strip().split() for line in f.readlines()]
Expand All @@ -78,19 +82,42 @@ def phoneticize(self, sentence):
phonemes = [item for item in phonemes if item in self.vocab.stoi]
return phonemes

def get_input_ids(self, sentence: str) -> paddle.Tensor:
result = {}
phones = self.phoneticize(sentence)
# remove start_symbol and end_symbol
phones = phones[1:-1]
phones = [phn for phn in phones if not phn.isspace()]
phones = [
def _p2id(self, phonemes: List[str]) -> np.array:
# replace unk phone with sp
phonemes = [
phn if (phn in self.vocab_phones and phn not in self.punc) else "sp"
for phn in phones
for phn in phonemes
]
phone_ids = [self.vocab_phones[phn] for phn in phones]
phone_ids = paddle.to_tensor(phone_ids)
result["phone_ids"] = phone_ids
phone_ids = [self.vocab_phones[item] for item in phonemes]
return np.array(phone_ids, np.int64)

def get_input_ids(self, sentence: str,
merge_sentences: bool=False) -> paddle.Tensor:
result = {}
sentences = self.text_normalizer._split(sentence, lang="en")
phones_list = []
temp_phone_ids = []
for sentence in sentences:
phones = self.phoneticize(sentence)
# remove start_symbol and end_symbol
phones = phones[1:-1]
phones = [phn for phn in phones if not phn.isspace()]
phones_list.append(phones)

if merge_sentences:
merge_list = sum(phones_list, [])
# rm the last 'sp' to avoid the noise at the end
# cause in the training data, no 'sp' in the end
if merge_list[-1] == 'sp':
merge_list = merge_list[:-1]
phones_list = []
phones_list.append(merge_list)

for part_phones_list in phones_list:
phone_ids = self._p2id(part_phones_list)
phone_ids = paddle.to_tensor(phone_ids)
temp_phone_ids.append(phone_ids)
result["phone_ids"] = temp_phone_ids
return result

def numericalize(self, phonemes):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class TextNormalizer():
def __init__(self):
self.SENTENCE_SPLITOR = re.compile(r'([:,;。?!,;?!][”’]?)')

def _split(self, text: str) -> List[str]:
def _split(self, text: str, lang="zh") -> List[str]:
"""Split long text into sentences with sentence-splitting punctuations.
Parameters
----------
Expand All @@ -65,7 +65,8 @@ def _split(self, text: str) -> List[str]:
Sentences.
"""
# Only for pure Chinese here
text = text.replace(" ", "")
if lang == "zh":
text = text.replace(" ", "")
text = self.SENTENCE_SPLITOR.sub(r'\1\n', text)
text = text.strip()
sentences = [sentence.strip() for sentence in re.split(r'\n+', text)]
Expand Down

0 comments on commit fbe3c05

Please sign in to comment.