mazesmazes commited on
Commit
5fe7953
·
verified ·
1 Parent(s): e425e43

Training in progress - step 1000

Browse files
Files changed (3) hide show
  1. .gitattributes +1 -0
  2. asr_modeling.py +7 -3
  3. projectors.py +161 -140
.gitattributes CHANGED
@@ -1,3 +1,4 @@
1
  *.safetensors filter=lfs diff=lfs merge=lfs -text
2
  *.bin filter=lfs diff=lfs merge=lfs -text
3
  tokenizer_config.json -filter -diff -merge text
 
 
1
  *.safetensors filter=lfs diff=lfs merge=lfs -text
2
  *.bin filter=lfs diff=lfs merge=lfs -text
3
  tokenizer_config.json -filter -diff -merge text
4
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
asr_modeling.py CHANGED
@@ -145,10 +145,12 @@ class ASRModel(PreTrainedModel, GenerationMixin):
145
  self.generation_config.length_penalty = config.length_penalty
146
  self.generation_config.repetition_penalty = config.repetition_penalty
147
  self.generation_config.no_repeat_ngram_size = config.no_repeat_ngram_size
148
- self.generation_config.eos_token_id = [
 
149
  self.tokenizer.convert_tokens_to_ids("<|im_end|>"),
150
  self.tokenizer.convert_tokens_to_ids("<|endoftext|>"),
151
  ]
 
152
  self.generation_config.pad_token_id = self.tokenizer.pad_token_id
153
 
154
  # Feature extractor for audio preprocessing
@@ -233,7 +235,6 @@ class ASRModel(PreTrainedModel, GenerationMixin):
233
  decoder_kwargs = {
234
  "attn_implementation": config.attn_implementation,
235
  "trust_remote_code": True,
236
- "tie_word_embeddings": False,
237
  "low_cpu_mem_usage": True,
238
  "dtype": dtype,
239
  }
@@ -419,7 +420,10 @@ class ASRModel(PreTrainedModel, GenerationMixin):
419
  # Compute per-sample encoder output lengths using conv formulas
420
  encoder_lengths = self._compute_encoder_output_lengths(audio_attention_mask)
421
  token_counts = torch.tensor(
422
- [self.projector.get_output_length(int(length.item())) for length in encoder_lengths],
 
 
 
423
  device=audio_embeds.device,
424
  )
425
 
 
145
  self.generation_config.length_penalty = config.length_penalty
146
  self.generation_config.repetition_penalty = config.repetition_penalty
147
  self.generation_config.no_repeat_ngram_size = config.no_repeat_ngram_size
148
+ # Set EOS tokens, filtering out any that don't exist in the tokenizer
149
+ eos_candidates = [
150
  self.tokenizer.convert_tokens_to_ids("<|im_end|>"),
151
  self.tokenizer.convert_tokens_to_ids("<|endoftext|>"),
152
  ]
153
+ self.generation_config.eos_token_id = [t for t in eos_candidates if t is not None]
154
  self.generation_config.pad_token_id = self.tokenizer.pad_token_id
155
 
156
  # Feature extractor for audio preprocessing
 
235
  decoder_kwargs = {
236
  "attn_implementation": config.attn_implementation,
237
  "trust_remote_code": True,
 
238
  "low_cpu_mem_usage": True,
239
  "dtype": dtype,
240
  }
 
420
  # Compute per-sample encoder output lengths using conv formulas
421
  encoder_lengths = self._compute_encoder_output_lengths(audio_attention_mask)
422
  token_counts = torch.tensor(
423
+ [
424
+ self.projector.get_output_length(int(length.item()))
425
+ for length in encoder_lengths
426
+ ],
427
  device=audio_embeds.device,
428
  )
429
 
projectors.py CHANGED
@@ -36,12 +36,13 @@ class MLPAudioProjector(nn.Module):
36
  self.k = getattr(config, "projector_pool_stride", 4)
