kn29 commited on
Commit
df9660d
·
verified ·
1 Parent(s): 05dbe82

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +642 -284
app.py CHANGED
@@ -1,206 +1,313 @@
1
  from fastapi import FastAPI, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
3
- from pydantic import BaseModel
4
  import pymongo
5
  import os
6
  import numpy as np
7
  from datetime import datetime, timedelta
8
  import logging
 
9
  from typing import Dict, Any, Optional, List
10
- import base64
11
- import json
12
  import threading
13
  import time
14
  from collections import defaultdict
15
- import faiss
 
16
 
17
- # Import our simplified advanced RAG system
18
- import rag
 
 
 
19
 
20
- # Configure logging
21
  logging.basicConfig(
22
  level=logging.INFO,
23
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
 
 
 
 
24
  )
25
  logger = logging.getLogger(__name__)
26
 
27
- # Initialize FastAPI app
28
- app = FastAPI(title="Advanced RAG Chat Service", version="1.0.0")
29
-
30
- # Add CORS middleware
31
- app.add_middleware(
32
- CORSMiddleware,
33
- allow_origins=["*"], # Configure this properly in production
34
- allow_credentials=True,
35
- allow_methods=["*"],
36
- allow_headers=["*"],
37
- )
38
-
39
- # Global variables
40
  MONGO_CLIENT = None
41
  DB = None
42
  RAG_INITIALIZED = False
 
 
 
 
 
 
 
 
 
43
 
44
- # In-memory session stores
45
- # Format: {session_id: {"chunks": [...], "faiss_index": faiss.Index, "indexed": bool, "metadata": {...}}}
46
- SESSION_STORES = {}
47
- STORE_LOCK = threading.RLock()
48
- CLEANUP_INTERVAL = 3600 # 1 hour cleanup interval
49
- STORE_TTL = 30 * 60 # 24 hours TTL for in-memory stores
50
 
51
- # Request/Response models
 
 
 
 
 
52
  class ChatRequest(BaseModel):
53
- message: str
54
 
55
  class ChatResponse(BaseModel):
56
  success: bool
57
  answer: str
58
- sources: List[Dict[str, Any]]
59
- chat_history: List[Dict[str, Any]]
60
  processing_time: float
61
  session_id: str
62
  query_analysis: Optional[Dict[str, Any]] = None
63
  confidence: Optional[float] = None
 
64
 
65
  class InitRequest(BaseModel):
66
- pass
67
 
68
  class InitResponse(BaseModel):
69
  success: bool
70
  session_id: str
71
  message: str
72
- chunk_count: int
73
- title: str
74
  document_info: Optional[Dict[str, Any]] = None
 
75
 
76
  class HealthResponse(BaseModel):
77
  status: str
78
  mongodb_connected: bool
79
  rag_initialized: bool
 
80
  active_sessions: int
81
  memory_usage: Dict[str, Any]
 
 
82
 
83
  def create_session_logger(session_id: str):
84
  """Create a logger with session context"""
85
- return logging.LoggerAdapter(logger, {'session_id': session_id})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  def connect_mongodb():
88
- """Initialize MongoDB connection"""
89
  global MONGO_CLIENT, DB
 
90
  try:
91
  mongodb_url = os.getenv("MONGODB_URL", "mongodb://localhost:27017/")
92
- MONGO_CLIENT = pymongo.MongoClient(mongodb_url)
93
- DB = MONGO_CLIENT["legal_rag_system"]
 
 
 
 
 
 
 
 
94
 
95
  # Test connection
96
- DB.command("ping")
 
97
 
98
- # Create indexes for chats collection
99
- logger.info("Creating MongoDB indexes for chats...")
100
- DB.chats.create_index("session_id")
101
- DB.chats.create_index("created_at", expireAfterSeconds=24*60*60) # 24 hour TTL
102
- DB.chats.create_index([("session_id", 1), ("created_at", 1)]) # Compound index
 
 
 
 
103
 
104
- logger.info("MongoDB connected successfully")
 
105
  return True
 
 
 
 
 
 
 
106
  except Exception as e:
107
  logger.error(f"MongoDB connection failed: {e}")
 
108
  return False
109
 
110
  def initialize_rag():
111
- """Initialize RAG system"""
112
  global RAG_INITIALIZED
 
 
 
 
 
 
 
 
 
113
  try:
114
  model_id = os.getenv("EMBEDDING_MODEL_ID", "sentence-transformers/all-MiniLM-L6-v2")
115
  groq_api_key = os.getenv("GROQ_API_KEY")
116
 
117
- logger.info(f"Initializing RAG system with model: {model_id}")
118
- rag.initialize_models(model_id, groq_api_key)
 
 
 
 
 
 
 
119
 
120
  RAG_INITIALIZED = True
 
121
  logger.info("RAG system initialized successfully")
122
  return True
 
 
 
 
123
  except Exception as e:
124
  logger.error(f"RAG initialization failed: {e}")
 
 
125
  return False
126
 
127
- def decode_embedding_from_storage(embedding_list: List[float]) -> np.ndarray:
128
- """Convert embedding from MongoDB list back to numpy array"""
129
  try:
130
- return np.array(embedding_list, dtype=np.float32)
 
 
 
 
 
 
 
 
 
 
 
 
131
  except Exception as e:
132
  logger.error(f"Failed to decode embedding: {e}")
133
  return np.array([])
134
 
135
  def load_session_from_mongodb(session_id: str) -> Dict[str, Any]:
136
- """Load session data from MongoDB with precomputed embeddings"""
137
  session_logger = create_session_logger(session_id)
138
 
 
 
 
139
  try:
140
- # Get session metadata
141
  session_doc = DB.sessions.find_one({"session_id": session_id})
142
  if not session_doc:
143
- raise ValueError(f"Session {session_id} not found")
144
 
145
- if session_doc.get("status") != "completed":
146
- raise ValueError(f"Session {session_id} not completed yet (status: {session_doc.get('status')})")
 
147
 
148
- session_logger.info("Loading session chunks with precomputed embeddings from MongoDB")
149
 
150
- # Get all chunks for this session with embeddings
151
  chunks_cursor = DB.chunks.find({"session_id": session_id}).sort("created_at", 1)
152
  chunks_list = list(chunks_cursor)
153
 
154
  if not chunks_list:
155
  raise ValueError(f"No chunks found for session {session_id}")
156
 
157
- session_logger.info(f"Found {len(chunks_list)} chunks with embeddings")
158
 
159
- # Convert MongoDB chunks to format needed by RAG system
160
  processed_chunks = []
161
  embeddings_matrix = []
 
162
 
163
  for i, chunk_doc in enumerate(chunks_list):
