rag-chat / app.py
kn29's picture
Update app.py
f98f6b7 verified
import asyncio
import logging
import os
import sys
import threading
import time
from contextlib import asynccontextmanager
from datetime import datetime
from typing import Any, Dict, List, Optional
import pymongo
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
try:
import faiss
FAISS_AVAILABLE = True
except ImportError:
FAISS_AVAILABLE = False
try:
from rag import SessionRAG, initialize_models
RAG_AVAILABLE = True
except ImportError:
RAG_AVAILABLE = False
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - [%(funcName)s:%(lineno)d] - %(message)s',
handlers=[
logging.StreamHandler(sys.stdout),
logging.FileHandler('rag_app.log', mode='a')
]
)
logger = logging.getLogger(__name__)
# --- Global State ---
MONGO_CLIENT = None
DB = None
RAG_MODELS_INITIALIZED = False
SESSION_STORES = {} # In-memory cache: {session_id: {session_rag, metadata, indexed}}
STORE_LOCK = threading.RLock()
APP_STATE = {
"startup_time": None,
"mongodb_connected": False,
"rag_models_ready": False,
"total_queries": 0,
"errors": []
}
# --- Pydantic Models ---
class ChatRequest(BaseModel):
message: str = Field(..., min_length=1, max_length=5000)
class ChatResponse(BaseModel):
success: bool
answer: str
sources: List[Dict[str, Any]] = Field(default_factory=list)
processing_time: float
session_id: str
query_analysis: Optional[Dict[str, Any]] = None
confidence: Optional[float] = None
error_details: Optional[str] = None
class HealthResponse(BaseModel):
status: str
mongodb_connected: bool
rag_models_initialized: bool
faiss_available: bool
active_sessions: int
memory_usage: Dict[str, Any]
uptime_seconds: float
last_error: Optional[str] = None
# --- Helper Functions ---
def create_session_logger(session_id: str):
return logging.LoggerAdapter(logger, {'session_id': session_id[:8]})
def connect_mongodb():
"""Connect to MongoDB Atlas"""
global MONGO_CLIENT, DB
try:
mongodb_url = os.getenv("MONGODB_URL", "mongodb://localhost:27017/")
logger.info(f"Connecting to MongoDB...")
MONGO_CLIENT = pymongo.MongoClient(
mongodb_url,
serverSelectionTimeoutMS=5000
)
MONGO_CLIENT.admin.command('ping')
DB = MONGO_CLIENT["legal_rag_system"]
# Create indexes
DB.chats.create_index("session_id", background=True)
DB.chats.create_index("created_at", expireAfterSeconds=24 * 60 * 60, background=True)
DB.sessions.create_index("session_id", unique=True, background=True)
DB.chunks.create_index("session_id", background=True)
APP_STATE["mongodb_connected"] = True
logger.info("MongoDB connected successfully")
return True
except Exception as e:
logger.error(f"MongoDB connection failed: {e}")
APP_STATE["errors"].append(f"MongoDB error: {str(e)}")
return False
def init_rag_models():
"""Initialize shared RAG models (embedding model, NLP model, etc.)"""
global RAG_MODELS_INITIALIZED
if not RAG_AVAILABLE or not FAISS_AVAILABLE:
logger.error("RAG module or FAISS not available")
return False
try:
model_id = os.getenv("EMBEDDING_MODEL_ID", "sentence-transformers/all-MiniLM-L6-v2")
groq_api_key = os.getenv("GROQ_API_KEY")
logger.info(f"Initializing shared RAG models with embedding model: {model_id}")
initialize_models(model_id, groq_api_key)
RAG_MODELS_INITIALIZED = True
APP_STATE["rag_models_ready"] = True
logger.info("Shared RAG models initialized successfully")
return True
except Exception as e:
logger.error(f"RAG model initialization failed: {e}", exc_info=True)
APP_STATE["errors"].append(f"RAG init failed: {str(e)}")
return False
def load_session_from_mongodb(session_id: str) -> Dict[str, Any]:
"""Load session data from MongoDB with better error handling"""
session_logger = create_session_logger(session_id)
session_logger.info(f"Loading session from MongoDB: {session_id}")
if DB is None: # ✅ Fixed
raise ValueError("Database not connected")
try:
# 1. Load session metadata
session_doc = DB.sessions.find_one({"session_id": session_id})
if not session_doc:
raise ValueError(f"Session {session_id} not found in database")
# Check session status
if session_doc.get("status") != "completed":
raise ValueError(f"Session not ready - status: {session_doc.get('status')}")
# 2. Load chunks with embeddings from MongoDB
session_logger.info(f"Loading chunks for: {session_doc.get('filename', 'unknown')}")
chunks_cursor = DB.chunks.find({"session_id": session_id}).sort("created_at", 1)
chunks_list = list(chunks_cursor)
if not chunks_list:
raise ValueError(f"No chunks found for session {session_id}")
session_logger.info(f"Found {len(chunks_list)} chunks with pre-computed embeddings")
# 3. Create SessionRAG instance
groq_api_key = os.getenv("GROQ_API_KEY")
# Make sure to import from the correct module
from rag import OptimizedSessionRAG # Or whatever your actual import is
session_rag = OptimizedSessionRAG(session_id, groq_api_key)
# 4. Load existing session data (rebuilds indices from stored embeddings)
session_logger.info(f"Rebuilding search indices from existing embeddings...")
session_rag.load_existing_session_data(chunks_list)
# 5. Create session store object
session_store = {
"session_rag": session_rag,
"indexed": True,
"metadata": {
"session_id": session_id,
"title": session_doc.get("filename", "Document"),
"chunk_count": len(chunks_list),
"loaded_at": datetime.utcnow(),
"document_info": {
"filename": session_doc.get("filename", "Unknown"),
"upload_date": session_doc.get("created_at")
}
}
}
session_logger.info("✓ Session loaded successfully with existing embeddings")
return session_store
except Exception as e:
session_logger.error(f"Failed to load session from MongoDB: {e}", exc_info=True)
raise ValueError(f"Failed to load session {session_id}: {str(e)}")
def get_or_load_session(session_id: str) -> Dict[str, Any]:
"""
Get session from memory cache, or load from MongoDB if not in memory.
Thread-safe with locking.
"""
with STORE_LOCK:
# Check if already loaded in memory
if session_id in SESSION_STORES:
logger.info(f"Session {session_id[:8]} already in memory")
return SESSION_STORES[session_id]
# Not in memory - load from MongoDB
logger.info(f"Session {session_id[:8]} not in memory, loading from MongoDB...")
session_store = load_session_from_mongodb(session_id)
SESSION_STORES[session_id] = session_store
logger.info(f"Session {session_id[:8]} loaded and cached in memory")
return session_store
async def save_chat_message_safely(session_id: str, role: str, message: str):
"""Save chat messages to MongoDB asynchronously"""
if DB is None: # ✅ CORRECT
return
try:
await asyncio.to_thread(
DB.chats.insert_one,
{
"session_id": session_id,
"role": role,
"message": message,
"created_at": datetime.utcnow()
}
)
except Exception as e:
logger.error(f"Failed to save chat message for session {session_id}: {e}")
def get_chat_history_safely(session_id: str, limit: int = 50) -> List[Dict[str, Any]]:
"""Get chat history from MongoDB with error handling"""
if DB is None: # ✅ CORRECT
return []
try:
chats_cursor = DB.chats.find({"session_id": session_id}).sort("created_at", -1).limit(limit)
return list(chats_cursor)[::-1] # Reverse for chronological order
except Exception as e:
logger.error(f"Failed to get chat history for session {session_id}: {e}")
return []
# --- Application Lifespan ---
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application startup and shutdown"""
APP_STATE["startup_time"] = datetime.utcnow()
logger.info("Starting RAG Chat Service...")
# Initialize MongoDB and models
connect_mongodb()
init_rag_models()
logger.info("✓ Service ready")
yield
# Cleanup on shutdown
logger.info("Shutting down...")
if MONGO_CLIENT:
MONGO_CLIENT.close()
logger.info("✓ Shutdown complete")
# --- FastAPI App ---
app = FastAPI(
title="Session-based RAG Chat Service",
description="RAG system with MongoDB session persistence",
version="4.0.0",
lifespan=lifespan
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
async def root():
return {
"service": "Session-based RAG Chat Service",
"version": "4.0.0",
"description": "Embeddings stored in MongoDB, lazy-loaded on demand"
}
@app.get("/health", response_model=HealthResponse)
async def health_check():
"""Health check endpoint"""
uptime = (datetime.utcnow() - APP_STATE["startup_time"]).total_seconds()
with STORE_LOCK:
active_sessions = len(SESSION_STORES)
indexed_sessions = sum(1 for s in SESSION_STORES.values() if s.get("indexed", False))
status = "healthy"
if not RAG_MODELS_INITIALIZED or not APP_STATE["mongodb_connected"]:
status = "degraded"
return HealthResponse(
status=status,
mongodb_connected=APP_STATE["mongodb_connected"],
rag_models_initialized=RAG_MODELS_INITIALIZED,
faiss_available=FAISS_AVAILABLE,
active_sessions=active_sessions,
memory_usage={
"loaded_sessions": active_sessions,
"indexed_sessions": indexed_sessions
},
uptime_seconds=uptime,
last_error=APP_STATE["errors"][-1] if APP_STATE["errors"] else None
)
@app.post("/chat/{session_id}", response_model=ChatResponse)
async def chat_with_document(session_id: str, request: ChatRequest):
"""
Main chat endpoint:
1. Load session from MongoDB if not in memory (lazy loading)
2. Process query using RAG pipeline
3. Save chat messages to MongoDB
4. Return answer with sources
"""
session_logger = create_session_logger(session_id)
start_time = time.time()
try:
session_logger.info(f"Chat request: {request.message[:100]}...")
# Get or load session (lazy loading from MongoDB)
try:
session_store = await asyncio.to_thread(get_or_load_session, session_id)
session_rag = session_store["session_rag"]
except Exception as load_error:
session_logger.error(f"Failed to load session: {load_error}")
raise HTTPException(
status_code=404,
detail=f"Session not found or failed to load: {str(load_error)}"
)
# Process query using RAG pipeline
session_logger.info(f"Processing query with RAG...")
result = await asyncio.to_thread(
session_rag.query_documents,
request.message,
top_k=5
)
if 'error' in result:
session_logger.error(f"Query error: {result['error']}")
raise HTTPException(status_code=500, detail=result['error'])
APP_STATE["total_queries"] += 1
answer = result.get('answer', 'Unable to generate an answer.')
# Save chat messages asynchronously to MongoDB
asyncio.create_task(save_chat_message_safely(session_id, "user", request.message))
asyncio.create_task(save_chat_message_safely(session_id, "assistant", answer))
processing_time = time.time() - start_time
session_logger.info(f"✓ Query processed in {processing_time:.2f}s")
return ChatResponse(
success=True,
answer=answer,
sources=result.get('sources', []),
processing_time=processing_time,
session_id=session_id,
query_analysis=result.get('query_analysis'),
confidence=result.get('confidence')
)
except HTTPException:
raise
except Exception as e:
session_logger.error(f"Chat processing failed: {e}", exc_info=True)
APP_STATE["errors"].append(f"Chat error: {str(e)}")
raise HTTPException(
status_code=500,
detail=f"Chat processing error: {str(e)}"
)
@app.get("/history/{session_id}")
async def get_session_history(session_id: str):
"""Get chat history for a session from MongoDB"""
if DB is None: # ✅ Correct way to check
raise HTTPException(status_code=503, detail="Database not connected")
history = await asyncio.to_thread(get_chat_history_safely, session_id)
return {
"session_id": session_id,
"chat_history": history,
"count": len(history)
}
@app.get("/session/{session_id}/info")
async def get_session_info(session_id: str):
"""Get session metadata from MongoDB"""
if DB is None: # ✅ CORRECT
raise HTTPException(status_code=503, detail="Database not connected")
session_doc = await asyncio.to_thread(DB.sessions.find_one, {"session_id": session_id})
if not session_doc:
raise HTTPException(status_code=404, detail="Session not found")
# Convert ObjectId to string for JSON serialization
session_doc['_id'] = str(session_doc['_id'])
# Check if loaded in memory
with STORE_LOCK:
in_memory = session_id in SESSION_STORES
return {
"session_id": session_id,
"metadata": session_doc,
"in_memory": in_memory
}
@app.get("/debug/sessions/list")
async def list_all_sessions():
"""List all sessions in MongoDB to see what session IDs actually exist"""
if DB is None:
return {"error": "Database not connected"}
try:
# Get all sessions
sessions = list(DB.sessions.find({}, {
"session_id": 1,
"filename": 1,
"status": 1,
"created_at": 1,
"_id": 0
}).sort("created_at", -1).limit(20))
# Get total counts
total_sessions = DB.sessions.count_documents({})
total_chunks = DB.chunks.count_documents({})
return {
"total_sessions": total_sessions,
"total_chunks": total_chunks,
"recent_sessions": sessions,
"session_ids_only": [s["session_id"] for s in sessions]
}
except Exception as e:
return {"error": f"Failed to list sessions: {str(e)}"}
@app.get("/debug/sessions/search/{partial_id}")
async def search_sessions_by_partial_id(partial_id: str):
"""Search for sessions that contain the partial ID"""
if DB is None:
return {"error": "Database not connected"}
try:
# Search for sessions containing the partial ID
sessions = list(DB.sessions.find({
"session_id": {"$regex": partial_id, "$options": "i"}
}, {
"session_id": 1,
"filename": 1,
"status": 1,
"created_at": 1,
"_id": 0
}).limit(10))
return {
"search_term": partial_id,
"matches": sessions,
"match_count": len(sessions)
}
except Exception as e:
return {"error": f"Search failed: {str(e)}"}
@app.get("/debug/chunks/by-session/{session_id}")
async def debug_chunks_for_session(session_id: str):
"""Check if chunks exist for a session ID (maybe the chunks use a different ID format)"""
if DB is None:
return {"error": "Database not connected"}
try:
# Check exact match
chunks_exact = DB.chunks.count_documents({"session_id": session_id})
# Check partial matches (in case of ID truncation)
chunks_partial = list(DB.chunks.find({
"session_id": {"$regex": session_id[:8], "$options": "i"} # First 8 chars
}, {
"session_id": 1,
"chunk_id": 1,
"_id": 0
}).limit(5))
# Check if any chunks exist at all
sample_chunks = list(DB.chunks.find({}, {
"session_id": 1,
"_id": 0
}).limit(5))
return {
"searched_session_id": session_id,
"chunks_exact_match": chunks_exact,
"chunks_partial_matches": chunks_partial,
"sample_chunk_session_ids": [c.get("session_id") for c in sample_chunks]
}
except Exception as e:
return {"error": f"Chunk search failed: {str(e)}"}
@app.get("/debug/frontend-session")
async def debug_frontend_session_issue():
"""General debug info to help identify frontend/backend session ID mismatch"""
if DB is None:
return {"error": "Database not connected"}
try:
# Get sample of how session IDs are actually stored
sessions_sample = list(DB.sessions.find({}, {
"session_id": 1,
"_id": 0
}).limit(5))
chunks_sample = list(DB.chunks.find({}, {
"session_id": 1,
"_id": 0
}).limit(5))
# Get session ID patterns
session_id_lengths = {}
for session in sessions_sample:
sid = session.get("session_id", "")
length = len(sid)
if length not in session_id_lengths:
session_id_lengths[length] = []
session_id_lengths[length].append(sid)
return {
"sessions_collection": {
"sample_session_ids": [s.get("session_id") for s in sessions_sample],
"session_id_lengths": session_id_lengths
},
"chunks_collection": {
"sample_session_ids": [c.get("session_id") for c in chunks_sample]
},
"analysis": {
"frontend_looking_for": "5ca64618-04fb-48c3-bb15-9d06eb720033",
"frontend_id_length": len("5ca64618-04fb-48c3-bb15-9d06eb720033"),
"frontend_id_format": "UUID with dashes"
}
}
except Exception as e:
return {"error": f"Debug failed: {str(e)}"}
@app.get("/debug/session/{session_id}")
async def debug_session_status(session_id: str):
"""Enhanced debug endpoint to check session status in MongoDB"""
if DB is None:
return {"error": "Database not connected"}
try:
# Check session document
session_doc = DB.sessions.find_one({"session_id": session_id})
# Check chunks
chunks_count = DB.chunks.count_documents({"session_id": session_id})
# If exact match fails, try partial matching
partial_sessions = []
partial_chunks = 0
if not session_doc:
# Try searching with first 8 characters (common truncation)
short_id = session_id[:8] if len(session_id) >= 8 else session_id
partial_sessions = list(DB.sessions.find({
"session_id": {"$regex": f"^{short_id}", "$options": "i"}
}, {"session_id": 1, "filename": 1, "_id": 0}).limit(3))
partial_chunks = DB.chunks.count_documents({
"session_id": {"$regex": f"^{short_id}", "$options": "i"}
})
# Sample chunks for debugging
sample_chunks = list(DB.chunks.find(
{"session_id": session_id},
{"chunk_id": 1, "content": 1, "embedding": 1, "_id": 0}
).limit(2))
# Check if chunks have embeddings
chunks_with_embeddings = DB.chunks.count_documents({
"session_id": session_id,
"embedding": {"$exists": True, "$ne": None}
})
return {
"searched_session_id": session_id,
"session_id_length": len(session_id),
"exact_match": {
"session_exists": session_doc is not None,
"session_status": session_doc.get("status") if session_doc else None,
"session_filename": session_doc.get("filename") if session_doc else None,
"chunks_count": chunks_count,
"chunks_with_embeddings": chunks_with_embeddings,
},
"partial_matches": {
"sessions_found": partial_sessions,
"chunks_count": partial_chunks
},
"sample_chunks": [
{
"chunk_id": chunk.get("chunk_id"),
"content_length": len(chunk.get("content", "")),
"has_embedding": chunk.get("embedding") is not None
}
for chunk in sample_chunks
],
"in_memory_cache": session_id in SESSION_STORES,
"suggestions": [
"Check if session was created with different upload service",
"Verify frontend is using correct session ID from upload response",
"Check if session creation completed successfully"
]
}
except Exception as e:
return {"error": f"Debug failed: {str(e)}"}
@app.delete("/session/{session_id}/cache")
async def clear_session_cache(session_id: str):
"""Remove session from memory cache (data remains in MongoDB)"""
with STORE_LOCK:
if session_id in SESSION_STORES:
store = SESSION_STORES.pop(session_id)
session_rag = store.get("session_rag")
if hasattr(session_rag, 'cleanup'):
session_rag.cleanup()
logger.info(f"Session {session_id[:8]} removed from memory cache")
return {
"success": True,
"message": f"Session removed from memory cache",
"note": "Data remains in MongoDB"
}
return {
"success": False,
"message": "Session not found in memory cache"
}
if __name__ == "__main__":
import uvicorn
port = int(os.getenv("PORT", 7860))
logger.info(f"Starting server on http://0.0.0.0:{port}")
uvicorn.run("app:app", host="0.0.0.0", port=port, reload=True)