37
 
38
  # Frame stacking: concat k adjacent frames then project
39
- # Hidden dim uses 2x expansion like GLM-ASR's GlmAsrMultiModalProjector
40
  in_dim = encoder_dim * self.k
41
- hidden_dim = llm_dim * 2
42
- self.linear_1 = nn.Linear(in_dim, hidden_dim)
 
 
43
  self.act = nn.GELU()
44
- self.linear_2 = nn.Linear(hidden_dim, llm_dim)
45
 
46
  def get_output_length(self, input_length: int) -> int:
47
  """Calculate output sequence length given input length (matches GLM-ASR)."""
@@ -65,6 +66,7 @@ class MLPAudioProjector(nn.Module):
65
  x = x.reshape(batch, out_len, dim * self.k)
66
 
67
  x = self.linear_1(x)
 
68
  x = self.act(x)
69
  return self.linear_2(x)
70
 
@@ -87,6 +89,34 @@ class SimpleAdapter(nn.Module):
87
  return self.fc2(self.act(self.fc1(x)))
88
 
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  class MOSAProjector(nn.Module):
91
  """MOSA-Base projector: simple 2-layer ReLU router with 4 simple adapters.
92
 
@@ -166,109 +196,18 @@ class MOSAProjector(nn.Module):
166
 
167
 
168
  # =============================================================================
169
- # MoE Projector (Shared Expert + Sparse Routed Experts)
170
  # =============================================================================
171
 
172
 
173
- class SharedMoEBlock(nn.Module):
174
- """MoE block with Shared + Sigmoid-Routed Experts."""
175
-
176
- def __init__(
177
- self,
178
- input_dim: int,
179
- hidden_dim: int,
180
- output_dim: int,
181
- num_experts: int = 4,
182
- top_k: int = 2,
183
- ):
184
- super().__init__()
185
- self.num_experts = num_experts
186
- self.top_k = top_k
187
- self.output_dim = output_dim
188
-
189
- # RMSNorm before routing
190
- self.norm = LlamaRMSNorm(input_dim, eps=1e-8)
191
-
192
- self.router = nn.Linear(input_dim, num_experts, bias=False)
193
- nn.init.normal_(self.router.weight, mean=0.0, std=0.02)
194
-
195
- self.shared_expert = SimpleAdapter(input_dim, hidden_dim, output_dim)
196
- self.experts = nn.ModuleList(
197
- [SimpleAdapter(input_dim, hidden_dim, output_dim) for _ in range(num_experts)]
198
- )
199
-
200
- self.last_router_logits = None
201
- self.last_router_probs = None
202
-
203
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
204
- batch_size, seq_len, dim = hidden_states.shape
205
-
206
- # 1. Apply Shared Expert
207
- normed_states = self.norm(hidden_states)
208
- shared_out = self.shared_expert(normed_states)
209
-
210
- # 2. Router Logic (Sigmoid Style)
211
- flat_hidden = normed_states.view(-1, dim)
212
- router_logits = self.router(flat_hidden)
213
-
214
- # Sigmoid routing
215
- router_probs = torch.sigmoid(router_logits)
216
-
217
- self.last_router_logits = router_logits
218
- self.last_router_probs = router_probs
219
-
220
- # 3. Top-K Selection
221
- top_k_scores, top_k_indices = torch.topk(router_probs, self.top_k, dim=-1)
222
-
223
- # Normalize weights
224
- top_k_weights = top_k_scores / (top_k_scores.sum(dim=-1, keepdim=True) + 1e-6)
225
- top_k_weights = top_k_weights.to(hidden_states.dtype)
226
-
227
- # 4. Dispatch
228
- routed_out = self._dispatch_experts(flat_hidden, top_k_indices, top_k_weights)
229
- routed_out = routed_out.view(batch_size, seq_len, -1)
230
-
231
- return shared_out + routed_out
232
-
233
- def _dispatch_experts(
234
- self,
235
- hidden_states: torch.Tensor,
236
- top_k_indices: torch.Tensor,
237
- top_k_weights: torch.Tensor,
238
- ) -> torch.Tensor:
239
- num_tokens = hidden_states.shape[0]
240
- output = torch.zeros(
241
- num_tokens, self.output_dim, device=hidden_states.device, dtype=hidden_states.dtype
242
- )
243
-
244
- for expert_idx, expert in enumerate(self.experts):
245
- expert_mask = top_k_indices == expert_idx
246
- if not expert_mask.any():
247
- continue
248
-
249
- token_indices, slot_indices = torch.where(expert_mask)
250
- expert_input = hidden_states[token_indices]
251
- expert_output = expert(expert_input).to(output.dtype)
252
- weights = top_k_weights[token_indices, slot_indices].unsqueeze(-1)
253
- output.index_add_(0, token_indices, expert_output * weights)
254
-
255
- return output
256
-
257
-
258
- def load_balancing_loss(router_probs: torch.Tensor, num_experts: int, top_k: int) -> torch.Tensor:
259
- """Auxiliary loss to encourage balanced expert usage."""
260
- prob_per_expert = router_probs.mean(dim=0)
261
- target_mean = prob_per_expert.mean()
262
- return (prob_per_expert - target_mean).square().sum() * num_experts
263
-
264
-
265
- def z_loss(router_logits: torch.Tensor) -> torch.Tensor:
266
- """Z-loss to prevent router logits from growing too large."""
267
- return torch.logsumexp(router_logits.float(), dim=-1).square().mean()
268
 
 
 
269
 
270
- class MoEAudioProjector(nn.Module):
271
- """MoE projector with shared expert + sparse routed experts."""
272
 
273
  def __init__(self, config):
274
  """Initialize MoE projector.
