sarahwei's picture
Transcribe when the microphone stops
54cf316
import gradio as gr
from transformers import (
AutoModelForSeq2SeqLM,
AutoTokenizer,
pipeline,
VitsTokenizer,
VitsModel,
set_seed,
)
from enum_ import trans_languages, tts_languages, whisper_languages
import logging
import torch
from TTS.api import TTS
from functools import lru_cache
import numpy as np
from faster_whisper import WhisperModel
import librosa
import numpy as np
import torch
import os
from evaluate import load
##translation
translation_model_name = "facebook/nllb-200-distilled-600M"
tokenizer = AutoTokenizer.from_pretrained(translation_model_name)
translation_model = AutoModelForSeq2SeqLM.from_pretrained(translation_model_name)
wer_metric = load("wer")
cer_metric = load("cer")
@lru_cache(maxsize=10)
def translate_sentence(sentence, src_lang, tgt_lang):
logging.info(src_lang, tgt_lang)
if not sentence:
return "Error: no input sentence"
try:
translator = pipeline(
"translation",
model=translation_model,
tokenizer=tokenizer,
src_lang=trans_languages[src_lang],
tgt_lang=trans_languages[tgt_lang],
max_length=400,
)
result = translator(sentence)
logging.info(f"Translation: {result}")
except Exception as e:
return f"Translation error: {e}"
if len(result) == 0:
return "No output from translator"
return result[0].get("translation_text", "No translation_text key in output")
@lru_cache(maxsize=10)
def load_tts():
# Get device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Init TTS
tts_model = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(device)
return tts_model
@lru_cache(maxsize=10)
def load_mms_tts(language):
tokenizer = VitsTokenizer.from_pretrained(f"facebook/mms-tts-{language}")
model = VitsModel.from_pretrained(f"facebook/mms-tts-{language}")
return model, tokenizer
def convert_vits_output_to_wav(vits_output):
"""
Convert VITS model output to WAV format.
Parameters:
vits_output: torch.Tensor or np.ndarray
The audio output from the VITS model (float32).
sample_rate: int, default 24000
The sample rate of the generated audio.
Returns:
None, but saves a file as 'output.wav'
"""
if isinstance(vits_output, torch.Tensor):
arr = vits_output.detach().cpu().numpy()
else:
arr = np.asarray(vits_output)
arr = np.squeeze(arr)
# Clip to valid range
arr = np.clip(arr, -1.0, 1.0).astype(np.float32)
arr = librosa.resample(arr, orig_sr=16000, target_sr=24000)
return arr
def tts(sentence, language):
if not sentence or sentence.strip() == "":
return None
try:
language_code = tts_languages[language]
if language_code in ["en", "ko", "ja", "zh-cn"]:
tts_model = load_tts()
base_dir = os.path.dirname(os.path.abspath(__file__))
wav_path = os.path.join(base_dir, "example.mp3")
wav = tts_model.tts(
text=sentence, speaker_wav=wav_path, language=language_code
)
# Return as (sample_rate, audio_array) tuple for Gradio
return (24000, np.array(wav))
else:
model, tokenizer = load_mms_tts(tts_languages[language])
inputs = tokenizer(text=sentence, return_tensors="pt")
set_seed(555) # make deterministic
with torch.no_grad():
outputs = model(inputs["input_ids"])
outputs_resample = convert_vits_output_to_wav(outputs.waveform)
return (24000, outputs_resample)
except Exception as e:
logging.error(f"TTS error: {e}")
return None
@lru_cache(maxsize=10)
def load_whisper(type):
model = WhisperModel(type)
return model
def transcribe(audio, language=None):
if audio is None:
return ""
sr, y = audio
if y.ndim > 1:
y = y.mean(axis=1)
y = y.astype(np.float32) / 32768.0
if sr != 16000:
y = librosa.resample(y, orig_sr=sr, target_sr=16000)
sr = 16000
model = load_whisper("large-v2")
if language:
segments, info = model.transcribe(y, language=whisper_languages[language])
else:
segments, info = model.transcribe(y)
logging.info(f"Detected language: {info.language}")
transcription = ""
for segment in segments:
logging.info(segment.text)
transcription += f"{segment.text}\n"
return f"{transcription}"
def evaluate(language, reference, prediction):
### wer
if language in ["Traditional Chinese", "Vitetnamese"]:
wer = wer_metric.compute(predictions=prediction, reference=reference)
return str((1 - wer) * 100) + "%"
### cer
else:
cer = cer_metric.compute(predictions=prediction, reference=reference)
return str((1 - cer) * 100) + "%"
with gr.Blocks() as demo:
gr.Markdown(
"""
## Language Learning Assistant
Learn a new language interactively:
1. **Type a Sentence**: Enter a sentence you want to learn and get an instant translation.
2. **Listen to Pronunciation**: Generate and listen to the correct pronunciation.
3. **Practice Speaking**: Record your pronunciation and compare it to the audio.
4. **Speech-to-Text Feedback**: Check if your pronunciation is recognized using speech-to-text and get real-time feedback.
Improve your speaking and comprehension skills, all in one place!
"""
)
with gr.Row():
# Left column: translation / text output
with gr.Column(scale=1, min_width=300):
with gr.Row():
src = gr.Dropdown(
list(trans_languages.keys()),
label="Input Language",
value="Traditional Chinese",
)
tgt = gr.Dropdown(
list(trans_languages.keys()),
label="Output Language",
value="English",
)
sentence = gr.Textbox(label="Sentence", interactive=True)
translate_btn = gr.Button("Translate Sentence")
with gr.Column(scale=1, min_width=300):
translation = gr.Textbox(label="Translation", interactive=False)
speech = gr.Audio()
with gr.Column(scale=1, min_width=300):
mic = gr.Audio(
sources=["microphone"], type="numpy", label="Record yourself"
)
transcription = gr.Textbox(label="Your transcription")
accuracy = gr.Textbox(label="Accuracy")
translate_btn.click(
fn=lambda txt, s_lang, t_lang: translate_sentence(txt, s_lang, t_lang),
inputs=[sentence, src, tgt],
outputs=translation,
)
translation.change(fn=tts, inputs=[translation, tgt], outputs=speech)
mic.stop(fn=transcribe, inputs=[mic, tgt], outputs=[transcription])
transcription.change(
fn=evaluate, inputs=[tgt, translation, transcription], outputs=[accuracy]
)
# You could add more callbacks: e.g. after generating sentence, allow translation etc.
demo.launch(share=True)