nisacayir commited on
Commit
ae87233
·
verified ·
1 Parent(s): 0725bb1

update torch and meta

Browse files
Files changed (1) hide show
  1. app.py +24 -10
app.py CHANGED
@@ -7,7 +7,8 @@ from datetime import datetime
7
  import plotly.express as px
8
 
9
  # external ASR
10
- from omnilingual_asr.models.inference.pipeline import ASRInferencePipeline
 
11
 
12
  # internal modules
13
  from analysis import (
@@ -25,8 +26,9 @@ from region_assets import REGION_COLORS, REGION_ICONS
25
  # ---------------------------
26
  # ASR Pipeline
27
  # ---------------------------
28
- PIPELINE = ASRInferencePipeline(model_card="omniASR_LLM_1B")
29
- LANG_CODE = "tur_Latn"
 
30
 
31
 
32
  # ---------------------------
@@ -166,16 +168,28 @@ def full_process(audio, region):
166
 
167
  ref = TEST_SENTENCES[region]
168
 
169
- # ASR
170
  try:
171
- hyp_list = PIPELINE.transcribe(
172
- [{"waveform": data, "sample_rate": sr}],
173
- lang=[LANG_CODE],
174
- batch_size=1
 
 
 
 
 
 
 
 
175
  )
176
- hyp = hyp_list[0] if hyp_list else ""
 
 
 
177
  except Exception as e:
178
- return (f"ASR hata: {e}", None, None, None, None, None, None, None, None, None)
 
179
 
180
  # basic metrics
181
  _cer = cer(ref, hyp)
 
7
  import plotly.express as px
8
 
9
  # external ASR
10
+ import torch
11
+ from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
12
 
13
  # internal modules
14
  from analysis import (
 
26
  # ---------------------------
27
  # ASR Pipeline
28
  # ---------------------------
29
+ MODEL_ID = "facebook/omnilingual_asr_llm_1b"
30
+ processor = AutoProcessor.from_pretrained(MODEL_ID)
31
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(MODEL_ID, torch_dtype=torch.float16).to("cuda")
32
 
33
 
34
  # ---------------------------
 
168
 
169
  ref = TEST_SENTENCES[region]
170
 
171
+ # ASR via transformers pipeline
172
  try:
173
+ # 1. Input features hazırlama
174
+ inputs = processor(
175
+ data,
176
+ sampling_rate=sr,
177
+ return_tensors="pt"
178
+ ).to("cuda")
179
+
180
+ # 2. Model ile inference
181
+ with torch.no_grad():
182
+ generated_ids = model.generate(
183
+ inputs.input_features,
184
+ max_length=400
185
  )
186
+
187
+ # 3. Decode
188
+ hyp = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
189
+
190
  except Exception as e:
191
+ return (f"ASR Hatası: {e}", None, None, None, None, None, None, None, None, None)
192
+
193
 
194
  # basic metrics
195
  _cer = cer(ref, hyp)