@@ -279,40 +218,59 @@ class MoEAudioProjector(nn.Module):
279
  super().__init__()
280
 
281
  self.k = getattr(config, "projector_pool_stride", 4)
282
- encoder_dim = config.encoder_dim
283
 
284
- # Depthwise Conv for temporal mixing
285
- self.temporal_conv = nn.Conv1d(
286
- encoder_dim, encoder_dim, kernel_size=3, padding=1, groups=encoder_dim
287
- )
 
 
 
288
 
289
- in_dim = encoder_dim * self.k
290
  out_dim = config.llm_dim
291
- hidden_dim = getattr(config, "projector_hidden_dim", None) or in_dim
292
 
 
 
 
 
293
  self.num_experts = getattr(config, "num_experts", 4)
294
  self.top_k = getattr(config, "num_experts_per_tok", 2)
295
- self.aux_loss_coef = getattr(config, "router_aux_loss_coef", 0.02)
296
- self.z_loss_coef = getattr(config, "router_z_loss_coef", 0.001)
297
 
298
- self.moe = SharedMoEBlock(in_dim, hidden_dim, out_dim, self.num_experts, self.top_k)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
  self._init_weights()
300
 
 
 
301
  def _init_weights(self):
 
302
  with torch.no_grad():
303
- nn.init.orthogonal_(self.moe.shared_expert.fc1.weight)
304
- nn.init.orthogonal_(self.moe.shared_expert.fc2.weight, gain=0.5)
305
 
306
- for expert in self.moe.experts:
307
- nn.init.orthogonal_(expert.fc1.weight)
308
- nn.init.orthogonal_(expert.fc2.weight, gain=0.01)
 
309
 
310
  def get_output_length(self, input_length: int) -> int:
311
- """Calculate output sequence length given input length."""
312
- # Temporal pooling with stride k
313
- if input_length % self.k:
314
- input_length += self.k - input_length % self.k
315
- return input_length // self.k
316
 
317
  def forward(self, x: torch.Tensor) -> torch.Tensor:
