Feature Extraction
Transformers
Safetensors
ColPali
English
argus_colqwen35
visual-document-retrieval
colqwen
text
image
multimodal-embedding
vidore
mixture-of-experts
late-interaction
query-conditioned-routing
custom_code
Instructions to use DataScience-UIBK/Argus-Colqwen3.5-4b-v0 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use DataScience-UIBK/Argus-Colqwen3.5-4b-v0 with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="DataScience-UIBK/Argus-Colqwen3.5-4b-v0", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("DataScience-UIBK/Argus-Colqwen3.5-4b-v0", trust_remote_code=True, dtype="auto") - ColPali
How to use DataScience-UIBK/Argus-Colqwen3.5-4b-v0 with ColPali:
# No code snippets available yet for this library. # To use this model, check the repository files and the library's documentation. # Want to help? PRs adding snippets are welcome at: # https://github.com/huggingface/huggingface.js
- Notebooks
- Google Colab
- Kaggle
| """Argus: Region-Aware Query-Conditioned Mixture of Experts for Visual Document Retrieval. | |
| Self-contained model implementation for the Argus-Colqwen3.5-9B release. | |
| Usage | |
| ----- | |
| from transformers import AutoModel, AutoProcessor | |
| model = AutoModel.from_pretrained( | |
| "DataScience-UIBK/Argus-Colqwen3.5-9B-v0", | |
| trust_remote_code=True, | |
| torch_dtype="bfloat16", | |
| ).eval().cuda() | |
| proc = AutoProcessor.from_pretrained( | |
| "DataScience-UIBK/Argus-Colqwen3.5-9B-v0", | |
| trust_remote_code=True, | |
| ) | |
| q_emb = model.encode_queries(proc, ["what is the revenue in 2019?"]) | |
| d_emb = model.encode_images(proc, [pil_image_1, pil_image_2]) | |
| scores = model.score(q_emb, d_emb) # shape [num_queries, num_docs] | |
| """ | |
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from math import ceil | |
| from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| from transformers.utils import ModelOutput | |
| try: | |
| from transformers.models.qwen3_5 import Qwen3_5Config, Qwen3_5Model | |
| except ImportError: | |
| try: | |
| from transformers.models.qwen3_5 import Qwen35Config as Qwen3_5Config | |
| from transformers.models.qwen3_5 import Qwen35Model as Qwen3_5Model | |
| except ImportError as exc: | |
| raise ImportError( | |
| "Argus requires a transformers build that exposes the Qwen3.5 VL " | |
| "classes (transformers.models.qwen3_5). Upgrade to transformers " | |
| ">= 4.57.0.dev0." | |
| ) from exc | |
| from .configuration_argus import ArgusConfig | |
| # --------------------------------------------------------------------------- # | |
| # Output container | |
| # --------------------------------------------------------------------------- # | |
| class ArgusOutput(ModelOutput): | |
| """Output of :meth:`ArgusForRetrieval.forward`. | |
| Attributes: | |
| embeddings: multi-vector token embeddings [B, T, D]. Use ``score`` / | |
| ``score_multi_vector`` against queries encoded the same way. | |
| region_embeddings: pooled region-level document embeddings [B, R, D] | |
| (only populated when images are in the batch). | |
| region_mask: valid mask for region_embeddings, shape [B, R]. | |
| routing_logits: raw MoE router logits [B, R, E] (per-region, per-expert). | |
| """ | |
| embeddings: torch.Tensor | |
| region_embeddings: Optional[torch.Tensor] = None | |
| region_mask: Optional[torch.Tensor] = None | |
| routing_logits: Optional[torch.Tensor] = None | |
| # --------------------------------------------------------------------------- # | |
| # MoE building blocks | |
| # --------------------------------------------------------------------------- # | |
| def _ceil_to_multiple(value: int, multiple: int) -> int: | |
| return int(ceil(value / multiple) * multiple) | |
| class SharedDenseExpert(nn.Module): | |
| """Shared expert applied to every spatial location.""" | |
| def __init__(self, hidden_dim: int, expansion: int = 4): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.LayerNorm(hidden_dim), | |
| nn.Linear(hidden_dim, hidden_dim * expansion), | |
| nn.GELU(), | |
| nn.Linear(hidden_dim * expansion, hidden_dim), | |
| ) | |
| def forward(self, grid: torch.Tensor) -> torch.Tensor: | |
| return self.net(grid) | |
| class LatentSpatialExpert(nn.Module): | |
| """One of ``num_specialists`` region-level experts routed by the query.""" | |
| def __init__(self, hidden_dim: int, expansion: int = 2): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.LayerNorm(hidden_dim), | |
| nn.Linear(hidden_dim, hidden_dim * expansion), | |
| nn.GELU(), | |
| nn.Linear(hidden_dim * expansion, hidden_dim), | |
| ) | |
| def forward(self, grid: torch.Tensor) -> torch.Tensor: | |
| return self.net(grid) | |
| class GateScalars(nn.Module): | |
| """Two learnable scalars whose sigmoids weight shared / specialist expert | |
| contributions onto the final hidden states. | |
| """ | |
| def __init__(self, shared_init: float = 0.0, specialist_init: float = 0.0): | |
| super().__init__() | |
| self.shared = nn.Parameter(torch.tensor(float(shared_init), dtype=torch.float32)) | |
| self.specialist = nn.Parameter(torch.tensor(float(specialist_init), dtype=torch.float32)) | |
| def _apply(self, fn): # noqa: D401 - keep fp32 even after .to(dtype) | |
| super()._apply(fn) | |
| for name in ("shared", "specialist"): | |
| param = getattr(self, name) | |
| if param.dtype != torch.float32: | |
| param.data = param.data.to(torch.float32) | |
| return self | |
| def sigmoid(self) -> Tuple[torch.Tensor, torch.Tensor]: | |
| return torch.sigmoid(self.shared), torch.sigmoid(self.specialist) | |
| # --------------------------------------------------------------------------- # | |
| # Argus model | |
| # --------------------------------------------------------------------------- # | |
| class ArgusForRetrieval(Qwen3_5Model): | |
| """Argus multi-vector visual document retriever. | |
| Structure: | |
| - Backbone: Qwen3.5-VL (9B) — produces per-token hidden states. | |
| - Region pool: non-overlapping ``region_size × region_size`` blocks over | |
| the vision-token grid; gives a compact region-level view. | |
| - Router: per-region MLP → ``num_specialists`` logits; the query (if given) | |
| biases the logits via ``query_context_proj``. Top-k sparse softmax. | |
| - Experts: one shared expert (applied everywhere) + ``num_specialists`` | |
| latent spatial experts (per-region weighted sum). | |
| - Fusion: ``final_hidden = final_hidden + σ(gate_shared) · shared_expert | |
| + σ(gate_specialist) · specialist_sum``. | |
| - Retrieval head: ``custom_text_proj`` projects fused hidden states to | |
| ``retrieval_dim`` multi-vectors, L2-normalized. | |
| - Query side: no MoE; just backbone + ``custom_text_proj``. | |
| The user-facing helpers are ``encode_images``, ``encode_queries``, and | |
| ``score`` (MaxSim). All live on this class so a downstream user can do | |
| everything via ``model.<method>``. | |
| """ | |
| config_class = ArgusConfig | |
| main_input_name: ClassVar[str] = "input_ids" | |
| def __init__(self, config: Union[ArgusConfig, Qwen3_5Config], **kwargs): | |
| # Accept either an ArgusConfig or a plain Qwen3_5Config with extra attrs | |
| # (transformers sometimes hands us a base-class instance during Auto* | |
| # dispatch before config_class kicks in). | |
| if not isinstance(config, ArgusConfig): | |
| promoted = ArgusConfig(**config.to_dict()) | |
| config = promoted | |
| dtype = kwargs.pop("dtype", kwargs.pop("torch_dtype", None)) | |
| attn_impl = kwargs.pop("attn_implementation", None) | |
| use_cache = kwargs.pop("use_cache", None) | |
| if hasattr(config, "text_config") and getattr(config.text_config, "rope_scaling", None) is None: | |
| config.text_config.rope_scaling = {} | |
| if getattr(config, "rope_scaling", None) is None: | |
| config.rope_scaling = {} | |
| super().__init__(config=config) | |
| hidden_size = getattr(config, "hidden_size", None) or getattr(config.text_config, "hidden_size", None) | |
| if hidden_size is None: | |
| raise ValueError("Argus: could not determine backbone hidden_size from config.") | |
| self.retrieval_dim = int(config.retrieval_dim) | |
| self.num_specialists = int(config.num_specialists) | |
| self.top_k_experts = max(1, min(int(config.top_k_experts), self.num_specialists)) | |
| self.region_size = int(config.region_size) | |
| self.router_layer_index = int(config.router_layer_index) | |
| self.router_temperature = float(config.router_temperature) | |
| self.router_noise_std = float(config.router_noise_std) | |
| self.mask_non_image_embeddings = bool(config.mask_non_image_embeddings) | |
| self.spatial_merge_size = getattr(config.vision_config, "spatial_merge_size", 1) | |
| self.padding_side = "left" | |
| self.custom_text_proj = nn.Linear(hidden_size, self.retrieval_dim) | |
| self.shared_expert = SharedDenseExpert(hidden_size) | |
| self.latent_experts = nn.ModuleList( | |
| LatentSpatialExpert(hidden_size) for _ in range(self.num_specialists) | |
| ) | |
| self.region_router = nn.Sequential( | |
| nn.LayerNorm(hidden_size), | |
| nn.Linear(hidden_size, hidden_size), | |
| nn.GELU(), | |
| nn.Linear(hidden_size, self.num_specialists), | |
| ) | |
| self.region_coord_proj = nn.Linear(4, hidden_size, bias=False) | |
| self.query_context_proj = nn.Linear(self.retrieval_dim, hidden_size, bias=False) | |
| self.gate_scalars = GateScalars( | |
| shared_init=config.shared_gate_init, | |
| specialist_init=config.specialist_gate_init, | |
| ) | |
| self.post_init() | |
| if dtype is not None: | |
| self.to(dtype=dtype) | |
| if use_cache is not None: | |
| self.config.use_cache = use_cache | |
| if attn_impl is not None and hasattr(self, "set_attn_implementation"): | |
| self.set_attn_implementation(attn_impl) | |
| # ----------------------------------------------------------------- # | |
| # Forward | |
| # ----------------------------------------------------------------- # | |
| def build_query_router_context( | |
| self, | |
| query_embeddings: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| """Pool query multi-vectors into one normalized vector per query. | |
| Used to bias the MoE router when the query is known at doc-encode | |
| time (cross-encoder-style, optional). Safe to call with query-only | |
| outputs of :meth:`forward`. | |
| """ | |
| if attention_mask is None: | |
| pooled = query_embeddings.mean(dim=1) | |
| else: | |
| weights = attention_mask.unsqueeze(-1).to(query_embeddings.dtype) | |
| pooled = (query_embeddings * weights).sum(dim=1) / weights.sum(dim=1).clamp_min(1.0) | |
| return pooled / pooled.norm(dim=-1, keepdim=True).clamp_min(1e-12) | |
| def forward(self, *args, **kwargs) -> ArgusOutput: | |
| """Run backbone + MoE + retrieval head. | |
| Inputs follow the standard Qwen3-VL processor outputs: | |
| ``input_ids``, ``attention_mask``, and (for images) ``pixel_values`` | |
| + ``image_grid_thw``. ``query_context`` is optional and, when given, | |
| biases the router for this batch. | |
| """ | |
| kwargs.pop("region_labels", None) | |
| kwargs.pop("region_mask", None) | |
| query_context = kwargs.pop("query_context", None) | |
| image_grid_thw = kwargs.get("image_grid_thw") | |
| # Processor may return per-image padded pixel tensors; the backbone | |
| # wants them flat-concatenated. | |
| if "pixel_values" in kwargs and image_grid_thw is not None: | |
| offsets = image_grid_thw[:, 1] * image_grid_thw[:, 2] | |
| kwargs["pixel_values"] = torch.cat( | |
| [pv[:off] for pv, off in zip(kwargs["pixel_values"], offsets)], | |
| dim=0, | |
| ) | |
| kwargs.pop("return_dict", True) | |
| kwargs.pop("output_hidden_states", None) | |
| kwargs.pop("use_cache", None) | |
| outputs = super().forward( | |
| *args, | |
| **kwargs, | |
| use_cache=False, | |
| output_hidden_states=True, | |
| return_dict=True, | |
| ) | |
| final_hidden = outputs.last_hidden_state | |
| router_hidden = outputs.hidden_states[self.router_layer_index] | |
| del outputs.hidden_states | |
| attention_mask = kwargs["attention_mask"] | |
| region_embeddings_list: List[torch.Tensor] = [] | |
| routing_logits_list: List[torch.Tensor] = [] | |
| routing_mask_list: List[torch.Tensor] = [] | |
| if "pixel_values" in kwargs and "input_ids" in kwargs: | |
| image_mask = kwargs["input_ids"] == self.config.image_token_id | |
| for batch_idx in range(final_hidden.size(0)): | |
| image_positions = image_mask[batch_idx].nonzero(as_tuple=False).squeeze(-1) | |
| if image_positions.numel() == 0: | |
| region_embeddings_list.append(final_hidden.new_zeros(0, self.retrieval_dim)) | |
| routing_logits_list.append(final_hidden.new_zeros(0, self.num_specialists)) | |
| routing_mask_list.append(final_hidden.new_zeros(0, dtype=torch.bool)) | |
| continue | |
| grid_t = int(image_grid_thw[batch_idx, 0].item()) | |
| raw_grid_h = int(image_grid_thw[batch_idx, 1].item()) | |
| raw_grid_w = int(image_grid_thw[batch_idx, 2].item()) | |
| grid_h = max(1, raw_grid_h // self.spatial_merge_size) | |
| grid_w = max(1, raw_grid_w // self.spatial_merge_size) | |
| num_image_tokens = min(grid_t * grid_h * grid_w, image_positions.numel()) | |
| image_positions = image_positions[:num_image_tokens] | |
| early_grid = router_hidden[batch_idx, image_positions].view(grid_t, grid_h, grid_w, -1).mean(dim=0) | |
| final_grid = final_hidden[batch_idx, image_positions].view(grid_t, grid_h, grid_w, -1).mean(dim=0) | |
| query_context_i = None if query_context is None else query_context[batch_idx] | |
| fused_grid, pooled_regions, pooled_mask, logits = self._apply_query_conditioned_moe( | |
| early_grid=early_grid, | |
| final_grid=final_grid, | |
| query_context=query_context_i, | |
| ) | |
| fused_tokens = ( | |
| fused_grid.unsqueeze(0) | |
| .expand(grid_t, -1, -1, -1) | |
| .reshape(num_image_tokens, -1) | |
| .to(final_hidden.dtype) | |
| ) | |
| final_hidden[batch_idx, image_positions] = fused_tokens | |
| projected_regions = self.custom_text_proj(pooled_regions) | |
| projected_regions = projected_regions / projected_regions.norm(dim=-1, keepdim=True).clamp_min(1e-12) | |
| region_embeddings_list.append(projected_regions) | |
| routing_logits_list.append(logits) | |
| routing_mask_list.append(pooled_mask) | |
| embeddings = self.custom_text_proj(final_hidden) | |
| embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True).clamp_min(1e-12) | |
| embeddings = embeddings * attention_mask.unsqueeze(-1) | |
| if "pixel_values" in kwargs and self.mask_non_image_embeddings and "input_ids" in kwargs: | |
| embeddings = embeddings * (kwargs["input_ids"] == self.config.image_token_id).unsqueeze(-1) | |
| region_embeddings, padded_routing_logits, padded_routing_mask = self._pad_regions( | |
| region_embeddings_list, | |
| routing_logits_list, | |
| routing_mask_list, | |
| device=embeddings.device, | |
| dtype=embeddings.dtype, | |
| ) | |
| return ArgusOutput( | |
| embeddings=embeddings, | |
| region_embeddings=region_embeddings, | |
| region_mask=padded_routing_mask, | |
| routing_logits=padded_routing_logits, | |
| ) | |
| # ----------------------------------------------------------------- # | |
| # MoE internals | |
| # ----------------------------------------------------------------- # | |
| def _topk_sparse_probs(self, routing_logits: torch.Tensor) -> torch.Tensor: | |
| logits = routing_logits.float() | |
| if self.training and self.router_noise_std > 0: | |
| logits = logits + self.router_noise_std * torch.randn_like(logits) | |
| if self.top_k_experts >= self.num_specialists: | |
| return F.softmax(logits / max(self.router_temperature, 1e-6), dim=-1).to(routing_logits.dtype) | |
| topk_values, topk_indices = torch.topk(logits, k=self.top_k_experts, dim=-1) | |
| sparse_logits = torch.full_like(logits, float("-inf")) | |
| sparse_logits.scatter_(-1, topk_indices, topk_values) | |
| probs = F.softmax(sparse_logits / max(self.router_temperature, 1e-6), dim=-1) | |
| return probs.to(routing_logits.dtype) | |
| def _apply_query_conditioned_moe( | |
| self, | |
| early_grid: torch.Tensor, | |
| final_grid: torch.Tensor, | |
| query_context: Optional[torch.Tensor] = None, | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
| region_tokens, pooled_mask, coords, region_shape = self._pool_regions(early_grid) | |
| router_input = region_tokens + self.region_coord_proj(coords.to(region_tokens.dtype)) | |
| if query_context is not None: | |
| query_bias = self.query_context_proj(query_context.to(region_tokens.dtype)).unsqueeze(0) | |
| router_input = router_input + query_bias | |
| routing_logits = self.region_router(router_input) | |
| routing_probs = self._topk_sparse_probs(routing_logits) | |
| shared_out = self.shared_expert(final_grid) | |
| specialist_outputs = torch.stack([expert(final_grid) for expert in self.latent_experts], dim=-2) | |
| patch_probs = self._broadcast_region_probs(routing_probs, region_shape, final_grid.shape[:2]) | |
| specialist_out = (specialist_outputs * patch_probs.unsqueeze(-1)).sum(dim=-2) | |
| shared_sig, specialist_sig = self.gate_scalars.sigmoid() | |
| fused_grid = ( | |
| final_grid | |
| + shared_sig.to(final_grid.dtype) * shared_out | |
| + specialist_sig.to(final_grid.dtype) * specialist_out | |
| ) | |
| pooled_regions, pooled_region_mask, _, _ = self._pool_regions(fused_grid) | |
| return fused_grid, pooled_regions, pooled_region_mask, routing_logits | |
| def _pool_regions( | |
| self, | |
| grid: torch.Tensor, | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Tuple[int, int]]: | |
| h, w, dim = grid.shape | |
| rs = self.region_size | |
| hp = _ceil_to_multiple(h, rs) | |
| wp = _ceil_to_multiple(w, rs) | |
| padded = grid.new_zeros(hp, wp, dim) | |
| padded[:h, :w] = grid | |
| valid = grid.new_zeros(hp, wp, 1) | |
| valid[:h, :w] = 1 | |
| num_h = hp // rs | |
| num_w = wp // rs | |
| blocks = padded.view(num_h, rs, num_w, rs, dim).permute(0, 2, 1, 3, 4).reshape(num_h * num_w, rs * rs, dim) | |
| valid_blocks = valid.view(num_h, rs, num_w, rs, 1).permute(0, 2, 1, 3, 4).reshape(num_h * num_w, rs * rs, 1) | |
| counts = valid_blocks.sum(dim=1).clamp_min(1.0) | |
| pooled = (blocks * valid_blocks).sum(dim=1) / counts | |
| mask = counts.squeeze(-1) > 0.5 | |
| coords = [] | |
| for ry in range(num_h): | |
| for rx in range(num_w): | |
| y0 = (ry * rs) / max(h, 1) | |
| x0 = (rx * rs) / max(w, 1) | |
| y1 = min((ry + 1) * rs, h) / max(h, 1) | |
| x1 = min((rx + 1) * rs, w) / max(w, 1) | |
| coords.append([x0, y0, x1, y1]) | |
| coord_tensor = torch.tensor(coords, device=grid.device, dtype=grid.dtype) | |
| return pooled, mask, coord_tensor, (num_h, num_w) | |
| def _broadcast_region_probs( | |
| self, | |
| region_probs: torch.Tensor, | |
| region_shape: Tuple[int, int], | |
| grid_shape: Tuple[int, int], | |
| ) -> torch.Tensor: | |
| num_h, num_w = region_shape | |
| h, w = grid_shape | |
| rs = self.region_size | |
| hp = num_h * rs | |
| wp = num_w * rs | |
| probs = region_probs.view(num_h, num_w, self.num_specialists) | |
| probs = probs[:, :, None, None, :].expand(num_h, num_w, rs, rs, self.num_specialists) | |
| probs = probs.permute(0, 2, 1, 3, 4).reshape(hp, wp, self.num_specialists) | |
| return probs[:h, :w] | |
| def _pad_regions( | |
| self, | |
| region_embeddings_list: List[torch.Tensor], | |
| routing_logits_list: List[torch.Tensor], | |
| routing_mask_list: List[torch.Tensor], | |
| device: torch.device, | |
| dtype: torch.dtype, | |
| ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: | |
| if not region_embeddings_list: | |
| return None, None, None | |
| max_regions = max((regions.size(0) for regions in region_embeddings_list), default=0) | |
| if max_regions == 0: | |
| batch_size = len(region_embeddings_list) | |
| return ( | |
| torch.zeros(batch_size, 0, self.retrieval_dim, device=device, dtype=dtype), | |
| torch.zeros(batch_size, 0, self.num_specialists, device=device, dtype=dtype), | |
| torch.zeros(batch_size, 0, device=device, dtype=torch.bool), | |
| ) | |
| batch_size = len(region_embeddings_list) | |
| padded_regions = torch.zeros(batch_size, max_regions, self.retrieval_dim, device=device, dtype=dtype) | |
| padded_logits = torch.zeros(batch_size, max_regions, self.num_specialists, device=device, dtype=dtype) | |
| padded_mask = torch.zeros(batch_size, max_regions, device=device, dtype=torch.bool) | |
| for idx, (regions, logits, mask) in enumerate(zip(region_embeddings_list, routing_logits_list, routing_mask_list)): | |
| if regions.numel() == 0: | |
| continue | |
| count = regions.size(0) | |
| padded_regions[idx, :count] = regions.to(dtype) | |
| padded_logits[idx, : logits.size(0)] = logits.to(dtype) | |
| padded_mask[idx, : mask.numel()] = mask.to(torch.bool) | |
| return padded_regions, padded_logits, padded_mask | |
| # ----------------------------------------------------------------- # | |
| # User-facing helpers | |
| # ----------------------------------------------------------------- # | |
| def encode_queries( | |
| self, | |
| processor, | |
| queries: List[str], | |
| batch_size: int = 8, | |
| max_length: Optional[int] = None, | |
| ) -> List[torch.Tensor]: | |
| """Encode a list of query strings into multi-vector embeddings. | |
| Returns one tensor per query, since queries may have different lengths. | |
| Run this on-GPU for speed; the returned tensors are moved to CPU for | |
| the caller to manage batching. | |
| """ | |
| device = next(self.parameters()).device | |
| out: List[torch.Tensor] = [] | |
| for i in range(0, len(queries), batch_size): | |
| batch = processor.process_texts(queries[i : i + batch_size], max_length=max_length).to(device) | |
| emb = self(**batch).embeddings.cpu() | |
| out.extend(list(torch.unbind(emb))) | |
| return out | |
| def encode_images(self, processor, images, batch_size: int = 2) -> List[torch.Tensor]: | |
| """Encode a list of PIL images into multi-vector embeddings.""" | |
| device = next(self.parameters()).device | |
| out: List[torch.Tensor] = [] | |
| for i in range(0, len(images), batch_size): | |
| batch = processor.process_images(images[i : i + batch_size]).to(device) | |
| emb = self(**batch).embeddings.cpu() | |
| out.extend(list(torch.unbind(emb))) | |
| return out | |
| def score( | |
| qs: List[torch.Tensor], | |
| ps: List[torch.Tensor], | |
| batch_size: int = 32, | |
| device: Optional[Union[str, torch.device]] = None, | |
| ) -> torch.Tensor: | |
| """MaxSim scoring: for each (q_i, p_j) pair, compute | |
| ``sum_t max_p <q_i_t, p_j_p>``. Returns a [N_q, N_p] matrix. | |
| This reproduces ``processor.score_multi_vector`` but lives on the | |
| model so users can compute relevance without touching the processor. | |
| """ | |
| dev = torch.device(device) if device is not None else torch.device("cpu") | |
| n_q, n_p = len(qs), len(ps) | |
| scores = torch.zeros(n_q, n_p, device=dev) | |
| for qi in range(0, n_q, batch_size): | |
| q_slice = qs[qi : qi + batch_size] | |
| q_len = max(x.size(0) for x in q_slice) | |
| q_pad = torch.zeros(len(q_slice), q_len, q_slice[0].size(-1), device=dev) | |
| q_mask = torch.zeros(len(q_slice), q_len, device=dev, dtype=torch.bool) | |
| for i, t in enumerate(q_slice): | |
| q_pad[i, : t.size(0)] = t.to(dev) | |
| q_mask[i, : t.size(0)] = t.abs().sum(dim=-1) > 0 | |
| for pi in range(0, n_p, batch_size): | |
| p_slice = ps[pi : pi + batch_size] | |
| p_len = max(x.size(0) for x in p_slice) | |
| p_pad = torch.zeros(len(p_slice), p_len, p_slice[0].size(-1), device=dev) | |
| for j, t in enumerate(p_slice): | |
| p_pad[j, : t.size(0)] = t.to(dev) | |
| sim = torch.einsum("qld,pkd->qplk", q_pad, p_pad) | |
| maxsim = sim.max(dim=-1).values | |
| maxsim = (maxsim * q_mask.unsqueeze(1).to(maxsim.dtype)).sum(dim=-1) | |
| scores[qi : qi + len(q_slice), pi : pi + len(p_slice)] = maxsim | |
| return scores | |
| __all__ = ["ArgusForRetrieval", "ArgusOutput"] | |