164
- # Decode the precomputed embedding
165
- embedding_list = chunk_doc.get('embedding', [])
166
- if not embedding_list:
167
- session_logger.warning(f"Chunk {chunk_doc.get('chunk_id', i)} missing embedding")
168
- continue
169
-
170
- embedding = decode_embedding_from_storage(embedding_list)
171
- if embedding.size == 0:
172
- session_logger.warning(f"Failed to decode embedding for chunk {chunk_doc.get('chunk_id', i)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  continue
174
-
175
- # Format chunk for RAG system
176
- processed_chunk = {
177
- 'id': chunk_doc.get('chunk_id', f'chunk_{i}'),
178
- 'text': chunk_doc['text'],
179
- 'title': chunk_doc.get('title', session_doc.get('title', 'Document')),
180
- 'section_type': chunk_doc.get('section_type', 'content'),
181
- 'importance_score': chunk_doc.get('importance_score', 1.0),
182
- 'entities': chunk_doc.get('entities', []),
183
- 'embedding': embedding # Precomputed embedding as numpy array
184
- }
185
-
186
- processed_chunks.append(processed_chunk)
187
- embeddings_matrix.append(embedding)
188
 
189
  if not processed_chunks:
190
- raise ValueError(f"No valid chunks with embeddings found for session {session_id}")
191
 
192
- # Stack embeddings for FAISS index
 
 
 
193
  embeddings_matrix = np.vstack(embeddings_matrix).astype('float32')
194
 
 
195
  session_store = {
196
  "chunks": processed_chunks,
197
  "embeddings_matrix": embeddings_matrix,
198
- "faiss_index": None, # Will be built in indexing step
199
  "indexed": False,
200
  "metadata": {
201
  "session_id": session_id,
202
- "title": session_doc.get("title", "Document"),
203
  "chunk_count": len(processed_chunks),
 
204
  "loaded_at": datetime.utcnow(),
205
  "document_info": {
206
  "filename": session_doc.get("filename", "Unknown"),
@@ -212,17 +319,21 @@ def load_session_from_mongodb(session_id: str) -> Dict[str, Any]:
212
  }
213
  }
214
 
215
- session_logger.info(f"Loaded {len(processed_chunks)} chunks with precomputed embeddings")
216
  return session_store
217
 
218
  except Exception as e:
219
- session_logger.error(f"Failed to load session from MongoDB: {e}")
 
220
  raise
221
 
222
- def build_faiss_index_from_embeddings(session_id: str) -> Dict[str, Any]:
223
- """Build FAISS index from precomputed embeddings"""
224
  session_logger = create_session_logger(session_id)
225
 
 
 
 
226
  with STORE_LOCK:
227
  if session_id not in SESSION_STORES:
228
  raise ValueError(f"Session {session_id} not loaded")
@@ -236,60 +347,76 @@ def build_faiss_index_from_embeddings(session_id: str) -> Dict[str, Any]:
236
  embeddings_matrix = store["embeddings_matrix"]
237
 
238
  try:
239
- session_logger.info(f"Building FAISS index from {len(chunks)} precomputed embeddings...")
240
 
241
- # Create FAISS index (Inner Product for normalized embeddings)
 
 
 
 
242
  dimension = embeddings_matrix.shape[1]
243
  faiss_index = faiss.IndexFlatIP(dimension)
244
-
245
- # Add embeddings to FAISS index
246
  faiss_index.add(embeddings_matrix)
247
 
248
- # Set global RAG data for this session
249
- rag.CHUNKS_DATA = chunks
250
- rag.DENSE_INDEX = faiss_index
251
-
252
- # Build other indices (BM25, concept graph, etc.) using precomputed chunks
253
- session_logger.info("Building additional retrieval indices...")
254
-
255
- # BM25 index for sparse retrieval
256
- tokenized_corpus = [chunk['text'].lower().split() for chunk in chunks]
257
- rag.BM25_INDEX = rag.BM25Okapi(tokenized_corpus)
258
-
259
- # ColBERT-style token index
260
- rag.TOKEN_TO_CHUNKS = defaultdict(set)
261
- for i, chunk in enumerate(chunks):
262
- tokens = chunk['text'].lower().split()
263
- for token in tokens:
264
- rag.TOKEN_TO_CHUNKS[token].add(i)
265
-
266
- # Legal concept graph
267
- import networkx as nx
268
- rag.CONCEPT_GRAPH = nx.Graph()
269
- for i, chunk in enumerate(chunks):
270
- rag.CONCEPT_GRAPH.add_node(i, text=chunk['text'][:200], importance=chunk['importance_score'])
271
 
272
- # Add edges between chunks with shared entities
273
- for j, other_chunk in enumerate(chunks[i+1:], i+1):
274
- shared_entities = set(e['text'] for e in chunk['entities']) & \
275
- set(e['text'] for e in other_chunk['entities'])
276
- if shared_entities:
277
- rag.CONCEPT_GRAPH.add_edge(i, j, weight=len(shared_entities))
278
-
279
- # Mark as indexed and store FAISS index
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  with STORE_LOCK:
281
  SESSION_STORES[session_id]["faiss_index"] = faiss_index
282
  SESSION_STORES[session_id]["indexed"] = True
283
 
284
- session_logger.info(f"FAISS index built successfully from precomputed embeddings: {len(chunks)} chunks indexed")
285
  return SESSION_STORES[session_id]["metadata"]
286
 
287
  except Exception as e:
288
- session_logger.error(f"Failed to build FAISS index from embeddings: {e}")
 
289
  raise
290
 
291
- def save_chat_message(session_id: str, role: str, message: str):
292
- """Save chat message to MongoDB"""
 
 
 
 
293
  try:
294
  chat_doc = {
295
  "session_id": session_id,
@@ -301,8 +428,11 @@ def save_chat_message(session_id: str, role: str, message: str):
301
  except Exception as e:
302
  logger.error(f"Failed to save chat message for session {session_id}: {e}")
303
 
304
- def get_chat_history(session_id: str, limit: int = 50) -> List[Dict[str, Any]]:
305
- """Get chat history for a session"""
 
 
 
306
  try:
307
  chats_cursor = DB.chats.find(
308
  {"session_id": session_id}
@@ -322,14 +452,8 @@ def get_chat_history(session_id: str, limit: int = 50) -> List[Dict[str, Any]]:
322
  logger.error(f"Failed to get chat history for session {session_id}: {e}")
323
  return []
324
 
325
- import asyncio
326
- from contextlib import asynccontextmanager
327
-
328
- # Global cleanup task
329
- cleanup_task = None
330
-
331
- def cleanup_old_stores():
332
- """Background cleanup of old in-memory stores - single run"""
333
  try:
334
  current_time = datetime.utcnow()
335
  expired_sessions = []
@@ -337,63 +461,119 @@ def cleanup_old_stores():
337
  with STORE_LOCK:
338
  for session_id, store in SESSION_STORES.items():
339
  loaded_at = store["metadata"]["loaded_at"]
340
- if (current_time - loaded_at).total_seconds() > STORE_TTL:
 
 
 
341
  expired_sessions.append(session_id)
342
 
 
343
  for session_id in expired_sessions:
344
- # Clean up FAISS index and other resources
345
- if SESSION_STORES[session_id].get("faiss_index"):
346
- del SESSION_STORES[session_id]["faiss_index"]
347
- del SESSION_STORES[session_id]
348
- logger.info(f"Cleaned up expired store for session: {session_id}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
349
 
350
  if expired_sessions:
351
- logger.info(f"Cleaned up {len(expired_sessions)} expired session stores")
 
 
352
 
353
  except Exception as e:
354
- logger.error(f"Cleanup error: {e}")
 
355
 
356
  async def periodic_cleanup():
357
- """Async periodic cleanup task"""
358
- global cleanup_task
359
  try:
360
  while True:
361
- cleanup_old_stores()
 
 
 
 
 
362
  await asyncio.sleep(CLEANUP_INTERVAL)
 
363
  except asyncio.CancelledError:
364
- logger.info("Cleanup task cancelled")
365
  raise
366
  except Exception as e:
367
- logger.error(f"Periodic cleanup error: {e}")
 
 
 
 
 
 
 
368
 
369
  @asynccontextmanager
370
  async def lifespan(app: FastAPI):
371
- """Application lifespan manager"""
372
  global cleanup_task
373
 
374
  # Startup
375
- logger.info("Starting up Advanced RAG Chat Service...")
 
376
 
377
- # Connect to MongoDB
 
 
 
 
 
 
 
 
 
 
 
 
378
  if not connect_mongodb():
379
- logger.error("Failed to connect to MongoDB")
380
- raise Exception("MongoDB connection failed")
381
 
382
- # Initialize RAG system
383
- if not initialize_rag():
384
- logger.error("Failed to initialize RAG system")
385
- raise Exception("RAG initialization failed")
386
 
387
- # Start background cleanup task
388
- cleanup_task = asyncio.create_task(periodic_cleanup())
389
- logger.info("Background cleanup task started")
 
 
 
 
390
 
391
- logger.info("Startup completed successfully")
 
 
 
392
 
393
  yield
394
 
395
  # Shutdown
396
- logger.info("Shutting down Advanced RAG Chat Service...")
397
 
398
  if cleanup_task:
399
  cleanup_task.cancel()
@@ -407,153 +587,237 @@ async def lifespan(app: FastAPI):
407
 
408
  logger.info("Shutdown completed")
409
 
410
- # Replace the FastAPI app initialization
411
  app = FastAPI(
412
- title="Advanced RAG Chat Service",
413
- version="1.0.0",
 
414
  lifespan=lifespan
415
  )
416
 
 
 
 
 
 
 
 
 
417
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
418
 
419
  @app.get("/health", response_model=HealthResponse)
420
  async def health_check():
421
- """Health check endpoint"""
422
  try:
423
- # Check MongoDB connection
424
  mongodb_connected = False
425
- active_sessions = 0
426
-
427
- if DB is not None:
428
  try:
429
  DB.command("ping")
430
  mongodb_connected = True
431
- # Count sessions with recent chats
432
- one_hour_ago = datetime.utcnow() - timedelta(hours=1)
433
- active_sessions = len(DB.chats.distinct("session_id", {"created_at": {"$gte": one_hour_ago}}))
434
  except:
435
  pass
436
 
437
- # Memory usage info
 
 
 
 
 
438
  with STORE_LOCK:
439
  memory_sessions = len(SESSION_STORES)
440
  indexed_sessions = sum(1 for store in SESSION_STORES.values() if store["indexed"])
441
 
 
 
 
 
 
 
 
 
 
442
  return HealthResponse(
443
- status="healthy" if mongodb_connected and RAG_INITIALIZED else "unhealthy",
444
  mongodb_connected=mongodb_connected,
445
  rag_initialized=RAG_INITIALIZED,
446
- active_sessions=active_sessions,
 
447
  memory_usage={
448
  "loaded_sessions": memory_sessions,
449
  "indexed_sessions": indexed_sessions,
450
- "store_ttl_hours": STORE_TTL / 3600
451
- }
 
 
 
452
  )
 
453
  except Exception as e:
454
  logger.error(f"Health check failed: {e}")
455
  return HealthResponse(
456
  status="unhealthy",
457
  mongodb_connected=False,
458
  rag_initialized=False,
 
459
  active_sessions=0,
460
- memory_usage={}
 
 
461
  )
462
 
463
  @app.post("/init/{session_id}", response_model=InitResponse)
464
  async def initialize_session(session_id: str, request: InitRequest):
465
- """Initialize RAG context for a session using precomputed embeddings"""
466
  session_logger = create_session_logger(session_id)
467
 
468
- if DB is None:
469
- raise HTTPException(status_code=503, detail="Database not connected")
470
-
471
- if not RAG_INITIALIZED:
472
- raise HTTPException(status_code=503, detail="RAG system not initialized")
473
-
474
- # Check if already loaded and indexed
475
- with STORE_LOCK:
476
- if session_id in SESSION_STORES and SESSION_STORES[session_id]["indexed"]:
477
- store = SESSION_STORES[session_id]
478
- metadata = store["metadata"]
479
- session_logger.info("Session already initialized and indexed with precomputed embeddings")
480
- return InitResponse(
481
- success=True,
482
- session_id=session_id,
483
- message="Session already initialized with precomputed embeddings",
484
- chunk_count=metadata["chunk_count"],
485
- title=metadata["title"],
486
- document_info=metadata["document_info"]
487
- )
488
-
489
  try:
490
- session_logger.info("Initializing session with precomputed embeddings from MongoDB")
 
 
 
 
 
491
 
492
- # Load session data with precomputed embeddings from MongoDB
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
493
  session_store = load_session_from_mongodb(session_id)
494
 
495
  # Store in memory
496
  with STORE_LOCK:
497
  SESSION_STORES[session_id] = session_store
 
498
 
499
- # Build FAISS index from precomputed embeddings (no re-embedding!)
500
- metadata = build_faiss_index_from_embeddings(session_id)
501
 
502
- session_logger.info(f"Session initialized with precomputed embeddings: {metadata['chunk_count']} chunks indexed")
503
 
504
  return InitResponse(
505
  success=True,
506
  session_id=session_id,
507
- message=f"Session initialized with precomputed embeddings: {metadata['chunk_count']} chunks ready for advanced RAG",
508
  chunk_count=metadata["chunk_count"],
509
  title=metadata["title"],
510
  document_info=metadata["document_info"]
511
  )
512
 
 
 
513
  except ValueError as e:
514
- session_logger.error(f"Session initialization failed: {e}")
515
- raise HTTPException(status_code=404, detail=str(e))
 
 
 
 
 
 
 
516
  except Exception as e:
517
  session_logger.error(f"Session initialization error: {e}")
518
- raise HTTPException(status_code=500, detail=f"Failed to initialize session: {str(e)}")
 
 
 
 
 
 
 
 
 
519
 
520
  @app.post("/chat/{session_id}", response_model=ChatResponse)
521
  async def chat_with_document(session_id: str, request: ChatRequest):
522
- """Handle chat query with advanced RAG using precomputed embeddings"""
523
  session_logger = create_session_logger(session_id)
524
  start_time = time.time()
525
 
526
- if DB is None:
527
- raise HTTPException(status_code=503, detail="Database not connected")
528
-
529
- if not RAG_INITIALIZED:
530
- raise HTTPException(status_code=503, detail="RAG system not initialized")
531
-
532
- # Validate request
533
- if not request.message.strip():
534
- raise HTTPException(status_code=400, detail="Empty message provided")
535
-
536
  try:
537
- session_logger.info(f"Processing advanced RAG query: {request.message[:100]}...")
 
 
538
 
539
- # Check if session is initialized and indexed
 
 
 
540
  with STORE_LOCK:
541
  if session_id not in SESSION_STORES:
542
  raise HTTPException(
543
- status_code=400,
544
- detail=f"Session {session_id} not initialized. Call /init/{session_id} first."
545
  )
546
 
547
  if not SESSION_STORES[session_id]["indexed"]:
548
  raise HTTPException(
549
  status_code=400,
550
- detail=f"Session {session_id} not indexed. Call /init/{session_id} first."
551
  )
552
 
553
- # Query using advanced RAG system (now using precomputed embeddings)
554
- result = rag.query_documents(request.message, top_k=5)
555
 
556
- if 'error' in result:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
557
  raise HTTPException(status_code=500, detail=result['error'])
558
 
559
  answer = result.get('answer', 'Unable to generate answer.')
@@ -561,28 +825,31 @@ async def chat_with_document(session_id: str, request: ChatRequest):
561
  query_analysis = result.get('query_analysis', {})
562
  confidence = result.get('confidence', 0.0)
563
 
564
- # Save chat messages to MongoDB for persistence
565
- save_chat_message(session_id, "user", request.message)
566
- save_chat_message(session_id, "assistant", answer)
567
 
568
- # Get updated chat history
569
- chat_history = get_chat_history(session_id)
570
 
571
  processing_time = time.time() - start_time
572
- session_logger.info(f"Advanced RAG query processed in {processing_time:.2f}s with confidence {confidence:.1f}% using precomputed embeddings")
573
-
574
- # Prepare sources for response
575
- formatted_sources = [
576
- {
577
- "chunk_id": source.get("chunk_id", ""),
578
- "title": source.get("title", ""),
579
- "section": source.get("section", ""),
580
- "relevance_score": source.get("relevance_score", 0.0),
581
- "text_preview": source.get("excerpt", "")[:300] + "..." if len(source.get("excerpt", "")) > 300 else source.get("excerpt", ""),
582
- "entities": source.get("entities", [])
583
- }
584
- for source in sources
585
- ]
 
 
 
586
 
587
  return ChatResponse(
588
  success=True,
@@ -598,19 +865,30 @@ async def chat_with_document(session_id: str, request: ChatRequest):
598
  except HTTPException:
599
  raise
600
  except Exception as e:
601
- session_logger.error(f"Advanced RAG chat processing failed: {e}")
602
- raise HTTPException(status_code=500, detail=f"Chat processing failed: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
603
 
604
  @app.get("/history/{session_id}")
605
  async def get_session_history(session_id: str):
606
  """Get chat history for a session"""
607
  session_logger = create_session_logger(session_id)
608
 
609
- if DB is None:
610
  raise HTTPException(status_code=503, detail="Database not connected")
611
 
612
  try:
613
- chat_history = get_chat_history(session_id, limit=100)
614
 
615
  session_logger.info(f"Retrieved {len(chat_history)} chat messages")
616
 
@@ -631,15 +909,31 @@ async def cleanup_session(session_id: str):
631
  session_logger = create_session_logger(session_id)
632
 
633
  try:
634
- # Remove from memory
 
635
  with STORE_LOCK:
636
  if session_id in SESSION_STORES:
 
 
 
 
 
 
 
 
 
637
  # Clean up FAISS index
638
- if SESSION_STORES[session_id].get("faiss_index"):
639
- del SESSION_STORES[session_id]["faiss_index"]
 
640
  del SESSION_STORES[session_id]
 
 
641
  session_logger.info("Session removed from memory")
642
 
 
 
 
643
  return {
644
  "success": True,
645
  "message": f"Session {session_id} cleaned up successfully"
@@ -651,65 +945,128 @@ async def cleanup_session(session_id: str):
651
 
652
  @app.get("/sessions/active")
653
  async def get_active_sessions():
654
- """Get information about active sessions in memory"""
655
  try:
 
 
656
  with STORE_LOCK:
657
  active_sessions = []
658
  for session_id, store in SESSION_STORES.items():
659
  metadata = store["metadata"]
 
 
 
 
660
  active_sessions.append({
661
  "session_id": session_id,
662
  "title": metadata["title"],
663
  "chunk_count": metadata["chunk_count"],
664
  "indexed": store["indexed"],
665
- "loaded_at": metadata["loaded_at"].isoformat(),
666
- "age_minutes": (datetime.utcnow() - metadata["loaded_at"]).total_seconds() / 60,
667
- "using_precomputed_embeddings": True
 
 
 
 
668
  })
 
 
 
669
 
670
  return {
671
  "success": True,
672
  "active_sessions": active_sessions,
673
- "total_sessions": len(active_sessions)
 
 
 
674
  }
675
 
676
  except Exception as e:
677
  logger.error(f"Failed to get active sessions: {e}")
678
  raise HTTPException(status_code=500, detail=f"Failed to get active sessions: {str(e)}")
679
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
680
  @app.get("/rag/status")
681
  async def get_rag_status():
682
- """Get advanced RAG system status"""
683
  try:
684
  return {
685
  "success": True,
686
  "rag_initialized": RAG_INITIALIZED,
 
 
 
 
 
 
 
687
  "optimization": {
688
- "using_precomputed_embeddings": True,
689
- "no_reembedding": True,
690
  "persistent_faiss_index": True,
691
- "mongodb_persistence": True
 
692
  },
693
  "features": {
694
  "multi_stage_retrieval": True,
695
- "dense_retrieval": "FAISS + Precomputed Legal-BERT Embeddings",
696
- "sparse_retrieval": "BM25",
697
- "entity_based_retrieval": "Legal NER + SpaCy",
698
- "graph_based_retrieval": "Legal Concept Graph",
699
  "query_analysis": "Legal Intent Classification",
700
  "answer_generation": "Groq LLM with IRAC Method"
701
  },
702
- "active_techniques": [
703
- "Dense Embedding Search (FAISS with Precomputed Embeddings)",
704
- "BM25 Sparse Retrieval",
705
- "ColBERT Token Matching",
706
- "Legal Entity Matching",
707
- "Concept Graph Expansion",
708
- "HyDE Query Expansion",
709
- "Multi-Query Retrieval",
710
- "Legal Section Classification",
711
- "Importance-based Ranking"
712
- ]
713
  }
714
 
715
  except Exception as e:
@@ -719,4 +1076,5 @@ async def get_rag_status():
719
  if __name__ == "__main__":
720
  import uvicorn
721
  port = int(os.getenv("PORT", 7861))
 
722
  uvicorn.run(app, host="0.0.0.0", port=port)
 
1
  from fastapi import FastAPI, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
3
+ from pydantic import BaseModel, Field
4
  import pymongo
5
  import os
6
  import numpy as np
7
  from datetime import datetime, timedelta
8
  import logging
9
+ import traceback
10
  from typing import Dict, Any, Optional, List
11
+ import asyncio
 
12
  import threading
13
  import time
14
  from collections import defaultdict
15
+ from contextlib import asynccontextmanager
16
+ import sys
17
 
18
+ try:
19
+ import faiss
20
+ FAISS_AVAILABLE = True
21
+ except ImportError:
22
+ FAISS_AVAILABLE = False
23
 
24
+ # Configure comprehensive logging
25
  logging.basicConfig(
26
  level=logging.INFO,
27
+ format='%(asctime)s - %(name)s - %(levelname)s - [%(funcName)s:%(lineno)d] - %(message)s',
28
+ handlers=[
29
+ logging.StreamHandler(sys.stdout),
30
+ logging.FileHandler('rag_app.log', mode='a')
31
+ ]
32
  )
33
  logger = logging.getLogger(__name__)
34
 
35
+ # Global state
 
 
 
 
 
 
 
 
 
 
 
 
36
  MONGO_CLIENT = None
37
  DB = None
38
  RAG_INITIALIZED = False
39
+ RAG_MODULE = None
40
+ APP_STATE = {
41
+ "startup_time": None,
42
+ "mongodb_connected": False,
43
+ "rag_ready": False,
44
+ "active_sessions": 0,
45
+ "total_queries": 0,
46
+ "errors": []
47
+ }
48
 
49
+ # Configuration - Session memory management
50
+ CLEANUP_INTERVAL = 1800 # Run cleanup every 30 minutes (1800 seconds)
51
+ STORE_TTL = 1800 # Sessions expire after 30 minutes of inactivity (1800 seconds)
 
 
 
52
 
53
+ # You can adjust these values:
54
+ # STORE_TTL = 900 # 15 minutes
55
+ # STORE_TTL = 3600 # 1 hour
56
+ # STORE_TTL = 7200 # 2 hours
57
+
58
+ # Request/Response models with validation
59
  class ChatRequest(BaseModel):
60
+ message: str = Field(..., min_length=1, max_length=5000, description="User's query message")
61
 
62
  class ChatResponse(BaseModel):
63
  success: bool
64
  answer: str
65
+ sources: List[Dict[str, Any]] = Field(default_factory=list)
66
+ chat_history: List[Dict[str, Any]] = Field(default_factory=list)
67
  processing_time: float
68
  session_id: str
69
  query_analysis: Optional[Dict[str, Any]] = None
70
  confidence: Optional[float] = None
71
+ error_details: Optional[str] = None
72
 
73
  class InitRequest(BaseModel):
74
+ force_reload: bool = Field(default=False, description="Force reload session even if already loaded")
75
 
76
  class InitResponse(BaseModel):
77
  success: bool
78
  session_id: str
79
  message: str
80
+ chunk_count: int = Field(default=0)
81
+ title: str = Field(default="Unknown Document")
82
  document_info: Optional[Dict[str, Any]] = None
83
+ error_details: Optional[str] = None
84
 
85
  class HealthResponse(BaseModel):
86
  status: str
87
  mongodb_connected: bool
88
  rag_initialized: bool
89
+ faiss_available: bool
90
  active_sessions: int
91
  memory_usage: Dict[str, Any]
92
+ uptime_seconds: float
93
+ last_error: Optional[str] = None
94
 
95
  def create_session_logger(session_id: str):
96
  """Create a logger with session context"""
97
+ return logging.LoggerAdapter(logger, {'session_id': session_id[:8]})
98
+
99
+ def safe_import_rag():
100
+ """Safely import RAG module with error handling"""
101
+ global RAG_MODULE
102
+ try:
103
+ import rag
104
+ RAG_MODULE = rag
105
+ logger.info("RAG module imported successfully")
106
+ return True
107
+ except ImportError as e:
108
+ logger.error(f"Failed to import RAG module: {e}")
109
+ logger.error("Make sure rag.py is in the same directory and all dependencies are installed")
110
+ return False
111
+ except Exception as e:
112
+ logger.error(f"Unexpected error importing RAG module: {e}")
113
+ logger.error(traceback.format_exc())
114
+ return False
115
 
116
  def connect_mongodb():
117
+ """Initialize MongoDB connection with comprehensive error handling"""
118
  global MONGO_CLIENT, DB
119
+
120
  try:
121
  mongodb_url = os.getenv("MONGODB_URL", "mongodb://localhost:27017/")
122
+ if not mongodb_url or mongodb_url == "mongodb://localhost:27017/":
123
+ logger.warning("Using default MongoDB URL - set MONGODB_URL environment variable for production")
124
+
125
+ logger.info(f"Connecting to MongoDB: {mongodb_url[:20]}...")
126
+ MONGO_CLIENT = pymongo.MongoClient(
127
+ mongodb_url,
128
+ serverSelectionTimeoutMS=10000, # 10 second timeout
129
+ connectTimeoutMS=10000,
130
+ socketTimeoutMS=10000
131
+ )
132
 
133
  # Test connection
134
+ MONGO_CLIENT.admin.command('ping')
135
+ DB = MONGO_CLIENT["legal_rag_system"]
136
 
137
+ logger.info("Creating MongoDB indexes...")
138
+ # Create indexes with error handling
139
+ try:
140
+ DB.chats.create_index("session_id", background=True)
141
+ DB.chats.create_index("created_at", expireAfterSeconds=24*60*60, background=True)
142
+ DB.chats.create_index([("session_id", 1), ("created_at", 1)], background=True)
143
+ logger.info("MongoDB indexes created successfully")
144
+ except Exception as idx_error:
145
+ logger.warning(f"Index creation failed (non-critical): {idx_error}")
146
 
147
+ APP_STATE["mongodb_connected"] = True
148
+ logger.info("MongoDB connected and configured successfully")
149
  return True
150
+
151
+ except pymongo.errors.ServerSelectionTimeoutError:
152
+ logger.error("MongoDB connection timeout - check if MongoDB is running and accessible")
153
+ return False
154
+ except pymongo.errors.ConfigurationError as e:
155
+ logger.error(f"MongoDB configuration error: {e}")
156
+ return False
157
  except Exception as e:
158
  logger.error(f"MongoDB connection failed: {e}")
159
+ logger.error(traceback.format_exc())
160
  return False
161
 
162
  def initialize_rag():
163
+ """Initialize RAG system with comprehensive error handling"""
164
  global RAG_INITIALIZED
165
+
166
+ if not RAG_MODULE:
167
+ logger.error("RAG module not available - cannot initialize")
168
+ return False
169
+
170
+ if not FAISS_AVAILABLE:
171
+ logger.error("FAISS library not available - RAG system requires FAISS")
172
+ return False
173
+
174
  try:
175
  model_id = os.getenv("EMBEDDING_MODEL_ID", "sentence-transformers/all-MiniLM-L6-v2")
176
  groq_api_key = os.getenv("GROQ_API_KEY")
177
 
178
+ logger.info(f"Initializing RAG system with embedding model: {model_id}")
179
+
180
+ if groq_api_key:
181
+ logger.info("Groq API key found - full RAG capabilities available")
182
+ else:
183
+ logger.warning("No Groq API key - some RAG features may be limited")
184
+
185
+ # Initialize with timeout protection
186
+ RAG_MODULE.initialize_models(model_id, groq_api_key)
187
 
188
  RAG_INITIALIZED = True
189
+ APP_STATE["rag_ready"] = True
190
  logger.info("RAG system initialized successfully")
191
  return True
192
+
193
+ except ImportError as e:
194
+ logger.error(f"Missing dependencies for RAG initialization: {e}")
195
+ return False
196
  except Exception as e:
197
  logger.error(f"RAG initialization failed: {e}")
198
+ logger.error(traceback.format_exc())
199
+ APP_STATE["errors"].append(f"RAG init failed: {str(e)}")
200
  return False
201
 
202
+ def decode_embedding_safely(embedding_list: List[float]) -> np.ndarray:
203
+ """Safely convert embedding from storage with validation"""
204
  try:
205
+ if not embedding_list or not isinstance(embedding_list, list):
206
+ raise ValueError("Invalid embedding data")
207
+
208
+ embedding = np.array(embedding_list, dtype=np.float32)
209
+
210
+ if embedding.size == 0:
211
+ raise ValueError("Empty embedding")
212
+
213
+ if np.isnan(embedding).any() or np.isinf(embedding).any():
214
+ raise ValueError("Embedding contains invalid values")
215
+
216
+ return embedding
217
+
218
  except Exception as e:
219
  logger.error(f"Failed to decode embedding: {e}")
220
  return np.array([])
221
 
222
  def load_session_from_mongodb(session_id: str) -> Dict[str, Any]:
223
+ """Load session with comprehensive error handling and validation"""
224
  session_logger = create_session_logger(session_id)
225
 
226
+ if not DB:
227
+ raise ValueError("Database not connected")
228
+
229
  try:
230
+ # Get and validate session metadata
231
  session_doc = DB.sessions.find_one({"session_id": session_id})
232
  if not session_doc:
233
+ raise ValueError(f"Session {session_id} not found in database")
234
 
235
+ session_status = session_doc.get("status")
236
+ if session_status != "completed":
237
+ raise ValueError(f"Session not ready - status: {session_status}")
238
 
239
+ session_logger.info(f"Loading session: {session_doc.get('filename', 'unknown')}")
240
 
241
+ # Load chunks with validation
242
  chunks_cursor = DB.chunks.find({"session_id": session_id}).sort("created_at", 1)
243
  chunks_list = list(chunks_cursor)
244
 
245
  if not chunks_list:
246
  raise ValueError(f"No chunks found for session {session_id}")
247
 
248
+ session_logger.info(f"Found {len(chunks_list)} chunks")
249
 
250
+ # Process chunks with validation
251
  processed_chunks = []
252
  embeddings_matrix = []
253
+ failed_chunks = 0
254
 
255
  for i, chunk_doc in enumerate(chunks_list):
256
+ try:
257
+ # Validate required fields
258
+ if 'text' not in chunk_doc or not chunk_doc['text'].strip():
259
+ session_logger.warning(f"Chunk {i} missing or empty text")
260
+ failed_chunks += 1
261
+ continue
262
+
263
+ # Decode embedding
264
+ embedding_list = chunk_doc.get('embedding', [])
265
+ embedding = decode_embedding_safely(embedding_list)
266
+
267
+ if embedding.size == 0:
268
+ session_logger.warning(f"Chunk {i} has invalid embedding")
269
+ failed_chunks += 1
270
+ continue
271
+
272
+ # Create processed chunk
273
+ processed_chunk = {
274
+ 'id': chunk_doc.get('chunk_id', f'chunk_{i}'),
275
+ 'text': chunk_doc['text'],
276
+ 'title': chunk_doc.get('title', session_doc.get('filename', 'Document')),
277
+ 'section_type': chunk_doc.get('section_type', 'content'),
278
+ 'importance_score': float(chunk_doc.get('importance_score', 1.0)),
279
+ 'entities': chunk_doc.get('entities', []),
280
+ 'embedding': embedding
281
+ }
282
+
283
+ processed_chunks.append(processed_chunk)
284
+ embeddings_matrix.append(embedding)
285
+
286
+ except Exception as chunk_error:
287
+ session_logger.error(f"Failed to process chunk {i}: {chunk_error}")
288
+ failed_chunks += 1
289
  continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
 
291
  if not processed_chunks:
292
+ raise ValueError(f"No valid chunks could be loaded (failed: {failed_chunks})")
293
 
294
+ if failed_chunks > 0:
295
+ session_logger.warning(f"Failed to load {failed_chunks} chunks, continuing with {len(processed_chunks)}")
296
+
297
+ # Create embeddings matrix
298
  embeddings_matrix = np.vstack(embeddings_matrix).astype('float32')
299
 
300
+ # Prepare session store
301
  session_store = {
302
  "chunks": processed_chunks,
303
  "embeddings_matrix": embeddings_matrix,
304
+ "faiss_index": None,
305
  "indexed": False,
306
  "metadata": {
307
  "session_id": session_id,
308
+ "title": session_doc.get("filename", "Document"),
309
  "chunk_count": len(processed_chunks),
310
+ "failed_chunks": failed_chunks,
311
  "loaded_at": datetime.utcnow(),
312
  "document_info": {
313
  "filename": session_doc.get("filename", "Unknown"),
 
319
  }
320
  }
321
 
322
+ session_logger.info(f"Session loaded successfully: {len(processed_chunks)} chunks")
323
  return session_store
324
 
325
  except Exception as e:
326
+ session_logger.error(f"Failed to load session: {e}")
327
+ session_logger.error(traceback.format_exc())
328
  raise
329
 
330
+ def build_faiss_index_safely(session_id: str) -> Dict[str, Any]:
331
+ """Build FAISS index with error handling"""
332
  session_logger = create_session_logger(session_id)
333
 
334
+ if not FAISS_AVAILABLE:
335
+ raise ValueError("FAISS library not available")
336
+
337
  with STORE_LOCK:
338
  if session_id not in SESSION_STORES:
339
  raise ValueError(f"Session {session_id} not loaded")
 
347
  embeddings_matrix = store["embeddings_matrix"]
348
 
349
  try:
350
+ session_logger.info(f"Building FAISS index for {len(chunks)} chunks...")
351
 
352
+ # Validate embeddings matrix
353
+ if embeddings_matrix.shape[0] != len(chunks):
354
+ raise ValueError("Embeddings matrix size mismatch with chunks")
355
+
356
+ # Create FAISS index
357
  dimension = embeddings_matrix.shape[1]
358
  faiss_index = faiss.IndexFlatIP(dimension)
 
 
359
  faiss_index.add(embeddings_matrix)
360
 
361
+ # Initialize RAG system components
362
+ if RAG_MODULE:
363
+ RAG_MODULE.CHUNKS_DATA = chunks
364
+ RAG_MODULE.DENSE_INDEX = faiss_index
365
+
366
+ # Build additional indices
367
+ session_logger.info("Building additional retrieval indices...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
368
 
369
+ try:
370
+ # BM25 index
371
+ tokenized_corpus = [chunk['text'].lower().split() for chunk in chunks]
372
+ RAG_MODULE.BM25_INDEX = RAG_MODULE.BM25Okapi(tokenized_corpus)
373
+
374
+ # Token index
375
+ RAG_MODULE.TOKEN_TO_CHUNKS = defaultdict(set)
376
+ for i, chunk in enumerate(chunks):
377
+ tokens = chunk['text'].lower().split()
378
+ for token in tokens:
379
+ RAG_MODULE.TOKEN_TO_CHUNKS[token].add(i)
380
+
381
+ # Concept graph
382
+ import networkx as nx
383
+ RAG_MODULE.CONCEPT_GRAPH = nx.Graph()
384
+ for i, chunk in enumerate(chunks):
385
+ RAG_MODULE.CONCEPT_GRAPH.add_node(
386
+ i,
387
+ text=chunk['text'][:200],
388
+ importance=chunk['importance_score']
389
+ )
390
+
391
+ # Add edges for shared entities
392
+ for j, other_chunk in enumerate(chunks[i+1:], i+1):
393
+ shared_entities = set(e.get('text', '') for e in chunk['entities']) & \
394
+ set(e.get('text', '') for e in other_chunk['entities'])
395
+ if shared_entities:
396
+ RAG_MODULE.CONCEPT_GRAPH.add_edge(i, j, weight=len(shared_entities))
397
+
398
+ except Exception as index_error:
399
+ session_logger.warning(f"Failed to build some retrieval indices: {index_error}")
400
+
401
+ # Mark as indexed
402
  with STORE_LOCK:
403
  SESSION_STORES[session_id]["faiss_index"] = faiss_index
404
  SESSION_STORES[session_id]["indexed"] = True
405
 
406
+ session_logger.info("FAISS index built successfully")
407
  return SESSION_STORES[session_id]["metadata"]
408
 
409
  except Exception as e:
410
+ session_logger.error(f"Failed to build FAISS index: {e}")
411
+ session_logger.error(traceback.format_exc())
412
  raise
413
 
414
+ def save_chat_message_safely(session_id: str, role: str, message: str):
415
+ """Save chat message with error handling"""
416
+ if not DB:
417
+ logger.warning("Database not available - chat message not saved")
418
+ return
419
+
420
  try:
421
  chat_doc = {
422
  "session_id": session_id,
 
428
  except Exception as e:
429
  logger.error(f"Failed to save chat message for session {session_id}: {e}")
430
 
431
+ def get_chat_history_safely(session_id: str, limit: int = 50) -> List[Dict[str, Any]]:
432
+ """Get chat history with error handling"""
433
+ if not DB:
434
+ return []
435
+
436
  try:
437
  chats_cursor = DB.chats.find(
438
  {"session_id": session_id}
 
452
  logger.error(f"Failed to get chat history for session {session_id}: {e}")
453
  return []
454
 
455
+ def cleanup_expired_sessions():
456
+ """Clean up only expired chat sessions from memory, keep server running"""
 
 
 
 
 
 
457
  try:
458
  current_time = datetime.utcnow()
459
  expired_sessions = []
 
461
  with STORE_LOCK:
462
  for session_id, store in SESSION_STORES.items():
463
  loaded_at = store["metadata"]["loaded_at"]
464
+ age_seconds = (current_time - loaded_at).total_seconds()
465
+
466
+ # Only expire sessions older than TTL (30 minutes)
467
+ if age_seconds > STORE_TTL:
468
  expired_sessions.append(session_id)
469
 
470
+ # Clean up expired sessions one by one
471
  for session_id in expired_sessions:
472
+ try:
473
+ store = SESSION_STORES[session_id]
474
+
475
+ # Clean up session-specific RAG instance
476
+ if "rag_instance" in store:
477
+ store["rag_instance"].cleanup()
478
+
479
+ # Clean up FAISS index
480
+ if store.get("faiss_index"):
481
+ del store["faiss_index"]
482
+
483
+ # Remove session from memory
484
+ del SESSION_STORES[session_id]
485
+
486
+ age_minutes = (current_time - store["metadata"]["loaded_at"]).total_seconds() / 60
487
+ logger.info(f"Expired session {session_id[:8]} removed from memory (age: {age_minutes:.1f} minutes)")
488
+
489
+ except Exception as cleanup_error:
490
+ logger.error(f"Error cleaning up session {session_id[:8]}: {cleanup_error}")
491
+
492
+ # Update active session count
493
+ APP_STATE["active_sessions"] = len(SESSION_STORES)
494
 
495
  if expired_sessions:
496
+ logger.info(f"Memory cleanup completed: {len(expired_sessions)} expired sessions removed, {len(SESSION_STORES)} sessions still active")
497
+ else:
498
+ logger.debug(f"No expired sessions found. {len(SESSION_STORES)} sessions still active in memory")
499
 
500
  except Exception as e:
501
+ logger.error(f"Session cleanup error: {e}")
502
+ logger.error(traceback.format_exc())
503
 
504
  async def periodic_cleanup():
505
+ """Periodic cleanup of expired sessions - keeps server running"""
506
+ cleanup_count = 0
507
  try:
508
  while True:
509
+ cleanup_count += 1
510
+ logger.debug(f"Running session cleanup cycle #{cleanup_count}")
511
+
512
+ cleanup_expired_sessions()
513
+
514
+ # Sleep for cleanup interval (30 minutes)
515
  await asyncio.sleep(CLEANUP_INTERVAL)
516
+
517
  except asyncio.CancelledError:
518
+ logger.info(f"Session cleanup task cancelled after {cleanup_count} cycles")
519
  raise
520
  except Exception as e:
521
+ logger.error(f"Periodic cleanup error in cycle #{cleanup_count}: {e}")
522
+ logger.error(traceback.format_exc())
523
+
524
+ # Don't break the loop - keep trying to clean up
525
+ await asyncio.sleep(60) # Wait 1 minute before retrying
526
+
527
+ # Global cleanup task
528
+ cleanup_task = None
529
 
530
  @asynccontextmanager
531
  async def lifespan(app: FastAPI):
532
+ """Application lifespan with comprehensive error handling"""
533
  global cleanup_task
534
 
535
  # Startup
536
+ logger.info("Starting Advanced RAG Chat Service...")
537
+ APP_STATE["startup_time"] = datetime.utcnow()
538
 
539
+ startup_success = True
540
+
541
+ # Check FAISS availability
542
+ if not FAISS_AVAILABLE:
543
+ logger.error("FAISS library not available - this is required for RAG functionality")
544
+ startup_success = False
545
+
546
+ # Import RAG module
547
+ if not safe_import_rag():
548
+ logger.error("RAG module import failed")
549
+ startup_success = False
550
+
551
+ # Connect to MongoDB (non-critical failure)
552
  if not connect_mongodb():
553
+ logger.error("MongoDB connection failed - continuing with limited functionality")
 
554
 
555
+ # Initialize RAG system (non-critical failure for basic health checks)
556
+ if RAG_MODULE and FAISS_AVAILABLE:
557
+ if not initialize_rag():
558
+ logger.error("RAG initialization failed - RAG features disabled")
559
 
560
+ # Start cleanup task if MongoDB is available
561
+ if APP_STATE["mongodb_connected"]:
562
+ try:
563
+ cleanup_task = asyncio.create_task(periodic_cleanup())
564
+ logger.info("Background cleanup task started")
565
+ except Exception as e:
566
+ logger.error(f"Failed to start cleanup task: {e}")
567
 
568
+ if startup_success:
569
+ logger.info("Startup completed successfully")
570
+ else:
571
+ logger.warning("Startup completed with errors - some features may be disabled")
572
 
573
  yield
574
 
575
  # Shutdown
576
+ logger.info("Shutting down...")
577
 
578
  if cleanup_task:
579
  cleanup_task.cancel()
 
587
 
588
  logger.info("Shutdown completed")
589
 
590
+ # Initialize FastAPI app
591
  app = FastAPI(
592
+ title="Advanced RAG Chat Service",
593
+ description="Robust RAG-based chat service with comprehensive error handling",
594
+ version="2.0.0",
595
  lifespan=lifespan
596
  )
597
 
598
+ # CORS configuration
599
+ app.add_middleware(
600
+ CORSMiddleware,
601
+ allow_origins=["*"],
602
+ allow_credentials=True,
603
+ allow_methods=["*"],
604
+ allow_headers=["*"],
605
+ )
606
 
607
+ # Root endpoint
608
+ @app.get("/")
609
+ async def root():
610
+ """Service information endpoint"""
611
+ uptime = (datetime.utcnow() - APP_STATE["startup_time"]).total_seconds() if APP_STATE["startup_time"] else 0
612
+
613
+ return {
614
+ "service": "Advanced RAG Chat Service",
615
+ "version": "2.0.0",
616
+ "status": "running",
617
+ "uptime_seconds": uptime,
618
+ "components": {
619
+ "mongodb": APP_STATE["mongodb_connected"],
620
+ "rag_system": APP_STATE["rag_ready"],
621
+ "faiss": FAISS_AVAILABLE
622
+ },
623
+ "active_sessions": len(SESSION_STORES),
624
+ "total_queries": APP_STATE["total_queries"],
625
+ "endpoints": {
626
+ "health": "GET /health",
627
+ "init": "POST /init/{session_id}",
628
+ "chat": "POST /chat/{session_id}",
629
+ "history": "GET /history/{session_id}",
630
+ "cleanup": "DELETE /session/{session_id}",
631
+ "status": "GET /sessions/active"
632
+ }
633
+ }
634
 
635
  @app.get("/health", response_model=HealthResponse)
636
  async def health_check():
637
+ """Comprehensive health check"""
638
  try:
639
+ # Test MongoDB connection
640
  mongodb_connected = False
641
+ if DB:
 
 
642
  try:
643
  DB.command("ping")
644
  mongodb_connected = True
 
 
 
645
  except:
646
  pass
647
 
648
+ # Calculate uptime
649
+ uptime = 0
650
+ if APP_STATE["startup_time"]:
651
+ uptime = (datetime.utcnow() - APP_STATE["startup_time"]).total_seconds()
652
+
653
+ # Memory usage
654
  with STORE_LOCK:
655
  memory_sessions = len(SESSION_STORES)
656
  indexed_sessions = sum(1 for store in SESSION_STORES.values() if store["indexed"])
657
 
658
+ # Overall status
659
+ status = "healthy"
660
+ if not FAISS_AVAILABLE:
661
+ status = "degraded"
662
+ elif not mongodb_connected and not RAG_INITIALIZED:
663
+ status = "unhealthy"
664
+
665
+ last_error = APP_STATE["errors"][-1] if APP_STATE["errors"] else None
666
+
667
  return HealthResponse(
668
+ status=status,
669
  mongodb_connected=mongodb_connected,
670
  rag_initialized=RAG_INITIALIZED,
671
+ faiss_available=FAISS_AVAILABLE,
672
+ active_sessions=memory_sessions,
673
  memory_usage={
674
  "loaded_sessions": memory_sessions,
675
  "indexed_sessions": indexed_sessions,
676
+ "store_ttl_minutes": STORE_TTL // 60,
677
+ "cleanup_interval_minutes": CLEANUP_INTERVAL // 60
678
+ },
679
+ uptime_seconds=uptime,
680
+ last_error=last_error
681
  )
682
+
683
  except Exception as e:
684
  logger.error(f"Health check failed: {e}")
685
  return HealthResponse(
686
  status="unhealthy",
687
  mongodb_connected=False,
688
  rag_initialized=False,
689
+ faiss_available=False,
690
  active_sessions=0,
691
+ memory_usage={},
692
+ uptime_seconds=0,
693
+ last_error=str(e)
694
  )
695
 
696
  @app.post("/init/{session_id}", response_model=InitResponse)
697
  async def initialize_session(session_id: str, request: InitRequest):
698
+ """Initialize session with comprehensive validation"""
699
  session_logger = create_session_logger(session_id)
700
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
701
  try:
702
+ # Validate prerequisites
703
+ if not DB:
704
+ raise HTTPException(status_code=503, detail="Database not connected")
705
+
706
+ if not RAG_INITIALIZED:
707
+ raise HTTPException(status_code=503, detail="RAG system not initialized")
708
 
709
+ if not FAISS_AVAILABLE:
710
+ raise HTTPException(status_code=503, detail="FAISS library not available")
711
+
712
+ # Check if already initialized
713
+ with STORE_LOCK:
714
+ if session_id in SESSION_STORES and SESSION_STORES[session_id]["indexed"] and not request.force_reload:
715
+ store = SESSION_STORES[session_id]
716
+ metadata = store["metadata"]
717
+ session_logger.info("Session already initialized")
718
+ return InitResponse(
719
+ success=True,
720
+ session_id=session_id,
721
+ message="Session already initialized",
722
+ chunk_count=metadata["chunk_count"],
723
+ title=metadata["title"],
724
+ document_info=metadata["document_info"]
725
+ )
726
+
727
+ session_logger.info("Initializing session...")
728
+
729
+ # Load session from MongoDB
730
  session_store = load_session_from_mongodb(session_id)
731
 
732
  # Store in memory
733
  with STORE_LOCK:
734
  SESSION_STORES[session_id] = session_store
735
+ APP_STATE["active_sessions"] = len(SESSION_STORES)
736
 
737
+ # Build FAISS index
738
+ metadata = build_faiss_index_safely(session_id)
739
 
740
+ session_logger.info(f"Session initialized: {metadata['chunk_count']} chunks ready")
741
 
742
  return InitResponse(
743
  success=True,
744
  session_id=session_id,
745
+ message=f"Session initialized successfully with {metadata['chunk_count']} chunks",
746
  chunk_count=metadata["chunk_count"],
747
  title=metadata["title"],
748
  document_info=metadata["document_info"]
749
  )
750
 
751
+ except HTTPException:
752
+ raise
753
  except ValueError as e:
754
+ session_logger.error(f"Session initialization validation error: {e}")
755
+ return InitResponse(
756
+ success=False,
757
+ session_id=session_id,
758
+ message="Session initialization failed",
759
+ chunk_count=0,
760
+ title="Error",
761
+ error_details=str(e)
762
+ )
763
  except Exception as e:
764
  session_logger.error(f"Session initialization error: {e}")
765
+ session_logger.error(traceback.format_exc())
766
+ APP_STATE["errors"].append(f"Init failed for {session_id[:8]}: {str(e)}")
767
+ return InitResponse(
768
+ success=False,
769
+ session_id=session_id,
770
+ message="Internal server error during initialization",
771
+ chunk_count=0,
772
+ title="Error",
773
+ error_details="Internal server error"
774
+ )
775
 
776
  @app.post("/chat/{session_id}", response_model=ChatResponse)
777
  async def chat_with_document(session_id: str, request: ChatRequest):
778
+ """Chat endpoint with comprehensive error handling"""
779
  session_logger = create_session_logger(session_id)
780
  start_time = time.time()
781
 
 
 
 
 
 
 
 
 
 
 
782
  try:
783
+ # Validate prerequisites
784
+ if not DB:
785
+ raise HTTPException(status_code=503, detail="Database not connected")
786
 
787
+ if not RAG_INITIALIZED or not RAG_MODULE:
788
+ raise HTTPException(status_code=503, detail="RAG system not initialized")
789
+
790
+ # Validate session
791
  with STORE_LOCK:
792
  if session_id not in SESSION_STORES:
793
  raise HTTPException(
794
+ status_code=400,
795
+ detail=f"Session not initialized. Call /init/{session_id} first."
796
  )
797
 
798
  if not SESSION_STORES[session_id]["indexed"]:
799
  raise HTTPException(
800
  status_code=400,
801
+ detail="Session not indexed properly. Try reinitializing."
802
  )
803
 
804
+ session_logger.info(f"Processing query: {request.message[:100]}...")
 
805
 
806
+ # Query RAG system
807
+ try:
808
+ result = RAG_MODULE.query_documents(request.message, top_k=5)
809
+ APP_STATE["total_queries"] += 1
810
+ except Exception as rag_error:
811
+ session_logger.error(f"RAG query failed: {rag_error}")
812
+ result = {
813
+ 'error': f'RAG processing failed: {str(rag_error)}',
814
+ 'answer': 'I apologize, but I encountered an error while processing your question. Please try again or rephrase your query.',
815
+ 'sources': [],
816
+ 'query_analysis': {},
817
+ 'confidence': 0.0
818
+ }
819
+
820
+ if 'error' in result and not result.get('answer'):
821
  raise HTTPException(status_code=500, detail=result['error'])
822
 
823
  answer = result.get('answer', 'Unable to generate answer.')
 
825
  query_analysis = result.get('query_analysis', {})
826
  confidence = result.get('confidence', 0.0)
827
 
828
+ # Save chat messages
829
+ save_chat_message_safely(session_id, "user", request.message)
830
+ save_chat_message_safely(session_id, "assistant", answer)
831
 
832
+ # Get chat history
833
+ chat_history = get_chat_history_safely(session_id)
834
 
835
  processing_time = time.time() - start_time
836
+ session_logger.info(f"Query processed in {processing_time:.2f}s, confidence: {confidence:.1f}%")
837
+
838
+ # Format sources
839
+ formatted_sources = []
840
+ for source in sources:
841
+ try:
842
+ formatted_source = {
843
+ "chunk_id": source.get("chunk_id", ""),
844
+ "title": source.get("title", ""),
845
+ "section": source.get("section", ""),
846
+ "relevance_score": float(source.get("relevance_score", 0.0)),
847
+ "text_preview": source.get("excerpt", "")[:300] + ("..." if len(source.get("excerpt", "")) > 300 else ""),
848
+ "entities": source.get("entities", [])
849
+ }
850
+ formatted_sources.append(formatted_source)
851
+ except Exception as source_error:
852
+ session_logger.warning(f"Failed to format source: {source_error}")
853
 
854
  return ChatResponse(
855
  success=True,
 
865
  except HTTPException:
866
  raise
867
  except Exception as e:
868
+ session_logger.error(f"Chat processing failed: {e}")
869
+ session_logger.error(traceback.format_exc())
870
+ APP_STATE["errors"].append(f"Chat failed for {session_id[:8]}: {str(e)}")
871
+
872
+ return ChatResponse(
873
+ success=False,
874
+ answer="I apologize, but I encountered an error while processing your question. Please try again.",
875
+ sources=[],
876
+ chat_history=get_chat_history_safely(session_id),
877
+ processing_time=time.time() - start_time,
878
+ session_id=session_id,
879
+ error_details="Internal server error"
880
+ )
881
 
882
  @app.get("/history/{session_id}")
883
  async def get_session_history(session_id: str):
884
  """Get chat history for a session"""
885
  session_logger = create_session_logger(session_id)
886
 
887
+ if not DB:
888
  raise HTTPException(status_code=503, detail="Database not connected")
889
 
890
  try:
891
+ chat_history = get_chat_history_safely(session_id, limit=100)
892
 
893
  session_logger.info(f"Retrieved {len(chat_history)} chat messages")
894
 
 
909
  session_logger = create_session_logger(session_id)
910
 
911
  try:
912
+ cleaned_up = False
913
+
914
  with STORE_LOCK:
915
  if session_id in SESSION_STORES:
916
+ # Clean up session-specific RAG instance
917
+ store = SESSION_STORES[session_id]
918
+ if "rag_instance" in store:
919
+ try:
920
+ # Clean up any resources in the RAG instance
921
+ store["rag_instance"].cleanup()
922
+ except:
923
+ pass
924
+
925
  # Clean up FAISS index
926
+ if store.get("faiss_index"):
927
+ del store["faiss_index"]
928
+
929
  del SESSION_STORES[session_id]
930
+ APP_STATE["active_sessions"] = len(SESSION_STORES)
931
+ cleaned_up = True
932
  session_logger.info("Session removed from memory")
933
 
934
+ if not cleaned_up:
935
+ session_logger.info("Session not found in memory")
936
+
937
  return {
938
  "success": True,
939
  "message": f"Session {session_id} cleaned up successfully"
 
945
 
946
  @app.get("/sessions/active")
947
  async def get_active_sessions():
948
+ """Get information about active sessions in memory with TTL info"""
949
  try:
950
+ current_time = datetime.utcnow()
951
+
952
  with STORE_LOCK:
953
  active_sessions = []
954
  for session_id, store in SESSION_STORES.items():
955
  metadata = store["metadata"]
956
+ loaded_at = metadata["loaded_at"]
957
+ age_seconds = (current_time - loaded_at).total_seconds()
958
+ remaining_seconds = STORE_TTL - age_seconds
959
+
960
  active_sessions.append({
961
  "session_id": session_id,
962
  "title": metadata["title"],
963
  "chunk_count": metadata["chunk_count"],
964
  "indexed": store["indexed"],
965
+ "has_rag_instance": "rag_instance" in store,
966
+ "loaded_at": loaded_at.isoformat(),
967
+ "age_minutes": age_seconds / 60,
968
+ "remaining_minutes": max(0, remaining_seconds / 60),
969
+ "expires_at": (loaded_at + timedelta(seconds=STORE_TTL)).isoformat(),
970
+ "will_expire_soon": remaining_seconds < 300, # Less than 5 minutes
971
+ "failed_chunks": metadata.get("failed_chunks", 0)
972
  })
973
+
974
+ # Sort by remaining time (expiring soon first)
975
+ active_sessions.sort(key=lambda x: x["remaining_minutes"])
976
 
977
  return {
978
  "success": True,
979
  "active_sessions": active_sessions,
980
+ "total_sessions": len(active_sessions),
981
+ "session_ttl_minutes": STORE_TTL / 60,
982
+ "cleanup_interval_minutes": CLEANUP_INTERVAL / 60,
983
+ "next_cleanup_in_minutes": CLEANUP_INTERVAL / 60 # Approximate
984
  }
985
 
986
  except Exception as e:
987
  logger.error(f"Failed to get active sessions: {e}")
988
  raise HTTPException(status_code=500, detail=f"Failed to get active sessions: {str(e)}")
989
 
990
+ @app.post("/sessions/{session_id}/extend")
991
+ async def extend_session_ttl(session_id: str):
992
+ """Extend a session's TTL by resetting its load time (keep it alive longer)"""
993
+ session_logger = create_session_logger(session_id)
994
+
995
+ try:
996
+ with STORE_LOCK:
997
+ if session_id not in SESSION_STORES:
998
+ raise HTTPException(status_code=404, detail="Session not found in memory")
999
+
1000
+ # Reset the loaded_at timestamp to extend TTL
1001
+ old_loaded_at = SESSION_STORES[session_id]["metadata"]["loaded_at"]
1002
+ SESSION_STORES[session_id]["metadata"]["loaded_at"] = datetime.utcnow()
1003
+
1004
+ session_logger.info(f"Session TTL extended (was loaded at: {old_loaded_at.isoformat()})")
1005
+
1006
+ return {
1007
+ "success": True,
1008
+ "message": f"Session {session_id} TTL extended for another {STORE_TTL//60} minutes",
1009
+ "new_expiry": (datetime.utcnow() + timedelta(seconds=STORE_TTL)).isoformat()
1010
+ }
1011
+
1012
+ except HTTPException:
1013
+ raise
1014
+ except Exception as e:
1015
+ session_logger.error(f"Failed to extend session TTL: {e}")
1016
+ raise HTTPException(status_code=500, detail=f"Failed to extend session TTL: {str(e)}")
1017
+
1018
+ @app.post("/cleanup/run")
1019
+ async def manual_cleanup():
1020
+ """Manually trigger cleanup of expired sessions"""
1021
+ try:
1022
+ before_count = len(SESSION_STORES)
1023
+ cleanup_expired_sessions()
1024
+ after_count = len(SESSION_STORES)
1025
+ cleaned_count = before_count - after_count
1026
+
1027
+ return {
1028
+ "success": True,
1029
+ "message": f"Manual cleanup completed",
1030
+ "sessions_before": before_count,
1031
+ "sessions_after": after_count,
1032
+ "sessions_cleaned": cleaned_count
1033
+ }
1034
+
1035
+ except Exception as e:
1036
+ logger.error(f"Manual cleanup failed: {e}")
1037
+ raise HTTPException(status_code=500, detail=f"Manual cleanup failed: {str(e)}")
1038
+
1039
  @app.get("/rag/status")
1040
  async def get_rag_status():
1041
+ """Get RAG system status"""
1042
  try:
1043
  return {
1044
  "success": True,
1045
  "rag_initialized": RAG_INITIALIZED,
1046
+ "faiss_available": FAISS_AVAILABLE,
1047
+ "concurrency": {
1048
+ "session_isolated_rag": True,
1049
+ "async_processing": True,
1050
+ "thread_pool_execution": True,
1051
+ "no_global_state_conflicts": True
1052
+ },
1053
  "optimization": {
1054
+ "precomputed_embeddings": True,
 
1055
  "persistent_faiss_index": True,
1056
+ "mongodb_persistence": True,
1057
+ "memory_cleanup": True
1058
  },
1059
  "features": {
1060
  "multi_stage_retrieval": True,
1061
+ "dense_retrieval": "FAISS + Session-Isolated Embeddings",
1062
+ "sparse_retrieval": "BM25 per Session",
1063
+ "entity_based_retrieval": "Legal NER + SpaCy",
1064
+ "graph_based_retrieval": "Legal Concept Graph per Session",
1065
  "query_analysis": "Legal Intent Classification",
1066
  "answer_generation": "Groq LLM with IRAC Method"
1067
  },
1068
+ "active_sessions": len(SESSION_STORES),
1069
+ "total_queries_processed": APP_STATE["total_queries"]
 
 
 
 
 
 
 
 
 
1070
  }
1071
 
1072
  except Exception as e:
 
1076
  if __name__ == "__main__":
1077
  import uvicorn
1078
  port = int(os.getenv("PORT", 7861))
1079
+ logger.info(f"Starting server on port {port}")
1080
  uvicorn.run(app, host="0.0.0.0", port=port)