318
  """Project audio features using shared + sparse MoE.
@@ -323,32 +281,95 @@ class MoEAudioProjector(nn.Module):
323
  Returns:
324
  Projected features of shape [batch, out_len, llm_dim]
325
  """
326
- batch_size, seq_len, dim = x.size()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
327
 
328
- target_dtype = self.moe.shared_expert.fc1.weight.dtype
329
- if x.dtype != target_dtype:
330
- x = x.to(target_dtype)
 
 
 
 
331
 
332
- # Temporal Context Injection
333
- x_ctx = x.transpose(1, 2)
334
- x_ctx = self.temporal_conv(x_ctx)
335
- x = x + x_ctx.transpose(1, 2)
336
 
337
- if seq_len % self.k:
338
- x = F.pad(x, (0, 0, 0, self.k - seq_len % self.k))
339
 
340
- x = x.view(batch_size, -1, dim * self.k)
 
341
 
342
- return self.moe(x)
 
343
 
344
- def get_aux_loss(self) -> torch.Tensor:
345
- if self.moe.last_router_logits is None:
346
- return torch.tensor(0.0, device=self.moe.router.weight.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
 
348
- balance = load_balancing_loss(self.moe.last_router_probs, self.num_experts, self.top_k)
349
- z = z_loss(self.moe.last_router_logits)
 
350
 
351
- return self.aux_loss_coef * balance + self.z_loss_coef * z
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
 
353
 
354
  # =============================================================================
 
36
  self.k = getattr(config, "projector_pool_stride", 4)
37
 
38
  # Frame stacking: concat k adjacent frames then project
 
39
  in_dim = encoder_dim * self.k
40
+ # Hidden dim defaults to llm_dim, can be overridden via config
41
+ hidden_dim = getattr(config, "projector_hidden_dim", None) or llm_dim
42
+ self.linear_1 = nn.Linear(in_dim, hidden_dim, bias=False)
43
+ self.norm = LlamaRMSNorm(hidden_dim, eps=1e-6)
44
  self.act = nn.GELU()
45
+ self.linear_2 = nn.Linear(hidden_dim, llm_dim, bias=False)
46
 
47
  def get_output_length(self, input_length: int) -> int:
48
  """Calculate output sequence length given input length (matches GLM-ASR)."""
 
66
  x = x.reshape(batch, out_len, dim * self.k)
67
 
68
  x = self.linear_1(x)
69
+ x = self.norm(x)
70
  x = self.act(x)
71
  return self.linear_2(x)
72
 
 
89
  return self.fc2(self.act(self.fc1(x)))
90
 
91
 
92
+ class SwiGLU(nn.Module):
93
+ """SwiGLU activation with gated linear units (used in LLaMA, Mistral, etc.)."""
94
+
95
+ def __init__(self, dim: int, hidden_dim: int, bias: bool = False):
96
+ super().__init__()
97
+ self.w1 = nn.Linear(dim, hidden_dim, bias=bias) # Gate
98
+ self.w2 = nn.Linear(dim, hidden_dim, bias=bias) # Value
99
+ self.w3 = nn.Linear(hidden_dim, dim, bias=bias) # Output
100
+
101
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
102
+ return self.w3(F.silu(self.w1(x)) * self.w2(x))
103
+
104
+
105
+ class AsymmetricSwiGLU(nn.Module):
106
+ """SwiGLU that handles different input and output dimensions."""
107
+
108
+ def __init__(
109
+ self, in_features: int, hidden_features: int, out_features: int, bias: bool = False
110
+ ):
111
+ super().__init__()
112
+ self.w1 = nn.Linear(in_features, hidden_features, bias=bias) # Gate
113
+ self.w2 = nn.Linear(in_features, hidden_features, bias=bias) # Value
114
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias) # Output
115
+
116
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
117
+ return self.w3(F.silu(self.w1(x)) * self.w2(x))
118
+
119
+
120
  class MOSAProjector(nn.Module):
