import json from pathlib import Path from typing import Optional, Union import torch import torch.nn as nn from transformers import ( AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, ) from transformers.generation import GenerationMixin from transformers.modeling_outputs import CausalLMOutputWithPast try: from .asr_config import ASRConfig from .projectors import PROJECTOR_CLASSES except ImportError: from asr_config import ASRConfig # type: ignore[no-redef] from projectors import PROJECTOR_CLASSES # type: ignore[no-redef] class ASRModel(PreTrainedModel, GenerationMixin): """Audio-to-text model combining an audio encoder, projector, and language model.""" config_class = ASRConfig base_model_prefix = "model" main_input_name = "input_features" _supports_flash_attn_2 = True supports_gradient_checkpointing = True _is_loading_from_pretrained: bool = False _pretrained_model_path: Optional[str] = None TRANSCRIBE_PROMPT = "Transcribe: " @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): """Load model from pretrained, handling device placement correctly.""" from safetensors.torch import load_file from transformers.utils.hub import cached_file config = kwargs.pop("config", None) if config is None: config = ASRConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) # Set flag to avoid device_map="auto" in sub-model loaders cls._is_loading_from_pretrained = True cls._pretrained_model_path = pretrained_model_name_or_path try: model = cls(config, **kwargs) # Load projector weights from safetensors subfolder = kwargs.get("subfolder") revision = kwargs.get("revision") cache_kwargs = {} if subfolder: cache_kwargs["subfolder"] = subfolder if revision: cache_kwargs["revision"] = revision model_file = cached_file( pretrained_model_name_or_path, "model.safetensors", _raise_exceptions_for_missing_entries=False, **cache_kwargs, ) if model_file is not None: state_dict = load_file(model_file) model.load_state_dict(state_dict, strict=False) # Load LoRA adapter if present adapter_config = cached_file( pretrained_model_name_or_path, "adapter_config.json", _raise_exceptions_for_missing_entries=False, **cache_kwargs, ) if adapter_config is not None: from peft import PeftModel # Pass original repo ID to PEFT, let it handle caching model.language_model = PeftModel.from_pretrained( model.language_model, pretrained_model_name_or_path, is_trainable=False ) return model finally: cls._is_loading_from_pretrained = False cls._pretrained_model_path = None def __init__(self, config: ASRConfig, **kwargs): super().__init__(config) self.system_prompt = config.system_prompt target_dtype = getattr(torch, config.model_dtype) # Audio encoder (frozen) self.audio_tower = self._load_audio_encoder(config, target_dtype) # Language model (frozen) self.language_model = self._load_language_model(config, target_dtype) # Initialize tokenizer and special tokens self._init_tokenizer(config) # Set up generation config with greedy decoding defaults self.generation_config = self.language_model.generation_config self.generation_config.max_new_tokens = config.max_new_tokens self.generation_config.num_beams = config.num_beams self.generation_config.do_sample = False # Clear sampling params (inherited from LLM) since we use greedy decoding self.generation_config.temperature = None self.generation_config.top_p = None self.generation_config.top_k = None self.generation_config.use_cache = config.use_cache self.generation_config.length_penalty = config.length_penalty self.generation_config.repetition_penalty = config.repetition_penalty self.generation_config.no_repeat_ngram_size = config.no_repeat_ngram_size self.generation_config.eos_token_id = [ self.tokenizer.convert_tokens_to_ids("<|im_end|>"), self.tokenizer.convert_tokens_to_ids("<|endoftext|>"), ] self.generation_config.pad_token_id = self.tokenizer.pad_token_id # Feature extractor for audio preprocessing self.feature_extractor = self._create_feature_extractor(config) # Audio projector (trainable) self.projector = self._create_projector(config, target_dtype) # For model parallelism self._no_split_modules = getattr(self.language_model, "_no_split_modules", []) def _create_feature_extractor(self, config: ASRConfig): """Create the appropriate feature extractor for the audio encoder.""" from transformers import AutoFeatureExtractor return AutoFeatureExtractor.from_pretrained(config.audio_model_id) @classmethod def _load_audio_encoder(cls, config: ASRConfig, dtype: torch.dtype) -> nn.Module: """Load and freeze the audio encoder.""" encoder_kwargs = { "attn_implementation": config.attn_implementation, "low_cpu_mem_usage": True, "dtype": dtype, } if "whisper" in config.audio_model_id.lower(): from transformers import WhisperModel full_model = WhisperModel.from_pretrained(config.audio_model_id, **encoder_kwargs) encoder = full_model.encoder del full_model elif "glm" in config.audio_model_id.lower(): # GLM-ASR models use audio_tower as the encoder # Requires transformers >= 5.x or installed from source from transformers import AutoModelForSeq2SeqLM full_model = AutoModelForSeq2SeqLM.from_pretrained( config.audio_model_id, trust_remote_code=True, **encoder_kwargs ) # GLM stores encoder at audio_tower (GlmAsrEncoder) encoder = full_model.audio_tower # Clear references to free VRAM from the LLM decoder full_model.language_model = None full_model.multi_modal_projector = None del full_model if torch.cuda.is_available(): torch.cuda.empty_cache() else: encoder = AutoModel.from_pretrained(config.audio_model_id, **encoder_kwargs) encoder.requires_grad_(False) encoder.eval() return encoder @classmethod def _load_language_model(cls, config: ASRConfig, dtype: torch.dtype) -> PreTrainedModel: """Load and freeze the language model.""" decoder_kwargs = { "attn_implementation": config.attn_implementation, "trust_remote_code": True, "tie_word_embeddings": False, "low_cpu_mem_usage": True, "dtype": dtype, } decoder = AutoModelForCausalLM.from_pretrained(config.text_model_id, **decoder_kwargs) decoder.config.use_cache = getattr(config, "use_cache", True) decoder.requires_grad_(False) decoder.eval() return decoder def _create_projector(self, config: ASRConfig, dtype: torch.dtype) -> nn.Module: """Create the trainable audio projector.""" # Auto-detect dimensions if not specified if config.encoder_dim is None: enc_cfg = self.audio_tower.config config.encoder_dim = getattr(enc_cfg, "hidden_size", None) or getattr( enc_cfg, "d_model", None ) if config.encoder_dim is None: raise ValueError("Could not auto-detect encoder_dim. Please specify in config.") if config.llm_dim is None: dec_cfg = self.language_model.config config.llm_dim = getattr(dec_cfg, "hidden_size", None) or getattr( dec_cfg, "d_model", None ) if config.llm_dim is None: raise ValueError("Could not auto-detect llm_dim. Please specify in config.") # Select projector type based on config projector_type = getattr(config, "projector_type", "mlp") projector_class = PROJECTOR_CLASSES.get(projector_type) if projector_class is None: raise ValueError( f"Unknown projector_type: {projector_type}. " f"Valid options: {list(PROJECTOR_CLASSES.keys())}" ) projector = projector_class(config) # Move projector to same device as language model (important when using quantization) device = next(self.language_model.parameters()).device return projector.to(device=device, dtype=dtype) def _init_tokenizer(self, config: ASRConfig): """Initialize tokenizer with audio token.""" self.tokenizer = AutoTokenizer.from_pretrained(config.text_model_id, trust_remote_code=True) # Set pad token if ( self.tokenizer.pad_token is None or self.tokenizer.pad_token_id == self.tokenizer.eos_token_id ) and "<|finetune_right_pad_id|>" in self.tokenizer.get_vocab(): self.tokenizer.pad_token = "<|finetune_right_pad_id|>" # Add audio token existing_special = getattr(self.tokenizer, "additional_special_tokens", None) or [] if "