Firworks commited on
Commit
aadbb6b
·
verified ·
1 Parent(s): d802356

Add NVFP4 quantized checkpoint

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ datasets:
3
+ - Rombo-Org/Optimized_Reasoning
4
+ base_model:
5
+ - stepfun-ai/Step-Audio-R1
6
+ ---
7
+ # Step-Audio-R1-nvfp4
8
+
9
+ **Format:** NVFP4 — weights & activations quantized to FP4 with dual scaling.
10
+ **Base model:** `stepfun-ai/Step-Audio-R1`
11
+ **How it was made:** One-shot calibration with LLM Compressor (NVFP4 recipe), long-seq calibration with Rombo-Org/Optimized_Reasoning.
12
+
13
+ > Notes: Keep `lm_head` in high precision; calibrate on long, domain-relevant sequences.
14
+
15
+ Check the original model card for information about this model.
16
+
17
+ # Running the model with VLLM in Docker
18
+ ```sh
19
+ sudo docker run --runtime nvidia --gpus all -p 8000:8000 --ipc=host vllm/vllm-openai:nightly --model Firworks/Step-Audio-R1-nvfp4 --dtype auto --max-model-len 32768
20
+ ```
21
+ This was tested on a B200 cloud instance.
22
+
23
+ If there are other models you're interested in seeing quantized to NVFP4 for use on the DGX Spark, or other modern Blackwell (or newer) cards let me know. I'm trying to make more NVFP4 models available to allow more people to try them out.
added_tokens.json ADDED
The diff for this file is too large to render. See raw diff
 