121
  """MOSA-Base projector: simple 2-layer ReLU router with 4 simple adapters.
122
 
 
196
 
197
 
198
  # =============================================================================
199
+ # MoE Projector (Pure PyTorch with Shared Expert)
200
  # =============================================================================
201
 
202
 
203
+ class MoEAudioProjector(nn.Module):
204
+ """MoE projector with shared expert (DeepSeek-style), pure PyTorch implementation.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
+ Uses 4 sparse experts with top-2 routing plus a shared expert that processes all tokens.
207
+ No external dependencies (megablocks removed).
208
 
209
+ Architecture matches main branch: norm → experts(in_dim → hidden → out_dim)
210
+ """
211
 
212
  def __init__(self, config):
213
  """Initialize MoE projector.
 
218
  super().__init__()
219
 
220
  self.k = getattr(config, "projector_pool_stride", 4)
221
+ self.aux_coef = getattr(config, "router_aux_loss_coef", 0.01)
222
 
223
+ # Stability coefficients
224
+ self.router_z_loss_coef = getattr(
225
+ config, "router_z_loss_coef", 1e-4
226
+ ) # Prevents logit explosion
227
+ self.router_jitter_noise = getattr(
228
+ config, "router_jitter_noise", 0.01
229
+ ) # Prevents expert collapse
230
 
231
+ in_dim = config.encoder_dim * self.k
232
  out_dim = config.llm_dim
 
233
 
234
+ # Expert hidden dim (default = output dim)
235
+ hidden_dim = getattr(config, "projector_hidden_dim", None) or out_dim
236
+
237
+ # Number of experts and top-k selection
238
  self.num_experts = getattr(config, "num_experts", 4)
239
  self.top_k = getattr(config, "num_experts_per_tok", 2)
 
 
240
 
241
+ # A. Normalize stacked input (like main branch SharedMoEBlock)
242
+ self.norm = LlamaRMSNorm(in_dim, eps=1e-6)
243
+
244
+ # B. Router (operates on stacked input)
245
+ self.router = nn.Linear(in_dim, self.num_experts, bias=False)
246
+
247
+ # C. Experts: simple 2-layer MLP (same as MLPAudioProjector)
248
+ self.experts = nn.ModuleList(
249
+ [SimpleAdapter(in_dim, hidden_dim, out_dim) for _ in range(self.num_experts)]
250
+ )
251
+
252
+ # D. Shared Expert (same architecture)
253
+ self.shared_expert = SimpleAdapter(in_dim, hidden_dim, out_dim)
254
+
255
+ # E. Initialize weights for stable training
256
  self._init_weights()
257
 
258
+ self.last_aux_loss = torch.tensor(0.0)
259
+
260
  def _init_weights(self):
261
+ """Initialize weights for stable training start."""
262
  with torch.no_grad():
263
+ # Router: small weights -> uniform probability
264
+ nn.init.normal_(self.router.weight, mean=0.0, std=0.02)
265
 
266
+ # Experts: xavier for fc1, small for fc2 (output)
267
+ for expert in [self.shared_expert, *self.experts]:
268
+ nn.init.xavier_uniform_(expert.fc1.weight)
269
+ nn.init.normal_(expert.fc2.weight, mean=0.0, std=0.01) # Small init
270
 
271
  def get_output_length(self, input_length: int) -> int:
272
+ """Calculate output sequence length given input length (matches MLP projector)."""
273
+ return (input_length - self.k) // self.k + 1
 
 
 
274
 
275
  def forward(self, x: torch.Tensor) -> torch.Tensor:
276
  """Project audio features using shared + sparse MoE.
 
281
  Returns:
282
  Projected features of shape [batch, out_len, llm_dim]
283
  """
