Spaces:
Running
on
Zero
Running
on
Zero
update
Browse files- TTS/tts/datasets/formatters.py +1 -0
- TTS/tts/layers/xtts/gpt.py +57 -0
- TTS/tts/layers/xtts/tokenizer.py +3 -2
- TTS/tts/layers/xtts/trainer/dataset.py +36 -24
- TTS/tts/layers/xtts/trainer/dvae_dataset.py +132 -0
- TTS/tts/layers/xtts/trainer/gpt_trainer.py +36 -1
- TTS/tts/models/xtts.py +3 -2
- app.py +50 -23
- local_model/__pycache__/inference.cpython-310.pyc +0 -0
- local_model/inference.py +198 -0
- requirements.txt +66 -17
TTS/tts/datasets/formatters.py
CHANGED
|
@@ -80,6 +80,7 @@ def coqui(root_path, meta_file, ignored_speakers=None):
|
|
| 80 |
{
|
| 81 |
"text": row.text,
|
| 82 |
"audio_file": audio_path,
|
|
|
|
| 83 |
"speaker_name": speaker_name if speaker_name is not None else row.speaker_name,
|
| 84 |
"emotion_name": emotion_name if emotion_name is not None else row.emotion_name,
|
| 85 |
"root_path": root_path,
|
|
|
|
| 80 |
{
|
| 81 |
"text": row.text,
|
| 82 |
"audio_file": audio_path,
|
| 83 |
+
"ref_file": "null" if "ref_file" not in metadata.columns else os.path.join(root_path, row.ref_file),
|
| 84 |
"speaker_name": speaker_name if speaker_name is not None else row.speaker_name,
|
| 85 |
"emotion_name": emotion_name if emotion_name is not None else row.emotion_name,
|
| 86 |
"root_path": root_path,
|
TTS/tts/layers/xtts/gpt.py
CHANGED
|
@@ -184,6 +184,63 @@ class GPT(nn.Module):
|
|
| 184 |
# XTTS v1
|
| 185 |
self.prompt_embedding = nn.Embedding(self.num_audio_tokens, model_dim)
|
| 186 |
self.prompt_pos_embedding = LearnedPositionEmbeddings(24 * 9, model_dim)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
|
| 188 |
def get_grad_norm_parameter_groups(self):
|
| 189 |
return {
|
|
|
|
| 184 |
# XTTS v1
|
| 185 |
self.prompt_embedding = nn.Embedding(self.num_audio_tokens, model_dim)
|
| 186 |
self.prompt_pos_embedding = LearnedPositionEmbeddings(24 * 9, model_dim)
|
| 187 |
+
|
| 188 |
+
def resize_text_embeddings(self, new_num_tokens: int):
|
| 189 |
+
|
| 190 |
+
old_embeddings_requires_grad = self.text_embedding.weight.requires_grad
|
| 191 |
+
|
| 192 |
+
old_num_tokens, old_embedding_dim = self.text_embedding.weight.size()
|
| 193 |
+
if old_num_tokens == new_num_tokens:
|
| 194 |
+
return
|
| 195 |
+
|
| 196 |
+
new_embeddings = nn.Embedding(
|
| 197 |
+
new_num_tokens,
|
| 198 |
+
old_embedding_dim,
|
| 199 |
+
device=self.text_embedding.weight.device,
|
| 200 |
+
dtype=self.text_embedding.weight.dtype,
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
# numbers of tokens to copy
|
| 204 |
+
n = min(old_num_tokens, new_num_tokens)
|
| 205 |
+
|
| 206 |
+
new_embeddings.weight.data[:n, :] = self.text_embedding.weight.data[:n, :]
|
| 207 |
+
|
| 208 |
+
self.text_embedding.weight.data = new_embeddings.weight.data
|
| 209 |
+
self.text_embedding.num_embeddings = new_embeddings.weight.data.shape[0]
|
| 210 |
+
if self.text_embedding.padding_idx is not None and (new_num_tokens - 1) < self.text_embedding.padding_idx:
|
| 211 |
+
self.text_embedding.padding_idx = None
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
self.text_embedding.requires_grad_(old_embeddings_requires_grad)
|
| 215 |
+
|
| 216 |
+
def resize_text_head(self, new_num_tokens: int):
|
| 217 |
+
old_lm_head_requires_grad = self.text_head.weight.requires_grad
|
| 218 |
+
|
| 219 |
+
old_num_tokens, old_lm_head_dim = self.text_head.weight.size()
|
| 220 |
+
|
| 221 |
+
new_lm_head_shape = (old_lm_head_dim, new_num_tokens)
|
| 222 |
+
has_new_lm_head_bias = self.text_head.bias is not None
|
| 223 |
+
|
| 224 |
+
new_lm_head = nn.Linear(
|
| 225 |
+
*new_lm_head_shape,
|
| 226 |
+
bias=has_new_lm_head_bias,
|
| 227 |
+
device=self.text_head.weight.device,
|
| 228 |
+
dtype=self.text_head.weight.dtype,
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
|
| 232 |
+
|
| 233 |
+
new_lm_head.weight.data[:num_tokens_to_copy, :] = self.text_head.weight.data[:num_tokens_to_copy, :]
|
| 234 |
+
|
| 235 |
+
# Copy bias weights to new lm head
|
| 236 |
+
if has_new_lm_head_bias:
|
| 237 |
+
new_lm_head.bias.data[:num_tokens_to_copy] = self.text_head.bias.data[:num_tokens_to_copy]
|
| 238 |
+
|
| 239 |
+
self.text_head = new_lm_head
|
| 240 |
+
|
| 241 |
+
self.text_head.requires_grad_(old_lm_head_requires_grad)
|
| 242 |
+
pass
|
| 243 |
+
|
| 244 |
|
| 245 |
def get_grad_norm_parameter_groups(self):
|
| 246 |
return {
|
TTS/tts/layers/xtts/tokenizer.py
CHANGED
|
@@ -621,7 +621,7 @@ class VoiceBpeTokenizer:
|
|
| 621 |
|
| 622 |
def check_input_length(self, txt, lang):
|
| 623 |
lang = lang.split("-")[0] # remove the region
|
| 624 |
-
limit = self.char_limits.get(lang,
|
| 625 |
if len(txt) > limit:
|
| 626 |
print(
|
| 627 |
f"[!] Warning: The text length exceeds the character limit of {limit} for language '{lang}', this might cause truncated audio."
|
|
@@ -640,7 +640,8 @@ class VoiceBpeTokenizer:
|
|
| 640 |
# @manmay will implement this
|
| 641 |
txt = basic_cleaners(txt)
|
| 642 |
else:
|
| 643 |
-
|
|
|
|
| 644 |
return txt
|
| 645 |
|
| 646 |
def encode(self, txt, lang):
|
|
|
|
| 621 |
|
| 622 |
def check_input_length(self, txt, lang):
|
| 623 |
lang = lang.split("-")[0] # remove the region
|
| 624 |
+
limit = self.char_limits.get(lang, 300)
|
| 625 |
if len(txt) > limit:
|
| 626 |
print(
|
| 627 |
f"[!] Warning: The text length exceeds the character limit of {limit} for language '{lang}', this might cause truncated audio."
|
|
|
|
| 640 |
# @manmay will implement this
|
| 641 |
txt = basic_cleaners(txt)
|
| 642 |
else:
|
| 643 |
+
txt = basic_cleaners(txt)
|
| 644 |
+
# print(f"[!] Warning: Preprocess [Language '{lang}'] text is not implemented, use `basic_cleaners` instead.")
|
| 645 |
return txt
|
| 646 |
|
| 647 |
def encode(self, txt, lang):
|
TTS/tts/layers/xtts/trainer/dataset.py
CHANGED
|
@@ -23,29 +23,41 @@ def key_samples_by_col(samples, col):
|
|
| 23 |
return samples_by_col
|
| 24 |
|
| 25 |
|
| 26 |
-
def get_prompt_slice(gt_path, max_sample_length, min_sample_length, sample_rate, is_eval=False):
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
else:
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
-
rand_end = rand_start + sample_length
|
| 45 |
-
rel_clip = rel_clip[:, rand_start:rand_end]
|
| 46 |
-
rel_clip = F.pad(rel_clip, pad=(0, max_sample_length - rel_clip.shape[-1]))
|
| 47 |
-
cond_idxs = [rand_start, rand_end]
|
| 48 |
-
return rel_clip, rel_clip.shape[-1], cond_idxs
|
| 49 |
|
| 50 |
|
| 51 |
class XTTSDataset(torch.utils.data.Dataset):
|
|
@@ -110,14 +122,14 @@ class XTTSDataset(torch.utils.data.Dataset):
|
|
| 110 |
wav = load_audio(audiopath, self.sample_rate)
|
| 111 |
if text is None or len(text.strip()) == 0:
|
| 112 |
raise ValueError
|
| 113 |
-
if wav is None or wav.shape[-1] < (0.
|
| 114 |
# Ultra short clips are also useless (and can cause problems within some models).
|
| 115 |
raise ValueError
|
| 116 |
|
| 117 |
if self.use_masking_gt_prompt_approach:
|
| 118 |
# get a slice from GT to condition the model
|
| 119 |
cond, _, cond_idxs = get_prompt_slice(
|
| 120 |
-
audiopath, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate, self.is_eval
|
| 121 |
)
|
| 122 |
# if use masking do not use cond_len
|
| 123 |
cond_len = torch.nan
|
|
@@ -128,7 +140,7 @@ class XTTSDataset(torch.utils.data.Dataset):
|
|
| 128 |
else audiopath
|
| 129 |
)
|
| 130 |
cond, cond_len, _ = get_prompt_slice(
|
| 131 |
-
ref_sample, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate, self.is_eval
|
| 132 |
)
|
| 133 |
# if do not use masking use cond_len
|
| 134 |
cond_idxs = torch.nan
|
|
|
|
| 23 |
return samples_by_col
|
| 24 |
|
| 25 |
|
| 26 |
+
def get_prompt_slice(gt_path, max_sample_length, min_sample_length, sample_rate, is_eval=False, ref_path="null"):
|
| 27 |
+
if ref_path == "null":
|
| 28 |
+
rel_clip = load_audio(gt_path, sample_rate)
|
| 29 |
+
# if eval uses a middle size sample when it is possible to be more reproducible
|
| 30 |
+
if is_eval:
|
| 31 |
+
sample_length = int((min_sample_length + max_sample_length) / 2)
|
| 32 |
+
else:
|
| 33 |
+
sample_length = random.randint(min_sample_length, max_sample_length)
|
| 34 |
+
gap = rel_clip.shape[-1] - sample_length
|
| 35 |
+
if gap < 0:
|
| 36 |
+
sample_length = rel_clip.shape[-1] // 2
|
| 37 |
+
gap = rel_clip.shape[-1] - sample_length
|
| 38 |
+
|
| 39 |
+
# if eval start always from the position 0 to be more reproducible
|
| 40 |
+
if is_eval:
|
| 41 |
+
rand_start = 0
|
| 42 |
+
else:
|
| 43 |
+
rand_start = random.randint(0, gap)
|
| 44 |
+
|
| 45 |
+
rand_end = rand_start + sample_length
|
| 46 |
+
rel_clip = rel_clip[:, rand_start:rand_end]
|
| 47 |
+
rel_clip = F.pad(rel_clip, pad=(0, max_sample_length - rel_clip.shape[-1]))
|
| 48 |
+
cond_idxs = [rand_start, rand_end]
|
| 49 |
+
return rel_clip, rel_clip.shape[-1], cond_idxs
|
| 50 |
else:
|
| 51 |
+
rel_clip = load_audio(ref_path, sample_rate)
|
| 52 |
+
|
| 53 |
+
sample_length = min(max_sample_length, rel_clip.shape[-1])
|
| 54 |
+
|
| 55 |
+
rel_clip = rel_clip[:, :sample_length]
|
| 56 |
+
rel_clip = F.pad(rel_clip, pad=(0, max_sample_length - rel_clip.shape[-1]))
|
| 57 |
+
cond_idxs = [0, sample_length]
|
| 58 |
+
return rel_clip, rel_clip.shape[-1], cond_idxs
|
| 59 |
+
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
|
| 63 |
class XTTSDataset(torch.utils.data.Dataset):
|
|
|
|
| 122 |
wav = load_audio(audiopath, self.sample_rate)
|
| 123 |
if text is None or len(text.strip()) == 0:
|
| 124 |
raise ValueError
|
| 125 |
+
if wav is None or wav.shape[-1] < (0.2 * self.sample_rate):
|
| 126 |
# Ultra short clips are also useless (and can cause problems within some models).
|
| 127 |
raise ValueError
|
| 128 |
|
| 129 |
if self.use_masking_gt_prompt_approach:
|
| 130 |
# get a slice from GT to condition the model
|
| 131 |
cond, _, cond_idxs = get_prompt_slice(
|
| 132 |
+
audiopath, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate, self.is_eval, sample["ref_file"]
|
| 133 |
)
|
| 134 |
# if use masking do not use cond_len
|
| 135 |
cond_len = torch.nan
|
|
|
|
| 140 |
else audiopath
|
| 141 |
)
|
| 142 |
cond, cond_len, _ = get_prompt_slice(
|
| 143 |
+
ref_sample, self.max_conditioning_length, self.min_conditioning_length, self.sample_rate, self.is_eval, sample["ref_file"]
|
| 144 |
)
|
| 145 |
# if do not use masking use cond_len
|
| 146 |
cond_idxs = torch.nan
|
TTS/tts/layers/xtts/trainer/dvae_dataset.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import random
|
| 3 |
+
from TTS.tts.models.xtts import load_audio
|
| 4 |
+
|
| 5 |
+
torch.set_num_threads(1)
|
| 6 |
+
|
| 7 |
+
def key_samples_by_col(samples, col):
|
| 8 |
+
"""Returns a dictionary of samples keyed by language."""
|
| 9 |
+
samples_by_col = {}
|
| 10 |
+
for sample in samples:
|
| 11 |
+
col_val = sample[col]
|
| 12 |
+
assert isinstance(col_val, str)
|
| 13 |
+
if col_val not in samples_by_col:
|
| 14 |
+
samples_by_col[col_val] = []
|
| 15 |
+
samples_by_col[col_val].append(sample)
|
| 16 |
+
return samples_by_col
|
| 17 |
+
|
| 18 |
+
class DVAEDataset(torch.utils.data.Dataset):
|
| 19 |
+
def __init__(self, samples, sample_rate, is_eval, max_wav_len=255995):
|
| 20 |
+
self.sample_rate = sample_rate
|
| 21 |
+
self.is_eval = is_eval
|
| 22 |
+
self.max_wav_len = max_wav_len
|
| 23 |
+
self.samples = samples
|
| 24 |
+
self.training_seed = 1
|
| 25 |
+
self.failed_samples = set()
|
| 26 |
+
if not is_eval:
|
| 27 |
+
random.seed(self.training_seed)
|
| 28 |
+
# random.shuffle(self.samples)
|
| 29 |
+
random.shuffle(self.samples)
|
| 30 |
+
# order by language
|
| 31 |
+
self.samples = key_samples_by_col(self.samples, "language")
|
| 32 |
+
print(" > Sampling by language:", self.samples.keys())
|
| 33 |
+
else:
|
| 34 |
+
# for evaluation load and check samples that are corrupted to ensures the reproducibility
|
| 35 |
+
self.check_eval_samples()
|
| 36 |
+
|
| 37 |
+
def check_eval_samples(self):
|
| 38 |
+
print(" > Filtering invalid eval samples!!")
|
| 39 |
+
new_samples = []
|
| 40 |
+
for sample in self.samples:
|
| 41 |
+
try:
|
| 42 |
+
_, wav = self.load_item(sample)
|
| 43 |
+
except:
|
| 44 |
+
continue
|
| 45 |
+
# Basically, this audio file is nonexistent or too long to be supported by the dataset.
|
| 46 |
+
if (
|
| 47 |
+
wav is None
|
| 48 |
+
or (self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len)
|
| 49 |
+
):
|
| 50 |
+
continue
|
| 51 |
+
new_samples.append(sample)
|
| 52 |
+
self.samples = new_samples
|
| 53 |
+
print(" > Total eval samples after filtering:", len(self.samples))
|
| 54 |
+
|
| 55 |
+
def load_item(self, sample):
|
| 56 |
+
audiopath = sample["audio_file"]
|
| 57 |
+
wav = load_audio(audiopath, self.sample_rate)
|
| 58 |
+
if wav is None or wav.shape[-1] < (0.5 * self.sample_rate):
|
| 59 |
+
# Ultra short clips are also useless (and can cause problems within some models).
|
| 60 |
+
raise ValueError
|
| 61 |
+
|
| 62 |
+
return audiopath, wav
|
| 63 |
+
|
| 64 |
+
def __getitem__(self, index):
|
| 65 |
+
if self.is_eval:
|
| 66 |
+
sample = self.samples[index]
|
| 67 |
+
sample_id = str(index)
|
| 68 |
+
else:
|
| 69 |
+
# select a random language
|
| 70 |
+
lang = random.choice(list(self.samples.keys()))
|
| 71 |
+
# select random sample
|
| 72 |
+
index = random.randint(0, len(self.samples[lang]) - 1)
|
| 73 |
+
sample = self.samples[lang][index]
|
| 74 |
+
# a unique id for each sampel to deal with fails
|
| 75 |
+
sample_id = lang + "_" + str(index)
|
| 76 |
+
|
| 77 |
+
# ignore samples that we already know that is not valid ones
|
| 78 |
+
if sample_id in self.failed_samples:
|
| 79 |
+
# call get item again to get other sample
|
| 80 |
+
return self[1]
|
| 81 |
+
|
| 82 |
+
# try to load the sample, if fails added it to the failed samples list
|
| 83 |
+
try:
|
| 84 |
+
audiopath, wav = self.load_item(sample)
|
| 85 |
+
except:
|
| 86 |
+
self.failed_samples.add(sample_id)
|
| 87 |
+
return self[1]
|
| 88 |
+
|
| 89 |
+
# check if the audio and text size limits and if it out of the limits, added it failed_samples
|
| 90 |
+
if (
|
| 91 |
+
wav is None
|
| 92 |
+
or (self.max_wav_len is not None and wav.shape[-1] > self.max_wav_len)
|
| 93 |
+
):
|
| 94 |
+
# Basically, this audio file is nonexistent or too long to be supported by the dataset.
|
| 95 |
+
# It's hard to handle this situation properly. Best bet is to return the a random valid token and skew the dataset somewhat as a result.
|
| 96 |
+
self.failed_samples.add(sample_id)
|
| 97 |
+
return self[1]
|
| 98 |
+
|
| 99 |
+
res = {
|
| 100 |
+
"wav": wav,
|
| 101 |
+
"wav_lengths": torch.tensor(wav.shape[-1], dtype=torch.long),
|
| 102 |
+
"filenames": audiopath,
|
| 103 |
+
}
|
| 104 |
+
return res
|
| 105 |
+
|
| 106 |
+
def __len__(self):
|
| 107 |
+
if self.is_eval:
|
| 108 |
+
return len(self.samples)
|
| 109 |
+
return sum([len(v) for v in self.samples.values()])
|
| 110 |
+
|
| 111 |
+
def collate_fn(self, batch):
|
| 112 |
+
# convert list of dicts to dict of lists
|
| 113 |
+
B = len(batch)
|
| 114 |
+
|
| 115 |
+
batch = {k: [dic[k] for dic in batch] for k in batch[0]}
|
| 116 |
+
|
| 117 |
+
# stack for features that already have the same shape
|
| 118 |
+
batch["wav_lengths"] = torch.stack(batch["wav_lengths"])
|
| 119 |
+
|
| 120 |
+
max_wav_len = batch["wav_lengths"].max()
|
| 121 |
+
|
| 122 |
+
# create padding tensors
|
| 123 |
+
wav_padded = torch.FloatTensor(B, 1, max_wav_len)
|
| 124 |
+
|
| 125 |
+
# initialize tensors for zero padding
|
| 126 |
+
wav_padded = wav_padded.zero_()
|
| 127 |
+
for i in range(B):
|
| 128 |
+
wav = batch["wav"][i]
|
| 129 |
+
wav_padded[i, :, : batch["wav_lengths"][i]] = torch.FloatTensor(wav)
|
| 130 |
+
|
| 131 |
+
batch["wav"] = wav_padded
|
| 132 |
+
return batch
|
TTS/tts/layers/xtts/trainer/gpt_trainer.py
CHANGED
|
@@ -97,7 +97,8 @@ class GPTTrainer(BaseTTS):
|
|
| 97 |
states_keys = list(gpt_checkpoint.keys())
|
| 98 |
for key in states_keys:
|
| 99 |
if "gpt." in key:
|
| 100 |
-
new_key = key.replace("gpt.", "")
|
|
|
|
| 101 |
gpt_checkpoint[new_key] = gpt_checkpoint[key]
|
| 102 |
del gpt_checkpoint[key]
|
| 103 |
else:
|
|
@@ -484,6 +485,40 @@ class GPTTrainer(BaseTTS):
|
|
| 484 |
|
| 485 |
state = self.xtts.get_compatible_checkpoint_state_dict(checkpoint_path)
|
| 486 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 487 |
# load the model weights
|
| 488 |
self.xtts.load_state_dict(state, strict=strict)
|
| 489 |
|
|
|
|
| 97 |
states_keys = list(gpt_checkpoint.keys())
|
| 98 |
for key in states_keys:
|
| 99 |
if "gpt." in key:
|
| 100 |
+
# new_key = key.replace("gpt.", "")
|
| 101 |
+
new_key = key[4:]
|
| 102 |
gpt_checkpoint[new_key] = gpt_checkpoint[key]
|
| 103 |
del gpt_checkpoint[key]
|
| 104 |
else:
|
|
|
|
| 485 |
|
| 486 |
state = self.xtts.get_compatible_checkpoint_state_dict(checkpoint_path)
|
| 487 |
|
| 488 |
+
# edit checkpoint if the number of tokens is changed to ensures the better transfer learning possible
|
| 489 |
+
if (
|
| 490 |
+
"gpt.text_embedding.weight" in state
|
| 491 |
+
and state["gpt.text_embedding.weight"].shape != self.xtts.gpt.text_embedding.weight.shape
|
| 492 |
+
):
|
| 493 |
+
num_new_tokens = (
|
| 494 |
+
self.xtts.gpt.text_embedding.weight.shape[0] - state["gpt.text_embedding.weight"].shape[0]
|
| 495 |
+
)
|
| 496 |
+
print(f" > Loading checkpoint with {num_new_tokens} additional tokens.")
|
| 497 |
+
|
| 498 |
+
# add new tokens to a linear layer (text_head)
|
| 499 |
+
emb_g = state["gpt.text_embedding.weight"]
|
| 500 |
+
new_row = torch.randn(num_new_tokens, emb_g.shape[1])
|
| 501 |
+
start_token_row = emb_g[-1, :]
|
| 502 |
+
emb_g = torch.cat([emb_g, new_row], axis=0)
|
| 503 |
+
emb_g[-1, :] = start_token_row
|
| 504 |
+
state["gpt.text_embedding.weight"] = emb_g
|
| 505 |
+
|
| 506 |
+
# add new weights to the linear layer (text_head)
|
| 507 |
+
text_head_weight = state["gpt.text_head.weight"]
|
| 508 |
+
start_token_row = text_head_weight[-1, :]
|
| 509 |
+
new_entry = torch.randn(num_new_tokens, self.xtts.gpt.text_head.weight.shape[1])
|
| 510 |
+
text_head_weight = torch.cat([text_head_weight, new_entry], axis=0)
|
| 511 |
+
text_head_weight[-1, :] = start_token_row
|
| 512 |
+
state["gpt.text_head.weight"] = text_head_weight
|
| 513 |
+
|
| 514 |
+
# add new biases to the linear layer (text_head)
|
| 515 |
+
text_head_bias = state["gpt.text_head.bias"]
|
| 516 |
+
start_token_row = text_head_bias[-1]
|
| 517 |
+
new_bias_entry = torch.zeros(num_new_tokens)
|
| 518 |
+
text_head_bias = torch.cat([text_head_bias, new_bias_entry], axis=0)
|
| 519 |
+
text_head_bias[-1] = start_token_row
|
| 520 |
+
state["gpt.text_head.bias"] = text_head_bias
|
| 521 |
+
|
| 522 |
# load the model weights
|
| 523 |
self.xtts.load_state_dict(state, strict=strict)
|
| 524 |
|
TTS/tts/models/xtts.py
CHANGED
|
@@ -523,7 +523,7 @@ class Xtts(BaseTTS):
|
|
| 523 |
gpt_cond_latent = gpt_cond_latent.to(self.device)
|
| 524 |
speaker_embedding = speaker_embedding.to(self.device)
|
| 525 |
if enable_text_splitting:
|
| 526 |
-
text = split_sentence(text, language, self.tokenizer.char_limits
|
| 527 |
else:
|
| 528 |
text = [text]
|
| 529 |
|
|
@@ -553,6 +553,7 @@ class Xtts(BaseTTS):
|
|
| 553 |
output_attentions=False,
|
| 554 |
**hf_generate_kwargs,
|
| 555 |
)
|
|
|
|
| 556 |
expected_output_len = torch.tensor(
|
| 557 |
[gpt_codes.shape[-1] * self.gpt.code_stride_len], device=text_tokens.device
|
| 558 |
)
|
|
@@ -633,7 +634,7 @@ class Xtts(BaseTTS):
|
|
| 633 |
gpt_cond_latent = gpt_cond_latent.to(self.device)
|
| 634 |
speaker_embedding = speaker_embedding.to(self.device)
|
| 635 |
if enable_text_splitting:
|
| 636 |
-
text = split_sentence(text, language, self.tokenizer.char_limits
|
| 637 |
else:
|
| 638 |
text = [text]
|
| 639 |
|
|
|
|
| 523 |
gpt_cond_latent = gpt_cond_latent.to(self.device)
|
| 524 |
speaker_embedding = speaker_embedding.to(self.device)
|
| 525 |
if enable_text_splitting:
|
| 526 |
+
text = split_sentence(text, language, self.tokenizer.char_limits.get(language, 250))
|
| 527 |
else:
|
| 528 |
text = [text]
|
| 529 |
|
|
|
|
| 553 |
output_attentions=False,
|
| 554 |
**hf_generate_kwargs,
|
| 555 |
)
|
| 556 |
+
|
| 557 |
expected_output_len = torch.tensor(
|
| 558 |
[gpt_codes.shape[-1] * self.gpt.code_stride_len], device=text_tokens.device
|
| 559 |
)
|
|
|
|
| 634 |
gpt_cond_latent = gpt_cond_latent.to(self.device)
|
| 635 |
speaker_embedding = speaker_embedding.to(self.device)
|
| 636 |
if enable_text_splitting:
|
| 637 |
+
text = split_sentence(text, language, self.tokenizer.char_limits.get(language, 250))
|
| 638 |
else:
|
| 639 |
text = [text]
|
| 640 |
|
app.py
CHANGED
|
@@ -6,45 +6,72 @@ import sys
|
|
| 6 |
import soundfile as sf
|
| 7 |
import numpy as np
|
| 8 |
import logging
|
|
|
|
| 9 |
|
| 10 |
# Configuration du logger
|
| 11 |
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 12 |
|
| 13 |
#Chemin local de téléchargement des fichiers (il faut s'assurer que le dossier existe)
|
| 14 |
-
LOCAL_DOWNLOAD_PATH = "
|
| 15 |
-
|
| 16 |
# Télécharger le script d'inférence
|
| 17 |
repo_id = "dofbi/galsenai-xtts-v2-wolof-inference"
|
| 18 |
-
inference_file = hf_hub_download(repo_id=repo_id, filename="inference.py",local_dir=LOCAL_DOWNLOAD_PATH)
|
| 19 |
|
| 20 |
# Ajouter le dossier au chemin de recherche
|
| 21 |
sys.path.insert(0, LOCAL_DOWNLOAD_PATH)
|
| 22 |
|
| 23 |
-
# Importer la
|
| 24 |
-
from inference import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
logging.debug(f"tts function called with text: {text} and audio_reference: {audio_reference}")
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
temp_audio_path = "temp_audio_ref.wav"
|
| 31 |
-
sf.write(temp_audio_path, audio_reference, 44100)
|
| 32 |
-
logging.debug(f"Audio reference saved to {temp_audio_path}")
|
| 33 |
-
audio_output, sample_rate = generate_audio(text, temp_audio_path, LOCAL_DOWNLOAD_PATH)
|
| 34 |
-
logging.debug(f"Audio generated with sample rate: {sample_rate}")
|
| 35 |
-
return (sample_rate, audio_output)
|
| 36 |
-
else:
|
| 37 |
logging.debug("Text or audio reference is missing")
|
| 38 |
return "Veuillez entrer un texte et fournir un audio de référence."
|
| 39 |
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
demo.launch()
|
|
|
|
| 6 |
import soundfile as sf
|
| 7 |
import numpy as np
|
| 8 |
import logging
|
| 9 |
+
import tempfile
|
| 10 |
|
| 11 |
# Configuration du logger
|
| 12 |
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 13 |
|
| 14 |
#Chemin local de téléchargement des fichiers (il faut s'assurer que le dossier existe)
|
| 15 |
+
LOCAL_DOWNLOAD_PATH = os.path.dirname("/content") # Utiliser le chemin du script
|
|
|
|
| 16 |
# Télécharger le script d'inférence
|
| 17 |
repo_id = "dofbi/galsenai-xtts-v2-wolof-inference"
|
| 18 |
+
inference_file = hf_hub_download(repo_id=repo_id, filename="inference.py", local_dir=LOCAL_DOWNLOAD_PATH)
|
| 19 |
|
| 20 |
# Ajouter le dossier au chemin de recherche
|
| 21 |
sys.path.insert(0, LOCAL_DOWNLOAD_PATH)
|
| 22 |
|
| 23 |
+
# Importer la classe à partir du script d'inférence téléchargé
|
| 24 |
+
from inference import WolofXTTSInference
|
| 25 |
+
|
| 26 |
+
# Initialiser le modèle une seule fois
|
| 27 |
+
tts_model = WolofXTTSInference()
|
| 28 |
+
|
| 29 |
+
def tts(text: str, audio_reference: tuple[int, np.ndarray]) -> tuple[int, np.ndarray] | str:
|
| 30 |
+
"""
|
| 31 |
+
Synthétise de la parole à partir d'un texte en utilisant un audio de référence.
|
| 32 |
|
| 33 |
+
Args:
|
| 34 |
+
text (str): Le texte à synthétiser.
|
| 35 |
+
audio_reference (tuple[int, np.ndarray]): Un tuple contenant le taux d'échantillonnage et les données audio de référence.
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
tuple[int, np.ndarray] | str: un tuple contenant le taux d'échantillonnage et les données audio synthétisées, ou un message d'erreur.
|
| 39 |
+
"""
|
| 40 |
logging.debug(f"tts function called with text: {text} and audio_reference: {audio_reference}")
|
| 41 |
+
|
| 42 |
+
if not text or audio_reference is None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
logging.debug("Text or audio reference is missing")
|
| 44 |
return "Veuillez entrer un texte et fournir un audio de référence."
|
| 45 |
|
| 46 |
+
try:
|
| 47 |
+
sample_rate, audio_data = audio_reference
|
| 48 |
+
|
| 49 |
+
# Créer un fichier temporaire pour l'audio de référence
|
| 50 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as temp_audio_file:
|
| 51 |
+
sf.write(temp_audio_file.name, audio_data, sample_rate)
|
| 52 |
+
logging.debug(f"Audio reference saved to {temp_audio_file.name}")
|
| 53 |
+
|
| 54 |
+
# Utiliser la méthode generate_audio de la nouvelle classe
|
| 55 |
+
audio_output, output_sample_rate = tts_model.generate_audio(
|
| 56 |
+
text,
|
| 57 |
+
reference_audio=temp_audio_file.name
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
logging.debug(f"Audio generated with sample rate: {output_sample_rate}")
|
| 61 |
+
return (output_sample_rate, audio_output)
|
| 62 |
+
|
| 63 |
+
except Exception as e:
|
| 64 |
+
logging.error(f"Error during audio generation: {e}")
|
| 65 |
+
return f"Une erreur s'est produite lors de la génération audio: {e}"
|
| 66 |
|
| 67 |
if __name__ == "__main__":
|
| 68 |
+
demo = gr.Interface(
|
| 69 |
+
fn=tts,
|
| 70 |
+
inputs=[
|
| 71 |
+
gr.Textbox(label="Text to synthesize"),
|
| 72 |
+
gr.Audio(type="numpy", label="Reference audio")
|
| 73 |
+
],
|
| 74 |
+
outputs=gr.Audio(label="Synthesized audio"),
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
demo.launch()
|
local_model/__pycache__/inference.cpython-310.pyc
ADDED
|
Binary file (5.07 kB). View file
|
|
|
local_model/inference.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import os
|
| 3 |
+
import logging
|
| 4 |
+
import soundfile as sf
|
| 5 |
+
import numpy as np
|
| 6 |
+
from huggingface_hub import hf_hub_download
|
| 7 |
+
from TTS.tts.configs.xtts_config import XttsConfig
|
| 8 |
+
from TTS.tts.models.xtts import Xtts
|
| 9 |
+
|
| 10 |
+
# --- CONSTANTES ---
|
| 11 |
+
REPO_ID = "dofbi/galsenai-xtts-v2-wolof-inference"
|
| 12 |
+
LOCAL_DIR = "./models"
|
| 13 |
+
|
| 14 |
+
class WolofXTTSInference:
|
| 15 |
+
def __init__(self, repo_id=REPO_ID, local_dir=LOCAL_DIR):
|
| 16 |
+
# Configuration du logging
|
| 17 |
+
logging.basicConfig(
|
| 18 |
+
level=logging.INFO,
|
| 19 |
+
format='%(asctime)s - %(levelname)s - %(message)s'
|
| 20 |
+
)
|
| 21 |
+
self.logger = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
# Créer le dossier local s'il n'existe pas
|
| 24 |
+
os.makedirs(local_dir, exist_ok=True)
|
| 25 |
+
|
| 26 |
+
# Téléchargement des fichiers nécessaires
|
| 27 |
+
try:
|
| 28 |
+
# Créer les sous-dossiers nécessaires
|
| 29 |
+
os.makedirs(os.path.join(local_dir, "Anta_GPT_XTTS_Wo"), exist_ok=True)
|
| 30 |
+
os.makedirs(os.path.join(local_dir, "XTTS_v2.0_original_model_files"), exist_ok=True)
|
| 31 |
+
|
| 32 |
+
# Télécharger le checkpoint
|
| 33 |
+
self.model_path = hf_hub_download(
|
| 34 |
+
repo_id=repo_id,
|
| 35 |
+
filename="Anta_GPT_XTTS_Wo/best_model_89250.pth",
|
| 36 |
+
local_dir=local_dir
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
# Télécharger le fichier de configuration
|
| 40 |
+
self.config_path = hf_hub_download(
|
| 41 |
+
repo_id=repo_id,
|
| 42 |
+
filename="Anta_GPT_XTTS_Wo/config.json",
|
| 43 |
+
local_dir=local_dir
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
# Télécharger le vocabulaire
|
| 47 |
+
self.vocab_path = hf_hub_download(
|
| 48 |
+
repo_id=repo_id,
|
| 49 |
+
filename="XTTS_v2.0_original_model_files/vocab.json",
|
| 50 |
+
local_dir=local_dir
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
# Télécharger l'audio de référence
|
| 54 |
+
self.reference_audio = hf_hub_download(
|
| 55 |
+
repo_id=repo_id,
|
| 56 |
+
filename="anta_sample.wav",
|
| 57 |
+
local_dir=local_dir
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
except Exception as e:
|
| 61 |
+
self.logger.error(f"Erreur lors du téléchargement des fichiers : {e}")
|
| 62 |
+
raise
|
| 63 |
+
|
| 64 |
+
# Sélection du device
|
| 65 |
+
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 66 |
+
|
| 67 |
+
# Initialisation du modèle
|
| 68 |
+
self.model = self._load_model()
|
| 69 |
+
|
| 70 |
+
def _load_model(self):
|
| 71 |
+
"""Charge le modèle XTTS"""
|
| 72 |
+
try:
|
| 73 |
+
self.logger.info("Chargement du modèle XTTS...")
|
| 74 |
+
|
| 75 |
+
# Initialisation du modèle
|
| 76 |
+
config = XttsConfig()
|
| 77 |
+
config.load_json(self.config_path)
|
| 78 |
+
model = Xtts.init_from_config(config)
|
| 79 |
+
|
| 80 |
+
# Chargement du checkpoint avec load_checkpoint
|
| 81 |
+
model.load_checkpoint(config,
|
| 82 |
+
checkpoint_path=self.model_path,
|
| 83 |
+
vocab_path=self.vocab_path,
|
| 84 |
+
use_deepspeed=False
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
model.to(self.device)
|
| 88 |
+
model.eval() # Mettre le modèle en mode évaluation
|
| 89 |
+
|
| 90 |
+
self.logger.info("Modèle chargé avec succès!")
|
| 91 |
+
return model
|
| 92 |
+
|
| 93 |
+
except Exception as e:
|
| 94 |
+
self.logger.error(f"Erreur lors du chargement du modèle : {e}")
|
| 95 |
+
raise
|
| 96 |
+
|
| 97 |
+
def generate_audio(
|
| 98 |
+
self,
|
| 99 |
+
text: str,
|
| 100 |
+
reference_audio: str = None,
|
| 101 |
+
speed: float = 1.06,
|
| 102 |
+
language: str = "wo",
|
| 103 |
+
output_path: str = None
|
| 104 |
+
) -> tuple[np.ndarray, int]:
|
| 105 |
+
"""
|
| 106 |
+
Génère de l'audio à partir du texte fourni
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
text (str): Texte à convertir en audio
|
| 110 |
+
reference_audio (str, optional): Chemin vers l'audio de référence. Defaults to None.
|
| 111 |
+
speed (float, optional): Vitesse de lecture. Defaults to 1.06.
|
| 112 |
+
language (str, optional): Langue du texte. Defaults to "wo".
|
| 113 |
+
output_path (str, optional): Chemin de sauvegarde de l'audio généré. Defaults to None.
|
| 114 |
+
|
| 115 |
+
Returns:
|
| 116 |
+
tuple[np.ndarray, int]: audio_array, sample_rate
|
| 117 |
+
"""
|
| 118 |
+
if not text:
|
| 119 |
+
raise ValueError("Le texte ne peut pas être vide.")
|
| 120 |
+
|
| 121 |
+
try:
|
| 122 |
+
# Utiliser l'audio de référence fourni ou par défaut
|
| 123 |
+
ref_audio = reference_audio or self.reference_audio
|
| 124 |
+
|
| 125 |
+
# Obtenir les embeddings
|
| 126 |
+
gpt_cond_latent, speaker_embedding = self.model.get_conditioning_latents(
|
| 127 |
+
audio_path=[ref_audio],
|
| 128 |
+
gpt_cond_len=self.model.config.gpt_cond_len,
|
| 129 |
+
max_ref_length=self.model.config.max_ref_len,
|
| 130 |
+
sound_norm_refs=self.model.config.sound_norm_refs
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
# Génération de l'audio
|
| 134 |
+
result = self.model.inference(
|
| 135 |
+
text=text.lower(),
|
| 136 |
+
gpt_cond_latent=gpt_cond_latent,
|
| 137 |
+
speaker_embedding=speaker_embedding,
|
| 138 |
+
do_sample=False,
|
| 139 |
+
speed=speed,
|
| 140 |
+
language=language,
|
| 141 |
+
enable_text_splitting=True
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
# Récupérer le taux d'échantillonnage
|
| 145 |
+
sample_rate = self.model.config.audio.sample_rate
|
| 146 |
+
|
| 147 |
+
# Sauvegarde optionnelle
|
| 148 |
+
if output_path:
|
| 149 |
+
sf.write(output_path, result["wav"], sample_rate)
|
| 150 |
+
self.logger.info(f"Audio sauvegardé dans {output_path}")
|
| 151 |
+
|
| 152 |
+
return result["wav"], sample_rate
|
| 153 |
+
|
| 154 |
+
except Exception as e:
|
| 155 |
+
self.logger.error(f"Erreur lors de la génération de l'audio : {e}")
|
| 156 |
+
raise
|
| 157 |
+
|
| 158 |
+
def generate_audio_from_config(self, text: str, config: dict, output_path: str = None) -> tuple[np.ndarray, int]:
|
| 159 |
+
"""
|
| 160 |
+
Génère de l'audio à partir du texte et d'un dictionnaire de configuration.
|
| 161 |
+
|
| 162 |
+
Args:
|
| 163 |
+
text (str): Texte à convertir en audio
|
| 164 |
+
config (dict): Dictionnaire de configuration (speed, language, reference_audio)
|
| 165 |
+
output_path (str, optional): Chemin de sauvegarde de l'audio généré. Defaults to None.
|
| 166 |
+
|
| 167 |
+
Returns:
|
| 168 |
+
tuple[np.ndarray, int]: audio_array, sample_rate
|
| 169 |
+
"""
|
| 170 |
+
speed = config.get('speed', 1.06)
|
| 171 |
+
language = config.get('language', "wo")
|
| 172 |
+
reference_audio = config.get('reference_audio', None)
|
| 173 |
+
return self.generate_audio(text=text, reference_audio=reference_audio, speed=speed, language=language, output_path=output_path)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
# Exemple d'utilisation
|
| 177 |
+
if __name__ == "__main__":
|
| 178 |
+
tts = WolofXTTSInference()
|
| 179 |
+
|
| 180 |
+
# Exemple de génération d'audio
|
| 181 |
+
text = "Màngi tuddu Aadama, di baat bii waa Galsen A.I defar ngir wax ak yéen ci wolof!"
|
| 182 |
+
|
| 183 |
+
# Simple
|
| 184 |
+
audio, sr = tts.generate_audio(
|
| 185 |
+
text,
|
| 186 |
+
output_path="generated_audio.wav"
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
# Avec une config
|
| 190 |
+
config_gen_audio = {
|
| 191 |
+
"speed": 1.2,
|
| 192 |
+
"language": "wo",
|
| 193 |
+
}
|
| 194 |
+
audio, sr = tts.generate_audio_from_config(
|
| 195 |
+
text=text,
|
| 196 |
+
config=config_gen_audio,
|
| 197 |
+
output_path="generated_audio_config.wav"
|
| 198 |
+
)
|
requirements.txt
CHANGED
|
@@ -1,24 +1,73 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
torchaudio
|
| 3 |
-
soundfile
|
| 4 |
transformers
|
| 5 |
-
|
| 6 |
-
huggingface_hub
|
| 7 |
-
tqdm
|
| 8 |
-
coqpit
|
| 9 |
trainer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
librosa
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
pypinyin
|
|
|
|
| 14 |
hangul_romanize
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
num2words
|
| 16 |
-
spacy
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
hmmlearn
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio
|
| 2 |
+
# core deps
|
| 3 |
+
numpy==1.23.0;python_version<="3.10"
|
| 4 |
+
numpy>=1.24.3;python_version>"3.10"
|
| 5 |
+
matplotlib
|
| 6 |
+
cython>=0.29.30
|
| 7 |
+
scipy>=1.11.2
|
| 8 |
+
torch>=2.1
|
| 9 |
torchaudio
|
|
|
|
| 10 |
transformers
|
| 11 |
+
gdown
|
|
|
|
|
|
|
|
|
|
| 12 |
trainer
|
| 13 |
+
soundfile>=0.12.0
|
| 14 |
+
librosa>=0.10.0
|
| 15 |
+
scikit-learn>=1.3.0
|
| 16 |
+
numba==0.55.1;python_version<"3.9"
|
| 17 |
+
numba>=0.57.0;python_version>="3.9"
|
| 18 |
+
inflect>=5.6.0
|
| 19 |
+
tqdm>=4.64.1
|
| 20 |
+
anyascii>=0.3.0
|
| 21 |
+
pyyaml>=6.0
|
| 22 |
+
fsspec>=2023.6.0 # <= 2023.9.1 makes aux tests fail
|
| 23 |
+
aiohttp>=3.8.1
|
| 24 |
+
packaging>=23.1
|
| 25 |
+
mutagen==1.47.0
|
| 26 |
librosa
|
| 27 |
+
# deps for examples
|
| 28 |
+
flask>=2.0.1
|
| 29 |
+
# deps for inference
|
| 30 |
+
pysbd>=0.3.4
|
| 31 |
+
# deps for notebooks
|
| 32 |
+
umap-learn>=0.5.1
|
| 33 |
+
pandas>=1.4,<2.0
|
| 34 |
+
# deps for training
|
| 35 |
+
matplotlib>=3.7.0
|
| 36 |
+
# coqui stack
|
| 37 |
+
trainer>=0.0.36
|
| 38 |
+
# config management
|
| 39 |
+
coqpit>=0.0.16
|
| 40 |
+
# chinese g2p deps
|
| 41 |
+
jieba
|
| 42 |
pypinyin
|
| 43 |
+
# korean
|
| 44 |
hangul_romanize
|
| 45 |
+
# gruut+supported langs
|
| 46 |
+
gruut[de,es,fr]==2.2.3
|
| 47 |
+
# deps for korean
|
| 48 |
+
jamo
|
| 49 |
+
nltk
|
| 50 |
+
g2pkk>=0.1.1
|
| 51 |
+
# deps for bangla
|
| 52 |
+
bangla
|
| 53 |
+
bnnumerizer
|
| 54 |
+
bnunicodenormalizer
|
| 55 |
+
# deps for tortoise
|
| 56 |
+
einops>=0.6.0
|
| 57 |
+
transformers>=4.45.2
|
| 58 |
+
# deps for bark
|
| 59 |
+
encodec>=0.1.1
|
| 60 |
+
# deps for XTTS
|
| 61 |
+
unidecode>=1.3.2
|
| 62 |
num2words
|
| 63 |
+
# spacy[ja]>=3
|
| 64 |
+
tokenizers==0.20.1
|
| 65 |
+
vinorm==2.0.7
|
| 66 |
+
underthesea==6.8.4
|
| 67 |
+
# remove silence
|
| 68 |
+
hmmlearn==0.3.3
|
| 69 |
+
eyed3==0.9.7
|
| 70 |
+
pesq==0.0.4
|
| 71 |
+
pydub==0.25.1
|
| 72 |
+
pyAudioAnalysis==0.3.14
|
| 73 |
+
ffmpeg-python==0.2.0
|