import asyncio import logging import os import sys from contextlib import asynccontextmanager from typing import List, Optional import chromadb import httpx import polars as pl import torch from cashews import cache from fastapi import FastAPI, HTTPException, Query from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from dotenv import load_dotenv from huggingface_hub import login load_dotenv(override=True) HF_TOKEN = os.getenv("HF_TOKEN") login(token=HF_TOKEN) # Configuration constants EMBEDDING_MODEL = "Qwen/Qwen3-Embedding-0.6B" EMBEDDINGS_REPO = "davanstrien/search-v2-embeddings" CACHE_TTL = "24h" TRENDING_CACHE_TTL = "1h" BATCH_SIZE = 2000 if torch.cuda.is_available(): DEVICE = "cuda" elif torch.backends.mps.is_available(): DEVICE = "mps" else: DEVICE = "cpu" os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) LOCAL = sys.platform == "darwin" CHROMA_PATH = "/tmp/chroma" if not LOCAL else "data/chroma" EMBEDDINGS_MOUNT = "/embeddings" if not LOCAL else None # Configure cache cache.setup("mem://", size_limit="8gb") # Initialize ChromaDB in /tmp (writable) client = chromadb.PersistentClient(path=CHROMA_PATH) def load_embeddings_into_collection(config_name, collection_name, id_col, extra_meta_cols=None): """Load pre-computed embeddings from a Hub dataset config into a ChromaDB collection.""" collection = client.get_or_create_collection( name=collection_name, metadata={"hnsw:space": "cosine"}, ) if collection.count() > 0: logger.info(f"{collection_name}: already has {collection.count():,} records, skipping load") return collection # Read from mounted dataset or from Hub if EMBEDDINGS_MOUNT and os.path.exists(EMBEDDINGS_MOUNT): # Debug: log mount contents to find correct path logger.info(f"Mount contents at {EMBEDDINGS_MOUNT}: {os.listdir(EMBEDDINGS_MOUNT)}") mount_config_path = f"{EMBEDDINGS_MOUNT}/{config_name}" if os.path.exists(mount_config_path): logger.info(f"Config dir contents: {os.listdir(mount_config_path)}") parquet_path = f"{mount_config_path}/train-*.parquet" else: # Maybe files are at root with config prefix parquet_path = f"{EMBEDDINGS_MOUNT}/train-*.parquet" logger.info(f"Loading {config_name} from mount: {parquet_path}") df = pl.read_parquet(parquet_path) else: logger.info(f"Loading {config_name} from Hub: {EMBEDDINGS_REPO}") from datasets import load_dataset ds = load_dataset(EMBEDDINGS_REPO, config_name, split="train") df = ds.to_polars() logger.info(f"Loaded {len(df):,} records for {collection_name}") for i in range(0, len(df), BATCH_SIZE): batch = df.slice(i, min(BATCH_SIZE, len(df) - i)) metadatas = [] for row in batch.iter_rows(named=True): meta = { "likes": int(row["likes"]), "downloads": int(row["downloads"]), "last_modified": str(row["last_modified"]), } if extra_meta_cols: for col in extra_meta_cols: val = row.get(col) meta[col] = int(val) if val is not None else 0 metadatas.append(meta) collection.upsert( ids=batch["id"].to_list(), documents=batch["summary"].to_list(), embeddings=batch["embedding"].to_list(), metadatas=metadatas, ) logger.info(f"{collection_name}: loaded {min(i + BATCH_SIZE, len(df)):,} / {len(df):,}") logger.info(f"{collection_name}: {collection.count():,} records loaded") return collection # Initialize FastAPI app @asynccontextmanager async def lifespan(app: FastAPI): logger.info(f"ChromaDB path: {CHROMA_PATH}") logger.info("Loading pre-computed embeddings into ChromaDB...") load_embeddings_into_collection("dataset_cards", "dataset_cards", "id") load_embeddings_into_collection("model_cards", "model_cards", "id", extra_meta_cols=["param_count"]) logger.info("Index ready") yield await cache.close() app = FastAPI(lifespan=lifespan) # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=[ "https://*.hf.space", # Allow all Hugging Face Spaces "https://*.huggingface.co", # Allow all Hugging Face domains # "http://localhost:5500", # Allow localhost:5500 # TODO remove before prod ], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Lazy-loaded SentenceTransformer for embedding queries at search time _query_model = None def get_query_model(): global _query_model if _query_model is None: from sentence_transformers import SentenceTransformer logger.info(f"Loading query embedding model {EMBEDDING_MODEL} on {DEVICE}") _query_model = SentenceTransformer(EMBEDDING_MODEL, device=DEVICE) return _query_model def embed_query(text: str) -> list[float]: """Embed a search query using the same model used to build the index.""" model = get_query_model() return model.encode(text, prompt_name="query").tolist() class QueryResult(BaseModel): dataset_id: str similarity: float summary: str likes: int downloads: int class QueryResponse(BaseModel): results: List[QueryResult] class ModelQueryResult(BaseModel): model_id: str similarity: float summary: str likes: int downloads: int param_count: Optional[int] = None class ModelQueryResponse(BaseModel): results: List[ModelQueryResult] @app.get("/") async def redirect_to_docs(): from fastapi.responses import RedirectResponse return RedirectResponse(url="/docs") @app.get("/search/datasets", response_model=QueryResponse) @cache(ttl=CACHE_TTL) async def search_datasets( query: str, k: int = Query(default=5, ge=1, le=100), sort_by: str = Query( default="similarity", enum=["similarity", "likes", "downloads", "trending"] ), min_likes: int = Query(default=0, ge=0), min_downloads: int = Query(default=0, ge=0), ): try: collection = client.get_collection(name="dataset_cards") task_description = "Given a search query, retrieve relevant model and dataset summaries that match the query. " query_text = f"Instruct: {task_description}\nQuery:{query}" query_embedding = embed_query(query_text) results = collection.query( query_embeddings=[query_embedding], n_results=k * 4 if sort_by != "similarity" else k, where={ "$and": [ {"likes": {"$gte": min_likes}}, {"downloads": {"$gte": min_downloads}}, ] } if min_likes > 0 or min_downloads > 0 else None, ) query_results = await process_search_results(results, "dataset", k, sort_by) return QueryResponse(results=query_results) except Exception as e: logger.error(f"Search error: {str(e)}") raise HTTPException(status_code=500, detail="Search failed") @app.get("/similarity/datasets", response_model=QueryResponse) @cache(ttl=CACHE_TTL) async def find_similar_datasets( dataset_id: str, k: int = Query(default=5, ge=1, le=100), sort_by: str = Query( default="similarity", enum=["similarity", "likes", "downloads", "trending"] ), min_likes: int = Query(default=0, ge=0), min_downloads: int = Query(default=0, ge=0), ): try: collection = client.get_collection("dataset_cards") results = collection.get(ids=[dataset_id], include=["embeddings"]) if not results["ids"]: raise HTTPException( status_code=404, detail=f"Dataset ID '{dataset_id}' not found" ) results = collection.query( query_embeddings=[results["embeddings"][0]], n_results=k * 4 if sort_by != "similarity" else k + 1, where={ "$and": [ {"likes": {"$gte": min_likes}}, {"downloads": {"$gte": min_downloads}}, ] } if min_likes > 0 or min_downloads > 0 else None, ) query_results = await process_search_results( results, "dataset", k, sort_by, dataset_id ) return QueryResponse(results=query_results) except HTTPException: raise except Exception as e: logger.error(f"Similarity search error: {str(e)}") raise HTTPException(status_code=500, detail="Similarity search failed") @app.get("/search/models", response_model=ModelQueryResponse) @cache(ttl=CACHE_TTL) async def search_models( query: str, k: int = Query(default=5, ge=1, le=100, description="Number of results to return"), sort_by: str = Query( default="similarity", enum=["similarity", "likes", "downloads", "trending"], description="Sort method for results", ), min_likes: int = Query(default=0, ge=0, description="Minimum likes filter"), min_downloads: int = Query(default=0, ge=0, description="Minimum downloads filter"), min_param_count: int = Query( default=0, ge=0, description="Minimum parameter count (models with param_count=0 will be excluded if any param filter is used)", ), max_param_count: Optional[int] = Query( default=None, ge=0, description="Maximum parameter count (None means no upper limit)", ), ): """ Search for models based on a text query with optional filtering. - When min_param_count > 0 or max_param_count is specified, models with param_count=0 are excluded - param_count=0 indicates missing/unknown parameter count in the dataset """ try: collection = client.get_collection(name="model_cards") where_conditions = [] if min_likes > 0: where_conditions.append({"likes": {"$gte": min_likes}}) if min_downloads > 0: where_conditions.append({"downloads": {"$gte": min_downloads}}) # Add parameter count filters using_param_filters = min_param_count > 0 or max_param_count is not None if using_param_filters: # Always exclude zero param count when using any parameter filters where_conditions.append({"param_count": {"$gt": 0}}) if min_param_count > 0: where_conditions.append({"param_count": {"$gte": min_param_count}}) if max_param_count is not None: where_conditions.append({"param_count": {"$lte": max_param_count}}) # Handle where clause creation based on number of conditions where_clause = None if len(where_conditions) > 1: where_clause = {"$and": where_conditions} elif len(where_conditions) == 1: where_clause = where_conditions[0] # Single condition without $and query_embedding = embed_query(f"search_query: {query}") results = collection.query( query_embeddings=[query_embedding], n_results=k * 4 if sort_by != "similarity" else k, where=where_clause, ) query_results = await process_search_results(results, "model", k, sort_by) return ModelQueryResponse(results=query_results) except Exception as e: logger.error(f"Model search error: {str(e)}") raise HTTPException(status_code=500, detail="Model search failed") @app.get("/similarity/models", response_model=ModelQueryResponse) @cache(ttl=CACHE_TTL) async def find_similar_models( model_id: str, k: int = Query(default=5, ge=1, le=100, description="Number of results to return"), sort_by: str = Query( default="similarity", enum=["similarity", "likes", "downloads", "trending"], description="Sort method for results", ), min_likes: int = Query(default=0, ge=0, description="Minimum likes filter"), min_downloads: int = Query(default=0, ge=0, description="Minimum downloads filter"), min_param_count: int = Query( default=0, ge=0, description="Minimum parameter count (models with param_count=0 will be excluded if any param filter is used)", ), max_param_count: Optional[int] = Query( default=None, ge=0, description="Maximum parameter count (None means no upper limit)", ), ): """ Find similar models to a specified model with optional filtering. - When min_param_count > 0 or max_param_count is specified, models with param_count=0 are excluded - param_count=0 indicates missing/unknown parameter count in the dataset """ try: collection = client.get_collection("model_cards") results = collection.get(ids=[model_id], include=["embeddings"]) if not results["ids"]: raise HTTPException( status_code=404, detail=f"Model ID '{model_id}' not found" ) where_conditions = [] if min_likes > 0: where_conditions.append({"likes": {"$gte": min_likes}}) if min_downloads > 0: where_conditions.append({"downloads": {"$gte": min_downloads}}) # Add parameter count filters using_param_filters = min_param_count > 0 or max_param_count is not None if using_param_filters: # Always exclude zero param count when using any parameter filters where_conditions.append({"param_count": {"$gt": 0}}) if min_param_count > 0: where_conditions.append({"param_count": {"$gte": min_param_count}}) if max_param_count is not None: where_conditions.append({"param_count": {"$lte": max_param_count}}) # Handle where clause creation based on number of conditions where_clause = None if len(where_conditions) > 1: where_clause = {"$and": where_conditions} elif len(where_conditions) == 1: where_clause = where_conditions[0] # Single condition without $and results = collection.query( query_embeddings=[results["embeddings"][0]], n_results=k * 4 if sort_by != "similarity" else k + 1, where=where_clause, ) query_results = await process_search_results( results, "model", k, sort_by, model_id ) return ModelQueryResponse(results=query_results) except HTTPException: raise except Exception as e: logger.error(f"Model similarity search error: {str(e)}") raise HTTPException(status_code=500, detail="Model similarity search failed") @cache(ttl="1h") async def get_trending_score(item_id: str, item_type: str) -> float: """Fetch trending score for a model or dataset from HuggingFace API""" try: async with httpx.AsyncClient() as client: endpoint = "models" if item_type == "model" else "datasets" response = await client.get( f"https://huggingface.co/api/{endpoint}/{item_id}?expand=trendingScore" ) response.raise_for_status() return response.json().get("trendingScore", 0) except Exception as e: logger.error( f"Error fetching trending score for {item_type} {item_id}: {str(e)}" ) return 0 async def process_search_results(results, id_field, k, sort_by, exclude_id=None): """Process search results into a standardized format.""" query_results = [] # Create base results for i in range(len(results["ids"][0])): current_id = results["ids"][0][i] if exclude_id and current_id == exclude_id: continue result = { f"{id_field}_id": current_id, "similarity": float(results["distances"][0][i]), "summary": results["documents"][0][i], "likes": results["metadatas"][0][i]["likes"], "downloads": results["metadatas"][0][i]["downloads"], } # Add param_count for models if it exists in metadata if id_field == "model" and "param_count" in results["metadatas"][0][i]: result["param_count"] = results["metadatas"][0][i]["param_count"] if id_field == "dataset": query_results.append(QueryResult(**result)) else: query_results.append(ModelQueryResult(**result)) # Handle sorting if sort_by == "trending": # Fetch trending scores for all results trending_scores = {} async with httpx.AsyncClient() as client: tasks = [ get_trending_score( getattr(result, f"{id_field}_id"), "model" if id_field == "model" else "dataset", ) for result in query_results ] scores = await asyncio.gather(*tasks) trending_scores = { getattr(result, f"{id_field}_id"): score for result, score in zip(query_results, scores) } # Sort by trending score query_results.sort( key=lambda x: trending_scores.get(getattr(x, f"{id_field}_id"), 0), reverse=True, ) query_results = query_results[:k] elif sort_by != "similarity": query_results.sort(key=lambda x: getattr(x, sort_by), reverse=True) query_results = query_results[:k] elif exclude_id: # We fetched extra for similarity + exclude_id case query_results = query_results[:k] return query_results async def fetch_trending_models(): """Fetch trending models from HuggingFace API""" async with httpx.AsyncClient() as client: response = await client.get("https://huggingface.co/api/models") response.raise_for_status() return response.json() @cache(ttl=TRENDING_CACHE_TTL) async def get_trending_models_with_summaries( limit: int = 10, min_likes: int = 0, min_downloads: int = 0, min_param_count: int = 0, max_param_count: Optional[int] = None, ) -> List[ModelQueryResult]: """Fetch trending models and combine with summaries from database""" try: # Fetch trending models trending_models = await fetch_trending_models() # Filter by minimum likes/downloads trending_models = [ model for model in trending_models if model.get("likes", 0) >= min_likes and model.get("downloads", 0) >= min_downloads ] # Sort by trending score trending_models = sorted( trending_models, key=lambda x: x.get("trendingScore", 0), reverse=True ) # Fetch up to 3x the limit (buffer for filtering) or all available if fewer # This ensures we have enough models to filter from fetch_limit = min(len(trending_models), limit * 3) trending_models = trending_models[:fetch_limit] # Get model IDs model_ids = [model["modelId"] for model in trending_models] # Fetch summaries from ChromaDB collection = client.get_collection("model_cards") summaries = collection.get(ids=model_ids, include=["documents", "metadatas"]) # Create mapping of model_id to summary and metadata id_to_summary = dict(zip(summaries["ids"], summaries["documents"])) id_to_metadata = dict(zip(summaries["ids"], summaries["metadatas"])) # Log parameters for debugging print( f"Filter params - min_param_count: {min_param_count}, max_param_count: {max_param_count}" ) # Combine data - collect all results first all_results = [] for model in trending_models: if model["modelId"] in id_to_summary: metadata = id_to_metadata.get(model["modelId"], {}) param_count = metadata.get("param_count", 0) # Log model parameter counts print(f"Model: {model['modelId']}, param_count: {param_count}") result = ModelQueryResult( model_id=model["modelId"], similarity=1.0, # Not applicable for trending summary=id_to_summary[model["modelId"]], likes=model.get("likes", 0), downloads=model.get("downloads", 0), param_count=param_count, ) all_results.append(result) # Apply parameter filtering after collecting all results filtered_results = all_results # Check if any parameter filtering is being applied using_param_filters = min_param_count > 0 or max_param_count is not None # Only filter by params if we have specific parameter constraints if using_param_filters: filtered_results = [] for result in all_results: should_include = True # Always exclude models with param_count=0 when any parameter filtering is active if result.param_count == 0: print( f"Filtering out {result.model_id} - has param_count=0 but parameter filtering is active" ) should_include = False # Apply min param filter if specified elif min_param_count > 0 and result.param_count < min_param_count: print( f"Filtering out {result.model_id} - param_count {result.param_count} < min_param_count {min_param_count}" ) should_include = False # Apply max param filter if specified elif ( max_param_count is not None and result.param_count > max_param_count ): print( f"Filtering out {result.model_id} - param_count {result.param_count} > max_param_count {max_param_count}" ) should_include = False if should_include: filtered_results.append(result) print(f"After filtering: {len(filtered_results)} models remain") # Finally limit to the requested number return filtered_results[:limit] except Exception as e: logger.error(f"Error fetching trending models: {str(e)}") raise HTTPException(status_code=500, detail="Failed to fetch trending models") @app.get("/trending/models", response_model=ModelQueryResponse) async def get_trending_models( limit: int = Query( default=10, ge=1, le=100, description="Number of results to return" ), min_likes: int = Query(default=0, ge=0, description="Minimum likes filter"), min_downloads: int = Query(default=0, ge=0, description="Minimum downloads filter"), min_param_count: int = Query( default=0, ge=0, description="Minimum parameter count (models with param_count=0 will be excluded if any parameter filter is used)", ), max_param_count: Optional[int] = Query( default=None, ge=0, description="Maximum parameter count (None means no upper limit)", ), ): """ Get trending models with their summaries and optional filtering. - When min_param_count > 0 or max_param_count is specified, models with param_count=0 are excluded - param_count=0 indicates missing/unknown parameter count in the dataset """ print( f"Request for trending models with params: limit={limit}, min_likes={min_likes}, min_downloads={min_downloads}, min_param_count={min_param_count}, max_param_count={max_param_count}" ) results = await get_trending_models_with_summaries( limit=limit, min_likes=min_likes, min_downloads=min_downloads, min_param_count=min_param_count, max_param_count=max_param_count, ) print(f"Returning {len(results)} trending model results") return ModelQueryResponse(results=results) async def fetch_trending_datasets(): """Fetch trending datasets from HuggingFace API""" async with httpx.AsyncClient() as client: response = await client.get("https://huggingface.co/api/datasets") response.raise_for_status() return response.json() @cache(ttl=TRENDING_CACHE_TTL) async def get_trending_datasets_with_summaries( limit: int = 10, min_likes: int = 0, min_downloads: int = 0, ) -> List[QueryResult]: """Fetch trending datasets and combine with summaries from database""" try: # Fetch trending datasets trending_datasets = await fetch_trending_datasets() # Filter by minimum likes/downloads trending_datasets = [ dataset for dataset in trending_datasets if dataset.get("likes", 0) >= min_likes and dataset.get("downloads", 0) >= min_downloads ] # Sort by trending score and limit trending_datasets = sorted( trending_datasets, key=lambda x: x.get("trendingScore", 0), reverse=True )[:limit] # Get dataset IDs dataset_ids = [dataset["id"] for dataset in trending_datasets] # Fetch summaries from ChromaDB collection = client.get_collection("dataset_cards") summaries = collection.get(ids=dataset_ids, include=["documents"]) # Create mapping of dataset_id to summary id_to_summary = dict(zip(summaries["ids"], summaries["documents"])) # Combine data results = [] for dataset in trending_datasets: if dataset["id"] in id_to_summary: result = QueryResult( dataset_id=dataset["id"], similarity=1.0, # Not applicable for trending summary=id_to_summary[dataset["id"]], likes=dataset.get("likes", 0), downloads=dataset.get("downloads", 0), ) results.append(result) return results except Exception as e: logger.error(f"Error fetching trending datasets: {str(e)}") raise HTTPException(status_code=500, detail="Failed to fetch trending datasets") @app.get("/trending/datasets", response_model=QueryResponse) async def get_trending_datasets( limit: int = Query(default=10, ge=1, le=100), min_likes: int = Query(default=0, ge=0), min_downloads: int = Query(default=0, ge=0), ): """Get trending datasets with their summaries""" results = await get_trending_datasets_with_summaries( limit=limit, min_likes=min_likes, min_downloads=min_downloads ) return QueryResponse(results=results) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)