File size: 4,054 Bytes
97a9fba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a55bd45
6e0631b
 
 
97a9fba
6e0631b
97a9fba
 
5f5b1b7
97a9fba
 
 
 
 
 
 
 
 
 
 
 
 
 
6e0631b
5f5b1b7
97a9fba
 
 
 
 
 
 
 
 
 
5f5b1b7
97a9fba
 
 
 
 
5f5b1b7
97a9fba
 
 
 
 
 
6e0631b
97a9fba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
---
license: mit
datasets:
- mozilla-foundation/common_voice_17_0
language:
- en
- es
- ar
- fr
- de
- it
- pt
- ru
- zh
- ja
metrics:
- accuracy
base_model:
- hubertsiuzdak/snac_24khz
pipeline_tag: audio-classification
tags:
- audio
- language
- classification
---
# Audio Language Classifier (SNAC backbone, Common Voice 17.0)
First iteration of a lightweight (7M parameter) model for detecting language from a speech audio. Code is available at [GitHub](https://github.com/surus-lat/audio-language-classification) 

In short:
- Identification of spoken language in audio (10 languages)
- Backbone: SNAC (hubertsiuzdak/snac_24khz) with attention pooling
- Dataset used: Mozilla Common Voice 17.0 (streaming)
- Sample rate: 24 kHz; Max audio length: 10 s (pad/trim)
- Mixed precision: FP16
- Best validation accuracy: 0.57

Supported languages (labels):
- en, es, fr, de, it, pt, ru, zh-CN, ja, ar

Intended use:
- Classify the language of short speech segments (≤10 s).
- Not for ASR or dialect/variant classification.

Out-of-scope:
- Very long audio, code-switching, overlapping speakers, noisy or music-heavy inputs.

Data:
- Source: Mozilla Common Voice 17.0 (streaming; per-language subset).
- License: CC-0 (check dataset card for details).
- Splits: Official validation/test splits used (use_official_splits: true). Parquet branch to handle the large sizes
- Percent slice per split used during training: 50%.

Model architecture:
- Backbone: SNAC encoder (pretrained).
- Pooling: Attention pooling over time.
- Head:
  - Linear(feature_dim → 512), ReLU, Dropout(0.1)
  - Linear(512 → 256), ReLU, Dropout(0.1)
  - Linear(256 → 10)
- Selective tuning:
  - Start frozen (backbone_tune_strategy: "frozen")
  - Unfreeze strategy at epoch 2: "last_n_blocks" with last_n_blocks: 1
  - Gradient checkpointing enabled for backbone.

Training setup:
- Batch size: 48
- Epochs: up to 100 (early stopping patience: 15)
- Streaming steps per epoch: 2000
- Optimizer: AdamW (betas: 0.9, 0.999; eps: 1e-8)
- Learning rate: head 1e-4; backbone 2e-5 (after unfreeze)
- Scheduler: cosine with warmup (num_warmup_steps: 2000)
- Label smoothing: 0.1
- Max grad norm: 1.0
- Seed: 42
- Hardware: 1x RTX3090; FP16 enabled

Preprocessing:
- Mono waveform at 24 kHz; pad/trim to 10 s.
- Normalization handled by torchaudio/Tensor transforms in pipeline.

Evaluation results:
- Validation:
  - Best accuracy: 0.5016
- Test:
  - accuracy: 0.3830
  - f1_micro: 0.3830
  - f1_macro: 0.3624
  - f1_weighted: 0.3666
  - loss: 2.2467

Files and checkpoints:
- Checkpoints dir: ./training
  - best_model.pt
  - language_mapping.txt (idx: language)
  - final_results.txt

How to use (inference):
```python
import torch
from models import LanguageClassifier

device = "cuda" if torch.cuda.is_available() else "cpu"

# Build and load from a directory containing best_model.pt and language_mapping.txt
model = LanguageClassifier.from_pretrained("training", device=device)

# Single-file prediction (auto resample to 24k, pad/trim to 10s)
label, prob = model.predict("example.wav", max_length_seconds=10.0, top_k=1)
print(label, prob)

# Top-3
top3 = model.predict("example.wav", top_k=3)
print(top3)  # [('en', 0.62), ('de', 0.21), ('fr', 0.08)]

# If you already have a waveform tensor:
#   wav: torch.Tensor [T] at 24kHz (or provide sample_rate to auto-resample)
#   model.predict handles [T] or [B,T]
# label, prob = model.predict(wav, sample_rate=orig_sr, top_k=1)
```

Limitations and risks:
- Accuracy varies across speakers, accents, microphones, and noise conditions.
- May misclassify short utterances or code-switched speech.
- Not suitable for sensitive decision making without human review.

Reproducibility:
- Default config: ./config.yaml
- Training script: ./train.py
- To visualize internals: CHECKPOINT_DIR=training/ python viualization.py

Citation and acknowledgements:
- [SNAC: hubertsiuzdak/snac_24khz](https://github.com/hubertsiuzdak/snac/)
- [Dataset: Mozilla Common Voice 17.0](https://huggingface.co/datasets/mozilla-foundation/common_voice_17_0)