from typing import Dict, List, Any, Union import torch import numpy as np import base64 import io import tempfile import os import transformers import logging from pathlib import Path print("transformers version ", transformers.__version__) # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class EndpointHandler: """ Custom HuggingFace Inference Endpoint Handler for V-JEPA2 Video Embeddings. This handler processes videos and returns pooled embeddings suitable for similarity search and vector databases like LanceDB. Features: - Batch processing support for efficient inference - Handles variable-length videos via uniform frame sampling - Supports video URLs and base64-encoded videos - Returns 1408-dimensional pooled embeddings """ def __init__(self, path: str = ""): """ Initialize the V-JEPA2 model and processor. Args: path: Path to the model weights (provided by HF Inference Endpoints) """ try: from transformers import AutoVideoProcessor, AutoModel from torchcodec.decoders import VideoDecoder logger.info(f"Loading V-JEPA2 model from {path}") # Determine device self.device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Using device: {self.device}") # Load model without the classification head to get embeddings # We use AutoModel instead of AutoModelForVideoClassification self.model = AutoModel.from_pretrained(path).to(self.device) self.processor = AutoVideoProcessor.from_pretrained(path) # Set model to evaluation mode self.model.eval() # Store model config self.frames_per_clip = getattr(self.model.config, 'frames_per_clip', 64) self.hidden_size = getattr(self.model.config, 'hidden_size', 1408) logger.info(f"Model loaded successfully. Frames per clip: {self.frames_per_clip}, Hidden size: {self.hidden_size}") except Exception as e: logger.error(f"Error initializing model: {str(e)}") raise def _load_video_from_url(self, video_url: str) -> np.ndarray: """ Load video from URL and sample frames. Args: video_url: URL to the video file Returns: Video tensor with shape (frames, channels, height, width) """ from torchcodec.decoders import VideoDecoder try: vr = VideoDecoder(video_url) total_frames = len(vr) # Uniform sampling to get exactly frames_per_clip frames if total_frames < self.frames_per_clip: logger.warning(f"Video has only {total_frames} frames, less than required {self.frames_per_clip}. Repeating frames.") # Repeat frames to reach required count frame_indices = np.tile(np.arange(total_frames), (self.frames_per_clip // total_frames) + 1)[:self.frames_per_clip] else: # Uniform sampling across the video frame_indices = np.linspace(0, total_frames - 1, self.frames_per_clip, dtype=int) video = vr.get_frames_at(indices=frame_indices).data return video except Exception as e: logger.error(f"Error loading video from URL {video_url}: {str(e)}") raise def _load_video_from_base64(self, video_b64: str) -> np.ndarray: """ Load video from base64-encoded data. Args: video_b64: Base64-encoded video data Returns: Video tensor with shape (frames, channels, height, width) """ from torchcodec.decoders import VideoDecoder try: # Decode base64 video_bytes = base64.b64decode(video_b64) # Save to temporary file (torchcodec requires file path) with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as tmp_file: tmp_file.write(video_bytes) tmp_path = tmp_file.name try: vr = VideoDecoder(tmp_path) total_frames = len(vr) # Uniform sampling if total_frames < self.frames_per_clip: frame_indices = np.tile(np.arange(total_frames), (self.frames_per_clip // total_frames) + 1)[:self.frames_per_clip] else: frame_indices = np.linspace(0, total_frames - 1, self.frames_per_clip, dtype=int) video = vr.get_frames_at(indices=frame_indices).data return video finally: # Clean up temporary file os.unlink(tmp_path) except Exception as e: logger.error(f"Error loading video from base64: {str(e)}") raise def _extract_embeddings(self, videos: List[np.ndarray]) -> np.ndarray: """ Extract pooled embeddings from a batch of videos. Args: videos: List of video tensors Returns: Numpy array of shape (batch_size, hidden_size) containing pooled embeddings """ try: # Process videos through the processor inputs = self.processor(videos, return_tensors="pt").to(self.device) # Run inference with torch.no_grad(): outputs = self.model(**inputs, output_hidden_states=True) # Extract last hidden state and pool # Shape: (batch_size, sequence_length, hidden_size) last_hidden_state = outputs.last_hidden_state # Mean pooling across sequence dimension # Shape: (batch_size, hidden_size) pooled_embeddings = last_hidden_state.mean(dim=1) # Convert to numpy embeddings = pooled_embeddings.cpu().numpy() return embeddings except Exception as e: logger.error(f"Error extracting embeddings: {str(e)}") raise def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ Process inference request. Expected input formats: 1. Single video URL: {"inputs": "https://example.com/video.mp4"} 2. Batch of video URLs: {"inputs": ["url1", "url2", "url3"]} 3. Base64-encoded video: {"inputs": "base64_encoded_string", "encoding": "base64"} 4. Batch with mixed formats: {"inputs": [...], "batch_size": 4} Returns: List of dictionaries containing embeddings: [{"embedding": [1408-dim vector], "shape": [1408]}] """ try: # Extract inputs inputs = data.get("inputs") encoding = data.get("encoding", "url") if inputs is None: raise ValueError("No 'inputs' provided in request data") # Handle single input vs batch if isinstance(inputs, str): inputs = [inputs] elif not isinstance(inputs, list): raise ValueError(f"'inputs' must be a string or list, got {type(inputs)}") logger.info(f"Processing {len(inputs)} video(s)") # Load videos videos = [] for idx, inp in enumerate(inputs): try: if encoding == "base64": video = self._load_video_from_base64(inp) else: # Default to URL video = self._load_video_from_url(inp) videos.append(video) except Exception as e: logger.error(f"Error loading video {idx}: {str(e)}") # Return error for this specific video videos.append(None) # Filter out failed videos and track their indices valid_videos = [] valid_indices = [] for idx, video in enumerate(videos): if video is not None: valid_videos.append(video) valid_indices.append(idx) if not valid_videos: raise ValueError("No valid videos could be loaded") # Extract embeddings for valid videos embeddings = self._extract_embeddings(valid_videos) # Prepare results results = [None] * len(inputs) for valid_idx, embedding in zip(valid_indices, embeddings): results[valid_idx] = { "embedding": embedding.tolist(), "shape": list(embedding.shape), "status": "success" } # Fill in errors for failed videos for idx in range(len(inputs)): if results[idx] is None: results[idx] = { "embedding": None, "shape": None, "status": "error", "error": "Failed to load video" } logger.info(f"Successfully processed {len(valid_videos)}/{len(inputs)} videos") return results except Exception as e: logger.error(f"Error in __call__: {str(e)}") return [{"error": str(e), "status": "error"}]