jasspier commited on
Commit
31564d0
·
verified ·
1 Parent(s): ce27feb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -10
app.py CHANGED
@@ -1,32 +1,73 @@
1
  import gradio as gr
2
  import torch
3
  import torchaudio
4
- from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
5
 
6
- # 定义模型路径和加载模型
7
- model_name = "Tele-AI/TeleSpeech-ASR1.0"
8
- processor = Wav2Vec2Processor.from_pretrained(model_name)
9
- model = Wav2Vec2ForCTC.from_pretrained(model_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  # 定义处理函数
12
  def transcribe(audio):
13
  print("Transcribing audio...")
14
  waveform, sample_rate = torchaudio.load(audio)
15
- resample = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
16
  waveform = resample(waveform).squeeze()
17
 
18
- input_values = processor(waveform, return_tensors="pt", sampling_rate=16000).input_values
19
  with torch.no_grad():
20
- logits = model(input_values).logits
21
  predicted_ids = torch.argmax(logits, dim=-1)
22
- transcription = processor.batch_decode(predicted_ids)[0]
23
  print("Transcription:", transcription)
24
  return transcription
25
 
26
  # 创建 Gradio 界面
27
  iface = gr.Interface(
28
  fn=transcribe,
29
- inputs=gr.Audio(source="upload", type="filepath"),
30
  outputs="text",
31
  title="TeleSpeech ASR",
32
  description="Upload an audio file or record your voice to transcribe speech to text using the TeleSpeech ASR model."
 
1
  import gradio as gr
2
  import torch
3
  import torchaudio
4
+ from torchaudio.transforms import Resample
5
 
6
+ # 定义一个假设的 ASR 模型结构
7
+ class ASRModel(torch.nn.Module):
8
+ def __init__(self):
9
+ super(ASRModel, self).__init__()
10
+ # 这里假设模型架构是一个简单的 LSTM
11
+ self.lstm = torch.nn.LSTM(input_size=160, hidden_size=256, num_layers=3, batch_first=True)
12
+ self.linear = torch.nn.Linear(256, 29) # 假设有 29 个输出类用于字符
13
+
14
+ def forward(self, x):
15
+ x, _ = self.lstm(x)
16
+ x = self.linear(x)
17
+ return x
18
+
19
+ # 定义模型路径
20
+ model_path = "https://huggingface.co/Tele-AI/TeleSpeech-ASR1.0/resolve/main/base.pt"
21
+
22
+ # 下载模型文件
23
+ print("Downloading model file...")
24
+ torch.hub.download_url_to_file(model_path, 'large.pt')
25
+ print("Model file downloaded.")
26
+
27
+ # 初始化模型
28
+ model = ASRModel()
29
+
30
+ # 加载模型参数
31
+ print("Loading model checkpoint...")
32
+ checkpoint = torch.load('large.pt', map_location=torch.device('cpu'))
33
+ print("Checkpoint keys:", checkpoint.keys())
34
+
35
+ # 打印模型参数中的键
36
+ if 'model' in checkpoint:
37
+ state_dict = checkpoint['model']
38
+ print("Model state_dict keys:", state_dict.keys())
39
+ else:
40
+ print("Key 'model' not found in checkpoint.")
41
+ state_dict = checkpoint
42
+
43
+ # 加载模型状态字典
44
+ try:
45
+ model.load_state_dict(state_dict)
46
+ print("Model state_dict loaded successfully.")
47
+ except Exception as e:
48
+ print("Error loading model state_dict:", str(e))
49
+
50
+ model.eval()
51
 
52
  # 定义处理函数
53
  def transcribe(audio):
54
  print("Transcribing audio...")
55
  waveform, sample_rate = torchaudio.load(audio)
56
+ resample = Resample(orig_freq=sample_rate, new_freq=16000)
57
  waveform = resample(waveform).squeeze()
58
 
59
+ input_values = waveform.unsqueeze(0)
60
  with torch.no_grad():
61
+ logits = model(input_values)
62
  predicted_ids = torch.argmax(logits, dim=-1)
63
+ transcription = ''.join([chr(i) for i in predicted_ids[0].tolist()]) # 解码预测到字符
64
  print("Transcription:", transcription)
65
  return transcription
66
 
67
  # 创建 Gradio 界面
68
  iface = gr.Interface(
69
  fn=transcribe,
70
+ inputs=gr.Audio(type="filepath"),
71
  outputs="text",
72
  title="TeleSpeech ASR",
73
  description="Upload an audio file or record your voice to transcribe speech to text using the TeleSpeech ASR model."