jasspier commited on
Commit
ca625d0
·
verified ·
1 Parent(s): 4181e49

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -22
app.py CHANGED
@@ -2,18 +2,7 @@ import gradio as gr
2
  import torch
3
  import torchaudio
4
  from torchaudio.transforms import Resample
5
-
6
- # 定义一个假设的 ASR 模型结构
7
- class ASRModel(torch.nn.Module):
8
- def __init__(self):
9
- super(ASRModel, self).__init__()
10
- self.lstm = torch.nn.LSTM(input_size=160, hidden_size=256, num_layers=3, batch_first=True)
11
- self.linear = torch.nn.Linear(256, 29) # 假设有 29 个输出类用于字符
12
-
13
- def forward(self, x):
14
- x, _ = self.lstm(x)
15
- x = self.linear(x)
16
- return x
17
 
18
  # 定义模型路径
19
  model_path = "https://huggingface.co/Tele-AI/TeleSpeech-ASR1.0/resolve/main/finetune_large_kespeech.pt"
@@ -23,8 +12,9 @@ print("Downloading model file...")
23
  torch.hub.download_url_to_file(model_path, 'large.pt')
24
  print("Model file downloaded.")
25
 
26
- # 初始化模型
27
- model = ASRModel()
 
28
 
29
  # 加载模型参数
30
  print("Loading model checkpoint...")
@@ -56,16 +46,11 @@ def transcribe(audio):
56
  waveform = resample(waveform).squeeze()
57
 
58
  # 将输入数据转换为符合模型预期的形状
59
- num_frames = waveform.size(0)
60
- if num_frames % 160 != 0:
61
- # 如果样本数量不是160的倍数,则填充样本
62
- num_frames_padded = ((num_frames // 160) + 1) * 160
63
- padding = num_frames_padded - num_frames
64
- waveform = torch.nn.functional.pad(waveform, (0, padding))
65
- input_values = waveform.view(-1, 160).unsqueeze(0) # 确保输入形状为 (batch_size, seq_len, input_size)
66
 
67
  with torch.no_grad():
68
- logits = model(input_values)
 
69
  predicted_ids = torch.argmax(logits, dim=-1)
70
  transcription = ''.join([chr(i) for i in predicted_ids[0].tolist()]) # 解码预测到字符
71
  print("Transcription:", transcription)
 
2
  import torch
3
  import torchaudio
4
  from torchaudio.transforms import Resample
5
+ from data2vec2 import Data2VecMultiModel, Data2VecMultiConfig, Modality
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  # 定义模型路径
8
  model_path = "https://huggingface.co/Tele-AI/TeleSpeech-ASR1.0/resolve/main/finetune_large_kespeech.pt"
 
12
  torch.hub.download_url_to_file(model_path, 'large.pt')
13
  print("Model file downloaded.")
14
 
15
+ # 加载模型配置和初始化模型
16
+ config = Data2VecMultiConfig()
17
+ model = Data2VecMultiModel(config, modalities=[Modality.AUDIO])
18
 
19
  # 加载模型参数
20
  print("Loading model checkpoint...")
 
46
  waveform = resample(waveform).squeeze()
47
 
48
  # 将输入数据转换为符合模型预期的形状
49
+ input_values = waveform.unsqueeze(0) # (batch_size, seq_len)
 
 
 
 
 
 
50
 
51
  with torch.no_grad():
52
+ outputs = model.extract_features(input_values, mode='AUDIO')
53
+ logits = outputs["x"]
54
  predicted_ids = torch.argmax(logits, dim=-1)
55
  transcription = ''.join([chr(i) for i in predicted_ids[0].tolist()]) # 解码预测到字符
56
  print("Transcription:", transcription)