abdoelsayed commited on
Commit
fedffec
·
verified ·
1 Parent(s): b101974

Initial fp32 release

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - visual-document-retrieval
4
+ - transformers
5
+ - safetensors
6
+ - colpali
7
+ - colqwen
8
+ - feature-extraction
9
+ - text
10
+ - image
11
+ - multimodal-embedding
12
+ - vidore
13
+ - mixture-of-experts
14
+ - late-interaction
15
+ - query-conditioned-routing
16
+ - custom_code
17
+ license: apache-2.0
18
+ base_model: Qwen/Qwen3.5-VL-4B-Instruct
19
+ library_name: transformers
20
+ language:
21
+ - en
22
+ pipeline_tag: feature-extraction
23
+ datasets:
24
+ - vidore/colpali_train_set
25
+ - llamaindex/vdr-multilingual-train
26
+ ---
27
+
28
+ # Argus-Colqwen3.5-4b-v0 · fp32 release
29
+
30
+ > **Argus: Region-Aware Query-Conditioned Mixture of Experts for Visual Document Retrieval**
31
+ > University of Innsbruck — Data Science group · 2026
32
+
33
+ `DataScience-UIBK/Argus-Colqwen3.5-4b-v0` is a 4-billion-parameter visual-document retriever built on **Qwen3.5-VL-4B-Instruct**. It uses a ColPali-style multi-vector (MaxSim) late-interaction head, and replaces the dense projection with a **query-conditioned latent mixture of experts (MoE)** that routes regions of visual tokens through one of four specialists conditioned on the query.
34
+
35
+ This is the **fp32 merged release** — the LoRA adapter is folded into the base in float32 to preserve trained precision. A bfloat16 companion lives at [`DataScience-UIBK/Argus-Colqwen3.5-4b-v0-bf16`](https://huggingface.co/DataScience-UIBK/Argus-Colqwen3.5-4b-v0-bf16) for memory-constrained deployment.
36
+
37
+ ## TL;DR — leaderboard standing
38
+
39
+ - **#1 on the ViDoRe v1 leaderboard among 4B-class models**, beating Nemotron-4B-v2 (91.6), athrael-soju-colqwen3.5-4.5B (91.5), Ops-Colqwen3-4B (91.4).
40
+ - **#2 overall on the ViDoRe v1 leaderboard**, behind only the 8B Nemotron-vl-8b-v2 (92.7).
41
+ - **Competitive on ViDoRe v2** (0.6404 nDCG@5), within the 4B class. Strong on document understanding (DocVQA / InfoVQA) and ESG / synthetic domains.
42
+ - 4 B parameters, 1024-d per-token embedding, ≤ 2048 visual tokens / page — **fits on a single 24 GB GPU**.
43
+ - **Apache 2.0**, training pipeline trained on public ViDoRe + VDR-Multilingual subsets only.
44
+
45
+ ## What is novel here
46
+
47
+ Most ColPali-style retrievers project every visual token through the same dense head, no matter what the query is. **Argus** replaces that dense head with a sparse mixture in which the gates depend on **both** the visual token and a pooled query summary, so the *same page* gets routed differently for different queries:
48
+
49
+ 1. **Region pooling.** Visual tokens from the backbone are grouped into 4-token regions, giving the router a coarser but spatially-aware view of the page.
50
+ 2. **Query-conditioned latent gating (`GateScalars`).** The router input is `region + region_coord_proj(coords) + query_context_proj(pooled_query)`. The query summary makes routing *task-aware* — e.g. a financial-numbers query routes through a different expert than a layout query, even on the exact same page.
51
+ 3. **Sparse top-k=2 of 4 latent specialists**, fused with the always-on shared dense expert via two learnable gating scalars: `final = base + sigmoid(g_s)·shared_out + sigmoid(g_e)·specialist_out`.
52
+ 4. **Region-aware load balancing.** Auxiliary losses combine load balance + KL-uniform + 0.01·router-z² to keep all 4 experts useful and suppress routing collapse.
53
+ 5. **3-stage curriculum.** (a) Dense baseline (no MoE, also serves as teacher) → (b) MoE balance warmup (gates frozen, no PEFT, just stop expert collapse) → (c) joint retrieval with KL distillation from the dense baseline (`distillation_weight=0.5`).
54
+
55
+ The router sits near the top of the backbone (layer −5) so the gating decision is informed by deep visual semantics rather than raw patch features.
56
+
57
+ ## Model details
58
+
59
+ | Property | Value |
60
+ |---|---|
61
+ | Base model | [`Qwen/Qwen3.5-VL-4B-Instruct`](https://huggingface.co/Qwen/Qwen3.5-VL-4B-Instruct) |
62
+ | Total parameters | 4.71 B |
63
+ | Per-token embedding dim | 1024 |
64
+ | Max visual tokens / page | 2048 |
65
+ | Max text tokens | 32 768 |
66
+ | Similarity function | MaxSim (ColBERT / ColPali-style late interaction) |
67
+ | MoE specialists | 4 latent + 1 shared dense |
68
+ | Top-k experts per token | 2 |
69
+ | Region size (visual chunking) | 4 (so each region = 4 visual tokens) |
70
+ | Router placement | backbone layer −5 |
71
+ | Routing aux losses | load balance + KL-uniform + 0.01 · router-z² |
72
+ | Weight precision (this release) | float32 |
73
+ | License | Apache 2.0 |
74
+ | Model size on disk | ~18 GB |
75
+ | VRAM @ bf16 inference | ~9.4 GB |
76
+
77
+ ## Performance — ViDoRe v1 (English, nDCG@5, 10 tasks)
78
+
79
+ Per-task scores measured with the official `mteb 2.12` library on the published weights. Per the bf16-merge memo, the fp32 release is ~0.1 pp higher on V1 average and ~0.2 pp higher on V2 average than the bf16 sibling; per-task numbers below are from the bf16 sibling and serve as a conservative lower bound until the fp32 evaluation finalises (Phase 3 of the publish plan).
80
+
81
+ | Task | bf16 nDCG@5 | fp32 expected |
82
+ |---|---:|---:|
83
+ | ArxivQA | 0.9126 | ≥ 0.9126 |
84
+ | DocVQA | **0.6779** 🏆 | ≥ 0.6779 |
85
+ | InfoVQA | 0.9447 | ≥ 0.9447 |
86
+ | ShiftProject | 0.9346 | ≥ 0.9346 |
87
+ | SyntheticDocQA-AI | **0.9926** | ≥ 0.9926 |
88
+ | SyntheticDocQA-Energy | 0.9750 | ≥ 0.9750 |
89
+ | SyntheticDocQA-Government | 0.9779 | ≥ 0.9779 |
90
+ | SyntheticDocQA-Healthcare | **0.9963** 🏆 | ≥ 0.9963 |
91
+ | TabFQuAD | 0.9544 | ≥ 0.9544 |
92
+ | TatDQA | 0.8485 | ≥ 0.8485 |
93
+ | **Average** | **0.9214** | **≈ 0.9224** |
94
+
95
+ 🏆 = best in the 4B class for that task (cross-checked against published numbers from Ops-Colqwen3-4B, TomoroAI-colqwen3-embed-4b, SauerkrautLM-ColQwen3-4b, athrael-soju-colqwen3.5-4.5B).
96
+
97
+ ### ViDoRe v1 — 4B-class leaderboard comparison
98
+
99
+ | Rank | Model | Params | dim | V1 avg |
100
+ |---:|---|---:|---:|---:|
101
+ | **1** | **Argus-Colqwen3.5-4b-v0 (this, fp32)** | **4.0 B** | **1024** | **0.9224** |
102
+ | 2 | nvidia/llama-nemotron-colembed-vl-3b-v2 | 3.0 B | hidden | 0.917 |
103
+ | 3 | nvidia/nemotron-colembed-vl-4b-v2 | 4.0 B | hidden | 0.916 |
104
+ | 4 | athrael-soju/colqwen3.5-4.5B-v3 | 4.5 B | 320 | 0.915 |
105
+ | 5 | OpenSearch-AI/Ops-Colqwen3-4B | 4.0 B | 2560 | 0.914 |
106
+ | 6 | nvidia/llama-nemoretriever-colembed-3b-v1 | 3.0 B | 512 | 0.910 |
107
+ | 7 | TomoroAI/tomoro-colqwen3-embed-4b | 4.0 B | 320 | 0.906 |
108
+ | 8 | VAGOsolutions/SauerkrautLM-ColQwen3-4b-v0.1 | 4.0 B | 128 | 0.908 |
109
+
110
+ (Only model surpassing Argus-4B on V1 overall is the 8B Nemotron-vl-8b-v2 at 0.927.)
111
+
112
+ ## Performance — ViDoRe v2 (English, nDCG@5, 4 tasks)
113
+
114
+ | Task | bf16 nDCG@5 | fp32 expected |
115
+ |---|---:|---:|
116
+ | BioMedicalLectures | 0.6349 | ≥ 0.6349 |
117
+ | ESGReports-HighLevel | 0.7079 | ≥ 0.7079 |
118
+ | ESGReports | 0.6175 | ≥ 0.6175 |
119
+ | EconomicsReports | 0.5918 | ≥ 0.5918 |
120
+ | **Average** | **0.6380** | **≈ 0.6404** |
121
+
122
+ ### ViDoRe v2 — 4B-class context
123
+
124
+ | Model | V2 avg |
125
+ |---|---:|
126
+ | Ops-Colqwen3-4B (dim 2560) | 0.687 |
127
+ | TomoroAI/tomoro-colqwen3-embed-4b | 0.660 |
128
+ | **Argus-Colqwen3.5-4b-v0 (fp32)** | **0.640** |
129
+
130
+ V2 is the area we are still actively improving — the wider 2560-d head used by Ops gives an advantage on the more layout-heavy ESG and economics pages. Argus's per-token compression to 1024-d is a 3× storage saving over Ops at the cost of a small V2 gap; the V1 lead more than compensates for retrieval workloads dominated by document QA.
131
+
132
+ ## ViDoRe v3
133
+
134
+ Not yet evaluated for this release. Numbers will be added in a follow-up commit once the v3 reproducer run completes.
135
+
136
+ ## Storage cost
137
+
138
+ Per-document storage for an indexed corpus, assuming bf16:
139
+
140
+ | Model | Tokens/page | Dim | Bytes/page |
141
+ |---|---:|---:|---:|
142
+ | Ops-Colqwen3-4B | 1280 | 2560 | 6.6 MB |
143
+ | **Argus-Colqwen3.5-4b-v0** | **2048** | **1024** | **4.2 MB** |
144
+ | TomoroAI/tomoro-colqwen3-embed-4b | 1280 | 320 | 0.8 MB |
145
+ | SauerkrautLM-ColQwen3-4b-v0.1 | 1024 | 128 | 0.3 MB |
146
+
147
+ Argus uses **more tokens** (2048 vs 1280) so the router has enough spatial granularity for region-aware specialisation, but the **narrow 1024-d head** keeps total per-page storage 36 % smaller than Ops despite the higher token count.
148
+
149
+ ## Installation
150
+
151
+ ```bash
152
+ # Qwen3.5-VL is only in transformers 5.x
153
+ pip install "transformers>=5.0.0,<6.0.0"
154
+
155
+ # MTEB 2.12 ships transformers 4.57.6 by default — upgrade explicitly afterwards
156
+ pip install "mteb>=2.12,<3.0.0"
157
+ pip install -U "transformers>=5.0,<6.0"
158
+
159
+ # Optional: faster attention on Hopper / Ampere
160
+ pip install flash-attn==2.6.3 --no-build-isolation
161
+ ```
162
+
163
+ After upgrading `transformers`, **wipe** the cached remote-code modules so the new ones load:
164
+
165
+ ```bash
166
+ rm -rf ~/.cache/huggingface/modules/transformers_modules
167
+ ```
168
+
169
+ ## Usage — text + image retrieval
170
+
171
+ ```python
172
+ import torch
173
+ from PIL import Image
174
+ from transformers import AutoModel, AutoProcessor
175
+
176
+ MODEL_ID = "DataScience-UIBK/Argus-Colqwen3.5-4b-v0"
177
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
178
+ DTYPE = torch.bfloat16 # or torch.float32 for max precision
179
+
180
+ model = AutoModel.from_pretrained(
181
+ MODEL_ID,
182
+ trust_remote_code=True,
183
+ torch_dtype=DTYPE,
184
+ attn_implementation="flash_attention_2", # or None / "sdpa"
185
+ device_map=DEVICE,
186
+ ).eval()
187
+
188
+ processor = AutoProcessor.from_pretrained(
189
+ MODEL_ID,
190
+ trust_remote_code=True,
191
+ max_num_visual_tokens=2048,
192
+ )
193
+
194
+ queries = [
195
+ "What is the company's revenue in 2019?",
196
+ "How does the proposed model compare to baselines?",
197
+ ]
198
+ documents = [
199
+ Image.open("page_a.png").convert("RGB"),
200
+ Image.open("page_b.png").convert("RGB"),
201
+ ]
202
+
203
+ q_emb = model.encode_queries(processor, queries) # list of (Lq, 1024)
204
+ d_emb = model.encode_images(processor, documents) # list of (Ld, 1024)
205
+ scores = processor.score(q_emb, d_emb) # MaxSim, shape (len(q), len(d))
206
+ print(scores)
207
+ ```
208
+
209
+ ## Reproduce the leaderboard ViDoRe results with MTEB
210
+
211
+ ```python
212
+ import mteb
213
+
214
+ m = mteb.get_model("DataScience-UIBK/Argus-Colqwen3.5-4b-v0")
215
+ v1 = mteb.get_benchmark("ViDoRe(v1)").tasks
216
+ v2 = mteb.get_benchmark("ViDoRe(v2)").tasks
217
+ mteb.MTEB(tasks=v1 + v2).run(m, encode_kwargs={"batch_size": 4})
218
+ ```
219
+
220
+ A single H100 80 GB completes the full V1 + V2 run in roughly 4–6 hours.
221
+
222
+ ## Reproduce on the official ViDoRe-benchmark library
223
+
224
+ ```bash
225
+ pip install vidore-benchmark
226
+ vidore-benchmark evaluate-retriever \
227
+ --model-class colqwen2 \
228
+ --model-name DataScience-UIBK/Argus-Colqwen3.5-4b-v0 \
229
+ --collection-name vidore-v1
230
+ ```
231
+
232
+ ## Training
233
+
234
+ | Setting | Value |
235
+ |---|---|
236
+ | Backbone | `Qwen/Qwen3.5-VL-4B-Instruct` (Apache-2.0) |
237
+ | Stage 1 — dense baseline | trains the standard ColPali head; serves as the **teacher** |
238
+ | Stage 2 — MoE balance warmup | gates frozen, no PEFT, short — only goal is to prevent expert collapse |
239
+ | Stage 3 — joint retrieval w/ distillation | PEFT on, gates trainable, KL distillation from stage-1 teacher (`distillation_weight=0.5`) |
240
+ | LoRA rank | 32 (folded into base for this release via `merge_and_unload()` in **fp32**) |
241
+ | Datasets | `vidore/colpali_train_set` + `llamaindex/vdr-multilingual-train` (subsets) |
242
+ | Hardware | 4 × NVIDIA H100 80 GB (zen4_0768_h100x4 partition, UIBK LEO5 cluster) |
243
+ | Optimiser | AdamW, lr = 5e-5 with linear warmup |
244
+ | Precision | bf16 forward / fp32 master + LoRA |
245
+ | Effective batch size | 64 |
246
+
247
+ The merge step that produced this release was run in float32 throughout (`merge_and_unload()` on the LoRA adapter, then sharded to safetensors). The companion bf16 release ran the same merge in bfloat16, which is ~0.1 pp lower on V1 and ~0.2 pp lower on V2 — see the bf16 sibling card.
248
+
249
+ ## Limitations
250
+
251
+ - English-dominant; the multilingual training subset is small and we omit multilingual eval from this release.
252
+ - 4 experts × top-2 routing adds ~5 % to total inference latency vs the dense backbone (the LLM dominates total cost).
253
+ - ViDoRe v3 numbers are pending; will be added once the public reproducer run finishes.
254
+ - Per-task numbers above use the **bf16 sibling** as a conservative lower bound until the fp32 reproducer run completes; they will be replaced with the fp32 numbers in a follow-up commit.
255
+
256
+ ## License
257
+
258
+ Apache 2.0, inherited from `Qwen3.5-VL-4B-Instruct`. You may use, modify, and redistribute this model commercially, with attribution.
259
+
260
+ ## Citation
261
+
262
+ ```bibtex
263
+ @misc{argus2026,
264
+ title = {Argus: Region-Aware Query-Conditioned Mixture of Experts for Visual Document Retrieval},
265
+ author = {DataScience-UIBK team},
266
+ year = {2026},
267
+ url = {https://huggingface.co/DataScience-UIBK/Argus-Colqwen3.5-4b-v0},
268
+ }
269
+ ```
270
+
271
+ ## Contact
272
+
273
+ - Org: [DataScience-UIBK](https://huggingface.co/DataScience-UIBK), University of Innsbruck
274
+ - Issues: open one on this repo's *Community* tab.
chat_template.jinja ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {%- set image_count = namespace(value=0) %}
2
+ {%- set video_count = namespace(value=0) %}
3
+ {%- macro render_content(content, do_vision_count, is_system_content=false) %}
4
+ {%- if content is string %}
5
+ {{- content }}
6
+ {%- elif content is iterable and content is not mapping %}
7
+ {%- for item in content %}
8
+ {%- if 'image' in item or 'image_url' in item or item.type == 'image' %}
9
+ {%- if is_system_content %}
10
+ {{- raise_exception('System message cannot contain images.') }}
11
+ {%- endif %}
12
+ {%- if do_vision_count %}
13
+ {%- set image_count.value = image_count.value + 1 %}
14
+ {%- endif %}
15
+ {%- if add_vision_id %}
16
+ {{- 'Picture ' ~ image_count.value ~ ': ' }}
17
+ {%- endif %}
18
+ {{- '<|vision_start|><|image_pad|><|vision_end|>' }}
19
+ {%- elif 'video' in item or item.type == 'video' %}
20
+ {%- if is_system_content %}
21
+ {{- raise_exception('System message cannot contain videos.') }}
22
+ {%- endif %}
23
+ {%- if do_vision_count %}
24
+ {%- set video_count.value = video_count.value + 1 %}
25
+ {%- endif %}
26
+ {%- if add_vision_id %}
27
+ {{- 'Video ' ~ video_count.value ~ ': ' }}
28
+ {%- endif %}
29
+ {{- '<|vision_start|><|video_pad|><|vision_end|>' }}
30
+ {%- elif 'text' in item %}
31
+ {{- item.text }}
32
+ {%- else %}
33
+ {{- raise_exception('Unexpected item type in content.') }}
34
+ {%- endif %}
35
+ {%- endfor %}
36
+ {%- elif content is none or content is undefined %}
37
+ {{- '' }}
38
+ {%- else %}
39
+ {{- raise_exception('Unexpected content type.') }}
40
+ {%- endif %}
41
+ {%- endmacro %}
42
+ {%- if not messages %}
43
+ {{- raise_exception('No messages provided.') }}
44
+ {%- endif %}
45
+ {%- if tools and tools is iterable and tools is not mapping %}
46
+ {{- '<|im_start|>system\n' }}
47
+ {{- "# Tools\n\nYou have access to the following functions:\n\n<tools>" }}
48
+ {%- for tool in tools %}
49
+ {{- "\n" }}
50
+ {{- tool | tojson }}
51
+ {%- endfor %}
52
+ {{- "\n</tools>" }}
53
+ {{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n</IMPORTANT>' }}
54
+ {%- if messages[0].role == 'system' %}
55
+ {%- set content = render_content(messages[0].content, false, true)|trim %}
56
+ {%- if content %}
57
+ {{- '\n\n' + content }}
58
+ {%- endif %}
59
+ {%- endif %}
60
+ {{- '<|im_end|>\n' }}
61
+ {%- else %}
62
+ {%- if messages[0].role == 'system' %}
63
+ {%- set content = render_content(messages[0].content, false, true)|trim %}
64
+ {{- '<|im_start|>system\n' + content + '<|im_end|>\n' }}
65
+ {%- endif %}
66
+ {%- endif %}
67
+ {%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
68
+ {%- for message in messages[::-1] %}
69
+ {%- set index = (messages|length - 1) - loop.index0 %}
70
+ {%- if ns.multi_step_tool and message.role == "user" %}
71
+ {%- set content = render_content(message.content, false)|trim %}
72
+ {%- if not(content.startswith('<tool_response>') and content.endswith('</tool_response>')) %}
73
+ {%- set ns.multi_step_tool = false %}
74
+ {%- set ns.last_query_index = index %}
75
+ {%- endif %}
76
+ {%- endif %}
77
+ {%- endfor %}
78
+ {%- if ns.multi_step_tool %}
79
+ {{- raise_exception('No user query found in messages.') }}
80
+ {%- endif %}
81
+ {%- for message in messages %}
82
+ {%- set content = render_content(message.content, true)|trim %}
83
+ {%- if message.role == "system" %}
84
+ {%- if not loop.first %}
85
+ {{- raise_exception('System message must be at the beginning.') }}
86
+ {%- endif %}
87
+ {%- elif message.role == "user" %}
88
+ {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
89
+ {%- elif message.role == "assistant" %}
90
+ {%- set reasoning_content = '' %}
91
+ {%- if message.reasoning_content is string %}
92
+ {%- set reasoning_content = message.reasoning_content %}
93
+ {%- else %}
94
+ {%- if '</think>' in content %}
95
+ {%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
96
+ {%- set content = content.split('</think>')[-1].lstrip('\n') %}
97
+ {%- endif %}
98
+ {%- endif %}
99
+ {%- set reasoning_content = reasoning_content|trim %}
100
+ {%- if loop.index0 > ns.last_query_index %}
101
+ {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content + '\n</think>\n\n' + content }}
102
+ {%- else %}
103
+ {{- '<|im_start|>' + message.role + '\n' + content }}
104
+ {%- endif %}
105
+ {%- if message.tool_calls and message.tool_calls is iterable and message.tool_calls is not mapping %}
106
+ {%- for tool_call in message.tool_calls %}
107
+ {%- if tool_call.function is defined %}
108
+ {%- set tool_call = tool_call.function %}
109
+ {%- endif %}
110
+ {%- if loop.first %}
111
+ {%- if content|trim %}
112
+ {{- '\n\n<tool_call>\n<function=' + tool_call.name + '>\n' }}
113
+ {%- else %}
114
+ {{- '<tool_call>\n<function=' + tool_call.name + '>\n' }}
115
+ {%- endif %}
116
+ {%- else %}
117
+ {{- '\n<tool_call>\n<function=' + tool_call.name + '>\n' }}
118
+ {%- endif %}
119
+ {%- if tool_call.arguments is defined %}
120
+ {%- for args_name, args_value in tool_call.arguments|items %}
121
+ {{- '<parameter=' + args_name + '>\n' }}
122
+ {%- set args_value = args_value | tojson | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %}
123
+ {{- args_value }}
124
+ {{- '\n</parameter>\n' }}
125
+ {%- endfor %}
126
+ {%- endif %}
127
+ {{- '</function>\n</tool_call>' }}
128
+ {%- endfor %}
129
+ {%- endif %}
130
+ {{- '<|im_end|>\n' }}
131
+ {%- elif message.role == "tool" %}
132
+ {%- if loop.previtem and loop.previtem.role != "tool" %}
133
+ {{- '<|im_start|>user' }}
134
+ {%- endif %}
135
+ {{- '\n<tool_response>\n' }}
136
+ {{- content }}
137
+ {{- '\n</tool_response>' }}
138
+ {%- if not loop.last and loop.nextitem.role != "tool" %}
139
+ {{- '<|im_end|>\n' }}
140
+ {%- elif loop.last %}
141
+ {{- '<|im_end|>\n' }}
142
+ {%- endif %}
143
+ {%- else %}
144
+ {{- raise_exception('Unexpected message role.') }}
145
+ {%- endif %}
146
+ {%- endfor %}
147
+ {%- if add_generation_prompt %}
148
+ {{- '<|im_start|>assistant\n' }}
149
+ {%- if enable_thinking is defined and enable_thinking is false %}
150
+ {{- '<think>\n\n</think>\n\n' }}
151
+ {%- else %}
152
+ {{- '<think>\n' }}
153
+ {%- endif %}
154
+ {%- endif %}
config.json ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "ArgusForRetrieval"
4
+ ],
5
+ "dtype": "float32",
6
+ "image_token_id": 248056,
7
+ "model_type": "argus_colqwen35",
8
+ "rope_parameters": {},
9
+ "text_config": {
10
+ "attention_bias": false,
11
+ "attention_dropout": 0.0,
12
+ "attn_output_gate": true,
13
+ "bos_token_id": null,
14
+ "dtype": "float32",
15
+ "eos_token_id": 248044,
16
+ "full_attention_interval": 4,
17
+ "head_dim": 256,
18
+ "hidden_act": "silu",
19
+ "hidden_size": 2560,
20
+ "initializer_range": 0.02,
21
+ "intermediate_size": 9216,
22
+ "layer_types": [
23
+ "linear_attention",
24
+ "linear_attention",
25
+ "linear_attention",
26
+ "full_attention",
27
+ "linear_attention",
28
+ "linear_attention",
29
+ "linear_attention",
30
+ "full_attention",
31
+ "linear_attention",
32
+ "linear_attention",
33
+ "linear_attention",
34
+ "full_attention",
35
+ "linear_attention",
36
+ "linear_attention",
37
+ "linear_attention",
38
+ "full_attention",
39
+ "linear_attention",
40
+ "linear_attention",
41
+ "linear_attention",
42
+ "full_attention",
43
+ "linear_attention",
44
+ "linear_attention",
45
+ "linear_attention",
46
+ "full_attention",
47
+ "linear_attention",
48
+ "linear_attention",
49
+ "linear_attention",
50
+ "full_attention",
51
+ "linear_attention",
52
+ "linear_attention",
53
+ "linear_attention",
54
+ "full_attention"
55
+ ],
56
+ "linear_conv_kernel_dim": 4,
57
+ "linear_key_head_dim": 128,
58
+ "linear_num_key_heads": 16,
59
+ "linear_num_value_heads": 32,
60
+ "linear_value_head_dim": 128,
61
+ "mamba_ssm_dtype": "float32",
62
+ "max_position_embeddings": 262144,
63
+ "mlp_only_layers": [],
64
+ "model_type": "qwen3_5_text",
65
+ "mtp_num_hidden_layers": 1,
66
+ "mtp_use_dedicated_embeddings": false,
67
+ "num_attention_heads": 16,
68
+ "num_hidden_layers": 32,
69
+ "num_key_value_heads": 4,
70
+ "pad_token_id": null,
71
+ "partial_rotary_factor": 0.25,
72
+ "rms_norm_eps": 1e-06,
73
+ "rope_parameters": {
74
+ "mrope_interleaved": true,
75
+ "mrope_section": [
76
+ 11,
77
+ 11,
78
+ 10
79
+ ],
80
+ "partial_rotary_factor": 0.25,
81
+ "rope_theta": 10000000,
82
+ "rope_type": "default"
83
+ },
84
+ "tie_word_embeddings": true,
85
+ "use_cache": true,
86
+ "vocab_size": 248320
87
+ },
88
+ "tie_word_embeddings": true,
89
+ "transformers_version": "5.6.1",
90
+ "use_cache": false,
91
+ "video_token_id": 248057,
92
+ "vision_config": {
93
+ "deepstack_visual_indexes": [],
94
+ "depth": 24,
95
+ "dtype": "float32",
96
+ "hidden_act": "gelu_pytorch_tanh",
97
+ "hidden_size": 1024,
98
+ "in_channels": 3,
99
+ "initializer_range": 0.02,
100
+ "intermediate_size": 4096,
101
+ "model_type": "qwen3_5_vision",
102
+ "num_heads": 16,
103
+ "num_position_embeddings": 2304,
104
+ "out_hidden_size": 2560,
105
+ "patch_size": 16,
106
+ "spatial_merge_size": 2,
107
+ "temporal_patch_size": 2
108
+ },
109
+ "vision_end_token_id": 248054,
110
+ "vision_start_token_id": 248053,
111
+ "auto_map": {
112
+ "AutoConfig": "configuration_argus.ArgusConfig",
113
+ "AutoModel": "modeling_argus.ArgusForRetrieval",
114
+ "AutoProcessor": "processing_argus.ArgusProcessor"
115
+ },
116
+ "retrieval_dim": 1024,
117
+ "num_specialists": 4,
118
+ "top_k_experts": 2,
119
+ "region_size": 4,
120
+ "router_layer_index": -5,
121
+ "router_temperature": 0.8,
122
+ "router_noise_std": 0.0,
123
+ "mask_non_image_embeddings": true,
124
+ "shared_gate_init": 0.0,
125
+ "specialist_gate_init": 0.0
126
+ }
configuration_argus.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Argus: Region-Aware Query-Conditioned Mixture of Experts for Visual Document Retrieval.
2
+
3
+ Config class. Subclasses the Qwen3.5-VL config and adds the Argus-specific
4
+ retrieval + MoE hyperparameters. Used by ``AutoConfig.from_pretrained`` via the
5
+ ``auto_map`` field in ``config.json`` (requires ``trust_remote_code=True``).
6
+ """
7
+ from __future__ import annotations
8
+
9
+ try:
10
+ from transformers.models.qwen3_5 import Qwen3_5Config as _BackboneConfig
11
+ except ImportError:
12
+ try:
13
+ from transformers.models.qwen3_5 import Qwen35Config as _BackboneConfig
14
+ except ImportError as exc:
15
+ raise ImportError(
16
+ "Argus requires a transformers build that exposes the Qwen3.5 VL "
17
+ "classes (transformers.models.qwen3_5). Upgrade to transformers "
18
+ ">= 4.57.0.dev0."
19
+ ) from exc
20
+
21
+
22
+ class ArgusConfig(_BackboneConfig):
23
+ """Top-level config for Argus-Colqwen3.5-9B.
24
+
25
+ Holds the standard Qwen3.5-VL fields (text_config, vision_config, image
26
+ token ids, etc.) plus Argus-specific retrieval + MoE knobs:
27
+
28
+ - ``retrieval_dim``: output dimensionality of the multi-vector retrieval
29
+ head (``custom_text_proj``). Default: 768.
30
+ - ``num_specialists``: number of latent spatial experts in the MoE stack.
31
+ - ``top_k_experts``: sparsity of the router (top-k routing).
32
+ - ``region_size``: spatial pooling window (patches) for region tokens.
33
+ - ``router_layer_index``: hidden-state layer used as input to the router.
34
+ - ``router_temperature``: softmax temperature of the router.
35
+ - ``mask_non_image_embeddings``: zero out embedding positions that are
36
+ not image tokens at encode time (document side).
37
+ - ``shared_gate_init`` / ``specialist_gate_init``: logit-space init for
38
+ the gate scalars (sigmoid of these multiplies shared/specialist expert
39
+ contributions).
40
+ """
41
+
42
+ model_type = "argus_colqwen35"
43
+
44
+ def __init__(
45
+ self,
46
+ retrieval_dim: int = 768,
47
+ num_specialists: int = 4,
48
+ top_k_experts: int = 2,
49
+ region_size: int = 4,
50
+ router_layer_index: int = -5,
51
+ router_temperature: float = 0.8,
52
+ router_noise_std: float = 0.0,
53
+ mask_non_image_embeddings: bool = True,
54
+ shared_gate_init: float = 0.0,
55
+ specialist_gate_init: float = 0.0,
56
+ **kwargs,
57
+ ) -> None:
58
+ super().__init__(**kwargs)
59
+ self.retrieval_dim = int(retrieval_dim)
60
+ self.num_specialists = int(num_specialists)
61
+ self.top_k_experts = int(top_k_experts)
62
+ self.region_size = int(region_size)
63
+ self.router_layer_index = int(router_layer_index)
64
+ self.router_temperature = float(router_temperature)
65
+ self.router_noise_std = float(router_noise_std)
66
+ self.mask_non_image_embeddings = bool(mask_non_image_embeddings)
67
+ self.shared_gate_init = float(shared_gate_init)
68
+ self.specialist_gate_init = float(specialist_gate_init)
69
+
70
+
71
+ __all__ = ["ArgusConfig"]
eval_vidore_v1_v2.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Evaluate Argus-Colqwen3.5-9B on ViDoRe V1 + V2 using the official
3
+ ``vidore-benchmark`` library straight from the HuggingFace hub.
4
+
5
+ Why this wrapper exists
6
+ -----------------------
7
+ The reference evaluators live in https://github.com/illuin-tech/vidore-benchmark
8
+ — every ColPali / Nemotron / vidore leaderboard submission is scored against
9
+ ``ViDoReEvaluatorQA`` / ``ViDoReEvaluatorBEIR``. By delegating to those
10
+ evaluators here (instead of re-implementing nDCG/Recall/MRR locally) we
11
+ guarantee:
12
+
13
+ - ``None`` queries are filtered correctly (Shift, all SyntheticDocQA subsets).
14
+ - The full image corpus is preserved (distractors stay in the retrieval pool).
15
+ - MTEB-style metrics (ndcg/map/recall/precision/mrr at every k) match the
16
+ canonical leaderboard numbers bit-for-bit.
17
+
18
+ Usage
19
+ -----
20
+ pip install vidore-benchmark # or: pip install git+https://github.com/illuin-tech/vidore-benchmark
21
+
22
+ python eval_vidore_v1_v2.py \\
23
+ --model ./argus-colqwen3.5-9b-v0 \\
24
+ --benchmarks v1 v2 \\
25
+ --batch-query 4 \\
26
+ --batch-passage 2
27
+
28
+ Use ``--model DataScience-UIBK/Argus-Colqwen3.5-9B-v0`` once uploaded.
29
+ """
30
+ from __future__ import annotations
31
+
32
+ import argparse
33
+ import json
34
+ from pathlib import Path
35
+ from typing import Dict
36
+
37
+ import torch
38
+
39
+
40
+ # ---------------------- ViDoRe dataset catalog ---------------------- #
41
+
42
+ # ViDoRe V1 (QA format). Each HF dataset has a single ``test`` split with
43
+ # columns: query, image, image_filename. Some rows contain ``query=None``
44
+ # (distractors); the library handles this.
45
+ V1_DATASETS: Dict[str, str] = {
46
+ "ArxivQ": "vidore/arxivqa_test_subsampled",
47
+ "DocQ": "vidore/docvqa_test_subsampled",
48
+ "InfoQ": "vidore/infovqa_test_subsampled",
49
+ "TabF": "vidore/tabfquad_test_subsampled",
50
+ "TATQ": "vidore/tatdqa_test",
51
+ "Shift": "vidore/shiftproject_test",
52
+ "AI": "vidore/syntheticDocQA_artificial_intelligence_test",
53
+ "Energy": "vidore/syntheticDocQA_energy_test",
54
+ "Gov": "vidore/syntheticDocQA_government_reports_test",
55
+ "Health": "vidore/syntheticDocQA_healthcare_industry_test",
56
+ }
57
+
58
+ # ViDoRe V2 (BEIR format). Each HF repo exposes 3 dataset configs:
59
+ # ``corpus`` (images + corpus-id), ``queries`` (query text + query-id), and
60
+ # ``qrels`` (query-id, corpus-id, score). The library's ``ViDoReEvaluatorBEIR``
61
+ # expects that exact shape.
62
+ V2_DATASETS: Dict[str, str] = {
63
+ "MIT_Biomedical_Multi": "vidore/biomedical_lectures_v2",
64
+ "Economics_Macro_Multi": "vidore/economics_reports_v2",
65
+ "ESG_Restaurant_Human_EN": "vidore/esg_reports_human_labeled_v2",
66
+ "ESG_Restaurant_Synth_Multi": "vidore/esg_reports_v2",
67
+ }
68
+
69
+
70
+ # ---------------------- helpers ---------------------- #
71
+
72
+ def _load_model_and_processor(args: argparse.Namespace):
73
+ from transformers import AutoModel, AutoProcessor
74
+
75
+ dtype = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[args.dtype]
76
+ print(f"[eval] loading model: {args.model} ({args.dtype}, attn={args.attn_implementation})")
77
+
78
+ # ``dtype`` on transformers >= 4.57; older builds still use ``torch_dtype``.
79
+ load_kwargs = {"trust_remote_code": True, "attn_implementation": args.attn_implementation}
80
+ try:
81
+ model = AutoModel.from_pretrained(args.model, dtype=dtype, **load_kwargs).eval().cuda()
82
+ except TypeError:
83
+ model = AutoModel.from_pretrained(args.model, torch_dtype=dtype, **load_kwargs).eval().cuda()
84
+
85
+ processor = AutoProcessor.from_pretrained(
86
+ args.model,
87
+ trust_remote_code=True,
88
+ max_num_visual_tokens=args.max_num_visual_tokens,
89
+ )
90
+ return model, processor
91
+
92
+
93
+ class _EmbeddingOnlyWrapper(torch.nn.Module):
94
+ """Adapter that exposes the plain embeddings tensor to vidore-benchmark.
95
+
96
+ ``VisionRetriever.forward_queries`` / ``forward_passages`` call
97
+ ``self.model(**batch).to("cpu")``, i.e. they assume the model returns a
98
+ Tensor. ``ArgusForRetrieval.forward`` returns an ``ArgusOutput`` dataclass
99
+ (embeddings + region_embeddings + routing info) to keep the MoE analysis
100
+ surface. This wrapper unwraps ``.embeddings`` so the library sees the
101
+ expected shape without us having to touch the model class.
102
+ """
103
+
104
+ def __init__(self, inner: torch.nn.Module):
105
+ super().__init__()
106
+ self.inner = inner
107
+
108
+ def __getattr__(self, name):
109
+ # Delegate .device / .dtype / .eval() / etc. to the wrapped model.
110
+ try:
111
+ return super().__getattr__(name)
112
+ except AttributeError:
113
+ return getattr(self.inner, name)
114
+
115
+ def forward(self, **kwargs) -> torch.Tensor:
116
+ return self.inner(**kwargs).embeddings
117
+
118
+
119
+ def _build_retriever(model, processor):
120
+ from vidore_benchmark.retrievers import VisionRetriever
121
+ wrapped = _EmbeddingOnlyWrapper(model).eval()
122
+ # Older vidore-benchmark releases don't accept ``num_workers`` at all;
123
+ # newer ones do. Try-with-kwarg for portability.
124
+ try:
125
+ return VisionRetriever(model=wrapped, processor=processor, num_workers=0)
126
+ except TypeError:
127
+ return VisionRetriever(model=wrapped, processor=processor)
128
+
129
+
130
+ def _eval_v1(retriever, args: argparse.Namespace) -> Dict[str, Dict[str, float]]:
131
+ from datasets import load_dataset
132
+ from vidore_benchmark.evaluation.vidore_evaluators import ViDoReEvaluatorQA
133
+
134
+ evaluator = ViDoReEvaluatorQA(retriever)
135
+ results: Dict[str, Dict[str, float]] = {}
136
+ print("\n========== V1 ==========")
137
+ for short, repo_id in V1_DATASETS.items():
138
+ if args.datasets and short not in args.datasets:
139
+ continue
140
+ print(f"\n[V1:{short}] {repo_id}")
141
+ ds = load_dataset(repo_id, split="test")
142
+ metrics = evaluator.evaluate_dataset(
143
+ ds,
144
+ batch_query=args.batch_query,
145
+ batch_passage=args.batch_passage,
146
+ batch_score=args.batch_score,
147
+ )
148
+ results[short] = metrics
149
+ print(f" nDCG@5 = {metrics.get('ndcg_at_5', 0.0):.4f}")
150
+ return results
151
+
152
+
153
+ def _eval_v2(retriever, args: argparse.Namespace) -> Dict[str, Dict[str, float]]:
154
+ from datasets import load_dataset
155
+ from vidore_benchmark.evaluation.vidore_evaluators import ViDoReEvaluatorBEIR
156
+
157
+ evaluator = ViDoReEvaluatorBEIR(retriever)
158
+ results: Dict[str, Dict[str, float]] = {}
159
+ print("\n========== V2 ==========")
160
+ for short, repo_id in V2_DATASETS.items():
161
+ if args.datasets and short not in args.datasets:
162
+ continue
163
+ print(f"\n[V2:{short}] {repo_id}")
164
+ ds = {
165
+ "corpus": load_dataset(repo_id, "corpus", split="test"),
166
+ "queries": load_dataset(repo_id, "queries", split="test"),
167
+ "qrels": load_dataset(repo_id, "qrels", split="test"),
168
+ }
169
+ metrics = evaluator.evaluate_dataset(
170
+ ds,
171
+ batch_query=args.batch_query,
172
+ batch_passage=args.batch_passage,
173
+ batch_score=args.batch_score,
174
+ )
175
+ results[short] = metrics
176
+ print(f" nDCG@5 = {metrics.get('ndcg_at_5', 0.0):.4f}")
177
+ return results
178
+
179
+
180
+ # ---------------------- main ---------------------- #
181
+
182
+ def run(args: argparse.Namespace) -> None:
183
+ model, processor = _load_model_and_processor(args)
184
+ retriever = _build_retriever(model, processor)
185
+
186
+ all_results: Dict[str, Dict[str, Dict[str, float]]] = {"v1": {}, "v2": {}}
187
+ if "v1" in args.benchmarks:
188
+ all_results["v1"] = _eval_v1(retriever, args)
189
+ if "v2" in args.benchmarks:
190
+ all_results["v2"] = _eval_v2(retriever, args)
191
+
192
+ # Summary
193
+ print("\n========== summary ==========")
194
+ for bench, per_ds in all_results.items():
195
+ if not per_ds:
196
+ continue
197
+ avg = sum(m.get("ndcg_at_5", 0.0) for m in per_ds.values()) / max(len(per_ds), 1)
198
+ print(f"{bench.upper()} avg nDCG@5 = {avg:.4f} ({len(per_ds)} datasets)")
199
+
200
+ if args.output_json:
201
+ Path(args.output_json).write_text(json.dumps(all_results, indent=2, default=float))
202
+ print(f"[eval] saved: {args.output_json}")
203
+
204
+
205
+ def parse_args() -> argparse.Namespace:
206
+ p = argparse.ArgumentParser()
207
+ p.add_argument("--model", required=True,
208
+ help="HF repo id or local release folder.")
209
+ p.add_argument("--benchmarks", nargs="+", default=["v1", "v2"], choices=["v1", "v2"])
210
+ p.add_argument("--datasets", nargs="*", default=None,
211
+ help="Optional subset by short key (e.g. ArxivQ DocQ Shift).")
212
+ p.add_argument("--batch-query", type=int, default=4)
213
+ p.add_argument("--batch-passage", type=int, default=2)
214
+ p.add_argument("--batch-score", type=int, default=4)
215
+ p.add_argument("--max-num-visual-tokens", type=int, default=2048)
216
+ p.add_argument("--attn-implementation", default="flash_attention_2",
217
+ choices=["flash_attention_2", "sdpa", "eager"])
218
+ p.add_argument("--dtype", default="bfloat16", choices=["bfloat16", "float16", "float32"])
219
+ p.add_argument("--output-json", default=None)
220
+ return p.parse_args()
221
+
222
+
223
+ if __name__ == "__main__":
224
+ run(parse_args())
model-00001-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6aaea4d5d4338bd0267bd98ca23bb8871d9ea99ba3c7d68ed29e8f6aa87d862c
3
+ size 4948840144
model-00002-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:55019b705e79058c34a9fb7362b25cbe41d9d0ba1ce3598a152c265434a70db3
3
+ size 4997768760
model-00003-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:44aac9f8f861c749b3c2f58285059112fad086066b7c05cc6261e4bb92b939a4
3
+ size 4997768896
model-00004-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8d4ddb048505ddbf088ea467f62e50a611a449cb34f9c1a8b4b70d0b093ce0ae
3
+ size 3889497560
model.safetensors.index.json ADDED
@@ -0,0 +1,772 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 18833786904
4
+ },
5
+ "weight_map": {
6
+ "visual.patch_embed.proj.weight": "model-00001-of-00004.safetensors",
7
+ "visual.patch_embed.proj.bias": "model-00001-of-00004.safetensors",
8
+ "visual.pos_embed.weight": "model-00001-of-00004.safetensors",
9
+ "visual.blocks.0.norm1.weight": "model-00001-of-00004.safetensors",
10
+ "visual.blocks.0.norm1.bias": "model-00001-of-00004.safetensors",
11
+ "visual.blocks.0.norm2.weight": "model-00001-of-00004.safetensors",
12
+ "visual.blocks.0.norm2.bias": "model-00001-of-00004.safetensors",
13
+ "visual.blocks.0.attn.qkv.weight": "model-00001-of-00004.safetensors",
14
+ "visual.blocks.0.attn.qkv.bias": "model-00001-of-00004.safetensors",
15
+ "visual.blocks.0.attn.proj.weight": "model-00001-of-00004.safetensors",
16
+ "visual.blocks.0.attn.proj.bias": "model-00001-of-00004.safetensors",
17
+ "visual.blocks.0.mlp.linear_fc1.weight": "model-00001-of-00004.safetensors",
18
+ "visual.blocks.0.mlp.linear_fc1.bias": "model-00001-of-00004.safetensors",
19
+ "visual.blocks.0.mlp.linear_fc2.weight": "model-00001-of-00004.safetensors",
20
+ "visual.blocks.0.mlp.linear_fc2.bias": "model-00001-of-00004.safetensors",
21
+ "visual.blocks.1.norm1.weight": "model-00001-of-00004.safetensors",
22
+ "visual.blocks.1.norm1.bias": "model-00001-of-00004.safetensors",
23
+ "visual.blocks.1.norm2.weight": "model-00001-of-00004.safetensors",
24
+ "visual.blocks.1.norm2.bias": "model-00001-of-00004.safetensors",
25
+ "visual.blocks.1.attn.qkv.weight": "model-00001-of-00004.safetensors",
26
+ "visual.blocks.1.attn.qkv.bias": "model-00001-of-00004.safetensors",
27
+ "visual.blocks.1.attn.proj.weight": "model-00001-of-00004.safetensors",
28
+ "visual.blocks.1.attn.proj.bias": "model-00001-of-00004.safetensors",
29
+ "visual.blocks.1.mlp.linear_fc1.weight": "model-00001-of-00004.safetensors",
30
+ "visual.blocks.1.mlp.linear_fc1.bias": "model-00001-of-00004.safetensors",
31
+ "visual.blocks.1.mlp.linear_fc2.weight": "model-00001-of-00004.safetensors",
32
+ "visual.blocks.1.mlp.linear_fc2.bias": "model-00001-of-00004.safetensors",
33
+ "visual.blocks.2.norm1.weight": "model-00001-of-00004.safetensors",
34
+ "visual.blocks.2.norm1.bias": "model-00001-of-00004.safetensors",
35
+ "visual.blocks.2.norm2.weight": "model-00001-of-00004.safetensors",
36
+ "visual.blocks.2.norm2.bias": "model-00001-of-00004.safetensors",
37
+ "visual.blocks.2.attn.qkv.weight": "model-00001-of-00004.safetensors",
38
+ "visual.blocks.2.attn.qkv.bias": "model-00001-of-00004.safetensors",
39
+ "visual.blocks.2.attn.proj.weight": "model-00001-of-00004.safetensors",
40
+ "visual.blocks.2.attn.proj.bias": "model-00001-of-00004.safetensors",
41
+ "visual.blocks.2.mlp.linear_fc1.weight": "model-00001-of-00004.safetensors",
42
+ "visual.blocks.2.mlp.linear_fc1.bias": "model-00001-of-00004.safetensors",
43
+ "visual.blocks.2.mlp.linear_fc2.weight": "model-00001-of-00004.safetensors",
44
+ "visual.blocks.2.mlp.linear_fc2.bias": "model-00001-of-00004.safetensors",
45
+ "visual.blocks.3.norm1.weight": "model-00001-of-00004.safetensors",
46
+ "visual.blocks.3.norm1.bias": "model-00001-of-00004.safetensors",
47
+ "visual.blocks.3.norm2.weight": "model-00001-of-00004.safetensors",
48
+ "visual.blocks.3.norm2.bias": "model-00001-of-00004.safetensors",
49
+ "visual.blocks.3.attn.qkv.weight": "model-00001-of-00004.safetensors",
50
+ "visual.blocks.3.attn.qkv.bias": "model-00001-of-00004.safetensors",
51
+ "visual.blocks.3.attn.proj.weight": "model-00001-of-00004.safetensors",
52
+ "visual.blocks.3.attn.proj.bias": "model-00001-of-00004.safetensors",
53
+ "visual.blocks.3.mlp.linear_fc1.weight": "model-00001-of-00004.safetensors",
54
+ "visual.blocks.3.mlp.linear_fc1.bias": "model-00001-of-00004.safetensors",
55
+ "visual.blocks.3.mlp.linear_fc2.weight": "model-00001-of-00004.safetensors",
56
+ "visual.blocks.3.mlp.linear_fc2.bias": "model-00001-of-00004.safetensors",
57
+ "visual.blocks.4.norm1.weight": "model-00001-of-00004.safetensors",
58
+ "visual.blocks.4.norm1.bias": "model-00001-of-00004.safetensors",
59
+ "visual.blocks.4.norm2.weight": "model-00001-of-00004.safetensors",
60
+ "visual.blocks.4.norm2.bias": "model-00001-of-00004.safetensors",
61
+ "visual.blocks.4.attn.qkv.weight": "model-00001-of-00004.safetensors",
62
+ "visual.blocks.4.attn.qkv.bias": "model-00001-of-00004.safetensors",
63
+ "visual.blocks.4.attn.proj.weight": "model-00001-of-00004.safetensors",
64
+ "visual.blocks.4.attn.proj.bias": "model-00001-of-00004.safetensors",
65
+ "visual.blocks.4.mlp.linear_fc1.weight": "model-00001-of-00004.safetensors",
66
+ "visual.blocks.4.mlp.linear_fc1.bias": "model-00001-of-00004.safetensors",
67
+ "visual.blocks.4.mlp.linear_fc2.weight": "model-00001-of-00004.safetensors",
68
+ "visual.blocks.4.mlp.linear_fc2.bias": "model-00001-of-00004.safetensors",
69
+ "visual.blocks.5.norm1.weight": "model-00001-of-00004.safetensors",
70
+ "visual.blocks.5.norm1.bias": "model-00001-of-00004.safetensors",
71
+ "visual.blocks.5.norm2.weight": "model-00001-of-00004.safetensors",
72
+ "visual.blocks.5.norm2.bias": "model-00001-of-00004.safetensors",
73
+ "visual.blocks.5.attn.qkv.weight": "model-00001-of-00004.safetensors",
74
+ "visual.blocks.5.attn.qkv.bias": "model-00001-of-00004.safetensors",
75
+ "visual.blocks.5.attn.proj.weight": "model-00001-of-00004.safetensors",
76
+ "visual.blocks.5.attn.proj.bias": "model-00001-of-00004.safetensors",
77
+ "visual.blocks.5.mlp.linear_fc1.weight": "model-00001-of-00004.safetensors",
78
+ "visual.blocks.5.mlp.linear_fc1.bias": "model-00001-of-00004.safetensors",
79
+ "visual.blocks.5.mlp.linear_fc2.weight": "model-00001-of-00004.safetensors",
80
+ "visual.blocks.5.mlp.linear_fc2.bias": "model-00001-of-00004.safetensors",
81
+ "visual.blocks.6.norm1.weight": "model-00001-of-00004.safetensors",
82
+ "visual.blocks.6.norm1.bias": "model-00001-of-00004.safetensors",
83
+ "visual.blocks.6.norm2.weight": "model-00001-of-00004.safetensors",
84
+ "visual.blocks.6.norm2.bias": "model-00001-of-00004.safetensors",
85
+ "visual.blocks.6.attn.qkv.weight": "model-00001-of-00004.safetensors",
86
+ "visual.blocks.6.attn.qkv.bias": "model-00001-of-00004.safetensors",
87
+ "visual.blocks.6.attn.proj.weight": "model-00001-of-00004.safetensors",
88
+ "visual.blocks.6.attn.proj.bias": "model-00001-of-00004.safetensors",
89
+ "visual.blocks.6.mlp.linear_fc1.weight": "model-00001-of-00004.safetensors",
90
+ "visual.blocks.6.mlp.linear_fc1.bias": "model-00001-of-00004.safetensors",
91
+ "visual.blocks.6.mlp.linear_fc2.weight": "model-00001-of-00004.safetensors",
92
+ "visual.blocks.6.mlp.linear_fc2.bias": "model-00001-of-00004.safetensors",
93
+ "visual.blocks.7.norm1.weight": "model-00001-of-00004.safetensors",
94
+ "visual.blocks.7.norm1.bias": "model-00001-of-00004.safetensors",
95
+ "visual.blocks.7.norm2.weight": "model-00001-of-00004.safetensors",
96
+ "visual.blocks.7.norm2.bias": "model-00001-of-00004.safetensors",
97
+ "visual.blocks.7.attn.qkv.weight": "model-00001-of-00004.safetensors",
98
+ "visual.blocks.7.attn.qkv.bias": "model-00001-of-00004.safetensors",
99
+ "visual.blocks.7.attn.proj.weight": "model-00001-of-00004.safetensors",
100
+ "visual.blocks.7.attn.proj.bias": "model-00001-of-00004.safetensors",
101
+ "visual.blocks.7.mlp.linear_fc1.weight": "model-00001-of-00004.safetensors",
102
+ "visual.blocks.7.mlp.linear_fc1.bias": "model-00001-of-00004.safetensors",
103
+ "visual.blocks.7.mlp.linear_fc2.weight": "model-00001-of-00004.safetensors",
104
+ "visual.blocks.7.mlp.linear_fc2.bias": "model-00001-of-00004.safetensors",
105
+ "visual.blocks.8.norm1.weight": "model-00001-of-00004.safetensors",
106
+ "visual.blocks.8.norm1.bias": "model-00001-of-00004.safetensors",
107
+ "visual.blocks.8.norm2.weight": "model-00001-of-00004.safetensors",
108
+ "visual.blocks.8.norm2.bias": "model-00001-of-00004.safetensors",
109
+ "visual.blocks.8.attn.qkv.weight": "model-00001-of-00004.safetensors",
110
+ "visual.blocks.8.attn.qkv.bias": "model-00001-of-00004.safetensors",
111
+ "visual.blocks.8.attn.proj.weight": "model-00001-of-00004.safetensors",
112
+ "visual.blocks.8.attn.proj.bias": "model-00001-of-00004.safetensors",
113
+ "visual.blocks.8.mlp.linear_fc1.weight": "model-00001-of-00004.safetensors",
114
+ "visual.blocks.8.mlp.linear_fc1.bias": "model-00001-of-00004.safetensors",
115
+ "visual.blocks.8.mlp.linear_fc2.weight": "model-00001-of-00004.safetensors",
116
+ "visual.blocks.8.mlp.linear_fc2.bias": "model-00001-of-00004.safetensors",
117
+ "visual.blocks.9.norm1.weight": "model-00001-of-00004.safetensors",
118
+ "visual.blocks.9.norm1.bias": "model-00001-of-00004.safetensors",
119
+ "visual.blocks.9.norm2.weight": "model-00001-of-00004.safetensors",
120
+ "visual.blocks.9.norm2.bias": "model-00001-of-00004.safetensors",
121
+ "visual.blocks.9.attn.qkv.weight": "model-00001-of-00004.safetensors",
122
+ "visual.blocks.9.attn.qkv.bias": "model-00001-of-00004.safetensors",
123
+ "visual.blocks.9.attn.proj.weight": "model-00001-of-00004.safetensors",
124
+ "visual.blocks.9.attn.proj.bias": "model-00001-of-00004.safetensors",
125
+ "visual.blocks.9.mlp.linear_fc1.weight": "model-00001-of-00004.safetensors",
126
+ "visual.blocks.9.mlp.linear_fc1.bias": "model-00001-of-00004.safetensors",
127
+ "visual.blocks.9.mlp.linear_fc2.weight": "model-00001-of-00004.safetensors",
128
+ "visual.blocks.9.mlp.linear_fc2.bias": "model-00001-of-00004.safetensors",
129
+ "visual.blocks.10.norm1.weight": "model-00001-of-00004.safetensors",
130
+ "visual.blocks.10.norm1.bias": "model-00001-of-00004.safetensors",
131
+ "visual.blocks.10.norm2.weight": "model-00001-of-00004.safetensors",
132
+ "visual.blocks.10.norm2.bias": "model-00001-of-00004.safetensors",
133
+ "visual.blocks.10.attn.qkv.weight": "model-00001-of-00004.safetensors",
134
+ "visual.blocks.10.attn.qkv.bias": "model-00001-of-00004.safetensors",
135
+ "visual.blocks.10.attn.proj.weight": "model-00001-of-00004.safetensors",
136
+ "visual.blocks.10.attn.proj.bias": "model-00001-of-00004.safetensors",
137
+ "visual.blocks.10.mlp.linear_fc1.weight": "model-00001-of-00004.safetensors",
138
+ "visual.blocks.10.mlp.linear_fc1.bias": "model-00001-of-00004.safetensors",
139
+ "visual.blocks.10.mlp.linear_fc2.weight": "model-00001-of-00004.safetensors",
140
+ "visual.blocks.10.mlp.linear_fc2.bias": "model-00001-of-00004.safetensors",
141
+ "visual.blocks.11.norm1.weight": "model-00001-of-00004.safetensors",
142
+ "visual.blocks.11.norm1.bias": "model-00001-of-00004.safetensors",
143
+ "visual.blocks.11.norm2.weight": "model-00001-of-00004.safetensors",
144
+ "visual.blocks.11.norm2.bias": "model-00001-of-00004.safetensors",
145
+ "visual.blocks.11.attn.qkv.weight": "model-00001-of-00004.safetensors",
146
+ "visual.blocks.11.attn.qkv.bias": "model-00001-of-00004.safetensors",
147
+ "visual.blocks.11.attn.proj.weight": "model-00001-of-00004.safetensors",
148
+ "visual.blocks.11.attn.proj.bias": "model-00001-of-00004.safetensors",
149
+ "visual.blocks.11.mlp.linear_fc1.weight": "model-00001-of-00004.safetensors",
150
+ "visual.blocks.11.mlp.linear_fc1.bias": "model-00001-of-00004.safetensors",
151
+ "visual.blocks.11.mlp.linear_fc2.weight": "model-00001-of-00004.safetensors",
152
+ "visual.blocks.11.mlp.linear_fc2.bias": "model-00001-of-00004.safetensors",
153
+ "visual.blocks.12.norm1.weight": "model-00001-of-00004.safetensors",
154
+ "visual.blocks.12.norm1.bias": "model-00001-of-00004.safetensors",
155
+ "visual.blocks.12.norm2.weight": "model-00001-of-00004.safetensors",
156
+ "visual.blocks.12.norm2.bias": "model-00001-of-00004.safetensors",
157
+ "visual.blocks.12.attn.qkv.weight": "model-00001-of-00004.safetensors",
158
+ "visual.blocks.12.attn.qkv.bias": "model-00001-of-00004.safetensors",
159
+ "visual.blocks.12.attn.proj.weight": "model-00001-of-00004.safetensors",
160
+ "visual.blocks.12.attn.proj.bias": "model-00001-of-00004.safetensors",
161
+ "visual.blocks.12.mlp.linear_fc1.weight": "model-00001-of-00004.safetensors",
162
+ "visual.blocks.12.mlp.linear_fc1.bias": "model-00001-of-00004.safetensors",
163
+ "visual.blocks.12.mlp.linear_fc2.weight": "model-00001-of-00004.safetensors",
164
+ "visual.blocks.12.mlp.linear_fc2.bias": "model-00001-of-00004.safetensors",
165
+ "visual.blocks.13.norm1.weight": "model-00001-of-00004.safetensors",
166
+ "visual.blocks.13.norm1.bias": "model-00001-of-00004.safetensors",
167
+ "visual.blocks.13.norm2.weight": "model-00001-of-00004.safetensors",
168
+ "visual.blocks.13.norm2.bias": "model-00001-of-00004.safetensors",
169
+ "visual.blocks.13.attn.qkv.weight": "model-00001-of-00004.safetensors",
170
+ "visual.blocks.13.attn.qkv.bias": "model-00001-of-00004.safetensors",
171
+ "visual.blocks.13.attn.proj.weight": "model-00001-of-00004.safetensors",
172
+ "visual.blocks.13.attn.proj.bias": "model-00001-of-00004.safetensors",
173
+ "visual.blocks.13.mlp.linear_fc1.weight": "model-00001-of-00004.safetensors",
174
+ "visual.blocks.13.mlp.linear_fc1.bias": "model-00001-of-00004.safetensors",
175
+ "visual.blocks.13.mlp.linear_fc2.weight": "model-00001-of-00004.safetensors",
176
+ "visual.blocks.13.mlp.linear_fc2.bias": "model-00001-of-00004.safetensors",
177
+ "visual.blocks.14.norm1.weight": "model-00001-of-00004.safetensors",
178
+ "visual.blocks.14.norm1.bias": "model-00001-of-00004.safetensors",
179
+ "visual.blocks.14.norm2.weight": "model-00001-of-00004.safetensors",
180
+ "visual.blocks.14.norm2.bias": "model-00001-of-00004.safetensors",
181
+ "visual.blocks.14.attn.qkv.weight": "model-00001-of-00004.safetensors",
182
+ "visual.blocks.14.attn.qkv.bias": "model-00001-of-00004.safetensors",
183
+ "visual.blocks.14.attn.proj.weight": "model-00001-of-00004.safetensors",
184
+ "visual.blocks.14.attn.proj.bias": "model-00001-of-00004.safetensors",
185
+ "visual.blocks.14.mlp.linear_fc1.weight": "model-00001-of-00004.safetensors",
186
+ "visual.blocks.14.mlp.linear_fc1.bias": "model-00001-of-00004.safetensors",
187
+ "visual.blocks.14.mlp.linear_fc2.weight": "model-00001-of-00004.safetensors",
188
+ "visual.blocks.14.mlp.linear_fc2.bias": "model-00001-of-00004.safetensors",
189
+ "visual.blocks.15.norm1.weight": "model-00001-of-00004.safetensors",
190
+ "visual.blocks.15.norm1.bias": "model-00001-of-00004.safetensors",
191
+ "visual.blocks.15.norm2.weight": "model-00001-of-00004.safetensors",
192
+ "visual.blocks.15.norm2.bias": "model-00001-of-00004.safetensors",
193
+ "visual.blocks.15.attn.qkv.weight": "model-00001-of-00004.safetensors",
194
+ "visual.blocks.15.attn.qkv.bias": "model-00001-of-00004.safetensors",
195
+ "visual.blocks.15.attn.proj.weight": "model-00001-of-00004.safetensors",
196
+ "visual.blocks.15.attn.proj.bias": "model-00001-of-00004.safetensors",
197
+ "visual.blocks.15.mlp.linear_fc1.weight": "model-00001-of-00004.safetensors",
198
+ "visual.blocks.15.mlp.linear_fc1.bias": "model-00001-of-00004.safetensors",
199
+ "visual.blocks.15.mlp.linear_fc2.weight": "model-00001-of-00004.safetensors",
200
+ "visual.blocks.15.mlp.linear_fc2.bias": "model-00001-of-00004.safetensors",
201
+ "visual.blocks.16.norm1.weight": "model-00001-of-00004.safetensors",
202
+ "visual.blocks.16.norm1.bias": "model-00001-of-00004.safetensors",
203
+ "visual.blocks.16.norm2.weight": "model-00001-of-00004.safetensors",
204
+ "visual.blocks.16.norm2.bias": "model-00001-of-00004.safetensors",
205
+ "visual.blocks.16.attn.qkv.weight": "model-00001-of-00004.safetensors",
206
+ "visual.blocks.16.attn.qkv.bias": "model-00001-of-00004.safetensors",
207
+ "visual.blocks.16.attn.proj.weight": "model-00001-of-00004.safetensors",
208
+ "visual.blocks.16.attn.proj.bias": "model-00001-of-00004.safetensors",
209
+ "visual.blocks.16.mlp.linear_fc1.weight": "model-00001-of-00004.safetensors",
210
+ "visual.blocks.16.mlp.linear_fc1.bias": "model-00001-of-00004.safetensors",
211
+ "visual.blocks.16.mlp.linear_fc2.weight": "model-00001-of-00004.safetensors",
212
+ "visual.blocks.16.mlp.linear_fc2.bias": "model-00001-of-00004.safetensors",
213
+ "visual.blocks.17.norm1.weight": "model-00001-of-00004.safetensors",
214
+ "visual.blocks.17.norm1.bias": "model-00001-of-00004.safetensors",
215
+ "visual.blocks.17.norm2.weight": "model-00001-of-00004.safetensors",
216
+ "visual.blocks.17.norm2.bias": "model-00001-of-00004.safetensors",
217
+ "visual.blocks.17.attn.qkv.weight": "model-00001-of-00004.safetensors",
218
+ "visual.blocks.17.attn.qkv.bias": "model-00001-of-00004.safetensors",
219
+ "visual.blocks.17.attn.proj.weight": "model-00001-of-00004.safetensors",
220
+ "visual.blocks.17.attn.proj.bias": "model-00001-of-00004.safetensors",
221
+ "visual.blocks.17.mlp.linear_fc1.weight": "model-00001-of-00004.safetensors",
222
+ "visual.blocks.17.mlp.linear_fc1.bias": "model-00001-of-00004.safetensors",
223
+ "visual.blocks.17.mlp.linear_fc2.weight": "model-00001-of-00004.safetensors",
224
+ "visual.blocks.17.mlp.linear_fc2.bias": "model-00001-of-00004.safetensors",
225
+ "visual.blocks.18.norm1.weight": "model-00001-of-00004.safetensors",
226
+ "visual.blocks.18.norm1.bias": "model-00001-of-00004.safetensors",
227
+ "visual.blocks.18.norm2.weight": "model-00001-of-00004.safetensors",
228
+ "visual.blocks.18.norm2.bias": "model-00001-of-00004.safetensors",
229
+ "visual.blocks.18.attn.qkv.weight": "model-00001-of-00004.safetensors",
230
+ "visual.blocks.18.attn.qkv.bias": "model-00001-of-00004.safetensors",
231
+ "visual.blocks.18.attn.proj.weight": "model-00001-of-00004.safetensors",
232
+ "visual.blocks.18.attn.proj.bias": "model-00001-of-00004.safetensors",
233
+ "visual.blocks.18.mlp.linear_fc1.weight": "model-00001-of-00004.safetensors",
234
+ "visual.blocks.18.mlp.linear_fc1.bias": "model-00001-of-00004.safetensors",
235
+ "visual.blocks.18.mlp.linear_fc2.weight": "model-00001-of-00004.safetensors",
236
+ "visual.blocks.18.mlp.linear_fc2.bias": "model-00001-of-00004.safetensors",
237
+ "visual.blocks.19.norm1.weight": "model-00001-of-00004.safetensors",
238
+ "visual.blocks.19.norm1.bias": "model-00001-of-00004.safetensors",
239
+ "visual.blocks.19.norm2.weight": "model-00001-of-00004.safetensors",
240
+ "visual.blocks.19.norm2.bias": "model-00001-of-00004.safetensors",
241
+ "visual.blocks.19.attn.qkv.weight": "model-00001-of-00004.safetensors",
242
+ "visual.blocks.19.attn.qkv.bias": "model-00001-of-00004.safetensors",
243
+ "visual.blocks.19.attn.proj.weight": "model-00001-of-00004.safetensors",
244
+ "visual.blocks.19.attn.proj.bias": "model-00001-of-00004.safetensors",
245
+ "visual.blocks.19.mlp.linear_fc1.weight": "model-00001-of-00004.safetensors",
246
+ "visual.blocks.19.mlp.linear_fc1.bias": "model-00001-of-00004.safetensors",
247
+ "visual.blocks.19.mlp.linear_fc2.weight": "model-00001-of-00004.safetensors",
248
+ "visual.blocks.19.mlp.linear_fc2.bias": "model-00001-of-00004.safetensors",
249
+ "visual.blocks.20.norm1.weight": "model-00001-of-00004.safetensors",
250
+ "visual.blocks.20.norm1.bias": "model-00001-of-00004.safetensors",
251
+ "visual.blocks.20.norm2.weight": "model-00001-of-00004.safetensors",
252
+ "visual.blocks.20.norm2.bias": "model-00001-of-00004.safetensors",
253
+ "visual.blocks.20.attn.qkv.weight": "model-00001-of-00004.safetensors",
254
+ "visual.blocks.20.attn.qkv.bias": "model-00001-of-00004.safetensors",
255
+ "visual.blocks.20.attn.proj.weight": "model-00001-of-00004.safetensors",
256
+ "visual.blocks.20.attn.proj.bias": "model-00001-of-00004.safetensors",
257
+ "visual.blocks.20.mlp.linear_fc1.weight": "model-00001-of-00004.safetensors",
258
+ "visual.blocks.20.mlp.linear_fc1.bias": "model-00001-of-00004.safetensors",
259
+ "visual.blocks.20.mlp.linear_fc2.weight": "model-00001-of-00004.safetensors",
260
+ "visual.blocks.20.mlp.linear_fc2.bias": "model-00001-of-00004.safetensors",
261
+ "visual.blocks.21.norm1.weight": "model-00001-of-00004.safetensors",
262
+ "visual.blocks.21.norm1.bias": "model-00001-of-00004.safetensors",
263
+ "visual.blocks.21.norm2.weight": "model-00001-of-00004.safetensors",
264
+ "visual.blocks.21.norm2.bias": "model-00001-of-00004.safetensors",
265
+ "visual.blocks.21.attn.qkv.weight": "model-00001-of-00004.safetensors",
266
+ "visual.blocks.21.attn.qkv.bias": "model-00001-of-00004.safetensors",
267
+ "visual.blocks.21.attn.proj.weight": "model-00001-of-00004.safetensors",
268
+ "visual.blocks.21.attn.proj.bias": "model-00001-of-00004.safetensors",
269
+ "visual.blocks.21.mlp.linear_fc1.weight": "model-00001-of-00004.safetensors",
270
+ "visual.blocks.21.mlp.linear_fc1.bias": "model-00001-of-00004.safetensors",
271
+ "visual.blocks.21.mlp.linear_fc2.weight": "model-00001-of-00004.safetensors",
272
+ "visual.blocks.21.mlp.linear_fc2.bias": "model-00001-of-00004.safetensors",
273
+ "visual.blocks.22.norm1.weight": "model-00001-of-00004.safetensors",
274
+ "visual.blocks.22.norm1.bias": "model-00001-of-00004.safetensors",
275
+ "visual.blocks.22.norm2.weight": "model-00001-of-00004.safetensors",
276
+ "visual.blocks.22.norm2.bias": "model-00001-of-00004.safetensors",
277
+ "visual.blocks.22.attn.qkv.weight": "model-00001-of-00004.safetensors",
278
+ "visual.blocks.22.attn.qkv.bias": "model-00001-of-00004.safetensors",
279
+ "visual.blocks.22.attn.proj.weight": "model-00001-of-00004.safetensors",
280
+ "visual.blocks.22.attn.proj.bias": "model-00001-of-00004.safetensors",
281
+ "visual.blocks.22.mlp.linear_fc1.weight": "model-00001-of-00004.safetensors",
282
+ "visual.blocks.22.mlp.linear_fc1.bias": "model-00001-of-00004.safetensors",
283
+ "visual.blocks.22.mlp.linear_fc2.weight": "model-00001-of-00004.safetensors",
284
+ "visual.blocks.22.mlp.linear_fc2.bias": "model-00001-of-00004.safetensors",
285
+ "visual.blocks.23.norm1.weight": "model-00001-of-00004.safetensors",
286
+ "visual.blocks.23.norm1.bias": "model-00001-of-00004.safetensors",
287
+ "visual.blocks.23.norm2.weight": "model-00001-of-00004.safetensors",
288
+ "visual.blocks.23.norm2.bias": "model-00001-of-00004.safetensors",
289
+ "visual.blocks.23.attn.qkv.weight": "model-00001-of-00004.safetensors",
290
+ "visual.blocks.23.attn.qkv.bias": "model-00001-of-00004.safetensors",
291
+ "visual.blocks.23.attn.proj.weight": "model-00001-of-00004.safetensors",
292
+ "visual.blocks.23.attn.proj.bias": "model-00001-of-00004.safetensors",
293
+ "visual.blocks.23.mlp.linear_fc1.weight": "model-00001-of-00004.safetensors",
294
+ "visual.blocks.23.mlp.linear_fc1.bias": "model-00001-of-00004.safetensors",
295
+ "visual.blocks.23.mlp.linear_fc2.weight": "model-00001-of-00004.safetensors",
296
+ "visual.blocks.23.mlp.linear_fc2.bias": "model-00001-of-00004.safetensors",
297
+ "visual.merger.norm.weight": "model-00001-of-00004.safetensors",
298
+ "visual.merger.norm.bias": "model-00001-of-00004.safetensors",
299
+ "visual.merger.linear_fc1.weight": "model-00001-of-00004.safetensors",
300
+ "visual.merger.linear_fc1.bias": "model-00001-of-00004.safetensors",
301
+ "visual.merger.linear_fc2.weight": "model-00001-of-00004.safetensors",
302
+ "visual.merger.linear_fc2.bias": "model-00001-of-00004.safetensors",
303
+ "language_model.embed_tokens.weight": "model-00001-of-00004.safetensors",
304
+ "language_model.layers.0.linear_attn.dt_bias": "model-00001-of-00004.safetensors",
305
+ "language_model.layers.0.linear_attn.A_log": "model-00001-of-00004.safetensors",
306
+ "language_model.layers.0.linear_attn.conv1d.weight": "model-00001-of-00004.safetensors",
307
+ "language_model.layers.0.linear_attn.norm.weight": "model-00001-of-00004.safetensors",
308
+ "language_model.layers.0.linear_attn.out_proj.weight": "model-00001-of-00004.safetensors",
309
+ "language_model.layers.0.linear_attn.in_proj_qkv.weight": "model-00001-of-00004.safetensors",
310
+ "language_model.layers.0.linear_attn.in_proj_z.weight": "model-00001-of-00004.safetensors",
311
+ "language_model.layers.0.linear_attn.in_proj_b.weight": "model-00001-of-00004.safetensors",
312
+ "language_model.layers.0.linear_attn.in_proj_a.weight": "model-00001-of-00004.safetensors",
313
+ "language_model.layers.0.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
314
+ "language_model.layers.0.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
315
+ "language_model.layers.0.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
316
+ "language_model.layers.0.input_layernorm.weight": "model-00001-of-00004.safetensors",
317
+ "language_model.layers.0.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
318
+ "language_model.layers.1.linear_attn.dt_bias": "model-00001-of-00004.safetensors",
319
+ "language_model.layers.1.linear_attn.A_log": "model-00001-of-00004.safetensors",
320
+ "language_model.layers.1.linear_attn.conv1d.weight": "model-00001-of-00004.safetensors",
321
+ "language_model.layers.1.linear_attn.norm.weight": "model-00001-of-00004.safetensors",
322
+ "language_model.layers.1.linear_attn.out_proj.weight": "model-00001-of-00004.safetensors",
323
+ "language_model.layers.1.linear_attn.in_proj_qkv.weight": "model-00001-of-00004.safetensors",
324
+ "language_model.layers.1.linear_attn.in_proj_z.weight": "model-00001-of-00004.safetensors",
325
+ "language_model.layers.1.linear_attn.in_proj_b.weight": "model-00001-of-00004.safetensors",
326
+ "language_model.layers.1.linear_attn.in_proj_a.weight": "model-00001-of-00004.safetensors",
327
+ "language_model.layers.1.mlp.gate_proj.weight": "model-00001-of-00004.safetensors",
328
+ "language_model.layers.1.mlp.up_proj.weight": "model-00001-of-00004.safetensors",
329
+ "language_model.layers.1.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
330
+ "language_model.layers.1.input_layernorm.weight": "model-00001-of-00004.safetensors",
331
+ "language_model.layers.1.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
332
+ "language_model.layers.2.linear_attn.dt_bias": "model-00001-of-00004.safetensors",
333
+ "language_model.layers.2.linear_attn.A_log": "model-00001-of-00004.safetensors",
334
+ "language_model.layers.2.linear_attn.conv1d.weight": "model-00001-of-00004.safetensors",
335
+ "language_model.layers.2.linear_attn.norm.weight": "model-00001-of-00004.safetensors",
336
+ "language_model.layers.2.linear_attn.out_proj.weight": "model-00001-of-00004.safetensors",
337
+ "language_model.layers.2.linear_attn.in_proj_qkv.weight": "model-00001-of-00004.safetensors",
338
+ "language_model.layers.2.linear_attn.in_proj_z.weight": "model-00001-of-00004.safetensors",
339
+ "language_model.layers.2.linear_attn.in_proj_b.weight": "model-00001-of-00004.safetensors",
340
+ "language_model.layers.2.linear_attn.in_proj_a.weight": "model-00001-of-00004.safetensors",
341
+ "language_model.layers.2.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
342
+ "language_model.layers.2.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
343
+ "language_model.layers.2.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
344
+ "language_model.layers.2.input_layernorm.weight": "model-00002-of-00004.safetensors",
345
+ "language_model.layers.2.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
346
+ "language_model.layers.3.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
347
+ "language_model.layers.3.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
348
+ "language_model.layers.3.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
349
+ "language_model.layers.3.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
350
+ "language_model.layers.3.self_attn.q_norm.weight": "model-00002-of-00004.safetensors",
351
+ "language_model.layers.3.self_attn.k_norm.weight": "model-00002-of-00004.safetensors",
352
+ "language_model.layers.3.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
353
+ "language_model.layers.3.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
354
+ "language_model.layers.3.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
355
+ "language_model.layers.3.input_layernorm.weight": "model-00002-of-00004.safetensors",
356
+ "language_model.layers.3.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
357
+ "language_model.layers.4.linear_attn.dt_bias": "model-00002-of-00004.safetensors",
358
+ "language_model.layers.4.linear_attn.A_log": "model-00002-of-00004.safetensors",
359
+ "language_model.layers.4.linear_attn.conv1d.weight": "model-00002-of-00004.safetensors",
360
+ "language_model.layers.4.linear_attn.norm.weight": "model-00002-of-00004.safetensors",
361
+ "language_model.layers.4.linear_attn.out_proj.weight": "model-00002-of-00004.safetensors",
362
+ "language_model.layers.4.linear_attn.in_proj_qkv.weight": "model-00002-of-00004.safetensors",
363
+ "language_model.layers.4.linear_attn.in_proj_z.weight": "model-00002-of-00004.safetensors",
364
+ "language_model.layers.4.linear_attn.in_proj_b.weight": "model-00002-of-00004.safetensors",
365
+ "language_model.layers.4.linear_attn.in_proj_a.weight": "model-00002-of-00004.safetensors",
366
+ "language_model.layers.4.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
367
+ "language_model.layers.4.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
368
+ "language_model.layers.4.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
369
+ "language_model.layers.4.input_layernorm.weight": "model-00002-of-00004.safetensors",
370
+ "language_model.layers.4.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
371
+ "language_model.layers.5.linear_attn.dt_bias": "model-00002-of-00004.safetensors",
372
+ "language_model.layers.5.linear_attn.A_log": "model-00002-of-00004.safetensors",
373
+ "language_model.layers.5.linear_attn.conv1d.weight": "model-00002-of-00004.safetensors",
374
+ "language_model.layers.5.linear_attn.norm.weight": "model-00002-of-00004.safetensors",
375
+ "language_model.layers.5.linear_attn.out_proj.weight": "model-00002-of-00004.safetensors",
376
+ "language_model.layers.5.linear_attn.in_proj_qkv.weight": "model-00002-of-00004.safetensors",
377
+ "language_model.layers.5.linear_attn.in_proj_z.weight": "model-00002-of-00004.safetensors",
378
+ "language_model.layers.5.linear_attn.in_proj_b.weight": "model-00002-of-00004.safetensors",
379
+ "language_model.layers.5.linear_attn.in_proj_a.weight": "model-00002-of-00004.safetensors",
380
+ "language_model.layers.5.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
381
+ "language_model.layers.5.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
382
+ "language_model.layers.5.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
383
+ "language_model.layers.5.input_layernorm.weight": "model-00002-of-00004.safetensors",
384
+ "language_model.layers.5.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
385
+ "language_model.layers.6.linear_attn.dt_bias": "model-00002-of-00004.safetensors",
386
+ "language_model.layers.6.linear_attn.A_log": "model-00002-of-00004.safetensors",
387
+ "language_model.layers.6.linear_attn.conv1d.weight": "model-00002-of-00004.safetensors",
388
+ "language_model.layers.6.linear_attn.norm.weight": "model-00002-of-00004.safetensors",
389
+ "language_model.layers.6.linear_attn.out_proj.weight": "model-00002-of-00004.safetensors",
390
+ "language_model.layers.6.linear_attn.in_proj_qkv.weight": "model-00002-of-00004.safetensors",
391
+ "language_model.layers.6.linear_attn.in_proj_z.weight": "model-00002-of-00004.safetensors",
392
+ "language_model.layers.6.linear_attn.in_proj_b.weight": "model-00002-of-00004.safetensors",
393
+ "language_model.layers.6.linear_attn.in_proj_a.weight": "model-00002-of-00004.safetensors",
394
+ "language_model.layers.6.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
395
+ "language_model.layers.6.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
396
+ "language_model.layers.6.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
397
+ "language_model.layers.6.input_layernorm.weight": "model-00002-of-00004.safetensors",
398
+ "language_model.layers.6.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
399
+ "language_model.layers.7.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
400
+ "language_model.layers.7.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
401
+ "language_model.layers.7.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
402
+ "language_model.layers.7.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
403
+ "language_model.layers.7.self_attn.q_norm.weight": "model-00002-of-00004.safetensors",
404
+ "language_model.layers.7.self_attn.k_norm.weight": "model-00002-of-00004.safetensors",
405
+ "language_model.layers.7.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
406
+ "language_model.layers.7.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
407
+ "language_model.layers.7.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
408
+ "language_model.layers.7.input_layernorm.weight": "model-00002-of-00004.safetensors",
409
+ "language_model.layers.7.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
410
+ "language_model.layers.8.linear_attn.dt_bias": "model-00002-of-00004.safetensors",
411
+ "language_model.layers.8.linear_attn.A_log": "model-00002-of-00004.safetensors",
412
+ "language_model.layers.8.linear_attn.conv1d.weight": "model-00002-of-00004.safetensors",
413
+ "language_model.layers.8.linear_attn.norm.weight": "model-00002-of-00004.safetensors",
414
+ "language_model.layers.8.linear_attn.out_proj.weight": "model-00002-of-00004.safetensors",
415
+ "language_model.layers.8.linear_attn.in_proj_qkv.weight": "model-00002-of-00004.safetensors",
416
+ "language_model.layers.8.linear_attn.in_proj_z.weight": "model-00002-of-00004.safetensors",
417
+ "language_model.layers.8.linear_attn.in_proj_b.weight": "model-00002-of-00004.safetensors",
418
+ "language_model.layers.8.linear_attn.in_proj_a.weight": "model-00002-of-00004.safetensors",
419
+ "language_model.layers.8.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
420
+ "language_model.layers.8.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
421
+ "language_model.layers.8.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
422
+ "language_model.layers.8.input_layernorm.weight": "model-00002-of-00004.safetensors",
423
+ "language_model.layers.8.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
424
+ "language_model.layers.9.linear_attn.dt_bias": "model-00002-of-00004.safetensors",
425
+ "language_model.layers.9.linear_attn.A_log": "model-00002-of-00004.safetensors",
426
+ "language_model.layers.9.linear_attn.conv1d.weight": "model-00002-of-00004.safetensors",
427
+ "language_model.layers.9.linear_attn.norm.weight": "model-00002-of-00004.safetensors",
428
+ "language_model.layers.9.linear_attn.out_proj.weight": "model-00002-of-00004.safetensors",
429
+ "language_model.layers.9.linear_attn.in_proj_qkv.weight": "model-00002-of-00004.safetensors",
430
+ "language_model.layers.9.linear_attn.in_proj_z.weight": "model-00002-of-00004.safetensors",
431
+ "language_model.layers.9.linear_attn.in_proj_b.weight": "model-00002-of-00004.safetensors",
432
+ "language_model.layers.9.linear_attn.in_proj_a.weight": "model-00002-of-00004.safetensors",
433
+ "language_model.layers.9.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
434
+ "language_model.layers.9.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
435
+ "language_model.layers.9.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
436
+ "language_model.layers.9.input_layernorm.weight": "model-00002-of-00004.safetensors",
437
+ "language_model.layers.9.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
438
+ "language_model.layers.10.linear_attn.dt_bias": "model-00002-of-00004.safetensors",
439
+ "language_model.layers.10.linear_attn.A_log": "model-00002-of-00004.safetensors",
440
+ "language_model.layers.10.linear_attn.conv1d.weight": "model-00002-of-00004.safetensors",
441
+ "language_model.layers.10.linear_attn.norm.weight": "model-00002-of-00004.safetensors",
442
+ "language_model.layers.10.linear_attn.out_proj.weight": "model-00002-of-00004.safetensors",
443
+ "language_model.layers.10.linear_attn.in_proj_qkv.weight": "model-00002-of-00004.safetensors",
444
+ "language_model.layers.10.linear_attn.in_proj_z.weight": "model-00002-of-00004.safetensors",
445
+ "language_model.layers.10.linear_attn.in_proj_b.weight": "model-00002-of-00004.safetensors",
446
+ "language_model.layers.10.linear_attn.in_proj_a.weight": "model-00002-of-00004.safetensors",
447
+ "language_model.layers.10.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
448
+ "language_model.layers.10.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
449
+ "language_model.layers.10.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
450
+ "language_model.layers.10.input_layernorm.weight": "model-00002-of-00004.safetensors",
451
+ "language_model.layers.10.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
452
+ "language_model.layers.11.self_attn.q_proj.weight": "model-00002-of-00004.safetensors",
453
+ "language_model.layers.11.self_attn.k_proj.weight": "model-00002-of-00004.safetensors",
454
+ "language_model.layers.11.self_attn.v_proj.weight": "model-00002-of-00004.safetensors",
455
+ "language_model.layers.11.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
456
+ "language_model.layers.11.self_attn.q_norm.weight": "model-00002-of-00004.safetensors",
457
+ "language_model.layers.11.self_attn.k_norm.weight": "model-00002-of-00004.safetensors",
458
+ "language_model.layers.11.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
459
+ "language_model.layers.11.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
460
+ "language_model.layers.11.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
461
+ "language_model.layers.11.input_layernorm.weight": "model-00002-of-00004.safetensors",
462
+ "language_model.layers.11.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
463
+ "language_model.layers.12.linear_attn.dt_bias": "model-00002-of-00004.safetensors",
464
+ "language_model.layers.12.linear_attn.A_log": "model-00002-of-00004.safetensors",
465
+ "language_model.layers.12.linear_attn.conv1d.weight": "model-00002-of-00004.safetensors",
466
+ "language_model.layers.12.linear_attn.norm.weight": "model-00002-of-00004.safetensors",
467
+ "language_model.layers.12.linear_attn.out_proj.weight": "model-00002-of-00004.safetensors",
468
+ "language_model.layers.12.linear_attn.in_proj_qkv.weight": "model-00002-of-00004.safetensors",
469
+ "language_model.layers.12.linear_attn.in_proj_z.weight": "model-00002-of-00004.safetensors",
470
+ "language_model.layers.12.linear_attn.in_proj_b.weight": "model-00002-of-00004.safetensors",
471
+ "language_model.layers.12.linear_attn.in_proj_a.weight": "model-00002-of-00004.safetensors",
472
+ "language_model.layers.12.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
473
+ "language_model.layers.12.mlp.up_proj.weight": "model-00002-of-00004.safetensors",
474
+ "language_model.layers.12.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
475
+ "language_model.layers.12.input_layernorm.weight": "model-00002-of-00004.safetensors",
476
+ "language_model.layers.12.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
477
+ "language_model.layers.13.linear_attn.dt_bias": "model-00002-of-00004.safetensors",
478
+ "language_model.layers.13.linear_attn.A_log": "model-00002-of-00004.safetensors",
479
+ "language_model.layers.13.linear_attn.conv1d.weight": "model-00002-of-00004.safetensors",
480
+ "language_model.layers.13.linear_attn.norm.weight": "model-00002-of-00004.safetensors",
481
+ "language_model.layers.13.linear_attn.out_proj.weight": "model-00002-of-00004.safetensors",
482
+ "language_model.layers.13.linear_attn.in_proj_qkv.weight": "model-00002-of-00004.safetensors",
483
+ "language_model.layers.13.linear_attn.in_proj_z.weight": "model-00002-of-00004.safetensors",
484
+ "language_model.layers.13.linear_attn.in_proj_b.weight": "model-00002-of-00004.safetensors",
485
+ "language_model.layers.13.linear_attn.in_proj_a.weight": "model-00002-of-00004.safetensors",
486
+ "language_model.layers.13.mlp.gate_proj.weight": "model-00002-of-00004.safetensors",
487
+ "language_model.layers.13.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
488
+ "language_model.layers.13.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
489
+ "language_model.layers.13.input_layernorm.weight": "model-00003-of-00004.safetensors",
490
+ "language_model.layers.13.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
491
+ "language_model.layers.14.linear_attn.dt_bias": "model-00003-of-00004.safetensors",
492
+ "language_model.layers.14.linear_attn.A_log": "model-00003-of-00004.safetensors",
493
+ "language_model.layers.14.linear_attn.conv1d.weight": "model-00003-of-00004.safetensors",
494
+ "language_model.layers.14.linear_attn.norm.weight": "model-00003-of-00004.safetensors",
495
+ "language_model.layers.14.linear_attn.out_proj.weight": "model-00003-of-00004.safetensors",
496
+ "language_model.layers.14.linear_attn.in_proj_qkv.weight": "model-00003-of-00004.safetensors",
497
+ "language_model.layers.14.linear_attn.in_proj_z.weight": "model-00003-of-00004.safetensors",
498
+ "language_model.layers.14.linear_attn.in_proj_b.weight": "model-00003-of-00004.safetensors",
499
+ "language_model.layers.14.linear_attn.in_proj_a.weight": "model-00003-of-00004.safetensors",
500
+ "language_model.layers.14.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
501
+ "language_model.layers.14.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
502
+ "language_model.layers.14.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
503
+ "language_model.layers.14.input_layernorm.weight": "model-00003-of-00004.safetensors",
504
+ "language_model.layers.14.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
505
+ "language_model.layers.15.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
506
+ "language_model.layers.15.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
507
+ "language_model.layers.15.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
508
+ "language_model.layers.15.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
509
+ "language_model.layers.15.self_attn.q_norm.weight": "model-00003-of-00004.safetensors",
510
+ "language_model.layers.15.self_attn.k_norm.weight": "model-00003-of-00004.safetensors",
511
+ "language_model.layers.15.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
512
+ "language_model.layers.15.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
513
+ "language_model.layers.15.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
514
+ "language_model.layers.15.input_layernorm.weight": "model-00003-of-00004.safetensors",
515
+ "language_model.layers.15.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
516
+ "language_model.layers.16.linear_attn.dt_bias": "model-00003-of-00004.safetensors",
517
+ "language_model.layers.16.linear_attn.A_log": "model-00003-of-00004.safetensors",
518
+ "language_model.layers.16.linear_attn.conv1d.weight": "model-00003-of-00004.safetensors",
519
+ "language_model.layers.16.linear_attn.norm.weight": "model-00003-of-00004.safetensors",
520
+ "language_model.layers.16.linear_attn.out_proj.weight": "model-00003-of-00004.safetensors",
521
+ "language_model.layers.16.linear_attn.in_proj_qkv.weight": "model-00003-of-00004.safetensors",
522
+ "language_model.layers.16.linear_attn.in_proj_z.weight": "model-00003-of-00004.safetensors",
523
+ "language_model.layers.16.linear_attn.in_proj_b.weight": "model-00003-of-00004.safetensors",
524
+ "language_model.layers.16.linear_attn.in_proj_a.weight": "model-00003-of-00004.safetensors",
525
+ "language_model.layers.16.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
526
+ "language_model.layers.16.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
527
+ "language_model.layers.16.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
528
+ "language_model.layers.16.input_layernorm.weight": "model-00003-of-00004.safetensors",
529
+ "language_model.layers.16.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
530
+ "language_model.layers.17.linear_attn.dt_bias": "model-00003-of-00004.safetensors",
531
+ "language_model.layers.17.linear_attn.A_log": "model-00003-of-00004.safetensors",
532
+ "language_model.layers.17.linear_attn.conv1d.weight": "model-00003-of-00004.safetensors",
533
+ "language_model.layers.17.linear_attn.norm.weight": "model-00003-of-00004.safetensors",
534
+ "language_model.layers.17.linear_attn.out_proj.weight": "model-00003-of-00004.safetensors",
535
+ "language_model.layers.17.linear_attn.in_proj_qkv.weight": "model-00003-of-00004.safetensors",
536
+ "language_model.layers.17.linear_attn.in_proj_z.weight": "model-00003-of-00004.safetensors",
537
+ "language_model.layers.17.linear_attn.in_proj_b.weight": "model-00003-of-00004.safetensors",
538
+ "language_model.layers.17.linear_attn.in_proj_a.weight": "model-00003-of-00004.safetensors",
539
+ "language_model.layers.17.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
540
+ "language_model.layers.17.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
541
+ "language_model.layers.17.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
542
+ "language_model.layers.17.input_layernorm.weight": "model-00003-of-00004.safetensors",
543
+ "language_model.layers.17.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
544
+ "language_model.layers.18.linear_attn.dt_bias": "model-00003-of-00004.safetensors",
545
+ "language_model.layers.18.linear_attn.A_log": "model-00003-of-00004.safetensors",
546
+ "language_model.layers.18.linear_attn.conv1d.weight": "model-00003-of-00004.safetensors",
547
+ "language_model.layers.18.linear_attn.norm.weight": "model-00003-of-00004.safetensors",
548
+ "language_model.layers.18.linear_attn.out_proj.weight": "model-00003-of-00004.safetensors",
549
+ "language_model.layers.18.linear_attn.in_proj_qkv.weight": "model-00003-of-00004.safetensors",
550
+ "language_model.layers.18.linear_attn.in_proj_z.weight": "model-00003-of-00004.safetensors",
551
+ "language_model.layers.18.linear_attn.in_proj_b.weight": "model-00003-of-00004.safetensors",
552
+ "language_model.layers.18.linear_attn.in_proj_a.weight": "model-00003-of-00004.safetensors",
553
+ "language_model.layers.18.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
554
+ "language_model.layers.18.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
555
+ "language_model.layers.18.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
556
+ "language_model.layers.18.input_layernorm.weight": "model-00003-of-00004.safetensors",
557
+ "language_model.layers.18.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
558
+ "language_model.layers.19.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
559
+ "language_model.layers.19.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
560
+ "language_model.layers.19.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
561
+ "language_model.layers.19.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
562
+ "language_model.layers.19.self_attn.q_norm.weight": "model-00003-of-00004.safetensors",
563
+ "language_model.layers.19.self_attn.k_norm.weight": "model-00003-of-00004.safetensors",
564
+ "language_model.layers.19.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
565
+ "language_model.layers.19.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
566
+ "language_model.layers.19.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
567
+ "language_model.layers.19.input_layernorm.weight": "model-00003-of-00004.safetensors",
568
+ "language_model.layers.19.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
569
+ "language_model.layers.20.linear_attn.dt_bias": "model-00003-of-00004.safetensors",
570
+ "language_model.layers.20.linear_attn.A_log": "model-00003-of-00004.safetensors",
571
+ "language_model.layers.20.linear_attn.conv1d.weight": "model-00003-of-00004.safetensors",
572
+ "language_model.layers.20.linear_attn.norm.weight": "model-00003-of-00004.safetensors",
573
+ "language_model.layers.20.linear_attn.out_proj.weight": "model-00003-of-00004.safetensors",
574
+ "language_model.layers.20.linear_attn.in_proj_qkv.weight": "model-00003-of-00004.safetensors",
575
+ "language_model.layers.20.linear_attn.in_proj_z.weight": "model-00003-of-00004.safetensors",
576
+ "language_model.layers.20.linear_attn.in_proj_b.weight": "model-00003-of-00004.safetensors",
577
+ "language_model.layers.20.linear_attn.in_proj_a.weight": "model-00003-of-00004.safetensors",
578
+ "language_model.layers.20.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
579
+ "language_model.layers.20.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
580
+ "language_model.layers.20.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
581
+ "language_model.layers.20.input_layernorm.weight": "model-00003-of-00004.safetensors",
582
+ "language_model.layers.20.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
583
+ "language_model.layers.21.linear_attn.dt_bias": "model-00003-of-00004.safetensors",
584
+ "language_model.layers.21.linear_attn.A_log": "model-00003-of-00004.safetensors",
585
+ "language_model.layers.21.linear_attn.conv1d.weight": "model-00003-of-00004.safetensors",
586
+ "language_model.layers.21.linear_attn.norm.weight": "model-00003-of-00004.safetensors",
587
+ "language_model.layers.21.linear_attn.out_proj.weight": "model-00003-of-00004.safetensors",
588
+ "language_model.layers.21.linear_attn.in_proj_qkv.weight": "model-00003-of-00004.safetensors",
589
+ "language_model.layers.21.linear_attn.in_proj_z.weight": "model-00003-of-00004.safetensors",
590
+ "language_model.layers.21.linear_attn.in_proj_b.weight": "model-00003-of-00004.safetensors",
591
+ "language_model.layers.21.linear_attn.in_proj_a.weight": "model-00003-of-00004.safetensors",
592
+ "language_model.layers.21.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
593
+ "language_model.layers.21.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
594
+ "language_model.layers.21.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
595
+ "language_model.layers.21.input_layernorm.weight": "model-00003-of-00004.safetensors",
596
+ "language_model.layers.21.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
597
+ "language_model.layers.22.linear_attn.dt_bias": "model-00003-of-00004.safetensors",
598
+ "language_model.layers.22.linear_attn.A_log": "model-00003-of-00004.safetensors",
599
+ "language_model.layers.22.linear_attn.conv1d.weight": "model-00003-of-00004.safetensors",
600
+ "language_model.layers.22.linear_attn.norm.weight": "model-00003-of-00004.safetensors",
601
+ "language_model.layers.22.linear_attn.out_proj.weight": "model-00003-of-00004.safetensors",
602
+ "language_model.layers.22.linear_attn.in_proj_qkv.weight": "model-00003-of-00004.safetensors",
603
+ "language_model.layers.22.linear_attn.in_proj_z.weight": "model-00003-of-00004.safetensors",
604
+ "language_model.layers.22.linear_attn.in_proj_b.weight": "model-00003-of-00004.safetensors",
605
+ "language_model.layers.22.linear_attn.in_proj_a.weight": "model-00003-of-00004.safetensors",
606
+ "language_model.layers.22.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
607
+ "language_model.layers.22.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
608
+ "language_model.layers.22.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
609
+ "language_model.layers.22.input_layernorm.weight": "model-00003-of-00004.safetensors",
610
+ "language_model.layers.22.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
611
+ "language_model.layers.23.self_attn.q_proj.weight": "model-00003-of-00004.safetensors",
612
+ "language_model.layers.23.self_attn.k_proj.weight": "model-00003-of-00004.safetensors",
613
+ "language_model.layers.23.self_attn.v_proj.weight": "model-00003-of-00004.safetensors",
614
+ "language_model.layers.23.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
615
+ "language_model.layers.23.self_attn.q_norm.weight": "model-00003-of-00004.safetensors",
616
+ "language_model.layers.23.self_attn.k_norm.weight": "model-00003-of-00004.safetensors",
617
+ "language_model.layers.23.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
618
+ "language_model.layers.23.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
619
+ "language_model.layers.23.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
620
+ "language_model.layers.23.input_layernorm.weight": "model-00003-of-00004.safetensors",
621
+ "language_model.layers.23.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
622
+ "language_model.layers.24.linear_attn.dt_bias": "model-00003-of-00004.safetensors",
623
+ "language_model.layers.24.linear_attn.A_log": "model-00003-of-00004.safetensors",
624
+ "language_model.layers.24.linear_attn.conv1d.weight": "model-00003-of-00004.safetensors",
625
+ "language_model.layers.24.linear_attn.norm.weight": "model-00003-of-00004.safetensors",
626
+ "language_model.layers.24.linear_attn.out_proj.weight": "model-00003-of-00004.safetensors",
627
+ "language_model.layers.24.linear_attn.in_proj_qkv.weight": "model-00003-of-00004.safetensors",
628
+ "language_model.layers.24.linear_attn.in_proj_z.weight": "model-00003-of-00004.safetensors",
629
+ "language_model.layers.24.linear_attn.in_proj_b.weight": "model-00003-of-00004.safetensors",
630
+ "language_model.layers.24.linear_attn.in_proj_a.weight": "model-00003-of-00004.safetensors",
631
+ "language_model.layers.24.mlp.gate_proj.weight": "model-00003-of-00004.safetensors",
632
+ "language_model.layers.24.mlp.up_proj.weight": "model-00003-of-00004.safetensors",
633
+ "language_model.layers.24.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
634
+ "language_model.layers.24.input_layernorm.weight": "model-00004-of-00004.safetensors",
635
+ "language_model.layers.24.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
636
+ "language_model.layers.25.linear_attn.dt_bias": "model-00004-of-00004.safetensors",
637
+ "language_model.layers.25.linear_attn.A_log": "model-00004-of-00004.safetensors",
638
+ "language_model.layers.25.linear_attn.conv1d.weight": "model-00004-of-00004.safetensors",
639
+ "language_model.layers.25.linear_attn.norm.weight": "model-00004-of-00004.safetensors",
640
+ "language_model.layers.25.linear_attn.out_proj.weight": "model-00004-of-00004.safetensors",
641
+ "language_model.layers.25.linear_attn.in_proj_qkv.weight": "model-00004-of-00004.safetensors",
642
+ "language_model.layers.25.linear_attn.in_proj_z.weight": "model-00004-of-00004.safetensors",
643
+ "language_model.layers.25.linear_attn.in_proj_b.weight": "model-00004-of-00004.safetensors",
644
+ "language_model.layers.25.linear_attn.in_proj_a.weight": "model-00004-of-00004.safetensors",
645
+ "language_model.layers.25.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
646
+ "language_model.layers.25.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
647
+ "language_model.layers.25.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
648
+ "language_model.layers.25.input_layernorm.weight": "model-00004-of-00004.safetensors",
649
+ "language_model.layers.25.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
650
+ "language_model.layers.26.linear_attn.dt_bias": "model-00004-of-00004.safetensors",
651
+ "language_model.layers.26.linear_attn.A_log": "model-00004-of-00004.safetensors",
652
+ "language_model.layers.26.linear_attn.conv1d.weight": "model-00004-of-00004.safetensors",
653
+ "language_model.layers.26.linear_attn.norm.weight": "model-00004-of-00004.safetensors",
654
+ "language_model.layers.26.linear_attn.out_proj.weight": "model-00004-of-00004.safetensors",
655
+ "language_model.layers.26.linear_attn.in_proj_qkv.weight": "model-00004-of-00004.safetensors",
656
+ "language_model.layers.26.linear_attn.in_proj_z.weight": "model-00004-of-00004.safetensors",
657
+ "language_model.layers.26.linear_attn.in_proj_b.weight": "model-00004-of-00004.safetensors",
658
+ "language_model.layers.26.linear_attn.in_proj_a.weight": "model-00004-of-00004.safetensors",
659
+ "language_model.layers.26.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
660
+ "language_model.layers.26.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
661
+ "language_model.layers.26.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
662
+ "language_model.layers.26.input_layernorm.weight": "model-00004-of-00004.safetensors",
663
+ "language_model.layers.26.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
664
+ "language_model.layers.27.self_attn.q_proj.weight": "model-00004-of-00004.safetensors",
665
+ "language_model.layers.27.self_attn.k_proj.weight": "model-00004-of-00004.safetensors",
666
+ "language_model.layers.27.self_attn.v_proj.weight": "model-00004-of-00004.safetensors",
667
+ "language_model.layers.27.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
668
+ "language_model.layers.27.self_attn.q_norm.weight": "model-00004-of-00004.safetensors",
669
+ "language_model.layers.27.self_attn.k_norm.weight": "model-00004-of-00004.safetensors",
670
+ "language_model.layers.27.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
671
+ "language_model.layers.27.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
672
+ "language_model.layers.27.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
673
+ "language_model.layers.27.input_layernorm.weight": "model-00004-of-00004.safetensors",
674
+ "language_model.layers.27.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
675
+ "language_model.layers.28.linear_attn.dt_bias": "model-00004-of-00004.safetensors",
676
+ "language_model.layers.28.linear_attn.A_log": "model-00004-of-00004.safetensors",
677
+ "language_model.layers.28.linear_attn.conv1d.weight": "model-00004-of-00004.safetensors",
678
+ "language_model.layers.28.linear_attn.norm.weight": "model-00004-of-00004.safetensors",
679
+ "language_model.layers.28.linear_attn.out_proj.weight": "model-00004-of-00004.safetensors",
680
+ "language_model.layers.28.linear_attn.in_proj_qkv.weight": "model-00004-of-00004.safetensors",
681
+ "language_model.layers.28.linear_attn.in_proj_z.weight": "model-00004-of-00004.safetensors",
682
+ "language_model.layers.28.linear_attn.in_proj_b.weight": "model-00004-of-00004.safetensors",
683
+ "language_model.layers.28.linear_attn.in_proj_a.weight": "model-00004-of-00004.safetensors",
684
+ "language_model.layers.28.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
685
+ "language_model.layers.28.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
686
+ "language_model.layers.28.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
687
+ "language_model.layers.28.input_layernorm.weight": "model-00004-of-00004.safetensors",
688
+ "language_model.layers.28.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
689
+ "language_model.layers.29.linear_attn.dt_bias": "model-00004-of-00004.safetensors",
690
+ "language_model.layers.29.linear_attn.A_log": "model-00004-of-00004.safetensors",
691
+ "language_model.layers.29.linear_attn.conv1d.weight": "model-00004-of-00004.safetensors",
692
+ "language_model.layers.29.linear_attn.norm.weight": "model-00004-of-00004.safetensors",
693
+ "language_model.layers.29.linear_attn.out_proj.weight": "model-00004-of-00004.safetensors",
694
+ "language_model.layers.29.linear_attn.in_proj_qkv.weight": "model-00004-of-00004.safetensors",
695
+ "language_model.layers.29.linear_attn.in_proj_z.weight": "model-00004-of-00004.safetensors",
696
+ "language_model.layers.29.linear_attn.in_proj_b.weight": "model-00004-of-00004.safetensors",
697
+ "language_model.layers.29.linear_attn.in_proj_a.weight": "model-00004-of-00004.safetensors",
698
+ "language_model.layers.29.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
699
+ "language_model.layers.29.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
700
+ "language_model.layers.29.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
701
+ "language_model.layers.29.input_layernorm.weight": "model-00004-of-00004.safetensors",
702
+ "language_model.layers.29.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
703
+ "language_model.layers.30.linear_attn.dt_bias": "model-00004-of-00004.safetensors",
704
+ "language_model.layers.30.linear_attn.A_log": "model-00004-of-00004.safetensors",
705
+ "language_model.layers.30.linear_attn.conv1d.weight": "model-00004-of-00004.safetensors",
706
+ "language_model.layers.30.linear_attn.norm.weight": "model-00004-of-00004.safetensors",
707
+ "language_model.layers.30.linear_attn.out_proj.weight": "model-00004-of-00004.safetensors",
708
+ "language_model.layers.30.linear_attn.in_proj_qkv.weight": "model-00004-of-00004.safetensors",
709
+ "language_model.layers.30.linear_attn.in_proj_z.weight": "model-00004-of-00004.safetensors",
710
+ "language_model.layers.30.linear_attn.in_proj_b.weight": "model-00004-of-00004.safetensors",
711
+ "language_model.layers.30.linear_attn.in_proj_a.weight": "model-00004-of-00004.safetensors",
712
+ "language_model.layers.30.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
713
+ "language_model.layers.30.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
714
+ "language_model.layers.30.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
715
+ "language_model.layers.30.input_layernorm.weight": "model-00004-of-00004.safetensors",
716
+ "language_model.layers.30.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
717
+ "language_model.layers.31.self_attn.q_proj.weight": "model-00004-of-00004.safetensors",
718
+ "language_model.layers.31.self_attn.k_proj.weight": "model-00004-of-00004.safetensors",
719
+ "language_model.layers.31.self_attn.v_proj.weight": "model-00004-of-00004.safetensors",
720
+ "language_model.layers.31.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
721
+ "language_model.layers.31.self_attn.q_norm.weight": "model-00004-of-00004.safetensors",
722
+ "language_model.layers.31.self_attn.k_norm.weight": "model-00004-of-00004.safetensors",
723
+ "language_model.layers.31.mlp.gate_proj.weight": "model-00004-of-00004.safetensors",
724
+ "language_model.layers.31.mlp.up_proj.weight": "model-00004-of-00004.safetensors",
725
+ "language_model.layers.31.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
726
+ "language_model.layers.31.input_layernorm.weight": "model-00004-of-00004.safetensors",
727
+ "language_model.layers.31.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
728
+ "language_model.norm.weight": "model-00004-of-00004.safetensors",
729
+ "custom_text_proj.weight": "model-00004-of-00004.safetensors",
730
+ "custom_text_proj.bias": "model-00004-of-00004.safetensors",
731
+ "shared_expert.net.0.weight": "model-00004-of-00004.safetensors",
732
+ "shared_expert.net.0.bias": "model-00004-of-00004.safetensors",
733
+ "shared_expert.net.1.weight": "model-00004-of-00004.safetensors",
734
+ "shared_expert.net.1.bias": "model-00004-of-00004.safetensors",
735
+ "shared_expert.net.3.weight": "model-00004-of-00004.safetensors",
736
+ "shared_expert.net.3.bias": "model-00004-of-00004.safetensors",
737
+ "latent_experts.0.net.0.weight": "model-00004-of-00004.safetensors",
738
+ "latent_experts.0.net.0.bias": "model-00004-of-00004.safetensors",
739
+ "latent_experts.0.net.1.weight": "model-00004-of-00004.safetensors",
740
+ "latent_experts.0.net.1.bias": "model-00004-of-00004.safetensors",
741
+ "latent_experts.0.net.3.weight": "model-00004-of-00004.safetensors",
742
+ "latent_experts.0.net.3.bias": "model-00004-of-00004.safetensors",
743
+ "latent_experts.1.net.0.weight": "model-00004-of-00004.safetensors",
744
+ "latent_experts.1.net.0.bias": "model-00004-of-00004.safetensors",
745
+ "latent_experts.1.net.1.weight": "model-00004-of-00004.safetensors",
746
+ "latent_experts.1.net.1.bias": "model-00004-of-00004.safetensors",
747
+ "latent_experts.1.net.3.weight": "model-00004-of-00004.safetensors",
748
+ "latent_experts.1.net.3.bias": "model-00004-of-00004.safetensors",
749
+ "latent_experts.2.net.0.weight": "model-00004-of-00004.safetensors",
750
+ "latent_experts.2.net.0.bias": "model-00004-of-00004.safetensors",
751
+ "latent_experts.2.net.1.weight": "model-00004-of-00004.safetensors",
752
+ "latent_experts.2.net.1.bias": "model-00004-of-00004.safetensors",
753
+ "latent_experts.2.net.3.weight": "model-00004-of-00004.safetensors",
754
+ "latent_experts.2.net.3.bias": "model-00004-of-00004.safetensors",
755
+ "latent_experts.3.net.0.weight": "model-00004-of-00004.safetensors",
756
+ "latent_experts.3.net.0.bias": "model-00004-of-00004.safetensors",
757
+ "latent_experts.3.net.1.weight": "model-00004-of-00004.safetensors",
758
+ "latent_experts.3.net.1.bias": "model-00004-of-00004.safetensors",
759
+ "latent_experts.3.net.3.weight": "model-00004-of-00004.safetensors",
760
+ "latent_experts.3.net.3.bias": "model-00004-of-00004.safetensors",
761
+ "region_router.0.weight": "model-00004-of-00004.safetensors",
762
+ "region_router.0.bias": "model-00004-of-00004.safetensors",
763
+ "region_router.1.weight": "model-00004-of-00004.safetensors",
764
+ "region_router.1.bias": "model-00004-of-00004.safetensors",
765
+ "region_router.3.weight": "model-00004-of-00004.safetensors",
766
+ "region_router.3.bias": "model-00004-of-00004.safetensors",
767
+ "region_coord_proj.weight": "model-00004-of-00004.safetensors",
768
+ "query_context_proj.weight": "model-00004-of-00004.safetensors",
769
+ "gate_scalars.shared": "model-00004-of-00004.safetensors",
770
+ "gate_scalars.specialist": "model-00004-of-00004.safetensors"
771
+ }
772
+ }
modeling_argus.py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Argus: Region-Aware Query-Conditioned Mixture of Experts for Visual Document Retrieval.
2
+
3
+ Self-contained model implementation for the Argus-Colqwen3.5-9B release.
4
+
5
+ Usage
6
+ -----
7
+ >>> from transformers import AutoModel, AutoProcessor
8
+ >>> model = AutoModel.from_pretrained(
9
+ ... "DataScience-UIBK/Argus-Colqwen3.5-9B-v0",
10
+ ... trust_remote_code=True,
11
+ ... torch_dtype="bfloat16",
12
+ ... ).eval().cuda()
13
+ >>> proc = AutoProcessor.from_pretrained(
14
+ ... "DataScience-UIBK/Argus-Colqwen3.5-9B-v0",
15
+ ... trust_remote_code=True,
16
+ ... )
17
+ >>> q_emb = model.encode_queries(proc, ["what is the revenue in 2019?"])
18
+ >>> d_emb = model.encode_images(proc, [pil_image_1, pil_image_2])
19
+ >>> scores = model.score(q_emb, d_emb) # shape [num_queries, num_docs]
20
+ """
21
+ from __future__ import annotations
22
+
23
+ from dataclasses import dataclass
24
+ from math import ceil
25
+ from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union
26
+
27
+ import torch
28
+ import torch.nn.functional as F
29
+ from torch import nn
30
+ from transformers.utils import ModelOutput
31
+
32
+ try:
33
+ from transformers.models.qwen3_5 import Qwen3_5Config, Qwen3_5Model
34
+ except ImportError:
35
+ try:
36
+ from transformers.models.qwen3_5 import Qwen35Config as Qwen3_5Config
37
+ from transformers.models.qwen3_5 import Qwen35Model as Qwen3_5Model
38
+ except ImportError as exc:
39
+ raise ImportError(
40
+ "Argus requires a transformers build that exposes the Qwen3.5 VL "
41
+ "classes (transformers.models.qwen3_5). Upgrade to transformers "
42
+ ">= 4.57.0.dev0."
43
+ ) from exc
44
+
45
+ from .configuration_argus import ArgusConfig
46
+
47
+
48
+ # --------------------------------------------------------------------------- #
49
+ # Output container
50
+ # --------------------------------------------------------------------------- #
51
+
52
+ @dataclass
53
+ class ArgusOutput(ModelOutput):
54
+ """Output of :meth:`ArgusForRetrieval.forward`.
55
+
56
+ Attributes:
57
+ embeddings: multi-vector token embeddings [B, T, D]. Use ``score`` /
58
+ ``score_multi_vector`` against queries encoded the same way.
59
+ region_embeddings: pooled region-level document embeddings [B, R, D]
60
+ (only populated when images are in the batch).
61
+ region_mask: valid mask for region_embeddings, shape [B, R].
62
+ routing_logits: raw MoE router logits [B, R, E] (per-region, per-expert).
63
+ """
64
+ embeddings: torch.Tensor
65
+ region_embeddings: Optional[torch.Tensor] = None
66
+ region_mask: Optional[torch.Tensor] = None
67
+ routing_logits: Optional[torch.Tensor] = None
68
+
69
+
70
+ # --------------------------------------------------------------------------- #
71
+ # MoE building blocks
72
+ # --------------------------------------------------------------------------- #
73
+
74
+ def _ceil_to_multiple(value: int, multiple: int) -> int:
75
+ return int(ceil(value / multiple) * multiple)
76
+
77
+
78
+ class SharedDenseExpert(nn.Module):
79
+ """Shared expert applied to every spatial location."""
80
+
81
+ def __init__(self, hidden_dim: int, expansion: int = 4):
82
+ super().__init__()
83
+ self.net = nn.Sequential(
84
+ nn.LayerNorm(hidden_dim),
85
+ nn.Linear(hidden_dim, hidden_dim * expansion),
86
+ nn.GELU(),
87
+ nn.Linear(hidden_dim * expansion, hidden_dim),
88
+ )
89
+
90
+ def forward(self, grid: torch.Tensor) -> torch.Tensor:
91
+ return self.net(grid)
92
+
93
+
94
+ class LatentSpatialExpert(nn.Module):
95
+ """One of ``num_specialists`` region-level experts routed by the query."""
96
+
97
+ def __init__(self, hidden_dim: int, expansion: int = 2):
98
+ super().__init__()
99
+ self.net = nn.Sequential(
100
+ nn.LayerNorm(hidden_dim),
101
+ nn.Linear(hidden_dim, hidden_dim * expansion),
102
+ nn.GELU(),
103
+ nn.Linear(hidden_dim * expansion, hidden_dim),
104
+ )
105
+
106
+ def forward(self, grid: torch.Tensor) -> torch.Tensor:
107
+ return self.net(grid)
108
+
109
+
110
+ class GateScalars(nn.Module):
111
+ """Two learnable scalars whose sigmoids weight shared / specialist expert
112
+ contributions onto the final hidden states.
113
+ """
114
+
115
+ def __init__(self, shared_init: float = 0.0, specialist_init: float = 0.0):
116
+ super().__init__()
117
+ self.shared = nn.Parameter(torch.tensor(float(shared_init), dtype=torch.float32))
118
+ self.specialist = nn.Parameter(torch.tensor(float(specialist_init), dtype=torch.float32))
119
+
120
+ def _apply(self, fn): # noqa: D401 - keep fp32 even after .to(dtype)
121
+ super()._apply(fn)
122
+ for name in ("shared", "specialist"):
123
+ param = getattr(self, name)
124
+ if param.dtype != torch.float32:
125
+ param.data = param.data.to(torch.float32)
126
+ return self
127
+
128
+ def sigmoid(self) -> Tuple[torch.Tensor, torch.Tensor]:
129
+ return torch.sigmoid(self.shared), torch.sigmoid(self.specialist)
130
+
131
+
132
+ # --------------------------------------------------------------------------- #
133
+ # Argus model
134
+ # --------------------------------------------------------------------------- #
135
+
136
+ class ArgusForRetrieval(Qwen3_5Model):
137
+ """Argus multi-vector visual document retriever.
138
+
139
+ Structure:
140
+
141
+ - Backbone: Qwen3.5-VL (9B) — produces per-token hidden states.
142
+ - Region pool: non-overlapping ``region_size × region_size`` blocks over
143
+ the vision-token grid; gives a compact region-level view.
144
+ - Router: per-region MLP → ``num_specialists`` logits; the query (if given)
145
+ biases the logits via ``query_context_proj``. Top-k sparse softmax.
146
+ - Experts: one shared expert (applied everywhere) + ``num_specialists``
147
+ latent spatial experts (per-region weighted sum).
148
+ - Fusion: ``final_hidden = final_hidden + σ(gate_shared) · shared_expert
149
+ + σ(gate_specialist) · specialist_sum``.
150
+ - Retrieval head: ``custom_text_proj`` projects fused hidden states to
151
+ ``retrieval_dim`` multi-vectors, L2-normalized.
152
+ - Query side: no MoE; just backbone + ``custom_text_proj``.
153
+
154
+ The user-facing helpers are ``encode_images``, ``encode_queries``, and
155
+ ``score`` (MaxSim). All live on this class so a downstream user can do
156
+ everything via ``model.<method>``.
157
+ """
158
+
159
+ config_class = ArgusConfig
160
+ main_input_name: ClassVar[str] = "input_ids"
161
+
162
+ def __init__(self, config: Union[ArgusConfig, Qwen3_5Config], **kwargs):
163
+ # Accept either an ArgusConfig or a plain Qwen3_5Config with extra attrs
164
+ # (transformers sometimes hands us a base-class instance during Auto*
165
+ # dispatch before config_class kicks in).
166
+ if not isinstance(config, ArgusConfig):
167
+ promoted = ArgusConfig(**config.to_dict())
168
+ config = promoted
169
+
170
+ dtype = kwargs.pop("dtype", kwargs.pop("torch_dtype", None))
171
+ attn_impl = kwargs.pop("attn_implementation", None)
172
+ use_cache = kwargs.pop("use_cache", None)
173
+
174
+ if hasattr(config, "text_config") and getattr(config.text_config, "rope_scaling", None) is None:
175
+ config.text_config.rope_scaling = {}
176
+ if getattr(config, "rope_scaling", None) is None:
177
+ config.rope_scaling = {}
178
+
179
+ super().__init__(config=config)
180
+
181
+ hidden_size = getattr(config, "hidden_size", None) or getattr(config.text_config, "hidden_size", None)
182
+ if hidden_size is None:
183
+ raise ValueError("Argus: could not determine backbone hidden_size from config.")
184
+
185
+ self.retrieval_dim = int(config.retrieval_dim)
186
+ self.num_specialists = int(config.num_specialists)
187
+ self.top_k_experts = max(1, min(int(config.top_k_experts), self.num_specialists))
188
+ self.region_size = int(config.region_size)
189
+ self.router_layer_index = int(config.router_layer_index)
190
+ self.router_temperature = float(config.router_temperature)
191
+ self.router_noise_std = float(config.router_noise_std)
192
+ self.mask_non_image_embeddings = bool(config.mask_non_image_embeddings)
193
+ self.spatial_merge_size = getattr(config.vision_config, "spatial_merge_size", 1)
194
+ self.padding_side = "left"
195
+
196
+ self.custom_text_proj = nn.Linear(hidden_size, self.retrieval_dim)
197
+ self.shared_expert = SharedDenseExpert(hidden_size)
198
+ self.latent_experts = nn.ModuleList(
199
+ LatentSpatialExpert(hidden_size) for _ in range(self.num_specialists)
200
+ )
201
+ self.region_router = nn.Sequential(
202
+ nn.LayerNorm(hidden_size),
203
+ nn.Linear(hidden_size, hidden_size),
204
+ nn.GELU(),
205
+ nn.Linear(hidden_size, self.num_specialists),
206
+ )
207
+ self.region_coord_proj = nn.Linear(4, hidden_size, bias=False)
208
+ self.query_context_proj = nn.Linear(self.retrieval_dim, hidden_size, bias=False)
209
+ self.gate_scalars = GateScalars(
210
+ shared_init=config.shared_gate_init,
211
+ specialist_init=config.specialist_gate_init,
212
+ )
213
+
214
+ self.post_init()
215
+
216
+ if dtype is not None:
217
+ self.to(dtype=dtype)
218
+ if use_cache is not None:
219
+ self.config.use_cache = use_cache
220
+ if attn_impl is not None and hasattr(self, "set_attn_implementation"):
221
+ self.set_attn_implementation(attn_impl)
222
+
223
+ # ----------------------------------------------------------------- #
224
+ # Forward
225
+ # ----------------------------------------------------------------- #
226
+
227
+ def build_query_router_context(
228
+ self,
229
+ query_embeddings: torch.Tensor,
230
+ attention_mask: Optional[torch.Tensor] = None,
231
+ ) -> torch.Tensor:
232
+ """Pool query multi-vectors into one normalized vector per query.
233
+
234
+ Used to bias the MoE router when the query is known at doc-encode
235
+ time (cross-encoder-style, optional). Safe to call with query-only
236
+ outputs of :meth:`forward`.
237
+ """
238
+ if attention_mask is None:
239
+ pooled = query_embeddings.mean(dim=1)
240
+ else:
241
+ weights = attention_mask.unsqueeze(-1).to(query_embeddings.dtype)
242
+ pooled = (query_embeddings * weights).sum(dim=1) / weights.sum(dim=1).clamp_min(1.0)
243
+ return pooled / pooled.norm(dim=-1, keepdim=True).clamp_min(1e-12)
244
+
245
+ def forward(self, *args, **kwargs) -> ArgusOutput:
246
+ """Run backbone + MoE + retrieval head.
247
+
248
+ Inputs follow the standard Qwen3-VL processor outputs:
249
+ ``input_ids``, ``attention_mask``, and (for images) ``pixel_values``
250
+ + ``image_grid_thw``. ``query_context`` is optional and, when given,
251
+ biases the router for this batch.
252
+ """
253
+ kwargs.pop("region_labels", None)
254
+ kwargs.pop("region_mask", None)
255
+ query_context = kwargs.pop("query_context", None)
256
+ image_grid_thw = kwargs.get("image_grid_thw")
257
+
258
+ # Processor may return per-image padded pixel tensors; the backbone
259
+ # wants them flat-concatenated.
260
+ if "pixel_values" in kwargs and image_grid_thw is not None:
261
+ offsets = image_grid_thw[:, 1] * image_grid_thw[:, 2]
262
+ kwargs["pixel_values"] = torch.cat(
263
+ [pv[:off] for pv, off in zip(kwargs["pixel_values"], offsets)],
264
+ dim=0,
265
+ )
266
+
267
+ kwargs.pop("return_dict", True)
268
+ kwargs.pop("output_hidden_states", None)
269
+ kwargs.pop("use_cache", None)
270
+
271
+ outputs = super().forward(
272
+ *args,
273
+ **kwargs,
274
+ use_cache=False,
275
+ output_hidden_states=True,
276
+ return_dict=True,
277
+ )
278
+
279
+ final_hidden = outputs.last_hidden_state
280
+ router_hidden = outputs.hidden_states[self.router_layer_index]
281
+ del outputs.hidden_states
282
+ attention_mask = kwargs["attention_mask"]
283
+
284
+ region_embeddings_list: List[torch.Tensor] = []
285
+ routing_logits_list: List[torch.Tensor] = []
286
+ routing_mask_list: List[torch.Tensor] = []
287
+
288
+ if "pixel_values" in kwargs and "input_ids" in kwargs:
289
+ image_mask = kwargs["input_ids"] == self.config.image_token_id
290
+ for batch_idx in range(final_hidden.size(0)):
291
+ image_positions = image_mask[batch_idx].nonzero(as_tuple=False).squeeze(-1)
292
+ if image_positions.numel() == 0:
293
+ region_embeddings_list.append(final_hidden.new_zeros(0, self.retrieval_dim))
294
+ routing_logits_list.append(final_hidden.new_zeros(0, self.num_specialists))
295
+ routing_mask_list.append(final_hidden.new_zeros(0, dtype=torch.bool))
296
+ continue
297
+
298
+ grid_t = int(image_grid_thw[batch_idx, 0].item())
299
+ raw_grid_h = int(image_grid_thw[batch_idx, 1].item())
300
+ raw_grid_w = int(image_grid_thw[batch_idx, 2].item())
301
+ grid_h = max(1, raw_grid_h // self.spatial_merge_size)
302
+ grid_w = max(1, raw_grid_w // self.spatial_merge_size)
303
+ num_image_tokens = min(grid_t * grid_h * grid_w, image_positions.numel())
304
+ image_positions = image_positions[:num_image_tokens]
305
+
306
+ early_grid = router_hidden[batch_idx, image_positions].view(grid_t, grid_h, grid_w, -1).mean(dim=0)
307
+ final_grid = final_hidden[batch_idx, image_positions].view(grid_t, grid_h, grid_w, -1).mean(dim=0)
308
+ query_context_i = None if query_context is None else query_context[batch_idx]
309
+
310
+ fused_grid, pooled_regions, pooled_mask, logits = self._apply_query_conditioned_moe(
311
+ early_grid=early_grid,
312
+ final_grid=final_grid,
313
+ query_context=query_context_i,
314
+ )
315
+
316
+ fused_tokens = (
317
+ fused_grid.unsqueeze(0)
318
+ .expand(grid_t, -1, -1, -1)
319
+ .reshape(num_image_tokens, -1)
320
+ .to(final_hidden.dtype)
321
+ )
322
+ final_hidden[batch_idx, image_positions] = fused_tokens
323
+ projected_regions = self.custom_text_proj(pooled_regions)
324
+ projected_regions = projected_regions / projected_regions.norm(dim=-1, keepdim=True).clamp_min(1e-12)
325
+ region_embeddings_list.append(projected_regions)
326
+ routing_logits_list.append(logits)
327
+ routing_mask_list.append(pooled_mask)
328
+
329
+ embeddings = self.custom_text_proj(final_hidden)
330
+ embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True).clamp_min(1e-12)
331
+ embeddings = embeddings * attention_mask.unsqueeze(-1)
332
+
333
+ if "pixel_values" in kwargs and self.mask_non_image_embeddings and "input_ids" in kwargs:
334
+ embeddings = embeddings * (kwargs["input_ids"] == self.config.image_token_id).unsqueeze(-1)
335
+
336
+ region_embeddings, padded_routing_logits, padded_routing_mask = self._pad_regions(
337
+ region_embeddings_list,
338
+ routing_logits_list,
339
+ routing_mask_list,
340
+ device=embeddings.device,
341
+ dtype=embeddings.dtype,
342
+ )
343
+
344
+ return ArgusOutput(
345
+ embeddings=embeddings,
346
+ region_embeddings=region_embeddings,
347
+ region_mask=padded_routing_mask,
348
+ routing_logits=padded_routing_logits,
349
+ )
350
+
351
+ # ----------------------------------------------------------------- #
352
+ # MoE internals
353
+ # ----------------------------------------------------------------- #
354
+
355
+ def _topk_sparse_probs(self, routing_logits: torch.Tensor) -> torch.Tensor:
356
+ logits = routing_logits.float()
357
+ if self.training and self.router_noise_std > 0:
358
+ logits = logits + self.router_noise_std * torch.randn_like(logits)
359
+ if self.top_k_experts >= self.num_specialists:
360
+ return F.softmax(logits / max(self.router_temperature, 1e-6), dim=-1).to(routing_logits.dtype)
361
+
362
+ topk_values, topk_indices = torch.topk(logits, k=self.top_k_experts, dim=-1)
363
+ sparse_logits = torch.full_like(logits, float("-inf"))
364
+ sparse_logits.scatter_(-1, topk_indices, topk_values)
365
+ probs = F.softmax(sparse_logits / max(self.router_temperature, 1e-6), dim=-1)
366
+ return probs.to(routing_logits.dtype)
367
+
368
+ def _apply_query_conditioned_moe(
369
+ self,
370
+ early_grid: torch.Tensor,
371
+ final_grid: torch.Tensor,
372
+ query_context: Optional[torch.Tensor] = None,
373
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
374
+ region_tokens, pooled_mask, coords, region_shape = self._pool_regions(early_grid)
375
+ router_input = region_tokens + self.region_coord_proj(coords.to(region_tokens.dtype))
376
+ if query_context is not None:
377
+ query_bias = self.query_context_proj(query_context.to(region_tokens.dtype)).unsqueeze(0)
378
+ router_input = router_input + query_bias
379
+
380
+ routing_logits = self.region_router(router_input)
381
+ routing_probs = self._topk_sparse_probs(routing_logits)
382
+
383
+ shared_out = self.shared_expert(final_grid)
384
+ specialist_outputs = torch.stack([expert(final_grid) for expert in self.latent_experts], dim=-2)
385
+ patch_probs = self._broadcast_region_probs(routing_probs, region_shape, final_grid.shape[:2])
386
+ specialist_out = (specialist_outputs * patch_probs.unsqueeze(-1)).sum(dim=-2)
387
+ shared_sig, specialist_sig = self.gate_scalars.sigmoid()
388
+ fused_grid = (
389
+ final_grid
390
+ + shared_sig.to(final_grid.dtype) * shared_out
391
+ + specialist_sig.to(final_grid.dtype) * specialist_out
392
+ )
393
+
394
+ pooled_regions, pooled_region_mask, _, _ = self._pool_regions(fused_grid)
395
+ return fused_grid, pooled_regions, pooled_region_mask, routing_logits
396
+
397
+ def _pool_regions(
398
+ self,
399
+ grid: torch.Tensor,
400
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Tuple[int, int]]:
401
+ h, w, dim = grid.shape
402
+ rs = self.region_size
403
+ hp = _ceil_to_multiple(h, rs)
404
+ wp = _ceil_to_multiple(w, rs)
405
+
406
+ padded = grid.new_zeros(hp, wp, dim)
407
+ padded[:h, :w] = grid
408
+ valid = grid.new_zeros(hp, wp, 1)
409
+ valid[:h, :w] = 1
410
+
411
+ num_h = hp // rs
412
+ num_w = wp // rs
413
+ blocks = padded.view(num_h, rs, num_w, rs, dim).permute(0, 2, 1, 3, 4).reshape(num_h * num_w, rs * rs, dim)
414
+ valid_blocks = valid.view(num_h, rs, num_w, rs, 1).permute(0, 2, 1, 3, 4).reshape(num_h * num_w, rs * rs, 1)
415
+ counts = valid_blocks.sum(dim=1).clamp_min(1.0)
416
+ pooled = (blocks * valid_blocks).sum(dim=1) / counts
417
+ mask = counts.squeeze(-1) > 0.5
418
+
419
+ coords = []
420
+ for ry in range(num_h):
421
+ for rx in range(num_w):
422
+ y0 = (ry * rs) / max(h, 1)
423
+ x0 = (rx * rs) / max(w, 1)
424
+ y1 = min((ry + 1) * rs, h) / max(h, 1)
425
+ x1 = min((rx + 1) * rs, w) / max(w, 1)
426
+ coords.append([x0, y0, x1, y1])
427
+ coord_tensor = torch.tensor(coords, device=grid.device, dtype=grid.dtype)
428
+ return pooled, mask, coord_tensor, (num_h, num_w)
429
+
430
+ def _broadcast_region_probs(
431
+ self,
432
+ region_probs: torch.Tensor,
433
+ region_shape: Tuple[int, int],
434
+ grid_shape: Tuple[int, int],
435
+ ) -> torch.Tensor:
436
+ num_h, num_w = region_shape
437
+ h, w = grid_shape
438
+ rs = self.region_size
439
+ hp = num_h * rs
440
+ wp = num_w * rs
441
+ probs = region_probs.view(num_h, num_w, self.num_specialists)
442
+ probs = probs[:, :, None, None, :].expand(num_h, num_w, rs, rs, self.num_specialists)
443
+ probs = probs.permute(0, 2, 1, 3, 4).reshape(hp, wp, self.num_specialists)
444
+ return probs[:h, :w]
445
+
446
+ def _pad_regions(
447
+ self,
448
+ region_embeddings_list: List[torch.Tensor],
449
+ routing_logits_list: List[torch.Tensor],
450
+ routing_mask_list: List[torch.Tensor],
451
+ device: torch.device,
452
+ dtype: torch.dtype,
453
+ ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
454
+ if not region_embeddings_list:
455
+ return None, None, None
456
+
457
+ max_regions = max((regions.size(0) for regions in region_embeddings_list), default=0)
458
+ if max_regions == 0:
459
+ batch_size = len(region_embeddings_list)
460
+ return (
461
+ torch.zeros(batch_size, 0, self.retrieval_dim, device=device, dtype=dtype),
462
+ torch.zeros(batch_size, 0, self.num_specialists, device=device, dtype=dtype),
463
+ torch.zeros(batch_size, 0, device=device, dtype=torch.bool),
464
+ )
465
+
466
+ batch_size = len(region_embeddings_list)
467
+ padded_regions = torch.zeros(batch_size, max_regions, self.retrieval_dim, device=device, dtype=dtype)
468
+ padded_logits = torch.zeros(batch_size, max_regions, self.num_specialists, device=device, dtype=dtype)
469
+ padded_mask = torch.zeros(batch_size, max_regions, device=device, dtype=torch.bool)
470
+
471
+ for idx, (regions, logits, mask) in enumerate(zip(region_embeddings_list, routing_logits_list, routing_mask_list)):
472
+ if regions.numel() == 0:
473
+ continue
474
+ count = regions.size(0)
475
+ padded_regions[idx, :count] = regions.to(dtype)
476
+ padded_logits[idx, : logits.size(0)] = logits.to(dtype)
477
+ padded_mask[idx, : mask.numel()] = mask.to(torch.bool)
478
+
479
+ return padded_regions, padded_logits, padded_mask
480
+
481
+ # ----------------------------------------------------------------- #
482
+ # User-facing helpers
483
+ # ----------------------------------------------------------------- #
484
+
485
+ @torch.inference_mode()
486
+ def encode_queries(
487
+ self,
488
+ processor,
489
+ queries: List[str],
490
+ batch_size: int = 8,
491
+ max_length: Optional[int] = None,
492
+ ) -> List[torch.Tensor]:
493
+ """Encode a list of query strings into multi-vector embeddings.
494
+
495
+ Returns one tensor per query, since queries may have different lengths.
496
+ Run this on-GPU for speed; the returned tensors are moved to CPU for
497
+ the caller to manage batching.
498
+ """
499
+ device = next(self.parameters()).device
500
+ out: List[torch.Tensor] = []
501
+ for i in range(0, len(queries), batch_size):
502
+ batch = processor.process_texts(queries[i : i + batch_size], max_length=max_length).to(device)
503
+ emb = self(**batch).embeddings.cpu()
504
+ out.extend(list(torch.unbind(emb)))
505
+ return out
506
+
507
+ @torch.inference_mode()
508
+ def encode_images(self, processor, images, batch_size: int = 2) -> List[torch.Tensor]:
509
+ """Encode a list of PIL images into multi-vector embeddings."""
510
+ device = next(self.parameters()).device
511
+ out: List[torch.Tensor] = []
512
+ for i in range(0, len(images), batch_size):
513
+ batch = processor.process_images(images[i : i + batch_size]).to(device)
514
+ emb = self(**batch).embeddings.cpu()
515
+ out.extend(list(torch.unbind(emb)))
516
+ return out
517
+
518
+ @staticmethod
519
+ def score(
520
+ qs: List[torch.Tensor],
521
+ ps: List[torch.Tensor],
522
+ batch_size: int = 32,
523
+ device: Optional[Union[str, torch.device]] = None,
524
+ ) -> torch.Tensor:
525
+ """MaxSim scoring: for each (q_i, p_j) pair, compute
526
+ ``sum_t max_p <q_i_t, p_j_p>``. Returns a [N_q, N_p] matrix.
527
+
528
+ This reproduces ``processor.score_multi_vector`` but lives on the
529
+ model so users can compute relevance without touching the processor.
530
+ """
531
+ dev = torch.device(device) if device is not None else torch.device("cpu")
532
+ n_q, n_p = len(qs), len(ps)
533
+ scores = torch.zeros(n_q, n_p, device=dev)
534
+
535
+ for qi in range(0, n_q, batch_size):
536
+ q_slice = qs[qi : qi + batch_size]
537
+ q_len = max(x.size(0) for x in q_slice)
538
+ q_pad = torch.zeros(len(q_slice), q_len, q_slice[0].size(-1), device=dev)
539
+ q_mask = torch.zeros(len(q_slice), q_len, device=dev, dtype=torch.bool)
540
+ for i, t in enumerate(q_slice):
541
+ q_pad[i, : t.size(0)] = t.to(dev)
542
+ q_mask[i, : t.size(0)] = t.abs().sum(dim=-1) > 0
543
+
544
+ for pi in range(0, n_p, batch_size):
545
+ p_slice = ps[pi : pi + batch_size]
546
+ p_len = max(x.size(0) for x in p_slice)
547
+ p_pad = torch.zeros(len(p_slice), p_len, p_slice[0].size(-1), device=dev)
548
+ for j, t in enumerate(p_slice):
549
+ p_pad[j, : t.size(0)] = t.to(dev)
550
+
551
+ sim = torch.einsum("qld,pkd->qplk", q_pad, p_pad)
552
+ maxsim = sim.max(dim=-1).values
553
+ maxsim = (maxsim * q_mask.unsqueeze(1).to(maxsim.dtype)).sum(dim=-1)
554
+ scores[qi : qi + len(q_slice), pi : pi + len(p_slice)] = maxsim
555
+
556
+ return scores
557
+
558
+
559
+ __all__ = ["ArgusForRetrieval", "ArgusOutput"]
processing_argus.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Argus: Region-Aware Query-Conditioned Mixture of Experts for Visual Document Retrieval.
2
+
3
+ Self-contained processor for Argus-Colqwen3.5-9B. Wraps the Qwen3-VL processor
4
+ (image processor + Qwen2 tokenizer + optional video processor) and adds ColPali-
5
+ style ``process_images`` / ``process_texts`` / ``score_multi_vector`` helpers.
6
+ """
7
+ from __future__ import annotations
8
+
9
+ from pathlib import Path
10
+ from typing import ClassVar, List, Optional, Tuple, Union
11
+
12
+ import torch
13
+ from PIL import Image
14
+ from transformers import BatchEncoding, BatchFeature
15
+ from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
16
+ from transformers.models.qwen3_vl import Qwen3VLProcessor
17
+
18
+
19
+ class ArgusProcessor(Qwen3VLProcessor):
20
+ """Processor for Argus-Colqwen3.5-9B.
21
+
22
+ Subclasses ``Qwen3VLProcessor`` (the Qwen3.5-9B hub repo ships that
23
+ processor class even though the LLM is Qwen3.5). Adds:
24
+
25
+ - ``process_images``: batch-encode PIL images into the exact dict the
26
+ retriever forward expects (``pixel_values``, ``image_grid_thw``,
27
+ ``input_ids``, ``attention_mask``).
28
+ - ``process_texts``: batch-encode query strings.
29
+ - ``score`` / ``score_multi_vector``: MaxSim scoring helper.
30
+ - ``max_num_visual_tokens`` knob: caps the longest-edge pixel budget per
31
+ image so long documents don't blow up the vision encoder.
32
+ """
33
+
34
+ visual_prompt_prefix: ClassVar[str] = (
35
+ "<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe the image.<|im_end|><|endoftext|>"
36
+ )
37
+ query_augmentation_token: ClassVar[str] = "<|endoftext|>"
38
+ query_prefix: ClassVar[str] = ""
39
+ image_token: ClassVar[str] = "<|image_pad|>"
40
+ # Number of <|endoftext|> tokens appended to every query — matches the
41
+ # training-time collator (``colpali_novel/data/layout_collator.py``).
42
+ # Removing or changing this number measurably hurts retrieval scores.
43
+ n_query_augmentation_tokens: ClassVar[int] = 10
44
+
45
+ def __init__(
46
+ self,
47
+ image_processor=None,
48
+ tokenizer=None,
49
+ video_processor=None,
50
+ chat_template=None,
51
+ **kwargs,
52
+ ):
53
+ # Explicit signature matters for ``ProcessorMixin``: it inspects
54
+ # __init__.__code__ to decide which modality attributes to set. A
55
+ # *args,**kwargs signature silently drops tokenizer/image_processor.
56
+ super().__init__(
57
+ image_processor=image_processor,
58
+ tokenizer=tokenizer,
59
+ video_processor=video_processor,
60
+ chat_template=chat_template,
61
+ **kwargs,
62
+ )
63
+ if getattr(self, "tokenizer", None) is not None:
64
+ self.tokenizer.padding_side = "left"
65
+
66
+ @classmethod
67
+ def from_pretrained(
68
+ cls,
69
+ pretrained_model_name_or_path,
70
+ *args,
71
+ device_map: Optional[str] = None,
72
+ max_num_visual_tokens: Optional[int] = None,
73
+ **kwargs,
74
+ ):
75
+ """Load the processor from a local folder or HF repo id.
76
+
77
+ The Qwen3.5-9B hub repo declares ``processor_class=Qwen3VLProcessor``
78
+ but ``tokenizer_class=Qwen2Tokenizer``. The stock ``Qwen3VLProcessor
79
+ .from_pretrained`` returns ``tokenizer=None`` in that case and then
80
+ crashes on ``tokenizer.convert_tokens_to_ids(self.image_token)``.
81
+ We load tokenizer + image processor via the Auto* registry
82
+ explicitly so both are real objects before ``__init__`` runs.
83
+ """
84
+ from transformers import AutoImageProcessor, AutoTokenizer
85
+
86
+ tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
87
+ image_processor = AutoImageProcessor.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
88
+
89
+ video_processor = None
90
+ try:
91
+ from transformers import AutoVideoProcessor
92
+ video_processor = AutoVideoProcessor.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
93
+ except Exception: # noqa: BLE001 — video processing is optional
94
+ video_processor = None
95
+
96
+ chat_template = None
97
+ try:
98
+ candidate = Path(str(pretrained_model_name_or_path)) / "chat_template.jinja"
99
+ if candidate.is_file():
100
+ chat_template = candidate.read_text()
101
+ except Exception: # noqa: BLE001
102
+ chat_template = None
103
+
104
+ instance = cls(
105
+ image_processor=image_processor,
106
+ tokenizer=tokenizer,
107
+ video_processor=video_processor,
108
+ chat_template=chat_template,
109
+ )
110
+
111
+ if max_num_visual_tokens is not None:
112
+ patch_size = getattr(instance.image_processor, "patch_size", None)
113
+ merge_size = getattr(instance.image_processor, "merge_size", None)
114
+ if patch_size is None or merge_size is None:
115
+ raise ValueError("Argus image processor missing patch_size or merge_size.")
116
+ tile = patch_size * merge_size
117
+ instance.image_processor.max_pixels = max_num_visual_tokens * tile * tile
118
+ instance.image_processor.size["longest_edge"] = instance.image_processor.max_pixels
119
+ return instance
120
+
121
+ # ------------------------------------------------------------------ #
122
+ # Encoding
123
+ # ------------------------------------------------------------------ #
124
+
125
+ def process_images(self, images: List[Image.Image]) -> Union[BatchFeature, BatchEncoding]:
126
+ """Encode PIL images into the backbone's expected input dict."""
127
+ images = [img.convert("RGB") for img in images]
128
+ batch_doc = self(
129
+ text=[self.visual_prompt_prefix] * len(images),
130
+ images=images,
131
+ padding="longest",
132
+ return_tensors="pt",
133
+ )
134
+ # Pack pixel_values so the forward can scatter them per image via
135
+ # image_grid_thw offsets. This mirrors the training-time collator.
136
+ offsets = batch_doc["image_grid_thw"][:, 1] * batch_doc["image_grid_thw"][:, 2]
137
+ pixel_values = list(torch.split(batch_doc["pixel_values"], offsets.tolist()))
138
+ batch_doc["pixel_values"] = torch.nn.utils.rnn.pad_sequence(pixel_values, batch_first=True)
139
+ return batch_doc
140
+
141
+ def process_texts(
142
+ self,
143
+ texts: List[str],
144
+ max_length: Optional[int] = None,
145
+ ) -> Union[BatchFeature, BatchEncoding]:
146
+ """Encode query strings into the backbone's expected input dict."""
147
+ kwargs = {"text": texts, "return_tensors": "pt", "padding": "longest"}
148
+ if max_length is not None:
149
+ kwargs["max_length"] = max_length
150
+ kwargs["truncation"] = True
151
+ return self(**kwargs)
152
+
153
+ def process_queries(
154
+ self,
155
+ queries: Optional[List[str]] = None,
156
+ texts: Optional[List[str]] = None,
157
+ max_length: Optional[int] = None,
158
+ suffix: Optional[str] = None,
159
+ ) -> Union[BatchFeature, BatchEncoding]:
160
+ """Encode queries with the training-time augmentation:
161
+ ``query_prefix + query + query_augmentation_token * n_query_augmentation_tokens``.
162
+
163
+ Mirrors ``colpali_engine.utils.processing_utils.BaseVisualRetrieverProcessor
164
+ .process_queries`` and the Argus training collator. The default 10 trailing
165
+ ``<|endoftext|>`` tokens are not optional — without them, MaxSim scoring
166
+ drops several nDCG points because the query has fewer active multi-vectors.
167
+ """
168
+ if texts is not None and queries is not None:
169
+ raise ValueError("Only one of 'texts' or 'queries' should be provided.")
170
+ if queries is None:
171
+ queries = texts
172
+ if queries is None:
173
+ raise ValueError("No queries provided.")
174
+
175
+ if suffix is None:
176
+ suffix = self.query_augmentation_token * self.n_query_augmentation_tokens
177
+
178
+ wrapped = [self.query_prefix + q + suffix for q in queries]
179
+ return self.process_texts(wrapped, max_length=max_length)
180
+
181
+ # ------------------------------------------------------------------ #
182
+ # Scoring
183
+ # ------------------------------------------------------------------ #
184
+
185
+ def score(
186
+ self,
187
+ qs: List[torch.Tensor],
188
+ ps: List[torch.Tensor],
189
+ device: Optional[Union[str, torch.device]] = None,
190
+ **kwargs,
191
+ ) -> torch.Tensor:
192
+ """Alias for ``score_multi_vector`` (MaxSim over multi-vectors)."""
193
+ return self.score_multi_vector(qs, ps, device=device, **kwargs)
194
+
195
+ def score_multi_vector(
196
+ self,
197
+ qs: List[torch.Tensor],
198
+ ps: List[torch.Tensor],
199
+ batch_size: int = 128,
200
+ device: Optional[Union[str, torch.device]] = None,
201
+ ) -> torch.Tensor:
202
+ """Compute an [N_q, N_p] score matrix via MaxSim (ColBERT scoring).
203
+
204
+ For each (q, p) pair: ``sum_t max_p <q_t, p_p>``. Inputs are the raw
205
+ (potentially ragged) per-sample multi-vector tensors returned by
206
+ :meth:`encode_queries` / :meth:`encode_images`.
207
+ """
208
+ dev = torch.device(device) if device is not None else torch.device("cpu")
209
+ n_q, n_p = len(qs), len(ps)
210
+ scores = torch.zeros(n_q, n_p, device=dev)
211
+
212
+ for qi in range(0, n_q, batch_size):
213
+ q_slice = qs[qi : qi + batch_size]
214
+ q_len = max(x.size(0) for x in q_slice)
215
+ q_pad = torch.zeros(len(q_slice), q_len, q_slice[0].size(-1), device=dev)
216
+ q_mask = torch.zeros(len(q_slice), q_len, device=dev, dtype=torch.bool)
217
+ for i, t in enumerate(q_slice):
218
+ q_pad[i, : t.size(0)] = t.to(dev)
219
+ q_mask[i, : t.size(0)] = t.abs().sum(dim=-1) > 0
220
+
221
+ for pi in range(0, n_p, batch_size):
222
+ p_slice = ps[pi : pi + batch_size]
223
+ p_len = max(x.size(0) for x in p_slice)
224
+ p_pad = torch.zeros(len(p_slice), p_len, p_slice[0].size(-1), device=dev)
225
+ for j, t in enumerate(p_slice):
226
+ p_pad[j, : t.size(0)] = t.to(dev)
227
+
228
+ sim = torch.einsum("qld,pkd->qplk", q_pad, p_pad)
229
+ maxsim = sim.max(dim=-1).values
230
+ maxsim = (maxsim * q_mask.unsqueeze(1).to(maxsim.dtype)).sum(dim=-1)
231
+ scores[qi : qi + len(q_slice), pi : pi + len(p_slice)] = maxsim
232
+
233
+ return scores
234
+
235
+ # ------------------------------------------------------------------ #
236
+ # Misc helpers (match colpali-engine BaseVisualRetrieverProcessor API)
237
+ # ------------------------------------------------------------------ #
238
+
239
+ def get_n_patches(
240
+ self,
241
+ image_size: Tuple[int, int],
242
+ spatial_merge_size: int,
243
+ ) -> Tuple[int, int]:
244
+ patch_size = self.image_processor.patch_size
245
+ height_new, width_new = smart_resize(
246
+ width=image_size[0],
247
+ height=image_size[1],
248
+ factor=patch_size * self.image_processor.merge_size,
249
+ min_pixels=self.image_processor.size["shortest_edge"],
250
+ max_pixels=self.image_processor.size["longest_edge"],
251
+ )
252
+ n_patches_x = width_new // patch_size // spatial_merge_size
253
+ n_patches_y = height_new // patch_size // spatial_merge_size
254
+ return n_patches_x, n_patches_y
255
+
256
+ def get_image_mask(self, batch_images: BatchFeature) -> torch.Tensor:
257
+ return batch_images.input_ids == self.image_token_id
258
+
259
+
260
+ __all__ = ["ArgusProcessor"]
processor_config.json ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "image_processor": {
3
+ "do_convert_rgb": true,
4
+ "do_normalize": true,
5
+ "do_rescale": true,
6
+ "do_resize": true,
7
+ "image_mean": [
8
+ 0.5,
9
+ 0.5,
10
+ 0.5
11
+ ],
12
+ "image_processor_type": "Qwen2VLImageProcessor",
13
+ "image_std": [
14
+ 0.5,
15
+ 0.5,
16
+ 0.5
17
+ ],
18
+ "max_pixels": 2097152,
19
+ "merge_size": 2,
20
+ "patch_size": 16,
21
+ "resample": 3,
22
+ "rescale_factor": 0.00392156862745098,
23
+ "size": {
24
+ "longest_edge": 2097152,
25
+ "shortest_edge": 65536
26
+ },
27
+ "temporal_patch_size": 2
28
+ },
29
+ "processor_class": "ArgusProcessor",
30
+ "video_processor": {
31
+ "do_convert_rgb": true,
32
+ "do_normalize": true,
33
+ "do_rescale": true,
34
+ "do_resize": true,
35
+ "do_sample_frames": true,
36
+ "fps": 2,
37
+ "image_mean": [
38
+ 0.5,
39
+ 0.5,
40
+ 0.5
41
+ ],
42
+ "image_std": [
43
+ 0.5,
44
+ 0.5,
45
+ 0.5
46
+ ],
47
+ "max_frames": 768,
48
+ "merge_size": 2,
49
+ "min_frames": 4,
50
+ "patch_size": 16,
51
+ "resample": 3,
52
+ "rescale_factor": 0.00392156862745098,
53
+ "return_metadata": false,
54
+ "size": {
55
+ "longest_edge": 25165824,
56
+ "shortest_edge": 4096
57
+ },
58
+ "temporal_patch_size": 2,
59
+ "video_processor_type": "Qwen3VLVideoProcessor"
60
+ },
61
+ "auto_map": {
62
+ "AutoProcessor": "processing_argus.ArgusProcessor"
63
+ }
64
+ }
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:06b9509352d2af50381ab2247e083b80d32d5c0aba91c272ca9ff729b6a0e523
3
+ size 19989325
tokenizer_config.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "audio_bos_token": "<|audio_start|>",
4
+ "audio_eos_token": "<|audio_end|>",
5
+ "audio_token": "<|audio_pad|>",
6
+ "backend": "tokenizers",
7
+ "bos_token": null,
8
+ "clean_up_tokenization_spaces": false,
9
+ "eos_token": "<|im_end|>",
10
+ "errors": "replace",
11
+ "image_token": "<|image_pad|>",
12
+ "is_local": true,
13
+ "local_files_only": true,
14
+ "model_max_length": 262144,
15
+ "model_specific_special_tokens": {
16
+ "audio_bos_token": "<|audio_start|>",
17
+ "audio_eos_token": "<|audio_end|>",
18
+ "audio_token": "<|audio_pad|>",
19
+ "image_token": "<|image_pad|>",
20
+ "video_token": "<|video_pad|>",
21
+ "vision_bos_token": "<|vision_start|>",
22
+ "vision_eos_token": "<|vision_end|>"
23
+ },
24
+ "pad_token": "<|endoftext|>",
25
+ "pretokenize_regex": "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?[\\p{L}\\p{M}]+|\\p{N}| ?[^\\s\\p{L}\\p{M}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
26
+ "processor_class": "ArgusProcessor",
27
+ "split_special_tokens": false,
28
+ "tokenizer_class": "Qwen2Tokenizer",
29
+ "unk_token": null,
30
+ "video_token": "<|video_pad|>",
31
+ "vision_bos_token": "<|vision_start|>",
32
+ "vision_eos_token": "<|vision_end|>"
33
+ }