Spaces:
Runtime error
Runtime error
File size: 2,841 Bytes
1869d7a 71b0e8e 31564d0 602897f d3defc4 602897f d926f75 602897f 31564d0 ac47f83 31564d0 ca625d0 960bf82 31564d0 1869d7a e2bcfc6 71b0e8e 960bf82 eb24c35 04a460e ca625d0 960bf82 1869d7a 960bf82 ca625d0 f8ebe93 31564d0 e2bcfc6 71b0e8e 1869d7a f8ebe93 31564d0 1869d7a e2bcfc6 1869d7a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import gradio as gr
import torch
import torchaudio
from torchaudio.transforms import Resample
import importlib.util
# Function to dynamically import wav2vec2 module and avoid duplicate registration
def import_wav2vec2():
if 'wav2vec2' not in sys.modules:
spec = importlib.util.spec_from_file_location("wav2vec2", "wav2vec2.py")
wav2vec2 = importlib.util.module_from_spec(spec)
sys.modules['wav2vec2'] = wav2vec2
spec.loader.exec_module(wav2vec2)
else:
wav2vec2 = sys.modules['wav2vec2']
Wav2Vec2Model = wav2vec2.Wav2Vec2Model
Wav2Vec2Config = wav2vec2.Wav2Vec2Config
return Wav2Vec2Model, Wav2Vec2Config
Wav2Vec2Model, Wav2Vec2Config = import_wav2vec2()
# 定义模型路径
model_path = "https://huggingface.co/Tele-AI/TeleSpeech-ASR1.0/resolve/main/finetune_large_kespeech.pt"
# 下载模型文件
print("Downloading model file...")
torch.hub.download_url_to_file(model_path, 'large.pt')
print("Model file downloaded.")
# 加载模型配置和初始化模型
config = Wav2Vec2Config()
model = Wav2Vec2Model.build_model(config)
# 加载模型参数
print("Loading model checkpoint...")
checkpoint = torch.load('large.pt', map_location=torch.device('cpu'))
print("Checkpoint keys:", checkpoint.keys())
# 打印模型参数中的键
if 'model' in checkpoint:
state_dict = checkpoint['model']
print("Model state_dict keys:", state_dict.keys())
else:
print("Key 'model' not found in checkpoint.")
state_dict = checkpoint
# 加载模型状态字典
try:
model.load_state_dict(state_dict)
print("Model state_dict loaded successfully.")
except Exception as e:
print("Error loading model state_dict:", str(e))
model.eval()
# 定义处理函数
def transcribe(audio):
print("Transcribing audio...")
waveform, sample_rate = torchaudio.load(audio)
if sample_rate != 16000:
resample = Resample(orig_freq=sample_rate, new_freq=16000)
waveform = resample(waveform).squeeze()
else:
waveform = waveform.squeeze()
# 将输入数据转换为符合模型预期的形状
input_values = waveform.unsqueeze(0) # (batch_size, seq_len)
with torch.no_grad():
outputs = model.extract_features(input_values, padding_mask=None)
logits = outputs["x"]
predicted_ids = torch.argmax(logits, dim=-1)
transcription = ''.join([chr(i) for i in predicted_ids[0].tolist()]) # 解码预测到字符
print("Transcription:", transcription)
return transcription
# 创建 Gradio 界面
iface = gr.Interface(
fn=transcribe,
inputs=gr.Audio(type="filepath"),
outputs="text",
title="TeleSpeech ASR",
description="Upload an audio file or record your voice to transcribe speech to text using the TeleSpeech ASR model."
)
print("Launching Gradio interface...")
iface.launch()
|