Argus-Colqwen3.5-4b-v0 / modeling_argus.py
abdoelsayed's picture
Initial fp32 release
fedffec verified
"""Argus: Region-Aware Query-Conditioned Mixture of Experts for Visual Document Retrieval.
Self-contained model implementation for the Argus-Colqwen3.5-9B release.
Usage
-----
>>> from transformers import AutoModel, AutoProcessor
>>> model = AutoModel.from_pretrained(
... "DataScience-UIBK/Argus-Colqwen3.5-9B-v0",
... trust_remote_code=True,
... torch_dtype="bfloat16",
... ).eval().cuda()
>>> proc = AutoProcessor.from_pretrained(
... "DataScience-UIBK/Argus-Colqwen3.5-9B-v0",
... trust_remote_code=True,
... )
>>> q_emb = model.encode_queries(proc, ["what is the revenue in 2019?"])
>>> d_emb = model.encode_images(proc, [pil_image_1, pil_image_2])
>>> scores = model.score(q_emb, d_emb) # shape [num_queries, num_docs]
"""
from __future__ import annotations
from dataclasses import dataclass
from math import ceil
from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import nn
from transformers.utils import ModelOutput
try:
from transformers.models.qwen3_5 import Qwen3_5Config, Qwen3_5Model
except ImportError:
try:
from transformers.models.qwen3_5 import Qwen35Config as Qwen3_5Config
from transformers.models.qwen3_5 import Qwen35Model as Qwen3_5Model
except ImportError as exc:
raise ImportError(
"Argus requires a transformers build that exposes the Qwen3.5 VL "
"classes (transformers.models.qwen3_5). Upgrade to transformers "
">= 4.57.0.dev0."
) from exc
from .configuration_argus import ArgusConfig
# --------------------------------------------------------------------------- #
# Output container
# --------------------------------------------------------------------------- #
@dataclass
class ArgusOutput(ModelOutput):
"""Output of :meth:`ArgusForRetrieval.forward`.
Attributes:
embeddings: multi-vector token embeddings [B, T, D]. Use ``score`` /
``score_multi_vector`` against queries encoded the same way.
region_embeddings: pooled region-level document embeddings [B, R, D]
(only populated when images are in the batch).
region_mask: valid mask for region_embeddings, shape [B, R].
routing_logits: raw MoE router logits [B, R, E] (per-region, per-expert).
"""
embeddings: torch.Tensor
region_embeddings: Optional[torch.Tensor] = None
region_mask: Optional[torch.Tensor] = None
routing_logits: Optional[torch.Tensor] = None
# --------------------------------------------------------------------------- #
# MoE building blocks
# --------------------------------------------------------------------------- #
def _ceil_to_multiple(value: int, multiple: int) -> int:
return int(ceil(value / multiple) * multiple)
class SharedDenseExpert(nn.Module):
"""Shared expert applied to every spatial location."""
def __init__(self, hidden_dim: int, expansion: int = 4):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(hidden_dim),
nn.Linear(hidden_dim, hidden_dim * expansion),
nn.GELU(),
nn.Linear(hidden_dim * expansion, hidden_dim),
)
def forward(self, grid: torch.Tensor) -> torch.Tensor:
return self.net(grid)
class LatentSpatialExpert(nn.Module):
"""One of ``num_specialists`` region-level experts routed by the query."""
def __init__(self, hidden_dim: int, expansion: int = 2):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(hidden_dim),
nn.Linear(hidden_dim, hidden_dim * expansion),
nn.GELU(),
nn.Linear(hidden_dim * expansion, hidden_dim),
)
def forward(self, grid: torch.Tensor) -> torch.Tensor:
return self.net(grid)
class GateScalars(nn.Module):
"""Two learnable scalars whose sigmoids weight shared / specialist expert
contributions onto the final hidden states.
"""
def __init__(self, shared_init: float = 0.0, specialist_init: float = 0.0):
super().__init__()
self.shared = nn.Parameter(torch.tensor(float(shared_init), dtype=torch.float32))
self.specialist = nn.Parameter(torch.tensor(float(specialist_init), dtype=torch.float32))
def _apply(self, fn): # noqa: D401 - keep fp32 even after .to(dtype)
super()._apply(fn)
for name in ("shared", "specialist"):
param = getattr(self, name)
if param.dtype != torch.float32:
param.data = param.data.to(torch.float32)
return self
def sigmoid(self) -> Tuple[torch.Tensor, torch.Tensor]:
return torch.sigmoid(self.shared), torch.sigmoid(self.specialist)
# --------------------------------------------------------------------------- #
# Argus model
# --------------------------------------------------------------------------- #
class ArgusForRetrieval(Qwen3_5Model):
"""Argus multi-vector visual document retriever.
Structure:
- Backbone: Qwen3.5-VL (9B) — produces per-token hidden states.
- Region pool: non-overlapping ``region_size × region_size`` blocks over
the vision-token grid; gives a compact region-level view.
- Router: per-region MLP → ``num_specialists`` logits; the query (if given)
biases the logits via ``query_context_proj``. Top-k sparse softmax.
- Experts: one shared expert (applied everywhere) + ``num_specialists``
latent spatial experts (per-region weighted sum).
- Fusion: ``final_hidden = final_hidden + σ(gate_shared) · shared_expert
+ σ(gate_specialist) · specialist_sum``.
- Retrieval head: ``custom_text_proj`` projects fused hidden states to
``retrieval_dim`` multi-vectors, L2-normalized.
- Query side: no MoE; just backbone + ``custom_text_proj``.
The user-facing helpers are ``encode_images``, ``encode_queries``, and
``score`` (MaxSim). All live on this class so a downstream user can do
everything via ``model.<method>``.
"""
config_class = ArgusConfig
main_input_name: ClassVar[str] = "input_ids"
def __init__(self, config: Union[ArgusConfig, Qwen3_5Config], **kwargs):
# Accept either an ArgusConfig or a plain Qwen3_5Config with extra attrs
# (transformers sometimes hands us a base-class instance during Auto*
# dispatch before config_class kicks in).
if not isinstance(config, ArgusConfig):
promoted = ArgusConfig(**config.to_dict())
config = promoted
dtype = kwargs.pop("dtype", kwargs.pop("torch_dtype", None))
attn_impl = kwargs.pop("attn_implementation", None)
use_cache = kwargs.pop("use_cache", None)
if hasattr(config, "text_config") and getattr(config.text_config, "rope_scaling", None) is None:
config.text_config.rope_scaling = {}
if getattr(config, "rope_scaling", None) is None:
config.rope_scaling = {}
super().__init__(config=config)
hidden_size = getattr(config, "hidden_size", None) or getattr(config.text_config, "hidden_size", None)
if hidden_size is None:
raise ValueError("Argus: could not determine backbone hidden_size from config.")
self.retrieval_dim = int(config.retrieval_dim)
self.num_specialists = int(config.num_specialists)
self.top_k_experts = max(1, min(int(config.top_k_experts), self.num_specialists))
self.region_size = int(config.region_size)
self.router_layer_index = int(config.router_layer_index)
self.router_temperature = float(config.router_temperature)
self.router_noise_std = float(config.router_noise_std)
self.mask_non_image_embeddings = bool(config.mask_non_image_embeddings)
self.spatial_merge_size = getattr(config.vision_config, "spatial_merge_size", 1)
self.padding_side = "left"
self.custom_text_proj = nn.Linear(hidden_size, self.retrieval_dim)
self.shared_expert = SharedDenseExpert(hidden_size)
self.latent_experts = nn.ModuleList(
LatentSpatialExpert(hidden_size) for _ in range(self.num_specialists)
)
self.region_router = nn.Sequential(
nn.LayerNorm(hidden_size),
nn.Linear(hidden_size, hidden_size),
nn.GELU(),
nn.Linear(hidden_size, self.num_specialists),
)
self.region_coord_proj = nn.Linear(4, hidden_size, bias=False)
self.query_context_proj = nn.Linear(self.retrieval_dim, hidden_size, bias=False)
self.gate_scalars = GateScalars(
shared_init=config.shared_gate_init,
specialist_init=config.specialist_gate_init,
)
self.post_init()
if dtype is not None:
self.to(dtype=dtype)
if use_cache is not None:
self.config.use_cache = use_cache
if attn_impl is not None and hasattr(self, "set_attn_implementation"):
self.set_attn_implementation(attn_impl)
# ----------------------------------------------------------------- #
# Forward
# ----------------------------------------------------------------- #
def build_query_router_context(
self,
query_embeddings: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Pool query multi-vectors into one normalized vector per query.
Used to bias the MoE router when the query is known at doc-encode
time (cross-encoder-style, optional). Safe to call with query-only
outputs of :meth:`forward`.
"""
if attention_mask is None:
pooled = query_embeddings.mean(dim=1)
else:
weights = attention_mask.unsqueeze(-1).to(query_embeddings.dtype)
pooled = (query_embeddings * weights).sum(dim=1) / weights.sum(dim=1).clamp_min(1.0)
return pooled / pooled.norm(dim=-1, keepdim=True).clamp_min(1e-12)
def forward(self, *args, **kwargs) -> ArgusOutput:
"""Run backbone + MoE + retrieval head.
Inputs follow the standard Qwen3-VL processor outputs:
``input_ids``, ``attention_mask``, and (for images) ``pixel_values``
+ ``image_grid_thw``. ``query_context`` is optional and, when given,
biases the router for this batch.
"""
kwargs.pop("region_labels", None)
kwargs.pop("region_mask", None)
query_context = kwargs.pop("query_context", None)
image_grid_thw = kwargs.get("image_grid_thw")
# Processor may return per-image padded pixel tensors; the backbone
# wants them flat-concatenated.
if "pixel_values" in kwargs and image_grid_thw is not None:
offsets = image_grid_thw[:, 1] * image_grid_thw[:, 2]
kwargs["pixel_values"] = torch.cat(
[pv[:off] for pv, off in zip(kwargs["pixel_values"], offsets)],
dim=0,
)
kwargs.pop("return_dict", True)
kwargs.pop("output_hidden_states", None)
kwargs.pop("use_cache", None)
outputs = super().forward(
*args,
**kwargs,
use_cache=False,
output_hidden_states=True,
return_dict=True,
)
final_hidden = outputs.last_hidden_state
router_hidden = outputs.hidden_states[self.router_layer_index]
del outputs.hidden_states
attention_mask = kwargs["attention_mask"]
region_embeddings_list: List[torch.Tensor] = []
routing_logits_list: List[torch.Tensor] = []
routing_mask_list: List[torch.Tensor] = []
if "pixel_values" in kwargs and "input_ids" in kwargs:
image_mask = kwargs["input_ids"] == self.config.image_token_id
for batch_idx in range(final_hidden.size(0)):
image_positions = image_mask[batch_idx].nonzero(as_tuple=False).squeeze(-1)
if image_positions.numel() == 0:
region_embeddings_list.append(final_hidden.new_zeros(0, self.retrieval_dim))
routing_logits_list.append(final_hidden.new_zeros(0, self.num_specialists))
routing_mask_list.append(final_hidden.new_zeros(0, dtype=torch.bool))
continue
grid_t = int(image_grid_thw[batch_idx, 0].item())
raw_grid_h = int(image_grid_thw[batch_idx, 1].item())
raw_grid_w = int(image_grid_thw[batch_idx, 2].item())
grid_h = max(1, raw_grid_h // self.spatial_merge_size)
grid_w = max(1, raw_grid_w // self.spatial_merge_size)
num_image_tokens = min(grid_t * grid_h * grid_w, image_positions.numel())
image_positions = image_positions[:num_image_tokens]
early_grid = router_hidden[batch_idx, image_positions].view(grid_t, grid_h, grid_w, -1).mean(dim=0)
final_grid = final_hidden[batch_idx, image_positions].view(grid_t, grid_h, grid_w, -1).mean(dim=0)
query_context_i = None if query_context is None else query_context[batch_idx]
fused_grid, pooled_regions, pooled_mask, logits = self._apply_query_conditioned_moe(
early_grid=early_grid,
final_grid=final_grid,
query_context=query_context_i,
)
fused_tokens = (
fused_grid.unsqueeze(0)
.expand(grid_t, -1, -1, -1)
.reshape(num_image_tokens, -1)
.to(final_hidden.dtype)
)
final_hidden[batch_idx, image_positions] = fused_tokens
projected_regions = self.custom_text_proj(pooled_regions)
projected_regions = projected_regions / projected_regions.norm(dim=-1, keepdim=True).clamp_min(1e-12)
region_embeddings_list.append(projected_regions)
routing_logits_list.append(logits)
routing_mask_list.append(pooled_mask)
embeddings = self.custom_text_proj(final_hidden)
embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True).clamp_min(1e-12)
embeddings = embeddings * attention_mask.unsqueeze(-1)
if "pixel_values" in kwargs and self.mask_non_image_embeddings and "input_ids" in kwargs:
embeddings = embeddings * (kwargs["input_ids"] == self.config.image_token_id).unsqueeze(-1)
region_embeddings, padded_routing_logits, padded_routing_mask = self._pad_regions(
region_embeddings_list,
routing_logits_list,
routing_mask_list,
device=embeddings.device,
dtype=embeddings.dtype,
)
return ArgusOutput(
embeddings=embeddings,
region_embeddings=region_embeddings,
region_mask=padded_routing_mask,
routing_logits=padded_routing_logits,
)
# ----------------------------------------------------------------- #
# MoE internals
# ----------------------------------------------------------------- #
def _topk_sparse_probs(self, routing_logits: torch.Tensor) -> torch.Tensor:
logits = routing_logits.float()
if self.training and self.router_noise_std > 0:
logits = logits + self.router_noise_std * torch.randn_like(logits)
if self.top_k_experts >= self.num_specialists:
return F.softmax(logits / max(self.router_temperature, 1e-6), dim=-1).to(routing_logits.dtype)
topk_values, topk_indices = torch.topk(logits, k=self.top_k_experts, dim=-1)
sparse_logits = torch.full_like(logits, float("-inf"))
sparse_logits.scatter_(-1, topk_indices, topk_values)
probs = F.softmax(sparse_logits / max(self.router_temperature, 1e-6), dim=-1)
return probs.to(routing_logits.dtype)
def _apply_query_conditioned_moe(
self,
early_grid: torch.Tensor,
final_grid: torch.Tensor,
query_context: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
region_tokens, pooled_mask, coords, region_shape = self._pool_regions(early_grid)
router_input = region_tokens + self.region_coord_proj(coords.to(region_tokens.dtype))
if query_context is not None:
query_bias = self.query_context_proj(query_context.to(region_tokens.dtype)).unsqueeze(0)
router_input = router_input + query_bias
routing_logits = self.region_router(router_input)
routing_probs = self._topk_sparse_probs(routing_logits)
shared_out = self.shared_expert(final_grid)
specialist_outputs = torch.stack([expert(final_grid) for expert in self.latent_experts], dim=-2)
patch_probs = self._broadcast_region_probs(routing_probs, region_shape, final_grid.shape[:2])
specialist_out = (specialist_outputs * patch_probs.unsqueeze(-1)).sum(dim=-2)
shared_sig, specialist_sig = self.gate_scalars.sigmoid()
fused_grid = (
final_grid
+ shared_sig.to(final_grid.dtype) * shared_out
+ specialist_sig.to(final_grid.dtype) * specialist_out
)
pooled_regions, pooled_region_mask, _, _ = self._pool_regions(fused_grid)
return fused_grid, pooled_regions, pooled_region_mask, routing_logits
def _pool_regions(
self,
grid: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Tuple[int, int]]:
h, w, dim = grid.shape
rs = self.region_size
hp = _ceil_to_multiple(h, rs)
wp = _ceil_to_multiple(w, rs)
padded = grid.new_zeros(hp, wp, dim)
padded[:h, :w] = grid
valid = grid.new_zeros(hp, wp, 1)
valid[:h, :w] = 1
num_h = hp // rs
num_w = wp // rs
blocks = padded.view(num_h, rs, num_w, rs, dim).permute(0, 2, 1, 3, 4).reshape(num_h * num_w, rs * rs, dim)
valid_blocks = valid.view(num_h, rs, num_w, rs, 1).permute(0, 2, 1, 3, 4).reshape(num_h * num_w, rs * rs, 1)
counts = valid_blocks.sum(dim=1).clamp_min(1.0)
pooled = (blocks * valid_blocks).sum(dim=1) / counts
mask = counts.squeeze(-1) > 0.5
coords = []
for ry in range(num_h):
for rx in range(num_w):
y0 = (ry * rs) / max(h, 1)
x0 = (rx * rs) / max(w, 1)
y1 = min((ry + 1) * rs, h) / max(h, 1)
x1 = min((rx + 1) * rs, w) / max(w, 1)
coords.append([x0, y0, x1, y1])
coord_tensor = torch.tensor(coords, device=grid.device, dtype=grid.dtype)
return pooled, mask, coord_tensor, (num_h, num_w)
def _broadcast_region_probs(
self,
region_probs: torch.Tensor,
region_shape: Tuple[int, int],
grid_shape: Tuple[int, int],
) -> torch.Tensor:
num_h, num_w = region_shape
h, w = grid_shape
rs = self.region_size
hp = num_h * rs
wp = num_w * rs
probs = region_probs.view(num_h, num_w, self.num_specialists)
probs = probs[:, :, None, None, :].expand(num_h, num_w, rs, rs, self.num_specialists)
probs = probs.permute(0, 2, 1, 3, 4).reshape(hp, wp, self.num_specialists)
return probs[:h, :w]
def _pad_regions(
self,
region_embeddings_list: List[torch.Tensor],
routing_logits_list: List[torch.Tensor],
routing_mask_list: List[torch.Tensor],
device: torch.device,
dtype: torch.dtype,
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
if not region_embeddings_list:
return None, None, None
max_regions = max((regions.size(0) for regions in region_embeddings_list), default=0)
if max_regions == 0:
batch_size = len(region_embeddings_list)
return (
torch.zeros(batch_size, 0, self.retrieval_dim, device=device, dtype=dtype),
torch.zeros(batch_size, 0, self.num_specialists, device=device, dtype=dtype),
torch.zeros(batch_size, 0, device=device, dtype=torch.bool),
)
batch_size = len(region_embeddings_list)
padded_regions = torch.zeros(batch_size, max_regions, self.retrieval_dim, device=device, dtype=dtype)
padded_logits = torch.zeros(batch_size, max_regions, self.num_specialists, device=device, dtype=dtype)
padded_mask = torch.zeros(batch_size, max_regions, device=device, dtype=torch.bool)
for idx, (regions, logits, mask) in enumerate(zip(region_embeddings_list, routing_logits_list, routing_mask_list)):
if regions.numel() == 0:
continue
count = regions.size(0)
padded_regions[idx, :count] = regions.to(dtype)
padded_logits[idx, : logits.size(0)] = logits.to(dtype)
padded_mask[idx, : mask.numel()] = mask.to(torch.bool)
return padded_regions, padded_logits, padded_mask
# ----------------------------------------------------------------- #
# User-facing helpers
# ----------------------------------------------------------------- #
@torch.inference_mode()
def encode_queries(
self,
processor,
queries: List[str],
batch_size: int = 8,
max_length: Optional[int] = None,
) -> List[torch.Tensor]:
"""Encode a list of query strings into multi-vector embeddings.
Returns one tensor per query, since queries may have different lengths.
Run this on-GPU for speed; the returned tensors are moved to CPU for
the caller to manage batching.
"""
device = next(self.parameters()).device
out: List[torch.Tensor] = []
for i in range(0, len(queries), batch_size):
batch = processor.process_texts(queries[i : i + batch_size], max_length=max_length).to(device)
emb = self(**batch).embeddings.cpu()
out.extend(list(torch.unbind(emb)))
return out
@torch.inference_mode()
def encode_images(self, processor, images, batch_size: int = 2) -> List[torch.Tensor]:
"""Encode a list of PIL images into multi-vector embeddings."""
device = next(self.parameters()).device
out: List[torch.Tensor] = []
for i in range(0, len(images), batch_size):
batch = processor.process_images(images[i : i + batch_size]).to(device)
emb = self(**batch).embeddings.cpu()
out.extend(list(torch.unbind(emb)))
return out
@staticmethod
def score(
qs: List[torch.Tensor],
ps: List[torch.Tensor],
batch_size: int = 32,
device: Optional[Union[str, torch.device]] = None,
) -> torch.Tensor:
"""MaxSim scoring: for each (q_i, p_j) pair, compute
``sum_t max_p <q_i_t, p_j_p>``. Returns a [N_q, N_p] matrix.
This reproduces ``processor.score_multi_vector`` but lives on the
model so users can compute relevance without touching the processor.
"""
dev = torch.device(device) if device is not None else torch.device("cpu")
n_q, n_p = len(qs), len(ps)
scores = torch.zeros(n_q, n_p, device=dev)
for qi in range(0, n_q, batch_size):
q_slice = qs[qi : qi + batch_size]
q_len = max(x.size(0) for x in q_slice)
q_pad = torch.zeros(len(q_slice), q_len, q_slice[0].size(-1), device=dev)
q_mask = torch.zeros(len(q_slice), q_len, device=dev, dtype=torch.bool)
for i, t in enumerate(q_slice):
q_pad[i, : t.size(0)] = t.to(dev)
q_mask[i, : t.size(0)] = t.abs().sum(dim=-1) > 0
for pi in range(0, n_p, batch_size):
p_slice = ps[pi : pi + batch_size]
p_len = max(x.size(0) for x in p_slice)
p_pad = torch.zeros(len(p_slice), p_len, p_slice[0].size(-1), device=dev)
for j, t in enumerate(p_slice):
p_pad[j, : t.size(0)] = t.to(dev)
sim = torch.einsum("qld,pkd->qplk", q_pad, p_pad)
maxsim = sim.max(dim=-1).values
maxsim = (maxsim * q_mask.unsqueeze(1).to(maxsim.dtype)).sum(dim=-1)
scores[qi : qi + len(q_slice), pi : pi + len(p_slice)] = maxsim
return scores
__all__ = ["ArgusForRetrieval", "ArgusOutput"]