Add NVFP4 quantized checkpoint
Browse files- .gitattributes +1 -0
- README.md +23 -0
- added_tokens.json +0 -0
- chat_template.jinja +74 -0
- config.json +358 -0
- configuration_step_audio_2.py +128 -0
- generation_config.json +5 -0
- merges.txt +0 -0
- model-00001-of-00005.safetensors +3 -0
- model-00002-of-00005.safetensors +3 -0
- model-00003-of-00005.safetensors +3 -0
- model-00004-of-00005.safetensors +3 -0
- model-00005-of-00005.safetensors +3 -0
- model.safetensors.index.json +0 -0
- modeling_step_audio_2.py +425 -0
- recipe.yaml +6 -0
- special_tokens_map.json +49 -0
- tokenizer.json +3 -0
- tokenizer_config.json +0 -0
- vocab.json +0 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
datasets:
|
| 3 |
+
- Rombo-Org/Optimized_Reasoning
|
| 4 |
+
base_model:
|
| 5 |
+
- stepfun-ai/Step-Audio-R1
|
| 6 |
+
---
|
| 7 |
+
# Step-Audio-R1-nvfp4
|
| 8 |
+
|
| 9 |
+
**Format:** NVFP4 — weights & activations quantized to FP4 with dual scaling.
|
| 10 |
+
**Base model:** `stepfun-ai/Step-Audio-R1`
|
| 11 |
+
**How it was made:** One-shot calibration with LLM Compressor (NVFP4 recipe), long-seq calibration with Rombo-Org/Optimized_Reasoning.
|
| 12 |
+
|
| 13 |
+
> Notes: Keep `lm_head` in high precision; calibrate on long, domain-relevant sequences.
|
| 14 |
+
|
| 15 |
+
Check the original model card for information about this model.
|
| 16 |
+
|
| 17 |
+
# Running the model with VLLM in Docker
|
| 18 |
+
```sh
|
| 19 |
+
sudo docker run --runtime nvidia --gpus all -p 8000:8000 --ipc=host vllm/vllm-openai:nightly --model Firworks/Step-Audio-R1-nvfp4 --dtype auto --max-model-len 32768
|
| 20 |
+
```
|
| 21 |
+
This was tested on a B200 cloud instance.
|
| 22 |
+
|
| 23 |
+
If there are other models you're interested in seeing quantized to NVFP4 for use on the DGX Spark, or other modern Blackwell (or newer) cards let me know. I'm trying to make more NVFP4 models available to allow more people to try them out.
|
added_tokens.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
chat_template.jinja
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{%- if tools %}
|
| 2 |
+
{{- '<|BOT|>system
|
| 3 |
+
' }}
|
| 4 |
+
{%- if messages[0]['role'] == 'system' %}
|
| 5 |
+
{{- messages[0]['content'] + '<|EOT|>' }}
|
| 6 |
+
{%- else %}
|
| 7 |
+
{{- 'You are a helpful assistant. Please think step by step and provide your reasoning process within <think> </think> tags, followed by your final answer. Format: <think>your reasoning here</think>your final answer<|EOT|>' }}
|
| 8 |
+
{%- endif %}
|
| 9 |
+
{{- '<|BOT|>' }}
|
| 10 |
+
{{- "tool_json_schemas
|
| 11 |
+
" }}
|
| 12 |
+
{{- tools | tojson }}
|
| 13 |
+
{{- '<|EOT|>' }}
|
| 14 |
+
{%- else %}
|
| 15 |
+
{%- if messages[0]['role'] == 'system' %}
|
| 16 |
+
{{- '<|BOT|>system
|
| 17 |
+
' + messages[0]['content'] + '<|EOT|>' }}
|
| 18 |
+
{%- else %}
|
| 19 |
+
{{- '<|BOT|>system
|
| 20 |
+
You are a helpful assistant. Please think step by step and provide your reasoning process within <think> </think> tags, followed by your final answer. Format: <think>your reasoning here</think>your final answer<|EOT|>' }}
|
| 21 |
+
{%- endif %}
|
| 22 |
+
{%- endif %}
|
| 23 |
+
{%- for message in messages %}
|
| 24 |
+
{%- if message["role"] == "user" %}
|
| 25 |
+
{{- '<|BOT|>human
|
| 26 |
+
' + message["content"] + '<|EOT|>' }}
|
| 27 |
+
{%- elif (message["role"] == "system" and not loop.first) or (message["role"] == "assistant" and not message["tool_calls"]) %}
|
| 28 |
+
{{- '<|BOT|>' + message["role"] + '
|
| 29 |
+
' + message["content"] + '<|EOT|>' }}
|
| 30 |
+
{%- elif message["role"] == "assistant" %}
|
| 31 |
+
{{- '<|BOT|>' + message["role"] + '
|
| 32 |
+
' }}
|
| 33 |
+
{%- if message["content"] %}
|
| 34 |
+
{{- message["content"] }}
|
| 35 |
+
{%- endif %}
|
| 36 |
+
{%- for tool_call in message.tool_calls %}
|
| 37 |
+
{%- if tool_call["function"] is defined %}
|
| 38 |
+
{%- set tool_call = tool_call["function"] %}
|
| 39 |
+
{%- endif %}
|
| 40 |
+
{{- '<|CALL_START|>' + 'function
|
| 41 |
+
' + tool_call["name"] + '
|
| 42 |
+
' }}
|
| 43 |
+
{{- tool_call["arguments"] | tojson }}
|
| 44 |
+
{{- '<|CALL_END|>' }}
|
| 45 |
+
{%- endfor %}
|
| 46 |
+
{{- '<|EOT|>' }}
|
| 47 |
+
{%- elif message["role"] == "tool" %}
|
| 48 |
+
{{- '<|BOT|>' }}
|
| 49 |
+
{%- set ns = namespace(function_name="tool") %}
|
| 50 |
+
{%- if message["tool_call_id"] %}
|
| 51 |
+
{%- for prev_msg in messages %}
|
| 52 |
+
{%- if prev_msg["role"] == "assistant" and prev_msg["tool_calls"] %}
|
| 53 |
+
{%- for tool_call in prev_msg["tool_calls"] %}
|
| 54 |
+
{%- if tool_call["id"] == message["tool_call_id"] %}
|
| 55 |
+
{%- if tool_call["function"] is defined %}
|
| 56 |
+
{%- set ns.function_name = tool_call["function"]["name"] %}
|
| 57 |
+
{%- endif %}
|
| 58 |
+
{%- endif %}
|
| 59 |
+
{%- endfor %}
|
| 60 |
+
{%- endif %}
|
| 61 |
+
{%- endfor %}
|
| 62 |
+
{%- endif %}
|
| 63 |
+
{{- 'function_output
|
| 64 |
+
' + ns.function_name + '
|
| 65 |
+
' }}
|
| 66 |
+
{{- message["content"] }}
|
| 67 |
+
{{- '<|EOT|>' }}
|
| 68 |
+
{%- endif %}
|
| 69 |
+
{%- endfor %}
|
| 70 |
+
{%- if add_generation_prompt %}
|
| 71 |
+
{{- '<|BOT|>assistant
|
| 72 |
+
<think>
|
| 73 |
+
' }}
|
| 74 |
+
{%- endif %}
|
config.json
ADDED
|
@@ -0,0 +1,358 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"StepAudio2ForCausalLM"
|
| 4 |
+
],
|
| 5 |
+
"audio_encoder_config": {
|
| 6 |
+
"adapter_stride": 2,
|
| 7 |
+
"kernel_size": 3,
|
| 8 |
+
"llm_dim": 5120,
|
| 9 |
+
"model_type": "step_audio_2_encoder",
|
| 10 |
+
"n_audio_ctx": 1500,
|
| 11 |
+
"n_audio_head": 20,
|
| 12 |
+
"n_audio_layer": 32,
|
| 13 |
+
"n_audio_state": 1280,
|
| 14 |
+
"n_codebook_size": 4096,
|
| 15 |
+
"n_mels": 128
|
| 16 |
+
},
|
| 17 |
+
"auto_map": {
|
| 18 |
+
"AutoConfig": "configuration_step_audio_2.StepAudio2Config",
|
| 19 |
+
"AutoModelForCausalLM": "modeling_step_audio_2.StepAudio2ForCausalLM"
|
| 20 |
+
},
|
| 21 |
+
"dtype": "bfloat16",
|
| 22 |
+
"max_window_layers": null,
|
| 23 |
+
"model_type": "step_audio_2",
|
| 24 |
+
"quantization_config": {
|
| 25 |
+
"config_groups": {
|
| 26 |
+
"group_0": {
|
| 27 |
+
"format": "nvfp4-pack-quantized",
|
| 28 |
+
"input_activations": {
|
| 29 |
+
"actorder": null,
|
| 30 |
+
"block_structure": null,
|
| 31 |
+
"dynamic": "local",
|
| 32 |
+
"group_size": 16,
|
| 33 |
+
"num_bits": 4,
|
| 34 |
+
"observer": "minmax",
|
| 35 |
+
"observer_kwargs": {},
|
| 36 |
+
"strategy": "tensor_group",
|
| 37 |
+
"symmetric": true,
|
| 38 |
+
"type": "float"
|
| 39 |
+
},
|
| 40 |
+
"output_activations": null,
|
| 41 |
+
"targets": [
|
| 42 |
+
"Linear"
|
| 43 |
+
],
|
| 44 |
+
"weights": {
|
| 45 |
+
"actorder": null,
|
| 46 |
+
"block_structure": null,
|
| 47 |
+
"dynamic": false,
|
| 48 |
+
"group_size": 16,
|
| 49 |
+
"num_bits": 4,
|
| 50 |
+
"observer": "minmax",
|
| 51 |
+
"observer_kwargs": {},
|
| 52 |
+
"strategy": "tensor_group",
|
| 53 |
+
"symmetric": true,
|
| 54 |
+
"type": "float"
|
| 55 |
+
}
|
| 56 |
+
}
|
| 57 |
+
},
|
| 58 |
+
"format": "nvfp4-pack-quantized",
|
| 59 |
+
"global_compression_ratio": null,
|
| 60 |
+
"ignore": [
|
| 61 |
+
"encoder.blocks.0.attn.query",
|
| 62 |
+
"encoder.blocks.0.attn.key",
|
| 63 |
+
"encoder.blocks.0.attn.value",
|
| 64 |
+
"encoder.blocks.0.attn.out",
|
| 65 |
+
"encoder.blocks.0.mlp.0",
|
| 66 |
+
"encoder.blocks.0.mlp.2",
|
| 67 |
+
"encoder.blocks.1.attn.query",
|
| 68 |
+
"encoder.blocks.1.attn.key",
|
| 69 |
+
"encoder.blocks.1.attn.value",
|
| 70 |
+
"encoder.blocks.1.attn.out",
|
| 71 |
+
"encoder.blocks.1.mlp.0",
|
| 72 |
+
"encoder.blocks.1.mlp.2",
|
| 73 |
+
"encoder.blocks.2.attn.query",
|
| 74 |
+
"encoder.blocks.2.attn.key",
|
| 75 |
+
"encoder.blocks.2.attn.value",
|
| 76 |
+
"encoder.blocks.2.attn.out",
|
| 77 |
+
"encoder.blocks.2.mlp.0",
|
| 78 |
+
"encoder.blocks.2.mlp.2",
|
| 79 |
+
"encoder.blocks.3.attn.query",
|
| 80 |
+
"encoder.blocks.3.attn.key",
|
| 81 |
+
"encoder.blocks.3.attn.value",
|
| 82 |
+
"encoder.blocks.3.attn.out",
|
| 83 |
+
"encoder.blocks.3.mlp.0",
|
| 84 |
+
"encoder.blocks.3.mlp.2",
|
| 85 |
+
"encoder.blocks.4.attn.query",
|
| 86 |
+
"encoder.blocks.4.attn.key",
|
| 87 |
+
"encoder.blocks.4.attn.value",
|
| 88 |
+
"encoder.blocks.4.attn.out",
|
| 89 |
+
"encoder.blocks.4.mlp.0",
|
| 90 |
+
"encoder.blocks.4.mlp.2",
|
| 91 |
+
"encoder.blocks.5.attn.query",
|
| 92 |
+
"encoder.blocks.5.attn.key",
|
| 93 |
+
"encoder.blocks.5.attn.value",
|
| 94 |
+
"encoder.blocks.5.attn.out",
|
| 95 |
+
"encoder.blocks.5.mlp.0",
|
| 96 |
+
"encoder.blocks.5.mlp.2",
|
| 97 |
+
"encoder.blocks.6.attn.query",
|
| 98 |
+
"encoder.blocks.6.attn.key",
|
| 99 |
+
"encoder.blocks.6.attn.value",
|
| 100 |
+
"encoder.blocks.6.attn.out",
|
| 101 |
+
"encoder.blocks.6.mlp.0",
|
| 102 |
+
"encoder.blocks.6.mlp.2",
|
| 103 |
+
"encoder.blocks.7.attn.query",
|
| 104 |
+
"encoder.blocks.7.attn.key",
|
| 105 |
+
"encoder.blocks.7.attn.value",
|
| 106 |
+
"encoder.blocks.7.attn.out",
|
| 107 |
+
"encoder.blocks.7.mlp.0",
|
| 108 |
+
"encoder.blocks.7.mlp.2",
|
| 109 |
+
"encoder.blocks.8.attn.query",
|
| 110 |
+
"encoder.blocks.8.attn.key",
|
| 111 |
+
"encoder.blocks.8.attn.value",
|
| 112 |
+
"encoder.blocks.8.attn.out",
|
| 113 |
+
"encoder.blocks.8.mlp.0",
|
| 114 |
+
"encoder.blocks.8.mlp.2",
|
| 115 |
+
"encoder.blocks.9.attn.query",
|
| 116 |
+
"encoder.blocks.9.attn.key",
|
| 117 |
+
"encoder.blocks.9.attn.value",
|
| 118 |
+
"encoder.blocks.9.attn.out",
|
| 119 |
+
"encoder.blocks.9.mlp.0",
|
| 120 |
+
"encoder.blocks.9.mlp.2",
|
| 121 |
+
"encoder.blocks.10.attn.query",
|
| 122 |
+
"encoder.blocks.10.attn.key",
|
| 123 |
+
"encoder.blocks.10.attn.value",
|
| 124 |
+
"encoder.blocks.10.attn.out",
|
| 125 |
+
"encoder.blocks.10.mlp.0",
|
| 126 |
+
"encoder.blocks.10.mlp.2",
|
| 127 |
+
"encoder.blocks.11.attn.query",
|
| 128 |
+
"encoder.blocks.11.attn.key",
|
| 129 |
+
"encoder.blocks.11.attn.value",
|
| 130 |
+
"encoder.blocks.11.attn.out",
|
| 131 |
+
"encoder.blocks.11.mlp.0",
|
| 132 |
+
"encoder.blocks.11.mlp.2",
|
| 133 |
+
"encoder.blocks.12.attn.query",
|
| 134 |
+
"encoder.blocks.12.attn.key",
|
| 135 |
+
"encoder.blocks.12.attn.value",
|
| 136 |
+
"encoder.blocks.12.attn.out",
|
| 137 |
+
"encoder.blocks.12.mlp.0",
|
| 138 |
+
"encoder.blocks.12.mlp.2",
|
| 139 |
+
"encoder.blocks.13.attn.query",
|
| 140 |
+
"encoder.blocks.13.attn.key",
|
| 141 |
+
"encoder.blocks.13.attn.value",
|
| 142 |
+
"encoder.blocks.13.attn.out",
|
| 143 |
+
"encoder.blocks.13.mlp.0",
|
| 144 |
+
"encoder.blocks.13.mlp.2",
|
| 145 |
+
"encoder.blocks.14.attn.query",
|
| 146 |
+
"encoder.blocks.14.attn.key",
|
| 147 |
+
"encoder.blocks.14.attn.value",
|
| 148 |
+
"encoder.blocks.14.attn.out",
|
| 149 |
+
"encoder.blocks.14.mlp.0",
|
| 150 |
+
"encoder.blocks.14.mlp.2",
|
| 151 |
+
"encoder.blocks.15.attn.query",
|
| 152 |
+
"encoder.blocks.15.attn.key",
|
| 153 |
+
"encoder.blocks.15.attn.value",
|
| 154 |
+
"encoder.blocks.15.attn.out",
|
| 155 |
+
"encoder.blocks.15.mlp.0",
|
| 156 |
+
"encoder.blocks.15.mlp.2",
|
| 157 |
+
"encoder.blocks.16.attn.query",
|
| 158 |
+
"encoder.blocks.16.attn.key",
|
| 159 |
+
"encoder.blocks.16.attn.value",
|
| 160 |
+
"encoder.blocks.16.attn.out",
|
| 161 |
+
"encoder.blocks.16.mlp.0",
|
| 162 |
+
"encoder.blocks.16.mlp.2",
|
| 163 |
+
"encoder.blocks.17.attn.query",
|
| 164 |
+
"encoder.blocks.17.attn.key",
|
| 165 |
+
"encoder.blocks.17.attn.value",
|
| 166 |
+
"encoder.blocks.17.attn.out",
|
| 167 |
+
"encoder.blocks.17.mlp.0",
|
| 168 |
+
"encoder.blocks.17.mlp.2",
|
| 169 |
+
"encoder.blocks.18.attn.query",
|
| 170 |
+
"encoder.blocks.18.attn.key",
|
| 171 |
+
"encoder.blocks.18.attn.value",
|
| 172 |
+
"encoder.blocks.18.attn.out",
|
| 173 |
+
"encoder.blocks.18.mlp.0",
|
| 174 |
+
"encoder.blocks.18.mlp.2",
|
| 175 |
+
"encoder.blocks.19.attn.query",
|
| 176 |
+
"encoder.blocks.19.attn.key",
|
| 177 |
+
"encoder.blocks.19.attn.value",
|
| 178 |
+
"encoder.blocks.19.attn.out",
|
| 179 |
+
"encoder.blocks.19.mlp.0",
|
| 180 |
+
"encoder.blocks.19.mlp.2",
|
| 181 |
+
"encoder.blocks.20.attn.query",
|
| 182 |
+
"encoder.blocks.20.attn.key",
|
| 183 |
+
"encoder.blocks.20.attn.value",
|
| 184 |
+
"encoder.blocks.20.attn.out",
|
| 185 |
+
"encoder.blocks.20.mlp.0",
|
| 186 |
+
"encoder.blocks.20.mlp.2",
|
| 187 |
+
"encoder.blocks.21.attn.query",
|
| 188 |
+
"encoder.blocks.21.attn.key",
|
| 189 |
+
"encoder.blocks.21.attn.value",
|
| 190 |
+
"encoder.blocks.21.attn.out",
|
| 191 |
+
"encoder.blocks.21.mlp.0",
|
| 192 |
+
"encoder.blocks.21.mlp.2",
|
| 193 |
+
"encoder.blocks.22.attn.query",
|
| 194 |
+
"encoder.blocks.22.attn.key",
|
| 195 |
+
"encoder.blocks.22.attn.value",
|
| 196 |
+
"encoder.blocks.22.attn.out",
|
| 197 |
+
"encoder.blocks.22.mlp.0",
|
| 198 |
+
"encoder.blocks.22.mlp.2",
|
| 199 |
+
"encoder.blocks.23.attn.query",
|
| 200 |
+
"encoder.blocks.23.attn.key",
|
| 201 |
+
"encoder.blocks.23.attn.value",
|
| 202 |
+
"encoder.blocks.23.attn.out",
|
| 203 |
+
"encoder.blocks.23.mlp.0",
|
| 204 |
+
"encoder.blocks.23.mlp.2",
|
| 205 |
+
"encoder.blocks.24.attn.query",
|
| 206 |
+
"encoder.blocks.24.attn.key",
|
| 207 |
+
"encoder.blocks.24.attn.value",
|
| 208 |
+
"encoder.blocks.24.attn.out",
|
| 209 |
+
"encoder.blocks.24.mlp.0",
|
| 210 |
+
"encoder.blocks.24.mlp.2",
|
| 211 |
+
"encoder.blocks.25.attn.query",
|
| 212 |
+
"encoder.blocks.25.attn.key",
|
| 213 |
+
"encoder.blocks.25.attn.value",
|
| 214 |
+
"encoder.blocks.25.attn.out",
|
| 215 |
+
"encoder.blocks.25.mlp.0",
|
| 216 |
+
"encoder.blocks.25.mlp.2",
|
| 217 |
+
"encoder.blocks.26.attn.query",
|
| 218 |
+
"encoder.blocks.26.attn.key",
|
| 219 |
+
"encoder.blocks.26.attn.value",
|
| 220 |
+
"encoder.blocks.26.attn.out",
|
| 221 |
+
"encoder.blocks.26.mlp.0",
|
| 222 |
+
"encoder.blocks.26.mlp.2",
|
| 223 |
+
"encoder.blocks.27.attn.query",
|
| 224 |
+
"encoder.blocks.27.attn.key",
|
| 225 |
+
"encoder.blocks.27.attn.value",
|
| 226 |
+
"encoder.blocks.27.attn.out",
|
| 227 |
+
"encoder.blocks.27.mlp.0",
|
| 228 |
+
"encoder.blocks.27.mlp.2",
|
| 229 |
+
"encoder.blocks.28.attn.query",
|
| 230 |
+
"encoder.blocks.28.attn.key",
|
| 231 |
+
"encoder.blocks.28.attn.value",
|
| 232 |
+
"encoder.blocks.28.attn.out",
|
| 233 |
+
"encoder.blocks.28.mlp.0",
|
| 234 |
+
"encoder.blocks.28.mlp.2",
|
| 235 |
+
"encoder.blocks.29.attn.query",
|
| 236 |
+
"encoder.blocks.29.attn.key",
|
| 237 |
+
"encoder.blocks.29.attn.value",
|
| 238 |
+
"encoder.blocks.29.attn.out",
|
| 239 |
+
"encoder.blocks.29.mlp.0",
|
| 240 |
+
"encoder.blocks.29.mlp.2",
|
| 241 |
+
"encoder.blocks.30.attn.query",
|
| 242 |
+
"encoder.blocks.30.attn.key",
|
| 243 |
+
"encoder.blocks.30.attn.value",
|
| 244 |
+
"encoder.blocks.30.attn.out",
|
| 245 |
+
"encoder.blocks.30.mlp.0",
|
| 246 |
+
"encoder.blocks.30.mlp.2",
|
| 247 |
+
"encoder.blocks.31.attn.query",
|
| 248 |
+
"encoder.blocks.31.attn.key",
|
| 249 |
+
"encoder.blocks.31.attn.value",
|
| 250 |
+
"encoder.blocks.31.attn.out",
|
| 251 |
+
"encoder.blocks.31.mlp.0",
|
| 252 |
+
"encoder.blocks.31.mlp.2",
|
| 253 |
+
"adapter.linear1",
|
| 254 |
+
"adapter.linear2",
|
| 255 |
+
"lm_head"
|
| 256 |
+
],
|
| 257 |
+
"kv_cache_scheme": null,
|
| 258 |
+
"quant_method": "compressed-tensors",
|
| 259 |
+
"quantization_status": "compressed",
|
| 260 |
+
"sparsity_config": {},
|
| 261 |
+
"transform_config": {},
|
| 262 |
+
"version": "0.12.2"
|
| 263 |
+
},
|
| 264 |
+
"sliding_window": 2048,
|
| 265 |
+
"text_config": {
|
| 266 |
+
"architectures": [
|
| 267 |
+
"Qwen2ForCausalLM"
|
| 268 |
+
],
|
| 269 |
+
"attention_dropout": 0.0,
|
| 270 |
+
"dtype": "bfloat16",
|
| 271 |
+
"hidden_act": "silu",
|
| 272 |
+
"hidden_size": 5120,
|
| 273 |
+
"initializer_range": 0.02,
|
| 274 |
+
"intermediate_size": 27648,
|
| 275 |
+
"layer_types": [
|
| 276 |
+
"full_attention",
|
| 277 |
+
"full_attention",
|
| 278 |
+
"full_attention",
|
| 279 |
+
"full_attention",
|
| 280 |
+
"full_attention",
|
| 281 |
+
"full_attention",
|
| 282 |
+
"full_attention",
|
| 283 |
+
"full_attention",
|
| 284 |
+
"full_attention",
|
| 285 |
+
"full_attention",
|
| 286 |
+
"full_attention",
|
| 287 |
+
"full_attention",
|
| 288 |
+
"full_attention",
|
| 289 |
+
"full_attention",
|
| 290 |
+
"full_attention",
|
| 291 |
+
"full_attention",
|
| 292 |
+
"full_attention",
|
| 293 |
+
"full_attention",
|
| 294 |
+
"full_attention",
|
| 295 |
+
"full_attention",
|
| 296 |
+
"full_attention",
|
| 297 |
+
"full_attention",
|
| 298 |
+
"full_attention",
|
| 299 |
+
"full_attention",
|
| 300 |
+
"full_attention",
|
| 301 |
+
"full_attention",
|
| 302 |
+
"full_attention",
|
| 303 |
+
"full_attention",
|
| 304 |
+
"full_attention",
|
| 305 |
+
"full_attention",
|
| 306 |
+
"full_attention",
|
| 307 |
+
"full_attention",
|
| 308 |
+
"full_attention",
|
| 309 |
+
"full_attention",
|
| 310 |
+
"full_attention",
|
| 311 |
+
"full_attention",
|
| 312 |
+
"full_attention",
|
| 313 |
+
"full_attention",
|
| 314 |
+
"full_attention",
|
| 315 |
+
"full_attention",
|
| 316 |
+
"full_attention",
|
| 317 |
+
"full_attention",
|
| 318 |
+
"full_attention",
|
| 319 |
+
"full_attention",
|
| 320 |
+
"full_attention",
|
| 321 |
+
"full_attention",
|
| 322 |
+
"full_attention",
|
| 323 |
+
"full_attention",
|
| 324 |
+
"full_attention",
|
| 325 |
+
"full_attention",
|
| 326 |
+
"full_attention",
|
| 327 |
+
"full_attention",
|
| 328 |
+
"full_attention",
|
| 329 |
+
"full_attention",
|
| 330 |
+
"full_attention",
|
| 331 |
+
"full_attention",
|
| 332 |
+
"full_attention",
|
| 333 |
+
"full_attention",
|
| 334 |
+
"full_attention",
|
| 335 |
+
"full_attention",
|
| 336 |
+
"full_attention",
|
| 337 |
+
"full_attention",
|
| 338 |
+
"full_attention",
|
| 339 |
+
"full_attention"
|
| 340 |
+
],
|
| 341 |
+
"max_position_embeddings": 65536,
|
| 342 |
+
"max_window_layers": 28,
|
| 343 |
+
"model_type": "qwen2",
|
| 344 |
+
"num_attention_heads": 40,
|
| 345 |
+
"num_hidden_layers": 64,
|
| 346 |
+
"num_key_value_heads": 8,
|
| 347 |
+
"rms_norm_eps": 1e-05,
|
| 348 |
+
"rope_scaling": null,
|
| 349 |
+
"rope_theta": 1000000.0,
|
| 350 |
+
"sliding_window": null,
|
| 351 |
+
"use_cache": true,
|
| 352 |
+
"use_sliding_window": false,
|
| 353 |
+
"vocab_size": 158720
|
| 354 |
+
},
|
| 355 |
+
"tie_word_embeddings": false,
|
| 356 |
+
"transformers_version": "4.56.2",
|
| 357 |
+
"use_sliding_window": false
|
| 358 |
+
}
|
configuration_step_audio_2.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Union
|
| 2 |
+
|
| 3 |
+
from transformers import Qwen2Config
|
| 4 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class StepAudio2EncoderConfig(PretrainedConfig):
|
| 8 |
+
model_type = "step_audio_2_encoder"
|
| 9 |
+
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
n_mels=128,
|
| 13 |
+
n_audio_ctx=1500,
|
| 14 |
+
n_audio_state=512,
|
| 15 |
+
n_audio_head=8,
|
| 16 |
+
n_audio_layer=6,
|
| 17 |
+
llm_dim=4096,
|
| 18 |
+
kernel_size=3,
|
| 19 |
+
adapter_stride=2,
|
| 20 |
+
**kwargs,
|
| 21 |
+
):
|
| 22 |
+
self.n_mels = n_mels
|
| 23 |
+
self.n_audio_ctx = n_audio_ctx
|
| 24 |
+
self.n_audio_state = n_audio_state
|
| 25 |
+
self.n_audio_head = n_audio_head
|
| 26 |
+
self.n_audio_layer = n_audio_layer
|
| 27 |
+
self.llm_dim = llm_dim
|
| 28 |
+
self.kernel_size = kernel_size
|
| 29 |
+
self.adapter_stride = adapter_stride
|
| 30 |
+
super().__init__(**kwargs)
|
| 31 |
+
|
| 32 |
+
class StepAudio2TextConfig(PretrainedConfig):
|
| 33 |
+
model_type = "step_audio_2_text"
|
| 34 |
+
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
vocab_size=64012,
|
| 38 |
+
hidden_size=4096,
|
| 39 |
+
intermediate_size=11008,
|
| 40 |
+
num_hidden_layers=48,
|
| 41 |
+
num_attention_heads=32,
|
| 42 |
+
num_attention_groups=4,
|
| 43 |
+
num_key_value_heads=4,
|
| 44 |
+
hidden_act="silu",
|
| 45 |
+
max_position_embeddings=8192,
|
| 46 |
+
initializer_range=0.02,
|
| 47 |
+
rms_norm_eps=1e-6,
|
| 48 |
+
rope_theta=1000000.0,
|
| 49 |
+
rope_scaling=None,
|
| 50 |
+
eos_token_id=None,
|
| 51 |
+
**kwargs
|
| 52 |
+
):
|
| 53 |
+
|
| 54 |
+
if eos_token_id is not None:
|
| 55 |
+
if isinstance(eos_token_id, list):
|
| 56 |
+
eos_token_id = list(set([151643, 151645, 151665] + eos_token_id))
|
| 57 |
+
else:
|
| 58 |
+
eos_token_id = [151643, 151645, 151665, eos_token_id]
|
| 59 |
+
else:
|
| 60 |
+
eos_token_id = [151643, 151645, 151665]
|
| 61 |
+
|
| 62 |
+
super().__init__(
|
| 63 |
+
eos_token_id=eos_token_id,
|
| 64 |
+
**kwargs)
|
| 65 |
+
|
| 66 |
+
self.vocab_size = vocab_size
|
| 67 |
+
self.hidden_size = hidden_size
|
| 68 |
+
self.intermediate_size = intermediate_size
|
| 69 |
+
self.num_hidden_layers = num_hidden_layers
|
| 70 |
+
self.num_attention_heads = num_attention_heads
|
| 71 |
+
self.num_attention_groups = num_attention_groups
|
| 72 |
+
self.num_key_value_heads = num_key_value_heads
|
| 73 |
+
assert self.num_attention_groups == self.num_key_value_heads, "num_attention_groups must be equal to num_key_value_heads"
|
| 74 |
+
self.hidden_act = hidden_act
|
| 75 |
+
self.max_position_embeddings = max_position_embeddings
|
| 76 |
+
self.initializer_range = initializer_range
|
| 77 |
+
self.rms_norm_eps = rms_norm_eps
|
| 78 |
+
self.rope_theta = rope_theta
|
| 79 |
+
self.rope_scaling = rope_scaling
|
| 80 |
+
|
| 81 |
+
self.text_config = Qwen2Config(
|
| 82 |
+
vocab_size=vocab_size,
|
| 83 |
+
hidden_size=hidden_size,
|
| 84 |
+
intermediate_size=intermediate_size,
|
| 85 |
+
num_hidden_layers=num_hidden_layers,
|
| 86 |
+
num_attention_heads=num_attention_heads,
|
| 87 |
+
num_key_value_heads=num_key_value_heads,
|
| 88 |
+
hidden_act=hidden_act,
|
| 89 |
+
max_position_embeddings=max_position_embeddings,
|
| 90 |
+
initializer_range=initializer_range,
|
| 91 |
+
rms_norm_eps=rms_norm_eps,
|
| 92 |
+
rope_theta=rope_theta,
|
| 93 |
+
rope_scaling=rope_scaling,
|
| 94 |
+
architectures=["Qwen2ForCausalLM"],
|
| 95 |
+
torch_dtype=getattr(self, "torch_dtype", "bfloat16"),
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
class StepAudio2Config(PretrainedConfig):
|
| 99 |
+
model_type = "step_audio_2"
|
| 100 |
+
architectures = ["StepAudio2ForCausalLM"]
|
| 101 |
+
|
| 102 |
+
def __init__(
|
| 103 |
+
self,
|
| 104 |
+
audio_encoder_config :Optional[Union[dict, StepAudio2EncoderConfig]] = None,
|
| 105 |
+
text_config: Optional[Union[dict, StepAudio2TextConfig]] = None,
|
| 106 |
+
use_sliding_window: bool = False,
|
| 107 |
+
sliding_window: Optional[int] = 2048,
|
| 108 |
+
max_window_layers: Optional[int] = None,
|
| 109 |
+
**kwargs
|
| 110 |
+
):
|
| 111 |
+
kwargs.setdefault("use_sliding_window", use_sliding_window)
|
| 112 |
+
kwargs.setdefault("sliding_window", sliding_window)
|
| 113 |
+
if max_window_layers is None:
|
| 114 |
+
max_window_layers = kwargs.get("num_hidden_layers", None)
|
| 115 |
+
kwargs.setdefault("max_window_layers", max_window_layers)
|
| 116 |
+
super().__init__(**kwargs)
|
| 117 |
+
|
| 118 |
+
if text_config is None:
|
| 119 |
+
text_config = StepAudio2TextConfig().text_config
|
| 120 |
+
elif isinstance(text_config, dict):
|
| 121 |
+
text_config = StepAudio2TextConfig(**text_config).text_config
|
| 122 |
+
|
| 123 |
+
self.text_config = text_config
|
| 124 |
+
|
| 125 |
+
if audio_encoder_config is None:
|
| 126 |
+
self.audio_encoder_config = StepAudio2EncoderConfig()
|
| 127 |
+
elif isinstance(audio_encoder_config, dict):
|
| 128 |
+
self.audio_encoder_config = StepAudio2EncoderConfig(**audio_encoder_config)
|
generation_config.json
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_from_model_config": true,
|
| 3 |
+
"do_sample": true,
|
| 4 |
+
"transformers_version": "4.56.2"
|
| 5 |
+
}
|
merges.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
model-00001-of-00005.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9a47a8da85e0da3bc31c230badb7663bae60309d731cb849ba15959dc422d68e
|
| 3 |
+
size 4952380248
|
model-00002-of-00005.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:95ef95ec980fec815184d6c9f2b2e95661be9e1e063a1a45466e908952803bbc
|
| 3 |
+
size 4937521480
|
model-00003-of-00005.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:093b77b6b160d2476ca545eb757dbd0ee1d4b554fa152a833b4f60a9526bd26d
|
| 3 |
+
size 4937521480
|
model-00004-of-00005.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:33090fc271d96985256ae7998d71d93c46756ad16d63c037602529b155428689
|
| 3 |
+
size 4997834160
|
model-00005-of-00005.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2d812abb83c44af00a386341fc563d9c15cda01c08e1851ff1c904eb09d793a8
|
| 3 |
+
size 2291022848
|
model.safetensors.index.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
modeling_step_audio_2.py
ADDED
|
@@ -0,0 +1,425 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Iterable, Optional, Tuple
|
| 2 |
+
|
| 3 |
+
import librosa
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import torchaudio
|
| 7 |
+
from torch import Tensor, nn
|
| 8 |
+
from transformers import PreTrainedModel, Qwen2Model
|
| 9 |
+
from transformers.generation.utils import GenerationMixin
|
| 10 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 11 |
+
|
| 12 |
+
from .configuration_step_audio_2 import StepAudio2Config
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def _mel_filters(n_mels: int) -> torch.Tensor:
|
| 16 |
+
"""Load the mel filterbank matrix for projecting STFT into a Mel spectrogram."""
|
| 17 |
+
assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
|
| 18 |
+
if n_mels == 128:
|
| 19 |
+
return torch.from_numpy(librosa.filters.mel(sr=16000, n_fft=400, n_mels=128))
|
| 20 |
+
else:
|
| 21 |
+
return torch.from_numpy(librosa.filters.mel(sr=16000, n_fft=400, n_mels=80))
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def load_audio(file_path, target_rate=16000, max_length=None):
|
| 25 |
+
"""
|
| 26 |
+
Open an audio file and read as mono waveform, resampling as necessary
|
| 27 |
+
If max_length is provided, truncate the audio to that length
|
| 28 |
+
"""
|
| 29 |
+
waveform, sample_rate = torchaudio.load(file_path)
|
| 30 |
+
if sample_rate != target_rate:
|
| 31 |
+
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_rate)(waveform)
|
| 32 |
+
audio = waveform[0] # get the first channel
|
| 33 |
+
|
| 34 |
+
# Truncate audio if it exceeds max_length
|
| 35 |
+
if max_length is not None and audio.shape[0] > max_length:
|
| 36 |
+
audio = audio[:max_length]
|
| 37 |
+
|
| 38 |
+
return audio
|
| 39 |
+
|
| 40 |
+
def log_mel_spectrogram(audio, n_mels=128, padding=479, device=None):
|
| 41 |
+
"""
|
| 42 |
+
Compute the log-Mel spectrogram with specific padding for StepAudio
|
| 43 |
+
"""
|
| 44 |
+
if not torch.is_tensor(audio):
|
| 45 |
+
if isinstance(audio, str):
|
| 46 |
+
audio = load_audio(audio)
|
| 47 |
+
audio = torch.from_numpy(audio)
|
| 48 |
+
if device is not None:
|
| 49 |
+
audio = audio.to(device)
|
| 50 |
+
if padding > 0:
|
| 51 |
+
audio = F.pad(audio, (0, padding))
|
| 52 |
+
window = torch.hann_window(400).to(audio.device)
|
| 53 |
+
stft = torch.stft(audio, 400, 160, window=window, return_complex=True)
|
| 54 |
+
magnitudes = stft[..., :-1].abs() ** 2
|
| 55 |
+
filters = _mel_filters(n_mels)
|
| 56 |
+
mel_spec = filters @ magnitudes
|
| 57 |
+
|
| 58 |
+
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
| 59 |
+
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
| 60 |
+
log_spec = (log_spec + 4.0) / 4.0
|
| 61 |
+
return log_spec
|
| 62 |
+
|
| 63 |
+
def compute_token_num(max_feature_len):
|
| 64 |
+
# First, audio goes through encoder:
|
| 65 |
+
# 1. conv1: kernel=3, stride=1, padding=1 -> size unchanged
|
| 66 |
+
# 2. conv2: kernel=3, stride=2, padding=1 -> size/2
|
| 67 |
+
# 3. avg_pooler: kernel=2, stride=2 -> size/2
|
| 68 |
+
max_feature_len = max_feature_len - 2 # remove padding
|
| 69 |
+
encoder_output_dim = (max_feature_len + 1) // 2 // 2 # after conv2 and avg_pooler
|
| 70 |
+
|
| 71 |
+
# Then through adaptor (parameters from config file):
|
| 72 |
+
padding = 1
|
| 73 |
+
kernel_size = 3 # from config: audio_encoder_config.kernel_size
|
| 74 |
+
stride = 2 # from config: audio_encoder_config.adapter_stride
|
| 75 |
+
adapter_output_dim = (encoder_output_dim + 2 * padding - kernel_size) // stride + 1
|
| 76 |
+
return adapter_output_dim
|
| 77 |
+
|
| 78 |
+
def make_non_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
| 79 |
+
"""Make mask tensor containing indices of non-padded part.
|
| 80 |
+
|
| 81 |
+
The sequences in a batch may have different lengths. To enable
|
| 82 |
+
batch computing, padding is need to make all sequence in same
|
| 83 |
+
size. To avoid the padding part pass value to context dependent
|
| 84 |
+
block such as attention or convolution , this padding part is
|
| 85 |
+
masked.
|
| 86 |
+
|
| 87 |
+
1 for non-padded part and 0 for padded part.
|
| 88 |
+
|
| 89 |
+
Parameters
|
| 90 |
+
----------
|
| 91 |
+
lengths (torch.Tensor): Batch of lengths (B,).
|
| 92 |
+
|
| 93 |
+
Returns:
|
| 94 |
+
-------
|
| 95 |
+
torch.Tensor: Mask tensor containing indices of padded part (B, max_T).
|
| 96 |
+
|
| 97 |
+
Examples:
|
| 98 |
+
>>> import torch
|
| 99 |
+
>>> import s3tokenizer
|
| 100 |
+
>>> lengths = torch.tensor([5, 3, 2])
|
| 101 |
+
>>> masks = s3tokenizer.make_non_pad_mask(lengths)
|
| 102 |
+
masks = [[1, 1, 1, 1, 1],
|
| 103 |
+
[1, 1, 1, 0, 0],
|
| 104 |
+
[1, 1, 0, 0, 0]]
|
| 105 |
+
"""
|
| 106 |
+
batch_size = lengths.size(0)
|
| 107 |
+
max_len = max_len if max_len > 0 else lengths.max().item()
|
| 108 |
+
seq_range = torch.arange(0,
|
| 109 |
+
max_len,
|
| 110 |
+
dtype=torch.int64,
|
| 111 |
+
device=lengths.device)
|
| 112 |
+
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
| 113 |
+
seq_length_expand = lengths.unsqueeze(-1)
|
| 114 |
+
mask = seq_range_expand >= seq_length_expand
|
| 115 |
+
return ~mask
|
| 116 |
+
|
| 117 |
+
def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
| 118 |
+
"""Convert bool-tensor to float-tensor for flash attention.
|
| 119 |
+
|
| 120 |
+
Parameters
|
| 121 |
+
----------
|
| 122 |
+
lengths (torch.Tensor): Batch of lengths (B, ?).
|
| 123 |
+
|
| 124 |
+
Returns:
|
| 125 |
+
-------
|
| 126 |
+
torch.Tensor: Mask tensor containing indices of padded part (B, ?).
|
| 127 |
+
|
| 128 |
+
Examples:
|
| 129 |
+
>>> import torch
|
| 130 |
+
>>> import s3tokenizer
|
| 131 |
+
>>> lengths = torch.tensor([5, 3, 2])
|
| 132 |
+
>>> masks = s3tokenizer.make_non_pad_mask(lengths)
|
| 133 |
+
masks = [[1, 1, 1, 1, 1],
|
| 134 |
+
[1, 1, 1, 0, 0],
|
| 135 |
+
[1, 1, 0, 0, 0]]
|
| 136 |
+
>>> new_masks = s3tokenizer.mask_to_bias(masks, torch.float32)
|
| 137 |
+
new_masks = [[-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],
|
| 138 |
+
[-0.0000e+00, -0.0000e+00, -0.0000e+00, -1.0000e+10, -1.0000e+10],
|
| 139 |
+
[-0.0000e+00, -0.0000e+00, -1.0000e+10, -1.0000e+10, -1.0000e+10]]
|
| 140 |
+
"""
|
| 141 |
+
assert mask.dtype == torch.bool
|
| 142 |
+
assert dtype in [torch.float32, torch.bfloat16, torch.float16]
|
| 143 |
+
mask = mask.to(dtype)
|
| 144 |
+
# attention mask bias
|
| 145 |
+
# NOTE(Mddct): torch.finfo jit issues
|
| 146 |
+
# chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
|
| 147 |
+
mask = (1.0 - mask) * -1.0e+10
|
| 148 |
+
return mask
|
| 149 |
+
|
| 150 |
+
class LayerNorm(nn.LayerNorm):
|
| 151 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 152 |
+
return super().forward(input).type(input.dtype)
|
| 153 |
+
|
| 154 |
+
class Linear(nn.Linear):
|
| 155 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 156 |
+
return F.linear(
|
| 157 |
+
input,
|
| 158 |
+
self.weight.to(input.dtype),
|
| 159 |
+
None if self.bias is None else self.bias.to(input.dtype),
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
class Conv1d(nn.Conv1d):
|
| 163 |
+
def _conv_forward(
|
| 164 |
+
self, input: Tensor, weight: Tensor, bias: Optional[Tensor]
|
| 165 |
+
) -> Tensor:
|
| 166 |
+
return super()._conv_forward(
|
| 167 |
+
input, weight.to(input.dtype), None if bias is None else bias.to(input.dtype)
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
class MultiHeadAttention(nn.Module):
|
| 171 |
+
def __init__(self, n_state: int, n_head: int):
|
| 172 |
+
super().__init__()
|
| 173 |
+
self.n_head = n_head
|
| 174 |
+
self.query = Linear(n_state, n_state)
|
| 175 |
+
self.key = Linear(n_state, n_state, bias=False)
|
| 176 |
+
self.value = Linear(n_state, n_state)
|
| 177 |
+
self.out = Linear(n_state, n_state)
|
| 178 |
+
|
| 179 |
+
def forward(
|
| 180 |
+
self,
|
| 181 |
+
x: Tensor,
|
| 182 |
+
mask: Optional[Tensor] = None,
|
| 183 |
+
):
|
| 184 |
+
q = self.query(x)
|
| 185 |
+
k = self.key(x)
|
| 186 |
+
v = self.value(x)
|
| 187 |
+
|
| 188 |
+
wv, qk = self.qkv_attention(q, k, v, mask)
|
| 189 |
+
return self.out(wv), qk
|
| 190 |
+
|
| 191 |
+
def qkv_attention(
|
| 192 |
+
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
|
| 193 |
+
):
|
| 194 |
+
_, T, D = q.shape
|
| 195 |
+
scale = (D // self.n_head) ** -0.25
|
| 196 |
+
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
|
| 197 |
+
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
|
| 198 |
+
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
| 199 |
+
|
| 200 |
+
qk = q @ k # (B, n_head, T, T)
|
| 201 |
+
if mask is not None:
|
| 202 |
+
qk = qk + mask
|
| 203 |
+
qk = qk.float()
|
| 204 |
+
|
| 205 |
+
w = F.softmax(qk, dim=-1).to(q.dtype)
|
| 206 |
+
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
|
| 207 |
+
|
| 208 |
+
class ResidualAttentionBlock(nn.Module):
|
| 209 |
+
def __init__(self, n_state: int, n_head: int):
|
| 210 |
+
super().__init__()
|
| 211 |
+
|
| 212 |
+
self.attn = MultiHeadAttention(n_state, n_head)
|
| 213 |
+
self.attn_ln = LayerNorm(n_state)
|
| 214 |
+
|
| 215 |
+
n_mlp = n_state * 4
|
| 216 |
+
self.mlp = nn.Sequential(
|
| 217 |
+
Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
|
| 218 |
+
)
|
| 219 |
+
self.mlp_ln = LayerNorm(n_state)
|
| 220 |
+
|
| 221 |
+
def forward(
|
| 222 |
+
self,
|
| 223 |
+
x: Tensor,
|
| 224 |
+
mask: Optional[Tensor] = None,
|
| 225 |
+
):
|
| 226 |
+
x = x + self.attn(self.attn_ln(x.contiguous()), mask=mask)[0]
|
| 227 |
+
x = x + self.mlp(self.mlp_ln(x.contiguous()))
|
| 228 |
+
return x
|
| 229 |
+
|
| 230 |
+
class AudioEncoder(nn.Module):
|
| 231 |
+
def __init__(
|
| 232 |
+
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
|
| 233 |
+
):
|
| 234 |
+
super().__init__()
|
| 235 |
+
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
|
| 236 |
+
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
|
| 237 |
+
self.positional_embedding = nn.Embedding(n_ctx, n_state)
|
| 238 |
+
self.positional_embedding.requires_grad_(False)
|
| 239 |
+
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
|
| 240 |
+
[ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
|
| 241 |
+
)
|
| 242 |
+
self.avg_pooler = nn.AvgPool1d(2, stride=2)
|
| 243 |
+
self.after_norm = LayerNorm(n_state)
|
| 244 |
+
self.gradient_checkpointing = False
|
| 245 |
+
|
| 246 |
+
def forward(self, x: Tensor, x_len: Tensor) -> Tuple[Tensor, Tensor]:
|
| 247 |
+
T = x.size(-1)
|
| 248 |
+
x = F.gelu(self.conv1(x))
|
| 249 |
+
x = F.gelu(self.conv2(x))
|
| 250 |
+
x = x.permute(0, 2, 1) # (B, T // 2, n_state)
|
| 251 |
+
mask = make_non_pad_mask(x_len, T).unsqueeze(1) # (B, 1, T)
|
| 252 |
+
mask = mask_to_bias(mask[:, :, (T + 1) % 2::2], x.dtype) # (B, 1, T // 2)
|
| 253 |
+
x = (x + self.positional_embedding.weight[:x.shape[1], :]).to(x.dtype)
|
| 254 |
+
for block in self.blocks:
|
| 255 |
+
if self.gradient_checkpointing and self.training:
|
| 256 |
+
x = torch.utils.checkpoint.checkpoint(block, x, mask.unsqueeze(1))
|
| 257 |
+
else:
|
| 258 |
+
x = block(x, mask.unsqueeze(1))
|
| 259 |
+
x = x.permute(0, 2, 1)
|
| 260 |
+
x = self.avg_pooler(x)
|
| 261 |
+
x = x.permute(0, 2, 1)
|
| 262 |
+
x_len = (x_len + 1) // 2 // 2
|
| 263 |
+
x = self.after_norm(x.contiguous())
|
| 264 |
+
return x, x_len
|
| 265 |
+
|
| 266 |
+
class Adaptor(nn.Module):
|
| 267 |
+
def __init__(
|
| 268 |
+
self,
|
| 269 |
+
n_state: int = 1280,
|
| 270 |
+
n_hidden: int = 3072,
|
| 271 |
+
kernel_size: int = 7,
|
| 272 |
+
stride: int = 4
|
| 273 |
+
):
|
| 274 |
+
super().__init__()
|
| 275 |
+
self.stride = stride
|
| 276 |
+
if self.stride != -1:
|
| 277 |
+
# print("self.stride: {}".format(self.stride))
|
| 278 |
+
self.conv = Conv1d(n_state, n_state, kernel_size, stride, padding=1)
|
| 279 |
+
self.linear1 = nn.Linear(n_state, 2048)
|
| 280 |
+
self.relu = nn.ReLU()
|
| 281 |
+
self.linear2 = nn.Linear(2048, n_hidden)
|
| 282 |
+
self.gradient_checkpointing = False
|
| 283 |
+
|
| 284 |
+
def forward(self, x: Tensor) -> Tuple[Tensor]:
|
| 285 |
+
T = x.size(-1)
|
| 286 |
+
if self.stride != -1:
|
| 287 |
+
if self.gradient_checkpointing and self.training:
|
| 288 |
+
x = torch.utils.checkpoint.checkpoint(self.conv, x.permute(0, 2, 1))
|
| 289 |
+
x = x.permute(0, 2, 1)
|
| 290 |
+
else:
|
| 291 |
+
x = x.permute(0, 2, 1)
|
| 292 |
+
x = F.gelu(self.conv(x))
|
| 293 |
+
x = x.permute(0, 2, 1)
|
| 294 |
+
if self.gradient_checkpointing and self.training:
|
| 295 |
+
x = torch.utils.checkpoint.checkpoint(self.linear1, x)
|
| 296 |
+
x = torch.utils.checkpoint.checkpoint(self.relu, x)
|
| 297 |
+
x = torch.utils.checkpoint.checkpoint(self.linear2, x)
|
| 298 |
+
else:
|
| 299 |
+
x = self.linear1(x)
|
| 300 |
+
x = self.relu(x)
|
| 301 |
+
x = self.linear2(x)
|
| 302 |
+
return x
|
| 303 |
+
|
| 304 |
+
class StepAudio2ForCausalLM(PreTrainedModel, GenerationMixin):
|
| 305 |
+
config_class = StepAudio2Config
|
| 306 |
+
main_input_name = "input_ids"
|
| 307 |
+
# Important: Add this attribute to make HF recognize it as a model with generation capability
|
| 308 |
+
# _keys_to_ignore_on_load_missing = ["lm_head.weight"]
|
| 309 |
+
supports_gradient_checkpointing = True # 新增,声明支持gradient checkpointing
|
| 310 |
+
|
| 311 |
+
def __init__(self, config: StepAudio2Config):
|
| 312 |
+
super().__init__(config)
|
| 313 |
+
if isinstance(config.torch_dtype, str):
|
| 314 |
+
dtype = getattr(torch, config.torch_dtype)
|
| 315 |
+
else:
|
| 316 |
+
dtype = config.torch_dtype
|
| 317 |
+
self.model = Qwen2Model(config.text_config)
|
| 318 |
+
self.bf16 = dtype==torch.bfloat16
|
| 319 |
+
self.encoder = AudioEncoder(
|
| 320 |
+
config.audio_encoder_config.n_mels, config.audio_encoder_config.n_audio_ctx, config.audio_encoder_config.n_audio_state,
|
| 321 |
+
config.audio_encoder_config.n_audio_head, config.audio_encoder_config.n_audio_layer
|
| 322 |
+
)
|
| 323 |
+
self.adapter = Adaptor(
|
| 324 |
+
config.audio_encoder_config.n_audio_state, config.audio_encoder_config.llm_dim,
|
| 325 |
+
config.audio_encoder_config.kernel_size, config.audio_encoder_config.adapter_stride
|
| 326 |
+
)
|
| 327 |
+
if self.bf16:
|
| 328 |
+
self.encoder = self.encoder.bfloat16()
|
| 329 |
+
self.adapter = self.adapter.bfloat16()
|
| 330 |
+
self.lm_head = torch.nn.Linear(
|
| 331 |
+
config.text_config.hidden_size,
|
| 332 |
+
config.text_config.vocab_size,
|
| 333 |
+
bias=False,
|
| 334 |
+
dtype=dtype
|
| 335 |
+
)
|
| 336 |
+
self.post_init()
|
| 337 |
+
|
| 338 |
+
def forward(
|
| 339 |
+
self,
|
| 340 |
+
input_ids=None,
|
| 341 |
+
wavs=None,
|
| 342 |
+
wav_lens=None,
|
| 343 |
+
attention_mask=None,
|
| 344 |
+
**kwargs
|
| 345 |
+
):
|
| 346 |
+
hidden_states = self.model.embed_tokens(input_ids)
|
| 347 |
+
if wavs is not None:
|
| 348 |
+
if self.bf16:
|
| 349 |
+
wavs = wavs.bfloat16()
|
| 350 |
+
out, feat_lens = self.encoder(wavs, wav_lens)
|
| 351 |
+
out = self.adapter(out)
|
| 352 |
+
feat_lens = (feat_lens - 1) // 2 + 1
|
| 353 |
+
insert_location = torch.nonzero(input_ids == 151688)
|
| 354 |
+
insert_location[:,1] += 1
|
| 355 |
+
for idx in range(len(insert_location)):
|
| 356 |
+
i,s = insert_location[idx]
|
| 357 |
+
hidden_states[i][s : s+feat_lens[idx]] = out[idx][:feat_lens[idx]]
|
| 358 |
+
|
| 359 |
+
x = self.model(inputs_embeds=hidden_states, attention_mask=attention_mask)[0]
|
| 360 |
+
logits = self.lm_head(x)
|
| 361 |
+
return CausalLMOutputWithPast(
|
| 362 |
+
logits=logits,
|
| 363 |
+
past_key_values=None,
|
| 364 |
+
hidden_states=None,
|
| 365 |
+
attentions=None
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
def get_input_embeddings(self):
|
| 369 |
+
"""Return the model's input embeddings - required for GenerationMixin"""
|
| 370 |
+
return self.model.embed_tokens
|
| 371 |
+
|
| 372 |
+
def get_output_embeddings(self):
|
| 373 |
+
"""Return the model's output embeddings (LM head) - required for GenerationMixin"""
|
| 374 |
+
return self.lm_head
|
| 375 |
+
|
| 376 |
+
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **kwargs):
|
| 377 |
+
"""Prepare inputs for generation - required for GenerationMixin"""
|
| 378 |
+
# Keep the wavs and wav_lens from the initial call
|
| 379 |
+
wavs = kwargs.get("wavs", None)
|
| 380 |
+
wav_lens = kwargs.get("wav_lens", None)
|
| 381 |
+
|
| 382 |
+
# For generation steps after the first, we don't need to process audio again
|
| 383 |
+
# because the audio tokens have already been replaced in the input sequence
|
| 384 |
+
if "past_key_values" in kwargs and kwargs["past_key_values"] is not None:
|
| 385 |
+
# We're in a generation step, no need to process audio again
|
| 386 |
+
return {
|
| 387 |
+
"input_ids": input_ids,
|
| 388 |
+
"attention_mask": attention_mask,
|
| 389 |
+
"past_key_values": kwargs.get("past_key_values")
|
| 390 |
+
}
|
| 391 |
+
|
| 392 |
+
# First generation step, include audio processing
|
| 393 |
+
return {
|
| 394 |
+
"input_ids": input_ids,
|
| 395 |
+
"attention_mask": attention_mask,
|
| 396 |
+
"wavs": wavs,
|
| 397 |
+
"wav_lens": wav_lens
|
| 398 |
+
}
|
| 399 |
+
|
| 400 |
+
def _reorder_cache(self, past_key_values, beam_idx):
|
| 401 |
+
"""Reorder the cache for beam search - required for GenerationMixin if using beam search"""
|
| 402 |
+
# If you're not using past_key_values or beam search, this can be a simple pass-through
|
| 403 |
+
# Otherwise implement according to your model's cache structure
|
| 404 |
+
return past_key_values
|
| 405 |
+
|
| 406 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
| 407 |
+
# For Qwen2Model
|
| 408 |
+
if hasattr(self.model, 'gradient_checkpointing'):
|
| 409 |
+
self.model.gradient_checkpointing = value
|
| 410 |
+
|
| 411 |
+
# Add the missing _gradient_checkpointing_func method to Qwen2Model
|
| 412 |
+
# This is what Qwen2Model tries to use when gradient_checkpointing=True
|
| 413 |
+
if value and not hasattr(self.model, '_gradient_checkpointing_func'):
|
| 414 |
+
def _gradient_checkpointing_func(module_to_run, *args, **kwargs):
|
| 415 |
+
# This function wraps torch.utils.checkpoint.checkpoint
|
| 416 |
+
# and is used by Qwen2Model to perform checkpointing
|
| 417 |
+
return torch.utils.checkpoint.checkpoint(module_to_run, *args, **kwargs)
|
| 418 |
+
|
| 419 |
+
self.model._gradient_checkpointing_func = _gradient_checkpointing_func
|
| 420 |
+
|
| 421 |
+
# For custom encoder and adapter
|
| 422 |
+
if hasattr(self.encoder, 'gradient_checkpointing'):
|
| 423 |
+
self.encoder.gradient_checkpointing = value
|
| 424 |
+
if hasattr(self.adapter, 'gradient_checkpointing'):
|
| 425 |
+
self.adapter.gradient_checkpointing = value
|
recipe.yaml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
default_stage:
|
| 2 |
+
default_modifiers:
|
| 3 |
+
QuantizationModifier:
|
| 4 |
+
targets: [Linear]
|
| 5 |
+
ignore: [lm_head, 're:^encoder\.', 're:^adapter\.', 're:^model\.embed_tokens\.', 're:.*layernorm.*']
|
| 6 |
+
scheme: NVFP4
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"additional_special_tokens": [
|
| 3 |
+
"<|EOT|>",
|
| 4 |
+
"<|BOT|>",
|
| 5 |
+
"<|CALL_START|>",
|
| 6 |
+
"<|CALL_END|>",
|
| 7 |
+
"<|THINK_START|>",
|
| 8 |
+
"<|THINK_END|>",
|
| 9 |
+
"<|IMG_START|>",
|
| 10 |
+
"<|IMG_END|>",
|
| 11 |
+
"<|META_START|>",
|
| 12 |
+
"<|META_END|>",
|
| 13 |
+
"<im_patch>",
|
| 14 |
+
"<im_start>",
|
| 15 |
+
"<im_end>",
|
| 16 |
+
"<dream>",
|
| 17 |
+
"<dream_start>",
|
| 18 |
+
"<dream_end>",
|
| 19 |
+
"<|MASK_1e69f|>",
|
| 20 |
+
"<|UNMASK_1e69f|>",
|
| 21 |
+
"<video_start>",
|
| 22 |
+
"<video_end>",
|
| 23 |
+
"<patch_start>",
|
| 24 |
+
"<patch_end>",
|
| 25 |
+
"<patch_newline>",
|
| 26 |
+
"<audio_start>",
|
| 27 |
+
"<audio_end>",
|
| 28 |
+
"<audio_patch>",
|
| 29 |
+
"<audio_patch_pad>",
|
| 30 |
+
"<|SC|>",
|
| 31 |
+
"<tts_start>",
|
| 32 |
+
"<tts_end>",
|
| 33 |
+
"<tts_pad>"
|
| 34 |
+
],
|
| 35 |
+
"eos_token": {
|
| 36 |
+
"content": "<|endoftext|>",
|
| 37 |
+
"lstrip": false,
|
| 38 |
+
"normalized": false,
|
| 39 |
+
"rstrip": false,
|
| 40 |
+
"single_word": false
|
| 41 |
+
},
|
| 42 |
+
"pad_token": {
|
| 43 |
+
"content": "<|endoftext|>",
|
| 44 |
+
"lstrip": false,
|
| 45 |
+
"normalized": false,
|
| 46 |
+
"rstrip": false,
|
| 47 |
+
"single_word": false
|
| 48 |
+
}
|
| 49 |
+
}
|
tokenizer.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c23796e0498b651e92b0d514d43636d0dfd556534f8dde7b72ed0e2ff1d07744
|
| 3 |
+
size 12684616
|
tokenizer_config.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
vocab.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|