Skip to content

Commit

Permalink
忽略ffmpeg .gitignore
Browse files Browse the repository at this point in the history
	使t2s模型支持批量推理:   GPT_SoVITS/AR/models/t2s_model.py
	修复batch bug   GPT_SoVITS/AR/models/utils.py
    重构的tts infer   GPT_SoVITS/TTS_infer_pack/TTS.py
	文本预处理模块   GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py
	new file   GPT_SoVITS/TTS_infer_pack/__init__.py
	文本拆分方法模块   GPT_SoVITS/TTS_infer_pack/text_segmentation_method.py
	tts infer配置文件   GPT_SoVITS/configs/tts_infer.yaml
	modified   GPT_SoVITS/feature_extractor/cnhubert.py
	modified   GPT_SoVITS/inference_gui.py
	重构的webui   GPT_SoVITS/inference_webui.py
	new file   GPT_SoVITS/inference_webui_old.py
  • Loading branch information
ChasonJiang committed Mar 8, 2024
1 parent e04b3f6 commit 17832e5
Show file tree
Hide file tree
Showing 12 changed files with 1,588 additions and 492 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@ reference
GPT_weights
SoVITS_weights
TEMP

ffmpeg.exe
ffprobe.exe

53 changes: 44 additions & 9 deletions GPT_SoVITS/AR/models/t2s_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
# reference: https://github.com/lifeiteng/vall-e
# modified from https://github.com/feng-yufei/shared_debugging_code/blob/main/model/t2s_model.py
import torch
from tqdm import tqdm

Expand Down Expand Up @@ -386,7 +385,9 @@ def infer_panel(
x.device
)


y_list = [None]*y.shape[0]
batch_idx_map = list(range(y.shape[0]))
idx_list = [None]*y.shape[0]
for idx in tqdm(range(1500)):

xy_dec, _ = self.h((xy_pos, None), mask=xy_attn_mask, cache=cache)
Expand All @@ -397,17 +398,45 @@ def infer_panel(
if(idx==0):###第一次跑不能EOS否则没有了
logits = logits[:, :-1] ###刨除1024终止符号的概率
samples = sample(
logits[0], y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature
)[0].unsqueeze(0)
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=1.35, temperature=temperature
)[0]
# 本次生成的 semantic_ids 和之前的 y 构成新的 y
# print(samples.shape)#[1,1]#第一个1是bs
y = torch.concat([y, samples], dim=1)

# 移除已经生成完毕的序列
reserved_idx_of_batch_for_y = None
if (self.EOS in torch.argmax(logits, dim=-1)) or \
(self.EOS in samples[:, 0]): ###如果生成到EOS,则停止
l = samples[:, 0]==self.EOS
removed_idx_of_batch_for_y = torch.where(l==True)[0].tolist()
reserved_idx_of_batch_for_y = torch.where(l==False)[0]
# batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y]
for i in removed_idx_of_batch_for_y:
batch_index = batch_idx_map[i]
idx_list[batch_index] = idx - 1
y_list[batch_index] = y[i, :-1]

batch_idx_map = [batch_idx_map[i] for i in reserved_idx_of_batch_for_y.tolist()]

# 只保留未生成完毕的序列
if reserved_idx_of_batch_for_y is not None:
# index = torch.LongTensor(batch_idx_map).to(y.device)
y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y)
if cache["y_emb"] is not None:
cache["y_emb"] = torch.index_select(cache["y_emb"], dim=0, index=reserved_idx_of_batch_for_y)
if cache["k"] is not None:
for i in range(self.num_layers):
# 因为kv转置了,所以batch dim是1
cache["k"][i] = torch.index_select(cache["k"][i], dim=1, index=reserved_idx_of_batch_for_y)
cache["v"][i] = torch.index_select(cache["v"][i], dim=1, index=reserved_idx_of_batch_for_y)


if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
print("use early stop num:", early_stop_num)
stop = True

if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
if not (None in idx_list):
# print(torch.argmax(logits, dim=-1)[0] == self.EOS, samples[0, 0] == self.EOS)
stop = True
if stop:
Expand Down Expand Up @@ -443,6 +472,12 @@ def infer_panel(
xy_attn_mask = torch.zeros(
(1, x_len + y_len), dtype=torch.bool, device=xy_pos.device
)

if (None in idx_list):
for i in range(x.shape[0]):
if idx_list[i] is None:
idx_list[i] = 1500-1 ###如果没有生成到EOS,就用最大长度代替

if ref_free:
return y[:, :-1], 0
return y[:, :-1], idx-1
return y_list, [0]*x.shape[0]
return y_list, idx_list
12 changes: 6 additions & 6 deletions GPT_SoVITS/AR/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,27 +115,27 @@ def logits_to_probs(
top_p: Optional[int] = None,
repetition_penalty: float = 1.0,
):
if previous_tokens is not None:
previous_tokens = previous_tokens.squeeze()
# if previous_tokens is not None:
# previous_tokens = previous_tokens.squeeze()
# print(logits.shape,previous_tokens.shape)
# pdb.set_trace()
if previous_tokens is not None and repetition_penalty != 1.0:
previous_tokens = previous_tokens.long()
score = torch.gather(logits, dim=0, index=previous_tokens)
score = torch.gather(logits, dim=1, index=previous_tokens)
score = torch.where(
score < 0, score * repetition_penalty, score / repetition_penalty
)
logits.scatter_(dim=0, index=previous_tokens, src=score)
logits.scatter_(dim=1, index=previous_tokens, src=score)

if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
)
sorted_indices_to_remove = cum_probs > top_p
sorted_indices_to_remove[0] = False # keep at least one option
sorted_indices_to_remove[:, 0] = False # keep at least one option
indices_to_remove = sorted_indices_to_remove.scatter(
dim=0, index=sorted_indices, src=sorted_indices_to_remove
dim=1, index=sorted_indices, src=sorted_indices_to_remove
)
logits = logits.masked_fill(indices_to_remove, -float("Inf"))

Expand Down
Loading

0 comments on commit 17832e5

Please sign in to comment.