Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -2,18 +2,7 @@ import gradio as gr
|
|
| 2 |
import torch
|
| 3 |
import torchaudio
|
| 4 |
from torchaudio.transforms import Resample
|
| 5 |
-
|
| 6 |
-
# 定义一个假设的 ASR 模型结构
|
| 7 |
-
class ASRModel(torch.nn.Module):
|
| 8 |
-
def __init__(self):
|
| 9 |
-
super(ASRModel, self).__init__()
|
| 10 |
-
self.lstm = torch.nn.LSTM(input_size=160, hidden_size=256, num_layers=3, batch_first=True)
|
| 11 |
-
self.linear = torch.nn.Linear(256, 29) # 假设有 29 个输出类用于字符
|
| 12 |
-
|
| 13 |
-
def forward(self, x):
|
| 14 |
-
x, _ = self.lstm(x)
|
| 15 |
-
x = self.linear(x)
|
| 16 |
-
return x
|
| 17 |
|
| 18 |
# 定义模型路径
|
| 19 |
model_path = "https://huggingface.co/Tele-AI/TeleSpeech-ASR1.0/resolve/main/finetune_large_kespeech.pt"
|
|
@@ -23,8 +12,9 @@ print("Downloading model file...")
|
|
| 23 |
torch.hub.download_url_to_file(model_path, 'large.pt')
|
| 24 |
print("Model file downloaded.")
|
| 25 |
|
| 26 |
-
#
|
| 27 |
-
|
|
|
|
| 28 |
|
| 29 |
# 加载模型参数
|
| 30 |
print("Loading model checkpoint...")
|
|
@@ -56,16 +46,11 @@ def transcribe(audio):
|
|
| 56 |
waveform = resample(waveform).squeeze()
|
| 57 |
|
| 58 |
# 将输入数据转换为符合模型预期的形状
|
| 59 |
-
|
| 60 |
-
if num_frames % 160 != 0:
|
| 61 |
-
# 如果样本数量不是160的倍数,则填充样本
|
| 62 |
-
num_frames_padded = ((num_frames // 160) + 1) * 160
|
| 63 |
-
padding = num_frames_padded - num_frames
|
| 64 |
-
waveform = torch.nn.functional.pad(waveform, (0, padding))
|
| 65 |
-
input_values = waveform.view(-1, 160).unsqueeze(0) # 确保输入形状为 (batch_size, seq_len, input_size)
|
| 66 |
|
| 67 |
with torch.no_grad():
|
| 68 |
-
|
|
|
|
| 69 |
predicted_ids = torch.argmax(logits, dim=-1)
|
| 70 |
transcription = ''.join([chr(i) for i in predicted_ids[0].tolist()]) # 解码预测到字符
|
| 71 |
print("Transcription:", transcription)
|
|
|
|
| 2 |
import torch
|
| 3 |
import torchaudio
|
| 4 |
from torchaudio.transforms import Resample
|
| 5 |
+
from data2vec2 import Data2VecMultiModel, Data2VecMultiConfig, Modality
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
# 定义模型路径
|
| 8 |
model_path = "https://huggingface.co/Tele-AI/TeleSpeech-ASR1.0/resolve/main/finetune_large_kespeech.pt"
|
|
|
|
| 12 |
torch.hub.download_url_to_file(model_path, 'large.pt')
|
| 13 |
print("Model file downloaded.")
|
| 14 |
|
| 15 |
+
# 加载模型配置和初始化模型
|
| 16 |
+
config = Data2VecMultiConfig()
|
| 17 |
+
model = Data2VecMultiModel(config, modalities=[Modality.AUDIO])
|
| 18 |
|
| 19 |
# 加载模型参数
|
| 20 |
print("Loading model checkpoint...")
|
|
|
|
| 46 |
waveform = resample(waveform).squeeze()
|
| 47 |
|
| 48 |
# 将输入数据转换为符合模型预期的形状
|
| 49 |
+
input_values = waveform.unsqueeze(0) # (batch_size, seq_len)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
with torch.no_grad():
|
| 52 |
+
outputs = model.extract_features(input_values, mode='AUDIO')
|
| 53 |
+
logits = outputs["x"]
|
| 54 |
predicted_ids = torch.argmax(logits, dim=-1)
|
| 55 |
transcription = ''.join([chr(i) for i in predicted_ids[0].tolist()]) # 解码预测到字符
|
| 56 |
print("Transcription:", transcription)
|