Skip to content

Commit

Permalink
Update webui_mix.py
Browse files Browse the repository at this point in the history
  • Loading branch information
v3ucn committed Jun 20, 2024
1 parent 659cbd4 commit a030df8
Showing 1 changed file with 116 additions and 0 deletions.
116 changes: 116 additions & 0 deletions webui_mix.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from tts_model import load_chat_tts_model, clear_cuda_cache, generate_audio_for_seed
from config import DEFAULT_BATCH_SIZE, DEFAULT_SPEED, DEFAULT_TEMPERATURE, DEFAULT_TOP_K, DEFAULT_TOP_P, DEFAULT_ORAL, \
DEFAULT_LAUGH, DEFAULT_BK, DEFAULT_SEG_LENGTH
import torch

parser = argparse.ArgumentParser(description="Gradio ChatTTS MIX")
parser.add_argument("--source", type=str, default="huggingface", help="Model source: 'huggingface' or 'local'.")
Expand Down Expand Up @@ -303,6 +304,92 @@ def generate_tts_audio(text_file, num_seeds, seed, speed, oral, laugh, bk, min_l
except Exception as e:
raise e

def generate_tts_audio_stream(text_file, num_seeds, seed, speed, oral, laugh, bk, min_length, batch_size, temperature, top_P,
top_K, roleid=None, refine_text=True, speaker_type="seed", pt_file=None):
from utils import split_text, replace_tokens, restore_tokens
from tts_model import deterministic
if seed in [0, -1, None]:
seed = random.randint(1, 9999)
content = ''
if os.path.isfile(text_file):
content = ""
elif isinstance(text_file, str):
content = text_file
# 将 [uv_break] [laugh] 替换为 _uv_break_ _laugh_ 处理后再还原
content = replace_tokens(content)
texts = [normalize_zh(_) for _ in content.split('\n') if _.strip()]

for i, text in enumerate(texts):
texts[i] = restore_tokens(text)

print(texts)

if oral < 0 or oral > 9 or laugh < 0 or laugh > 2 or bk < 0 or bk > 7:
raise ValueError("oral_(0-9), laugh_(0-2), break_(0-7) out of range")

refine_text_prompt = f"[oral_{oral}][laugh_{laugh}][break_{bk}]"

print(f"speaker_type: {speaker_type}")
if speaker_type == "seed":
if seed in [None, -1, 0, "", "random"]:
seed = np.random.randint(0, 9999)
deterministic(seed)
rnd_spk_emb = chat.sample_random_speaker()
elif speaker_type == "role":
# 从 JSON 文件中读取数据
with open('./slct_voice_240605.json', 'r', encoding='utf-8') as json_file:
slct_idx_loaded = json.load(json_file)
# 将包含 Tensor 数据的部分转换回 Tensor 对象
for key in slct_idx_loaded:
tensor_list = slct_idx_loaded[key]["tensor"]
slct_idx_loaded[key]["tensor"] = torch.tensor(tensor_list)
# 将音色 tensor 打包进params_infer_code,固定使用此音色发音,调低temperature
rnd_spk_emb = slct_idx_loaded[roleid]["tensor"]
# temperature = 0.001
elif speaker_type == "pt":
print(pt_file)
rnd_spk_emb = torch.load(pt_file)
print(rnd_spk_emb.shape)
if rnd_spk_emb.shape != (768,):
raise ValueError("维度应为 768。")
else:
raise ValueError(f"Invalid speaker_type: {speaker_type}. ")

params_infer_code = {
'spk_emb': rnd_spk_emb,
'prompt': f'[speed_{speed}]',
'top_P': top_P,
'top_K': top_K,
'temperature': temperature
}
params_refine_text = {
'prompt': refine_text_prompt,
'top_P': top_P,
'top_K': top_K,
'temperature': temperature
}


for text in texts:

wavs_gen = chat.infer(text, params_infer_code=params_infer_code, params_refine_text=params_refine_text,
use_decoder=True, skip_refine_text=True,stream=True)

for gen in wavs_gen:
wavs = [np.array([[]])]
wavs[0] = np.hstack([wavs[0], np.array(gen[0])])
audio = wavs[0][0]

max_audio = np.abs(audio).max() # 简单防止16bit爆音
if max_audio > 1:
audio /= max_audio

yield 24000,(audio * 32768).astype(np.int16)

clear_cuda_cache()




def generate_refine(text_file, oral, laugh, bk, temperature, top_P, top_K, progress=gr.Progress()):
from tts_model import generate_refine_text
Expand Down Expand Up @@ -514,9 +601,15 @@ def inser_token(text, btn):

with gr.Row():
generate_button = gr.Button("生成音频", variant="primary")
generate_button_stream = gr.Button("流式生成音频(一边播放一边推理)", variant="primary")

with gr.Row():
output_audio = gr.Audio(label="生成的音频文件")
output_audio_stream = gr.Audio(label="流式音频",value=None,
streaming=True,
autoplay=True, # disable auto play for Windows, due to https://developer.chrome.com/blog/autoplay#webaudio
interactive=False,
show_label=True)

generate_audio_seed.click(generate_seed,
inputs=[],
Expand Down Expand Up @@ -590,6 +683,29 @@ def do_style_select(x):
outputs=[output_audio]
)

generate_button_stream.click(
fn=generate_tts_audio_stream,
inputs=[
text_file_input,
num_seeds_input,
seed_input,
speed_input,
oral_input,
laugh_input,
bk_input,
min_length_input,
batch_size_input,
temperature_input,
top_P_input,
top_K_input,
roleid_input,
refine_text_input,
speaker_stat,
pt_input
],
outputs=[output_audio_stream]
)

break_button.click(
inser_token,
inputs=[text_file_input, break_button],
Expand Down

0 comments on commit a030df8

Please sign in to comment.