Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import torchaudio | |
| from torchaudio.transforms import Resample | |
| import importlib.util | |
| # Function to dynamically import wav2vec2 module and avoid duplicate registration | |
| def import_wav2vec2(): | |
| if 'wav2vec2' not in sys.modules: | |
| spec = importlib.util.spec_from_file_location("wav2vec2", "wav2vec2.py") | |
| wav2vec2 = importlib.util.module_from_spec(spec) | |
| sys.modules['wav2vec2'] = wav2vec2 | |
| spec.loader.exec_module(wav2vec2) | |
| else: | |
| wav2vec2 = sys.modules['wav2vec2'] | |
| Wav2Vec2Model = wav2vec2.Wav2Vec2Model | |
| Wav2Vec2Config = wav2vec2.Wav2Vec2Config | |
| return Wav2Vec2Model, Wav2Vec2Config | |
| Wav2Vec2Model, Wav2Vec2Config = import_wav2vec2() | |
| # 定义模型路径 | |
| model_path = "https://huggingface.co/Tele-AI/TeleSpeech-ASR1.0/resolve/main/finetune_large_kespeech.pt" | |
| # 下载模型文件 | |
| print("Downloading model file...") | |
| torch.hub.download_url_to_file(model_path, 'large.pt') | |
| print("Model file downloaded.") | |
| # 加载模型配置和初始化模型 | |
| config = Wav2Vec2Config() | |
| model = Wav2Vec2Model.build_model(config) | |
| # 加载模型参数 | |
| print("Loading model checkpoint...") | |
| checkpoint = torch.load('large.pt', map_location=torch.device('cpu')) | |
| print("Checkpoint keys:", checkpoint.keys()) | |
| # 打印模型参数中的键 | |
| if 'model' in checkpoint: | |
| state_dict = checkpoint['model'] | |
| print("Model state_dict keys:", state_dict.keys()) | |
| else: | |
| print("Key 'model' not found in checkpoint.") | |
| state_dict = checkpoint | |
| # 加载模型状态字典 | |
| try: | |
| model.load_state_dict(state_dict) | |
| print("Model state_dict loaded successfully.") | |
| except Exception as e: | |
| print("Error loading model state_dict:", str(e)) | |
| model.eval() | |
| # 定义处理函数 | |
| def transcribe(audio): | |
| print("Transcribing audio...") | |
| waveform, sample_rate = torchaudio.load(audio) | |
| if sample_rate != 16000: | |
| resample = Resample(orig_freq=sample_rate, new_freq=16000) | |
| waveform = resample(waveform).squeeze() | |
| else: | |
| waveform = waveform.squeeze() | |
| # 将输入数据转换为符合模型预期的形状 | |
| input_values = waveform.unsqueeze(0) # (batch_size, seq_len) | |
| with torch.no_grad(): | |
| outputs = model.extract_features(input_values, padding_mask=None) | |
| logits = outputs["x"] | |
| predicted_ids = torch.argmax(logits, dim=-1) | |
| transcription = ''.join([chr(i) for i in predicted_ids[0].tolist()]) # 解码预测到字符 | |
| print("Transcription:", transcription) | |
| return transcription | |
| # 创建 Gradio 界面 | |
| iface = gr.Interface( | |
| fn=transcribe, | |
| inputs=gr.Audio(type="filepath"), | |
| outputs="text", | |
| title="TeleSpeech ASR", | |
| description="Upload an audio file or record your voice to transcribe speech to text using the TeleSpeech ASR model." | |
| ) | |
| print("Launching Gradio interface...") | |
| iface.launch() | |