From a030df86d3dd7867b8113f6760324f959d96f331 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E6=82=A6?= Date: Thu, 20 Jun 2024 23:26:08 +0800 Subject: [PATCH] Update webui_mix.py --- webui_mix.py | 116 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 116 insertions(+) diff --git a/webui_mix.py b/webui_mix.py index 1108062..99aa0a6 100644 --- a/webui_mix.py +++ b/webui_mix.py @@ -16,6 +16,7 @@ from tts_model import load_chat_tts_model, clear_cuda_cache, generate_audio_for_seed from config import DEFAULT_BATCH_SIZE, DEFAULT_SPEED, DEFAULT_TEMPERATURE, DEFAULT_TOP_K, DEFAULT_TOP_P, DEFAULT_ORAL, \ DEFAULT_LAUGH, DEFAULT_BK, DEFAULT_SEG_LENGTH +import torch parser = argparse.ArgumentParser(description="Gradio ChatTTS MIX") parser.add_argument("--source", type=str, default="huggingface", help="Model source: 'huggingface' or 'local'.") @@ -303,6 +304,92 @@ def generate_tts_audio(text_file, num_seeds, seed, speed, oral, laugh, bk, min_l except Exception as e: raise e +def generate_tts_audio_stream(text_file, num_seeds, seed, speed, oral, laugh, bk, min_length, batch_size, temperature, top_P, + top_K, roleid=None, refine_text=True, speaker_type="seed", pt_file=None): + from utils import split_text, replace_tokens, restore_tokens + from tts_model import deterministic + if seed in [0, -1, None]: + seed = random.randint(1, 9999) + content = '' + if os.path.isfile(text_file): + content = "" + elif isinstance(text_file, str): + content = text_file + # 将 [uv_break] [laugh] 替换为 _uv_break_ _laugh_ 处理后再还原 + content = replace_tokens(content) + texts = [normalize_zh(_) for _ in content.split('\n') if _.strip()] + + for i, text in enumerate(texts): + texts[i] = restore_tokens(text) + + print(texts) + + if oral < 0 or oral > 9 or laugh < 0 or laugh > 2 or bk < 0 or bk > 7: + raise ValueError("oral_(0-9), laugh_(0-2), break_(0-7) out of range") + + refine_text_prompt = f"[oral_{oral}][laugh_{laugh}][break_{bk}]" + + print(f"speaker_type: {speaker_type}") + if speaker_type == "seed": + if seed in [None, -1, 0, "", "random"]: + seed = np.random.randint(0, 9999) + deterministic(seed) + rnd_spk_emb = chat.sample_random_speaker() + elif speaker_type == "role": + # 从 JSON 文件中读取数据 + with open('./slct_voice_240605.json', 'r', encoding='utf-8') as json_file: + slct_idx_loaded = json.load(json_file) + # 将包含 Tensor 数据的部分转换回 Tensor 对象 + for key in slct_idx_loaded: + tensor_list = slct_idx_loaded[key]["tensor"] + slct_idx_loaded[key]["tensor"] = torch.tensor(tensor_list) + # 将音色 tensor 打包进params_infer_code,固定使用此音色发音,调低temperature + rnd_spk_emb = slct_idx_loaded[roleid]["tensor"] + # temperature = 0.001 + elif speaker_type == "pt": + print(pt_file) + rnd_spk_emb = torch.load(pt_file) + print(rnd_spk_emb.shape) + if rnd_spk_emb.shape != (768,): + raise ValueError("维度应为 768。") + else: + raise ValueError(f"Invalid speaker_type: {speaker_type}. ") + + params_infer_code = { + 'spk_emb': rnd_spk_emb, + 'prompt': f'[speed_{speed}]', + 'top_P': top_P, + 'top_K': top_K, + 'temperature': temperature + } + params_refine_text = { + 'prompt': refine_text_prompt, + 'top_P': top_P, + 'top_K': top_K, + 'temperature': temperature + } + + + for text in texts: + + wavs_gen = chat.infer(text, params_infer_code=params_infer_code, params_refine_text=params_refine_text, + use_decoder=True, skip_refine_text=True,stream=True) + + for gen in wavs_gen: + wavs = [np.array([[]])] + wavs[0] = np.hstack([wavs[0], np.array(gen[0])]) + audio = wavs[0][0] + + max_audio = np.abs(audio).max() # 简单防止16bit爆音 + if max_audio > 1: + audio /= max_audio + + yield 24000,(audio * 32768).astype(np.int16) + + clear_cuda_cache() + + + def generate_refine(text_file, oral, laugh, bk, temperature, top_P, top_K, progress=gr.Progress()): from tts_model import generate_refine_text @@ -514,9 +601,15 @@ def inser_token(text, btn): with gr.Row(): generate_button = gr.Button("生成音频", variant="primary") + generate_button_stream = gr.Button("流式生成音频(一边播放一边推理)", variant="primary") with gr.Row(): output_audio = gr.Audio(label="生成的音频文件") + output_audio_stream = gr.Audio(label="流式音频",value=None, + streaming=True, + autoplay=True, # disable auto play for Windows, due to https://developer.chrome.com/blog/autoplay#webaudio + interactive=False, + show_label=True) generate_audio_seed.click(generate_seed, inputs=[], @@ -590,6 +683,29 @@ def do_style_select(x): outputs=[output_audio] ) + generate_button_stream.click( + fn=generate_tts_audio_stream, + inputs=[ + text_file_input, + num_seeds_input, + seed_input, + speed_input, + oral_input, + laugh_input, + bk_input, + min_length_input, + batch_size_input, + temperature_input, + top_P_input, + top_K_input, + roleid_input, + refine_text_input, + speaker_stat, + pt_input + ], + outputs=[output_audio_stream] + ) + break_button.click( inser_token, inputs=[text_file_input, break_button],