Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import ( | |
| AutoModelForSeq2SeqLM, | |
| AutoTokenizer, | |
| pipeline, | |
| VitsTokenizer, | |
| VitsModel, | |
| set_seed, | |
| ) | |
| from enum_ import trans_languages, tts_languages, whisper_languages | |
| import logging | |
| import torch | |
| from TTS.api import TTS | |
| from functools import lru_cache | |
| import numpy as np | |
| from faster_whisper import WhisperModel | |
| import librosa | |
| import numpy as np | |
| import torch | |
| import os | |
| from evaluate import load | |
| ##translation | |
| translation_model_name = "facebook/nllb-200-distilled-600M" | |
| tokenizer = AutoTokenizer.from_pretrained(translation_model_name) | |
| translation_model = AutoModelForSeq2SeqLM.from_pretrained(translation_model_name) | |
| wer_metric = load("wer") | |
| cer_metric = load("cer") | |
| def translate_sentence(sentence, src_lang, tgt_lang): | |
| logging.info(src_lang, tgt_lang) | |
| if not sentence: | |
| return "Error: no input sentence" | |
| try: | |
| translator = pipeline( | |
| "translation", | |
| model=translation_model, | |
| tokenizer=tokenizer, | |
| src_lang=trans_languages[src_lang], | |
| tgt_lang=trans_languages[tgt_lang], | |
| max_length=400, | |
| ) | |
| result = translator(sentence) | |
| logging.info(f"Translation: {result}") | |
| except Exception as e: | |
| return f"Translation error: {e}" | |
| if len(result) == 0: | |
| return "No output from translator" | |
| return result[0].get("translation_text", "No translation_text key in output") | |
| def load_tts(): | |
| # Get device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Init TTS | |
| tts_model = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(device) | |
| return tts_model | |
| def load_mms_tts(language): | |
| tokenizer = VitsTokenizer.from_pretrained(f"facebook/mms-tts-{language}") | |
| model = VitsModel.from_pretrained(f"facebook/mms-tts-{language}") | |
| return model, tokenizer | |
| def convert_vits_output_to_wav(vits_output): | |
| """ | |
| Convert VITS model output to WAV format. | |
| Parameters: | |
| vits_output: torch.Tensor or np.ndarray | |
| The audio output from the VITS model (float32). | |
| sample_rate: int, default 24000 | |
| The sample rate of the generated audio. | |
| Returns: | |
| None, but saves a file as 'output.wav' | |
| """ | |
| if isinstance(vits_output, torch.Tensor): | |
| arr = vits_output.detach().cpu().numpy() | |
| else: | |
| arr = np.asarray(vits_output) | |
| arr = np.squeeze(arr) | |
| # Clip to valid range | |
| arr = np.clip(arr, -1.0, 1.0).astype(np.float32) | |
| arr = librosa.resample(arr, orig_sr=16000, target_sr=24000) | |
| return arr | |
| def tts(sentence, language): | |
| if not sentence or sentence.strip() == "": | |
| return None | |
| try: | |
| language_code = tts_languages[language] | |
| if language_code in ["en", "ko", "ja", "zh-cn"]: | |
| tts_model = load_tts() | |
| base_dir = os.path.dirname(os.path.abspath(__file__)) | |
| wav_path = os.path.join(base_dir, "example.mp3") | |
| wav = tts_model.tts( | |
| text=sentence, speaker_wav=wav_path, language=language_code | |
| ) | |
| # Return as (sample_rate, audio_array) tuple for Gradio | |
| return (24000, np.array(wav)) | |
| else: | |
| model, tokenizer = load_mms_tts(tts_languages[language]) | |
| inputs = tokenizer(text=sentence, return_tensors="pt") | |
| set_seed(555) # make deterministic | |
| with torch.no_grad(): | |
| outputs = model(inputs["input_ids"]) | |
| outputs_resample = convert_vits_output_to_wav(outputs.waveform) | |
| return (24000, outputs_resample) | |
| except Exception as e: | |
| logging.error(f"TTS error: {e}") | |
| return None | |
| def load_whisper(type): | |
| model = WhisperModel(type) | |
| return model | |
| def transcribe(audio, language=None): | |
| if audio is None: | |
| return "" | |
| sr, y = audio | |
| if y.ndim > 1: | |
| y = y.mean(axis=1) | |
| y = y.astype(np.float32) / 32768.0 | |
| if sr != 16000: | |
| y = librosa.resample(y, orig_sr=sr, target_sr=16000) | |
| sr = 16000 | |
| model = load_whisper("large-v2") | |
| if language: | |
| segments, info = model.transcribe(y, language=whisper_languages[language]) | |
| else: | |
| segments, info = model.transcribe(y) | |
| logging.info(f"Detected language: {info.language}") | |
| transcription = "" | |
| for segment in segments: | |
| logging.info(segment.text) | |
| transcription += f"{segment.text}\n" | |
| return f"{transcription}" | |
| def evaluate(language, reference, prediction): | |
| ### wer | |
| if language in ["Traditional Chinese", "Vitetnamese"]: | |
| wer = wer_metric.compute(predictions=prediction, reference=reference) | |
| return str((1 - wer) * 100) + "%" | |
| ### cer | |
| else: | |
| cer = cer_metric.compute(predictions=prediction, reference=reference) | |
| return str((1 - cer) * 100) + "%" | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| ## Language Learning Assistant | |
| Learn a new language interactively: | |
| 1. **Type a Sentence**: Enter a sentence you want to learn and get an instant translation. | |
| 2. **Listen to Pronunciation**: Generate and listen to the correct pronunciation. | |
| 3. **Practice Speaking**: Record your pronunciation and compare it to the audio. | |
| 4. **Speech-to-Text Feedback**: Check if your pronunciation is recognized using speech-to-text and get real-time feedback. | |
| Improve your speaking and comprehension skills, all in one place! | |
| """ | |
| ) | |
| with gr.Row(): | |
| # Left column: translation / text output | |
| with gr.Column(scale=1, min_width=300): | |
| with gr.Row(): | |
| src = gr.Dropdown( | |
| list(trans_languages.keys()), | |
| label="Input Language", | |
| value="Traditional Chinese", | |
| ) | |
| tgt = gr.Dropdown( | |
| list(trans_languages.keys()), | |
| label="Output Language", | |
| value="English", | |
| ) | |
| sentence = gr.Textbox(label="Sentence", interactive=True) | |
| translate_btn = gr.Button("Translate Sentence") | |
| with gr.Column(scale=1, min_width=300): | |
| translation = gr.Textbox(label="Translation", interactive=False) | |
| speech = gr.Audio() | |
| with gr.Column(scale=1, min_width=300): | |
| mic = gr.Audio( | |
| sources=["microphone"], type="numpy", label="Record yourself" | |
| ) | |
| transcription = gr.Textbox(label="Your transcription") | |
| accuracy = gr.Textbox(label="Accuracy") | |
| translate_btn.click( | |
| fn=lambda txt, s_lang, t_lang: translate_sentence(txt, s_lang, t_lang), | |
| inputs=[sentence, src, tgt], | |
| outputs=translation, | |
| ) | |
| translation.change(fn=tts, inputs=[translation, tgt], outputs=speech) | |
| mic.stop(fn=transcribe, inputs=[mic, tgt], outputs=[transcription]) | |
| transcription.change( | |
| fn=evaluate, inputs=[tgt, translation, transcription], outputs=[accuracy] | |
| ) | |
| # You could add more callbacks: e.g. after generating sentence, allow translation etc. | |
| demo.launch(share=True) | |