update torch and meta
Browse files
app.py
CHANGED
|
@@ -7,7 +7,8 @@ from datetime import datetime
|
|
| 7 |
import plotly.express as px
|
| 8 |
|
| 9 |
# external ASR
|
| 10 |
-
|
|
|
|
| 11 |
|
| 12 |
# internal modules
|
| 13 |
from analysis import (
|
|
@@ -25,8 +26,9 @@ from region_assets import REGION_COLORS, REGION_ICONS
|
|
| 25 |
# ---------------------------
|
| 26 |
# ASR Pipeline
|
| 27 |
# ---------------------------
|
| 28 |
-
|
| 29 |
-
|
|
|
|
| 30 |
|
| 31 |
|
| 32 |
# ---------------------------
|
|
@@ -166,16 +168,28 @@ def full_process(audio, region):
|
|
| 166 |
|
| 167 |
ref = TEST_SENTENCES[region]
|
| 168 |
|
| 169 |
-
# ASR
|
| 170 |
try:
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
)
|
| 176 |
-
|
|
|
|
|
|
|
|
|
|
| 177 |
except Exception as e:
|
| 178 |
-
return (f"ASR
|
|
|
|
| 179 |
|
| 180 |
# basic metrics
|
| 181 |
_cer = cer(ref, hyp)
|
|
|
|
| 7 |
import plotly.express as px
|
| 8 |
|
| 9 |
# external ASR
|
| 10 |
+
import torch
|
| 11 |
+
from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
|
| 12 |
|
| 13 |
# internal modules
|
| 14 |
from analysis import (
|
|
|
|
| 26 |
# ---------------------------
|
| 27 |
# ASR Pipeline
|
| 28 |
# ---------------------------
|
| 29 |
+
MODEL_ID = "facebook/omnilingual_asr_llm_1b"
|
| 30 |
+
processor = AutoProcessor.from_pretrained(MODEL_ID)
|
| 31 |
+
model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_ID, torch_dtype=torch.float16).to("cuda")
|
| 32 |
|
| 33 |
|
| 34 |
# ---------------------------
|
|
|
|
| 168 |
|
| 169 |
ref = TEST_SENTENCES[region]
|
| 170 |
|
| 171 |
+
# ASR via transformers pipeline
|
| 172 |
try:
|
| 173 |
+
# 1. Input features hazırlama
|
| 174 |
+
inputs = processor(
|
| 175 |
+
data,
|
| 176 |
+
sampling_rate=sr,
|
| 177 |
+
return_tensors="pt"
|
| 178 |
+
).to("cuda")
|
| 179 |
+
|
| 180 |
+
# 2. Model ile inference
|
| 181 |
+
with torch.no_grad():
|
| 182 |
+
generated_ids = model.generate(
|
| 183 |
+
inputs.input_features,
|
| 184 |
+
max_length=400
|
| 185 |
)
|
| 186 |
+
|
| 187 |
+
# 3. Decode
|
| 188 |
+
hyp = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
| 189 |
+
|
| 190 |
except Exception as e:
|
| 191 |
+
return (f"ASR Hatası: {e}", None, None, None, None, None, None, None, None, None)
|
| 192 |
+
|
| 193 |
|
| 194 |
# basic metrics
|
| 195 |
_cer = cer(ref, hyp)
|