ndc8
Refactor backend service to support Gemma 3n model and update requirements; remove obsolete test script and add new dependency tests
4b4e9ed
| from dotenv import load_dotenv | |
| load_dotenv() | |
| import os | |
| import httpx | |
| # Hugging Face Spaces: Only transformers backend is supported (no vLLM, no llama-cpp/gguf) | |
| """ | |
| FastAPI Backend AI Service using Gemma-3n-E4B-it | |
| Provides OpenAI-compatible chat completion endpoints powered by google/gemma-3n-E4B-it | |
| """ | |
| import warnings | |
| # Suppress warnings before any other imports | |
| warnings.filterwarnings("ignore", category=FutureWarning, module="transformers") | |
| warnings.filterwarnings("ignore", message=".*slow image processor.*") | |
| warnings.filterwarnings("ignore", message=".*rope_scaling.*") | |
| # Direct Hugging Face caches to a writable folder under /tmp (use only HF_HOME, TRANSFORMERS_CACHE is deprecated) | |
| os.environ.setdefault("HF_HOME", "/tmp/.cache/huggingface") | |
| # Suppress advisory warnings from transformers (including deprecation warnings) | |
| os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1" | |
| hf_token = os.environ.get("HF_TOKEN") | |
| import asyncio | |
| import logging | |
| import time | |
| from contextlib import asynccontextmanager | |
| from typing import List, Dict, Any, Optional, Union | |
| from fastapi import FastAPI, HTTPException, Depends, Request | |
| from fastapi.responses import StreamingResponse, JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field, field_validator | |
| import uvicorn | |
| import requests | |
| from PIL import Image | |
| # Keep transformers imports as fallback | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| # Transformers imports (now fallback for non-GGUF models) | |
| from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, AutoConfig # type: ignore | |
| from transformers import BitsAndBytesConfig # type: ignore | |
| # Gemma 3n specific imports | |
| from transformers import Gemma3nForConditionalGeneration, AutoProcessor # type: ignore | |
| import torch | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Pydantic models for multimodal content | |
| class TextContent(BaseModel): | |
| type: str = Field(default="text", description="Content type") | |
| text: str = Field(..., description="Text content") | |
| def validate_type(cls, v: str) -> str: | |
| if v != "text": | |
| raise ValueError("Type must be 'text'") | |
| return v | |
| class ImageContent(BaseModel): | |
| type: str = Field(default="image", description="Content type") | |
| url: str = Field(..., description="Image URL") | |
| def validate_type(cls, v: str) -> str: | |
| if v != "image": | |
| raise ValueError("Type must be 'image'") | |
| return v | |
| # Pydantic models for OpenAI-compatible API | |
| class ChatMessage(BaseModel): | |
| role: str = Field(..., description="The role of the message author") | |
| content: Union[str, List[Union[TextContent, ImageContent]]] = Field(..., description="The content of the message - either string or list of content items") | |
| def validate_role(cls, v: str) -> str: | |
| if v not in ["system", "user", "assistant"]: | |
| raise ValueError("Role must be one of: system, user, assistant") | |
| return v | |
| class ChatCompletionRequest(BaseModel): | |
| model: str = Field(default_factory=lambda: os.environ.get("AI_MODEL", "google/gemma-3n-E4B-it"), description="The model to use for completion") | |
| messages: List[ChatMessage] = Field(..., description="List of messages in the conversation") | |
| max_tokens: Optional[int] = Field(default=512, ge=1, le=2048, description="Maximum tokens to generate") | |
| temperature: Optional[float] = Field(default=0.7, ge=0.0, le=2.0, description="Sampling temperature") | |
| stream: Optional[bool] = Field(default=False, description="Whether to stream responses") | |
| top_p: Optional[float] = Field(default=0.95, ge=0.0, le=1.0, description="Top-p sampling") | |
| class ChatCompletionChoice(BaseModel): | |
| index: int | |
| message: ChatMessage | |
| finish_reason: str | |
| class ChatCompletionResponse(BaseModel): | |
| id: str | |
| object: str = "chat.completion" | |
| created: int | |
| model: str | |
| choices: List[ChatCompletionChoice] | |
| class ChatCompletionChunk(BaseModel): | |
| id: str | |
| object: str = "chat.completion.chunk" | |
| created: int | |
| model: str | |
| choices: List[Dict[str, Any]] | |
| class HealthResponse(BaseModel): | |
| status: str | |
| model: str | |
| version: str | |
| class ModelInfo(BaseModel): | |
| id: str | |
| object: str = "model" | |
| created: int | |
| owned_by: str = "huggingface" | |
| class ModelsResponse(BaseModel): | |
| object: str = "list" | |
| data: List[ModelInfo] | |
| class CompletionRequest(BaseModel): | |
| prompt: str = Field(..., description="The prompt to complete") | |
| max_tokens: Optional[int] = Field(default=512, ge=1, le=2048) | |
| temperature: Optional[float] = Field(default=0.7, ge=0.0, le=2.0) | |
| # Model can be configured via environment variable - defaults to Gemma 3n (transformers format) | |
| current_model = os.environ.get("AI_MODEL", "google/gemma-3n-E4B-it") | |
| vision_model = os.environ.get("VISION_MODEL", "Salesforce/blip-image-captioning-base") | |
| # Transformers model support | |
| processor = None # For Gemma 3n we use AutoProcessor instead of just tokenizer | |
| model = None | |
| image_text_pipeline = None # type: ignore | |
| # Image processing utilities | |
| async def download_image(url: str) -> Image.Image: | |
| """Download and process image from URL""" | |
| try: | |
| response = requests.get(url, timeout=10) | |
| response.raise_for_status() | |
| image = Image.open(requests.compat.BytesIO(response.content)) # type: ignore | |
| return image | |
| except Exception as e: | |
| logger.error(f"Failed to download image from {url}: {e}") | |
| raise HTTPException(status_code=400, detail=f"Failed to download image: {str(e)}") | |
| def extract_text_and_images(content: Union[str, List[Any]]) -> tuple[str, List[str]]: | |
| """Extract text and image URLs from message content""" | |
| if isinstance(content, str): | |
| return content, [] | |
| text_parts: List[str] = [] | |
| image_urls: List[str] = [] | |
| for item in content: | |
| if hasattr(item, 'type'): | |
| if item.type == "text" and hasattr(item, 'text'): | |
| text_parts.append(str(item.text)) | |
| elif item.type == "image" and hasattr(item, 'url'): | |
| image_urls.append(str(item.url)) | |
| return " ".join(text_parts), image_urls | |
| def has_images(messages: List[ChatMessage]) -> bool: | |
| """Check if any messages contain images""" | |
| for message in messages: | |
| if isinstance(message.content, list): | |
| for item in message.content: | |
| if hasattr(item, 'type') and item.type == "image": | |
| return True | |
| return False | |
| async def lifespan(app: FastAPI): | |
| """Application lifespan manager for startup and shutdown events""" | |
| global processor, model, image_text_pipeline, current_model | |
| logger.info("🚀 Starting AI Backend Service (Hugging Face Spaces mode)...") | |
| try: | |
| logger.info(f"📥 Loading model with transformers: {current_model}") | |
| # For Gemma 3n models, use the specific classes | |
| if "gemma-3n" in current_model.lower(): | |
| processor = AutoProcessor.from_pretrained(current_model) | |
| model = Gemma3nForConditionalGeneration.from_pretrained( | |
| current_model, | |
| low_cpu_mem_usage=True, | |
| trust_remote_code=True, | |
| torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, | |
| ).eval() | |
| else: | |
| # Fallback for other models | |
| processor = AutoTokenizer.from_pretrained(current_model) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| current_model, | |
| low_cpu_mem_usage=True, | |
| trust_remote_code=True, | |
| ) | |
| logger.info(f"✅ Successfully loaded model and processor: {current_model}") | |
| # Gemma 3n is multimodal, so we don't need a separate image pipeline | |
| if "gemma-3n" not in current_model.lower(): | |
| # Load image pipeline for multimodal support (only for non-Gemma-3n models) | |
| try: | |
| logger.info(f"🖼️ Initializing image captioning pipeline with model: {vision_model}") | |
| image_text_pipeline = pipeline("image-to-text", model=vision_model) | |
| logger.info("✅ Image captioning pipeline loaded successfully") | |
| except Exception as e: | |
| logger.warning(f"⚠️ Could not load image captioning pipeline: {e}") | |
| image_text_pipeline = None | |
| else: | |
| logger.info("✅ Gemma 3n has built-in multimodal support") | |
| image_text_pipeline = None | |
| except Exception as e: | |
| logger.error(f"❌ Failed to initialize model: {e}") | |
| raise RuntimeError(f"Service initialization failed: {e}") | |
| yield | |
| logger.info("🔄 Shutting down AI Backend Service...") | |
| processor = None | |
| model = None | |
| image_text_pipeline = None | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="AI Backend Service - Gemma 3n", | |
| description="OpenAI-compatible chat completion API powered by google/gemma-3n-E4B-it", | |
| version="1.0.0", | |
| lifespan=lifespan | |
| ) | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Configure appropriately for production | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| def ensure_model_ready(): | |
| """Check if transformers model is loaded and ready""" | |
| if processor is None or model is None: | |
| raise HTTPException(status_code=503, detail="Service not ready - no model initialized (transformers)") | |
| def convert_messages_to_prompt(messages: List[ChatMessage]) -> str: | |
| """Convert OpenAI messages format to a single prompt string""" | |
| prompt_parts: List[str] = [] | |
| for message in messages: | |
| role = message.role | |
| # Extract text content (handle both string and list formats) | |
| if isinstance(message.content, str): | |
| content = message.content | |
| else: | |
| content, _ = extract_text_and_images(message.content) | |
| if role == "system": | |
| prompt_parts.append(f"System: {content}") | |
| elif role == "user": | |
| prompt_parts.append(f"Human: {content}") | |
| elif role == "assistant": | |
| prompt_parts.append(f"Assistant: {content}") | |
| # Add assistant prompt to continue | |
| prompt_parts.append("Assistant:") | |
| return "\n".join(prompt_parts) | |
| async def generate_multimodal_response( | |
| messages: List[ChatMessage], | |
| request: ChatCompletionRequest | |
| ) -> str: | |
| """Generate response using image-text-to-text pipeline for multimodal content""" | |
| if not image_text_pipeline: | |
| raise HTTPException(status_code=503, detail="Image processing not available - pipeline not initialized") | |
| try: | |
| # Find the last user message with images | |
| last_user_message = None | |
| for message in reversed(messages): | |
| if message.role == "user" and isinstance(message.content, list): | |
| last_user_message = message | |
| break | |
| if not last_user_message: | |
| raise HTTPException(status_code=400, detail="No user message with images found") | |
| # Extract text and images from the message | |
| text_content, image_urls = extract_text_and_images(last_user_message.content) | |
| if not image_urls: | |
| raise HTTPException(status_code=400, detail="No images found in the message") | |
| # Use the first image for now (could be extended to handle multiple images) | |
| image_url = image_urls[0] | |
| # Generate response using the image-to-text pipeline | |
| logger.info(f"🖼️ Processing image: {image_url}") | |
| try: | |
| # Use the pipeline directly with the image URL (no messages format needed for image-to-text) | |
| result = await asyncio.to_thread(lambda: image_text_pipeline(image_url)) # type: ignore | |
| # Handle response format from image-to-text pipeline | |
| if result and hasattr(result, '__len__') and len(result) > 0: # type: ignore | |
| first_result = result[0] # type: ignore | |
| if hasattr(first_result, 'get'): | |
| generated_text = first_result.get('generated_text', f'I can see an image at {image_url}.') # type: ignore | |
| else: | |
| generated_text = str(first_result) | |
| # Combine with user's text question if provided | |
| if text_content: | |
| response = f"Looking at this image, I can see: {generated_text}. " | |
| if "what" in text_content.lower() or "?" in text_content: | |
| response += f"Regarding your question '{text_content}': Based on what I can see, this appears to be {generated_text.lower()}." | |
| else: | |
| response += f"You mentioned: {text_content}" | |
| return response | |
| else: | |
| return f"I can see: {generated_text}" | |
| else: | |
| return f"I can see there's an image at {image_url}, but cannot process it right now." | |
| except Exception as pipeline_error: | |
| logger.warning(f"Pipeline error: {pipeline_error}") | |
| return f"I can see there's an image at {image_url}. The image appears to contain visual content that I'm having trouble processing right now." | |
| except Exception as e: | |
| logger.error(f"Error in multimodal generation: {e}") | |
| return f"I'm having trouble processing the image. Error: {str(e)}" | |
| def generate_response_local(messages: List[ChatMessage], max_tokens: int = 512, temperature: float = 0.7, top_p: float = 0.95) -> str: | |
| """Generate response using local transformers model with chat template.""" | |
| ensure_model_ready() | |
| try: | |
| logger.info(" Generating response using transformers model") | |
| return generate_response_transformers(messages, max_tokens, temperature, top_p) | |
| except Exception as e: | |
| logger.error(f"Local generation failed: {e}") | |
| return "I apologize, but I'm having trouble generating a response right now. Please try again." | |
| ## GGUF/llama-cpp support removed for Hugging Face Spaces | |
| def convert_messages_to_gemma_prompt(messages: List[ChatMessage]) -> str: | |
| """Convert OpenAI messages format to Gemma 3n chat format.""" | |
| # Gemma 3n uses specific format with <start_of_turn> and <end_of_turn> | |
| prompt_parts = ["<bos>"] | |
| for message in messages: | |
| role = message.role | |
| content = message.content | |
| if role == "system": | |
| prompt_parts.append(f"<start_of_turn>system\n{content}<end_of_turn>") | |
| elif role == "user": | |
| prompt_parts.append(f"<start_of_turn>user\n{content}<end_of_turn>") | |
| elif role == "assistant": | |
| prompt_parts.append(f"<start_of_turn>model\n{content}<end_of_turn>") | |
| # Add the start for model response | |
| prompt_parts.append("<start_of_turn>model\n") | |
| return "\n".join(prompt_parts) | |
| def generate_response_transformers(messages: List[ChatMessage], max_tokens: int = 512, temperature: float = 0.7, top_p: float = 0.95) -> str: | |
| """Generate response using transformers model with chat template.""" | |
| try: | |
| # Check if we're using Gemma 3n | |
| if "gemma-3n" in current_model.lower(): | |
| # Gemma 3n specific handling | |
| # Convert messages to HuggingFace format for chat template | |
| chat_messages = [] | |
| for m in messages: | |
| # Gemma 3n supports multimodal, but for now we'll handle text only | |
| if isinstance(m.content, str): | |
| content = [{"type": "text", "text": m.content}] | |
| else: | |
| # Extract text content for now (image support can be added later) | |
| text_content, _ = extract_text_and_images(m.content) | |
| content = [{"type": "text", "text": text_content}] | |
| chat_messages.append({"role": m.role, "content": content}) | |
| # Apply chat template using processor | |
| inputs = processor.apply_chat_template( | |
| chat_messages, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_dict=True, | |
| return_tensors="pt", | |
| ) | |
| # Generate with Gemma 3n | |
| input_len = inputs["input_ids"].shape[-1] | |
| with torch.inference_mode(): | |
| generation = model.generate( | |
| **inputs, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| do_sample=temperature > 0, | |
| ) | |
| generation = generation[0][input_len:] | |
| # Decode the response | |
| generated_text = processor.decode(generation, skip_special_tokens=True) | |
| return generated_text.strip() | |
| else: | |
| # Fallback for other models | |
| # Convert messages to HuggingFace format for chat template | |
| chat_messages = [] | |
| for m in messages: | |
| content_str = m.content if isinstance(m.content, str) else extract_text_and_images(m.content)[0] | |
| chat_messages.append({"role": m.role, "content": content_str}) | |
| # Apply chat template and tokenize | |
| inputs = processor.apply_chat_template( | |
| chat_messages, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_dict=True, | |
| return_tensors="pt", | |
| ) | |
| # Generate response | |
| outputs = model.generate( | |
| input_ids=inputs["input_ids"], | |
| attention_mask=inputs.get("attention_mask"), | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| do_sample=temperature > 0, | |
| ) | |
| # Decode only the newly generated tokens (exclude input) | |
| generated_text = processor.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True) | |
| return generated_text.strip() | |
| except Exception as e: | |
| logger.error(f"Transformers generation failed: {e}") | |
| return "I apologize, but I'm having trouble generating a response right now. Please try again." | |
| async def root() -> Dict[str, Any]: | |
| """Root endpoint with service information""" | |
| return { | |
| "message": "AI Backend Service is running with Mistral Nemo!", | |
| "model": current_model, | |
| "version": "1.0.0", | |
| "endpoints": { | |
| "health": "/health", | |
| "models": "/v1/models", | |
| "chat_completions": "/v1/chat/completions" | |
| } | |
| } | |
| async def health_check(): | |
| """Health check endpoint""" | |
| global current_model, tokenizer, model | |
| return HealthResponse( | |
| status="healthy" if (tokenizer is not None and model is not None) else "unhealthy", | |
| model=current_model, | |
| version="1.0.0" | |
| ) | |
| async def list_models(): | |
| """List available models (OpenAI-compatible)""" | |
| models = [ | |
| ModelInfo( | |
| id=current_model, | |
| created=int(time.time()), | |
| owned_by="huggingface" | |
| ) | |
| ] | |
| # Add vision model if available | |
| if image_text_pipeline: | |
| models.append( | |
| ModelInfo( | |
| id=vision_model, | |
| created=int(time.time()), | |
| owned_by="huggingface" | |
| ) | |
| ) | |
| return ModelsResponse(data=models) | |
| # ...existing code... | |
| # --- Hugging Face Spaces: Only transformers backend supported --- | |
| async def create_chat_completion(request: ChatCompletionRequest) -> ChatCompletionResponse: | |
| """Create a chat completion (OpenAI-compatible) with multimodal support. Hugging Face Spaces: Only transformers backend supported.""" | |
| try: | |
| if not request.messages: | |
| raise HTTPException(status_code=400, detail="Messages cannot be empty") | |
| is_multimodal = has_images(request.messages) | |
| if is_multimodal: | |
| if not image_text_pipeline: | |
| raise HTTPException(status_code=503, detail="Image processing not available") | |
| response_text = await generate_multimodal_response(request.messages, request) | |
| else: | |
| logger.info(f"Generating local response for messages: {request.messages}") | |
| response_text = await asyncio.to_thread( | |
| generate_response_local, | |
| request.messages, | |
| request.max_tokens or 512, | |
| request.temperature or 0.7, | |
| request.top_p or 0.95 | |
| ) | |
| response_text = response_text.strip() if response_text else "No response generated." | |
| return ChatCompletionResponse( | |
| id=f"chatcmpl-{int(time.time())}", | |
| created=int(time.time()), | |
| model=request.model, | |
| choices=[ChatCompletionChoice( | |
| index=0, | |
| message=ChatMessage(role="assistant", content=response_text), | |
| finish_reason="stop" | |
| )] | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error in chat completion: {e}") | |
| raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") | |
| async def create_completion( | |
| request: CompletionRequest | |
| ) -> Dict[str, Any]: | |
| """Create a text completion (OpenAI-compatible)""" | |
| try: | |
| if not request.prompt: | |
| raise HTTPException(status_code=400, detail="Prompt cannot be empty") | |
| ensure_model_ready() | |
| # Use the prompt as a single user message | |
| messages = [ChatMessage(role="user", content=request.prompt)] | |
| response_text = await asyncio.to_thread( | |
| generate_response_local, | |
| messages, | |
| request.max_tokens or 512, | |
| request.temperature or 0.7, | |
| 0.95 | |
| ) | |
| return { | |
| "id": f"cmpl-{int(time.time())}", | |
| "object": "text_completion", | |
| "created": int(time.time()), | |
| "model": current_model, | |
| "choices": [{ | |
| "text": response_text, | |
| "index": 0, | |
| "finish_reason": "stop" | |
| }] | |
| } | |
| except Exception as e: | |
| logger.error(f"Error in completion: {e}") | |
| raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") | |
| async def api_response(request: Request) -> JSONResponse: | |
| """Endpoint to receive and send responses via API.""" | |
| try: | |
| data = await request.json() | |
| message = data.get("message", "No message provided") | |
| return JSONResponse(content={ | |
| "status": "success", | |
| "received_message": message, | |
| "response_message": f"You sent: {message}" | |
| }) | |
| except Exception as e: | |
| logger.error(f"Error processing API response: {e}") | |
| raise HTTPException(status_code=500, detail="Internal server error") | |
| # Main entry point moved to the end for proper initialization | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run("backend_service:app", host="0.0.0.0", port=8000, reload=True) | |