|
|
|
|
|
import tempfile |
|
|
import os |
|
|
|
|
|
|
|
|
import time |
|
|
import librosa |
|
|
import torch |
|
|
import argparse |
|
|
import soundfile as sf |
|
|
|
|
|
import cn2an |
|
|
import requests |
|
|
import re |
|
|
import numpy as np |
|
|
import onnxruntime as ort |
|
|
import axengine as axe |
|
|
import threading |
|
|
import queue |
|
|
from collections import deque |
|
|
|
|
|
|
|
|
from model import SinusoidalPositionEncoder |
|
|
from utils.ax_model_bin import AX_SenseVoiceSmall |
|
|
from utils.ax_vad_bin import AX_Fsmn_vad |
|
|
from utils.vad_utils import merge_vad |
|
|
from funasr.tokenizer.sentencepiece_tokenizer import SentencepiecesTokenizer |
|
|
|
|
|
|
|
|
from libmelotts.python.split_utils import split_sentence |
|
|
from libmelotts.python.text import cleaned_text_to_sequence |
|
|
from libmelotts.python.text.cleaner import clean_text |
|
|
from libmelotts.python.symbols import LANG_TO_SYMBOL_MAP |
|
|
|
|
|
|
|
|
|
|
|
TTS_MODEL_DIR = "libmelotts/models" |
|
|
TTS_MODEL_FILES = { |
|
|
"g": "g-zh_mix_en.bin", |
|
|
"encoder": "encoder-zh.onnx", |
|
|
"decoder": "decoder-zh.axmodel" |
|
|
} |
|
|
|
|
|
|
|
|
QWEN_API_URL = "" |
|
|
|
|
|
|
|
|
def intersperse(lst, item): |
|
|
result = [item] * (len(lst) * 2 + 1) |
|
|
result[1::2] = lst |
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_text_for_tts_infer(text, language_str, symbol_to_id=None): |
|
|
"""修复版音素处理:确保所有数组长度一致""" |
|
|
try: |
|
|
norm_text, phone, tone, word2ph = clean_text(text, language_str) |
|
|
|
|
|
|
|
|
phone_mapping = { |
|
|
'ɛ': '', 'æ': '', 'ʌ': '', 'ʊ': '', 'ɔ': '', 'ɪ': '', 'ɝ': '', 'ɚ': '', 'ɑ': '', |
|
|
'ʒ': '', 'θ': '', 'ð': '', 'ŋ': '', 'ʃ': '', 'ʧ': '', 'ʤ': '', 'ː': '', 'ˈ': '', |
|
|
'ˌ': '', 'ʰ': '', 'ʲ': '', 'ʷ': '', 'ʔ': '', 'ɾ': '', 'ɹ': '', 'ɫ': '', 'ɡ': '', |
|
|
} |
|
|
|
|
|
|
|
|
processed_phone = [] |
|
|
processed_tone = [] |
|
|
removed_symbols = set() |
|
|
|
|
|
for p, t in zip(phone, tone): |
|
|
if p in phone_mapping: |
|
|
|
|
|
removed_symbols.add(p) |
|
|
elif p in symbol_to_id: |
|
|
|
|
|
processed_phone.append(p) |
|
|
processed_tone.append(t) |
|
|
else: |
|
|
|
|
|
removed_symbols.add(p) |
|
|
|
|
|
|
|
|
if removed_symbols: |
|
|
print(f"[音素过滤] 删除了 {len(removed_symbols)} 个特殊音素: {sorted(removed_symbols)}") |
|
|
print(f"[音素过滤] 处理后音素序列长度: {len(processed_phone)}") |
|
|
print(f"[音素过滤] 处理后音调序列长度: {len(processed_tone)}") |
|
|
|
|
|
|
|
|
if not processed_phone: |
|
|
print("[警告] 没有有效音素,使用默认中文音素") |
|
|
processed_phone = ['ni', 'hao'] |
|
|
processed_tone = ['1', '3'] |
|
|
word2ph = [1, 1] |
|
|
|
|
|
|
|
|
if len(processed_phone) != len(phone): |
|
|
print(f"[警告] 音素序列长度变化: {len(phone)} -> {len(processed_phone)}") |
|
|
|
|
|
word2ph = [1] * len(processed_phone) |
|
|
|
|
|
phone, tone, language = cleaned_text_to_sequence(processed_phone, processed_tone, language_str, symbol_to_id) |
|
|
|
|
|
phone = intersperse(phone, 0) |
|
|
tone = intersperse(tone, 0) |
|
|
language = intersperse(language, 0) |
|
|
|
|
|
phone = np.array(phone, dtype=np.int32) |
|
|
tone = np.array(tone, dtype=np.int32) |
|
|
language = np.array(language, dtype=np.int32) |
|
|
word2ph = np.array(word2ph, dtype=np.int32) * 2 |
|
|
word2ph[0] += 1 |
|
|
|
|
|
return phone, tone, language, norm_text, word2ph |
|
|
|
|
|
except Exception as e: |
|
|
print(f"[错误] 文本处理失败: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
raise e |
|
|
|
|
|
def audio_numpy_concat(segment_data_list, sr, speed=1.): |
|
|
audio_segments = [] |
|
|
for segment_data in segment_data_list: |
|
|
audio_segments += segment_data.reshape(-1).tolist() |
|
|
audio_segments += [0] * int((sr * 0.05) / speed) |
|
|
audio_segments = np.array(audio_segments).astype(np.float32) |
|
|
return audio_segments |
|
|
|
|
|
def merge_sub_audio(sub_audio_list, pad_size, audio_len): |
|
|
if pad_size > 0: |
|
|
for i in range(len(sub_audio_list) - 1): |
|
|
sub_audio_list[i][-pad_size:] += sub_audio_list[i+1][:pad_size] |
|
|
sub_audio_list[i][-pad_size:] /= 2 |
|
|
if i > 0: |
|
|
sub_audio_list[i] = sub_audio_list[i][pad_size:] |
|
|
|
|
|
sub_audio = np.concatenate(sub_audio_list, axis=-1) |
|
|
return sub_audio[:audio_len] |
|
|
|
|
|
def calc_word2pronoun(word2ph, pronoun_lens): |
|
|
indice = [0] |
|
|
for ph in word2ph[:-1]: |
|
|
indice.append(indice[-1] + ph) |
|
|
word2pronoun = [] |
|
|
for i, ph in zip(indice, word2ph): |
|
|
word2pronoun.append(np.sum(pronoun_lens[i : i + ph])) |
|
|
return word2pronoun |
|
|
|
|
|
def generate_slices(word2pronoun, dec_len): |
|
|
pn_start, pn_end = 0, 0 |
|
|
zp_start, zp_end = 0, 0 |
|
|
zp_len = 0 |
|
|
pn_slices = [] |
|
|
zp_slices = [] |
|
|
while pn_end < len(word2pronoun): |
|
|
if pn_end - pn_start > 2 and np.sum(word2pronoun[pn_end - 2 : pn_end + 1]) <= dec_len: |
|
|
zp_len = np.sum(word2pronoun[pn_end - 2 : pn_end]) |
|
|
zp_start = zp_end - zp_len |
|
|
pn_start = pn_end - 2 |
|
|
else: |
|
|
zp_len = 0 |
|
|
zp_start = zp_end |
|
|
pn_start = pn_end |
|
|
|
|
|
while pn_end < len(word2pronoun) and zp_len + word2pronoun[pn_end] <= dec_len: |
|
|
zp_len += word2pronoun[pn_end] |
|
|
pn_end += 1 |
|
|
zp_end = zp_start + zp_len |
|
|
pn_slices.append(slice(pn_start, pn_end)) |
|
|
zp_slices.append(slice(zp_start, zp_end)) |
|
|
return pn_slices, zp_slices |
|
|
|
|
|
|
|
|
def lang_detect_with_regex(text): |
|
|
text_without_digits = re.sub(r'\d+', '', text) |
|
|
|
|
|
if not text_without_digits: |
|
|
return 'unknown' |
|
|
|
|
|
if re.search(r'[\u4e00-\u9fff]', text_without_digits): |
|
|
return 'chinese' |
|
|
else: |
|
|
if re.search(r'[a-zA-Z]', text_without_digits): |
|
|
return 'english' |
|
|
else: |
|
|
return 'unknown' |
|
|
|
|
|
class QwenTranslationAPI: |
|
|
def __init__(self, api_url=QWEN_API_URL): |
|
|
self.api_url = api_url |
|
|
self.session_id = f"speech_translate_{int(time.time())}" |
|
|
self.last_reset_time = time.time() |
|
|
self.request_count = 0 |
|
|
self.max_requests_before_reset = 10 |
|
|
|
|
|
def reset_context(self): |
|
|
"""重置API上下文""" |
|
|
try: |
|
|
reset_url = f"{self.api_url}/api/reset" |
|
|
response = requests.post(reset_url, json={}, timeout=5) |
|
|
if response.status_code == 200: |
|
|
print("[翻译API] ✓ 上下文重置成功") |
|
|
self.last_reset_time = time.time() |
|
|
self.request_count = 0 |
|
|
return True |
|
|
else: |
|
|
print(f"[翻译API] ✗ 重置失败,状态码: {response.status_code}, 响应: {response.text}") |
|
|
except Exception as e: |
|
|
print(f"[翻译API] ✗ 重置上下文失败: {e}") |
|
|
return False |
|
|
|
|
|
def check_and_reset_if_needed(self): |
|
|
"""检查是否需要重置上下文""" |
|
|
current_time = time.time() |
|
|
if (self.request_count >= 10 or |
|
|
current_time - self.last_reset_time > 120): |
|
|
print(f"[翻译API] 重置 (请求数: {self.request_count}, 时间: {current_time - self.last_reset_time:.1f}秒)") |
|
|
return self.reset_context() |
|
|
return True |
|
|
|
|
|
def translate(self, text_content, max_retries=3, timeout=120): |
|
|
if not text_content or text_content.strip() == "": |
|
|
return "输入文本为空" |
|
|
|
|
|
|
|
|
if len(text_content.strip()) < 3: |
|
|
return text_content |
|
|
|
|
|
if lang_detect_with_regex(text_content)=='chinese': |
|
|
prompt_f = "翻译成英文" |
|
|
else: |
|
|
prompt_f= "翻译成中文" |
|
|
|
|
|
prompt = f"{prompt_f}:{text_content}" |
|
|
print(f"[翻译API] 发送请求: {prompt}") |
|
|
|
|
|
|
|
|
self.check_and_reset_if_needed() |
|
|
|
|
|
for attempt in range(max_retries): |
|
|
try: |
|
|
generate_url = f"{self.api_url}/api/generate" |
|
|
payload = { |
|
|
"prompt": prompt, |
|
|
"temperature": 0.1, |
|
|
"repetition_penalty": 1.0, |
|
|
"top-p": 0.9, |
|
|
"top-k": 40, |
|
|
"max_new_tokens": 512 |
|
|
} |
|
|
|
|
|
print(f"[翻译API] 开始生成请求 (尝试 {attempt + 1}/{max_retries})") |
|
|
response = requests.post(generate_url, json=payload, timeout=30) |
|
|
response.raise_for_status() |
|
|
print("[翻译API] 生成请求成功") |
|
|
|
|
|
result_url = f"{self.api_url}/api/generate_provider" |
|
|
start_time = time.time() |
|
|
full_translation = "" |
|
|
error_detected = False |
|
|
|
|
|
while time.time() - start_time < timeout: |
|
|
try: |
|
|
result_response = requests.get(result_url, timeout=10) |
|
|
result_data = result_response.json() |
|
|
|
|
|
current_chunk = result_data.get("response", "") |
|
|
|
|
|
|
|
|
if "error:" in current_chunk.lower() or "setkvcache failed" in current_chunk.lower(): |
|
|
print(f"[翻译API] ✗ 检测到错误: {current_chunk}") |
|
|
error_detected = True |
|
|
print("[翻译API] 立即重置上下文...") |
|
|
self.reset_context() |
|
|
break |
|
|
|
|
|
full_translation += current_chunk |
|
|
|
|
|
if result_data.get("done", False): |
|
|
if full_translation and len(full_translation.strip()) > 0: |
|
|
self.request_count += 1 |
|
|
print(f"[翻译API] ✓ 翻译完成: {full_translation}") |
|
|
return full_translation |
|
|
else: |
|
|
print(f"[翻译API] ✗ 翻译结果为空") |
|
|
break |
|
|
|
|
|
time.sleep(0.05) |
|
|
|
|
|
except requests.exceptions.RequestException as e: |
|
|
print(f"[翻译API] 轮询请求失败: {e}") |
|
|
if time.time() - start_time > timeout: |
|
|
break |
|
|
time.sleep(0.05) |
|
|
continue |
|
|
|
|
|
if error_detected: |
|
|
if attempt < max_retries - 1: |
|
|
wait_time = 1 |
|
|
print(f"[翻译API] 等待 {wait_time} 秒后重试...") |
|
|
time.sleep(wait_time) |
|
|
continue |
|
|
else: |
|
|
print("[翻译API] 达到最大重试次数,返回原文") |
|
|
return text_content |
|
|
|
|
|
print(f"[翻译API] 轮询超时,尝试第 {attempt + 1} 次重试") |
|
|
|
|
|
except requests.exceptions.RequestException as e: |
|
|
print(f"[翻译API] 请求失败 (尝试 {attempt + 1}/{max_retries}): {e}") |
|
|
if attempt < max_retries - 1: |
|
|
wait_time = 2 ** attempt |
|
|
print(f"[翻译API] 等待 {wait_time} 秒后重试...") |
|
|
time.sleep(wait_time) |
|
|
else: |
|
|
return text_content |
|
|
except Exception as e: |
|
|
print(f"[翻译API] 翻译过程出错: {e}") |
|
|
if attempt < max_retries - 1: |
|
|
time.sleep(1) |
|
|
continue |
|
|
return text_content |
|
|
|
|
|
print("[翻译API] 翻译超时,返回原文") |
|
|
return text_content |
|
|
|
|
|
class AudioResampler: |
|
|
"""音频重采样器""" |
|
|
def __init__(self, target_sr=16000): |
|
|
self.target_sr = target_sr |
|
|
|
|
|
def resample_audio(self, audio_data, original_sr): |
|
|
"""重采样音频到目标采样率,asr统一输入16000Hz""" |
|
|
if original_sr == self.target_sr: |
|
|
return audio_data |
|
|
|
|
|
print(f"[重采样] {original_sr}Hz -> {self.target_sr}Hz") |
|
|
return librosa.resample(y=audio_data, orig_sr=original_sr, target_sr=self.target_sr) |
|
|
|
|
|
def resample_chunk(self, audio_chunk, original_sr): |
|
|
"""重采样音频块:长音频进行过冲采样后,音频块可以不做重采样""" |
|
|
if original_sr == self.target_sr: |
|
|
return audio_chunk |
|
|
|
|
|
if len(audio_chunk) < 1000: |
|
|
return self._linear_resample(audio_chunk, original_sr, self.target_sr) |
|
|
else: |
|
|
return librosa.resample(y=audio_chunk, orig_sr=original_sr, target_sr=self.target_sr) |
|
|
|
|
|
def _linear_resample(self, audio_chunk, original_sr, target_sr): |
|
|
"""线性插值重采样""" |
|
|
ratio = target_sr / original_sr |
|
|
old_length = len(audio_chunk) |
|
|
new_length = int(old_length * ratio) |
|
|
|
|
|
old_indices = np.arange(old_length) |
|
|
new_indices = np.linspace(0, old_length - 1, new_length) |
|
|
|
|
|
resampled = np.interp(new_indices, old_indices, audio_chunk) |
|
|
return resampled |
|
|
|
|
|
class StreamProcessor: |
|
|
"""流式处理""" |
|
|
def __init__(self, pipeline, chunk_duration=7.0, overlap_duration=0.01, target_sr=16000): |
|
|
self.pipeline = pipeline |
|
|
self.chunk_duration = chunk_duration |
|
|
self.overlap_duration = overlap_duration |
|
|
self.target_sr = target_sr |
|
|
self.chunk_samples = int(chunk_duration * target_sr) |
|
|
self.overlap_samples = int(overlap_duration * target_sr) |
|
|
self.audio_buffer = deque() |
|
|
self.result_queue = queue.Queue() |
|
|
self.is_running = False |
|
|
self.processing_thread = None |
|
|
self.resampler = AudioResampler(target_sr=target_sr) |
|
|
self.segment_counter = 0 |
|
|
self.processed_texts = set() |
|
|
|
|
|
def start_processing(self): |
|
|
"""开始流式处理""" |
|
|
self.is_running = True |
|
|
self.processing_thread = threading.Thread(target=self._process_loop) |
|
|
self.processing_thread.daemon = True |
|
|
self.processing_thread.start() |
|
|
|
|
|
def stop_processing(self): |
|
|
"""停止流式处理""" |
|
|
self.is_running = False |
|
|
if self.processing_thread: |
|
|
self.processing_thread.join(timeout=5) |
|
|
|
|
|
def add_audio_chunk(self, audio_chunk, original_sr=None): |
|
|
"""添加音频块到缓冲区""" |
|
|
if original_sr and original_sr != self.target_sr: |
|
|
audio_chunk = self.resampler.resample_chunk(audio_chunk, original_sr) |
|
|
|
|
|
self.audio_buffer.append(audio_chunk) |
|
|
|
|
|
def get_next_result(self, timeout=1.0): |
|
|
"""获取下一个处理结果""" |
|
|
try: |
|
|
return self.result_queue.get(timeout=timeout) |
|
|
except queue.Empty: |
|
|
return None |
|
|
|
|
|
def _process_loop(self): |
|
|
"""处理循环""" |
|
|
accumulated_audio = np.array([], dtype=np.float32) |
|
|
last_asr_result = "" |
|
|
|
|
|
while self.is_running: |
|
|
if len(self.audio_buffer) > 0: |
|
|
audio_chunk = self.audio_buffer.popleft() |
|
|
accumulated_audio = np.concatenate([accumulated_audio, audio_chunk]) |
|
|
|
|
|
|
|
|
if len(accumulated_audio) >= self.chunk_samples: |
|
|
|
|
|
process_chunk = accumulated_audio[:self.chunk_samples] |
|
|
accumulated_audio = accumulated_audio[self.chunk_samples - self.overlap_samples:] |
|
|
|
|
|
try: |
|
|
|
|
|
asr_result = self._stream_asr(process_chunk) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (asr_result and asr_result.strip() and |
|
|
|
|
|
asr_result != last_asr_result and |
|
|
asr_result not in self.processed_texts): |
|
|
|
|
|
print(f"[实时ASR] {asr_result}") |
|
|
last_asr_result = asr_result |
|
|
self.processed_texts.add(asr_result) |
|
|
|
|
|
|
|
|
try: |
|
|
translation_result = self.pipeline.run_translation(asr_result) |
|
|
|
|
|
|
|
|
if (translation_result and |
|
|
translation_result != asr_result and |
|
|
"翻译失败" not in translation_result and |
|
|
"error:" not in translation_result.lower() and |
|
|
"输入文本为空" not in translation_result): |
|
|
|
|
|
print(f"[实时翻译] {translation_result}") |
|
|
|
|
|
|
|
|
try: |
|
|
self.segment_counter += 1 |
|
|
tts_filename = f"stream_segment_{self.segment_counter:04d}.wav" |
|
|
tts_start_time = time.time() |
|
|
|
|
|
tts_path = self.pipeline.run_tts( |
|
|
translation_result, |
|
|
self.pipeline.output_dir, |
|
|
tts_filename |
|
|
) |
|
|
|
|
|
tts_time = time.time() - tts_start_time |
|
|
print(f"[实时TTS] 音频已保存: {tts_path} (耗时: {tts_time:.2f}秒)") |
|
|
|
|
|
|
|
|
self.result_queue.put({ |
|
|
'type': 'complete', |
|
|
'original': asr_result, |
|
|
'translated': translation_result, |
|
|
'audio_path': tts_path, |
|
|
'timestamp': time.time(), |
|
|
'segment_id': self.segment_counter |
|
|
}) |
|
|
|
|
|
except Exception as tts_error: |
|
|
print(f"[实时TTS错误] {tts_error}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
else: |
|
|
print(f"[实时翻译] 翻译结果无效,已跳过") |
|
|
|
|
|
except Exception as translation_error: |
|
|
print(f"[实时翻译错误] {translation_error}") |
|
|
else: |
|
|
if asr_result == last_asr_result: |
|
|
print(f"[实时ASR] 重复内容已跳过: {asr_result}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"[流式处理错误] {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
|
|
|
time.sleep(0.01) |
|
|
|
|
|
def _stream_asr(self, audio_chunk): |
|
|
"""流式ASR识别(带VAD)""" |
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
res_vad = self.pipeline.model_vad(audio_chunk)[0] |
|
|
vad_segments = merge_vad(res_vad, 15 * 1000) |
|
|
|
|
|
|
|
|
if not vad_segments or len(vad_segments) == 0: |
|
|
print(f"[VAD] 未检测到语音活动,跳过此音频块") |
|
|
return "" |
|
|
|
|
|
print(f"[VAD] 检测到 {len(vad_segments)} 个语音段") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
all_results = "" |
|
|
|
|
|
for i, segment in enumerate(vad_segments): |
|
|
segment_start, segment_end = segment |
|
|
start_sample = int(segment_start / 1000 * self.target_sr) |
|
|
end_sample = min(int(segment_end / 1000 * self.target_sr), len(audio_chunk)) |
|
|
segment_audio = audio_chunk[start_sample:end_sample] |
|
|
|
|
|
|
|
|
if len(segment_audio) < int(0.3 * self.target_sr): |
|
|
continue |
|
|
|
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as temp_file: |
|
|
sf.write(temp_file.name, segment_audio, self.target_sr) |
|
|
temp_filename = temp_file.name |
|
|
|
|
|
try: |
|
|
|
|
|
segment_result = self.pipeline.model_bin( |
|
|
temp_filename, |
|
|
"auto", |
|
|
True, |
|
|
self.pipeline.position_encoding, |
|
|
tokenizer=self.pipeline.tokenizer, |
|
|
) |
|
|
|
|
|
if segment_result and segment_result.strip(): |
|
|
all_results += segment_result + " " |
|
|
|
|
|
|
|
|
os.unlink(temp_filename) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"[ASR错误] 处理VAD段 {i} 时出错: {e}") |
|
|
if os.path.exists(temp_filename): |
|
|
os.unlink(temp_filename) |
|
|
continue |
|
|
|
|
|
return all_results.strip() |
|
|
|
|
|
except Exception as e: |
|
|
print(f"[ASR错误] {e}") |
|
|
return "" |
|
|
|
|
|
class SpeechTranslationPipeline: |
|
|
def __init__(self, |
|
|
tts_model_dir, tts_model_files, |
|
|
asr_model_dir="ax_model", seq_len=132, |
|
|
tts_dec_len=128, sample_rate=44100, tts_speed=0.8, |
|
|
qwen_api_url=QWEN_API_URL, target_sr=16000, |
|
|
output_dir="./output"): |
|
|
self.tts_model_dir = tts_model_dir |
|
|
self.tts_model_files = tts_model_files |
|
|
self.asr_model_dir = asr_model_dir |
|
|
self.seq_len = seq_len |
|
|
self.tts_dec_len = tts_dec_len |
|
|
self.sample_rate = sample_rate |
|
|
self.tts_speed = tts_speed |
|
|
self.qwen_api_url = qwen_api_url |
|
|
self.target_sr = target_sr |
|
|
self.output_dir = output_dir |
|
|
|
|
|
|
|
|
os.makedirs(self.output_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
self.resampler = AudioResampler(target_sr=target_sr) |
|
|
|
|
|
|
|
|
self._init_asr_models() |
|
|
|
|
|
|
|
|
self._init_tts_models() |
|
|
|
|
|
|
|
|
self.translator = QwenTranslationAPI(api_url=qwen_api_url) |
|
|
|
|
|
|
|
|
self.stream_processor = StreamProcessor(self, target_sr=target_sr) |
|
|
|
|
|
|
|
|
self._validate_files() |
|
|
|
|
|
|
|
|
print("[初始化] 重置API上下文...") |
|
|
self.translator.reset_context() |
|
|
|
|
|
def _init_asr_models(self): |
|
|
"""初始化语音识别相关模型""" |
|
|
print("Initializing SenseVoice models...") |
|
|
|
|
|
self.model_vad = AX_Fsmn_vad(self.asr_model_dir) |
|
|
|
|
|
self.embed = SinusoidalPositionEncoder() |
|
|
self.position_encoding = self.embed.get_position_encoding( |
|
|
torch.randn(1, self.seq_len, 560)).numpy() |
|
|
|
|
|
self.model_bin = AX_SenseVoiceSmall(self.asr_model_dir, seq_len=self.seq_len) |
|
|
|
|
|
tokenizer_path = os.path.join(self.asr_model_dir, "chn_jpn_yue_eng_ko_spectok.bpe.model") |
|
|
self.tokenizer = SentencepiecesTokenizer(bpemodel=tokenizer_path) |
|
|
|
|
|
print("SenseVoice models initialized successfully.") |
|
|
|
|
|
def _init_tts_models(self): |
|
|
"""初始化TTS相关模型""" |
|
|
print("Initializing MeloTTS models...") |
|
|
init_start = time.time() |
|
|
|
|
|
enc_model = os.path.join(self.tts_model_dir, self.tts_model_files["encoder"]) |
|
|
dec_model = os.path.join(self.tts_model_dir, self.tts_model_files["decoder"]) |
|
|
|
|
|
model_load_start = time.time() |
|
|
self.sess_enc = ort.InferenceSession(enc_model, providers=["CPUExecutionProvider"], sess_options=ort.SessionOptions()) |
|
|
self.sess_dec = axe.InferenceSession(dec_model) |
|
|
print(f" Load encoder/decoder models: {(time.time() - model_load_start)*1000:.2f}ms") |
|
|
|
|
|
g_file = os.path.join(self.tts_model_dir, self.tts_model_files["g"]) |
|
|
self.tts_g = np.fromfile(g_file, dtype=np.float32).reshape(1, 256, 1) |
|
|
|
|
|
self.tts_language = "ZH_MIX_EN" |
|
|
self.symbol_to_id = {s: i for i, s in enumerate(LANG_TO_SYMBOL_MAP[self.tts_language])} |
|
|
|
|
|
print(" Warming up TTS modules...") |
|
|
warmup_start = time.time() |
|
|
|
|
|
try: |
|
|
warmup_text_mix = "这是一个test测试。" |
|
|
_, _, _, _, _ = get_text_for_tts_infer(warmup_text_mix, self.tts_language, symbol_to_id=self.symbol_to_id) |
|
|
print(f" Mixed ZH-EN warm-up: {(time.time() - warmup_start)*1000:.2f}ms") |
|
|
except Exception as e: |
|
|
print(f" Warning: Mixed warm-up failed: {e}") |
|
|
|
|
|
total_init_time = (time.time() - init_start) * 1000 |
|
|
print(f"MeloTTS models initialized successfully. Total init time: {total_init_time:.2f}ms") |
|
|
|
|
|
def _validate_files(self): |
|
|
"""验证所有必需的文件都存在""" |
|
|
for key, filename in self.tts_model_files.items(): |
|
|
filepath = os.path.join(self.tts_model_dir, filename) |
|
|
if not os.path.exists(filepath): |
|
|
raise FileNotFoundError(f"TTS模型文件不存在: {filepath}") |
|
|
|
|
|
try: |
|
|
response = requests.get(f"{self.qwen_api_url}/api/generate_provider", timeout=5) |
|
|
print("[API检查] 千问API服务连接正常") |
|
|
except: |
|
|
print("[API警告] 无法连接到千问API服务,请确保已启动API服务") |
|
|
|
|
|
def start_stream_processing(self): |
|
|
"""开始流式处理""" |
|
|
self.stream_processor.start_processing() |
|
|
print("[流式处理] 已启动") |
|
|
|
|
|
def stop_stream_processing(self): |
|
|
"""停止流式处理""" |
|
|
self.stream_processor.stop_processing() |
|
|
print("[流式处理] 已停止") |
|
|
|
|
|
def process_audio_stream(self, audio_chunk, original_sr=None): |
|
|
"""处理音频流数据""" |
|
|
self.stream_processor.add_audio_chunk(audio_chunk, original_sr) |
|
|
|
|
|
def get_stream_results(self): |
|
|
"""获取流式处理结果""" |
|
|
return self.stream_processor.get_next_result() |
|
|
|
|
|
def load_and_resample_audio(self, audio_file): |
|
|
"""加载音频并重采样到目标采样率""" |
|
|
print(f"加载音频文件: {audio_file}") |
|
|
speech, original_sr = librosa.load(audio_file, sr=None) |
|
|
|
|
|
audio_duration = len(speech) / original_sr |
|
|
print(f"原始音频: {original_sr}Hz, 时长: {audio_duration:.2f}秒") |
|
|
|
|
|
if original_sr != self.target_sr: |
|
|
speech = self.resampler.resample_audio(speech, original_sr) |
|
|
print(f"重采样后: {self.target_sr}Hz, 时长: {len(speech)/self.target_sr:.2f}秒") |
|
|
|
|
|
return speech, self.target_sr |
|
|
|
|
|
def run_translation(self, text_content): |
|
|
"""调用Qwen大模型API中英互译""" |
|
|
print("Starting translation via API...") |
|
|
translation_start_time = time.time() |
|
|
|
|
|
translate_content = self.translator.translate(text_content) |
|
|
|
|
|
translation_time_cost = time.time() - translation_start_time |
|
|
print(f"Translation processing time: {translation_time_cost:.2f} seconds") |
|
|
print(f"Translation Result: {translate_content}") |
|
|
|
|
|
return translate_content |
|
|
|
|
|
def run_tts(self, translate_content, output_dir, output_wav=None): |
|
|
"""使用TTS模型合成语音""" |
|
|
output_path = os.path.join(output_dir, output_wav) |
|
|
|
|
|
try: |
|
|
if lang_detect_with_regex(translate_content) == "chinese": |
|
|
translate_content = cn2an.transform(translate_content, "an2cn") |
|
|
|
|
|
print(f"TTS synthesis for text: {translate_content}") |
|
|
|
|
|
sens = split_sentence(translate_content, language_str=self.tts_language) |
|
|
print(f"Text split into {len(sens)} sentences") |
|
|
|
|
|
audio_list = [] |
|
|
|
|
|
for n, se in enumerate(sens): |
|
|
if self.tts_language in ['EN', 'ZH_MIX_EN']: |
|
|
se = re.sub(r'([a-z])([A-Z])', r'\1 \2', se) |
|
|
|
|
|
print(f"Processing sentence[{n}]: {se}") |
|
|
|
|
|
phones, tones, lang_ids, norm_text, word2ph = get_text_for_tts_infer( |
|
|
se, self.tts_language, symbol_to_id=self.symbol_to_id) |
|
|
|
|
|
encoder_start = time.time() |
|
|
z_p, pronoun_lens, audio_len = self.sess_enc.run(None, input_feed={ |
|
|
'phone': phones, 'g': self.tts_g, |
|
|
'tone': tones, 'language': lang_ids, |
|
|
'noise_scale': np.array([0], dtype=np.float32), |
|
|
'length_scale': np.array([1.0 / self.tts_speed], dtype=np.float32), |
|
|
'noise_scale_w': np.array([0], dtype=np.float32), |
|
|
'sdp_ratio': np.array([0], dtype=np.float32)}) |
|
|
print(f"Encoder run time: {1000 * (time.time() - encoder_start):.2f}ms") |
|
|
|
|
|
word2pronoun = calc_word2pronoun(word2ph, pronoun_lens) |
|
|
pn_slices, zp_slices = generate_slices(word2pronoun, self.tts_dec_len) |
|
|
|
|
|
audio_len = audio_len[0] |
|
|
sub_audio_list = [] |
|
|
|
|
|
for i, (ps, zs) in enumerate(zip(pn_slices, zp_slices)): |
|
|
zp_slice = z_p[..., zs] |
|
|
|
|
|
sub_dec_len = zp_slice.shape[-1] |
|
|
sub_audio_len = 512 * sub_dec_len |
|
|
|
|
|
if zp_slice.shape[-1] < self.tts_dec_len: |
|
|
zp_slice = np.concatenate((zp_slice, np.zeros((*zp_slice.shape[:-1], self.tts_dec_len - zp_slice.shape[-1]), dtype=np.float32)), axis=-1) |
|
|
|
|
|
decoder_start = time.time() |
|
|
audio = self.sess_dec.run(None, input_feed={"z_p": zp_slice, "g": self.tts_g})[0].flatten() |
|
|
|
|
|
audio_start = 0 |
|
|
if len(sub_audio_list) > 0: |
|
|
if pn_slices[i - 1].stop > ps.start: |
|
|
audio_start = 512 * word2pronoun[ps.start] |
|
|
|
|
|
audio_end = sub_audio_len |
|
|
if i < len(pn_slices) - 1: |
|
|
if ps.stop > pn_slices[i + 1].start: |
|
|
audio_end = sub_audio_len - 512 * word2pronoun[ps.stop - 1] |
|
|
|
|
|
audio = audio[audio_start:audio_end] |
|
|
print(f"Decode slice[{i}]: decoder run time {1000 * (time.time() - decoder_start):.2f}ms") |
|
|
sub_audio_list.append(audio) |
|
|
|
|
|
sub_audio = merge_sub_audio(sub_audio_list, 0, audio_len) |
|
|
audio_list.append(sub_audio) |
|
|
|
|
|
audio = audio_numpy_concat(audio_list, sr=self.sample_rate, speed=self.tts_speed) |
|
|
|
|
|
sf.write(output_path, audio, self.sample_rate) |
|
|
print(f"TTS audio saved to {output_path}") |
|
|
|
|
|
return output_path |
|
|
|
|
|
except Exception as e: |
|
|
print(f"TTS synthesis failed: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
raise e |
|
|
|
|
|
def process_long_audio_stream(self, audio_file, chunk_size=64000): |
|
|
""" |
|
|
处理长音频文件的流式模拟 |
|
|
chunk_size增加到64000(4秒 * 16000Hz),与StreamProcessor的chunk_duration匹配 |
|
|
4秒有点短,改到7秒感觉更好点 |
|
|
""" |
|
|
print(f"[流式处理] 开始处理长音频: {audio_file}") |
|
|
|
|
|
|
|
|
speech, fs = self.load_and_resample_audio(audio_file) |
|
|
|
|
|
|
|
|
self.start_stream_processing() |
|
|
|
|
|
total_chunks = (len(speech) + chunk_size - 1) // chunk_size |
|
|
print(f"[流式处理] 音频总长度: {len(speech)/fs:.2f}秒, 分块数: {total_chunks}") |
|
|
|
|
|
|
|
|
all_results = [] |
|
|
|
|
|
|
|
|
chunk_count = 0 |
|
|
for i in range(0, len(speech), chunk_size): |
|
|
chunk = speech[i:i+chunk_size] |
|
|
chunk_count += 1 |
|
|
|
|
|
|
|
|
if len(chunk) < chunk_size: |
|
|
padding_size = chunk_size - len(chunk) |
|
|
chunk = np.concatenate([chunk, np.zeros(padding_size, dtype=np.float32)]) |
|
|
print(f"\n[流式处理] 处理音频块 {chunk_count}/{total_chunks} (最后一块,已填零 {padding_size} 样本)") |
|
|
else: |
|
|
print(f"\n[流式处理] 处理音频块 {chunk_count}/{total_chunks}") |
|
|
|
|
|
self.process_audio_stream(chunk, fs) |
|
|
|
|
|
|
|
|
result = self.get_stream_results() |
|
|
while result: |
|
|
print(f"\n{'='*70}") |
|
|
print(f"[实时结果 #{len(all_results) + 1}]") |
|
|
print(f"段落ID: {result['segment_id']}") |
|
|
print(f"原文: {result['original']}") |
|
|
print(f"翻译: {result['translated']}") |
|
|
print(f"音频: {result['audio_path']}") |
|
|
print(f"{'='*70}") |
|
|
all_results.append(result) |
|
|
result = self.get_stream_results() |
|
|
|
|
|
time.sleep(0.01) |
|
|
|
|
|
|
|
|
|
|
|
max_wait_time = 20 |
|
|
wait_start = time.time() |
|
|
|
|
|
while time.time() - wait_start < max_wait_time: |
|
|
result = self.get_stream_results() |
|
|
if result: |
|
|
print(f"\n{'='*70}") |
|
|
print(f"[实时结果 #{len(all_results) + 1}]") |
|
|
print(f"段落ID: {result['segment_id']}") |
|
|
print(f"原文: {result['original']}") |
|
|
print(f"翻译: {result['translated']}") |
|
|
print(f"音频: {result['audio_path']}") |
|
|
print(f"{'='*70}") |
|
|
all_results.append(result) |
|
|
wait_start = time.time() |
|
|
else: |
|
|
time.sleep(0.02) |
|
|
|
|
|
|
|
|
self.stop_stream_processing() |
|
|
|
|
|
print(f"\n[流式处理] 完成!共处理 {len(all_results)} 个有效结果") |
|
|
return all_results |
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="实时语音翻译pipeline") |
|
|
parser.add_argument("--audio_file", type=str, default="./wav/en_6mins.wav", help="输入音频文件路径") |
|
|
parser.add_argument("--output_dir", type=str, default="./output", help="输出目录") |
|
|
parser.add_argument("--api_url", type=str, default="http://10.126.29.158:8000", help="Qwen API服务器URL") |
|
|
parser.add_argument("--target_sr", type=int, default=16000, help="ASR目标采样率 (默认: 16000)") |
|
|
parser.add_argument("--chunk_duration", type=float, default=7.0, help="音频块时长(秒) (默认: 7.0)") |
|
|
parser.add_argument("--overlap_duration", type=float, default=0.01, help="重叠时长(秒) (默认: 0.1)") |
|
|
|
|
|
args = parser.parse_args() |
|
|
print("-------------------实时语音翻译pipeline-------------------\n") |
|
|
os.makedirs(args.output_dir, exist_ok=True) |
|
|
|
|
|
print(f"处理音频文件: {args.audio_file}") |
|
|
print(f"输出目录: {args.output_dir}") |
|
|
print(f"音频块时长: {args.chunk_duration}秒") |
|
|
print(f"重叠时长: {args.overlap_duration}秒\n") |
|
|
|
|
|
|
|
|
pipeline = SpeechTranslationPipeline( |
|
|
tts_model_dir=TTS_MODEL_DIR, |
|
|
tts_model_files=TTS_MODEL_FILES, |
|
|
asr_model_dir="ax_model", |
|
|
seq_len=132, |
|
|
tts_dec_len=128, |
|
|
sample_rate=44100, |
|
|
tts_speed=0.8, |
|
|
qwen_api_url=args.api_url, |
|
|
target_sr=args.target_sr, |
|
|
output_dir=args.output_dir |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
start_time = time.time() |
|
|
try: |
|
|
|
|
|
|
|
|
print("="*70 + "\n") |
|
|
|
|
|
|
|
|
chunk_size = int(args.chunk_duration * args.target_sr) |
|
|
results = pipeline.process_long_audio_stream(args.audio_file, chunk_size=chunk_size) |
|
|
|
|
|
print("\n" + "="*70) |
|
|
print(" 处理完成") |
|
|
print("="*70) |
|
|
print(f"\n 成功处理 {len(results)} 个有效翻译段落\n") |
|
|
|
|
|
|
|
|
if results: |
|
|
print("所有翻译结果:") |
|
|
print("-" * 70) |
|
|
for idx, result in enumerate(results, 1): |
|
|
print(f"\n【段落 {idx}】(ID: {result['segment_id']})") |
|
|
print(f" 原文: {result['original']}") |
|
|
print(f" 译文: {result['translated']}") |
|
|
print(f" 音频: {result['audio_path']}") |
|
|
print(f" 时间: {time.strftime('%H:%M:%S', time.localtime(result['timestamp']))}") |
|
|
print("-" * 70) |
|
|
|
|
|
|
|
|
result_file = os.path.join(args.output_dir, "stream_results.txt") |
|
|
with open(result_file, 'w', encoding='utf-8') as f: |
|
|
f.write(f"流式翻译+TTS结果 - {args.audio_file}\n") |
|
|
f.write(f"处理时间: {time.strftime('%Y-%m-%d %H:%M:%S')}\n") |
|
|
f.write(f"音频块时长: {args.chunk_duration}秒, 重叠时长: {args.overlap_duration}秒\n") |
|
|
f.write("="*70 + "\n\n") |
|
|
for idx, result in enumerate(results, 1): |
|
|
f.write(f"【段落 {idx}】(ID: {result['segment_id']})\n") |
|
|
f.write(f"原文: {result['original']}\n") |
|
|
f.write(f"译文: {result['translated']}\n") |
|
|
f.write(f"音频: {result['audio_path']}\n") |
|
|
f.write(f"时间: {time.strftime('%H:%M:%S', time.localtime(result['timestamp']))}\n") |
|
|
f.write("\n" + "-"*70 + "\n\n") |
|
|
print(f"\n✓ 结果已保存到: {result_file}") |
|
|
|
|
|
|
|
|
audio_files = [r['audio_path'] for r in results] |
|
|
print(f"\n 生成 {len(audio_files)} 个TTS音频文件:") |
|
|
for audio_file in audio_files: |
|
|
print(f" - {audio_file}") |
|
|
else: |
|
|
print("\n 未获取到有效的翻译结果") |
|
|
|
|
|
print("="*70) |
|
|
|
|
|
|
|
|
total_time = time.time() - start_time |
|
|
print(f"\n总处理时间: {total_time:.2f} 秒") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Pipeline执行失败: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
|