tiny-audio-mlp-stage2 / asr_processing.py
mazesmazes's picture
Training in progress - step 500
949f02b verified
from typing import Optional, Union
import torch
import transformers
from transformers import ProcessorMixin
try:
from .asr_config import ASRConfig
except ImportError:
from asr_config import ASRConfig # type: ignore[no-redef]
class ASRProcessor(ProcessorMixin):
"""Processor for Whisper-based ASR models."""
attributes = ["feature_extractor", "tokenizer"]
feature_extractor_class = "AutoFeatureExtractor"
tokenizer_class = "AutoTokenizer"
AUDIO_TOKEN = "<audio>"
TRANSCRIBE_PROMPT = "Transcribe: "
def __init__(self, feature_extractor, tokenizer, projector=None):
self.feature_extractor = feature_extractor
self.tokenizer = tokenizer
self.audio_token_id = tokenizer.convert_tokens_to_ids(self.AUDIO_TOKEN)
self.projector = projector
def __call__(
self,
audio: Optional[Union[list, "torch.Tensor"]] = None,
text: Optional[str] = None,
system_prompt: Optional[str] = None,
return_tensors: str = "pt",
**kwargs,
) -> dict:
"""Process audio and text inputs for inference.
Args:
audio: Raw audio waveform(s)
text: Target transcription (optional, for training - but use DataCollator instead)
system_prompt: Optional system prompt
return_tensors: Return format ("pt" for PyTorch)
Returns:
Dict with input_features, input_ids, attention_mask
"""
result = {}
# Process audio
if audio is not None:
audio_inputs = self.feature_extractor(
audio,
sampling_rate=getattr(self.feature_extractor, "sampling_rate", 16000),
return_attention_mask=True,
return_tensors=return_tensors,
**kwargs,
)
result["input_features"] = audio_inputs["input_features"]
result["audio_attention_mask"] = audio_inputs["attention_mask"]
# Use actual audio length (from attention mask) for token count
real_mel_len = audio_inputs["attention_mask"].sum(dim=-1).max().item()
encoder_output_len = real_mel_len // 2
num_audio_tokens = self.projector.get_output_length(encoder_output_len)
else:
num_audio_tokens = 0
# Build prompt with audio token placeholders
user_content = self.TRANSCRIBE_PROMPT
if num_audio_tokens > 0:
user_content += self.AUDIO_TOKEN * num_audio_tokens
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": user_content})
if text is not None:
messages.append({"role": "assistant", "content": text})
# Tokenize
input_ids = self.tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=(text is None),
return_tensors=return_tensors,
)
if isinstance(input_ids, torch.Tensor) and input_ids.dim() == 1:
input_ids = input_ids.unsqueeze(0)
result["input_ids"] = input_ids
result["attention_mask"] = torch.ones_like(input_ids)
return result
ASRProcessor.register_for_auto_class()
transformers.AutoProcessor.register(ASRConfig, ASRProcessor)