284
+ # 1. Frame Stacking
285
+ batch, seq, dim = x.shape
286
+ out_len = (seq - self.k) // self.k + 1
287
+ x = x[:, : out_len * self.k, :]
288
+ x = x.reshape(batch, out_len, dim * self.k)
289
+
290
+ # 2. Normalize stacked input (like main branch SharedMoEBlock)
291
+ x = self.norm(x)
292
+ flat_x = x.view(-1, x.size(-1)) # [tokens, in_dim]
293
+
294
+ # 3. Shared Expert (compute first, creates output tensor)
295
+ output = self.shared_expert(flat_x)
296
+
297
+ # 4. Sparse Experts (in-place add to shared output)
298
+ self.last_aux_loss = self._forward_sparse(flat_x, output)
299
+
300
+ return output.view(batch, out_len, -1)
301
+
302
+ def _forward_sparse(self, x: torch.Tensor, output: torch.Tensor) -> torch.Tensor:
303
+ """Stability-hardened sparse expert dispatch (in-place add to output).
304
+
305
+ Args:
306
+ x: Flattened input of shape [tokens, dim]
307
+ output: Output tensor to add sparse expert results into (in-place)
308
+
309
+ Returns:
310
+ Auxiliary loss tensor
311
+ """
312
+ # A. Router Logic with Jitter
313
+ logits = self.router(x)
314
 
315
+ if self.training and self.router_jitter_noise > 0:
316
+ # Jitter: multiply by uniform noise (1-eps, 1+eps) to shake decision boundary
317
+ # Prevents router from getting stuck on one expert early in training
318
+ noise = torch.empty_like(logits).uniform_(
319
+ 1.0 - self.router_jitter_noise, 1.0 + self.router_jitter_noise
320
+ )
321
+ logits = logits * noise
322
 
323
+ # Force float32 for softmax (bf16/fp16 exponentials can overflow)
324
+ probs = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(x)
 
 
325
 
326
+ # B. Top-K Selection
327
+ top_k_weights, top_k_indices = torch.topk(probs, self.top_k, dim=-1)
328
 
329
+ # Normalize weights so they sum to 1.0
330
+ top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-6)
331
 
332
+ # C. Aux Loss + Z-Loss
333
+ aux_loss = torch.tensor(0.0, device=x.device)
334
 
335
+ if self.training:
336
+ # Load balancing loss (batch-size invariant)
337
+ prob_per_expert = probs.mean(0) # [num_experts]
338
+ target = 1.0 / self.num_experts
339
+ balance_loss = (
340
+ self.aux_coef * ((prob_per_expert - target) ** 2).mean() * self.num_experts
341
+ )
342
+
343
+ # Z-loss: penalty on large logits to prevent softmax saturation
344
+ z_loss = self.router_z_loss_coef * torch.logsumexp(logits, dim=-1).pow(2).mean()
345
+
346
+ aux_loss = balance_loss + z_loss
347
+
348
+ # D. Dispatch Loop (in-place add to output)
349
+ for i, expert in enumerate(self.experts):
350
+ # Create boolean mask for tokens that selected Expert 'i'
351
+ mask = top_k_indices == i
352
 
353
+ if mask.any():
354
+ # token_idx = which tokens, k_idx = 1st or 2nd choice
355
+ token_idx, k_idx = torch.where(mask)
356
 
357
+ # Gather inputs and compute
358
+ expert_input = x[token_idx]
359
+ expert_output = expert(expert_input)
360
+
361
+ # Apply routing weight
362
+ weight = top_k_weights[token_idx, k_idx].unsqueeze(-1)
363
+ weighted_output = (expert_output * weight).type_as(output)
364
+
365
+ # Scatter back in-place (index_add_ is atomic and deterministic)
366
+ output.index_add_(0, token_idx, weighted_output)
367
+
368
+ return aux_loss
369
+
370
+ def get_aux_loss(self) -> torch.Tensor:
371
+ """Return auxiliary load balancing loss."""
372
+ return self.last_aux_loss
373
 
374
 
375
  # =============================================================================