chat_template.jinja ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {%- if tools %}
2
+ {{- '<|BOT|>system
3
+ ' }}
4
+ {%- if messages[0]['role'] == 'system' %}
5
+ {{- messages[0]['content'] + '<|EOT|>' }}
6
+ {%- else %}
7
+ {{- 'You are a helpful assistant. Please think step by step and provide your reasoning process within <think> </think> tags, followed by your final answer. Format: <think>your reasoning here</think>your final answer<|EOT|>' }}
8
+ {%- endif %}
9
+ {{- '<|BOT|>' }}
10
+ {{- "tool_json_schemas
11
+ " }}
12
+ {{- tools | tojson }}
13
+ {{- '<|EOT|>' }}
14
+ {%- else %}
15
+ {%- if messages[0]['role'] == 'system' %}
16
+ {{- '<|BOT|>system
17
+ ' + messages[0]['content'] + '<|EOT|>' }}
18
+ {%- else %}
19
+ {{- '<|BOT|>system
20
+ You are a helpful assistant. Please think step by step and provide your reasoning process within <think> </think> tags, followed by your final answer. Format: <think>your reasoning here</think>your final answer<|EOT|>' }}
21
+ {%- endif %}
22
+ {%- endif %}
23
+ {%- for message in messages %}
24
+ {%- if message["role"] == "user" %}
25
+ {{- '<|BOT|>human
26
+ ' + message["content"] + '<|EOT|>' }}
27
+ {%- elif (message["role"] == "system" and not loop.first) or (message["role"] == "assistant" and not message["tool_calls"]) %}
28
+ {{- '<|BOT|>' + message["role"] + '
29
+ ' + message["content"] + '<|EOT|>' }}
30
+ {%- elif message["role"] == "assistant" %}
31
+ {{- '<|BOT|>' + message["role"] + '
32
+ ' }}
33
+ {%- if message["content"] %}
34
+ {{- message["content"] }}
35
+ {%- endif %}
36
+ {%- for tool_call in message.tool_calls %}
37
+ {%- if tool_call["function"] is defined %}
38
+ {%- set tool_call = tool_call["function"] %}
39
+ {%- endif %}
40
+ {{- '<|CALL_START|>' + 'function
41
+ ' + tool_call["name"] + '
42
+ ' }}
43
+ {{- tool_call["arguments"] | tojson }}
44
+ {{- '<|CALL_END|>' }}
45
+ {%- endfor %}
46
+ {{- '<|EOT|>' }}
47
+ {%- elif message["role"] == "tool" %}
48
+ {{- '<|BOT|>' }}
49
+ {%- set ns = namespace(function_name="tool") %}
50
+ {%- if message["tool_call_id"] %}
51
+ {%- for prev_msg in messages %}
52
+ {%- if prev_msg["role"] == "assistant" and prev_msg["tool_calls"] %}
53
+ {%- for tool_call in prev_msg["tool_calls"] %}
54
+ {%- if tool_call["id"] == message["tool_call_id"] %}
55
+ {%- if tool_call["function"] is defined %}
56
+ {%- set ns.function_name = tool_call["function"]["name"] %}
57
+ {%- endif %}
58
+ {%- endif %}
59
+ {%- endfor %}
60
+ {%- endif %}
61
+ {%- endfor %}
62
+ {%- endif %}
63
+ {{- 'function_output
64
+ ' + ns.function_name + '
65
+ ' }}
66
+ {{- message["content"] }}
67
+ {{- '<|EOT|>' }}
68
+ {%- endif %}
69
+ {%- endfor %}
70
+ {%- if add_generation_prompt %}
71
+ {{- '<|BOT|>assistant
72
+ <think>
73
+ ' }}
74
+ {%- endif %}
config.json ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "StepAudio2ForCausalLM"
4
+ ],
5
+ "audio_encoder_config": {
6
+ "adapter_stride": 2,
7
+ "kernel_size": 3,
8
+ "llm_dim": 5120,
9
+ "model_type": "step_audio_2_encoder",
10
+ "n_audio_ctx": 1500,
11
+ "n_audio_head": 20,
12
+ "n_audio_layer": 32,
13
+ "n_audio_state": 1280,
14
+ "n_codebook_size": 4096,
15
+ "n_mels": 128
16
+ },
17
+ "auto_map": {
18
+ "AutoConfig": "configuration_step_audio_2.StepAudio2Config",
19
+ "AutoModelForCausalLM": "modeling_step_audio_2.StepAudio2ForCausalLM"
20
+ },
21
+ "dtype": "bfloat16",
22
+ "max_window_layers": null,
23
+ "model_type": "step_audio_2",
24
+ "quantization_config": {
25
+ "config_groups": {
26
+ "group_0": {
27
+ "format": "nvfp4-pack-quantized",
28
+ "input_activations": {
29
+ "actorder": null,
30
+ "block_structure": null,
31
+ "dynamic": "local",
32
+ "group_size": 16,
33
+ "num_bits": 4,
34
+ "observer": "minmax",
35
+ "observer_kwargs": {},
36
+ "strategy": "tensor_group",
37
+ "symmetric": true,
38
+ "type": "float"
39
+ },
40
+ "output_activations": null,
41
+ "targets": [
42
+ "Linear"
43
+ ],
44
+ "weights": {
45
+ "actorder": null,
46
+ "block_structure": null,
47
+ "dynamic": false,
48
+ "group_size": 16,
49
+ "num_bits": 4,
50
+ "observer": "minmax",
51
+ "observer_kwargs": {},
52
+ "strategy": "tensor_group",
53
+ "symmetric": true,
54
+ "type": "float"
55
+ }
56
+ }
57
+ },
58
+ "format": "nvfp4-pack-quantized",
59
+ "global_compression_ratio": null,
60
+ "ignore": [
61
+ "encoder.blocks.0.attn.query",
62
+ "encoder.blocks.0.attn.key",
63
+ "encoder.blocks.0.attn.value",
64
+ "encoder.blocks.0.attn.out",
65
+ "encoder.blocks.0.mlp.0",
66
+ "encoder.blocks.0.mlp.2",
67
+ "encoder.blocks.1.attn.query",
68
+ "encoder.blocks.1.attn.key",
69
+ "encoder.blocks.1.attn.value",
70
+ "encoder.blocks.1.attn.out",
71
+ "encoder.blocks.1.mlp.0",
72
+ "encoder.blocks.1.mlp.2",
73
+ "encoder.blocks.2.attn.query",
74
+ "encoder.blocks.2.attn.key",
75
+ "encoder.blocks.2.attn.value",
76
+ "encoder.blocks.2.attn.out",
77
+ "encoder.blocks.2.mlp.0",
78
+ "encoder.blocks.2.mlp.2",
79
+ "encoder.blocks.3.attn.query",
80
+ "encoder.blocks.3.attn.key",
81
+ "encoder.blocks.3.attn.value",
82
+ "encoder.blocks.3.attn.out",
83
+ "encoder.blocks.3.mlp.0",
84
+ "encoder.blocks.3.mlp.2",
85
+ "encoder.blocks.4.attn.query",
86
+ "encoder.blocks.4.attn.key",
87
+ "encoder.blocks.4.attn.value",
88
+ "encoder.blocks.4.attn.out",
89
+ "encoder.blocks.4.mlp.0",
90
+ "encoder.blocks.4.mlp.2",
91
+ "encoder.blocks.5.attn.query",
92
+ "encoder.blocks.5.attn.key",
93
+ "encoder.blocks.5.attn.value",
94
+ "encoder.blocks.5.attn.out",
95
+ "encoder.blocks.5.mlp.0",
96
+ "encoder.blocks.5.mlp.2",
97
+ "encoder.blocks.6.attn.query",
98
+ "encoder.blocks.6.attn.key",
99
+ "encoder.blocks.6.attn.value",
100
+ "encoder.blocks.6.attn.out",
101
+ "encoder.blocks.6.mlp.0",
102
+ "encoder.blocks.6.mlp.2",
103
+ "encoder.blocks.7.attn.query",
104
+ "encoder.blocks.7.attn.key",
105
+ "encoder.blocks.7.attn.value",
106
+ "encoder.blocks.7.attn.out",
107
+ "encoder.blocks.7.mlp.0",
108
+ "encoder.blocks.7.mlp.2",
109
+ "encoder.blocks.8.attn.query",
110
+ "encoder.blocks.8.attn.key",
111
+ "encoder.blocks.8.attn.value",
112
+ "encoder.blocks.8.attn.out",
113
+ "encoder.blocks.8.mlp.0",
114
+ "encoder.blocks.8.mlp.2",
115
+ "encoder.blocks.9.attn.query",
116
+ "encoder.blocks.9.attn.key",
117
+ "encoder.blocks.9.attn.value",
118
+ "encoder.blocks.9.attn.out",
119
+ "encoder.blocks.9.mlp.0",
120
+ "encoder.blocks.9.mlp.2",
121
+ "encoder.blocks.10.attn.query",
122
+ "encoder.blocks.10.attn.key",
123
+ "encoder.blocks.10.attn.value",
124
+ "encoder.blocks.10.attn.out",
125
+ "encoder.blocks.10.mlp.0",
126
+ "encoder.blocks.10.mlp.2",
127
+ "encoder.blocks.11.attn.query",
128
+ "encoder.blocks.11.attn.key",
129
+ "encoder.blocks.11.attn.value",
130
+ "encoder.blocks.11.attn.out",
131
+ "encoder.blocks.11.mlp.0",
132
+ "encoder.blocks.11.mlp.2",
133
+ "encoder.blocks.12.attn.query",
134
+ "encoder.blocks.12.attn.key",
135
+ "encoder.blocks.12.attn.value",
136
+ "encoder.blocks.12.attn.out",
137
+ "encoder.blocks.12.mlp.0",
138
+ "encoder.blocks.12.mlp.2",
139
+ "encoder.blocks.13.attn.query",
140
+ "encoder.blocks.13.attn.key",
141
+ "encoder.blocks.13.attn.value",
142
+ "encoder.blocks.13.attn.out",
143
+ "encoder.blocks.13.mlp.0",
144
+ "encoder.blocks.13.mlp.2",
145
+ "encoder.blocks.14.attn.query",
146
+ "encoder.blocks.14.attn.key",
147
+ "encoder.blocks.14.attn.value",
148
+ "encoder.blocks.14.attn.out",
149
+ "encoder.blocks.14.mlp.0",
150
+ "encoder.blocks.14.mlp.2",
151
+ "encoder.blocks.15.attn.query",
152
+ "encoder.blocks.15.attn.key",
153
+ "encoder.blocks.15.attn.value",
154
+ "encoder.blocks.15.attn.out",
155
+ "encoder.blocks.15.mlp.0",
156
+ "encoder.blocks.15.mlp.2",
157
+ "encoder.blocks.16.attn.query",
158
+ "encoder.blocks.16.attn.key",
159
+ "encoder.blocks.16.attn.value",
160
+ "encoder.blocks.16.attn.out",
161
+ "encoder.blocks.16.mlp.0",
162
+ "encoder.blocks.16.mlp.2",
163
+ "encoder.blocks.17.attn.query",
164
+ "encoder.blocks.17.attn.key",
165
+ "encoder.blocks.17.attn.value",
166
+ "encoder.blocks.17.attn.out",
167
+ "encoder.blocks.17.mlp.0",
168
+ "encoder.blocks.17.mlp.2",
169
+ "encoder.blocks.18.attn.query",
170
+ "encoder.blocks.18.attn.key",
171
+ "encoder.blocks.18.attn.value",
172
+ "encoder.blocks.18.attn.out",
173
+ "encoder.blocks.18.mlp.0",
174
+ "encoder.blocks.18.mlp.2",
175
+ "encoder.blocks.19.attn.query",
176
+ "encoder.blocks.19.attn.key",
177
+ "encoder.blocks.19.attn.value",
178
+ "encoder.blocks.19.attn.out",
179
+ "encoder.blocks.19.mlp.0",
180
+ "encoder.blocks.19.mlp.2",
181
+ "encoder.blocks.20.attn.query",
182
+ "encoder.blocks.20.attn.key",
183
+ "encoder.blocks.20.attn.value",
184
+ "encoder.blocks.20.attn.out",
185
+ "encoder.blocks.20.mlp.0",
186
+ "encoder.blocks.20.mlp.2",
187
+ "encoder.blocks.21.attn.query",
188
+ "encoder.blocks.21.attn.key",
189
+ "encoder.blocks.21.attn.value",
190
+ "encoder.blocks.21.attn.out",
191
+ "encoder.blocks.21.mlp.0",
192
+ "encoder.blocks.21.mlp.2",
193
+ "encoder.blocks.22.attn.query",
194
+ "encoder.blocks.22.attn.key",
195
+ "encoder.blocks.22.attn.value",
196
+ "encoder.blocks.22.attn.out",
197
+ "encoder.blocks.22.mlp.0",
198
+ "encoder.blocks.22.mlp.2",
199
+ "encoder.blocks.23.attn.query",
200
+ "encoder.blocks.23.attn.key",
201
+ "encoder.blocks.23.attn.value",
202
+ "encoder.blocks.23.attn.out",
203
+ "encoder.blocks.23.mlp.0",
204
+ "encoder.blocks.23.mlp.2",
205
+ "encoder.blocks.24.attn.query",
206
+ "encoder.blocks.24.attn.key",
207
+ "encoder.blocks.24.attn.value",
208
+ "encoder.blocks.24.attn.out",
209
+ "encoder.blocks.24.mlp.0",
210
+ "encoder.blocks.24.mlp.2",
211
+ "encoder.blocks.25.attn.query",
212
+ "encoder.blocks.25.attn.key",
213
+ "encoder.blocks.25.attn.value",
214
+ "encoder.blocks.25.attn.out",
215
+ "encoder.blocks.25.mlp.0",
216
+ "encoder.blocks.25.mlp.2",
217
+ "encoder.blocks.26.attn.query",
218
+ "encoder.blocks.26.attn.key",
219
+ "encoder.blocks.26.attn.value",
220
+ "encoder.blocks.26.attn.out",
221
+ "encoder.blocks.26.mlp.0",
222
+ "encoder.blocks.26.mlp.2",
223
+ "encoder.blocks.27.attn.query",
224
+ "encoder.blocks.27.attn.key",
225
+ "encoder.blocks.27.attn.value",
226
+ "encoder.blocks.27.attn.out",
227
+ "encoder.blocks.27.mlp.0",
228
+ "encoder.blocks.27.mlp.2",
229
+ "encoder.blocks.28.attn.query",
230
+ "encoder.blocks.28.attn.key",
231
+ "encoder.blocks.28.attn.value",
232
+ "encoder.blocks.28.attn.out",
233
+ "encoder.blocks.28.mlp.0",
234
+ "encoder.blocks.28.mlp.2",
235
+ "encoder.blocks.29.attn.query",
236
+ "encoder.blocks.29.attn.key",
237
+ "encoder.blocks.29.attn.value",
238
+ "encoder.blocks.29.attn.out",
239
+ "encoder.blocks.29.mlp.0",
240
+ "encoder.blocks.29.mlp.2",
241
+ "encoder.blocks.30.attn.query",
242
+ "encoder.blocks.30.attn.key",
243
+ "encoder.blocks.30.attn.value",
244
+ "encoder.blocks.30.attn.out",
245
+ "encoder.blocks.30.mlp.0",
246
+ "encoder.blocks.30.mlp.2",
247
+ "encoder.blocks.31.attn.query",
248
+ "encoder.blocks.31.attn.key",
249
+ "encoder.blocks.31.attn.value",
250
+ "encoder.blocks.31.attn.out",
251
+ "encoder.blocks.31.mlp.0",
252
+ "encoder.blocks.31.mlp.2",
253
+ "adapter.linear1",
254
+ "adapter.linear2",
255
+ "lm_head"
256
+ ],
257
+ "kv_cache_scheme": null,
258
+ "quant_method": "compressed-tensors",
259
+ "quantization_status": "compressed",
260
+ "sparsity_config": {},
261
+ "transform_config": {},
262
+ "version": "0.12.2"
263
+ },
264
+ "sliding_window": 2048,
265
+ "text_config": {
266
+ "architectures": [
267
+ "Qwen2ForCausalLM"
268
+ ],
269
+ "attention_dropout": 0.0,
270
+ "dtype": "bfloat16",
271
+ "hidden_act": "silu",
272
+ "hidden_size": 5120,
273
+ "initializer_range": 0.02,
274
+ "intermediate_size": 27648,
275
+ "layer_types": [
276
+ "full_attention",
277
+ "full_attention",
278
+ "full_attention",
279
+ "full_attention",
280
+ "full_attention",
281
+ "full_attention",
282
+ "full_attention",
283
+ "full_attention",
284
+ "full_attention",
285
+ "full_attention",
286
+ "full_attention",
287
+ "full_attention",
288
+ "full_attention",
289
+ "full_attention",
290
+ "full_attention",
291
+ "full_attention",
292
+ "full_attention",
293
+ "full_attention",
294
+ "full_attention",
295
+ "full_attention",
296
+ "full_attention",
297
+ "full_attention",
298
+ "full_attention",
299
+ "full_attention",
300
+ "full_attention",
301
+ "full_attention",
302
+ "full_attention",
303
+ "full_attention",
304
+ "full_attention",
305
+ "full_attention",
306
+ "full_attention",
307
+ "full_attention",
308
+ "full_attention",
309
+ "full_attention",
310
+ "full_attention",
311
+ "full_attention",
312
+ "full_attention",
313
+ "full_attention",
314
+ "full_attention",
315
+ "full_attention",
316
+ "full_attention",
317
+ "full_attention",
318
+ "full_attention",
319
+ "full_attention",
320
+ "full_attention",
321
+ "full_attention",
322
+ "full_attention",
323
+ "full_attention",
324
+ "full_attention",
325
+ "full_attention",
326
+ "full_attention",
327
+ "full_attention",
328
+ "full_attention",
329
+ "full_attention",
330
+ "full_attention",
331
+ "full_attention",
332
+ "full_attention",
333
+ "full_attention",
334
+ "full_attention",
335
+ "full_attention",
336
+ "full_attention",
337
+ "full_attention",
338
+ "full_attention",
339
+ "full_attention"
340
+ ],
341
+ "max_position_embeddings": 65536,
342
+ "max_window_layers": 28,
343
+ "model_type": "qwen2",
344
+ "num_attention_heads": 40,
345
+ "num_hidden_layers": 64,
346
+ "num_key_value_heads": 8,
347
+ "rms_norm_eps": 1e-05,
348
+ "rope_scaling": null,
349
+ "rope_theta": 1000000.0,
350
+ "sliding_window": null,
351
+ "use_cache": true,
352
+ "use_sliding_window": false,
353
+ "vocab_size": 158720
354
+ },
355
+ "tie_word_embeddings": false,
356
+ "transformers_version": "4.56.2",
357
+ "use_sliding_window": false
358
+ }
configuration_step_audio_2.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union
2
+
3
+ from transformers import Qwen2Config
4
+ from transformers.configuration_utils import PretrainedConfig
5
+
6
+
7
+ class StepAudio2EncoderConfig(PretrainedConfig):
8
+ model_type = "step_audio_2_encoder"
9
+
10
+ def __init__(
11
+ self,
12
+ n_mels=128,
13
+ n_audio_ctx=1500,
14
+ n_audio_state=512,
15
+ n_audio_head=8,
16
+ n_audio_layer=6,
17
+ llm_dim=4096,
18
+ kernel_size=3,
19
+ adapter_stride=2,
20
+ **kwargs,
21
+ ):
22
+ self.n_mels = n_mels
23
+ self.n_audio_ctx = n_audio_ctx
24
+ self.n_audio_state = n_audio_state
25
+ self.n_audio_head = n_audio_head
26
+ self.n_audio_layer = n_audio_layer
27
+ self.llm_dim = llm_dim
28
+ self.kernel_size = kernel_size
29
+ self.adapter_stride = adapter_stride
30
+ super().__init__(**kwargs)
31
+
32
+ class StepAudio2TextConfig(PretrainedConfig):
33
+ model_type = "step_audio_2_text"
34
+
35
+ def __init__(
36
+ self,
37
+ vocab_size=64012,
38
+ hidden_size=4096,
39
+ intermediate_size=11008,
40
+ num_hidden_layers=48,
41
+ num_attention_heads=32,
42
+ num_attention_groups=4,
43
+ num_key_value_heads=4,
44
+ hidden_act="silu",
45
+ max_position_embeddings=8192,
46
+ initializer_range=0.02,
47
+ rms_norm_eps=1e-6,
48
+ rope_theta=1000000.0,
49
+ rope_scaling=None,
50
+ eos_token_id=None,
51
+ **kwargs
52
+ ):
53
+
54
+ if eos_token_id is not None:
55
+ if isinstance(eos_token_id, list):
56
+ eos_token_id = list(set([151643, 151645, 151665] + eos_token_id))
57
+ else:
58
+ eos_token_id = [151643, 151645, 151665, eos_token_id]
59
+ else:
60
+ eos_token_id = [151643, 151645, 151665]
61
+
62
+ super().__init__(
63
+ eos_token_id=eos_token_id,
64
+ **kwargs)
65
+
66
+ self.vocab_size = vocab_size
67
+ self.hidden_size = hidden_size
68
+ self.intermediate_size = intermediate_size
69
+ self.num_hidden_layers = num_hidden_layers
70
+ self.num_attention_heads = num_attention_heads
71
+ self.num_attention_groups = num_attention_groups
72
+ self.num_key_value_heads = num_key_value_heads
73
+ assert self.num_attention_groups == self.num_key_value_heads, "num_attention_groups must be equal to num_key_value_heads"
74
+ self.hidden_act = hidden_act
75
+ self.max_position_embeddings = max_position_embeddings
76
+ self.initializer_range = initializer_range
77
+ self.rms_norm_eps = rms_norm_eps
78
+ self.rope_theta = rope_theta
79
+ self.rope_scaling = rope_scaling
80
+
81
+ self.text_config = Qwen2Config(
82
+ vocab_size=vocab_size,
83
+ hidden_size=hidden_size,
84
+ intermediate_size=intermediate_size,
85
+ num_hidden_layers=num_hidden_layers,
86
+ num_attention_heads=num_attention_heads,
87
+ num_key_value_heads=num_key_value_heads,
88
+ hidden_act=hidden_act,
89
+ max_position_embeddings=max_position_embeddings,
90
+ initializer_range=initializer_range,
91
+ rms_norm_eps=rms_norm_eps,
92
+ rope_theta=rope_theta,
93
+ rope_scaling=rope_scaling,
94
+ architectures=["Qwen2ForCausalLM"],
95
+ torch_dtype=getattr(self, "torch_dtype", "bfloat16"),
96
+ )
97
+
98
+ class StepAudio2Config(PretrainedConfig):
99
+ model_type = "step_audio_2"
100
+ architectures = ["StepAudio2ForCausalLM"]
101
+
102
+ def __init__(
103
+ self,
104
+ audio_encoder_config :Optional[Union[dict, StepAudio2EncoderConfig]] = None,
105
+ text_config: Optional[Union[dict, StepAudio2TextConfig]] = None,
106
+ use_sliding_window: bool = False,
107
+ sliding_window: Optional[int] = 2048,
108
+ max_window_layers: Optional[int] = None,
109
+ **kwargs
110
+ ):
111
+ kwargs.setdefault("use_sliding_window", use_sliding_window)
112
+ kwargs.setdefault("sliding_window", sliding_window)
113
+ if max_window_layers is None:
114
+ max_window_layers = kwargs.get("num_hidden_layers", None)
115
+ kwargs.setdefault("max_window_layers", max_window_layers)
116
+ super().__init__(**kwargs)
117
+
118
+ if text_config is None:
119
+ text_config = StepAudio2TextConfig().text_config
120
+ elif isinstance(text_config, dict):
121
+ text_config = StepAudio2TextConfig(**text_config).text_config
122
+
123
+ self.text_config = text_config
124
+
125
+ if audio_encoder_config is None:
126
+ self.audio_encoder_config = StepAudio2EncoderConfig()
127
+ elif isinstance(audio_encoder_config, dict):
128
+ self.audio_encoder_config = StepAudio2EncoderConfig(**audio_encoder_config)
generation_config.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "do_sample": true,
4
+ "transformers_version": "4.56.2"
5
+ }
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
model-00001-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9a47a8da85e0da3bc31c230badb7663bae60309d731cb849ba15959dc422d68e
3
+ size 4952380248
model-00002-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:95ef95ec980fec815184d6c9f2b2e95661be9e1e063a1a45466e908952803bbc
3
+ size 4937521480
model-00003-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:093b77b6b160d2476ca545eb757dbd0ee1d4b554fa152a833b4f60a9526bd26d
3
+ size 4937521480
model-00004-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:33090fc271d96985256ae7998d71d93c46756ad16d63c037602529b155428689
3
+ size 4997834160
model-00005-of-00005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2d812abb83c44af00a386341fc563d9c15cda01c08e1851ff1c904eb09d793a8
3
+ size 2291022848
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_step_audio_2.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Iterable, Optional, Tuple
2
+
3
+ import librosa
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import torchaudio
7
+ from torch import Tensor, nn
8
+ from transformers import PreTrainedModel, Qwen2Model
9
+ from transformers.generation.utils import GenerationMixin
10
+ from transformers.modeling_outputs import CausalLMOutputWithPast
11
+
12
+ from .configuration_step_audio_2 import StepAudio2Config
13
+
14
+
15
+ def _mel_filters(n_mels: int) -> torch.Tensor:
16
+ """Load the mel filterbank matrix for projecting STFT into a Mel spectrogram."""
17
+ assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
18
+ if n_mels == 128:
19
+ return torch.from_numpy(librosa.filters.mel(sr=16000, n_fft=400, n_mels=128))
20
+ else:
21
+ return torch.from_numpy(librosa.filters.mel(sr=16000, n_fft=400, n_mels=80))
22
+
23
+
24
+ def load_audio(file_path, target_rate=16000, max_length=None):
25
+ """
26
+ Open an audio file and read as mono waveform, resampling as necessary
27
+ If max_length is provided, truncate the audio to that length
28
+ """
29
+ waveform, sample_rate = torchaudio.load(file_path)
30
+ if sample_rate != target_rate:
31
+ waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_rate)(waveform)
32
+ audio = waveform[0] # get the first channel
33
+
34
+ # Truncate audio if it exceeds max_length
35
+ if max_length is not None and audio.shape[0] > max_length:
36
+ audio = audio[:max_length]
37
+
38
+ return audio
39
+
40
+ def log_mel_spectrogram(audio, n_mels=128, padding=479, device=None):
41
+ """
42
+ Compute the log-Mel spectrogram with specific padding for StepAudio
43
+ """
44
+ if not torch.is_tensor(audio):
45
+ if isinstance(audio, str):
46
+ audio = load_audio(audio)
47
+ audio = torch.from_numpy(audio)
48
+ if device is not None:
49
+ audio = audio.to(device)
50
+ if padding > 0:
51
+ audio = F.pad(audio, (0, padding))
52
+ window = torch.hann_window(400).to(audio.device)
53
+ stft = torch.stft(audio, 400, 160, window=window, return_complex=True)
54
+ magnitudes = stft[..., :-1].abs() ** 2
55
+ filters = _mel_filters(n_mels)
56
+ mel_spec = filters @ magnitudes
57
+
58
+ log_spec = torch.clamp(mel_spec, min=1e-10).log10()
59
+ log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
60
+ log_spec = (log_spec + 4.0) / 4.0
61
+ return log_spec
62
+
63
+ def compute_token_num(max_feature_len):
64
+ # First, audio goes through encoder:
65
+ # 1. conv1: kernel=3, stride=1, padding=1 -> size unchanged
66
+ # 2. conv2: kernel=3, stride=2, padding=1 -> size/2
67
+ # 3. avg_pooler: kernel=2, stride=2 -> size/2
68
+ max_feature_len = max_feature_len - 2 # remove padding
69
+ encoder_output_dim = (max_feature_len + 1) // 2 // 2 # after conv2 and avg_pooler
70
+
71
+ # Then through adaptor (parameters from config file):
72
+ padding = 1
73
+ kernel_size = 3 # from config: audio_encoder_config.kernel_size
74
+ stride = 2 # from config: audio_encoder_config.adapter_stride
75
+ adapter_output_dim = (encoder_output_dim + 2 * padding - kernel_size) // stride + 1
76
+ return adapter_output_dim
77
+
78
+ def make_non_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
79
+ """Make mask tensor containing indices of non-padded part.
80
+
81
+ The sequences in a batch may have different lengths. To enable
82
+ batch computing, padding is need to make all sequence in same
83
+ size. To avoid the padding part pass value to context dependent
84
+ block such as attention or convolution , this padding part is
85
+ masked.
86
+
87
+ 1 for non-padded part and 0 for padded part.
88
+
89
+ Parameters
90
+ ----------
91
+ lengths (torch.Tensor): Batch of lengths (B,).
92
+
93
+ Returns:
94
+ -------
95
+ torch.Tensor: Mask tensor containing indices of padded part (B, max_T).
96
+
97
+ Examples:
98
+ >>> import torch
99
+ >>> import s3tokenizer
100
+ >>> lengths = torch.tensor([5, 3, 2])
101
+ >>> masks = s3tokenizer.make_non_pad_mask(lengths)
102
+ masks = [[1, 1, 1, 1, 1],
103
+ [1, 1, 1, 0, 0],
104
+ [1, 1, 0, 0, 0]]
105
+ """
106
+ batch_size = lengths.size(0)
107
+ max_len = max_len if max_len > 0 else lengths.max().item()
108
+ seq_range = torch.arange(0,
109
+ max_len,
110
+ dtype=torch.int64,
111
+ device=lengths.device)
112
+ seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
113
+ seq_length_expand = lengths.unsqueeze(-1)
114
+ mask = seq_range_expand >= seq_length_expand
115
+ return ~mask
116
+
117
+ def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
118
+ """Convert bool-tensor to float-tensor for flash attention.
119
+
120
+ Parameters
121
+ ----------
122
+ lengths (torch.Tensor): Batch of lengths (B, ?).
123
+
124
+ Returns:
125
+ -------
126
+ torch.Tensor: Mask tensor containing indices of padded part (B, ?).
127
+
128
+ Examples:
129
+ >>> import torch
130
+ >>> import s3tokenizer
131
+ >>> lengths = torch.tensor([5, 3, 2])
132
+ >>> masks = s3tokenizer.make_non_pad_mask(lengths)
133
+ masks = [[1, 1, 1, 1, 1],
134
+ [1, 1, 1, 0, 0],
135
+ [1, 1, 0, 0, 0]]
136
+ >>> new_masks = s3tokenizer.mask_to_bias(masks, torch.float32)
137
+ new_masks = [[-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],
138
+ [-0.0000e+00, -0.0000e+00, -0.0000e+00, -1.0000e+10, -1.0000e+10],
139
+ [-0.0000e+00, -0.0000e+00, -1.0000e+10, -1.0000e+10, -1.0000e+10]]
140
+ """
141
+ assert mask.dtype == torch.bool
142
+ assert dtype in [torch.float32, torch.bfloat16, torch.float16]
143
+ mask = mask.to(dtype)
144
+ # attention mask bias
145
+ # NOTE(Mddct): torch.finfo jit issues
146
+ # chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
147
+ mask = (1.0 - mask) * -1.0e+10
148
+ return mask
149
+
150
+ class LayerNorm(nn.LayerNorm):
151
+ def forward(self, input: Tensor) -> Tensor:
152
+ return super().forward(input).type(input.dtype)
153
+
154
+ class Linear(nn.Linear):
155
+ def forward(self, input: Tensor) -> Tensor:
156
+ return F.linear(
157
+ input,
158
+ self.weight.to(input.dtype),
159
+ None if self.bias is None else self.bias.to(input.dtype),
160
+ )
161
+
162
+ class Conv1d(nn.Conv1d):
163
+ def _conv_forward(
164
+ self, input: Tensor, weight: Tensor, bias: Optional[Tensor]
165
+ ) -> Tensor:
166
+ return super()._conv_forward(
167
+ input, weight.to(input.dtype), None if bias is None else bias.to(input.dtype)
168
+ )
169
+
170
+ class MultiHeadAttention(nn.Module):
171
+ def __init__(self, n_state: int, n_head: int):
172
+ super().__init__()
173
+ self.n_head = n_head
174
+ self.query = Linear(n_state, n_state)
175
+ self.key = Linear(n_state, n_state, bias=False)
176
+ self.value = Linear(n_state, n_state)
177
+ self.out = Linear(n_state, n_state)
178
+
179
+ def forward(
180
+ self,
181
+ x: Tensor,
182
+ mask: Optional[Tensor] = None,
183
+ ):
184
+ q = self.query(x)
185
+ k = self.key(x)
186
+ v = self.value(x)
187
+
188
+ wv, qk = self.qkv_attention(q, k, v, mask)
189
+ return self.out(wv), qk
190
+
191
+ def qkv_attention(
192
+ self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
193
+ ):
194
+ _, T, D = q.shape
195
+ scale = (D // self.n_head) ** -0.25
196
+ q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
197
+ k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
198
+ v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
199
+
200
+ qk = q @ k # (B, n_head, T, T)
201
+ if mask is not None:
202
+ qk = qk + mask
203
+ qk = qk.float()
204
+
205
+ w = F.softmax(qk, dim=-1).to(q.dtype)
206
+ return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
207
+
208
+ class ResidualAttentionBlock(nn.Module):
209
+ def __init__(self, n_state: int, n_head: int):
210
+ super().__init__()
211
+
212
+ self.attn = MultiHeadAttention(n_state, n_head)
213
+ self.attn_ln = LayerNorm(n_state)
214
+
215
+ n_mlp = n_state * 4
216
+ self.mlp = nn.Sequential(
217
+ Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
218
+ )
219
+ self.mlp_ln = LayerNorm(n_state)
220
+
221
+ def forward(
222
+ self,
223
+ x: Tensor,
224
+ mask: Optional[Tensor] = None,
225
+ ):
226
+ x = x + self.attn(self.attn_ln(x.contiguous()), mask=mask)[0]
227
+ x = x + self.mlp(self.mlp_ln(x.contiguous()))
228
+ return x
229
+
230
+ class AudioEncoder(nn.Module):
231
+ def __init__(
232
+ self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
233
+ ):
234
+ super().__init__()
235
+ self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
236
+ self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
237
+ self.positional_embedding = nn.Embedding(n_ctx, n_state)
238
+ self.positional_embedding.requires_grad_(False)
239
+ self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
240
+ [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
241
+ )
242
+ self.avg_pooler = nn.AvgPool1d(2, stride=2)
243
+ self.after_norm = LayerNorm(n_state)
244
+ self.gradient_checkpointing = False
245
+
246
+ def forward(self, x: Tensor, x_len: Tensor) -> Tuple[Tensor, Tensor]:
247
+ T = x.size(-1)
248
+ x = F.gelu(self.conv1(x))
249
+ x = F.gelu(self.conv2(x))
250
+ x = x.permute(0, 2, 1) # (B, T // 2, n_state)
251
+ mask = make_non_pad_mask(x_len, T).unsqueeze(1) # (B, 1, T)
252
+ mask = mask_to_bias(mask[:, :, (T + 1) % 2::2], x.dtype) # (B, 1, T // 2)
253
+ x = (x + self.positional_embedding.weight[:x.shape[1], :]).to(x.dtype)
254
+ for block in self.blocks:
255
+ if self.gradient_checkpointing and self.training:
256
+ x = torch.utils.checkpoint.checkpoint(block, x, mask.unsqueeze(1))
257
+ else:
258
+ x = block(x, mask.unsqueeze(1))
259
+ x = x.permute(0, 2, 1)
260
+ x = self.avg_pooler(x)
261
+ x = x.permute(0, 2, 1)
262
+ x_len = (x_len + 1) // 2 // 2
263
+ x = self.after_norm(x.contiguous())
264
+ return x, x_len
265
+
266
+ class Adaptor(nn.Module):
267
+ def __init__(
268
+ self,
269
+ n_state: int = 1280,
270
+ n_hidden: int = 3072,
271
+ kernel_size: int = 7,
272
+ stride: int = 4
273
+ ):
274
+ super().__init__()
275
+ self.stride = stride
276
+ if self.stride != -1:
277
+ # print("self.stride: {}".format(self.stride))
278
+ self.conv = Conv1d(n_state, n_state, kernel_size, stride, padding=1)
279
+ self.linear1 = nn.Linear(n_state, 2048)
280
+ self.relu = nn.ReLU()
281
+ self.linear2 = nn.Linear(2048, n_hidden)
282
+ self.gradient_checkpointing = False
283
+
284
+ def forward(self, x: Tensor) -> Tuple[Tensor]:
285
+ T = x.size(-1)
286
+ if self.stride != -1:
287
+ if self.gradient_checkpointing and self.training:
288
+ x = torch.utils.checkpoint.checkpoint(self.conv, x.permute(0, 2, 1))
289
+ x = x.permute(0, 2, 1)
290
+ else:
291
+ x = x.permute(0, 2, 1)
292
+ x = F.gelu(self.conv(x))
293
+ x = x.permute(0, 2, 1)
294
+ if self.gradient_checkpointing and self.training:
295
+ x = torch.utils.checkpoint.checkpoint(self.linear1, x)
296
+ x = torch.utils.checkpoint.checkpoint(self.relu, x)
297
+ x = torch.utils.checkpoint.checkpoint(self.linear2, x)
298
+ else:
299
+ x = self.linear1(x)
300
+ x = self.relu(x)
301
+ x = self.linear2(x)
302
+ return x
303
+
304
+ class StepAudio2ForCausalLM(PreTrainedModel, GenerationMixin):
305
+ config_class = StepAudio2Config
306
+ main_input_name = "input_ids"
307
+ # Important: Add this attribute to make HF recognize it as a model with generation capability
308
+ # _keys_to_ignore_on_load_missing = ["lm_head.weight"]
309
+ supports_gradient_checkpointing = True # 新增,声明支持gradient checkpointing
310
+
311
+ def __init__(self, config: StepAudio2Config):
312
+ super().__init__(config)
313
+ if isinstance(config.torch_dtype, str):
314
+ dtype = getattr(torch, config.torch_dtype)
315
+ else:
316
+ dtype = config.torch_dtype
317
+ self.model = Qwen2Model(config.text_config)
318
+ self.bf16 = dtype==torch.bfloat16
319
+ self.encoder = AudioEncoder(
320
+ config.audio_encoder_config.n_mels, config.audio_encoder_config.n_audio_ctx, config.audio_encoder_config.n_audio_state,
321
+ config.audio_encoder_config.n_audio_head, config.audio_encoder_config.n_audio_layer
322
+ )
323
+ self.adapter = Adaptor(
324
+ config.audio_encoder_config.n_audio_state, config.audio_encoder_config.llm_dim,
325
+ config.audio_encoder_config.kernel_size, config.audio_encoder_config.adapter_stride
326
+ )
327
+ if self.bf16:
328
+ self.encoder = self.encoder.bfloat16()
329
+ self.adapter = self.adapter.bfloat16()
330
+ self.lm_head = torch.nn.Linear(
331
+ config.text_config.hidden_size,
332
+ config.text_config.vocab_size,
333
+ bias=False,
334
+ dtype=dtype
335
+ )
336
+ self.post_init()
337
+
338
+ def forward(
339
+ self,
340
+ input_ids=None,
341
+ wavs=None,
342
+ wav_lens=None,
343
+ attention_mask=None,
344
+ **kwargs
345
+ ):
346
+ hidden_states = self.model.embed_tokens(input_ids)
347
+ if wavs is not None:
348
+ if self.bf16:
349
+ wavs = wavs.bfloat16()
350
+ out, feat_lens = self.encoder(wavs, wav_lens)
351
+ out = self.adapter(out)
352
+ feat_lens = (feat_lens - 1) // 2 + 1
353
+ insert_location = torch.nonzero(input_ids == 151688)
354
+ insert_location[:,1] += 1
355
+ for idx in range(len(insert_location)):
356
+ i,s = insert_location[idx]
357
+ hidden_states[i][s : s+feat_lens[idx]] = out[idx][:feat_lens[idx]]
358
+
359
+ x = self.model(inputs_embeds=hidden_states, attention_mask=attention_mask)[0]
360
+ logits = self.lm_head(x)
361
+ return CausalLMOutputWithPast(
362
+ logits=logits,
363
+ past_key_values=None,
364
+ hidden_states=None,
365
+ attentions=None
366
+ )
367
+
368
+ def get_input_embeddings(self):
369
+ """Return the model's input embeddings - required for GenerationMixin"""
370
+ return self.model.embed_tokens
371
+
372
+ def get_output_embeddings(self):
373
+ """Return the model's output embeddings (LM head) - required for GenerationMixin"""
374
+ return self.lm_head
375
+
376
+ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **kwargs):
377
+ """Prepare inputs for generation - required for GenerationMixin"""
378
+ # Keep the wavs and wav_lens from the initial call
379
+ wavs = kwargs.get("wavs", None)
380
+ wav_lens = kwargs.get("wav_lens", None)
381
+
382
+ # For generation steps after the first, we don't need to process audio again
383
+ # because the audio tokens have already been replaced in the input sequence
384
+ if "past_key_values" in kwargs and kwargs["past_key_values"] is not None:
385
+ # We're in a generation step, no need to process audio again
386
+ return {
387
+ "input_ids": input_ids,
388
+ "attention_mask": attention_mask,
389
+ "past_key_values": kwargs.get("past_key_values")
390
+ }
391
+
392
+ # First generation step, include audio processing
393
+ return {
394
+ "input_ids": input_ids,
395
+ "attention_mask": attention_mask,
396
+ "wavs": wavs,
397
+ "wav_lens": wav_lens
398
+ }
399
+
400
+ def _reorder_cache(self, past_key_values, beam_idx):
401
+ """Reorder the cache for beam search - required for GenerationMixin if using beam search"""
402
+ # If you're not using past_key_values or beam search, this can be a simple pass-through
403
+ # Otherwise implement according to your model's cache structure
404
+ return past_key_values
405
+
406
+ def _set_gradient_checkpointing(self, module, value=False):
407
+ # For Qwen2Model
408
+ if hasattr(self.model, 'gradient_checkpointing'):
409
+ self.model.gradient_checkpointing = value
410
+
411
+ # Add the missing _gradient_checkpointing_func method to Qwen2Model
412
+ # This is what Qwen2Model tries to use when gradient_checkpointing=True
413
+ if value and not hasattr(self.model, '_gradient_checkpointing_func'):
414
+ def _gradient_checkpointing_func(module_to_run, *args, **kwargs):
415
+ # This function wraps torch.utils.checkpoint.checkpoint
416
+ # and is used by Qwen2Model to perform checkpointing
417
+ return torch.utils.checkpoint.checkpoint(module_to_run, *args, **kwargs)
418
+
419
+ self.model._gradient_checkpointing_func = _gradient_checkpointing_func
420
+
421
+ # For custom encoder and adapter
422
+ if hasattr(self.encoder, 'gradient_checkpointing'):
423
+ self.encoder.gradient_checkpointing = value
424
+ if hasattr(self.adapter, 'gradient_checkpointing'):
425
+ self.adapter.gradient_checkpointing = value
recipe.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ default_stage:
2
+ default_modifiers:
3
+ QuantizationModifier:
4
+ targets: [Linear]
5
+ ignore: [lm_head, 're:^encoder\.', 're:^adapter\.', 're:^model\.embed_tokens\.', 're:.*layernorm.*']
6
+ scheme: NVFP4
special_tokens_map.json ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<|EOT|>",
4
+ "<|BOT|>",
5
+ "<|CALL_START|>",
6
+ "<|CALL_END|>",
7
+ "<|THINK_START|>",
8
+ "<|THINK_END|>",
9
+ "<|IMG_START|>",
10
+ "<|IMG_END|>",
11
+ "<|META_START|>",
12
+ "<|META_END|>",
13
+ "<im_patch>",
14
+ "<im_start>",
15
+ "<im_end>",
16
+ "<dream>",
17
+ "<dream_start>",
18
+ "<dream_end>",
19
+ "<|MASK_1e69f|>",
20
+ "<|UNMASK_1e69f|>",
21
+ "<video_start>",
22
+ "<video_end>",
23
+ "<patch_start>",
24
+ "<patch_end>",
25
+ "<patch_newline>",
26
+ "<audio_start>",
27
+ "<audio_end>",
28
+ "<audio_patch>",
29
+ "<audio_patch_pad>",
30
+ "<|SC|>",
31
+ "<tts_start>",
32
+ "<tts_end>",
33
+ "<tts_pad>"
34
+ ],
35
+ "eos_token": {
36
+ "content": "<|endoftext|>",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false
41
+ },
42
+ "pad_token": {
43
+ "content": "<|endoftext|>",
44
+ "lstrip": false,
45
+ "normalized": false,
46
+ "rstrip": false,
47
+ "single_word": false
48
+ }
49
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c23796e0498b651e92b0d514d43636d0dfd556534f8dde7b72ed0e2ff1d07744
3
+ size 12684616
tokenizer_config.json ADDED
The diff for this file is too large to render. See raw diff
 
vocab.json ADDED
The diff for this file is too large to render. See raw diff