import asyncio import logging import os import sys import threading import time from contextlib import asynccontextmanager from datetime import datetime from typing import Any, Dict, List, Optional import pymongo from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field try: import faiss FAISS_AVAILABLE = True except ImportError: FAISS_AVAILABLE = False try: from rag import SessionRAG, initialize_models RAG_AVAILABLE = True except ImportError: RAG_AVAILABLE = False # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - [%(funcName)s:%(lineno)d] - %(message)s', handlers=[ logging.StreamHandler(sys.stdout), logging.FileHandler('rag_app.log', mode='a') ] ) logger = logging.getLogger(__name__) # --- Global State --- MONGO_CLIENT = None DB = None RAG_MODELS_INITIALIZED = False SESSION_STORES = {} # In-memory cache: {session_id: {session_rag, metadata, indexed}} STORE_LOCK = threading.RLock() APP_STATE = { "startup_time": None, "mongodb_connected": False, "rag_models_ready": False, "total_queries": 0, "errors": [] } # --- Pydantic Models --- class ChatRequest(BaseModel): message: str = Field(..., min_length=1, max_length=5000) class ChatResponse(BaseModel): success: bool answer: str sources: List[Dict[str, Any]] = Field(default_factory=list) processing_time: float session_id: str query_analysis: Optional[Dict[str, Any]] = None confidence: Optional[float] = None error_details: Optional[str] = None class HealthResponse(BaseModel): status: str mongodb_connected: bool rag_models_initialized: bool faiss_available: bool active_sessions: int memory_usage: Dict[str, Any] uptime_seconds: float last_error: Optional[str] = None # --- Helper Functions --- def create_session_logger(session_id: str): return logging.LoggerAdapter(logger, {'session_id': session_id[:8]}) def connect_mongodb(): """Connect to MongoDB Atlas""" global MONGO_CLIENT, DB try: mongodb_url = os.getenv("MONGODB_URL", "mongodb://localhost:27017/") logger.info(f"Connecting to MongoDB...") MONGO_CLIENT = pymongo.MongoClient( mongodb_url, serverSelectionTimeoutMS=5000 ) MONGO_CLIENT.admin.command('ping') DB = MONGO_CLIENT["legal_rag_system"] # Create indexes DB.chats.create_index("session_id", background=True) DB.chats.create_index("created_at", expireAfterSeconds=24 * 60 * 60, background=True) DB.sessions.create_index("session_id", unique=True, background=True) DB.chunks.create_index("session_id", background=True) APP_STATE["mongodb_connected"] = True logger.info("MongoDB connected successfully") return True except Exception as e: logger.error(f"MongoDB connection failed: {e}") APP_STATE["errors"].append(f"MongoDB error: {str(e)}") return False def init_rag_models(): """Initialize shared RAG models (embedding model, NLP model, etc.)""" global RAG_MODELS_INITIALIZED if not RAG_AVAILABLE or not FAISS_AVAILABLE: logger.error("RAG module or FAISS not available") return False try: model_id = os.getenv("EMBEDDING_MODEL_ID", "sentence-transformers/all-MiniLM-L6-v2") groq_api_key = os.getenv("GROQ_API_KEY") logger.info(f"Initializing shared RAG models with embedding model: {model_id}") initialize_models(model_id, groq_api_key) RAG_MODELS_INITIALIZED = True APP_STATE["rag_models_ready"] = True logger.info("Shared RAG models initialized successfully") return True except Exception as e: logger.error(f"RAG model initialization failed: {e}", exc_info=True) APP_STATE["errors"].append(f"RAG init failed: {str(e)}") return False def load_session_from_mongodb(session_id: str) -> Dict[str, Any]: """Load session data from MongoDB with better error handling""" session_logger = create_session_logger(session_id) session_logger.info(f"Loading session from MongoDB: {session_id}") if DB is None: # ✅ Fixed raise ValueError("Database not connected") try: # 1. Load session metadata session_doc = DB.sessions.find_one({"session_id": session_id}) if not session_doc: raise ValueError(f"Session {session_id} not found in database") # Check session status if session_doc.get("status") != "completed": raise ValueError(f"Session not ready - status: {session_doc.get('status')}") # 2. Load chunks with embeddings from MongoDB session_logger.info(f"Loading chunks for: {session_doc.get('filename', 'unknown')}") chunks_cursor = DB.chunks.find({"session_id": session_id}).sort("created_at", 1) chunks_list = list(chunks_cursor) if not chunks_list: raise ValueError(f"No chunks found for session {session_id}") session_logger.info(f"Found {len(chunks_list)} chunks with pre-computed embeddings") # 3. Create SessionRAG instance groq_api_key = os.getenv("GROQ_API_KEY") # Make sure to import from the correct module from rag import OptimizedSessionRAG # Or whatever your actual import is session_rag = OptimizedSessionRAG(session_id, groq_api_key) # 4. Load existing session data (rebuilds indices from stored embeddings) session_logger.info(f"Rebuilding search indices from existing embeddings...") session_rag.load_existing_session_data(chunks_list) # 5. Create session store object session_store = { "session_rag": session_rag, "indexed": True, "metadata": { "session_id": session_id, "title": session_doc.get("filename", "Document"), "chunk_count": len(chunks_list), "loaded_at": datetime.utcnow(), "document_info": { "filename": session_doc.get("filename", "Unknown"), "upload_date": session_doc.get("created_at") } } } session_logger.info("✓ Session loaded successfully with existing embeddings") return session_store except Exception as e: session_logger.error(f"Failed to load session from MongoDB: {e}", exc_info=True) raise ValueError(f"Failed to load session {session_id}: {str(e)}") def get_or_load_session(session_id: str) -> Dict[str, Any]: """ Get session from memory cache, or load from MongoDB if not in memory. Thread-safe with locking. """ with STORE_LOCK: # Check if already loaded in memory if session_id in SESSION_STORES: logger.info(f"Session {session_id[:8]} already in memory") return SESSION_STORES[session_id] # Not in memory - load from MongoDB logger.info(f"Session {session_id[:8]} not in memory, loading from MongoDB...") session_store = load_session_from_mongodb(session_id) SESSION_STORES[session_id] = session_store logger.info(f"Session {session_id[:8]} loaded and cached in memory") return session_store async def save_chat_message_safely(session_id: str, role: str, message: str): """Save chat messages to MongoDB asynchronously""" if DB is None: # ✅ CORRECT return try: await asyncio.to_thread( DB.chats.insert_one, { "session_id": session_id, "role": role, "message": message, "created_at": datetime.utcnow() } ) except Exception as e: logger.error(f"Failed to save chat message for session {session_id}: {e}") def get_chat_history_safely(session_id: str, limit: int = 50) -> List[Dict[str, Any]]: """Get chat history from MongoDB with error handling""" if DB is None: # ✅ CORRECT return [] try: chats_cursor = DB.chats.find({"session_id": session_id}).sort("created_at", -1).limit(limit) return list(chats_cursor)[::-1] # Reverse for chronological order except Exception as e: logger.error(f"Failed to get chat history for session {session_id}: {e}") return [] # --- Application Lifespan --- @asynccontextmanager async def lifespan(app: FastAPI): """Application startup and shutdown""" APP_STATE["startup_time"] = datetime.utcnow() logger.info("Starting RAG Chat Service...") # Initialize MongoDB and models connect_mongodb() init_rag_models() logger.info("✓ Service ready") yield # Cleanup on shutdown logger.info("Shutting down...") if MONGO_CLIENT: MONGO_CLIENT.close() logger.info("✓ Shutdown complete") # --- FastAPI App --- app = FastAPI( title="Session-based RAG Chat Service", description="RAG system with MongoDB session persistence", version="4.0.0", lifespan=lifespan ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/") async def root(): return { "service": "Session-based RAG Chat Service", "version": "4.0.0", "description": "Embeddings stored in MongoDB, lazy-loaded on demand" } @app.get("/health", response_model=HealthResponse) async def health_check(): """Health check endpoint""" uptime = (datetime.utcnow() - APP_STATE["startup_time"]).total_seconds() with STORE_LOCK: active_sessions = len(SESSION_STORES) indexed_sessions = sum(1 for s in SESSION_STORES.values() if s.get("indexed", False)) status = "healthy" if not RAG_MODELS_INITIALIZED or not APP_STATE["mongodb_connected"]: status = "degraded" return HealthResponse( status=status, mongodb_connected=APP_STATE["mongodb_connected"], rag_models_initialized=RAG_MODELS_INITIALIZED, faiss_available=FAISS_AVAILABLE, active_sessions=active_sessions, memory_usage={ "loaded_sessions": active_sessions, "indexed_sessions": indexed_sessions }, uptime_seconds=uptime, last_error=APP_STATE["errors"][-1] if APP_STATE["errors"] else None ) @app.post("/chat/{session_id}", response_model=ChatResponse) async def chat_with_document(session_id: str, request: ChatRequest): """ Main chat endpoint: 1. Load session from MongoDB if not in memory (lazy loading) 2. Process query using RAG pipeline 3. Save chat messages to MongoDB 4. Return answer with sources """ session_logger = create_session_logger(session_id) start_time = time.time() try: session_logger.info(f"Chat request: {request.message[:100]}...") # Get or load session (lazy loading from MongoDB) try: session_store = await asyncio.to_thread(get_or_load_session, session_id) session_rag = session_store["session_rag"] except Exception as load_error: session_logger.error(f"Failed to load session: {load_error}") raise HTTPException( status_code=404, detail=f"Session not found or failed to load: {str(load_error)}" ) # Process query using RAG pipeline session_logger.info(f"Processing query with RAG...") result = await asyncio.to_thread( session_rag.query_documents, request.message, top_k=5 ) if 'error' in result: session_logger.error(f"Query error: {result['error']}") raise HTTPException(status_code=500, detail=result['error']) APP_STATE["total_queries"] += 1 answer = result.get('answer', 'Unable to generate an answer.') # Save chat messages asynchronously to MongoDB asyncio.create_task(save_chat_message_safely(session_id, "user", request.message)) asyncio.create_task(save_chat_message_safely(session_id, "assistant", answer)) processing_time = time.time() - start_time session_logger.info(f"✓ Query processed in {processing_time:.2f}s") return ChatResponse( success=True, answer=answer, sources=result.get('sources', []), processing_time=processing_time, session_id=session_id, query_analysis=result.get('query_analysis'), confidence=result.get('confidence') ) except HTTPException: raise except Exception as e: session_logger.error(f"Chat processing failed: {e}", exc_info=True) APP_STATE["errors"].append(f"Chat error: {str(e)}") raise HTTPException( status_code=500, detail=f"Chat processing error: {str(e)}" ) @app.get("/history/{session_id}") async def get_session_history(session_id: str): """Get chat history for a session from MongoDB""" if DB is None: # ✅ Correct way to check raise HTTPException(status_code=503, detail="Database not connected") history = await asyncio.to_thread(get_chat_history_safely, session_id) return { "session_id": session_id, "chat_history": history, "count": len(history) } @app.get("/session/{session_id}/info") async def get_session_info(session_id: str): """Get session metadata from MongoDB""" if DB is None: # ✅ CORRECT raise HTTPException(status_code=503, detail="Database not connected") session_doc = await asyncio.to_thread(DB.sessions.find_one, {"session_id": session_id}) if not session_doc: raise HTTPException(status_code=404, detail="Session not found") # Convert ObjectId to string for JSON serialization session_doc['_id'] = str(session_doc['_id']) # Check if loaded in memory with STORE_LOCK: in_memory = session_id in SESSION_STORES return { "session_id": session_id, "metadata": session_doc, "in_memory": in_memory } @app.get("/debug/sessions/list") async def list_all_sessions(): """List all sessions in MongoDB to see what session IDs actually exist""" if DB is None: return {"error": "Database not connected"} try: # Get all sessions sessions = list(DB.sessions.find({}, { "session_id": 1, "filename": 1, "status": 1, "created_at": 1, "_id": 0 }).sort("created_at", -1).limit(20)) # Get total counts total_sessions = DB.sessions.count_documents({}) total_chunks = DB.chunks.count_documents({}) return { "total_sessions": total_sessions, "total_chunks": total_chunks, "recent_sessions": sessions, "session_ids_only": [s["session_id"] for s in sessions] } except Exception as e: return {"error": f"Failed to list sessions: {str(e)}"} @app.get("/debug/sessions/search/{partial_id}") async def search_sessions_by_partial_id(partial_id: str): """Search for sessions that contain the partial ID""" if DB is None: return {"error": "Database not connected"} try: # Search for sessions containing the partial ID sessions = list(DB.sessions.find({ "session_id": {"$regex": partial_id, "$options": "i"} }, { "session_id": 1, "filename": 1, "status": 1, "created_at": 1, "_id": 0 }).limit(10)) return { "search_term": partial_id, "matches": sessions, "match_count": len(sessions) } except Exception as e: return {"error": f"Search failed: {str(e)}"} @app.get("/debug/chunks/by-session/{session_id}") async def debug_chunks_for_session(session_id: str): """Check if chunks exist for a session ID (maybe the chunks use a different ID format)""" if DB is None: return {"error": "Database not connected"} try: # Check exact match chunks_exact = DB.chunks.count_documents({"session_id": session_id}) # Check partial matches (in case of ID truncation) chunks_partial = list(DB.chunks.find({ "session_id": {"$regex": session_id[:8], "$options": "i"} # First 8 chars }, { "session_id": 1, "chunk_id": 1, "_id": 0 }).limit(5)) # Check if any chunks exist at all sample_chunks = list(DB.chunks.find({}, { "session_id": 1, "_id": 0 }).limit(5)) return { "searched_session_id": session_id, "chunks_exact_match": chunks_exact, "chunks_partial_matches": chunks_partial, "sample_chunk_session_ids": [c.get("session_id") for c in sample_chunks] } except Exception as e: return {"error": f"Chunk search failed: {str(e)}"} @app.get("/debug/frontend-session") async def debug_frontend_session_issue(): """General debug info to help identify frontend/backend session ID mismatch""" if DB is None: return {"error": "Database not connected"} try: # Get sample of how session IDs are actually stored sessions_sample = list(DB.sessions.find({}, { "session_id": 1, "_id": 0 }).limit(5)) chunks_sample = list(DB.chunks.find({}, { "session_id": 1, "_id": 0 }).limit(5)) # Get session ID patterns session_id_lengths = {} for session in sessions_sample: sid = session.get("session_id", "") length = len(sid) if length not in session_id_lengths: session_id_lengths[length] = [] session_id_lengths[length].append(sid) return { "sessions_collection": { "sample_session_ids": [s.get("session_id") for s in sessions_sample], "session_id_lengths": session_id_lengths }, "chunks_collection": { "sample_session_ids": [c.get("session_id") for c in chunks_sample] }, "analysis": { "frontend_looking_for": "5ca64618-04fb-48c3-bb15-9d06eb720033", "frontend_id_length": len("5ca64618-04fb-48c3-bb15-9d06eb720033"), "frontend_id_format": "UUID with dashes" } } except Exception as e: return {"error": f"Debug failed: {str(e)}"} @app.get("/debug/session/{session_id}") async def debug_session_status(session_id: str): """Enhanced debug endpoint to check session status in MongoDB""" if DB is None: return {"error": "Database not connected"} try: # Check session document session_doc = DB.sessions.find_one({"session_id": session_id}) # Check chunks chunks_count = DB.chunks.count_documents({"session_id": session_id}) # If exact match fails, try partial matching partial_sessions = [] partial_chunks = 0 if not session_doc: # Try searching with first 8 characters (common truncation) short_id = session_id[:8] if len(session_id) >= 8 else session_id partial_sessions = list(DB.sessions.find({ "session_id": {"$regex": f"^{short_id}", "$options": "i"} }, {"session_id": 1, "filename": 1, "_id": 0}).limit(3)) partial_chunks = DB.chunks.count_documents({ "session_id": {"$regex": f"^{short_id}", "$options": "i"} }) # Sample chunks for debugging sample_chunks = list(DB.chunks.find( {"session_id": session_id}, {"chunk_id": 1, "content": 1, "embedding": 1, "_id": 0} ).limit(2)) # Check if chunks have embeddings chunks_with_embeddings = DB.chunks.count_documents({ "session_id": session_id, "embedding": {"$exists": True, "$ne": None} }) return { "searched_session_id": session_id, "session_id_length": len(session_id), "exact_match": { "session_exists": session_doc is not None, "session_status": session_doc.get("status") if session_doc else None, "session_filename": session_doc.get("filename") if session_doc else None, "chunks_count": chunks_count, "chunks_with_embeddings": chunks_with_embeddings, }, "partial_matches": { "sessions_found": partial_sessions, "chunks_count": partial_chunks }, "sample_chunks": [ { "chunk_id": chunk.get("chunk_id"), "content_length": len(chunk.get("content", "")), "has_embedding": chunk.get("embedding") is not None } for chunk in sample_chunks ], "in_memory_cache": session_id in SESSION_STORES, "suggestions": [ "Check if session was created with different upload service", "Verify frontend is using correct session ID from upload response", "Check if session creation completed successfully" ] } except Exception as e: return {"error": f"Debug failed: {str(e)}"} @app.delete("/session/{session_id}/cache") async def clear_session_cache(session_id: str): """Remove session from memory cache (data remains in MongoDB)""" with STORE_LOCK: if session_id in SESSION_STORES: store = SESSION_STORES.pop(session_id) session_rag = store.get("session_rag") if hasattr(session_rag, 'cleanup'): session_rag.cleanup() logger.info(f"Session {session_id[:8]} removed from memory cache") return { "success": True, "message": f"Session removed from memory cache", "note": "Data remains in MongoDB" } return { "success": False, "message": "Session not found in memory cache" } if __name__ == "__main__": import uvicorn port = int(os.getenv("PORT", 7860)) logger.info(f"Starting server on http://0.0.0.0:{port}") uvicorn.run("app:app", host="0.0.0.0", port=port, reload=True)