Training in progress - step 1000
Browse files- .gitattributes +1 -0
- README.md +199 -0
- alignment.py +283 -0
- asr_config.py +233 -0
- asr_modeling.py +896 -0
- asr_pipeline.py +322 -0
- asr_processing.py +133 -0
- chat_template.jinja +89 -0
- diarization.py +732 -0
- preprocessor_config.json +19 -0
- projectors.py +505 -0
- tokenizer.json +3 -0
- tokenizer_config.json +0 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
library_name: transformers
|
| 3 |
+
tags: []
|
| 4 |
+
---
|
| 5 |
+
|
| 6 |
+
# Model Card for Model ID
|
| 7 |
+
|
| 8 |
+
<!-- Provide a quick summary of what the model is/does. -->
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
## Model Details
|
| 13 |
+
|
| 14 |
+
### Model Description
|
| 15 |
+
|
| 16 |
+
<!-- Provide a longer summary of what this model is. -->
|
| 17 |
+
|
| 18 |
+
This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
|
| 19 |
+
|
| 20 |
+
- **Developed by:** [More Information Needed]
|
| 21 |
+
- **Funded by [optional]:** [More Information Needed]
|
| 22 |
+
- **Shared by [optional]:** [More Information Needed]
|
| 23 |
+
- **Model type:** [More Information Needed]
|
| 24 |
+
- **Language(s) (NLP):** [More Information Needed]
|
| 25 |
+
- **License:** [More Information Needed]
|
| 26 |
+
- **Finetuned from model [optional]:** [More Information Needed]
|
| 27 |
+
|
| 28 |
+
### Model Sources [optional]
|
| 29 |
+
|
| 30 |
+
<!-- Provide the basic links for the model. -->
|
| 31 |
+
|
| 32 |
+
- **Repository:** [More Information Needed]
|
| 33 |
+
- **Paper [optional]:** [More Information Needed]
|
| 34 |
+
- **Demo [optional]:** [More Information Needed]
|
| 35 |
+
|
| 36 |
+
## Uses
|
| 37 |
+
|
| 38 |
+
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
|
| 39 |
+
|
| 40 |
+
### Direct Use
|
| 41 |
+
|
| 42 |
+
<!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
|
| 43 |
+
|
| 44 |
+
[More Information Needed]
|
| 45 |
+
|
| 46 |
+
### Downstream Use [optional]
|
| 47 |
+
|
| 48 |
+
<!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
|
| 49 |
+
|
| 50 |
+
[More Information Needed]
|
| 51 |
+
|
| 52 |
+
### Out-of-Scope Use
|
| 53 |
+
|
| 54 |
+
<!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
|
| 55 |
+
|
| 56 |
+
[More Information Needed]
|
| 57 |
+
|
| 58 |
+
## Bias, Risks, and Limitations
|
| 59 |
+
|
| 60 |
+
<!-- This section is meant to convey both technical and sociotechnical limitations. -->
|
| 61 |
+
|
| 62 |
+
[More Information Needed]
|
| 63 |
+
|
| 64 |
+
### Recommendations
|
| 65 |
+
|
| 66 |
+
<!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
|
| 67 |
+
|
| 68 |
+
Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
|
| 69 |
+
|
| 70 |
+
## How to Get Started with the Model
|
| 71 |
+
|
| 72 |
+
Use the code below to get started with the model.
|
| 73 |
+
|
| 74 |
+
[More Information Needed]
|
| 75 |
+
|
| 76 |
+
## Training Details
|
| 77 |
+
|
| 78 |
+
### Training Data
|
| 79 |
+
|
| 80 |
+
<!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
|
| 81 |
+
|
| 82 |
+
[More Information Needed]
|
| 83 |
+
|
| 84 |
+
### Training Procedure
|
| 85 |
+
|
| 86 |
+
<!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
|
| 87 |
+
|
| 88 |
+
#### Preprocessing [optional]
|
| 89 |
+
|
| 90 |
+
[More Information Needed]
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
#### Training Hyperparameters
|
| 94 |
+
|
| 95 |
+
- **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
|
| 96 |
+
|
| 97 |
+
#### Speeds, Sizes, Times [optional]
|
| 98 |
+
|
| 99 |
+
<!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
|
| 100 |
+
|
| 101 |
+
[More Information Needed]
|
| 102 |
+
|
| 103 |
+
## Evaluation
|
| 104 |
+
|
| 105 |
+
<!-- This section describes the evaluation protocols and provides the results. -->
|
| 106 |
+
|
| 107 |
+
### Testing Data, Factors & Metrics
|
| 108 |
+
|
| 109 |
+
#### Testing Data
|
| 110 |
+
|
| 111 |
+
<!-- This should link to a Dataset Card if possible. -->
|
| 112 |
+
|
| 113 |
+
[More Information Needed]
|
| 114 |
+
|
| 115 |
+
#### Factors
|
| 116 |
+
|
| 117 |
+
<!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
|
| 118 |
+
|
| 119 |
+
[More Information Needed]
|
| 120 |
+
|
| 121 |
+
#### Metrics
|
| 122 |
+
|
| 123 |
+
<!-- These are the evaluation metrics being used, ideally with a description of why. -->
|
| 124 |
+
|
| 125 |
+
[More Information Needed]
|
| 126 |
+
|
| 127 |
+
### Results
|
| 128 |
+
|
| 129 |
+
[More Information Needed]
|
| 130 |
+
|
| 131 |
+
#### Summary
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
## Model Examination [optional]
|
| 136 |
+
|
| 137 |
+
<!-- Relevant interpretability work for the model goes here -->
|
| 138 |
+
|
| 139 |
+
[More Information Needed]
|
| 140 |
+
|
| 141 |
+
## Environmental Impact
|
| 142 |
+
|
| 143 |
+
<!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
|
| 144 |
+
|
| 145 |
+
Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
|
| 146 |
+
|
| 147 |
+
- **Hardware Type:** [More Information Needed]
|
| 148 |
+
- **Hours used:** [More Information Needed]
|
| 149 |
+
- **Cloud Provider:** [More Information Needed]
|
| 150 |
+
- **Compute Region:** [More Information Needed]
|
| 151 |
+
- **Carbon Emitted:** [More Information Needed]
|
| 152 |
+
|
| 153 |
+
## Technical Specifications [optional]
|
| 154 |
+
|
| 155 |
+
### Model Architecture and Objective
|
| 156 |
+
|
| 157 |
+
[More Information Needed]
|
| 158 |
+
|
| 159 |
+
### Compute Infrastructure
|
| 160 |
+
|
| 161 |
+
[More Information Needed]
|
| 162 |
+
|
| 163 |
+
#### Hardware
|
| 164 |
+
|
| 165 |
+
[More Information Needed]
|
| 166 |
+
|
| 167 |
+
#### Software
|
| 168 |
+
|
| 169 |
+
[More Information Needed]
|
| 170 |
+
|
| 171 |
+
## Citation [optional]
|
| 172 |
+
|
| 173 |
+
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
|
| 174 |
+
|
| 175 |
+
**BibTeX:**
|
| 176 |
+
|
| 177 |
+
[More Information Needed]
|
| 178 |
+
|
| 179 |
+
**APA:**
|
| 180 |
+
|
| 181 |
+
[More Information Needed]
|
| 182 |
+
|
| 183 |
+
## Glossary [optional]
|
| 184 |
+
|
| 185 |
+
<!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
|
| 186 |
+
|
| 187 |
+
[More Information Needed]
|
| 188 |
+
|
| 189 |
+
## More Information [optional]
|
| 190 |
+
|
| 191 |
+
[More Information Needed]
|
| 192 |
+
|
| 193 |
+
## Model Card Authors [optional]
|
| 194 |
+
|
| 195 |
+
[More Information Needed]
|
| 196 |
+
|
| 197 |
+
## Model Card Contact
|
| 198 |
+
|
| 199 |
+
[More Information Needed]
|
alignment.py
ADDED
|
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Forced alignment for word-level timestamps using Wav2Vec2."""
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def _get_device() -> str:
|
| 8 |
+
"""Get best available device for non-transformers models."""
|
| 9 |
+
if torch.cuda.is_available():
|
| 10 |
+
return "cuda"
|
| 11 |
+
if torch.backends.mps.is_available():
|
| 12 |
+
return "mps"
|
| 13 |
+
return "cpu"
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class ForcedAligner:
|
| 17 |
+
"""Lazy-loaded forced aligner for word-level timestamps using torchaudio wav2vec2.
|
| 18 |
+
|
| 19 |
+
Uses Viterbi trellis algorithm for optimal alignment path finding.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
_bundle = None
|
| 23 |
+
_model = None
|
| 24 |
+
_labels = None
|
| 25 |
+
_dictionary = None
|
| 26 |
+
|
| 27 |
+
@classmethod
|
| 28 |
+
def get_instance(cls, device: str = "cuda"):
|
| 29 |
+
"""Get or create the forced alignment model (singleton).
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
device: Device to run model on ("cuda" or "cpu")
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
Tuple of (model, labels, dictionary)
|
| 36 |
+
"""
|
| 37 |
+
if cls._model is None:
|
| 38 |
+
import torchaudio
|
| 39 |
+
|
| 40 |
+
cls._bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
|
| 41 |
+
cls._model = cls._bundle.get_model().to(device)
|
| 42 |
+
cls._model.eval()
|
| 43 |
+
cls._labels = cls._bundle.get_labels()
|
| 44 |
+
cls._dictionary = {c: i for i, c in enumerate(cls._labels)}
|
| 45 |
+
return cls._model, cls._labels, cls._dictionary
|
| 46 |
+
|
| 47 |
+
@staticmethod
|
| 48 |
+
def _get_trellis(emission: torch.Tensor, tokens: list[int], blank_id: int = 0) -> torch.Tensor:
|
| 49 |
+
"""Build trellis for forced alignment using forward algorithm.
|
| 50 |
+
|
| 51 |
+
The trellis[t, j] represents the log probability of the best path that
|
| 52 |
+
aligns the first j tokens to the first t frames.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
emission: Log-softmax emission matrix of shape (num_frames, num_classes)
|
| 56 |
+
tokens: List of target token indices
|
| 57 |
+
blank_id: Index of the blank/CTC token (default 0)
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
Trellis matrix of shape (num_frames + 1, num_tokens + 1)
|
| 61 |
+
"""
|
| 62 |
+
num_frames = emission.size(0)
|
| 63 |
+
num_tokens = len(tokens)
|
| 64 |
+
|
| 65 |
+
trellis = torch.full((num_frames + 1, num_tokens + 1), -float("inf"))
|
| 66 |
+
trellis[0, 0] = 0
|
| 67 |
+
|
| 68 |
+
for t in range(num_frames):
|
| 69 |
+
for j in range(num_tokens + 1):
|
| 70 |
+
# Stay: emit blank and stay at j tokens
|
| 71 |
+
stay = trellis[t, j] + emission[t, blank_id]
|
| 72 |
+
|
| 73 |
+
# Move: emit token j and advance to j+1 tokens
|
| 74 |
+
move = trellis[t, j - 1] + emission[t, tokens[j - 1]] if j > 0 else -float("inf")
|
| 75 |
+
|
| 76 |
+
trellis[t + 1, j] = max(stay, move) # Viterbi: take best path
|
| 77 |
+
|
| 78 |
+
return trellis
|
| 79 |
+
|
| 80 |
+
@staticmethod
|
| 81 |
+
def _backtrack(
|
| 82 |
+
trellis: torch.Tensor, emission: torch.Tensor, tokens: list[int], blank_id: int = 0
|
| 83 |
+
) -> list[tuple[int, float, float]]:
|
| 84 |
+
"""Backtrack through trellis to find optimal forced monotonic alignment.
|
| 85 |
+
|
| 86 |
+
Guarantees:
|
| 87 |
+
- All tokens are emitted exactly once
|
| 88 |
+
- Strictly monotonic: each token's frames come after previous token's
|
| 89 |
+
- No frame skipping or token teleporting
|
| 90 |
+
|
| 91 |
+
Returns list of (token_id, start_frame, end_frame) for each token.
|
| 92 |
+
"""
|
| 93 |
+
num_frames = emission.size(0)
|
| 94 |
+
num_tokens = len(tokens)
|
| 95 |
+
|
| 96 |
+
if num_tokens == 0:
|
| 97 |
+
return []
|
| 98 |
+
|
| 99 |
+
# Find the best ending point (should be at num_tokens)
|
| 100 |
+
# But verify trellis reached a valid state
|
| 101 |
+
if trellis[num_frames, num_tokens] == -float("inf"):
|
| 102 |
+
# Alignment failed - fall back to uniform distribution
|
| 103 |
+
frames_per_token = num_frames / num_tokens
|
| 104 |
+
return [
|
| 105 |
+
(tokens[i], i * frames_per_token, (i + 1) * frames_per_token)
|
| 106 |
+
for i in range(num_tokens)
|
| 107 |
+
]
|
| 108 |
+
|
| 109 |
+
# Backtrack: find where each token transition occurred
|
| 110 |
+
# path[i] = frame where token i was first emitted
|
| 111 |
+
token_frames: list[list[int]] = [[] for _ in range(num_tokens)]
|
| 112 |
+
|
| 113 |
+
t = num_frames
|
| 114 |
+
j = num_tokens
|
| 115 |
+
|
| 116 |
+
while t > 0 and j > 0:
|
| 117 |
+
# Check: did we transition from j-1 to j at frame t-1?
|
| 118 |
+
stay_score = trellis[t - 1, j] + emission[t - 1, blank_id]
|
| 119 |
+
move_score = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]
|
| 120 |
+
|
| 121 |
+
if move_score >= stay_score:
|
| 122 |
+
# Token j-1 was emitted at frame t-1
|
| 123 |
+
token_frames[j - 1].insert(0, t - 1)
|
| 124 |
+
j -= 1
|
| 125 |
+
# Always decrement time (monotonic)
|
| 126 |
+
t -= 1
|
| 127 |
+
|
| 128 |
+
# Handle any remaining tokens at the start (edge case)
|
| 129 |
+
while j > 0:
|
| 130 |
+
token_frames[j - 1].insert(0, 0)
|
| 131 |
+
j -= 1
|
| 132 |
+
|
| 133 |
+
# Convert to spans
|
| 134 |
+
token_spans: list[tuple[int, float, float]] = []
|
| 135 |
+
for token_idx, frames in enumerate(token_frames):
|
| 136 |
+
if not frames:
|
| 137 |
+
# Token never emitted - assign minimal span after previous
|
| 138 |
+
if token_spans:
|
| 139 |
+
prev_end = token_spans[-1][2]
|
| 140 |
+
frames = [int(prev_end)]
|
| 141 |
+
else:
|
| 142 |
+
frames = [0]
|
| 143 |
+
|
| 144 |
+
token_id = tokens[token_idx]
|
| 145 |
+
start_frame = float(min(frames))
|
| 146 |
+
end_frame = float(max(frames)) + 1.0
|
| 147 |
+
token_spans.append((token_id, start_frame, end_frame))
|
| 148 |
+
|
| 149 |
+
return token_spans
|
| 150 |
+
|
| 151 |
+
# Offset compensation for Wav2Vec2-BASE systematic bias (in seconds)
|
| 152 |
+
# Calibrated on librispeech-alignments dataset
|
| 153 |
+
START_OFFSET = 0.06 # Subtract from start times (shift earlier)
|
| 154 |
+
END_OFFSET = -0.03 # Add to end times (shift later)
|
| 155 |
+
|
| 156 |
+
@classmethod
|
| 157 |
+
def align(
|
| 158 |
+
cls,
|
| 159 |
+
audio: np.ndarray,
|
| 160 |
+
text: str,
|
| 161 |
+
sample_rate: int = 16000,
|
| 162 |
+
_language: str = "eng",
|
| 163 |
+
_batch_size: int = 16,
|
| 164 |
+
) -> list[dict]:
|
| 165 |
+
"""Align transcript to audio and return word-level timestamps.
|
| 166 |
+
|
| 167 |
+
Uses Viterbi trellis algorithm for optimal forced alignment.
|
| 168 |
+
|
| 169 |
+
Args:
|
| 170 |
+
audio: Audio waveform as numpy array
|
| 171 |
+
text: Transcript text to align
|
| 172 |
+
sample_rate: Audio sample rate (default 16000)
|
| 173 |
+
_language: ISO-639-3 language code (default "eng" for English, unused)
|
| 174 |
+
_batch_size: Batch size for alignment model (unused)
|
| 175 |
+
|
| 176 |
+
Returns:
|
| 177 |
+
List of dicts with 'word', 'start', 'end' keys
|
| 178 |
+
"""
|
| 179 |
+
import torchaudio
|
| 180 |
+
|
| 181 |
+
device = _get_device()
|
| 182 |
+
model, _labels, dictionary = cls.get_instance(device)
|
| 183 |
+
assert cls._bundle is not None and dictionary is not None # Initialized by get_instance
|
| 184 |
+
|
| 185 |
+
# Convert audio to tensor (copy to ensure array is writable)
|
| 186 |
+
if isinstance(audio, np.ndarray):
|
| 187 |
+
waveform = torch.from_numpy(audio.copy()).float()
|
| 188 |
+
else:
|
| 189 |
+
waveform = audio.clone().float()
|
| 190 |
+
|
| 191 |
+
# Ensure 2D (channels, time)
|
| 192 |
+
if waveform.dim() == 1:
|
| 193 |
+
waveform = waveform.unsqueeze(0)
|
| 194 |
+
|
| 195 |
+
# Resample if needed (wav2vec2 expects 16kHz)
|
| 196 |
+
if sample_rate != cls._bundle.sample_rate:
|
| 197 |
+
waveform = torchaudio.functional.resample(
|
| 198 |
+
waveform, sample_rate, cls._bundle.sample_rate
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
waveform = waveform.to(device)
|
| 202 |
+
|
| 203 |
+
# Get emissions from model
|
| 204 |
+
with torch.inference_mode():
|
| 205 |
+
emissions, _ = model(waveform)
|
| 206 |
+
emissions = torch.log_softmax(emissions, dim=-1)
|
| 207 |
+
|
| 208 |
+
emission = emissions[0].cpu()
|
| 209 |
+
|
| 210 |
+
# Normalize text: uppercase, keep only valid characters
|
| 211 |
+
transcript = text.upper()
|
| 212 |
+
|
| 213 |
+
# Build tokens from transcript (including word separators)
|
| 214 |
+
tokens = []
|
| 215 |
+
for char in transcript:
|
| 216 |
+
if char in dictionary:
|
| 217 |
+
tokens.append(dictionary[char])
|
| 218 |
+
elif char == " ":
|
| 219 |
+
tokens.append(dictionary.get("|", dictionary.get(" ", 0)))
|
| 220 |
+
|
| 221 |
+
if not tokens:
|
| 222 |
+
return []
|
| 223 |
+
|
| 224 |
+
# Build Viterbi trellis and backtrack for optimal path
|
| 225 |
+
trellis = cls._get_trellis(emission, tokens, blank_id=0)
|
| 226 |
+
alignment_path = cls._backtrack(trellis, emission, tokens, blank_id=0)
|
| 227 |
+
|
| 228 |
+
# Convert frame indices to time (model stride is 320 samples at 16kHz = 20ms)
|
| 229 |
+
frame_duration = 320 / cls._bundle.sample_rate
|
| 230 |
+
|
| 231 |
+
# Apply separate offset compensation for start/end (Wav2Vec2 systematic bias)
|
| 232 |
+
start_offset = cls.START_OFFSET
|
| 233 |
+
end_offset = cls.END_OFFSET
|
| 234 |
+
|
| 235 |
+
# Group aligned tokens into words based on pipe separator
|
| 236 |
+
words = text.split()
|
| 237 |
+
word_timestamps = []
|
| 238 |
+
current_word_start = None
|
| 239 |
+
current_word_end = None
|
| 240 |
+
word_idx = 0
|
| 241 |
+
separator_id = dictionary.get("|", dictionary.get(" ", 0))
|
| 242 |
+
|
| 243 |
+
for token_id, start_frame, end_frame in alignment_path:
|
| 244 |
+
if token_id == separator_id: # Word separator
|
| 245 |
+
if (
|
| 246 |
+
current_word_start is not None
|
| 247 |
+
and current_word_end is not None
|
| 248 |
+
and word_idx < len(words)
|
| 249 |
+
):
|
| 250 |
+
start_time = max(0.0, current_word_start * frame_duration - start_offset)
|
| 251 |
+
end_time = max(0.0, current_word_end * frame_duration - end_offset)
|
| 252 |
+
word_timestamps.append(
|
| 253 |
+
{
|
| 254 |
+
"word": words[word_idx],
|
| 255 |
+
"start": start_time,
|
| 256 |
+
"end": end_time,
|
| 257 |
+
}
|
| 258 |
+
)
|
| 259 |
+
word_idx += 1
|
| 260 |
+
current_word_start = None
|
| 261 |
+
current_word_end = None
|
| 262 |
+
else:
|
| 263 |
+
if current_word_start is None:
|
| 264 |
+
current_word_start = start_frame
|
| 265 |
+
current_word_end = end_frame
|
| 266 |
+
|
| 267 |
+
# Don't forget the last word
|
| 268 |
+
if (
|
| 269 |
+
current_word_start is not None
|
| 270 |
+
and current_word_end is not None
|
| 271 |
+
and word_idx < len(words)
|
| 272 |
+
):
|
| 273 |
+
start_time = max(0.0, current_word_start * frame_duration - start_offset)
|
| 274 |
+
end_time = max(0.0, current_word_end * frame_duration - end_offset)
|
| 275 |
+
word_timestamps.append(
|
| 276 |
+
{
|
| 277 |
+
"word": words[word_idx],
|
| 278 |
+
"start": start_time,
|
| 279 |
+
"end": end_time,
|
| 280 |
+
}
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
return word_timestamps
|
asr_config.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
import transformers
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class ASRConfig(transformers.PretrainedConfig):
|
| 7 |
+
"""Configuration class for the ASR model.
|
| 8 |
+
|
| 9 |
+
This config combines settings for:
|
| 10 |
+
- Audio encoder (GLM-ASR/Whisper)
|
| 11 |
+
- Text decoder (Qwen)
|
| 12 |
+
- Projector (MLP, MOSA, MoE, QFormer)
|
| 13 |
+
- Generation parameters
|
| 14 |
+
- Training options (SpecAugment, LoRA)
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
model_type = "asr_model"
|
| 18 |
+
is_composition = True
|
| 19 |
+
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
audio_model_id: str = "zai-org/GLM-ASR-Nano-2512",
|
| 23 |
+
text_model_id: str = "Qwen/Qwen3-0.6B",
|
| 24 |
+
attn_implementation: str = "flash_attention_2",
|
| 25 |
+
model_dtype: str = "bfloat16",
|
| 26 |
+
num_beams: Optional[int] = None,
|
| 27 |
+
system_prompt: str = "You are a helpful assistant.",
|
| 28 |
+
encoder_dim: Optional[int] = None,
|
| 29 |
+
llm_dim: Optional[int] = None,
|
| 30 |
+
# Encoder conv layers: list of (padding, kernel_size, stride) tuples
|
| 31 |
+
# Default is Whisper/GLM-ASR structure: conv1(k=3,s=1,p=1) + conv2(k=3,s=2,p=1)
|
| 32 |
+
encoder_conv_layers: Optional[list] = None,
|
| 33 |
+
audio_sample_rate: int = 16000,
|
| 34 |
+
projector_pool_stride: int = 4,
|
| 35 |
+
downsample_rate: int = 5, # Granite default
|
| 36 |
+
projector_hidden_dim: Optional[int] = None,
|
| 37 |
+
projector_type: str = "mlp", # "mlp", "mosa", "moe", "qformer"
|
| 38 |
+
projector_num_layers: int = 2, # Number of layers in MLP projector
|
| 39 |
+
projector_init_std: float = 0.02, # Weight initialization std
|
| 40 |
+
projector_dropout: float = 0.0, # Dropout rate for projector layers
|
| 41 |
+
# MoE-specific configuration
|
| 42 |
+
num_experts: int = 4, # Number of experts in MoE projectors
|
| 43 |
+
num_experts_per_tok: int = 2, # Top-k experts per token
|
| 44 |
+
router_aux_loss_coef: float = 0.01, # Auxiliary loss coefficient for load balancing
|
| 45 |
+
# QFormer-specific configuration (Granite defaults)
|
| 46 |
+
qformer_window_size: int = 15, # Window size for QFormer processing
|
| 47 |
+
qformer_hidden_size: Optional[int] = None, # QFormer hidden size (defaults to encoder_dim)
|
| 48 |
+
qformer_num_layers: int = 2, # Number of QFormer transformer layers
|
| 49 |
+
qformer_num_heads: int = 16, # Number of attention heads in QFormer
|
| 50 |
+
qformer_intermediate_size: Optional[int] = None, # FFN size (defaults to 4x hidden)
|
| 51 |
+
label_smoothing: float = 0.0, # Label smoothing for cross-entropy loss
|
| 52 |
+
inference_warmup_tokens: int = 10,
|
| 53 |
+
# SpecAugment settings
|
| 54 |
+
use_specaugment: bool = False,
|
| 55 |
+
num_time_masks: int = 2,
|
| 56 |
+
time_mask_length: int = 10,
|
| 57 |
+
num_freq_masks: int = 0,
|
| 58 |
+
freq_mask_length: int = 10,
|
| 59 |
+
# LoRA configuration (for Stage 2 fine-tuning)
|
| 60 |
+
use_lora: bool = False,
|
| 61 |
+
lora_rank: int = 8, # SALMONN default
|
| 62 |
+
lora_alpha: int = 32, # SALMONN default (scaling factor 4.0)
|
| 63 |
+
lora_dropout: float = 0.0,
|
| 64 |
+
lora_target_modules: Optional[list] = None, # Default: all linear layers
|
| 65 |
+
freeze_projector: bool = False, # True for Stage 2 (LoRA-only training)
|
| 66 |
+
do_sample: bool = False,
|
| 67 |
+
temperature: Optional[float] = None,
|
| 68 |
+
top_p: Optional[float] = None,
|
| 69 |
+
top_k: Optional[int] = None,
|
| 70 |
+
max_new_tokens: Optional[int] = None,
|
| 71 |
+
min_new_tokens: Optional[int] = None,
|
| 72 |
+
repetition_penalty: Optional[float] = None,
|
| 73 |
+
length_penalty: Optional[float] = None,
|
| 74 |
+
no_repeat_ngram_size: Optional[int] = None,
|
| 75 |
+
use_cache: Optional[bool] = None,
|
| 76 |
+
**kwargs,
|
| 77 |
+
):
|
| 78 |
+
"""Initialize ASR model configuration.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
audio_model_id: HuggingFace model ID for audio encoder (GLM-ASR/Whisper)
|
| 82 |
+
text_model_id: HuggingFace model ID for text decoder (Qwen)
|
| 83 |
+
attn_implementation: Attention implementation ("flash_attention_2", "sdpa", "eager")
|
| 84 |
+
model_dtype: Model dtype ("bfloat16", "float16", "float32")
|
| 85 |
+
projector_type: Projector architecture ("mlp", "mosa", "moe", "qformer")
|
| 86 |
+
use_lora: Enable LoRA adapters for Stage 2 fine-tuning
|
| 87 |
+
use_specaugment: Enable SpecAugment data augmentation
|
| 88 |
+
"""
|
| 89 |
+
# Set default generation parameters (greedy decoding only)
|
| 90 |
+
generation_defaults = {
|
| 91 |
+
"num_beams": 1,
|
| 92 |
+
"max_new_tokens": 128,
|
| 93 |
+
"min_new_tokens": 0,
|
| 94 |
+
"repetition_penalty": 1.0,
|
| 95 |
+
"length_penalty": 1.0,
|
| 96 |
+
"no_repeat_ngram_size": 0, # Prevent repeating 3-grams like "so so so"
|
| 97 |
+
"use_cache": True,
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
# Apply defaults (config.json values take precedence)
|
| 101 |
+
kwargs = {**generation_defaults, **kwargs}
|
| 102 |
+
|
| 103 |
+
self.audio_model_id = audio_model_id
|
| 104 |
+
self.text_model_id = text_model_id
|
| 105 |
+
self.attn_implementation = attn_implementation
|
| 106 |
+
self.model_dtype = model_dtype
|
| 107 |
+
self.system_prompt = system_prompt
|
| 108 |
+
self.encoder_dim = encoder_dim
|
| 109 |
+
self.llm_dim = llm_dim
|
| 110 |
+
# Default conv layers for Whisper/GLM-ASR: [(pad, kernel, stride), ...]
|
| 111 |
+
self.encoder_conv_layers = encoder_conv_layers or [(1, 3, 1), (1, 3, 2)]
|
| 112 |
+
self.audio_sample_rate = audio_sample_rate
|
| 113 |
+
self.projector_init_std = projector_init_std
|
| 114 |
+
self.projector_pool_stride = projector_pool_stride
|
| 115 |
+
self.downsample_rate = downsample_rate
|
| 116 |
+
self.projector_hidden_dim = projector_hidden_dim
|
| 117 |
+
self.projector_type = projector_type
|
| 118 |
+
self.projector_num_layers = projector_num_layers
|
| 119 |
+
self.projector_dropout = projector_dropout
|
| 120 |
+
# MoE-specific configuration
|
| 121 |
+
self.num_experts = num_experts
|
| 122 |
+
self.num_experts_per_tok = num_experts_per_tok
|
| 123 |
+
self.router_aux_loss_coef = router_aux_loss_coef
|
| 124 |
+
# QFormer-specific configuration
|
| 125 |
+
self.qformer_window_size = qformer_window_size
|
| 126 |
+
self.qformer_hidden_size = qformer_hidden_size
|
| 127 |
+
self.qformer_num_layers = qformer_num_layers
|
| 128 |
+
self.qformer_num_heads = qformer_num_heads
|
| 129 |
+
self.qformer_intermediate_size = qformer_intermediate_size
|
| 130 |
+
self.label_smoothing = label_smoothing
|
| 131 |
+
self.inference_warmup_tokens = inference_warmup_tokens
|
| 132 |
+
# SpecAugment configuration
|
| 133 |
+
self.use_specaugment = use_specaugment
|
| 134 |
+
self.num_time_masks = num_time_masks
|
| 135 |
+
self.time_mask_length = time_mask_length
|
| 136 |
+
self.num_freq_masks = num_freq_masks
|
| 137 |
+
self.freq_mask_length = freq_mask_length
|
| 138 |
+
# LoRA configuration
|
| 139 |
+
self.use_lora = use_lora
|
| 140 |
+
self.lora_rank = lora_rank
|
| 141 |
+
self.lora_alpha = lora_alpha
|
| 142 |
+
self.lora_dropout = lora_dropout
|
| 143 |
+
self.lora_target_modules = lora_target_modules or [
|
| 144 |
+
"q_proj",
|
| 145 |
+
"k_proj",
|
| 146 |
+
"v_proj",
|
| 147 |
+
"o_proj",
|
| 148 |
+
"gate_proj",
|
| 149 |
+
"up_proj",
|
| 150 |
+
"down_proj",
|
| 151 |
+
]
|
| 152 |
+
self.freeze_projector = freeze_projector
|
| 153 |
+
|
| 154 |
+
# Generation parameters (use explicit value if provided, else use default)
|
| 155 |
+
self.num_beams = num_beams if num_beams is not None else generation_defaults["num_beams"]
|
| 156 |
+
self.max_new_tokens = (
|
| 157 |
+
max_new_tokens if max_new_tokens is not None else generation_defaults["max_new_tokens"]
|
| 158 |
+
)
|
| 159 |
+
self.min_new_tokens = (
|
| 160 |
+
min_new_tokens if min_new_tokens is not None else generation_defaults["min_new_tokens"]
|
| 161 |
+
)
|
| 162 |
+
self.repetition_penalty = (
|
| 163 |
+
repetition_penalty
|
| 164 |
+
if repetition_penalty is not None
|
| 165 |
+
else generation_defaults["repetition_penalty"]
|
| 166 |
+
)
|
| 167 |
+
self.length_penalty = (
|
| 168 |
+
length_penalty if length_penalty is not None else generation_defaults["length_penalty"]
|
| 169 |
+
)
|
| 170 |
+
self.no_repeat_ngram_size = (
|
| 171 |
+
no_repeat_ngram_size
|
| 172 |
+
if no_repeat_ngram_size is not None
|
| 173 |
+
else generation_defaults["no_repeat_ngram_size"]
|
| 174 |
+
)
|
| 175 |
+
self.use_cache = use_cache if use_cache is not None else generation_defaults["use_cache"]
|
| 176 |
+
self.do_sample = do_sample
|
| 177 |
+
self.temperature = temperature
|
| 178 |
+
self.top_p = top_p
|
| 179 |
+
self.top_k = top_k
|
| 180 |
+
|
| 181 |
+
if "audio_config" not in kwargs:
|
| 182 |
+
self.audio_config = transformers.AutoConfig.from_pretrained(audio_model_id)
|
| 183 |
+
# Override dtype to match model_dtype
|
| 184 |
+
self.audio_config.dtype = model_dtype
|
| 185 |
+
else:
|
| 186 |
+
self.audio_config = kwargs.pop("audio_config")
|
| 187 |
+
|
| 188 |
+
if "text_config" not in kwargs:
|
| 189 |
+
self.text_config = transformers.AutoConfig.from_pretrained(
|
| 190 |
+
text_model_id, trust_remote_code=True
|
| 191 |
+
)
|
| 192 |
+
# Override dtype to match model_dtype
|
| 193 |
+
self.text_config.dtype = model_dtype
|
| 194 |
+
else:
|
| 195 |
+
self.text_config = kwargs.pop("text_config")
|
| 196 |
+
|
| 197 |
+
if isinstance(self.text_config, dict):
|
| 198 |
+
# Reconstruct config from dict using the model_type stored in the dict
|
| 199 |
+
model_type = self.text_config["model_type"]
|
| 200 |
+
config_class = transformers.AutoConfig.for_model(model_type).__class__
|
| 201 |
+
self.text_config = config_class(**self.text_config)
|
| 202 |
+
|
| 203 |
+
if isinstance(self.audio_config, dict):
|
| 204 |
+
model_type = self.audio_config.get("model_type")
|
| 205 |
+
if model_type:
|
| 206 |
+
config_class = transformers.AutoConfig.for_model(model_type).__class__
|
| 207 |
+
self.audio_config = config_class(**self.audio_config)
|
| 208 |
+
|
| 209 |
+
super().__init__(**kwargs)
|
| 210 |
+
|
| 211 |
+
# Point encoder to audio_config so pipeline uses correct feature extractor
|
| 212 |
+
# The pipeline looks for config.encoder._name_or_path for feature extractor
|
| 213 |
+
self.encoder = self.audio_config
|
| 214 |
+
|
| 215 |
+
self.auto_map = {
|
| 216 |
+
"AutoConfig": "asr_config.ASRConfig",
|
| 217 |
+
"AutoModel": "asr_modeling.ASRModel",
|
| 218 |
+
"AutoModelForSpeechSeq2Seq": "asr_modeling.ASRModel",
|
| 219 |
+
"AutoProcessor": "asr_processing.ASRProcessor",
|
| 220 |
+
}
|
| 221 |
+
self.custom_pipelines = {
|
| 222 |
+
"automatic-speech-recognition": {
|
| 223 |
+
"impl": "asr_pipeline.ASRPipeline",
|
| 224 |
+
"pt": ["AutoModelForSpeechSeq2Seq"],
|
| 225 |
+
"tf": [],
|
| 226 |
+
"type": "audio",
|
| 227 |
+
}
|
| 228 |
+
}
|
| 229 |
+
self.architectures = ["ASRModel"]
|
| 230 |
+
self.pipeline_tag = "automatic-speech-recognition"
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
transformers.AutoConfig.register("asr_model", ASRConfig)
|
asr_modeling.py
ADDED
|
@@ -0,0 +1,896 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from threading import Thread
|
| 4 |
+
from typing import Iterator, Optional, Union
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from transformers import (
|
| 9 |
+
AutoConfig,
|
| 10 |
+
AutoModel,
|
| 11 |
+
AutoModelForCausalLM,
|
| 12 |
+
AutoTokenizer,
|
| 13 |
+
PreTrainedModel,
|
| 14 |
+
TextIteratorStreamer,
|
| 15 |
+
)
|
| 16 |
+
from transformers.generation import GenerationMixin
|
| 17 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
from .asr_config import ASRConfig
|
| 21 |
+
from .projectors import PROJECTOR_CLASSES
|
| 22 |
+
except ImportError:
|
| 23 |
+
from asr_config import ASRConfig # type: ignore[no-redef]
|
| 24 |
+
from projectors import PROJECTOR_CLASSES # type: ignore[no-redef]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
from torchaudio.transforms import SpecAugment
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class ASRModel(PreTrainedModel, GenerationMixin):
|
| 31 |
+
"""Audio-to-text model combining an audio encoder, projector, and language model."""
|
| 32 |
+
|
| 33 |
+
config_class = ASRConfig
|
| 34 |
+
base_model_prefix = "model"
|
| 35 |
+
main_input_name = "input_features"
|
| 36 |
+
_supports_flash_attn_2 = True
|
| 37 |
+
supports_gradient_checkpointing = True
|
| 38 |
+
_is_loading_from_pretrained: bool = False
|
| 39 |
+
_pretrained_model_path: Optional[str] = None
|
| 40 |
+
|
| 41 |
+
TRANSCRIBE_PROMPT = ""
|
| 42 |
+
|
| 43 |
+
@classmethod
|
| 44 |
+
def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs) -> "ASRModel":
|
| 45 |
+
"""Load model from pretrained, handling device placement correctly."""
|
| 46 |
+
from safetensors.torch import load_file
|
| 47 |
+
from transformers.utils.hub import cached_file
|
| 48 |
+
|
| 49 |
+
config = kwargs.pop("config", None)
|
| 50 |
+
if config is None:
|
| 51 |
+
config = ASRConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
| 52 |
+
|
| 53 |
+
# Set flag to avoid device_map="auto" in sub-model loaders
|
| 54 |
+
cls._is_loading_from_pretrained = True
|
| 55 |
+
cls._pretrained_model_path = pretrained_model_name_or_path
|
| 56 |
+
|
| 57 |
+
try:
|
| 58 |
+
model = cls(config, **kwargs)
|
| 59 |
+
|
| 60 |
+
# Load projector weights from safetensors
|
| 61 |
+
subfolder = kwargs.get("subfolder")
|
| 62 |
+
revision = kwargs.get("revision")
|
| 63 |
+
cache_kwargs = {}
|
| 64 |
+
if subfolder:
|
| 65 |
+
cache_kwargs["subfolder"] = subfolder
|
| 66 |
+
if revision:
|
| 67 |
+
cache_kwargs["revision"] = revision
|
| 68 |
+
|
| 69 |
+
model_file = cached_file(
|
| 70 |
+
pretrained_model_name_or_path,
|
| 71 |
+
"model.safetensors",
|
| 72 |
+
_raise_exceptions_for_missing_entries=False,
|
| 73 |
+
**cache_kwargs,
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
if model_file is not None:
|
| 77 |
+
state_dict = load_file(model_file)
|
| 78 |
+
model.load_state_dict(state_dict, strict=False)
|
| 79 |
+
|
| 80 |
+
# Load LoRA adapters if use_lora is enabled
|
| 81 |
+
if getattr(config, "use_lora", False):
|
| 82 |
+
# Check for adapter_config.json (required by PEFT to load adapters)
|
| 83 |
+
adapter_config_file = cached_file(
|
| 84 |
+
pretrained_model_name_or_path,
|
| 85 |
+
"adapter_config.json",
|
| 86 |
+
_raise_exceptions_for_missing_entries=False,
|
| 87 |
+
**cache_kwargs,
|
| 88 |
+
)
|
| 89 |
+
if adapter_config_file is not None:
|
| 90 |
+
# Load saved adapter weights using the original repo_id/path
|
| 91 |
+
# PEFT handles Hub downloads and caching internally
|
| 92 |
+
from peft import PeftModel
|
| 93 |
+
|
| 94 |
+
model.language_model = PeftModel.from_pretrained(
|
| 95 |
+
model.language_model,
|
| 96 |
+
pretrained_model_name_or_path,
|
| 97 |
+
is_trainable=True,
|
| 98 |
+
**cache_kwargs,
|
| 99 |
+
)
|
| 100 |
+
else:
|
| 101 |
+
# No saved adapters - initialize fresh LLM LoRA for training
|
| 102 |
+
from peft import LoraConfig, get_peft_model
|
| 103 |
+
|
| 104 |
+
lora_config = LoraConfig(
|
| 105 |
+
r=config.lora_rank,
|
| 106 |
+
lora_alpha=config.lora_alpha,
|
| 107 |
+
target_modules=config.lora_target_modules,
|
| 108 |
+
lora_dropout=config.lora_dropout,
|
| 109 |
+
bias="none",
|
| 110 |
+
task_type="CAUSAL_LM",
|
| 111 |
+
)
|
| 112 |
+
model.language_model = get_peft_model(model.language_model, lora_config)
|
| 113 |
+
|
| 114 |
+
return model
|
| 115 |
+
finally:
|
| 116 |
+
cls._is_loading_from_pretrained = False
|
| 117 |
+
cls._pretrained_model_path = None
|
| 118 |
+
|
| 119 |
+
def __init__(self, config: ASRConfig, **kwargs) -> None:
|
| 120 |
+
super().__init__(config)
|
| 121 |
+
|
| 122 |
+
self.system_prompt = config.system_prompt
|
| 123 |
+
target_dtype = getattr(torch, config.model_dtype)
|
| 124 |
+
|
| 125 |
+
# Audio encoder (frozen)
|
| 126 |
+
self.audio_tower = self._load_audio_encoder(config, target_dtype)
|
| 127 |
+
|
| 128 |
+
# Language model (frozen)
|
| 129 |
+
self.language_model = self._load_language_model(config, target_dtype)
|
| 130 |
+
|
| 131 |
+
# Initialize tokenizer and special tokens
|
| 132 |
+
self._init_tokenizer(config)
|
| 133 |
+
|
| 134 |
+
# Set up generation config with greedy decoding defaults
|
| 135 |
+
self.generation_config = self.language_model.generation_config
|
| 136 |
+
self.generation_config.max_new_tokens = config.max_new_tokens
|
| 137 |
+
self.generation_config.min_new_tokens = config.min_new_tokens
|
| 138 |
+
self.generation_config.num_beams = config.num_beams
|
| 139 |
+
self.generation_config.do_sample = config.do_sample
|
| 140 |
+
# Set sampling params from config (None means use model defaults)
|
| 141 |
+
self.generation_config.temperature = config.temperature
|
| 142 |
+
self.generation_config.top_p = config.top_p
|
| 143 |
+
self.generation_config.top_k = config.top_k
|
| 144 |
+
self.generation_config.use_cache = config.use_cache
|
| 145 |
+
self.generation_config.length_penalty = config.length_penalty
|
| 146 |
+
self.generation_config.repetition_penalty = config.repetition_penalty
|
| 147 |
+
self.generation_config.no_repeat_ngram_size = config.no_repeat_ngram_size
|
| 148 |
+
# Set EOS tokens, filtering out any that don't exist in the tokenizer
|
| 149 |
+
eos_candidates = [
|
| 150 |
+
self.tokenizer.convert_tokens_to_ids("<|im_end|>"),
|
| 151 |
+
self.tokenizer.convert_tokens_to_ids("<|endoftext|>"),
|
| 152 |
+
]
|
| 153 |
+
self.generation_config.eos_token_id = [t for t in eos_candidates if t is not None]
|
| 154 |
+
self.generation_config.pad_token_id = self.tokenizer.pad_token_id
|
| 155 |
+
|
| 156 |
+
# Feature extractor for audio preprocessing
|
| 157 |
+
self.feature_extractor = self._create_feature_extractor(config)
|
| 158 |
+
|
| 159 |
+
# Audio projector (trainable unless freeze_projector is set)
|
| 160 |
+
self.projector = self._create_projector(config, target_dtype)
|
| 161 |
+
|
| 162 |
+
# Setup LoRA if enabled (Stage 2 fine-tuning)
|
| 163 |
+
# Skip if loading from pretrained - from_pretrained will handle adapter loading
|
| 164 |
+
if getattr(config, "use_lora", False) and not getattr(
|
| 165 |
+
self.__class__, "_is_loading_from_pretrained", False
|
| 166 |
+
):
|
| 167 |
+
self._setup_lora(config)
|
| 168 |
+
|
| 169 |
+
# Freeze projector if specified (for Stage 2 LoRA-only training)
|
| 170 |
+
if getattr(config, "freeze_projector", False):
|
| 171 |
+
self.projector.requires_grad_(False)
|
| 172 |
+
|
| 173 |
+
# SpecAugment for data augmentation during training
|
| 174 |
+
if getattr(config, "use_specaugment", False):
|
| 175 |
+
self.spec_augment = SpecAugment(
|
| 176 |
+
n_time_masks=config.num_time_masks,
|
| 177 |
+
time_mask_param=config.time_mask_length,
|
| 178 |
+
n_freq_masks=config.num_freq_masks,
|
| 179 |
+
freq_mask_param=config.freq_mask_length,
|
| 180 |
+
)
|
| 181 |
+
else:
|
| 182 |
+
self.spec_augment = None
|
| 183 |
+
|
| 184 |
+
# For model parallelism
|
| 185 |
+
self._no_split_modules = getattr(self.language_model, "_no_split_modules", [])
|
| 186 |
+
|
| 187 |
+
def _create_feature_extractor(self, config: ASRConfig):
|
| 188 |
+
"""Create the appropriate feature extractor for the audio encoder."""
|
| 189 |
+
from transformers import AutoFeatureExtractor
|
| 190 |
+
|
| 191 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained(config.audio_model_id)
|
| 192 |
+
# Disable padding by default - use actual audio length
|
| 193 |
+
feature_extractor.padding = False
|
| 194 |
+
return feature_extractor
|
| 195 |
+
|
| 196 |
+
@classmethod
|
| 197 |
+
def _load_audio_encoder(cls, config: ASRConfig, dtype: torch.dtype) -> nn.Module:
|
| 198 |
+
"""Load and freeze the audio encoder."""
|
| 199 |
+
encoder_kwargs = {
|
| 200 |
+
"attn_implementation": config.attn_implementation,
|
| 201 |
+
"low_cpu_mem_usage": True,
|
| 202 |
+
"dtype": dtype,
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
if "whisper" in config.audio_model_id.lower():
|
| 206 |
+
from transformers import WhisperModel
|
| 207 |
+
|
| 208 |
+
full_model = WhisperModel.from_pretrained(config.audio_model_id, **encoder_kwargs)
|
| 209 |
+
encoder = full_model.encoder
|
| 210 |
+
del full_model
|
| 211 |
+
elif "glm" in config.audio_model_id.lower():
|
| 212 |
+
# GLM-ASR models use audio_tower as the encoder
|
| 213 |
+
# Requires transformers >= 5.x or installed from source
|
| 214 |
+
from transformers import AutoModelForSeq2SeqLM
|
| 215 |
+
|
| 216 |
+
full_model = AutoModelForSeq2SeqLM.from_pretrained(
|
| 217 |
+
config.audio_model_id, trust_remote_code=True, **encoder_kwargs
|
| 218 |
+
)
|
| 219 |
+
# GLM stores encoder at audio_tower (GlmAsrEncoder)
|
| 220 |
+
encoder = full_model.audio_tower
|
| 221 |
+
# Clear references to free VRAM from the LLM decoder
|
| 222 |
+
full_model.language_model = None
|
| 223 |
+
full_model.multi_modal_projector = None
|
| 224 |
+
del full_model
|
| 225 |
+
else:
|
| 226 |
+
encoder = AutoModel.from_pretrained(config.audio_model_id, **encoder_kwargs)
|
| 227 |
+
|
| 228 |
+
encoder.requires_grad_(False)
|
| 229 |
+
encoder.eval()
|
| 230 |
+
return encoder
|
| 231 |
+
|
| 232 |
+
@classmethod
|
| 233 |
+
def _load_language_model(cls, config: ASRConfig, dtype: torch.dtype) -> PreTrainedModel:
|
| 234 |
+
"""Load and freeze the language model."""
|
| 235 |
+
decoder_kwargs = {
|
| 236 |
+
"attn_implementation": config.attn_implementation,
|
| 237 |
+
"trust_remote_code": True,
|
| 238 |
+
"low_cpu_mem_usage": True,
|
| 239 |
+
"dtype": dtype,
|
| 240 |
+
}
|
| 241 |
+
|
| 242 |
+
decoder = AutoModelForCausalLM.from_pretrained(config.text_model_id, **decoder_kwargs)
|
| 243 |
+
decoder.config.use_cache = getattr(config, "use_cache", True)
|
| 244 |
+
decoder.requires_grad_(False)
|
| 245 |
+
decoder.eval()
|
| 246 |
+
return decoder
|
| 247 |
+
|
| 248 |
+
def _create_projector(self, config: ASRConfig, dtype: torch.dtype) -> nn.Module:
|
| 249 |
+
"""Create the trainable audio projector."""
|
| 250 |
+
# Auto-detect dimensions if not specified
|
| 251 |
+
if config.encoder_dim is None:
|
| 252 |
+
enc_cfg = self.audio_tower.config
|
| 253 |
+
config.encoder_dim = getattr(enc_cfg, "hidden_size", None) or getattr(
|
| 254 |
+
enc_cfg, "d_model", None
|
| 255 |
+
)
|
| 256 |
+
if config.encoder_dim is None:
|
| 257 |
+
raise ValueError("Could not auto-detect encoder_dim. Please specify in config.")
|
| 258 |
+
|
| 259 |
+
if config.llm_dim is None:
|
| 260 |
+
dec_cfg = self.language_model.config
|
| 261 |
+
config.llm_dim = getattr(dec_cfg, "hidden_size", None) or getattr(
|
| 262 |
+
dec_cfg, "d_model", None
|
| 263 |
+
)
|
| 264 |
+
if config.llm_dim is None:
|
| 265 |
+
raise ValueError("Could not auto-detect llm_dim. Please specify in config.")
|
| 266 |
+
|
| 267 |
+
# Select projector type based on config
|
| 268 |
+
projector_type = getattr(config, "projector_type", "mlp")
|
| 269 |
+
projector_class = PROJECTOR_CLASSES.get(projector_type)
|
| 270 |
+
if projector_class is None:
|
| 271 |
+
raise ValueError(
|
| 272 |
+
f"Unknown projector_type: {projector_type}. "
|
| 273 |
+
f"Valid options: {list(PROJECTOR_CLASSES.keys())}"
|
| 274 |
+
)
|
| 275 |
+
projector = projector_class(config)
|
| 276 |
+
|
| 277 |
+
# Move projector to same device as language model (important when using quantization)
|
| 278 |
+
device = next(self.language_model.parameters()).device
|
| 279 |
+
return projector.to(device=device, dtype=dtype)
|
| 280 |
+
|
| 281 |
+
def _setup_lora(self, config: ASRConfig):
|
| 282 |
+
"""Apply LoRA adapters to the language model for Stage 2 fine-tuning."""
|
| 283 |
+
from peft import LoraConfig, get_peft_model
|
| 284 |
+
|
| 285 |
+
lora_config = LoraConfig(
|
| 286 |
+
r=config.lora_rank,
|
| 287 |
+
lora_alpha=config.lora_alpha,
|
| 288 |
+
target_modules=config.lora_target_modules,
|
| 289 |
+
lora_dropout=config.lora_dropout,
|
| 290 |
+
bias="none",
|
| 291 |
+
task_type="CAUSAL_LM",
|
| 292 |
+
)
|
| 293 |
+
self.language_model = get_peft_model(self.language_model, lora_config)
|
| 294 |
+
|
| 295 |
+
def _init_tokenizer(self, config: ASRConfig):
|
| 296 |
+
"""Initialize tokenizer with audio token."""
|
| 297 |
+
self.tokenizer = AutoTokenizer.from_pretrained(config.text_model_id, trust_remote_code=True)
|
| 298 |
+
|
| 299 |
+
# Set pad token
|
| 300 |
+
if (
|
| 301 |
+
self.tokenizer.pad_token is None
|
| 302 |
+
or self.tokenizer.pad_token_id == self.tokenizer.eos_token_id
|
| 303 |
+
) and "<|finetune_right_pad_id|>" in self.tokenizer.get_vocab():
|
| 304 |
+
self.tokenizer.pad_token = "<|finetune_right_pad_id|>"
|
| 305 |
+
|
| 306 |
+
# Add audio token
|
| 307 |
+
existing_special = getattr(self.tokenizer, "additional_special_tokens", None) or []
|
| 308 |
+
if "<audio>" not in existing_special:
|
| 309 |
+
self.tokenizer.add_special_tokens(
|
| 310 |
+
{"additional_special_tokens": existing_special + ["<audio>"]}
|
| 311 |
+
)
|
| 312 |
+
self.language_model.resize_token_embeddings(len(self.tokenizer), mean_resizing=False)
|
| 313 |
+
|
| 314 |
+
self.audio_token_id = self.tokenizer.convert_tokens_to_ids("<audio>")
|
| 315 |
+
self.tokenizer.padding_side = "right"
|
| 316 |
+
|
| 317 |
+
# Sync token IDs to configs
|
| 318 |
+
for cfg in [self.config.text_config, self.language_model.config, self.generation_config]:
|
| 319 |
+
if cfg is not None:
|
| 320 |
+
cfg.pad_token_id = self.tokenizer.pad_token_id
|
| 321 |
+
cfg.eos_token_id = self.tokenizer.eos_token_id
|
| 322 |
+
cfg.bos_token_id = self.tokenizer.bos_token_id
|
| 323 |
+
|
| 324 |
+
def _init_weights(self, _module):
|
| 325 |
+
"""Weight initialization (projector weights are initialized in MoEAudioProjector)."""
|
| 326 |
+
pass
|
| 327 |
+
|
| 328 |
+
def _set_gradient_checkpointing(self, enable: bool = True, gradient_checkpointing_func=None):
|
| 329 |
+
"""Enable/disable gradient checkpointing for the language model."""
|
| 330 |
+
# The LLM still stores activations during forward for backprop to projector
|
| 331 |
+
# Gradient checkpointing trades compute for memory by recomputing activations
|
| 332 |
+
if hasattr(self.language_model, "_set_gradient_checkpointing"):
|
| 333 |
+
self.language_model._set_gradient_checkpointing(enable, gradient_checkpointing_func)
|
| 334 |
+
elif hasattr(self.language_model, "gradient_checkpointing_enable") and enable:
|
| 335 |
+
self.language_model.gradient_checkpointing_enable(
|
| 336 |
+
gradient_checkpointing_kwargs={"use_reentrant": False}
|
| 337 |
+
)
|
| 338 |
+
elif hasattr(self.language_model, "gradient_checkpointing_disable") and not enable:
|
| 339 |
+
self.language_model.gradient_checkpointing_disable()
|
| 340 |
+
|
| 341 |
+
def get_input_embeddings(self) -> nn.Module:
|
| 342 |
+
return self.language_model.get_input_embeddings()
|
| 343 |
+
|
| 344 |
+
def set_input_embeddings(self, value: nn.Module) -> None:
|
| 345 |
+
self.language_model.set_input_embeddings(value)
|
| 346 |
+
|
| 347 |
+
def get_output_embeddings(self) -> nn.Module:
|
| 348 |
+
return self.language_model.get_output_embeddings()
|
| 349 |
+
|
| 350 |
+
def set_output_embeddings(self, value: nn.Module) -> None:
|
| 351 |
+
self.language_model.set_output_embeddings(value)
|
| 352 |
+
|
| 353 |
+
def get_processor(self):
|
| 354 |
+
"""Get the processor for this model."""
|
| 355 |
+
try:
|
| 356 |
+
from .asr_processing import ASRProcessor
|
| 357 |
+
except ImportError:
|
| 358 |
+
from asr_processing import ASRProcessor # type: ignore[no-redef]
|
| 359 |
+
|
| 360 |
+
return ASRProcessor(
|
| 361 |
+
feature_extractor=self.feature_extractor,
|
| 362 |
+
tokenizer=self.tokenizer,
|
| 363 |
+
projector=self.projector,
|
| 364 |
+
encoder_conv_layers=self.config.encoder_conv_layers,
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
def state_dict(self, *args, **kwargs) -> dict[str, torch.Tensor]:
|
| 368 |
+
"""Only save trainable projector weights."""
|
| 369 |
+
return {f"projector.{k}": v for k, v in self.projector.state_dict().items()}
|
| 370 |
+
|
| 371 |
+
def _compute_encoder_output_lengths(
|
| 372 |
+
self,
|
| 373 |
+
audio_attention_mask: torch.Tensor,
|
| 374 |
+
) -> torch.Tensor:
|
| 375 |
+
"""Compute per-sample encoder output lengths using conv layer formulas.
|
| 376 |
+
|
| 377 |
+
Args:
|
| 378 |
+
audio_attention_mask: Mask indicating real vs padded mel frames (batch, mel_len)
|
| 379 |
+
|
| 380 |
+
Returns:
|
| 381 |
+
Tensor of encoder output lengths per sample (batch,)
|
| 382 |
+
"""
|
| 383 |
+
# Get mel frame lengths from attention mask
|
| 384 |
+
lengths = audio_attention_mask.sum(dim=-1)
|
| 385 |
+
|
| 386 |
+
# Apply conv layer formulas: output = (input + 2*pad - (kernel-1) - 1) // stride + 1
|
| 387 |
+
for padding, kernel_size, stride in self.config.encoder_conv_layers:
|
| 388 |
+
lengths = (lengths + 2 * padding - (kernel_size - 1) - 1) // stride + 1
|
| 389 |
+
|
| 390 |
+
return lengths
|
| 391 |
+
|
| 392 |
+
def _encode_audio(
|
| 393 |
+
self,
|
| 394 |
+
audio_features: torch.Tensor,
|
| 395 |
+
audio_attention_mask: torch.Tensor,
|
| 396 |
+
expected_token_counts: torch.Tensor | None = None,
|
| 397 |
+
) -> torch.Tensor:
|
| 398 |
+
"""Encode audio and project to LLM embedding space.
|
| 399 |
+
|
| 400 |
+
Args:
|
| 401 |
+
audio_features: Mel spectrogram features (batch, n_mels, mel_len)
|
| 402 |
+
audio_attention_mask: Mask indicating real vs padded mel frames (batch, mel_len)
|
| 403 |
+
expected_token_counts: Expected number of audio tokens per sample from input_ids.
|
| 404 |
+
If provided, output will match these counts exactly (padding/truncating as needed).
|
| 405 |
+
|
| 406 |
+
Returns:
|
| 407 |
+
Flattened audio embeddings of shape (total_audio_tokens, hidden_dim).
|
| 408 |
+
"""
|
| 409 |
+
with torch.no_grad():
|
| 410 |
+
encoder_out = self.audio_tower(input_features=audio_features)
|
| 411 |
+
hidden_states = encoder_out.last_hidden_state
|
| 412 |
+
|
| 413 |
+
# Project to LLM space
|
| 414 |
+
audio_embeds = self.projector(hidden_states)
|
| 415 |
+
|
| 416 |
+
# Use expected token counts if provided (from input_ids), otherwise compute from audio
|
| 417 |
+
if expected_token_counts is not None:
|
| 418 |
+
token_counts = expected_token_counts
|
| 419 |
+
else:
|
| 420 |
+
# Compute per-sample encoder output lengths using conv formulas
|
| 421 |
+
encoder_lengths = self._compute_encoder_output_lengths(audio_attention_mask)
|
| 422 |
+
token_counts = torch.tensor(
|
| 423 |
+
[
|
| 424 |
+
self.projector.get_output_length(int(length.item()))
|
| 425 |
+
for length in encoder_lengths
|
| 426 |
+
],
|
| 427 |
+
device=audio_embeds.device,
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
# Extract embeddings matching expected token counts per sample
|
| 431 |
+
batch_size = audio_embeds.shape[0]
|
| 432 |
+
hidden_dim = audio_embeds.shape[2]
|
| 433 |
+
|
| 434 |
+
result_embeds = []
|
| 435 |
+
for i in range(batch_size):
|
| 436 |
+
count = int(token_counts[i].item())
|
| 437 |
+
sample_embeds = audio_embeds[i, :count, :] # Take first 'count' embeddings
|
| 438 |
+
# Pad with zeros if we don't have enough embeddings
|
| 439 |
+
if sample_embeds.shape[0] < count:
|
| 440 |
+
padding = torch.zeros(
|
| 441 |
+
count - sample_embeds.shape[0],
|
| 442 |
+
hidden_dim,
|
| 443 |
+
device=audio_embeds.device,
|
| 444 |
+
dtype=audio_embeds.dtype,
|
| 445 |
+
)
|
| 446 |
+
sample_embeds = torch.cat([sample_embeds, padding], dim=0)
|
| 447 |
+
result_embeds.append(sample_embeds)
|
| 448 |
+
|
| 449 |
+
return torch.cat(result_embeds, dim=0)
|
| 450 |
+
|
| 451 |
+
def forward(
|
| 452 |
+
self,
|
| 453 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 454 |
+
input_features: Optional[torch.Tensor] = None,
|
| 455 |
+
audio_attention_mask: Optional[torch.Tensor] = None,
|
| 456 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 457 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 458 |
+
past_key_values: Optional[torch.Tensor] = None,
|
| 459 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 460 |
+
labels: Optional[torch.Tensor] = None,
|
| 461 |
+
use_cache: Optional[bool] = None,
|
| 462 |
+
cache_position: Optional[torch.Tensor] = None,
|
| 463 |
+
**kwargs,
|
| 464 |
+
) -> CausalLMOutputWithPast:
|
| 465 |
+
"""Forward pass for training and inference."""
|
| 466 |
+
# Get text embeddings if not provided
|
| 467 |
+
if inputs_embeds is None:
|
| 468 |
+
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
| 469 |
+
|
| 470 |
+
if input_features is not None and input_ids is not None:
|
| 471 |
+
# Apply SpecAugment during training if enabled
|
| 472 |
+
if self.training and self.spec_augment is not None:
|
| 473 |
+
input_features = self.spec_augment(input_features)
|
| 474 |
+
|
| 475 |
+
# Count expected audio tokens from input_ids (ground truth from collator)
|
| 476 |
+
audio_token_counts = (input_ids == self.audio_token_id).sum(dim=-1)
|
| 477 |
+
|
| 478 |
+
# Encode audio -> flattened (total_audio_tokens, hidden_dim)
|
| 479 |
+
audio_embeds = self._encode_audio(
|
| 480 |
+
input_features, audio_attention_mask, audio_token_counts
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
# Replace <audio> token placeholders with audio embeddings using masked_scatter
|
| 484 |
+
audio_token_mask = (input_ids == self.audio_token_id).unsqueeze(-1)
|
| 485 |
+
inputs_embeds = inputs_embeds.masked_scatter(
|
| 486 |
+
audio_token_mask.to(inputs_embeds.device),
|
| 487 |
+
audio_embeds.to(inputs_embeds.device, dtype=inputs_embeds.dtype),
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
+
# Run through language model (let it compute loss if labels provided)
|
| 491 |
+
outputs = self.language_model(
|
| 492 |
+
attention_mask=attention_mask,
|
| 493 |
+
position_ids=position_ids,
|
| 494 |
+
past_key_values=past_key_values,
|
| 495 |
+
inputs_embeds=inputs_embeds,
|
| 496 |
+
labels=labels,
|
| 497 |
+
use_cache=use_cache,
|
| 498 |
+
cache_position=cache_position,
|
| 499 |
+
**kwargs,
|
| 500 |
+
)
|
| 501 |
+
|
| 502 |
+
# Add auxiliary loss from MoE projectors if available
|
| 503 |
+
if outputs.loss is not None and hasattr(self.projector, "get_aux_loss"):
|
| 504 |
+
aux_loss = self.projector.get_aux_loss()
|
| 505 |
+
if aux_loss is not None and aux_loss.numel() > 0:
|
| 506 |
+
outputs.loss = outputs.loss + aux_loss.to(outputs.loss.device)
|
| 507 |
+
|
| 508 |
+
return outputs
|
| 509 |
+
|
| 510 |
+
def prepare_inputs_for_generation(self, *args, **kwargs):
|
| 511 |
+
"""Prepare inputs for generation, handling audio features for cached decoding."""
|
| 512 |
+
input_features = kwargs.pop("input_features", None)
|
| 513 |
+
cache_position = kwargs.get("cache_position")
|
| 514 |
+
|
| 515 |
+
model_inputs = self.language_model.prepare_inputs_for_generation(*args, **kwargs)
|
| 516 |
+
|
| 517 |
+
# Only pass audio features on the first generation step (cache_position[0] == 0)
|
| 518 |
+
if cache_position is not None and cache_position[0] == 0 and input_features is not None:
|
| 519 |
+
model_inputs["input_features"] = input_features
|
| 520 |
+
|
| 521 |
+
return model_inputs
|
| 522 |
+
|
| 523 |
+
def _get_num_audio_tokens(
|
| 524 |
+
self,
|
| 525 |
+
audio_attention_mask: torch.Tensor,
|
| 526 |
+
) -> int:
|
| 527 |
+
"""Calculate number of audio tokens based on actual audio length.
|
| 528 |
+
|
| 529 |
+
Uses attention mask to get real audio length, then computes:
|
| 530 |
+
mel_frames -> encoder_frames (via conv formulas) -> projector output tokens
|
| 531 |
+
"""
|
| 532 |
+
encoder_lengths = self._compute_encoder_output_lengths(audio_attention_mask)
|
| 533 |
+
# Use max length for batch (all samples should have same token count for generation)
|
| 534 |
+
encoder_output_len = int(encoder_lengths.max().item())
|
| 535 |
+
return int(self.projector.get_output_length(encoder_output_len))
|
| 536 |
+
|
| 537 |
+
@torch.no_grad()
|
| 538 |
+
def generate(
|
| 539 |
+
self,
|
| 540 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 541 |
+
input_features: Optional[torch.Tensor] = None,
|
| 542 |
+
audio_attention_mask: Optional[torch.Tensor] = None,
|
| 543 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 544 |
+
system_prompt: Optional[str] = None,
|
| 545 |
+
**generate_kwargs,
|
| 546 |
+
) -> torch.Tensor:
|
| 547 |
+
"""Generate transcription from audio input.
|
| 548 |
+
|
| 549 |
+
Can be called in two ways:
|
| 550 |
+
1. With input_ids containing <audio> tokens (from processor)
|
| 551 |
+
2. With just audio, and we build the prompt internally
|
| 552 |
+
"""
|
| 553 |
+
if input_features is None:
|
| 554 |
+
raise ValueError("input_features required for generation")
|
| 555 |
+
if audio_attention_mask is None:
|
| 556 |
+
raise ValueError("audio_attention_mask required for generation")
|
| 557 |
+
|
| 558 |
+
device = input_features.device
|
| 559 |
+
batch_size = input_features.shape[0]
|
| 560 |
+
|
| 561 |
+
# Encode audio -> flattened embeddings
|
| 562 |
+
audio_embeds = self._encode_audio(input_features, audio_attention_mask)
|
| 563 |
+
|
| 564 |
+
# If input_ids not provided, build prompt with correct number of audio tokens
|
| 565 |
+
if input_ids is None:
|
| 566 |
+
num_audio_tokens = self._get_num_audio_tokens(audio_attention_mask)
|
| 567 |
+
audio_placeholder = "<audio>" * num_audio_tokens
|
| 568 |
+
|
| 569 |
+
system_prompt = system_prompt or self.system_prompt
|
| 570 |
+
|
| 571 |
+
messages: list[dict[str, str]] = []
|
| 572 |
+
if system_prompt:
|
| 573 |
+
messages.append({"role": "system", "content": system_prompt})
|
| 574 |
+
# Audio tokens only (instruction-free)
|
| 575 |
+
user_content = audio_placeholder
|
| 576 |
+
if self.TRANSCRIBE_PROMPT:
|
| 577 |
+
user_content += " " + self.TRANSCRIBE_PROMPT
|
| 578 |
+
messages.append({"role": "user", "content": user_content})
|
| 579 |
+
|
| 580 |
+
chat_result = self.tokenizer.apply_chat_template(
|
| 581 |
+
messages,
|
| 582 |
+
tokenize=True,
|
| 583 |
+
add_generation_prompt=True,
|
| 584 |
+
return_tensors="pt",
|
| 585 |
+
enable_thinking=False, # Disable Qwen3 thinking mode for ASR
|
| 586 |
+
)
|
| 587 |
+
input_ids = chat_result.input_ids.to(device)
|
| 588 |
+
|
| 589 |
+
if input_ids.dim() == 1:
|
| 590 |
+
input_ids = input_ids.unsqueeze(0)
|
| 591 |
+
if input_ids.shape[0] == 1 and batch_size > 1:
|
| 592 |
+
input_ids = input_ids.expand(batch_size, -1)
|
| 593 |
+
|
| 594 |
+
attention_mask = torch.ones_like(input_ids)
|
| 595 |
+
|
| 596 |
+
# Get text embeddings and replace audio tokens with audio embeddings
|
| 597 |
+
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
| 598 |
+
audio_token_mask = (input_ids == self.audio_token_id).unsqueeze(-1)
|
| 599 |
+
inputs_embeds = inputs_embeds.masked_scatter(
|
| 600 |
+
audio_token_mask.to(inputs_embeds.device),
|
| 601 |
+
audio_embeds.to(inputs_embeds.device, dtype=inputs_embeds.dtype),
|
| 602 |
+
)
|
| 603 |
+
|
| 604 |
+
# Generate using language model
|
| 605 |
+
# Pass both input_ids and inputs_embeds so repetition_penalty works correctly
|
| 606 |
+
# (it needs input_ids to track which tokens have been used)
|
| 607 |
+
output = self.language_model.generate(
|
| 608 |
+
input_ids=input_ids,
|
| 609 |
+
inputs_embeds=inputs_embeds,
|
| 610 |
+
attention_mask=attention_mask,
|
| 611 |
+
generation_config=self.generation_config,
|
| 612 |
+
**generate_kwargs,
|
| 613 |
+
)
|
| 614 |
+
|
| 615 |
+
# When using inputs_embeds with input_ids, generate returns full sequence
|
| 616 |
+
# Strip the input tokens to return only generated tokens
|
| 617 |
+
sequences = output if isinstance(output, torch.Tensor) else output.sequences
|
| 618 |
+
input_len = input_ids.shape[1]
|
| 619 |
+
return sequences[:, input_len:]
|
| 620 |
+
|
| 621 |
+
def generate_streaming(
|
| 622 |
+
self,
|
| 623 |
+
input_features: torch.Tensor,
|
| 624 |
+
audio_attention_mask: torch.Tensor,
|
| 625 |
+
system_prompt: Optional[str] = None,
|
| 626 |
+
**generate_kwargs,
|
| 627 |
+
) -> Iterator[str]:
|
| 628 |
+
"""Generate transcription with streaming token output.
|
| 629 |
+
|
| 630 |
+
Yields partial transcript strings as tokens are generated.
|
| 631 |
+
Reduces time-to-first-word by streaming tokens as they're decoded.
|
| 632 |
+
|
| 633 |
+
Args:
|
| 634 |
+
input_features: Mel spectrogram features (batch, n_mels, mel_len)
|
| 635 |
+
audio_attention_mask: Mask for real vs padded mel frames (batch, mel_len)
|
| 636 |
+
system_prompt: Optional system prompt override
|
| 637 |
+
**generate_kwargs: Additional generation arguments
|
| 638 |
+
|
| 639 |
+
Yields:
|
| 640 |
+
Partial transcript text as each token is generated
|
| 641 |
+
"""
|
| 642 |
+
device = input_features.device
|
| 643 |
+
batch_size = input_features.shape[0]
|
| 644 |
+
|
| 645 |
+
# Encode audio -> flattened embeddings
|
| 646 |
+
audio_embeds = self._encode_audio(input_features, audio_attention_mask)
|
| 647 |
+
|
| 648 |
+
# Build prompt with correct number of audio tokens
|
| 649 |
+
num_audio_tokens = self._get_num_audio_tokens(audio_attention_mask)
|
| 650 |
+
audio_placeholder = "<audio>" * num_audio_tokens
|
| 651 |
+
|
| 652 |
+
system_prompt = system_prompt or self.system_prompt
|
| 653 |
+
|
| 654 |
+
messages: list[dict[str, str]] = []
|
| 655 |
+
if system_prompt:
|
| 656 |
+
messages.append({"role": "system", "content": system_prompt})
|
| 657 |
+
# Audio tokens only (instruction-free)
|
| 658 |
+
user_content = audio_placeholder
|
| 659 |
+
if self.TRANSCRIBE_PROMPT:
|
| 660 |
+
user_content += " " + self.TRANSCRIBE_PROMPT
|
| 661 |
+
messages.append({"role": "user", "content": user_content})
|
| 662 |
+
|
| 663 |
+
chat_result = self.tokenizer.apply_chat_template(
|
| 664 |
+
messages,
|
| 665 |
+
tokenize=True,
|
| 666 |
+
add_generation_prompt=True,
|
| 667 |
+
return_tensors="pt",
|
| 668 |
+
enable_thinking=False, # Disable Qwen3 thinking mode for ASR
|
| 669 |
+
)
|
| 670 |
+
input_ids = chat_result.input_ids.to(device)
|
| 671 |
+
|
| 672 |
+
if input_ids.dim() == 1:
|
| 673 |
+
input_ids = input_ids.unsqueeze(0)
|
| 674 |
+
if input_ids.shape[0] == 1 and batch_size > 1:
|
| 675 |
+
input_ids = input_ids.expand(batch_size, -1)
|
| 676 |
+
|
| 677 |
+
attention_mask = torch.ones_like(input_ids)
|
| 678 |
+
|
| 679 |
+
# Get text embeddings and replace audio tokens with audio embeddings
|
| 680 |
+
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
| 681 |
+
audio_token_mask = (input_ids == self.audio_token_id).unsqueeze(-1)
|
| 682 |
+
inputs_embeds = inputs_embeds.masked_scatter(
|
| 683 |
+
audio_token_mask.to(inputs_embeds.device),
|
| 684 |
+
audio_embeds.to(inputs_embeds.device, dtype=inputs_embeds.dtype),
|
| 685 |
+
)
|
| 686 |
+
|
| 687 |
+
# Setup streamer for token-by-token output
|
| 688 |
+
streamer = TextIteratorStreamer(
|
| 689 |
+
self.tokenizer,
|
| 690 |
+
skip_prompt=True,
|
| 691 |
+
skip_special_tokens=True,
|
| 692 |
+
)
|
| 693 |
+
|
| 694 |
+
# Prepare generation kwargs
|
| 695 |
+
gen_kwargs = {
|
| 696 |
+
"inputs_embeds": inputs_embeds,
|
| 697 |
+
"attention_mask": attention_mask,
|
| 698 |
+
"generation_config": self.generation_config,
|
| 699 |
+
"streamer": streamer,
|
| 700 |
+
**generate_kwargs,
|
| 701 |
+
}
|
| 702 |
+
|
| 703 |
+
# Run generation in background thread
|
| 704 |
+
thread = Thread(target=self.language_model.generate, kwargs=gen_kwargs)
|
| 705 |
+
thread.start()
|
| 706 |
+
|
| 707 |
+
# Yield tokens as they're generated, filtering out <think>...</think> blocks
|
| 708 |
+
# Start assuming no think block - only filter when we see <think>
|
| 709 |
+
in_think_block = False
|
| 710 |
+
buffer = ""
|
| 711 |
+
|
| 712 |
+
for text in streamer:
|
| 713 |
+
buffer += text
|
| 714 |
+
|
| 715 |
+
# Check for think block start (in case model outputs think blocks)
|
| 716 |
+
while "<think>" in buffer:
|
| 717 |
+
in_think_block = True
|
| 718 |
+
# Yield any text before <think>
|
| 719 |
+
before_think = buffer.split("<think>")[0]
|
| 720 |
+
if before_think:
|
| 721 |
+
yield before_think
|
| 722 |
+
buffer = buffer.split("<think>", 1)[-1]
|
| 723 |
+
|
| 724 |
+
# Check for think block end
|
| 725 |
+
while in_think_block and "</think>" in buffer:
|
| 726 |
+
in_think_block = False
|
| 727 |
+
buffer = buffer.split("</think>", 1)[-1]
|
| 728 |
+
|
| 729 |
+
# Yield text if not in think block
|
| 730 |
+
if not in_think_block and buffer:
|
| 731 |
+
yield buffer
|
| 732 |
+
buffer = ""
|
| 733 |
+
|
| 734 |
+
# Yield any remaining buffer
|
| 735 |
+
if buffer and not in_think_block:
|
| 736 |
+
yield buffer
|
| 737 |
+
|
| 738 |
+
thread.join()
|
| 739 |
+
|
| 740 |
+
@torch.no_grad()
|
| 741 |
+
def generate_text_only(
|
| 742 |
+
self,
|
| 743 |
+
messages: list[dict[str, str]],
|
| 744 |
+
max_new_tokens: int = 256,
|
| 745 |
+
**generate_kwargs,
|
| 746 |
+
) -> str:
|
| 747 |
+
"""Generate text using only the LLM (no audio encoding).
|
| 748 |
+
|
| 749 |
+
Used for SIFT-style response generation from metadata prompts.
|
| 750 |
+
|
| 751 |
+
Args:
|
| 752 |
+
messages: List of chat messages [{"role": "user", "content": "..."}]
|
| 753 |
+
max_new_tokens: Maximum tokens to generate
|
| 754 |
+
**generate_kwargs: Additional generation arguments
|
| 755 |
+
|
| 756 |
+
Returns:
|
| 757 |
+
Generated text response
|
| 758 |
+
"""
|
| 759 |
+
device = next(self.language_model.parameters()).device
|
| 760 |
+
|
| 761 |
+
# Apply chat template
|
| 762 |
+
input_ids = self.tokenizer.apply_chat_template(
|
| 763 |
+
messages,
|
| 764 |
+
tokenize=True,
|
| 765 |
+
add_generation_prompt=True,
|
| 766 |
+
return_tensors="pt",
|
| 767 |
+
enable_thinking=False, # Disable Qwen3 thinking mode for ASR
|
| 768 |
+
).to(device)
|
| 769 |
+
|
| 770 |
+
if input_ids.dim() == 1:
|
| 771 |
+
input_ids = input_ids.unsqueeze(0)
|
| 772 |
+
|
| 773 |
+
attention_mask = torch.ones_like(input_ids)
|
| 774 |
+
|
| 775 |
+
# Generate using language model directly
|
| 776 |
+
output = self.language_model.generate(
|
| 777 |
+
input_ids=input_ids,
|
| 778 |
+
attention_mask=attention_mask,
|
| 779 |
+
max_new_tokens=max_new_tokens,
|
| 780 |
+
do_sample=False,
|
| 781 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
| 782 |
+
eos_token_id=self.tokenizer.eos_token_id,
|
| 783 |
+
**generate_kwargs,
|
| 784 |
+
)
|
| 785 |
+
|
| 786 |
+
# Decode only the new tokens
|
| 787 |
+
new_tokens = output[0, input_ids.shape[1] :]
|
| 788 |
+
response = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
|
| 789 |
+
return response.strip()
|
| 790 |
+
|
| 791 |
+
def save_pretrained(self, save_directory: Union[str, Path], **kwargs) -> None:
|
| 792 |
+
"""Save model, tokenizer, and processor."""
|
| 793 |
+
import shutil
|
| 794 |
+
from pathlib import Path as PathlibPath
|
| 795 |
+
|
| 796 |
+
save_dir = PathlibPath(save_directory)
|
| 797 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
| 798 |
+
|
| 799 |
+
# Update config with actual vocab size
|
| 800 |
+
self.config.vocab_size = self.language_model.config.vocab_size
|
| 801 |
+
self.config.text_config.vocab_size = self.language_model.config.vocab_size
|
| 802 |
+
|
| 803 |
+
if hasattr(self.audio_tower.config, "num_mel_bins"):
|
| 804 |
+
self.config.audio_config.num_mel_bins = self.audio_tower.config.num_mel_bins
|
| 805 |
+
|
| 806 |
+
# Save model (temporarily remove non-serializable attributes)
|
| 807 |
+
tokenizer = self.tokenizer
|
| 808 |
+
del self.tokenizer
|
| 809 |
+
|
| 810 |
+
try:
|
| 811 |
+
super().save_pretrained(save_dir, **kwargs)
|
| 812 |
+
finally:
|
| 813 |
+
self.tokenizer = tokenizer
|
| 814 |
+
|
| 815 |
+
# Save tokenizer and feature extractor
|
| 816 |
+
self.tokenizer.save_pretrained(save_dir)
|
| 817 |
+
self.feature_extractor.save_pretrained(save_dir)
|
| 818 |
+
|
| 819 |
+
# Save LoRA adapters if present (creates adapter_model.safetensors and adapter_config.json)
|
| 820 |
+
# Don't save embedding layers - the <audio> token embedding is never used
|
| 821 |
+
# (it's replaced with projected audio embeddings before the LLM sees it)
|
| 822 |
+
if hasattr(self.language_model, "peft_config"):
|
| 823 |
+
self.language_model.save_pretrained(save_dir, save_embedding_layers=False)
|
| 824 |
+
|
| 825 |
+
# Clear base_model_name_or_path in adapter_config.json to prevent HF pipeline
|
| 826 |
+
# from redirecting to the base LLM repo (like Qwen) which breaks feature
|
| 827 |
+
# extractor loading for multimodal models. If a repo_id is provided, use that
|
| 828 |
+
# so the model can be loaded directly from the Hub.
|
| 829 |
+
adapter_config_path = save_dir / "adapter_config.json"
|
| 830 |
+
if adapter_config_path.exists():
|
| 831 |
+
with adapter_config_path.open() as f:
|
| 832 |
+
adapter_config = json.load(f)
|
| 833 |
+
|
| 834 |
+
# Use repo_id if available, otherwise clear to prevent redirect.
|
| 835 |
+
# Use empty string instead of None to avoid str(None) -> "None" bug
|
| 836 |
+
# in some transformers/PEFT versions.
|
| 837 |
+
repo_id = (
|
| 838 |
+
kwargs.get("repo_id")
|
| 839 |
+
or kwargs.get("push_to_hub_model_id")
|
| 840 |
+
or getattr(self.config, "pretrained_model_path", None)
|
| 841 |
+
or "" # Use empty string instead of None
|
| 842 |
+
)
|
| 843 |
+
adapter_config["base_model_name_or_path"] = repo_id
|
| 844 |
+
|
| 845 |
+
with adapter_config_path.open("w") as f:
|
| 846 |
+
json.dump(adapter_config, f, indent=2)
|
| 847 |
+
|
| 848 |
+
# Add processor auto_map to preprocessor_config.json
|
| 849 |
+
config_path = save_dir / "preprocessor_config.json"
|
| 850 |
+
if config_path.exists():
|
| 851 |
+
with config_path.open() as f:
|
| 852 |
+
processor_config = json.load(f)
|
| 853 |
+
else:
|
| 854 |
+
processor_config = {}
|
| 855 |
+
|
| 856 |
+
processor_config.update(
|
| 857 |
+
{
|
| 858 |
+
"processor_class": "ASRProcessor",
|
| 859 |
+
"auto_map": {"AutoProcessor": "asr_processing.ASRProcessor"},
|
| 860 |
+
}
|
| 861 |
+
)
|
| 862 |
+
|
| 863 |
+
with config_path.open("w") as f:
|
| 864 |
+
json.dump(processor_config, f, indent=2)
|
| 865 |
+
|
| 866 |
+
# Copy source files for auto-loading
|
| 867 |
+
src_dir = PathlibPath(__file__).parent
|
| 868 |
+
for asr_file in src_dir.glob("asr_*.py"):
|
| 869 |
+
shutil.copy(asr_file, save_dir / asr_file.name)
|
| 870 |
+
# Copy projectors module
|
| 871 |
+
shutil.copy(src_dir / "projectors.py", save_dir / "projectors.py")
|
| 872 |
+
# Copy alignment module
|
| 873 |
+
shutil.copy(src_dir / "alignment.py", save_dir / "alignment.py")
|
| 874 |
+
# Copy diarization module
|
| 875 |
+
shutil.copy(src_dir / "diarization.py", save_dir / "diarization.py")
|
| 876 |
+
|
| 877 |
+
def push_to_hub(self, repo_id: str, **kwargs) -> str:
|
| 878 |
+
"""Push model to HuggingFace Hub, ensuring adapter_config points to repo.
|
| 879 |
+
|
| 880 |
+
IMPORTANT: Sets base_model_name_or_path in adapter_config.json to repo_id
|
| 881 |
+
so that transformers pipeline() can load the model correctly. Without this,
|
| 882 |
+
the pipeline tries to load from "None" which fails.
|
| 883 |
+
"""
|
| 884 |
+
# Store repo_id in config so save_pretrained can access it
|
| 885 |
+
self.config.pretrained_model_path = repo_id
|
| 886 |
+
# Call parent's push_to_hub
|
| 887 |
+
return super().push_to_hub(repo_id, **kwargs)
|
| 888 |
+
|
| 889 |
+
def create_or_update_model_card(self, output_dir: Union[str, Path]) -> None:
|
| 890 |
+
"""No-op for model card creation - we use MODEL_CARD.md in repo instead."""
|
| 891 |
+
pass
|
| 892 |
+
|
| 893 |
+
|
| 894 |
+
# Register with transformers Auto classes
|
| 895 |
+
AutoConfig.register("asr_model", ASRConfig)
|
| 896 |
+
AutoModel.register(ASRConfig, ASRModel)
|
asr_pipeline.py
ADDED
|
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ASR pipeline for audio-to-text transcription with optional timestamps and diarization."""
|
| 2 |
+
|
| 3 |
+
import re
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Any
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import transformers
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
from .alignment import ForcedAligner
|
| 13 |
+
from .asr_modeling import ASRModel
|
| 14 |
+
from .diarization import SpeakerDiarizer
|
| 15 |
+
except ImportError:
|
| 16 |
+
from alignment import ForcedAligner # type: ignore[no-redef]
|
| 17 |
+
from asr_modeling import ASRModel # type: ignore[no-redef]
|
| 18 |
+
from diarization import SpeakerDiarizer # type: ignore[no-redef]
|
| 19 |
+
|
| 20 |
+
# Re-export for backwards compatibility
|
| 21 |
+
__all__ = ["ForcedAligner", "SpeakerDiarizer", "ASRPipeline"]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class ASRPipeline(transformers.AutomaticSpeechRecognitionPipeline):
|
| 25 |
+
"""ASR Pipeline for audio-to-text transcription."""
|
| 26 |
+
|
| 27 |
+
model: ASRModel
|
| 28 |
+
|
| 29 |
+
def __init__(self, model: ASRModel, **kwargs):
|
| 30 |
+
"""Initialize ASR pipeline.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
model: ASRModel instance for transcription
|
| 34 |
+
**kwargs: Additional arguments (feature_extractor, tokenizer, device)
|
| 35 |
+
"""
|
| 36 |
+
feature_extractor = kwargs.pop("feature_extractor", None)
|
| 37 |
+
tokenizer = kwargs.pop("tokenizer", model.tokenizer)
|
| 38 |
+
|
| 39 |
+
if feature_extractor is None:
|
| 40 |
+
feature_extractor = model.get_processor().feature_extractor
|
| 41 |
+
|
| 42 |
+
super().__init__(
|
| 43 |
+
model=model, feature_extractor=feature_extractor, tokenizer=tokenizer, **kwargs
|
| 44 |
+
)
|
| 45 |
+
self._current_audio = None
|
| 46 |
+
|
| 47 |
+
def _sanitize_parameters(self, **kwargs):
|
| 48 |
+
"""Intercept our custom parameters before parent class validates them."""
|
| 49 |
+
# Remove our custom parameters so parent doesn't see them
|
| 50 |
+
kwargs.pop("return_timestamps", None)
|
| 51 |
+
kwargs.pop("return_speakers", None)
|
| 52 |
+
kwargs.pop("num_speakers", None)
|
| 53 |
+
kwargs.pop("min_speakers", None)
|
| 54 |
+
kwargs.pop("max_speakers", None)
|
| 55 |
+
kwargs.pop("hf_token", None)
|
| 56 |
+
kwargs.pop("user_prompt", None)
|
| 57 |
+
kwargs.pop("diarization_backend", None)
|
| 58 |
+
|
| 59 |
+
return super()._sanitize_parameters(**kwargs)
|
| 60 |
+
|
| 61 |
+
def __call__(
|
| 62 |
+
self,
|
| 63 |
+
inputs,
|
| 64 |
+
**kwargs,
|
| 65 |
+
):
|
| 66 |
+
"""Transcribe audio with optional word-level timestamps and speaker diarization.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
inputs: Audio input (file path, dict with array/sampling_rate, etc.)
|
| 70 |
+
return_timestamps: If True, return word-level timestamps using forced alignment
|
| 71 |
+
return_speakers: If True, return speaker labels for each word
|
| 72 |
+
user_prompt: Custom transcription prompt (default: "Transcribe: ")
|
| 73 |
+
num_speakers: Exact number of speakers (if known, for diarization)
|
| 74 |
+
min_speakers: Minimum number of speakers (for diarization)
|
| 75 |
+
max_speakers: Maximum number of speakers (for diarization)
|
| 76 |
+
**kwargs: Additional arguments passed to the pipeline
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
Dict with 'text' key, 'words' key if return_timestamps=True,
|
| 80 |
+
and speaker labels on words if return_speakers=True
|
| 81 |
+
"""
|
| 82 |
+
# Extract our params before super().__call__ (which will also call _sanitize_parameters)
|
| 83 |
+
return_timestamps = kwargs.pop("return_timestamps", False)
|
| 84 |
+
return_speakers = kwargs.pop("return_speakers", False)
|
| 85 |
+
user_prompt = kwargs.pop("user_prompt", None)
|
| 86 |
+
diarization_params = {
|
| 87 |
+
"num_speakers": kwargs.pop("num_speakers", None),
|
| 88 |
+
"min_speakers": kwargs.pop("min_speakers", None),
|
| 89 |
+
"max_speakers": kwargs.pop("max_speakers", None),
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
if return_speakers:
|
| 93 |
+
return_timestamps = True
|
| 94 |
+
|
| 95 |
+
# Set custom user prompt if provided
|
| 96 |
+
original_prompt = None
|
| 97 |
+
if user_prompt:
|
| 98 |
+
original_prompt = self.model.TRANSCRIBE_PROMPT
|
| 99 |
+
self.model.TRANSCRIBE_PROMPT = user_prompt
|
| 100 |
+
|
| 101 |
+
# Store audio for timestamp alignment and diarization
|
| 102 |
+
if return_timestamps or return_speakers:
|
| 103 |
+
self._current_audio = self._extract_audio(inputs)
|
| 104 |
+
|
| 105 |
+
# Run standard transcription
|
| 106 |
+
result = super().__call__(inputs, **kwargs)
|
| 107 |
+
|
| 108 |
+
# Add timestamps if requested
|
| 109 |
+
if return_timestamps and self._current_audio is not None:
|
| 110 |
+
text = result.get("text", "")
|
| 111 |
+
if text:
|
| 112 |
+
try:
|
| 113 |
+
words = ForcedAligner.align(
|
| 114 |
+
self._current_audio["array"],
|
| 115 |
+
text,
|
| 116 |
+
sample_rate=self._current_audio.get("sampling_rate", 16000),
|
| 117 |
+
)
|
| 118 |
+
result["words"] = words
|
| 119 |
+
except Exception as e:
|
| 120 |
+
result["words"] = []
|
| 121 |
+
result["timestamp_error"] = str(e)
|
| 122 |
+
else:
|
| 123 |
+
result["words"] = []
|
| 124 |
+
|
| 125 |
+
# Add speaker diarization if requested
|
| 126 |
+
if return_speakers and self._current_audio is not None:
|
| 127 |
+
try:
|
| 128 |
+
# Run diarization
|
| 129 |
+
speaker_segments = SpeakerDiarizer.diarize(
|
| 130 |
+
self._current_audio["array"],
|
| 131 |
+
sample_rate=self._current_audio.get("sampling_rate", 16000),
|
| 132 |
+
**{k: v for k, v in diarization_params.items() if v is not None},
|
| 133 |
+
)
|
| 134 |
+
result["speaker_segments"] = speaker_segments
|
| 135 |
+
|
| 136 |
+
# Assign speakers to words
|
| 137 |
+
if result.get("words"):
|
| 138 |
+
result["words"] = SpeakerDiarizer.assign_speakers_to_words(
|
| 139 |
+
result["words"],
|
| 140 |
+
speaker_segments,
|
| 141 |
+
)
|
| 142 |
+
except Exception as e:
|
| 143 |
+
result["speaker_segments"] = []
|
| 144 |
+
result["diarization_error"] = str(e)
|
| 145 |
+
|
| 146 |
+
# Clean up
|
| 147 |
+
self._current_audio = None
|
| 148 |
+
if original_prompt is not None:
|
| 149 |
+
self.model.TRANSCRIBE_PROMPT = original_prompt
|
| 150 |
+
|
| 151 |
+
return result
|
| 152 |
+
|
| 153 |
+
def _extract_audio(self, inputs) -> dict | None:
|
| 154 |
+
"""Extract audio array from various input formats using HF utilities."""
|
| 155 |
+
from transformers.pipelines.audio_utils import ffmpeg_read
|
| 156 |
+
|
| 157 |
+
if isinstance(inputs, dict):
|
| 158 |
+
if "array" in inputs:
|
| 159 |
+
return {
|
| 160 |
+
"array": inputs["array"],
|
| 161 |
+
"sampling_rate": inputs.get("sampling_rate", 16000),
|
| 162 |
+
}
|
| 163 |
+
if "raw" in inputs:
|
| 164 |
+
return {
|
| 165 |
+
"array": inputs["raw"],
|
| 166 |
+
"sampling_rate": inputs.get("sampling_rate", 16000),
|
| 167 |
+
}
|
| 168 |
+
elif isinstance(inputs, str):
|
| 169 |
+
# File path - load audio using ffmpeg (same as HF pipeline)
|
| 170 |
+
with Path(inputs).open("rb") as f:
|
| 171 |
+
audio = ffmpeg_read(f.read(), sampling_rate=16000)
|
| 172 |
+
return {"array": audio, "sampling_rate": 16000}
|
| 173 |
+
elif isinstance(inputs, bytes):
|
| 174 |
+
audio = ffmpeg_read(inputs, sampling_rate=16000)
|
| 175 |
+
return {"array": audio, "sampling_rate": 16000}
|
| 176 |
+
elif isinstance(inputs, np.ndarray):
|
| 177 |
+
return {"array": inputs, "sampling_rate": 16000}
|
| 178 |
+
|
| 179 |
+
return None
|
| 180 |
+
|
| 181 |
+
def preprocess(self, inputs, **preprocess_params):
|
| 182 |
+
"""Preprocess audio inputs for the model.
|
| 183 |
+
|
| 184 |
+
Args:
|
| 185 |
+
inputs: Audio input (dict with array, file path, etc.)
|
| 186 |
+
**preprocess_params: Additional preprocessing parameters
|
| 187 |
+
|
| 188 |
+
Yields:
|
| 189 |
+
Model input dicts with input_features and attention_mask
|
| 190 |
+
"""
|
| 191 |
+
# Handle dict with "array" key (from datasets)
|
| 192 |
+
if isinstance(inputs, dict) and "array" in inputs:
|
| 193 |
+
inputs = {
|
| 194 |
+
"raw": inputs["array"],
|
| 195 |
+
"sampling_rate": inputs.get("sampling_rate", self.feature_extractor.sampling_rate),
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
for item in super().preprocess(inputs, **preprocess_params):
|
| 199 |
+
if "is_last" not in item:
|
| 200 |
+
item["is_last"] = True
|
| 201 |
+
yield item
|
| 202 |
+
|
| 203 |
+
def _forward(self, model_inputs, **generate_kwargs) -> dict[str, Any]:
|
| 204 |
+
"""Run model forward pass to generate transcription.
|
| 205 |
+
|
| 206 |
+
Args:
|
| 207 |
+
model_inputs: Dict with input_features and attention_mask
|
| 208 |
+
**generate_kwargs: Generation parameters
|
| 209 |
+
|
| 210 |
+
Returns:
|
| 211 |
+
Dict with generated token IDs
|
| 212 |
+
"""
|
| 213 |
+
# Extract audio features and is_last flag
|
| 214 |
+
is_last = model_inputs.pop("is_last", True) if isinstance(model_inputs, dict) else True
|
| 215 |
+
|
| 216 |
+
input_features = model_inputs["input_features"].to(self.model.device)
|
| 217 |
+
audio_attention_mask = model_inputs["attention_mask"].to(self.model.device)
|
| 218 |
+
|
| 219 |
+
generated_ids = self.model.generate(
|
| 220 |
+
input_features=input_features,
|
| 221 |
+
audio_attention_mask=audio_attention_mask,
|
| 222 |
+
**generate_kwargs,
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
return {"tokens": generated_ids, "is_last": is_last}
|
| 226 |
+
|
| 227 |
+
def postprocess(self, model_outputs, **kwargs) -> dict[str, str]:
|
| 228 |
+
"""Convert model output tokens to text.
|
| 229 |
+
|
| 230 |
+
Args:
|
| 231 |
+
model_outputs: Dict with 'tokens' key containing generated IDs
|
| 232 |
+
**kwargs: Additional postprocessing parameters
|
| 233 |
+
|
| 234 |
+
Returns:
|
| 235 |
+
Dict with 'text' key containing transcription
|
| 236 |
+
"""
|
| 237 |
+
# Handle list of outputs (from chunking)
|
| 238 |
+
if isinstance(model_outputs, list):
|
| 239 |
+
model_outputs = model_outputs[0] if model_outputs else {}
|
| 240 |
+
|
| 241 |
+
tokens = model_outputs.get("tokens")
|
| 242 |
+
if tokens is None:
|
| 243 |
+
return super().postprocess(model_outputs, **kwargs)
|
| 244 |
+
|
| 245 |
+
if torch.is_tensor(tokens):
|
| 246 |
+
tokens = tokens.cpu()
|
| 247 |
+
if tokens.dim() > 1:
|
| 248 |
+
tokens = tokens[0]
|
| 249 |
+
|
| 250 |
+
# Filter out eos tokens that the tokenizer doesn't recognize as special
|
| 251 |
+
# (generation_config.eos_token_id may differ from tokenizer.eos_token_id)
|
| 252 |
+
if hasattr(self, "model") and hasattr(self.model, "generation_config"):
|
| 253 |
+
eos_ids = self.model.generation_config.eos_token_id
|
| 254 |
+
if eos_ids is not None:
|
| 255 |
+
eos_set = set(eos_ids) if isinstance(eos_ids, list) else {eos_ids}
|
| 256 |
+
tokens = [t for t in tokens.tolist() if t not in eos_set]
|
| 257 |
+
|
| 258 |
+
text = self.tokenizer.decode(tokens, skip_special_tokens=True).strip()
|
| 259 |
+
# Strip <think>...</think> tags (Qwen3 doesn't respect /no_think prompt)
|
| 260 |
+
text = re.sub(r"<think>.*?</think>\s*", "", text, flags=re.DOTALL).strip()
|
| 261 |
+
# Truncate repetitions at end of text
|
| 262 |
+
text = _truncate_repetitions(text)
|
| 263 |
+
return {"text": text}
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def _truncate_repetitions(text: str, min_repeats: int = 3) -> str:
|
| 267 |
+
"""Truncate repeated words/phrases/characters at end of text.
|
| 268 |
+
|
| 269 |
+
Detects patterns like:
|
| 270 |
+
- Repeated words: "the the the the" -> "the"
|
| 271 |
+
- Repeated phrases: "i am sorry i am sorry i am sorry" -> "i am sorry"
|
| 272 |
+
- Repeated characters: "444444" -> "4"
|
| 273 |
+
|
| 274 |
+
Args:
|
| 275 |
+
text: Input text to process
|
| 276 |
+
min_repeats: Minimum repetitions to trigger truncation (default 3)
|
| 277 |
+
|
| 278 |
+
Returns:
|
| 279 |
+
Text with trailing repetitions removed
|
| 280 |
+
"""
|
| 281 |
+
if not text:
|
| 282 |
+
return text
|
| 283 |
+
|
| 284 |
+
# 1. Truncate repeated characters at end (e.g., "444444" -> "4")
|
| 285 |
+
char_pattern = re.compile(r"(.)\1{" + str(min_repeats - 1) + r",}$")
|
| 286 |
+
text = char_pattern.sub(r"\1", text)
|
| 287 |
+
|
| 288 |
+
# 2. Truncate repeated words at end (e.g., "the the the" -> "the")
|
| 289 |
+
word_pattern = re.compile(
|
| 290 |
+
r"\b(\w+)(?:\s+\1){" + str(min_repeats - 1) + r",}\s*$", re.IGNORECASE
|
| 291 |
+
)
|
| 292 |
+
while word_pattern.search(text):
|
| 293 |
+
text = word_pattern.sub(r"\1", text)
|
| 294 |
+
|
| 295 |
+
# 3. Truncate repeated phrases (2-20 words) at end
|
| 296 |
+
# e.g., "i am sorry i am sorry i am sorry" -> "i am sorry"
|
| 297 |
+
words = text.split()
|
| 298 |
+
if len(words) >= min_repeats * 2:
|
| 299 |
+
# Try phrase lengths from 2 to 20 words
|
| 300 |
+
for phrase_len in range(2, min(21, len(words) // min_repeats + 1)):
|
| 301 |
+
# Check if the last phrase_len words repeat
|
| 302 |
+
phrase = " ".join(words[-phrase_len:])
|
| 303 |
+
# Build pattern to match repeated phrases at end
|
| 304 |
+
phrase_escaped = re.escape(phrase)
|
| 305 |
+
phrase_pattern = re.compile(
|
| 306 |
+
r"(^|.*?\s)("
|
| 307 |
+
+ phrase_escaped
|
| 308 |
+
+ r")(?:\s+"
|
| 309 |
+
+ phrase_escaped
|
| 310 |
+
+ r"){"
|
| 311 |
+
+ str(min_repeats - 1)
|
| 312 |
+
+ r",}\s*$",
|
| 313 |
+
re.IGNORECASE,
|
| 314 |
+
)
|
| 315 |
+
match = phrase_pattern.match(text)
|
| 316 |
+
if match:
|
| 317 |
+
# Keep prefix + one instance of the phrase
|
| 318 |
+
text = (match.group(1) + match.group(2)).strip()
|
| 319 |
+
words = text.split()
|
| 320 |
+
break
|
| 321 |
+
|
| 322 |
+
return text
|
asr_processing.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import transformers
|
| 5 |
+
from transformers import ProcessorMixin
|
| 6 |
+
|
| 7 |
+
try:
|
| 8 |
+
from .asr_config import ASRConfig
|
| 9 |
+
except ImportError:
|
| 10 |
+
from asr_config import ASRConfig # type: ignore[no-redef]
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class ASRProcessor(ProcessorMixin):
|
| 14 |
+
"""Processor for Whisper-based ASR models."""
|
| 15 |
+
|
| 16 |
+
attributes = ["feature_extractor", "tokenizer"]
|
| 17 |
+
feature_extractor_class = "AutoFeatureExtractor"
|
| 18 |
+
tokenizer_class = "AutoTokenizer"
|
| 19 |
+
AUDIO_TOKEN = "<audio>"
|
| 20 |
+
TRANSCRIBE_PROMPT = ""
|
| 21 |
+
# Default conv layers for Whisper/GLM-ASR: [(pad, kernel, stride), ...]
|
| 22 |
+
DEFAULT_ENCODER_CONV_LAYERS = [(1, 3, 1), (1, 3, 2)]
|
| 23 |
+
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
feature_extractor,
|
| 27 |
+
tokenizer,
|
| 28 |
+
projector=None,
|
| 29 |
+
encoder_conv_layers: Optional[list] = None,
|
| 30 |
+
):
|
| 31 |
+
"""Initialize the ASR processor.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
feature_extractor: Audio feature extractor (WhisperFeatureExtractor)
|
| 35 |
+
tokenizer: Text tokenizer for the language model
|
| 36 |
+
projector: Audio projector module (for computing output lengths)
|
| 37 |
+
encoder_conv_layers: Conv layer specs [(pad, kernel, stride), ...]
|
| 38 |
+
"""
|
| 39 |
+
self.feature_extractor = feature_extractor
|
| 40 |
+
self.tokenizer = tokenizer
|
| 41 |
+
self.audio_token_id = tokenizer.convert_tokens_to_ids(self.AUDIO_TOKEN)
|
| 42 |
+
self.projector = projector
|
| 43 |
+
self.encoder_conv_layers = encoder_conv_layers or self.DEFAULT_ENCODER_CONV_LAYERS
|
| 44 |
+
|
| 45 |
+
def _compute_encoder_output_length(self, mel_length: int) -> int:
|
| 46 |
+
"""Compute encoder output length using conv layer formulas."""
|
| 47 |
+
length = mel_length
|
| 48 |
+
for padding, kernel_size, stride in self.encoder_conv_layers:
|
| 49 |
+
length = (length + 2 * padding - (kernel_size - 1) - 1) // stride + 1
|
| 50 |
+
return length
|
| 51 |
+
|
| 52 |
+
def __call__(
|
| 53 |
+
self,
|
| 54 |
+
audio: Optional[Union[list, "torch.Tensor"]] = None,
|
| 55 |
+
text: Optional[str] = None,
|
| 56 |
+
system_prompt: Optional[str] = None,
|
| 57 |
+
return_tensors: str = "pt",
|
| 58 |
+
**kwargs,
|
| 59 |
+
) -> dict:
|
| 60 |
+
"""Process audio and text inputs for inference.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
audio: Raw audio waveform(s)
|
| 64 |
+
text: Target transcription (optional, for training - but use DataCollator instead)
|
| 65 |
+
system_prompt: Optional system prompt
|
| 66 |
+
return_tensors: Return format ("pt" for PyTorch)
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
Dict with input_features, input_ids, attention_mask
|
| 70 |
+
"""
|
| 71 |
+
result = {}
|
| 72 |
+
|
| 73 |
+
# Process audio
|
| 74 |
+
if audio is not None:
|
| 75 |
+
audio_inputs = self.feature_extractor(
|
| 76 |
+
audio,
|
| 77 |
+
sampling_rate=getattr(self.feature_extractor, "sampling_rate", 16000),
|
| 78 |
+
return_attention_mask=True,
|
| 79 |
+
return_tensors=return_tensors,
|
| 80 |
+
**kwargs,
|
| 81 |
+
)
|
| 82 |
+
result["input_features"] = audio_inputs["input_features"]
|
| 83 |
+
result["audio_attention_mask"] = audio_inputs["attention_mask"]
|
| 84 |
+
|
| 85 |
+
# Use actual audio length (from attention mask) for token count
|
| 86 |
+
real_mel_len = int(audio_inputs["attention_mask"].sum(dim=-1).max().item())
|
| 87 |
+
encoder_output_len = self._compute_encoder_output_length(real_mel_len)
|
| 88 |
+
num_audio_tokens = self.projector.get_output_length(encoder_output_len)
|
| 89 |
+
else:
|
| 90 |
+
num_audio_tokens = 0
|
| 91 |
+
|
| 92 |
+
# Build prompt with audio token placeholders (instruction-free)
|
| 93 |
+
if num_audio_tokens > 0:
|
| 94 |
+
user_content = self.AUDIO_TOKEN * num_audio_tokens
|
| 95 |
+
if self.TRANSCRIBE_PROMPT:
|
| 96 |
+
user_content += " " + self.TRANSCRIBE_PROMPT
|
| 97 |
+
else:
|
| 98 |
+
user_content = self.TRANSCRIBE_PROMPT or ""
|
| 99 |
+
|
| 100 |
+
messages = []
|
| 101 |
+
if system_prompt:
|
| 102 |
+
messages.append({"role": "system", "content": system_prompt})
|
| 103 |
+
messages.append({"role": "user", "content": user_content})
|
| 104 |
+
if text is not None:
|
| 105 |
+
messages.append({"role": "assistant", "content": text})
|
| 106 |
+
|
| 107 |
+
# Tokenize
|
| 108 |
+
tokenized = self.tokenizer.apply_chat_template(
|
| 109 |
+
messages,
|
| 110 |
+
tokenize=True,
|
| 111 |
+
add_generation_prompt=(text is None),
|
| 112 |
+
return_tensors=return_tensors,
|
| 113 |
+
enable_thinking=False, # Disable Qwen3 thinking mode for ASR
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
# Handle both tensor and BatchEncoding returns
|
| 117 |
+
if isinstance(tokenized, torch.Tensor):
|
| 118 |
+
input_ids = tokenized
|
| 119 |
+
else:
|
| 120 |
+
# BatchEncoding or dict-like object
|
| 121 |
+
input_ids = tokenized.get("input_ids", tokenized.input_ids)
|
| 122 |
+
|
| 123 |
+
if input_ids.dim() == 1:
|
| 124 |
+
input_ids = input_ids.unsqueeze(0)
|
| 125 |
+
|
| 126 |
+
result["input_ids"] = input_ids
|
| 127 |
+
result["attention_mask"] = torch.ones_like(input_ids)
|
| 128 |
+
|
| 129 |
+
return result
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
ASRProcessor.register_for_auto_class()
|
| 133 |
+
transformers.AutoProcessor.register(ASRConfig, ASRProcessor)
|
chat_template.jinja
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{%- if tools %}
|
| 2 |
+
{{- '<|im_start|>system\n' }}
|
| 3 |
+
{%- if messages[0].role == 'system' %}
|
| 4 |
+
{{- messages[0].content + '\n\n' }}
|
| 5 |
+
{%- endif %}
|
| 6 |
+
{{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
|
| 7 |
+
{%- for tool in tools %}
|
| 8 |
+
{{- "\n" }}
|
| 9 |
+
{{- tool | tojson }}
|
| 10 |
+
{%- endfor %}
|
| 11 |
+
{{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
|
| 12 |
+
{%- else %}
|
| 13 |
+
{%- if messages[0].role == 'system' %}
|
| 14 |
+
{{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
|
| 15 |
+
{%- endif %}
|
| 16 |
+
{%- endif %}
|
| 17 |
+
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
|
| 18 |
+
{%- for message in messages[::-1] %}
|
| 19 |
+
{%- set index = (messages|length - 1) - loop.index0 %}
|
| 20 |
+
{%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}
|
| 21 |
+
{%- set ns.multi_step_tool = false %}
|
| 22 |
+
{%- set ns.last_query_index = index %}
|
| 23 |
+
{%- endif %}
|
| 24 |
+
{%- endfor %}
|
| 25 |
+
{%- for message in messages %}
|
| 26 |
+
{%- if message.content is string %}
|
| 27 |
+
{%- set content = message.content %}
|
| 28 |
+
{%- else %}
|
| 29 |
+
{%- set content = '' %}
|
| 30 |
+
{%- endif %}
|
| 31 |
+
{%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
|
| 32 |
+
{{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
|
| 33 |
+
{%- elif message.role == "assistant" %}
|
| 34 |
+
{%- set reasoning_content = '' %}
|
| 35 |
+
{%- if message.reasoning_content is string %}
|
| 36 |
+
{%- set reasoning_content = message.reasoning_content %}
|
| 37 |
+
{%- else %}
|
| 38 |
+
{%- if '</think>' in content %}
|
| 39 |
+
{%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
|
| 40 |
+
{%- set content = content.split('</think>')[-1].lstrip('\n') %}
|
| 41 |
+
{%- endif %}
|
| 42 |
+
{%- endif %}
|
| 43 |
+
{%- if loop.index0 > ns.last_query_index %}
|
| 44 |
+
{%- if loop.last or (not loop.last and reasoning_content) %}
|
| 45 |
+
{{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
|
| 46 |
+
{%- else %}
|
| 47 |
+
{{- '<|im_start|>' + message.role + '\n' + content }}
|
| 48 |
+
{%- endif %}
|
| 49 |
+
{%- else %}
|
| 50 |
+
{{- '<|im_start|>' + message.role + '\n' + content }}
|
| 51 |
+
{%- endif %}
|
| 52 |
+
{%- if message.tool_calls %}
|
| 53 |
+
{%- for tool_call in message.tool_calls %}
|
| 54 |
+
{%- if (loop.first and content) or (not loop.first) %}
|
| 55 |
+
{{- '\n' }}
|
| 56 |
+
{%- endif %}
|
| 57 |
+
{%- if tool_call.function %}
|
| 58 |
+
{%- set tool_call = tool_call.function %}
|
| 59 |
+
{%- endif %}
|
| 60 |
+
{{- '<tool_call>\n{"name": "' }}
|
| 61 |
+
{{- tool_call.name }}
|
| 62 |
+
{{- '", "arguments": ' }}
|
| 63 |
+
{%- if tool_call.arguments is string %}
|
| 64 |
+
{{- tool_call.arguments }}
|
| 65 |
+
{%- else %}
|
| 66 |
+
{{- tool_call.arguments | tojson }}
|
| 67 |
+
{%- endif %}
|
| 68 |
+
{{- '}\n</tool_call>' }}
|
| 69 |
+
{%- endfor %}
|
| 70 |
+
{%- endif %}
|
| 71 |
+
{{- '<|im_end|>\n' }}
|
| 72 |
+
{%- elif message.role == "tool" %}
|
| 73 |
+
{%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
|
| 74 |
+
{{- '<|im_start|>user' }}
|
| 75 |
+
{%- endif %}
|
| 76 |
+
{{- '\n<tool_response>\n' }}
|
| 77 |
+
{{- content }}
|
| 78 |
+
{{- '\n</tool_response>' }}
|
| 79 |
+
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
|
| 80 |
+
{{- '<|im_end|>\n' }}
|
| 81 |
+
{%- endif %}
|
| 82 |
+
{%- endif %}
|
| 83 |
+
{%- endfor %}
|
| 84 |
+
{%- if add_generation_prompt %}
|
| 85 |
+
{{- '<|im_start|>assistant\n' }}
|
| 86 |
+
{%- if true %}
|
| 87 |
+
{{- '<think>\n\n</think>\n\n' }}
|
| 88 |
+
{%- endif %}
|
| 89 |
+
{%- endif %}
|
diarization.py
ADDED
|
@@ -0,0 +1,732 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Speaker diarization using TEN-VAD + ECAPA-TDNN + spectral clustering.
|
| 2 |
+
|
| 3 |
+
Spectral clustering implementation adapted from FunASR/3D-Speaker:
|
| 4 |
+
https://github.com/alibaba-damo-academy/FunASR
|
| 5 |
+
MIT License (https://opensource.org/licenses/MIT)
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import warnings
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import scipy
|
| 12 |
+
import sklearn.metrics.pairwise
|
| 13 |
+
import torch
|
| 14 |
+
from sklearn.cluster._kmeans import k_means
|
| 15 |
+
from sklearn.preprocessing import normalize
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _get_device() -> torch.device:
|
| 19 |
+
"""Get best available device for inference."""
|
| 20 |
+
if torch.cuda.is_available():
|
| 21 |
+
return torch.device("cuda")
|
| 22 |
+
if torch.backends.mps.is_available():
|
| 23 |
+
return torch.device("mps")
|
| 24 |
+
return torch.device("cpu")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class SpectralCluster:
|
| 28 |
+
"""Spectral clustering using unnormalized Laplacian of affinity matrix.
|
| 29 |
+
|
| 30 |
+
Adapted from FunASR/3D-Speaker and SpeechBrain implementations.
|
| 31 |
+
Uses eigenvalue gap to automatically determine number of speakers.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(self, min_num_spks: int = 1, max_num_spks: int = 15, pval: float = 0.06):
|
| 35 |
+
self.min_num_spks = min_num_spks
|
| 36 |
+
self.max_num_spks = max_num_spks
|
| 37 |
+
self.pval = pval
|
| 38 |
+
|
| 39 |
+
def __call__(self, embeddings: np.ndarray, oracle_num: int | None = None) -> np.ndarray:
|
| 40 |
+
"""Run spectral clustering on embeddings.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
embeddings: Speaker embeddings of shape [N, D]
|
| 44 |
+
oracle_num: Optional known number of speakers
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
Cluster labels of shape [N]
|
| 48 |
+
"""
|
| 49 |
+
# Similarity matrix computation
|
| 50 |
+
sim_mat = self.get_sim_mat(embeddings)
|
| 51 |
+
|
| 52 |
+
# Refining similarity matrix with pval
|
| 53 |
+
prunned_sim_mat = self.p_pruning(sim_mat)
|
| 54 |
+
|
| 55 |
+
# Symmetrization
|
| 56 |
+
sym_prund_sim_mat = 0.5 * (prunned_sim_mat + prunned_sim_mat.T)
|
| 57 |
+
|
| 58 |
+
# Laplacian calculation
|
| 59 |
+
laplacian = self.get_laplacian(sym_prund_sim_mat)
|
| 60 |
+
|
| 61 |
+
# Get Spectral Embeddings
|
| 62 |
+
emb, num_of_spk = self.get_spec_embs(laplacian, oracle_num)
|
| 63 |
+
|
| 64 |
+
# Perform clustering
|
| 65 |
+
return self.cluster_embs(emb, num_of_spk)
|
| 66 |
+
|
| 67 |
+
def get_sim_mat(self, embeddings: np.ndarray) -> np.ndarray:
|
| 68 |
+
"""Compute cosine similarity matrix."""
|
| 69 |
+
return sklearn.metrics.pairwise.cosine_similarity(embeddings, embeddings)
|
| 70 |
+
|
| 71 |
+
def p_pruning(self, affinity: np.ndarray) -> np.ndarray:
|
| 72 |
+
"""Prune low similarity values in affinity matrix (keep top pval fraction)."""
|
| 73 |
+
n = affinity.shape[0]
|
| 74 |
+
pval = max(self.pval, 6.0 / n)
|
| 75 |
+
k_keep = max(1, int(pval * n))
|
| 76 |
+
|
| 77 |
+
# Vectorized: find top-k indices per row and zero out the rest
|
| 78 |
+
top_k_idx = np.argpartition(affinity, -k_keep, axis=1)[:, -k_keep:]
|
| 79 |
+
mask = np.zeros_like(affinity, dtype=bool)
|
| 80 |
+
np.put_along_axis(mask, top_k_idx, True, axis=1)
|
| 81 |
+
affinity[~mask] = 0
|
| 82 |
+
return affinity
|
| 83 |
+
|
| 84 |
+
def get_laplacian(self, sim_mat: np.ndarray) -> np.ndarray:
|
| 85 |
+
"""Compute unnormalized Laplacian matrix."""
|
| 86 |
+
from scipy.sparse.csgraph import laplacian
|
| 87 |
+
|
| 88 |
+
np.fill_diagonal(sim_mat, 0)
|
| 89 |
+
return laplacian(sim_mat, normed=False)
|
| 90 |
+
|
| 91 |
+
def get_spec_embs(
|
| 92 |
+
self, laplacian: np.ndarray, k_oracle: int | None = None
|
| 93 |
+
) -> tuple[np.ndarray, int]:
|
| 94 |
+
"""Extract spectral embeddings from Laplacian."""
|
| 95 |
+
lambdas, eig_vecs = scipy.linalg.eigh(laplacian)
|
| 96 |
+
|
| 97 |
+
if k_oracle is not None:
|
| 98 |
+
num_of_spk = k_oracle
|
| 99 |
+
else:
|
| 100 |
+
lambda_gap_list = self.get_eigen_gaps(
|
| 101 |
+
lambdas[self.min_num_spks - 1 : self.max_num_spks + 1]
|
| 102 |
+
)
|
| 103 |
+
num_of_spk = np.argmax(lambda_gap_list) + self.min_num_spks
|
| 104 |
+
|
| 105 |
+
emb = eig_vecs[:, :num_of_spk]
|
| 106 |
+
return emb, num_of_spk
|
| 107 |
+
|
| 108 |
+
def cluster_embs(self, emb: np.ndarray, k: int) -> np.ndarray:
|
| 109 |
+
"""Cluster spectral embeddings using k-means."""
|
| 110 |
+
_, labels, _ = k_means(emb, k, n_init=10)
|
| 111 |
+
return labels
|
| 112 |
+
|
| 113 |
+
def get_eigen_gaps(self, eig_vals: np.ndarray) -> np.ndarray:
|
| 114 |
+
"""Compute gaps between consecutive eigenvalues."""
|
| 115 |
+
return np.diff(eig_vals)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class SpeakerClusterer:
|
| 119 |
+
"""Speaker clustering backend using spectral clustering with speaker merging.
|
| 120 |
+
|
| 121 |
+
Features:
|
| 122 |
+
- Spectral clustering with eigenvalue gap for auto speaker count detection
|
| 123 |
+
- P-pruning for affinity matrix refinement
|
| 124 |
+
- Post-clustering speaker merging by cosine similarity
|
| 125 |
+
"""
|
| 126 |
+
|
| 127 |
+
def __init__(
|
| 128 |
+
self,
|
| 129 |
+
min_num_spks: int = 2,
|
| 130 |
+
max_num_spks: int = 10,
|
| 131 |
+
merge_thr: float = 0.90, # Moderate merging
|
| 132 |
+
):
|
| 133 |
+
self.min_num_spks = min_num_spks
|
| 134 |
+
self.max_num_spks = max_num_spks
|
| 135 |
+
self.merge_thr = merge_thr
|
| 136 |
+
self._spectral_cluster: SpectralCluster | None = None
|
| 137 |
+
|
| 138 |
+
def _get_spectral_cluster(self) -> SpectralCluster:
|
| 139 |
+
"""Lazy-load spectral clusterer."""
|
| 140 |
+
if self._spectral_cluster is None:
|
| 141 |
+
self._spectral_cluster = SpectralCluster(
|
| 142 |
+
min_num_spks=self.min_num_spks,
|
| 143 |
+
max_num_spks=self.max_num_spks,
|
| 144 |
+
)
|
| 145 |
+
return self._spectral_cluster
|
| 146 |
+
|
| 147 |
+
def __call__(self, embeddings: np.ndarray, num_speakers: int | None = None) -> np.ndarray:
|
| 148 |
+
"""Cluster speaker embeddings and return labels.
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
embeddings: Speaker embeddings of shape [N, D]
|
| 152 |
+
num_speakers: Optional oracle number of speakers
|
| 153 |
+
|
| 154 |
+
Returns:
|
| 155 |
+
Cluster labels of shape [N]
|
| 156 |
+
"""
|
| 157 |
+
import warnings
|
| 158 |
+
|
| 159 |
+
if len(embeddings.shape) != 2:
|
| 160 |
+
raise ValueError(f"Expected 2D array, got shape {embeddings.shape}")
|
| 161 |
+
|
| 162 |
+
# Handle edge cases
|
| 163 |
+
if embeddings.shape[0] == 0:
|
| 164 |
+
return np.array([], dtype=int)
|
| 165 |
+
if embeddings.shape[0] == 1:
|
| 166 |
+
return np.array([0], dtype=int)
|
| 167 |
+
if embeddings.shape[0] < 6:
|
| 168 |
+
return np.zeros(embeddings.shape[0], dtype=int)
|
| 169 |
+
|
| 170 |
+
# Normalize embeddings and replace NaN/inf
|
| 171 |
+
embeddings = np.nan_to_num(embeddings, nan=0.0, posinf=0.0, neginf=0.0)
|
| 172 |
+
embeddings = normalize(embeddings)
|
| 173 |
+
|
| 174 |
+
# Run spectral clustering (suppress numerical warnings)
|
| 175 |
+
spectral = self._get_spectral_cluster()
|
| 176 |
+
|
| 177 |
+
# Update min/max for oracle case
|
| 178 |
+
if num_speakers is not None:
|
| 179 |
+
spectral.min_num_spks = num_speakers
|
| 180 |
+
spectral.max_num_spks = num_speakers
|
| 181 |
+
|
| 182 |
+
with warnings.catch_warnings():
|
| 183 |
+
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
| 184 |
+
labels = spectral(embeddings, oracle_num=num_speakers)
|
| 185 |
+
|
| 186 |
+
# Reset min/max
|
| 187 |
+
if num_speakers is not None:
|
| 188 |
+
spectral.min_num_spks = self.min_num_spks
|
| 189 |
+
spectral.max_num_spks = self.max_num_spks
|
| 190 |
+
|
| 191 |
+
# Merge similar speakers if no oracle
|
| 192 |
+
if num_speakers is None:
|
| 193 |
+
labels = self._merge_by_cos(labels, embeddings, self.merge_thr)
|
| 194 |
+
|
| 195 |
+
# Re-index labels sequentially
|
| 196 |
+
_, labels = np.unique(labels, return_inverse=True)
|
| 197 |
+
|
| 198 |
+
return labels
|
| 199 |
+
|
| 200 |
+
def _merge_by_cos(self, labels: np.ndarray, embs: np.ndarray, cos_thr: float) -> np.ndarray:
|
| 201 |
+
"""Merge similar speakers by cosine similarity of centroids."""
|
| 202 |
+
from scipy.cluster.hierarchy import fcluster, linkage
|
| 203 |
+
from scipy.spatial.distance import pdist
|
| 204 |
+
|
| 205 |
+
unique_labels = np.unique(labels)
|
| 206 |
+
if len(unique_labels) <= 1:
|
| 207 |
+
return labels
|
| 208 |
+
|
| 209 |
+
# Compute normalized speaker centroids
|
| 210 |
+
centroids = np.array([embs[labels == lbl].mean(0) for lbl in unique_labels])
|
| 211 |
+
centroids = normalize(centroids)
|
| 212 |
+
|
| 213 |
+
# Hierarchical clustering with cosine distance
|
| 214 |
+
distances = pdist(centroids, metric="cosine")
|
| 215 |
+
linkage_matrix = linkage(distances, method="average")
|
| 216 |
+
merged_labels = fcluster(linkage_matrix, t=1.0 - cos_thr, criterion="distance") - 1
|
| 217 |
+
|
| 218 |
+
# Map original labels to merged labels
|
| 219 |
+
label_map = dict(zip(unique_labels, merged_labels))
|
| 220 |
+
return np.array([label_map[lbl] for lbl in labels])
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
class LocalSpeakerDiarizer:
|
| 224 |
+
"""Local speaker diarization using TEN-VAD + ECAPA-TDNN + spectral clustering.
|
| 225 |
+
|
| 226 |
+
Pipeline:
|
| 227 |
+
1. TEN-VAD detects speech segments
|
| 228 |
+
2. Sliding window (1.0s, 75% overlap) for uniform embedding extraction
|
| 229 |
+
3. ECAPA-TDNN extracts speaker embeddings per window
|
| 230 |
+
4. Spectral clustering with eigenvalue gap for auto speaker detection
|
| 231 |
+
5. Frame-level consensus voting for segment reconstruction
|
| 232 |
+
6. Post-processing merges short segments to reduce flicker
|
| 233 |
+
|
| 234 |
+
Tunable Parameters (class attributes):
|
| 235 |
+
- WINDOW_SIZE: Embedding extraction window size in seconds
|
| 236 |
+
- STEP_SIZE: Sliding window step size (overlap = WINDOW_SIZE - STEP_SIZE)
|
| 237 |
+
- VAD_THRESHOLD: Speech detection threshold (lower = more sensitive)
|
| 238 |
+
- VAD_MIN_DURATION: Minimum speech segment duration
|
| 239 |
+
- VAD_MAX_GAP: Maximum gap to bridge between segments
|
| 240 |
+
- VAD_PAD_ONSET/OFFSET: Padding added to speech segments
|
| 241 |
+
- VOTING_RATE: Frame resolution for consensus voting
|
| 242 |
+
- MIN_SEGMENT_DURATION: Minimum final segment duration
|
| 243 |
+
- SAME_SPEAKER_GAP: Maximum gap to merge same-speaker segments
|
| 244 |
+
- TAIL_COVERAGE_RATIO: Minimum tail coverage to add extra window
|
| 245 |
+
"""
|
| 246 |
+
|
| 247 |
+
_ten_vad_model = None
|
| 248 |
+
_ecapa_model = None
|
| 249 |
+
_device = None
|
| 250 |
+
|
| 251 |
+
# ==================== TUNABLE PARAMETERS ====================
|
| 252 |
+
|
| 253 |
+
# Sliding window for embedding extraction
|
| 254 |
+
WINDOW_SIZE = 0.75 # seconds - shorter window for finer resolution
|
| 255 |
+
STEP_SIZE = 0.15 # seconds (80% overlap for more votes)
|
| 256 |
+
TAIL_COVERAGE_RATIO = 0.1 # Add extra window if tail > this ratio of window
|
| 257 |
+
|
| 258 |
+
# VAD hysteresis parameters
|
| 259 |
+
VAD_THRESHOLD = 0.25 # Balanced threshold
|
| 260 |
+
VAD_MIN_DURATION = 0.05 # Minimum speech segment duration (seconds)
|
| 261 |
+
VAD_MAX_GAP = 0.50 # Bridge gaps shorter than this (seconds)
|
| 262 |
+
VAD_PAD_ONSET = 0.05 # Padding at segment start (seconds)
|
| 263 |
+
VAD_PAD_OFFSET = 0.05 # Padding at segment end (seconds)
|
| 264 |
+
|
| 265 |
+
# Frame-level voting
|
| 266 |
+
VOTING_RATE = 0.01 # 10ms resolution for consensus voting
|
| 267 |
+
|
| 268 |
+
# Post-processing
|
| 269 |
+
MIN_SEGMENT_DURATION = 0.15 # Minimum final segment duration (seconds)
|
| 270 |
+
SHORT_SEGMENT_GAP = 0.1 # Gap threshold for merging short segments
|
| 271 |
+
SAME_SPEAKER_GAP = 0.5 # Gap threshold for merging same-speaker segments
|
| 272 |
+
|
| 273 |
+
# ===========================================================
|
| 274 |
+
|
| 275 |
+
@classmethod
|
| 276 |
+
def _get_ten_vad_model(cls):
|
| 277 |
+
"""Lazy-load TEN-VAD model (singleton)."""
|
| 278 |
+
if cls._ten_vad_model is None:
|
| 279 |
+
from ten_vad import TenVad
|
| 280 |
+
|
| 281 |
+
cls._ten_vad_model = TenVad(hop_size=256, threshold=cls.VAD_THRESHOLD)
|
| 282 |
+
return cls._ten_vad_model
|
| 283 |
+
|
| 284 |
+
@classmethod
|
| 285 |
+
def _get_device(cls) -> torch.device:
|
| 286 |
+
"""Get the best available device."""
|
| 287 |
+
if cls._device is None:
|
| 288 |
+
cls._device = _get_device()
|
| 289 |
+
return cls._device
|
| 290 |
+
|
| 291 |
+
@classmethod
|
| 292 |
+
def _get_ecapa_model(cls):
|
| 293 |
+
"""Lazy-load ECAPA-TDNN speaker embedding model (singleton)."""
|
| 294 |
+
if cls._ecapa_model is None:
|
| 295 |
+
# Suppress torchaudio deprecation warning from SpeechBrain
|
| 296 |
+
with warnings.catch_warnings():
|
| 297 |
+
warnings.filterwarnings("ignore", message="torchaudio._backend")
|
| 298 |
+
from speechbrain.inference.speaker import EncoderClassifier
|
| 299 |
+
|
| 300 |
+
device = cls._get_device()
|
| 301 |
+
cls._ecapa_model = EncoderClassifier.from_hparams(
|
| 302 |
+
source="speechbrain/spkrec-ecapa-voxceleb",
|
| 303 |
+
run_opts={"device": str(device)},
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
return cls._ecapa_model
|
| 307 |
+
|
| 308 |
+
@classmethod
|
| 309 |
+
def diarize(
|
| 310 |
+
cls,
|
| 311 |
+
audio: np.ndarray | str,
|
| 312 |
+
sample_rate: int = 16000,
|
| 313 |
+
num_speakers: int | None = None,
|
| 314 |
+
min_speakers: int = 2,
|
| 315 |
+
max_speakers: int = 10,
|
| 316 |
+
**_kwargs,
|
| 317 |
+
) -> list[dict]:
|
| 318 |
+
"""Run speaker diarization on audio.
|
| 319 |
+
|
| 320 |
+
Args:
|
| 321 |
+
audio: Audio waveform as numpy array or path to audio file
|
| 322 |
+
sample_rate: Audio sample rate (default 16000)
|
| 323 |
+
num_speakers: Exact number of speakers (if known)
|
| 324 |
+
min_speakers: Minimum number of speakers
|
| 325 |
+
max_speakers: Maximum number of speakers
|
| 326 |
+
|
| 327 |
+
Returns:
|
| 328 |
+
List of dicts with 'speaker', 'start', 'end' keys
|
| 329 |
+
"""
|
| 330 |
+
# Handle file path input
|
| 331 |
+
if isinstance(audio, str):
|
| 332 |
+
import librosa
|
| 333 |
+
|
| 334 |
+
audio, sample_rate = librosa.load(audio, sr=16000)
|
| 335 |
+
|
| 336 |
+
# Ensure correct sample rate
|
| 337 |
+
if sample_rate != 16000:
|
| 338 |
+
import librosa
|
| 339 |
+
|
| 340 |
+
audio = librosa.resample(audio, orig_sr=sample_rate, target_sr=16000)
|
| 341 |
+
sample_rate = 16000
|
| 342 |
+
|
| 343 |
+
audio = audio.astype(np.float32)
|
| 344 |
+
total_duration = len(audio) / sample_rate
|
| 345 |
+
|
| 346 |
+
# Step 1: VAD (returns segments and raw frame-level decisions)
|
| 347 |
+
segments, vad_frames = cls._get_speech_segments(audio, sample_rate)
|
| 348 |
+
if not segments:
|
| 349 |
+
return []
|
| 350 |
+
|
| 351 |
+
# Step 2: Extract embeddings
|
| 352 |
+
embeddings, window_segments = cls._extract_embeddings(audio, segments, sample_rate)
|
| 353 |
+
if len(embeddings) == 0:
|
| 354 |
+
return []
|
| 355 |
+
|
| 356 |
+
# Step 3: Cluster
|
| 357 |
+
clusterer = SpeakerClusterer(min_num_spks=min_speakers, max_num_spks=max_speakers)
|
| 358 |
+
labels = clusterer(embeddings, num_speakers)
|
| 359 |
+
|
| 360 |
+
# Step 4: Post-process with consensus voting (VAD-aware)
|
| 361 |
+
return cls._postprocess_segments(window_segments, labels, total_duration, vad_frames)
|
| 362 |
+
|
| 363 |
+
@classmethod
|
| 364 |
+
def _get_speech_segments(
|
| 365 |
+
cls, audio_array: np.ndarray, sample_rate: int = 16000
|
| 366 |
+
) -> tuple[list[dict], list[bool]]:
|
| 367 |
+
"""Get speech segments using TEN-VAD.
|
| 368 |
+
|
| 369 |
+
Returns:
|
| 370 |
+
Tuple of (segments list, vad_frames list of per-frame speech decisions)
|
| 371 |
+
"""
|
| 372 |
+
vad_model = cls._get_ten_vad_model()
|
| 373 |
+
|
| 374 |
+
# Convert to int16 as required by TEN-VAD
|
| 375 |
+
# Clip to prevent integer overflow
|
| 376 |
+
if audio_array.dtype != np.int16:
|
| 377 |
+
audio_int16 = (np.clip(audio_array, -1.0, 1.0) * 32767).astype(np.int16)
|
| 378 |
+
else:
|
| 379 |
+
audio_int16 = audio_array
|
| 380 |
+
|
| 381 |
+
# Process frame by frame
|
| 382 |
+
hop_size = 256
|
| 383 |
+
frame_duration = hop_size / sample_rate
|
| 384 |
+
speech_frames: list[bool] = []
|
| 385 |
+
|
| 386 |
+
for i in range(0, len(audio_int16) - hop_size, hop_size):
|
| 387 |
+
frame = audio_int16[i : i + hop_size]
|
| 388 |
+
_, is_speech = vad_model.process(frame)
|
| 389 |
+
speech_frames.append(is_speech)
|
| 390 |
+
|
| 391 |
+
# Convert frame-level decisions to segments
|
| 392 |
+
segments = []
|
| 393 |
+
in_speech = False
|
| 394 |
+
start_idx = 0
|
| 395 |
+
|
| 396 |
+
for i, is_speech in enumerate(speech_frames):
|
| 397 |
+
if is_speech and not in_speech:
|
| 398 |
+
start_idx = i
|
| 399 |
+
in_speech = True
|
| 400 |
+
elif not is_speech and in_speech:
|
| 401 |
+
start_time = start_idx * frame_duration
|
| 402 |
+
end_time = i * frame_duration
|
| 403 |
+
segments.append(
|
| 404 |
+
{
|
| 405 |
+
"start": start_time,
|
| 406 |
+
"end": end_time,
|
| 407 |
+
"start_sample": int(start_time * sample_rate),
|
| 408 |
+
"end_sample": int(end_time * sample_rate),
|
| 409 |
+
}
|
| 410 |
+
)
|
| 411 |
+
in_speech = False
|
| 412 |
+
|
| 413 |
+
# Handle trailing speech
|
| 414 |
+
if in_speech:
|
| 415 |
+
start_time = start_idx * frame_duration
|
| 416 |
+
end_time = len(speech_frames) * frame_duration
|
| 417 |
+
segments.append(
|
| 418 |
+
{
|
| 419 |
+
"start": start_time,
|
| 420 |
+
"end": end_time,
|
| 421 |
+
"start_sample": int(start_time * sample_rate),
|
| 422 |
+
"end_sample": int(end_time * sample_rate),
|
| 423 |
+
}
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
return cls._apply_vad_hysteresis(segments, sample_rate), speech_frames
|
| 427 |
+
|
| 428 |
+
@classmethod
|
| 429 |
+
def _apply_vad_hysteresis(cls, segments: list[dict], sample_rate: int = 16000) -> list[dict]:
|
| 430 |
+
"""Apply hysteresis-like post-processing to VAD segments."""
|
| 431 |
+
if not segments:
|
| 432 |
+
return segments
|
| 433 |
+
|
| 434 |
+
segments = sorted(segments, key=lambda x: x["start"])
|
| 435 |
+
|
| 436 |
+
# Fill short gaps
|
| 437 |
+
merged = [segments[0].copy()]
|
| 438 |
+
for seg in segments[1:]:
|
| 439 |
+
gap = seg["start"] - merged[-1]["end"]
|
| 440 |
+
if gap <= cls.VAD_MAX_GAP:
|
| 441 |
+
merged[-1]["end"] = seg["end"]
|
| 442 |
+
merged[-1]["end_sample"] = seg["end_sample"]
|
| 443 |
+
else:
|
| 444 |
+
merged.append(seg.copy())
|
| 445 |
+
|
| 446 |
+
# Remove short segments
|
| 447 |
+
filtered = [seg for seg in merged if (seg["end"] - seg["start"]) >= cls.VAD_MIN_DURATION]
|
| 448 |
+
|
| 449 |
+
# Dilate segments (add padding)
|
| 450 |
+
for seg in filtered:
|
| 451 |
+
seg["start"] = max(0.0, seg["start"] - cls.VAD_PAD_ONSET)
|
| 452 |
+
seg["end"] = seg["end"] + cls.VAD_PAD_OFFSET
|
| 453 |
+
seg["start_sample"] = int(seg["start"] * sample_rate)
|
| 454 |
+
seg["end_sample"] = int(seg["end"] * sample_rate)
|
| 455 |
+
|
| 456 |
+
return filtered
|
| 457 |
+
|
| 458 |
+
@classmethod
|
| 459 |
+
def _extract_embeddings(
|
| 460 |
+
cls, audio_array: np.ndarray, segments: list[dict], sample_rate: int
|
| 461 |
+
) -> tuple[np.ndarray, list[dict]]:
|
| 462 |
+
"""Extract speaker embeddings using sliding windows."""
|
| 463 |
+
speaker_model = cls._get_ecapa_model()
|
| 464 |
+
|
| 465 |
+
window_samples = int(cls.WINDOW_SIZE * sample_rate)
|
| 466 |
+
step_samples = int(cls.STEP_SIZE * sample_rate)
|
| 467 |
+
|
| 468 |
+
embeddings = []
|
| 469 |
+
window_segments = []
|
| 470 |
+
|
| 471 |
+
with torch.no_grad():
|
| 472 |
+
for seg in segments:
|
| 473 |
+
seg_start = seg["start_sample"]
|
| 474 |
+
seg_end = seg["end_sample"]
|
| 475 |
+
seg_len = seg_end - seg_start
|
| 476 |
+
|
| 477 |
+
# Generate window positions
|
| 478 |
+
if seg_len <= window_samples:
|
| 479 |
+
starts = [seg_start]
|
| 480 |
+
ends = [seg_end]
|
| 481 |
+
else:
|
| 482 |
+
starts = list(range(seg_start, seg_end - window_samples + 1, step_samples))
|
| 483 |
+
ends = [s + window_samples for s in starts]
|
| 484 |
+
|
| 485 |
+
# Cover tail if > TAIL_COVERAGE_RATIO of window remains
|
| 486 |
+
if ends and ends[-1] < seg_end:
|
| 487 |
+
remainder = seg_end - ends[-1]
|
| 488 |
+
if remainder > (window_samples * cls.TAIL_COVERAGE_RATIO):
|
| 489 |
+
starts.append(seg_end - window_samples)
|
| 490 |
+
ends.append(seg_end)
|
| 491 |
+
|
| 492 |
+
for c_start, c_end in zip(starts, ends):
|
| 493 |
+
chunk = audio_array[c_start:c_end]
|
| 494 |
+
|
| 495 |
+
# Pad short chunks with reflection
|
| 496 |
+
if len(chunk) < window_samples:
|
| 497 |
+
pad_width = window_samples - len(chunk)
|
| 498 |
+
chunk = np.pad(chunk, (0, pad_width), mode="reflect")
|
| 499 |
+
|
| 500 |
+
# Extract embedding using SpeechBrain's encode_batch
|
| 501 |
+
chunk_tensor = torch.from_numpy(chunk).float().unsqueeze(0)
|
| 502 |
+
embedding = (
|
| 503 |
+
speaker_model.encode_batch(chunk_tensor).squeeze(0).squeeze(0).cpu().numpy()
|
| 504 |
+
)
|
| 505 |
+
|
| 506 |
+
# Validate embedding
|
| 507 |
+
if np.isfinite(embedding).all() and np.linalg.norm(embedding) > 1e-8:
|
| 508 |
+
embeddings.append(embedding)
|
| 509 |
+
window_segments.append(
|
| 510 |
+
{
|
| 511 |
+
"start": c_start / sample_rate,
|
| 512 |
+
"end": c_end / sample_rate,
|
| 513 |
+
}
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
# Normalize all embeddings at once
|
| 517 |
+
if embeddings:
|
| 518 |
+
return normalize(np.array(embeddings)), window_segments
|
| 519 |
+
return np.array([]), []
|
| 520 |
+
|
| 521 |
+
@classmethod
|
| 522 |
+
def _resample_vad(cls, vad_frames: list[bool], num_frames: int) -> np.ndarray:
|
| 523 |
+
"""Resample VAD frame decisions to match voting grid resolution.
|
| 524 |
+
|
| 525 |
+
VAD operates at 256 samples / 16000 Hz = 16ms per frame.
|
| 526 |
+
Voting operates at VOTING_RATE (default 10ms) per frame.
|
| 527 |
+
This maps VAD decisions to the finer voting grid.
|
| 528 |
+
"""
|
| 529 |
+
if not vad_frames:
|
| 530 |
+
return np.zeros(num_frames, dtype=bool)
|
| 531 |
+
|
| 532 |
+
vad_rate = 256 / 16000 # 16ms per VAD frame
|
| 533 |
+
vad_arr = np.array(vad_frames)
|
| 534 |
+
|
| 535 |
+
# Vectorized: compute VAD frame indices for each voting frame
|
| 536 |
+
voting_times = np.arange(num_frames) * cls.VOTING_RATE
|
| 537 |
+
vad_indices = np.clip((voting_times / vad_rate).astype(int), 0, len(vad_arr) - 1)
|
| 538 |
+
return vad_arr[vad_indices]
|
| 539 |
+
|
| 540 |
+
@classmethod
|
| 541 |
+
def _postprocess_segments(
|
| 542 |
+
cls,
|
| 543 |
+
window_segments: list[dict],
|
| 544 |
+
labels: np.ndarray,
|
| 545 |
+
total_duration: float,
|
| 546 |
+
vad_frames: list[bool],
|
| 547 |
+
) -> list[dict]:
|
| 548 |
+
"""Post-process using frame-level consensus voting with VAD-aware silence."""
|
| 549 |
+
if not window_segments or len(labels) == 0:
|
| 550 |
+
return []
|
| 551 |
+
|
| 552 |
+
# Correct labels to be contiguous
|
| 553 |
+
unique_labels = np.unique(labels)
|
| 554 |
+
label_map = {old: new for new, old in enumerate(unique_labels)}
|
| 555 |
+
clean_labels = np.array([label_map[lbl] for lbl in labels])
|
| 556 |
+
num_speakers = len(unique_labels)
|
| 557 |
+
|
| 558 |
+
if num_speakers == 0:
|
| 559 |
+
return []
|
| 560 |
+
|
| 561 |
+
# Create voting grid
|
| 562 |
+
num_frames = int(np.ceil(total_duration / cls.VOTING_RATE)) + 1
|
| 563 |
+
votes = np.zeros((num_frames, num_speakers), dtype=np.float32)
|
| 564 |
+
|
| 565 |
+
# Accumulate votes
|
| 566 |
+
for win, label in zip(window_segments, clean_labels):
|
| 567 |
+
start_frame = int(win["start"] / cls.VOTING_RATE)
|
| 568 |
+
end_frame = int(win["end"] / cls.VOTING_RATE)
|
| 569 |
+
end_frame = min(end_frame, num_frames)
|
| 570 |
+
if start_frame < end_frame:
|
| 571 |
+
votes[start_frame:end_frame, label] += 1.0
|
| 572 |
+
|
| 573 |
+
# Determine winner per frame
|
| 574 |
+
frame_speakers = np.argmax(votes, axis=1)
|
| 575 |
+
max_votes = np.max(votes, axis=1)
|
| 576 |
+
|
| 577 |
+
# Resample VAD to voting grid resolution for silence-aware voting
|
| 578 |
+
vad_resampled = cls._resample_vad(vad_frames, num_frames)
|
| 579 |
+
|
| 580 |
+
# Convert frames to segments
|
| 581 |
+
final_segments = []
|
| 582 |
+
current_speaker = -1
|
| 583 |
+
seg_start = 0.0
|
| 584 |
+
|
| 585 |
+
for f in range(num_frames):
|
| 586 |
+
speaker = int(frame_speakers[f])
|
| 587 |
+
score = max_votes[f]
|
| 588 |
+
|
| 589 |
+
# Force silence if VAD says no speech OR no votes
|
| 590 |
+
if score == 0 or not vad_resampled[f]:
|
| 591 |
+
speaker = -1
|
| 592 |
+
|
| 593 |
+
if speaker != current_speaker:
|
| 594 |
+
if current_speaker != -1:
|
| 595 |
+
final_segments.append(
|
| 596 |
+
{
|
| 597 |
+
"speaker": f"SPEAKER_{current_speaker}",
|
| 598 |
+
"start": seg_start,
|
| 599 |
+
"end": f * cls.VOTING_RATE,
|
| 600 |
+
}
|
| 601 |
+
)
|
| 602 |
+
current_speaker = speaker
|
| 603 |
+
seg_start = f * cls.VOTING_RATE
|
| 604 |
+
|
| 605 |
+
# Close last segment
|
| 606 |
+
if current_speaker != -1:
|
| 607 |
+
final_segments.append(
|
| 608 |
+
{
|
| 609 |
+
"speaker": f"SPEAKER_{current_speaker}",
|
| 610 |
+
"start": seg_start,
|
| 611 |
+
"end": num_frames * cls.VOTING_RATE,
|
| 612 |
+
}
|
| 613 |
+
)
|
| 614 |
+
|
| 615 |
+
return cls._merge_short_segments(final_segments)
|
| 616 |
+
|
| 617 |
+
@classmethod
|
| 618 |
+
def _merge_short_segments(cls, segments: list[dict]) -> list[dict]:
|
| 619 |
+
"""Merge short segments to reduce flicker."""
|
| 620 |
+
if not segments:
|
| 621 |
+
return []
|
| 622 |
+
|
| 623 |
+
clean: list[dict] = []
|
| 624 |
+
for seg in segments:
|
| 625 |
+
dur = seg["end"] - seg["start"]
|
| 626 |
+
if dur < cls.MIN_SEGMENT_DURATION:
|
| 627 |
+
if (
|
| 628 |
+
clean
|
| 629 |
+
and clean[-1]["speaker"] == seg["speaker"]
|
| 630 |
+
and seg["start"] - clean[-1]["end"] < cls.SHORT_SEGMENT_GAP
|
| 631 |
+
):
|
| 632 |
+
clean[-1]["end"] = seg["end"]
|
| 633 |
+
continue
|
| 634 |
+
|
| 635 |
+
if (
|
| 636 |
+
clean
|
| 637 |
+
and clean[-1]["speaker"] == seg["speaker"]
|
| 638 |
+
and seg["start"] - clean[-1]["end"] < cls.SAME_SPEAKER_GAP
|
| 639 |
+
):
|
| 640 |
+
clean[-1]["end"] = seg["end"]
|
| 641 |
+
else:
|
| 642 |
+
clean.append(seg)
|
| 643 |
+
|
| 644 |
+
return clean
|
| 645 |
+
|
| 646 |
+
@classmethod
|
| 647 |
+
def assign_speakers_to_words(
|
| 648 |
+
cls,
|
| 649 |
+
words: list[dict],
|
| 650 |
+
speaker_segments: list[dict],
|
| 651 |
+
) -> list[dict]:
|
| 652 |
+
"""Assign speaker labels to words based on timestamp overlap.
|
| 653 |
+
|
| 654 |
+
Args:
|
| 655 |
+
words: List of word dicts with 'word', 'start', 'end' keys
|
| 656 |
+
speaker_segments: List of speaker dicts with 'speaker', 'start', 'end' keys
|
| 657 |
+
|
| 658 |
+
Returns:
|
| 659 |
+
Words list with 'speaker' key added to each word
|
| 660 |
+
"""
|
| 661 |
+
for word in words:
|
| 662 |
+
word_mid = (word["start"] + word["end"]) / 2
|
| 663 |
+
|
| 664 |
+
# Find the speaker segment that contains this word's midpoint
|
| 665 |
+
best_speaker = None
|
| 666 |
+
for seg in speaker_segments:
|
| 667 |
+
if seg["start"] <= word_mid <= seg["end"]:
|
| 668 |
+
best_speaker = seg["speaker"]
|
| 669 |
+
break
|
| 670 |
+
|
| 671 |
+
# If no exact match, find closest segment
|
| 672 |
+
if best_speaker is None and speaker_segments:
|
| 673 |
+
min_dist = float("inf")
|
| 674 |
+
for seg in speaker_segments:
|
| 675 |
+
seg_mid = (seg["start"] + seg["end"]) / 2
|
| 676 |
+
dist = abs(word_mid - seg_mid)
|
| 677 |
+
if dist < min_dist:
|
| 678 |
+
min_dist = dist
|
| 679 |
+
best_speaker = seg["speaker"]
|
| 680 |
+
|
| 681 |
+
word["speaker"] = best_speaker
|
| 682 |
+
|
| 683 |
+
return words
|
| 684 |
+
|
| 685 |
+
|
| 686 |
+
class SpeakerDiarizer:
|
| 687 |
+
"""Speaker diarization using TEN-VAD + ECAPA-TDNN + spectral clustering.
|
| 688 |
+
|
| 689 |
+
Example:
|
| 690 |
+
>>> segments = SpeakerDiarizer.diarize(audio_array)
|
| 691 |
+
>>> for seg in segments:
|
| 692 |
+
... print(f"{seg['speaker']}: {seg['start']:.2f} - {seg['end']:.2f}")
|
| 693 |
+
"""
|
| 694 |
+
|
| 695 |
+
@classmethod
|
| 696 |
+
def diarize(
|
| 697 |
+
cls,
|
| 698 |
+
audio: np.ndarray | str,
|
| 699 |
+
sample_rate: int = 16000,
|
| 700 |
+
num_speakers: int | None = None,
|
| 701 |
+
min_speakers: int | None = None,
|
| 702 |
+
max_speakers: int | None = None,
|
| 703 |
+
**_kwargs,
|
| 704 |
+
) -> list[dict]:
|
| 705 |
+
"""Run speaker diarization on audio.
|
| 706 |
+
|
| 707 |
+
Args:
|
| 708 |
+
audio: Audio waveform as numpy array or path to audio file
|
| 709 |
+
sample_rate: Audio sample rate (default 16000)
|
| 710 |
+
num_speakers: Exact number of speakers (if known)
|
| 711 |
+
min_speakers: Minimum number of speakers
|
| 712 |
+
max_speakers: Maximum number of speakers
|
| 713 |
+
|
| 714 |
+
Returns:
|
| 715 |
+
List of dicts with 'speaker', 'start', 'end' keys
|
| 716 |
+
"""
|
| 717 |
+
return LocalSpeakerDiarizer.diarize(
|
| 718 |
+
audio,
|
| 719 |
+
sample_rate=sample_rate,
|
| 720 |
+
num_speakers=num_speakers,
|
| 721 |
+
min_speakers=min_speakers or 2,
|
| 722 |
+
max_speakers=max_speakers or 10,
|
| 723 |
+
)
|
| 724 |
+
|
| 725 |
+
@classmethod
|
| 726 |
+
def assign_speakers_to_words(
|
| 727 |
+
cls,
|
| 728 |
+
words: list[dict],
|
| 729 |
+
speaker_segments: list[dict],
|
| 730 |
+
) -> list[dict]:
|
| 731 |
+
"""Assign speaker labels to words based on timestamp overlap."""
|
| 732 |
+
return LocalSpeakerDiarizer.assign_speakers_to_words(words, speaker_segments)
|
preprocessor_config.json
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"chunk_length": 30,
|
| 3 |
+
"dither": 0.0,
|
| 4 |
+
"feature_extractor_type": "WhisperFeatureExtractor",
|
| 5 |
+
"feature_size": 128,
|
| 6 |
+
"hop_length": 160,
|
| 7 |
+
"n_fft": 400,
|
| 8 |
+
"n_samples": 480000,
|
| 9 |
+
"nb_max_frames": 3000,
|
| 10 |
+
"padding": false,
|
| 11 |
+
"padding_side": "right",
|
| 12 |
+
"padding_value": 0.0,
|
| 13 |
+
"return_attention_mask": false,
|
| 14 |
+
"sampling_rate": 16000,
|
| 15 |
+
"processor_class": "ASRProcessor",
|
| 16 |
+
"auto_map": {
|
| 17 |
+
"AutoProcessor": "asr_processing.ASRProcessor"
|
| 18 |
+
}
|
| 19 |
+
}
|
projectors.py
ADDED
|
@@ -0,0 +1,505 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Audio projector modules for bridging encoder and decoder embeddings.
|
| 2 |
+
|
| 3 |
+
This module contains all projector architectures:
|
| 4 |
+
- MLPAudioProjector: Simple 2-layer MLP with frame stacking downsampling
|
| 5 |
+
- MOSAProjector: MOSA-style dense mixture of experts
|
| 6 |
+
- SharedMoEAudioProjector: Shared expert + sparse routed experts
|
| 7 |
+
- QFormerAudioProjector: BLIP-2 QFormer with learnable queries (Granite-style)
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import math
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torch.nn.functional as F # noqa: N812
|
| 15 |
+
from transformers import AutoModel, Blip2QFormerConfig
|
| 16 |
+
from transformers.models.llama.modeling_llama import LlamaRMSNorm
|
| 17 |
+
|
| 18 |
+
# =============================================================================
|
| 19 |
+
# MLP Projector
|
| 20 |
+
# =============================================================================
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class MLPAudioProjector(nn.Module):
|
| 24 |
+
"""2-layer MLP projector with frame-stacking downsampling (matches GLM-ASR)."""
|
| 25 |
+
|
| 26 |
+
def __init__(self, config):
|
| 27 |
+
"""Initialize MLP projector.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
config: ASRConfig with encoder_dim, llm_dim, projector_pool_stride
|
| 31 |
+
"""
|
| 32 |
+
super().__init__()
|
| 33 |
+
|
| 34 |
+
encoder_dim = getattr(config, "encoder_dim", 768)
|
| 35 |
+
llm_dim = getattr(config, "llm_dim", 2048)
|
| 36 |
+
self.k = getattr(config, "projector_pool_stride", 4)
|
| 37 |
+
|
| 38 |
+
# Frame stacking: concat k adjacent frames then project
|
| 39 |
+
in_dim = encoder_dim * self.k
|
| 40 |
+
# Hidden dim defaults to llm_dim, can be overridden via config
|
| 41 |
+
hidden_dim = getattr(config, "projector_hidden_dim", None) or llm_dim
|
| 42 |
+
self.linear_1 = nn.Linear(in_dim, hidden_dim, bias=False)
|
| 43 |
+
self.norm = LlamaRMSNorm(hidden_dim, eps=1e-6)
|
| 44 |
+
self.act = nn.GELU()
|
| 45 |
+
self.linear_2 = nn.Linear(hidden_dim, llm_dim, bias=False)
|
| 46 |
+
|
| 47 |
+
def get_output_length(self, input_length: int) -> int:
|
| 48 |
+
"""Calculate output sequence length given input length (matches GLM-ASR)."""
|
| 49 |
+
# GLM-ASR formula: (L - merge_factor) // merge_factor + 1
|
| 50 |
+
return (input_length - self.k) // self.k + 1
|
| 51 |
+
|
| 52 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 53 |
+
"""Project audio features to LLM embedding space.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
x: Audio encoder output of shape [batch, seq_len, encoder_dim]
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
Projected features of shape [batch, (seq_len - k) // k + 1, llm_dim]
|
| 60 |
+
"""
|
| 61 |
+
batch, seq, dim = x.shape
|
| 62 |
+
# Truncate to match GLM-ASR: use (seq - k) // k + 1 frames
|
| 63 |
+
# This drops trailing frames that don't fill a complete k-frame window
|
| 64 |
+
out_len = (seq - self.k) // self.k + 1
|
| 65 |
+
x = x[:, : out_len * self.k, :] # Truncate to exact multiple
|
| 66 |
+
x = x.reshape(batch, out_len, dim * self.k)
|
| 67 |
+
|
| 68 |
+
x = self.linear_1(x)
|
| 69 |
+
x = self.norm(x)
|
| 70 |
+
x = self.act(x)
|
| 71 |
+
return self.linear_2(x)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
# =============================================================================
|
| 75 |
+
# MoE Projector (MOSA-style)
|
| 76 |
+
# =============================================================================
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class SimpleAdapter(nn.Module):
|
| 80 |
+
"""Simple 2-layer GELU adapter (from MOSA paper)."""
|
| 81 |
+
|
| 82 |
+
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
|
| 83 |
+
super().__init__()
|
| 84 |
+
self.fc1 = nn.Linear(input_dim, hidden_dim)
|
| 85 |
+
self.act = nn.GELU()
|
| 86 |
+
self.fc2 = nn.Linear(hidden_dim, output_dim)
|
| 87 |
+
|
| 88 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 89 |
+
return self.fc2(self.act(self.fc1(x)))
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class SwiGLU(nn.Module):
|
| 93 |
+
"""SwiGLU activation with gated linear units (used in LLaMA, Mistral, etc.)."""
|
| 94 |
+
|
| 95 |
+
def __init__(self, dim: int, hidden_dim: int, bias: bool = False):
|
| 96 |
+
super().__init__()
|
| 97 |
+
self.w1 = nn.Linear(dim, hidden_dim, bias=bias) # Gate
|
| 98 |
+
self.w2 = nn.Linear(dim, hidden_dim, bias=bias) # Value
|
| 99 |
+
self.w3 = nn.Linear(hidden_dim, dim, bias=bias) # Output
|
| 100 |
+
|
| 101 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 102 |
+
return self.w3(F.silu(self.w1(x)) * self.w2(x))
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class AsymmetricSwiGLU(nn.Module):
|
| 106 |
+
"""SwiGLU that handles different input and output dimensions."""
|
| 107 |
+
|
| 108 |
+
def __init__(
|
| 109 |
+
self, in_features: int, hidden_features: int, out_features: int, bias: bool = False
|
| 110 |
+
):
|
| 111 |
+
super().__init__()
|
| 112 |
+
self.w1 = nn.Linear(in_features, hidden_features, bias=bias) # Gate
|
| 113 |
+
self.w2 = nn.Linear(in_features, hidden_features, bias=bias) # Value
|
| 114 |
+
self.w3 = nn.Linear(hidden_features, out_features, bias=bias) # Output
|
| 115 |
+
|
| 116 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 117 |
+
return self.w3(F.silu(self.w1(x)) * self.w2(x))
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class MOSAProjector(nn.Module):
|
| 121 |
+
"""MOSA-Base projector: simple 2-layer ReLU router with 4 simple adapters.
|
| 122 |
+
|
| 123 |
+
Based on "MOSA: Mixtures of Simple Adapters" (arXiv:2508.18998).
|
| 124 |
+
Uses softmax gating over all experts (dense MoE) with only cross-entropy loss.
|
| 125 |
+
Uses Conv1d for downsampling (2 layers, stride 2 each = 4x total).
|
| 126 |
+
"""
|
| 127 |
+
|
| 128 |
+
def __init__(self, config):
|
| 129 |
+
"""Initialize MOSA projector.
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
config: ASRConfig with encoder_dim, llm_dim, num_experts
|
| 133 |
+
"""
|
| 134 |
+
super().__init__()
|
| 135 |
+
self.encoder_dim = getattr(config, "encoder_dim", None) or 1280
|
| 136 |
+
self.llm_dim = getattr(config, "llm_dim", None) or 2048
|
| 137 |
+
self.num_experts = getattr(config, "num_experts", None) or 4 # MOSA-Base uses 4
|
| 138 |
+
adapter_hidden = getattr(config, "adapter_hidden_dim", None) or 4096
|
| 139 |
+
router_hidden = getattr(config, "router_hidden_dim", None) or 512
|
| 140 |
+
|
| 141 |
+
# --- 1. Conv1d Downsampler (4x reduction) ---
|
| 142 |
+
# 2 layers of stride-2 convolution
|
| 143 |
+
self.downsampler = nn.Sequential(
|
| 144 |
+
nn.Conv1d(self.encoder_dim, self.encoder_dim, kernel_size=3, stride=2, padding=1),
|
| 145 |
+
nn.GELU(),
|
| 146 |
+
nn.Conv1d(self.encoder_dim, self.llm_dim, kernel_size=3, stride=2, padding=1),
|
| 147 |
+
nn.GELU(),
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
# --- 2. Simple Router (MOSA-Base: 2 layers with ReLU) ---
|
| 151 |
+
# Takes downsampled features (llm_dim) -> 512 -> num_experts
|
| 152 |
+
self.router = nn.Sequential(
|
| 153 |
+
nn.Linear(self.llm_dim, router_hidden),
|
| 154 |
+
nn.ReLU(),
|
| 155 |
+
nn.Linear(router_hidden, self.num_experts),
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
# --- 3. Experts (Simple 2-layer GELU adapters) ---
|
| 159 |
+
# Each expert: llm_dim -> hidden -> llm_dim (much smaller than frame-stacking)
|
| 160 |
+
self.experts = nn.ModuleList(
|
| 161 |
+
[
|
| 162 |
+
SimpleAdapter(self.llm_dim, adapter_hidden, self.llm_dim)
|
| 163 |
+
for _ in range(self.num_experts)
|
| 164 |
+
]
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 168 |
+
"""Project audio features using mixture of experts.
|
| 169 |
+
|
| 170 |
+
Args:
|
| 171 |
+
x: Audio encoder output of shape [batch, seq_len, encoder_dim]
|
| 172 |
+
|
| 173 |
+
Returns:
|
| 174 |
+
Projected features of shape [batch, out_len, llm_dim]
|
| 175 |
+
"""
|
| 176 |
+
# --- 1. Conv1d Downsampling ---
|
| 177 |
+
# Permute for Conv1d: [B, S, D] -> [B, D, S]
|
| 178 |
+
x = x.transpose(1, 2)
|
| 179 |
+
x = self.downsampler(x)
|
| 180 |
+
# Permute back: [B, D, S] -> [B, S, D]
|
| 181 |
+
x = x.transpose(1, 2)
|
| 182 |
+
|
| 183 |
+
# --- 2. Routing ---
|
| 184 |
+
routing_weights = F.softmax(self.router(x), dim=-1) # (B, out_len, num_experts)
|
| 185 |
+
|
| 186 |
+
# --- 3. Expert Mixture (Dense Execution) ---
|
| 187 |
+
expert_outputs = torch.stack([expert(x) for expert in self.experts]) # (E, B, out_len, D)
|
| 188 |
+
return torch.einsum("ebsd, bse -> bsd", expert_outputs, routing_weights)
|
| 189 |
+
|
| 190 |
+
def get_output_length(self, input_length: int) -> int:
|
| 191 |
+
"""Calculate output sequence length after Conv1d downsampling (4x reduction)."""
|
| 192 |
+
# Conv1d with stride 2, kernel 3, padding 1: out = (in + 2*1 - 3) // 2 + 1 = (in - 1) // 2 + 1
|
| 193 |
+
# Applied twice for 4x total reduction
|
| 194 |
+
after_conv1 = (input_length + 2 * 1 - 3) // 2 + 1
|
| 195 |
+
return (after_conv1 + 2 * 1 - 3) // 2 + 1
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
# =============================================================================
|
| 199 |
+
# MoE Projector (Pure PyTorch with Shared Expert)
|
| 200 |
+
# =============================================================================
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
class MoEAudioProjector(nn.Module):
|
| 204 |
+
"""MoE projector with shared expert (DeepSeek-style), pure PyTorch implementation.
|
| 205 |
+
|
| 206 |
+
Uses 4 sparse experts with top-2 routing plus a shared expert that processes all tokens.
|
| 207 |
+
No external dependencies (megablocks removed).
|
| 208 |
+
|
| 209 |
+
Architecture matches main branch: norm → experts(in_dim → hidden → out_dim)
|
| 210 |
+
"""
|
| 211 |
+
|
| 212 |
+
def __init__(self, config):
|
| 213 |
+
"""Initialize MoE projector.
|
| 214 |
+
|
| 215 |
+
Args:
|
| 216 |
+
config: ASRConfig with encoder_dim, llm_dim, num_experts, num_experts_per_tok
|
| 217 |
+
"""
|
| 218 |
+
super().__init__()
|
| 219 |
+
|
| 220 |
+
self.k = getattr(config, "projector_pool_stride", 4)
|
| 221 |
+
self.aux_coef = getattr(config, "router_aux_loss_coef", 0.01)
|
| 222 |
+
|
| 223 |
+
# Stability coefficients
|
| 224 |
+
self.router_z_loss_coef = getattr(
|
| 225 |
+
config, "router_z_loss_coef", 1e-4
|
| 226 |
+
) # Prevents logit explosion
|
| 227 |
+
self.router_jitter_noise = getattr(
|
| 228 |
+
config, "router_jitter_noise", 0.01
|
| 229 |
+
) # Prevents expert collapse
|
| 230 |
+
|
| 231 |
+
in_dim = config.encoder_dim * self.k
|
| 232 |
+
out_dim = config.llm_dim
|
| 233 |
+
|
| 234 |
+
# Expert hidden dim (default = output dim)
|
| 235 |
+
hidden_dim = getattr(config, "projector_hidden_dim", None) or out_dim
|
| 236 |
+
|
| 237 |
+
# Number of experts and top-k selection
|
| 238 |
+
self.num_experts = getattr(config, "num_experts", 4)
|
| 239 |
+
self.top_k = getattr(config, "num_experts_per_tok", 2)
|
| 240 |
+
|
| 241 |
+
# A. Normalize stacked input (like main branch SharedMoEBlock)
|
| 242 |
+
self.norm = LlamaRMSNorm(in_dim, eps=1e-6)
|
| 243 |
+
|
| 244 |
+
# B. Router (operates on stacked input)
|
| 245 |
+
self.router = nn.Linear(in_dim, self.num_experts, bias=False)
|
| 246 |
+
|
| 247 |
+
# C. Experts: simple 2-layer MLP (same as MLPAudioProjector)
|
| 248 |
+
self.experts = nn.ModuleList(
|
| 249 |
+
[SimpleAdapter(in_dim, hidden_dim, out_dim) for _ in range(self.num_experts)]
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
# D. Shared Expert (same architecture)
|
| 253 |
+
self.shared_expert = SimpleAdapter(in_dim, hidden_dim, out_dim)
|
| 254 |
+
|
| 255 |
+
# E. Initialize weights for stable training
|
| 256 |
+
self._init_weights()
|
| 257 |
+
|
| 258 |
+
self.last_aux_loss = torch.tensor(0.0)
|
| 259 |
+
|
| 260 |
+
def _init_weights(self):
|
| 261 |
+
"""Initialize weights for stable training start."""
|
| 262 |
+
with torch.no_grad():
|
| 263 |
+
# Router: small weights -> uniform probability
|
| 264 |
+
nn.init.normal_(self.router.weight, mean=0.0, std=0.02)
|
| 265 |
+
|
| 266 |
+
# Experts: xavier for fc1, small for fc2 (output)
|
| 267 |
+
for expert in [self.shared_expert, *self.experts]:
|
| 268 |
+
nn.init.xavier_uniform_(expert.fc1.weight)
|
| 269 |
+
nn.init.normal_(expert.fc2.weight, mean=0.0, std=0.01) # Small init
|
| 270 |
+
|
| 271 |
+
def get_output_length(self, input_length: int) -> int:
|
| 272 |
+
"""Calculate output sequence length given input length (matches MLP projector)."""
|
| 273 |
+
return (input_length - self.k) // self.k + 1
|
| 274 |
+
|
| 275 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 276 |
+
"""Project audio features using shared + sparse MoE.
|
| 277 |
+
|
| 278 |
+
Args:
|
| 279 |
+
x: Audio encoder output of shape [batch, seq_len, encoder_dim]
|
| 280 |
+
|
| 281 |
+
Returns:
|
| 282 |
+
Projected features of shape [batch, out_len, llm_dim]
|
| 283 |
+
"""
|
| 284 |
+
# 1. Frame Stacking
|
| 285 |
+
batch, seq, dim = x.shape
|
| 286 |
+
out_len = (seq - self.k) // self.k + 1
|
| 287 |
+
x = x[:, : out_len * self.k, :]
|
| 288 |
+
x = x.reshape(batch, out_len, dim * self.k)
|
| 289 |
+
|
| 290 |
+
# 2. Normalize stacked input (like main branch SharedMoEBlock)
|
| 291 |
+
x = self.norm(x)
|
| 292 |
+
flat_x = x.view(-1, x.size(-1)) # [tokens, in_dim]
|
| 293 |
+
|
| 294 |
+
# 3. Shared Expert (compute first, creates output tensor)
|
| 295 |
+
output = self.shared_expert(flat_x)
|
| 296 |
+
|
| 297 |
+
# 4. Sparse Experts (in-place add to shared output)
|
| 298 |
+
self.last_aux_loss = self._forward_sparse(flat_x, output)
|
| 299 |
+
|
| 300 |
+
return output.view(batch, out_len, -1)
|
| 301 |
+
|
| 302 |
+
def _forward_sparse(self, x: torch.Tensor, output: torch.Tensor) -> torch.Tensor:
|
| 303 |
+
"""Stability-hardened sparse expert dispatch (in-place add to output).
|
| 304 |
+
|
| 305 |
+
Args:
|
| 306 |
+
x: Flattened input of shape [tokens, dim]
|
| 307 |
+
output: Output tensor to add sparse expert results into (in-place)
|
| 308 |
+
|
| 309 |
+
Returns:
|
| 310 |
+
Auxiliary loss tensor
|
| 311 |
+
"""
|
| 312 |
+
# A. Router Logic with Jitter
|
| 313 |
+
logits = self.router(x)
|
| 314 |
+
|
| 315 |
+
if self.training and self.router_jitter_noise > 0:
|
| 316 |
+
# Jitter: multiply by uniform noise (1-eps, 1+eps) to shake decision boundary
|
| 317 |
+
# Prevents router from getting stuck on one expert early in training
|
| 318 |
+
noise = torch.empty_like(logits).uniform_(
|
| 319 |
+
1.0 - self.router_jitter_noise, 1.0 + self.router_jitter_noise
|
| 320 |
+
)
|
| 321 |
+
logits = logits * noise
|
| 322 |
+
|
| 323 |
+
# Force float32 for softmax (bf16/fp16 exponentials can overflow)
|
| 324 |
+
probs = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(x)
|
| 325 |
+
|
| 326 |
+
# B. Top-K Selection
|
| 327 |
+
top_k_weights, top_k_indices = torch.topk(probs, self.top_k, dim=-1)
|
| 328 |
+
|
| 329 |
+
# Normalize weights so they sum to 1.0
|
| 330 |
+
top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-6)
|
| 331 |
+
|
| 332 |
+
# C. Aux Loss + Z-Loss
|
| 333 |
+
aux_loss = torch.tensor(0.0, device=x.device)
|
| 334 |
+
|
| 335 |
+
if self.training:
|
| 336 |
+
# Load balancing loss (batch-size invariant)
|
| 337 |
+
prob_per_expert = probs.mean(0) # [num_experts]
|
| 338 |
+
target = 1.0 / self.num_experts
|
| 339 |
+
balance_loss = (
|
| 340 |
+
self.aux_coef * ((prob_per_expert - target) ** 2).mean() * self.num_experts
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
# Z-loss: penalty on large logits to prevent softmax saturation
|
| 344 |
+
z_loss = self.router_z_loss_coef * torch.logsumexp(logits, dim=-1).pow(2).mean()
|
| 345 |
+
|
| 346 |
+
aux_loss = balance_loss + z_loss
|
| 347 |
+
|
| 348 |
+
# D. Dispatch Loop (in-place add to output)
|
| 349 |
+
for i, expert in enumerate(self.experts):
|
| 350 |
+
# Create boolean mask for tokens that selected Expert 'i'
|
| 351 |
+
mask = top_k_indices == i
|
| 352 |
+
|
| 353 |
+
if mask.any():
|
| 354 |
+
# token_idx = which tokens, k_idx = 1st or 2nd choice
|
| 355 |
+
token_idx, k_idx = torch.where(mask)
|
| 356 |
+
|
| 357 |
+
# Gather inputs and compute
|
| 358 |
+
expert_input = x[token_idx]
|
| 359 |
+
expert_output = expert(expert_input)
|
| 360 |
+
|
| 361 |
+
# Apply routing weight
|
| 362 |
+
weight = top_k_weights[token_idx, k_idx].unsqueeze(-1)
|
| 363 |
+
weighted_output = (expert_output * weight).type_as(output)
|
| 364 |
+
|
| 365 |
+
# Scatter back in-place (index_add_ is atomic and deterministic)
|
| 366 |
+
output.index_add_(0, token_idx, weighted_output)
|
| 367 |
+
|
| 368 |
+
return aux_loss
|
| 369 |
+
|
| 370 |
+
def get_aux_loss(self) -> torch.Tensor:
|
| 371 |
+
"""Return auxiliary load balancing loss."""
|
| 372 |
+
return self.last_aux_loss
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
# =============================================================================
|
| 376 |
+
# QFormer Projector (Granite-style)
|
| 377 |
+
# =============================================================================
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
class QFormerAudioProjector(nn.Module):
|
| 381 |
+
"""
|
| 382 |
+
BLIP-2 QFormer projector with learnable queries.
|
| 383 |
+
|
| 384 |
+
Based on GraniteSpeechEncoderProjector - uses a QFormer model with learnable
|
| 385 |
+
query embeddings to compress and project audio encoder outputs. The audio
|
| 386 |
+
sequence is processed in windows and downsampled via cross-attention.
|
| 387 |
+
"""
|
| 388 |
+
|
| 389 |
+
def __init__(self, config):
|
| 390 |
+
"""Initialize QFormer projector.
|
| 391 |
+
|
| 392 |
+
Args:
|
| 393 |
+
config: ASRConfig with encoder_dim, llm_dim, qformer_* settings
|
| 394 |
+
"""
|
| 395 |
+
super().__init__()
|
| 396 |
+
|
| 397 |
+
encoder_dim = config.encoder_dim
|
| 398 |
+
llm_dim = config.llm_dim
|
| 399 |
+
|
| 400 |
+
# Window and downsampling parameters (Granite defaults: window=15, downsample=5)
|
| 401 |
+
self.window_size = getattr(config, "qformer_window_size", 15)
|
| 402 |
+
self.downsample_rate = getattr(config, "downsample_rate", 5)
|
| 403 |
+
self.num_queries = self.window_size // self.downsample_rate
|
| 404 |
+
|
| 405 |
+
# QFormer hidden size (matches encoder for cross-attention)
|
| 406 |
+
qformer_hidden = getattr(config, "qformer_hidden_size", None) or encoder_dim
|
| 407 |
+
qformer_num_layers = getattr(config, "qformer_num_layers", 2)
|
| 408 |
+
qformer_num_heads = getattr(config, "qformer_num_heads", 16)
|
| 409 |
+
qformer_intermediate = getattr(config, "qformer_intermediate_size", None) or (
|
| 410 |
+
qformer_hidden * 4
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
# Learnable query embeddings (Granite uses std=1.0)
|
| 414 |
+
self.query = nn.Parameter(torch.zeros(1, self.num_queries, qformer_hidden))
|
| 415 |
+
self.query.data.normal_(mean=0.0, std=1.0)
|
| 416 |
+
|
| 417 |
+
# Optional projection if encoder dim != qformer hidden
|
| 418 |
+
if encoder_dim != qformer_hidden:
|
| 419 |
+
self.encoder_proj = nn.Linear(encoder_dim, qformer_hidden, bias=False)
|
| 420 |
+
else:
|
| 421 |
+
self.encoder_proj = None
|
| 422 |
+
|
| 423 |
+
# Configure QFormer to match Granite's exact config
|
| 424 |
+
qformer_config = Blip2QFormerConfig(
|
| 425 |
+
hidden_size=qformer_hidden,
|
| 426 |
+
num_hidden_layers=qformer_num_layers,
|
| 427 |
+
num_attention_heads=qformer_num_heads,
|
| 428 |
+
intermediate_size=qformer_intermediate,
|
| 429 |
+
encoder_hidden_size=qformer_hidden,
|
| 430 |
+
cross_attention_frequency=1,
|
| 431 |
+
# Granite-specific settings
|
| 432 |
+
hidden_act="gelu",
|
| 433 |
+
attention_probs_dropout_prob=0.1,
|
| 434 |
+
hidden_dropout_prob=0.1,
|
| 435 |
+
layer_norm_eps=1e-12,
|
| 436 |
+
initializer_range=0.02,
|
| 437 |
+
)
|
| 438 |
+
self.qformer = AutoModel.from_config(qformer_config)
|
| 439 |
+
|
| 440 |
+
# Final projection to LLM dimension (Granite uses bias=True)
|
| 441 |
+
self.linear = nn.Linear(qformer_hidden, llm_dim)
|
| 442 |
+
|
| 443 |
+
def get_output_length(self, input_length: int) -> int:
|
| 444 |
+
"""Calculate output sequence length given input length."""
|
| 445 |
+
# QFormer uses window-based processing with num_queries per window
|
| 446 |
+
nblocks = math.ceil(input_length / self.window_size)
|
| 447 |
+
return nblocks * self.num_queries
|
| 448 |
+
|
| 449 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 450 |
+
"""
|
| 451 |
+
Args:
|
| 452 |
+
hidden_states: [batch_size, seq_len, encoder_dim]
|
| 453 |
+
|
| 454 |
+
Returns:
|
| 455 |
+
projected: [batch_size, num_output_tokens, llm_dim]
|
| 456 |
+
"""
|
| 457 |
+
batch_size, seq_len, dim = hidden_states.size()
|
| 458 |
+
|
| 459 |
+
# Ensure float dtype for QFormer
|
| 460 |
+
target_dtype = self.query.dtype
|
| 461 |
+
if hidden_states.dtype != target_dtype:
|
| 462 |
+
hidden_states = hidden_states.to(target_dtype)
|
| 463 |
+
|
| 464 |
+
# Optional encoder projection
|
| 465 |
+
if self.encoder_proj is not None:
|
| 466 |
+
hidden_states = self.encoder_proj(hidden_states)
|
| 467 |
+
|
| 468 |
+
# Compute number of windows and pad to fit
|
| 469 |
+
nblocks = math.ceil(seq_len / self.window_size)
|
| 470 |
+
pad = nblocks * self.window_size - seq_len
|
| 471 |
+
if pad > 0:
|
| 472 |
+
hidden_states = F.pad(hidden_states, (0, 0, 0, pad), "constant", 0)
|
| 473 |
+
|
| 474 |
+
# Reshape to process each window: [batch*nblocks, window_size, dim]
|
| 475 |
+
effective_batch = batch_size * nblocks
|
| 476 |
+
hidden_states = hidden_states.view(effective_batch, self.window_size, -1)
|
| 477 |
+
|
| 478 |
+
# Expand queries to match batch size
|
| 479 |
+
query_embeds = self.query.expand(effective_batch, -1, -1)
|
| 480 |
+
|
| 481 |
+
# QFormer cross-attention
|
| 482 |
+
query_output = self.qformer(
|
| 483 |
+
query_embeds=query_embeds,
|
| 484 |
+
encoder_hidden_states=hidden_states,
|
| 485 |
+
return_dict=True,
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
# Reshape back: [batch, nblocks * num_queries, hidden]
|
| 489 |
+
output_tokens = nblocks * self.num_queries
|
| 490 |
+
query_proj = query_output.last_hidden_state.view(batch_size, output_tokens, -1)
|
| 491 |
+
|
| 492 |
+
# Project to LLM dimension
|
| 493 |
+
return self.linear(query_proj)
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
# =============================================================================
|
| 497 |
+
# Projector Registry
|
| 498 |
+
# =============================================================================
|
| 499 |
+
|
| 500 |
+
PROJECTOR_CLASSES = {
|
| 501 |
+
"mlp": MLPAudioProjector,
|
| 502 |
+
"mosa": MOSAProjector,
|
| 503 |
+
"moe": MoEAudioProjector,
|
| 504 |
+
"qformer": QFormerAudioProjector,
|
| 505 |
+
}
|
tokenizer.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:33b674fb8444e2553eae8f1b261093371920a28ef75b5c18f4deb3f9217ed0ba
|
| 3 |
+
size 11422834
|
tokenizer_config.json
ADDED
|
Binary file (396 Bytes). View file
|
|
|