"""Argus: Region-Aware Query-Conditioned Mixture of Experts for Visual Document Retrieval. Self-contained model implementation for the Argus-Colqwen3.5-9B release. Usage ----- >>> from transformers import AutoModel, AutoProcessor >>> model = AutoModel.from_pretrained( ... "DataScience-UIBK/Argus-Colqwen3.5-9B-v0", ... trust_remote_code=True, ... torch_dtype="bfloat16", ... ).eval().cuda() >>> proc = AutoProcessor.from_pretrained( ... "DataScience-UIBK/Argus-Colqwen3.5-9B-v0", ... trust_remote_code=True, ... ) >>> q_emb = model.encode_queries(proc, ["what is the revenue in 2019?"]) >>> d_emb = model.encode_images(proc, [pil_image_1, pil_image_2]) >>> scores = model.score(q_emb, d_emb) # shape [num_queries, num_docs] """ from __future__ import annotations from dataclasses import dataclass from math import ceil from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union import torch import torch.nn.functional as F from torch import nn from transformers.utils import ModelOutput try: from transformers.models.qwen3_5 import Qwen3_5Config, Qwen3_5Model except ImportError: try: from transformers.models.qwen3_5 import Qwen35Config as Qwen3_5Config from transformers.models.qwen3_5 import Qwen35Model as Qwen3_5Model except ImportError as exc: raise ImportError( "Argus requires a transformers build that exposes the Qwen3.5 VL " "classes (transformers.models.qwen3_5). Upgrade to transformers " ">= 4.57.0.dev0." ) from exc from .configuration_argus import ArgusConfig # --------------------------------------------------------------------------- # # Output container # --------------------------------------------------------------------------- # @dataclass class ArgusOutput(ModelOutput): """Output of :meth:`ArgusForRetrieval.forward`. Attributes: embeddings: multi-vector token embeddings [B, T, D]. Use ``score`` / ``score_multi_vector`` against queries encoded the same way. region_embeddings: pooled region-level document embeddings [B, R, D] (only populated when images are in the batch). region_mask: valid mask for region_embeddings, shape [B, R]. routing_logits: raw MoE router logits [B, R, E] (per-region, per-expert). """ embeddings: torch.Tensor region_embeddings: Optional[torch.Tensor] = None region_mask: Optional[torch.Tensor] = None routing_logits: Optional[torch.Tensor] = None # --------------------------------------------------------------------------- # # MoE building blocks # --------------------------------------------------------------------------- # def _ceil_to_multiple(value: int, multiple: int) -> int: return int(ceil(value / multiple) * multiple) class SharedDenseExpert(nn.Module): """Shared expert applied to every spatial location.""" def __init__(self, hidden_dim: int, expansion: int = 4): super().__init__() self.net = nn.Sequential( nn.LayerNorm(hidden_dim), nn.Linear(hidden_dim, hidden_dim * expansion), nn.GELU(), nn.Linear(hidden_dim * expansion, hidden_dim), ) def forward(self, grid: torch.Tensor) -> torch.Tensor: return self.net(grid) class LatentSpatialExpert(nn.Module): """One of ``num_specialists`` region-level experts routed by the query.""" def __init__(self, hidden_dim: int, expansion: int = 2): super().__init__() self.net = nn.Sequential( nn.LayerNorm(hidden_dim), nn.Linear(hidden_dim, hidden_dim * expansion), nn.GELU(), nn.Linear(hidden_dim * expansion, hidden_dim), ) def forward(self, grid: torch.Tensor) -> torch.Tensor: return self.net(grid) class GateScalars(nn.Module): """Two learnable scalars whose sigmoids weight shared / specialist expert contributions onto the final hidden states. """ def __init__(self, shared_init: float = 0.0, specialist_init: float = 0.0): super().__init__() self.shared = nn.Parameter(torch.tensor(float(shared_init), dtype=torch.float32)) self.specialist = nn.Parameter(torch.tensor(float(specialist_init), dtype=torch.float32)) def _apply(self, fn): # noqa: D401 - keep fp32 even after .to(dtype) super()._apply(fn) for name in ("shared", "specialist"): param = getattr(self, name) if param.dtype != torch.float32: param.data = param.data.to(torch.float32) return self def sigmoid(self) -> Tuple[torch.Tensor, torch.Tensor]: return torch.sigmoid(self.shared), torch.sigmoid(self.specialist) # --------------------------------------------------------------------------- # # Argus model # --------------------------------------------------------------------------- # class ArgusForRetrieval(Qwen3_5Model): """Argus multi-vector visual document retriever. Structure: - Backbone: Qwen3.5-VL (9B) — produces per-token hidden states. - Region pool: non-overlapping ``region_size × region_size`` blocks over the vision-token grid; gives a compact region-level view. - Router: per-region MLP → ``num_specialists`` logits; the query (if given) biases the logits via ``query_context_proj``. Top-k sparse softmax. - Experts: one shared expert (applied everywhere) + ``num_specialists`` latent spatial experts (per-region weighted sum). - Fusion: ``final_hidden = final_hidden + σ(gate_shared) · shared_expert + σ(gate_specialist) · specialist_sum``. - Retrieval head: ``custom_text_proj`` projects fused hidden states to ``retrieval_dim`` multi-vectors, L2-normalized. - Query side: no MoE; just backbone + ``custom_text_proj``. The user-facing helpers are ``encode_images``, ``encode_queries``, and ``score`` (MaxSim). All live on this class so a downstream user can do everything via ``model.``. """ config_class = ArgusConfig main_input_name: ClassVar[str] = "input_ids" def __init__(self, config: Union[ArgusConfig, Qwen3_5Config], **kwargs): # Accept either an ArgusConfig or a plain Qwen3_5Config with extra attrs # (transformers sometimes hands us a base-class instance during Auto* # dispatch before config_class kicks in). if not isinstance(config, ArgusConfig): promoted = ArgusConfig(**config.to_dict()) config = promoted dtype = kwargs.pop("dtype", kwargs.pop("torch_dtype", None)) attn_impl = kwargs.pop("attn_implementation", None) use_cache = kwargs.pop("use_cache", None) if hasattr(config, "text_config") and getattr(config.text_config, "rope_scaling", None) is None: config.text_config.rope_scaling = {} if getattr(config, "rope_scaling", None) is None: config.rope_scaling = {} super().__init__(config=config) hidden_size = getattr(config, "hidden_size", None) or getattr(config.text_config, "hidden_size", None) if hidden_size is None: raise ValueError("Argus: could not determine backbone hidden_size from config.") self.retrieval_dim = int(config.retrieval_dim) self.num_specialists = int(config.num_specialists) self.top_k_experts = max(1, min(int(config.top_k_experts), self.num_specialists)) self.region_size = int(config.region_size) self.router_layer_index = int(config.router_layer_index) self.router_temperature = float(config.router_temperature) self.router_noise_std = float(config.router_noise_std) self.mask_non_image_embeddings = bool(config.mask_non_image_embeddings) self.spatial_merge_size = getattr(config.vision_config, "spatial_merge_size", 1) self.padding_side = "left" self.custom_text_proj = nn.Linear(hidden_size, self.retrieval_dim) self.shared_expert = SharedDenseExpert(hidden_size) self.latent_experts = nn.ModuleList( LatentSpatialExpert(hidden_size) for _ in range(self.num_specialists) ) self.region_router = nn.Sequential( nn.LayerNorm(hidden_size), nn.Linear(hidden_size, hidden_size), nn.GELU(), nn.Linear(hidden_size, self.num_specialists), ) self.region_coord_proj = nn.Linear(4, hidden_size, bias=False) self.query_context_proj = nn.Linear(self.retrieval_dim, hidden_size, bias=False) self.gate_scalars = GateScalars( shared_init=config.shared_gate_init, specialist_init=config.specialist_gate_init, ) self.post_init() if dtype is not None: self.to(dtype=dtype) if use_cache is not None: self.config.use_cache = use_cache if attn_impl is not None and hasattr(self, "set_attn_implementation"): self.set_attn_implementation(attn_impl) # ----------------------------------------------------------------- # # Forward # ----------------------------------------------------------------- # def build_query_router_context( self, query_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Pool query multi-vectors into one normalized vector per query. Used to bias the MoE router when the query is known at doc-encode time (cross-encoder-style, optional). Safe to call with query-only outputs of :meth:`forward`. """ if attention_mask is None: pooled = query_embeddings.mean(dim=1) else: weights = attention_mask.unsqueeze(-1).to(query_embeddings.dtype) pooled = (query_embeddings * weights).sum(dim=1) / weights.sum(dim=1).clamp_min(1.0) return pooled / pooled.norm(dim=-1, keepdim=True).clamp_min(1e-12) def forward(self, *args, **kwargs) -> ArgusOutput: """Run backbone + MoE + retrieval head. Inputs follow the standard Qwen3-VL processor outputs: ``input_ids``, ``attention_mask``, and (for images) ``pixel_values`` + ``image_grid_thw``. ``query_context`` is optional and, when given, biases the router for this batch. """ kwargs.pop("region_labels", None) kwargs.pop("region_mask", None) query_context = kwargs.pop("query_context", None) image_grid_thw = kwargs.get("image_grid_thw") # Processor may return per-image padded pixel tensors; the backbone # wants them flat-concatenated. if "pixel_values" in kwargs and image_grid_thw is not None: offsets = image_grid_thw[:, 1] * image_grid_thw[:, 2] kwargs["pixel_values"] = torch.cat( [pv[:off] for pv, off in zip(kwargs["pixel_values"], offsets)], dim=0, ) kwargs.pop("return_dict", True) kwargs.pop("output_hidden_states", None) kwargs.pop("use_cache", None) outputs = super().forward( *args, **kwargs, use_cache=False, output_hidden_states=True, return_dict=True, ) final_hidden = outputs.last_hidden_state router_hidden = outputs.hidden_states[self.router_layer_index] del outputs.hidden_states attention_mask = kwargs["attention_mask"] region_embeddings_list: List[torch.Tensor] = [] routing_logits_list: List[torch.Tensor] = [] routing_mask_list: List[torch.Tensor] = [] if "pixel_values" in kwargs and "input_ids" in kwargs: image_mask = kwargs["input_ids"] == self.config.image_token_id for batch_idx in range(final_hidden.size(0)): image_positions = image_mask[batch_idx].nonzero(as_tuple=False).squeeze(-1) if image_positions.numel() == 0: region_embeddings_list.append(final_hidden.new_zeros(0, self.retrieval_dim)) routing_logits_list.append(final_hidden.new_zeros(0, self.num_specialists)) routing_mask_list.append(final_hidden.new_zeros(0, dtype=torch.bool)) continue grid_t = int(image_grid_thw[batch_idx, 0].item()) raw_grid_h = int(image_grid_thw[batch_idx, 1].item()) raw_grid_w = int(image_grid_thw[batch_idx, 2].item()) grid_h = max(1, raw_grid_h // self.spatial_merge_size) grid_w = max(1, raw_grid_w // self.spatial_merge_size) num_image_tokens = min(grid_t * grid_h * grid_w, image_positions.numel()) image_positions = image_positions[:num_image_tokens] early_grid = router_hidden[batch_idx, image_positions].view(grid_t, grid_h, grid_w, -1).mean(dim=0) final_grid = final_hidden[batch_idx, image_positions].view(grid_t, grid_h, grid_w, -1).mean(dim=0) query_context_i = None if query_context is None else query_context[batch_idx] fused_grid, pooled_regions, pooled_mask, logits = self._apply_query_conditioned_moe( early_grid=early_grid, final_grid=final_grid, query_context=query_context_i, ) fused_tokens = ( fused_grid.unsqueeze(0) .expand(grid_t, -1, -1, -1) .reshape(num_image_tokens, -1) .to(final_hidden.dtype) ) final_hidden[batch_idx, image_positions] = fused_tokens projected_regions = self.custom_text_proj(pooled_regions) projected_regions = projected_regions / projected_regions.norm(dim=-1, keepdim=True).clamp_min(1e-12) region_embeddings_list.append(projected_regions) routing_logits_list.append(logits) routing_mask_list.append(pooled_mask) embeddings = self.custom_text_proj(final_hidden) embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True).clamp_min(1e-12) embeddings = embeddings * attention_mask.unsqueeze(-1) if "pixel_values" in kwargs and self.mask_non_image_embeddings and "input_ids" in kwargs: embeddings = embeddings * (kwargs["input_ids"] == self.config.image_token_id).unsqueeze(-1) region_embeddings, padded_routing_logits, padded_routing_mask = self._pad_regions( region_embeddings_list, routing_logits_list, routing_mask_list, device=embeddings.device, dtype=embeddings.dtype, ) return ArgusOutput( embeddings=embeddings, region_embeddings=region_embeddings, region_mask=padded_routing_mask, routing_logits=padded_routing_logits, ) # ----------------------------------------------------------------- # # MoE internals # ----------------------------------------------------------------- # def _topk_sparse_probs(self, routing_logits: torch.Tensor) -> torch.Tensor: logits = routing_logits.float() if self.training and self.router_noise_std > 0: logits = logits + self.router_noise_std * torch.randn_like(logits) if self.top_k_experts >= self.num_specialists: return F.softmax(logits / max(self.router_temperature, 1e-6), dim=-1).to(routing_logits.dtype) topk_values, topk_indices = torch.topk(logits, k=self.top_k_experts, dim=-1) sparse_logits = torch.full_like(logits, float("-inf")) sparse_logits.scatter_(-1, topk_indices, topk_values) probs = F.softmax(sparse_logits / max(self.router_temperature, 1e-6), dim=-1) return probs.to(routing_logits.dtype) def _apply_query_conditioned_moe( self, early_grid: torch.Tensor, final_grid: torch.Tensor, query_context: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: region_tokens, pooled_mask, coords, region_shape = self._pool_regions(early_grid) router_input = region_tokens + self.region_coord_proj(coords.to(region_tokens.dtype)) if query_context is not None: query_bias = self.query_context_proj(query_context.to(region_tokens.dtype)).unsqueeze(0) router_input = router_input + query_bias routing_logits = self.region_router(router_input) routing_probs = self._topk_sparse_probs(routing_logits) shared_out = self.shared_expert(final_grid) specialist_outputs = torch.stack([expert(final_grid) for expert in self.latent_experts], dim=-2) patch_probs = self._broadcast_region_probs(routing_probs, region_shape, final_grid.shape[:2]) specialist_out = (specialist_outputs * patch_probs.unsqueeze(-1)).sum(dim=-2) shared_sig, specialist_sig = self.gate_scalars.sigmoid() fused_grid = ( final_grid + shared_sig.to(final_grid.dtype) * shared_out + specialist_sig.to(final_grid.dtype) * specialist_out ) pooled_regions, pooled_region_mask, _, _ = self._pool_regions(fused_grid) return fused_grid, pooled_regions, pooled_region_mask, routing_logits def _pool_regions( self, grid: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Tuple[int, int]]: h, w, dim = grid.shape rs = self.region_size hp = _ceil_to_multiple(h, rs) wp = _ceil_to_multiple(w, rs) padded = grid.new_zeros(hp, wp, dim) padded[:h, :w] = grid valid = grid.new_zeros(hp, wp, 1) valid[:h, :w] = 1 num_h = hp // rs num_w = wp // rs blocks = padded.view(num_h, rs, num_w, rs, dim).permute(0, 2, 1, 3, 4).reshape(num_h * num_w, rs * rs, dim) valid_blocks = valid.view(num_h, rs, num_w, rs, 1).permute(0, 2, 1, 3, 4).reshape(num_h * num_w, rs * rs, 1) counts = valid_blocks.sum(dim=1).clamp_min(1.0) pooled = (blocks * valid_blocks).sum(dim=1) / counts mask = counts.squeeze(-1) > 0.5 coords = [] for ry in range(num_h): for rx in range(num_w): y0 = (ry * rs) / max(h, 1) x0 = (rx * rs) / max(w, 1) y1 = min((ry + 1) * rs, h) / max(h, 1) x1 = min((rx + 1) * rs, w) / max(w, 1) coords.append([x0, y0, x1, y1]) coord_tensor = torch.tensor(coords, device=grid.device, dtype=grid.dtype) return pooled, mask, coord_tensor, (num_h, num_w) def _broadcast_region_probs( self, region_probs: torch.Tensor, region_shape: Tuple[int, int], grid_shape: Tuple[int, int], ) -> torch.Tensor: num_h, num_w = region_shape h, w = grid_shape rs = self.region_size hp = num_h * rs wp = num_w * rs probs = region_probs.view(num_h, num_w, self.num_specialists) probs = probs[:, :, None, None, :].expand(num_h, num_w, rs, rs, self.num_specialists) probs = probs.permute(0, 2, 1, 3, 4).reshape(hp, wp, self.num_specialists) return probs[:h, :w] def _pad_regions( self, region_embeddings_list: List[torch.Tensor], routing_logits_list: List[torch.Tensor], routing_mask_list: List[torch.Tensor], device: torch.device, dtype: torch.dtype, ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: if not region_embeddings_list: return None, None, None max_regions = max((regions.size(0) for regions in region_embeddings_list), default=0) if max_regions == 0: batch_size = len(region_embeddings_list) return ( torch.zeros(batch_size, 0, self.retrieval_dim, device=device, dtype=dtype), torch.zeros(batch_size, 0, self.num_specialists, device=device, dtype=dtype), torch.zeros(batch_size, 0, device=device, dtype=torch.bool), ) batch_size = len(region_embeddings_list) padded_regions = torch.zeros(batch_size, max_regions, self.retrieval_dim, device=device, dtype=dtype) padded_logits = torch.zeros(batch_size, max_regions, self.num_specialists, device=device, dtype=dtype) padded_mask = torch.zeros(batch_size, max_regions, device=device, dtype=torch.bool) for idx, (regions, logits, mask) in enumerate(zip(region_embeddings_list, routing_logits_list, routing_mask_list)): if regions.numel() == 0: continue count = regions.size(0) padded_regions[idx, :count] = regions.to(dtype) padded_logits[idx, : logits.size(0)] = logits.to(dtype) padded_mask[idx, : mask.numel()] = mask.to(torch.bool) return padded_regions, padded_logits, padded_mask # ----------------------------------------------------------------- # # User-facing helpers # ----------------------------------------------------------------- # @torch.inference_mode() def encode_queries( self, processor, queries: List[str], batch_size: int = 8, max_length: Optional[int] = None, ) -> List[torch.Tensor]: """Encode a list of query strings into multi-vector embeddings. Returns one tensor per query, since queries may have different lengths. Run this on-GPU for speed; the returned tensors are moved to CPU for the caller to manage batching. """ device = next(self.parameters()).device out: List[torch.Tensor] = [] for i in range(0, len(queries), batch_size): batch = processor.process_texts(queries[i : i + batch_size], max_length=max_length).to(device) emb = self(**batch).embeddings.cpu() out.extend(list(torch.unbind(emb))) return out @torch.inference_mode() def encode_images(self, processor, images, batch_size: int = 2) -> List[torch.Tensor]: """Encode a list of PIL images into multi-vector embeddings.""" device = next(self.parameters()).device out: List[torch.Tensor] = [] for i in range(0, len(images), batch_size): batch = processor.process_images(images[i : i + batch_size]).to(device) emb = self(**batch).embeddings.cpu() out.extend(list(torch.unbind(emb))) return out @staticmethod def score( qs: List[torch.Tensor], ps: List[torch.Tensor], batch_size: int = 32, device: Optional[Union[str, torch.device]] = None, ) -> torch.Tensor: """MaxSim scoring: for each (q_i, p_j) pair, compute ``sum_t max_p ``. Returns a [N_q, N_p] matrix. This reproduces ``processor.score_multi_vector`` but lives on the model so users can compute relevance without touching the processor. """ dev = torch.device(device) if device is not None else torch.device("cpu") n_q, n_p = len(qs), len(ps) scores = torch.zeros(n_q, n_p, device=dev) for qi in range(0, n_q, batch_size): q_slice = qs[qi : qi + batch_size] q_len = max(x.size(0) for x in q_slice) q_pad = torch.zeros(len(q_slice), q_len, q_slice[0].size(-1), device=dev) q_mask = torch.zeros(len(q_slice), q_len, device=dev, dtype=torch.bool) for i, t in enumerate(q_slice): q_pad[i, : t.size(0)] = t.to(dev) q_mask[i, : t.size(0)] = t.abs().sum(dim=-1) > 0 for pi in range(0, n_p, batch_size): p_slice = ps[pi : pi + batch_size] p_len = max(x.size(0) for x in p_slice) p_pad = torch.zeros(len(p_slice), p_len, p_slice[0].size(-1), device=dev) for j, t in enumerate(p_slice): p_pad[j, : t.size(0)] = t.to(dev) sim = torch.einsum("qld,pkd->qplk", q_pad, p_pad) maxsim = sim.max(dim=-1).values maxsim = (maxsim * q_mask.unsqueeze(1).to(maxsim.dtype)).sum(dim=-1) scores[qi : qi + len(q_slice), pi : pi + len(p_slice)] = maxsim return scores __all__ = ["ArgusForRetrieval", "ArgusOutput"]