Kartik Narang commited on
Commit
fc6a53f
·
1 Parent(s): da63606

first commit

Browse files
Files changed (3) hide show
  1. app.py +678 -0
  2. rag.py +593 -0
  3. requirements.txt +23 -0
app.py ADDED
@@ -0,0 +1,678 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from pydantic import BaseModel
4
+ import pymongo
5
+ import os
6
+ import numpy as np
7
+ from datetime import datetime, timedelta
8
+ import logging
9
+ from typing import Dict, Any, Optional, List
10
+ import base64
11
+ import json
12
+ import threading
13
+ import time
14
+ from collections import defaultdict
15
+ import faiss
16
+
17
+ # Import our simplified advanced RAG system
18
+ import rag
19
+
20
+ # Configure logging
21
+ logging.basicConfig(
22
+ level=logging.INFO,
23
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
24
+ )
25
+ logger = logging.getLogger(__name__)
26
+
27
+ # Initialize FastAPI app
28
+ app = FastAPI(title="Advanced RAG Chat Service", version="1.0.0")
29
+
30
+ # Add CORS middleware
31
+ app.add_middleware(
32
+ CORSMiddleware,
33
+ allow_origins=["*"], # Configure this properly in production
34
+ allow_credentials=True,
35
+ allow_methods=["*"],
36
+ allow_headers=["*"],
37
+ )
38
+
39
+ # Global variables
40
+ MONGO_CLIENT = None
41
+ DB = None
42
+ RAG_INITIALIZED = False
43
+
44
+ # In-memory session stores
45
+ # Format: {session_id: {"chunks": [...], "faiss_index": faiss.Index, "indexed": bool, "metadata": {...}}}
46
+ SESSION_STORES = {}
47
+ STORE_LOCK = threading.RLock()
48
+ CLEANUP_INTERVAL = 3600 # 1 hour cleanup interval
49
+ STORE_TTL = 24 * 3600 # 24 hours TTL for in-memory stores
50
+
51
+ # Request/Response models
52
+ class ChatRequest(BaseModel):
53
+ message: str
54
+
55
+ class ChatResponse(BaseModel):
56
+ success: bool
57
+ answer: str
58
+ sources: List[Dict[str, Any]]
59
+ chat_history: List[Dict[str, Any]]
60
+ processing_time: float
61
+ session_id: str
62
+ query_analysis: Optional[Dict[str, Any]] = None
63
+ confidence: Optional[float] = None
64
+
65
+ class InitRequest(BaseModel):
66
+ pass
67
+
68
+ class InitResponse(BaseModel):
69
+ success: bool
70
+ session_id: str
71
+ message: str
72
+ chunk_count: int
73
+ title: str
74
+ document_info: Optional[Dict[str, Any]] = None
75
+
76
+ class HealthResponse(BaseModel):
77
+ status: str
78
+ mongodb_connected: bool
79
+ rag_initialized: bool
80
+ active_sessions: int
81
+ memory_usage: Dict[str, Any]
82
+
83
+ def create_session_logger(session_id: str):
84
+ """Create a logger with session context"""
85
+ return logging.LoggerAdapter(logger, {'session_id': session_id})
86
+
87
+ def connect_mongodb():
88
+ """Initialize MongoDB connection"""
89
+ global MONGO_CLIENT, DB
90
+ try:
91
+ mongodb_url = os.getenv("MONGODB_URL", "mongodb://localhost:27017/")
92
+ MONGO_CLIENT = pymongo.MongoClient(mongodb_url)
93
+ DB = MONGO_CLIENT["legal_rag_db"]
94
+
95
+ # Test connection
96
+ DB.command("ping")
97
+
98
+ # Create indexes for chats collection
99
+ logger.info("Creating MongoDB indexes for chats...")
100
+ DB.chats.create_index("session_id")
101
+ DB.chats.create_index("created_at", expireAfterSeconds=24*60*60) # 24 hour TTL
102
+ DB.chats.create_index([("session_id", 1), ("created_at", 1)]) # Compound index
103
+
104
+ logger.info("MongoDB connected successfully")
105
+ return True
106
+ except Exception as e:
107
+ logger.error(f"MongoDB connection failed: {e}")
108
+ return False
109
+
110
+ def initialize_rag():
111
+ """Initialize RAG system"""
112
+ global RAG_INITIALIZED
113
+ try:
114
+ model_id = os.getenv("EMBEDDING_MODEL_ID", "sentence-transformers/all-MiniLM-L6-v2")
115
+ groq_api_key = os.getenv("GROQ_API_KEY")
116
+
117
+ logger.info(f"Initializing RAG system with model: {model_id}")
118
+ rag.initialize_models(model_id, groq_api_key)
119
+
120
+ RAG_INITIALIZED = True
121
+ logger.info("RAG system initialized successfully")
122
+ return True
123
+ except Exception as e:
124
+ logger.error(f"RAG initialization failed: {e}")
125
+ return False
126
+
127
+ def decode_embedding_from_storage(embedding_list: List[float]) -> np.ndarray:
128
+ """Convert embedding from MongoDB list back to numpy array"""
129
+ try:
130
+ return np.array(embedding_list, dtype=np.float32)
131
+ except Exception as e:
132
+ logger.error(f"Failed to decode embedding: {e}")
133
+ return np.array([])
134
+
135
+ def load_session_from_mongodb(session_id: str) -> Dict[str, Any]:
136
+ """Load session data from MongoDB with precomputed embeddings"""
137
+ session_logger = create_session_logger(session_id)
138
+
139
+ try:
140
+ # Get session metadata
141
+ session_doc = DB.sessions.find_one({"session_id": session_id})
142
+ if not session_doc:
143
+ raise ValueError(f"Session {session_id} not found")
144
+
145
+ if session_doc.get("status") != "completed":
146
+ raise ValueError(f"Session {session_id} not completed yet (status: {session_doc.get('status')})")
147
+
148
+ session_logger.info("Loading session chunks with precomputed embeddings from MongoDB")
149
+
150
+ # Get all chunks for this session with embeddings
151
+ chunks_cursor = DB.chunks.find({"session_id": session_id}).sort("created_at", 1)
152
+ chunks_list = list(chunks_cursor)
153
+
154
+ if not chunks_list:
155
+ raise ValueError(f"No chunks found for session {session_id}")
156
+
157
+ session_logger.info(f"Found {len(chunks_list)} chunks with embeddings")
158
+
159
+ # Convert MongoDB chunks to format needed by RAG system
160
+ processed_chunks = []
161
+ embeddings_matrix = []
162
+
163
+ for i, chunk_doc in enumerate(chunks_list):
164
+ # Decode the precomputed embedding
165
+ embedding_list = chunk_doc.get('embedding', [])
166
+ if not embedding_list:
167
+ session_logger.warning(f"Chunk {chunk_doc.get('chunk_id', i)} missing embedding")
168
+ continue
169
+
170
+ embedding = decode_embedding_from_storage(embedding_list)
171
+ if embedding.size == 0:
172
+ session_logger.warning(f"Failed to decode embedding for chunk {chunk_doc.get('chunk_id', i)}")
173
+ continue
174
+
175
+ # Format chunk for RAG system
176
+ processed_chunk = {
177
+ 'id': chunk_doc.get('chunk_id', f'chunk_{i}'),
178
+ 'text': chunk_doc['text'],
179
+ 'title': chunk_doc.get('title', session_doc.get('title', 'Document')),
180
+ 'section_type': chunk_doc.get('section_type', 'content'),
181
+ 'importance_score': chunk_doc.get('importance_score', 1.0),
182
+ 'entities': chunk_doc.get('entities', []),
183
+ 'embedding': embedding # Precomputed embedding as numpy array
184
+ }
185
+
186
+ processed_chunks.append(processed_chunk)
187
+ embeddings_matrix.append(embedding)
188
+
189
+ if not processed_chunks:
190
+ raise ValueError(f"No valid chunks with embeddings found for session {session_id}")
191
+
192
+ # Stack embeddings for FAISS index
193
+ embeddings_matrix = np.vstack(embeddings_matrix).astype('float32')
194
+
195
+ session_store = {
196
+ "chunks": processed_chunks,
197
+ "embeddings_matrix": embeddings_matrix,
198
+ "faiss_index": None, # Will be built in indexing step
199
+ "indexed": False,
200
+ "metadata": {
201
+ "session_id": session_id,
202
+ "title": session_doc.get("title", "Document"),
203
+ "chunk_count": len(processed_chunks),
204
+ "loaded_at": datetime.utcnow(),
205
+ "document_info": {
206
+ "filename": session_doc.get("filename", "Unknown"),
207
+ "text_length": session_doc.get("text_length", 0),
208
+ "word_count": session_doc.get("word_count", 0),
209
+ "file_size": session_doc.get("file_size", 0),
210
+ "processing_completed_at": session_doc.get("processing_completed_at")
211
+ }
212
+ }
213
+ }
214
+
215
+ session_logger.info(f"Loaded {len(processed_chunks)} chunks with precomputed embeddings")
216
+ return session_store
217
+
218
+ except Exception as e:
219
+ session_logger.error(f"Failed to load session from MongoDB: {e}")
220
+ raise
221
+
222
+ def build_faiss_index_from_embeddings(session_id: str) -> Dict[str, Any]:
223
+ """Build FAISS index from precomputed embeddings"""
224
+ session_logger = create_session_logger(session_id)
225
+
226
+ with STORE_LOCK:
227
+ if session_id not in SESSION_STORES:
228
+ raise ValueError(f"Session {session_id} not loaded")
229
+
230
+ store = SESSION_STORES[session_id]
231
+ if store["indexed"]:
232
+ session_logger.info("Session already indexed")
233
+ return store["metadata"]
234
+
235
+ chunks = store["chunks"]
236
+ embeddings_matrix = store["embeddings_matrix"]
237
+
238
+ try:
239
+ session_logger.info(f"Building FAISS index from {len(chunks)} precomputed embeddings...")
240
+
241
+ # Create FAISS index (Inner Product for normalized embeddings)
242
+ dimension = embeddings_matrix.shape[1]
243
+ faiss_index = faiss.IndexFlatIP(dimension)
244
+
245
+ # Add embeddings to FAISS index
246
+ faiss_index.add(embeddings_matrix)
247
+
248
+ # Set global RAG data for this session
249
+ rag.CHUNKS_DATA = chunks
250
+ rag.DENSE_INDEX = faiss_index
251
+
252
+ # Build other indices (BM25, concept graph, etc.) using precomputed chunks
253
+ session_logger.info("Building additional retrieval indices...")
254
+
255
+ # BM25 index for sparse retrieval
256
+ tokenized_corpus = [chunk['text'].lower().split() for chunk in chunks]
257
+ rag.BM25_INDEX = rag.BM25Okapi(tokenized_corpus)
258
+
259
+ # ColBERT-style token index
260
+ rag.TOKEN_TO_CHUNKS = defaultdict(set)
261
+ for i, chunk in enumerate(chunks):
262
+ tokens = chunk['text'].lower().split()
263
+ for token in tokens:
264
+ rag.TOKEN_TO_CHUNKS[token].add(i)
265
+
266
+ # Legal concept graph
267
+ import networkx as nx
268
+ rag.CONCEPT_GRAPH = nx.Graph()
269
+ for i, chunk in enumerate(chunks):
270
+ rag.CONCEPT_GRAPH.add_node(i, text=chunk['text'][:200], importance=chunk['importance_score'])
271
+
272
+ # Add edges between chunks with shared entities
273
+ for j, other_chunk in enumerate(chunks[i+1:], i+1):
274
+ shared_entities = set(e['text'] for e in chunk['entities']) & \
275
+ set(e['text'] for e in other_chunk['entities'])
276
+ if shared_entities:
277
+ rag.CONCEPT_GRAPH.add_edge(i, j, weight=len(shared_entities))
278
+
279
+ # Mark as indexed and store FAISS index
280
+ with STORE_LOCK:
281
+ SESSION_STORES[session_id]["faiss_index"] = faiss_index
282
+ SESSION_STORES[session_id]["indexed"] = True
283
+
284
+ session_logger.info(f"FAISS index built successfully from precomputed embeddings: {len(chunks)} chunks indexed")
285
+ return SESSION_STORES[session_id]["metadata"]
286
+
287
+ except Exception as e:
288
+ session_logger.error(f"Failed to build FAISS index from embeddings: {e}")
289
+ raise
290
+
291
+ def save_chat_message(session_id: str, role: str, message: str):
292
+ """Save chat message to MongoDB"""
293
+ try:
294
+ chat_doc = {
295
+ "session_id": session_id,
296
+ "role": role,
297
+ "message": message,
298
+ "created_at": datetime.utcnow()
299
+ }
300
+ DB.chats.insert_one(chat_doc)
301
+ except Exception as e:
302
+ logger.error(f"Failed to save chat message for session {session_id}: {e}")
303
+
304
+ def get_chat_history(session_id: str, limit: int = 50) -> List[Dict[str, Any]]:
305
+ """Get chat history for a session"""
306
+ try:
307
+ chats_cursor = DB.chats.find(
308
+ {"session_id": session_id}
309
+ ).sort("created_at", 1).limit(limit)
310
+
311
+ chat_history = []
312
+ for chat_doc in chats_cursor:
313
+ chat_history.append({
314
+ "role": chat_doc["role"],
315
+ "message": chat_doc["message"],
316
+ "timestamp": chat_doc["created_at"].isoformat()
317
+ })
318
+
319
+ return chat_history
320
+
321
+ except Exception as e:
322
+ logger.error(f"Failed to get chat history for session {session_id}: {e}")
323
+ return []
324
+
325
+ def cleanup_old_stores():
326
+ """Background cleanup of old in-memory stores"""
327
+ while True:
328
+ try:
329
+ current_time = datetime.utcnow()
330
+ expired_sessions = []
331
+
332
+ with STORE_LOCK:
333
+ for session_id, store in SESSION_STORES.items():
334
+ loaded_at = store["metadata"]["loaded_at"]
335
+ if (current_time - loaded_at).total_seconds() > STORE_TTL:
336
+ expired_sessions.append(session_id)
337
+
338
+ for session_id in expired_sessions:
339
+ # Clean up FAISS index and other resources
340
+ if SESSION_STORES[session_id].get("faiss_index"):
341
+ del SESSION_STORES[session_id]["faiss_index"]
342
+ del SESSION_STORES[session_id]
343
+ logger.info(f"Cleaned up expired store for session: {session_id}")
344
+
345
+ if expired_sessions:
346
+ logger.info(f"Cleaned up {len(expired_sessions)} expired session stores")
347
+
348
+ except Exception as e:
349
+ logger.error(f"Cleanup error: {e}")
350
+
351
+ time.sleep(CLEANUP_INTERVAL)
352
+
353
+ @app.on_event("startup")
354
+ async def startup_event():
355
+ """Initialize connections on startup"""
356
+ logger.info("Starting up Advanced RAG Chat Service...")
357
+
358
+ # Connect to MongoDB
359
+ if not connect_mongodb():
360
+ logger.error("Failed to connect to MongoDB")
361
+ raise Exception("MongoDB connection failed")
362
+
363
+ # Initialize RAG system
364
+ if not initialize_rag():
365
+ logger.error("Failed to initialize RAG system")
366
+ raise Exception("RAG initialization failed")
367
+
368
+ # Start background cleanup thread
369
+ cleanup_thread = threading.Thread(target=cleanup_old_stores, daemon=True)
370
+ cleanup_thread.start()
371
+ logger.info("Background cleanup thread started")
372
+
373
+ logger.info("Startup completed successfully")
374
+
375
+ @app.get("/health", response_model=HealthResponse)
376
+ async def health_check():
377
+ """Health check endpoint"""
378
+ try:
379
+ # Check MongoDB connection
380
+ mongodb_connected = False
381
+ active_sessions = 0
382
+
383
+ if DB is not None:
384
+ try:
385
+ DB.command("ping")
386
+ mongodb_connected = True
387
+ # Count sessions with recent chats
388
+ one_hour_ago = datetime.utcnow() - timedelta(hours=1)
389
+ active_sessions = len(DB.chats.distinct("session_id", {"created_at": {"$gte": one_hour_ago}}))
390
+ except:
391
+ pass
392
+
393
+ # Memory usage info
394
+ with STORE_LOCK:
395
+ memory_sessions = len(SESSION_STORES)
396
+ indexed_sessions = sum(1 for store in SESSION_STORES.values() if store["indexed"])
397
+
398
+ return HealthResponse(
399
+ status="healthy" if mongodb_connected and RAG_INITIALIZED else "unhealthy",
400
+ mongodb_connected=mongodb_connected,
401
+ rag_initialized=RAG_INITIALIZED,
402
+ active_sessions=active_sessions,
403
+ memory_usage={
404
+ "loaded_sessions": memory_sessions,
405
+ "indexed_sessions": indexed_sessions,
406
+ "store_ttl_hours": STORE_TTL / 3600
407
+ }
408
+ )
409
+ except Exception as e:
410
+ logger.error(f"Health check failed: {e}")
411
+ return HealthResponse(
412
+ status="unhealthy",
413
+ mongodb_connected=False,
414
+ rag_initialized=False,
415
+ active_sessions=0,
416
+ memory_usage={}
417
+ )
418
+
419
+ @app.post("/init/{session_id}", response_model=InitResponse)
420
+ async def initialize_session(session_id: str, request: InitRequest):
421
+ """Initialize RAG context for a session using precomputed embeddings"""
422
+ session_logger = create_session_logger(session_id)
423
+
424
+ if DB is None:
425
+ raise HTTPException(status_code=503, detail="Database not connected")
426
+
427
+ if not RAG_INITIALIZED:
428
+ raise HTTPException(status_code=503, detail="RAG system not initialized")
429
+
430
+ # Check if already loaded and indexed
431
+ with STORE_LOCK:
432
+ if session_id in SESSION_STORES and SESSION_STORES[session_id]["indexed"]:
433
+ store = SESSION_STORES[session_id]
434
+ metadata = store["metadata"]
435
+ session_logger.info("Session already initialized and indexed with precomputed embeddings")
436
+ return InitResponse(
437
+ success=True,
438
+ session_id=session_id,
439
+ message="Session already initialized with precomputed embeddings",
440
+ chunk_count=metadata["chunk_count"],
441
+ title=metadata["title"],
442
+ document_info=metadata["document_info"]
443
+ )
444
+
445
+ try:
446
+ session_logger.info("Initializing session with precomputed embeddings from MongoDB")
447
+
448
+ # Load session data with precomputed embeddings from MongoDB
449
+ session_store = load_session_from_mongodb(session_id)
450
+
451
+ # Store in memory
452
+ with STORE_LOCK:
453
+ SESSION_STORES[session_id] = session_store
454
+
455
+ # Build FAISS index from precomputed embeddings (no re-embedding!)
456
+ metadata = build_faiss_index_from_embeddings(session_id)
457
+
458
+ session_logger.info(f"Session initialized with precomputed embeddings: {metadata['chunk_count']} chunks indexed")
459
+
460
+ return InitResponse(
461
+ success=True,
462
+ session_id=session_id,
463
+ message=f"Session initialized with precomputed embeddings: {metadata['chunk_count']} chunks ready for advanced RAG",
464
+ chunk_count=metadata["chunk_count"],
465
+ title=metadata["title"],
466
+ document_info=metadata["document_info"]
467
+ )
468
+
469
+ except ValueError as e:
470
+ session_logger.error(f"Session initialization failed: {e}")
471
+ raise HTTPException(status_code=404, detail=str(e))
472
+ except Exception as e:
473
+ session_logger.error(f"Session initialization error: {e}")
474
+ raise HTTPException(status_code=500, detail=f"Failed to initialize session: {str(e)}")
475
+
476
+ @app.post("/chat/{session_id}", response_model=ChatResponse)
477
+ async def chat_with_document(session_id: str, request: ChatRequest):
478
+ """Handle chat query with advanced RAG using precomputed embeddings"""
479
+ session_logger = create_session_logger(session_id)
480
+ start_time = time.time()
481
+
482
+ if DB is None:
483
+ raise HTTPException(status_code=503, detail="Database not connected")
484
+
485
+ if not RAG_INITIALIZED:
486
+ raise HTTPException(status_code=503, detail="RAG system not initialized")
487
+
488
+ # Validate request
489
+ if not request.message.strip():
490
+ raise HTTPException(status_code=400, detail="Empty message provided")
491
+
492
+ try:
493
+ session_logger.info(f"Processing advanced RAG query: {request.message[:100]}...")
494
+
495
+ # Check if session is initialized and indexed
496
+ with STORE_LOCK:
497
+ if session_id not in SESSION_STORES:
498
+ raise HTTPException(
499
+ status_code=400,
500
+ detail=f"Session {session_id} not initialized. Call /init/{session_id} first."
501
+ )
502
+
503
+ if not SESSION_STORES[session_id]["indexed"]:
504
+ raise HTTPException(
505
+ status_code=400,
506
+ detail=f"Session {session_id} not indexed. Call /init/{session_id} first."
507
+ )
508
+
509
+ # Query using advanced RAG system (now using precomputed embeddings)
510
+ result = rag.query_documents(request.message, top_k=5)
511
+
512
+ if 'error' in result:
513
+ raise HTTPException(status_code=500, detail=result['error'])
514
+
515
+ answer = result.get('answer', 'Unable to generate answer.')
516
+ sources = result.get('sources', [])
517
+ query_analysis = result.get('query_analysis', {})
518
+ confidence = result.get('confidence', 0.0)
519
+
520
+ # Save chat messages to MongoDB for persistence
521
+ save_chat_message(session_id, "user", request.message)
522
+ save_chat_message(session_id, "assistant", answer)
523
+
524
+ # Get updated chat history
525
+ chat_history = get_chat_history(session_id)
526
+
527
+ processing_time = time.time() - start_time
528
+ session_logger.info(f"Advanced RAG query processed in {processing_time:.2f}s with confidence {confidence:.1f}% using precomputed embeddings")
529
+
530
+ # Prepare sources for response
531
+ formatted_sources = [
532
+ {
533
+ "chunk_id": source.get("chunk_id", ""),
534
+ "title": source.get("title", ""),
535
+ "section": source.get("section", ""),
536
+ "relevance_score": source.get("relevance_score", 0.0),
537
+ "text_preview": source.get("excerpt", "")[:300] + "..." if len(source.get("excerpt", "")) > 300 else source.get("excerpt", ""),
538
+ "entities": source.get("entities", [])
539
+ }
540
+ for source in sources
541
+ ]
542
+
543
+ return ChatResponse(
544
+ success=True,
545
+ answer=answer,
546
+ sources=formatted_sources,
547
+ chat_history=chat_history,
548
+ processing_time=processing_time,
549
+ session_id=session_id,
550
+ query_analysis=query_analysis,
551
+ confidence=confidence
552
+ )
553
+
554
+ except HTTPException:
555
+ raise
556
+ except Exception as e:
557
+ session_logger.error(f"Advanced RAG chat processing failed: {e}")
558
+ raise HTTPException(status_code=500, detail=f"Chat processing failed: {str(e)}")
559
+
560
+ @app.get("/history/{session_id}")
561
+ async def get_session_history(session_id: str):
562
+ """Get chat history for a session"""
563
+ session_logger = create_session_logger(session_id)
564
+
565
+ if DB is None:
566
+ raise HTTPException(status_code=503, detail="Database not connected")
567
+
568
+ try:
569
+ chat_history = get_chat_history(session_id, limit=100)
570
+
571
+ session_logger.info(f"Retrieved {len(chat_history)} chat messages")
572
+
573
+ return {
574
+ "success": True,
575
+ "session_id": session_id,
576
+ "chat_history": chat_history,
577
+ "total_messages": len(chat_history)
578
+ }
579
+
580
+ except Exception as e:
581
+ session_logger.error(f"Failed to get chat history: {e}")
582
+ raise HTTPException(status_code=500, detail=f"Failed to retrieve chat history: {str(e)}")
583
+
584
+ @app.delete("/session/{session_id}")
585
+ async def cleanup_session(session_id: str):
586
+ """Clean up session from memory"""
587
+ session_logger = create_session_logger(session_id)
588
+
589
+ try:
590
+ # Remove from memory
591
+ with STORE_LOCK:
592
+ if session_id in SESSION_STORES:
593
+ # Clean up FAISS index
594
+ if SESSION_STORES[session_id].get("faiss_index"):
595
+ del SESSION_STORES[session_id]["faiss_index"]
596
+ del SESSION_STORES[session_id]
597
+ session_logger.info("Session removed from memory")
598
+
599
+ return {
600
+ "success": True,
601
+ "message": f"Session {session_id} cleaned up successfully"
602
+ }
603
+
604
+ except Exception as e:
605
+ session_logger.error(f"Session cleanup failed: {e}")
606
+ raise HTTPException(status_code=500, detail=f"Failed to cleanup session: {str(e)}")
607
+
608
+ @app.get("/sessions/active")
609
+ async def get_active_sessions():
610
+ """Get information about active sessions in memory"""
611
+ try:
612
+ with STORE_LOCK:
613
+ active_sessions = []
614
+ for session_id, store in SESSION_STORES.items():
615
+ metadata = store["metadata"]
616
+ active_sessions.append({
617
+ "session_id": session_id,
618
+ "title": metadata["title"],
619
+ "chunk_count": metadata["chunk_count"],
620
+ "indexed": store["indexed"],
621
+ "loaded_at": metadata["loaded_at"].isoformat(),
622
+ "age_minutes": (datetime.utcnow() - metadata["loaded_at"]).total_seconds() / 60,
623
+ "using_precomputed_embeddings": True
624
+ })
625
+
626
+ return {
627
+ "success": True,
628
+ "active_sessions": active_sessions,
629
+ "total_sessions": len(active_sessions)
630
+ }
631
+
632
+ except Exception as e:
633
+ logger.error(f"Failed to get active sessions: {e}")
634
+ raise HTTPException(status_code=500, detail=f"Failed to get active sessions: {str(e)}")
635
+
636
+ @app.get("/rag/status")
637
+ async def get_rag_status():
638
+ """Get advanced RAG system status"""
639
+ try:
640
+ return {
641
+ "success": True,
642
+ "rag_initialized": RAG_INITIALIZED,
643
+ "optimization": {
644
+ "using_precomputed_embeddings": True,
645
+ "no_reembedding": True,
646
+ "persistent_faiss_index": True,
647
+ "mongodb_persistence": True
648
+ },
649
+ "features": {
650
+ "multi_stage_retrieval": True,
651
+ "dense_retrieval": "FAISS + Precomputed Legal-BERT Embeddings",
652
+ "sparse_retrieval": "BM25",
653
+ "entity_based_retrieval": "Legal NER + SpaCy",
654
+ "graph_based_retrieval": "Legal Concept Graph",
655
+ "query_analysis": "Legal Intent Classification",
656
+ "answer_generation": "Groq LLM with IRAC Method"
657
+ },
658
+ "active_techniques": [
659
+ "Dense Embedding Search (FAISS with Precomputed Embeddings)",
660
+ "BM25 Sparse Retrieval",
661
+ "ColBERT Token Matching",
662
+ "Legal Entity Matching",
663
+ "Concept Graph Expansion",
664
+ "HyDE Query Expansion",
665
+ "Multi-Query Retrieval",
666
+ "Legal Section Classification",
667
+ "Importance-based Ranking"
668
+ ]
669
+ }
670
+
671
+ except Exception as e:
672
+ logger.error(f"Failed to get RAG status: {e}")
673
+ raise HTTPException(status_code=500, detail=f"Failed to get RAG status: {str(e)}")
674
+
675
+ if __name__ == "__main__":
676
+ import uvicorn
677
+ port = int(os.getenv("PORT", 7861))
678
+ uvicorn.run(app, host="0.0.0.0", port=port)
rag.py ADDED
@@ -0,0 +1,593 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from transformers import AutoTokenizer, AutoModel
4
+ from typing import List, Dict, Any, Tuple, Optional
5
+ import faiss
6
+ import hashlib
7
+ from tqdm import tqdm
8
+ from groq import Groq
9
+ import re
10
+ import nltk
11
+ from sklearn.metrics.pairwise import cosine_similarity
12
+ import networkx as nx
13
+ from collections import defaultdict
14
+ import spacy
15
+ from rank_bm25 import BM25Okapi
16
+
17
+ # Global variables for models
18
+ MODEL = None
19
+ TOKENIZER = None
20
+ GROQ_CLIENT = None
21
+ NLP_MODEL = None
22
+ DEVICE = None
23
+
24
+ # Global indices
25
+ DENSE_INDEX = None
26
+ BM25_INDEX = None
27
+ CONCEPT_GRAPH = None
28
+ TOKEN_TO_CHUNKS = None
29
+ CHUNKS_DATA = []
30
+
31
+ # Legal knowledge base
32
+ LEGAL_CONCEPTS = {
33
+ 'liability': ['negligence', 'strict liability', 'vicarious liability', 'product liability'],
34
+ 'contract': ['breach', 'consideration', 'offer', 'acceptance', 'damages', 'specific performance'],
35
+ 'criminal': ['mens rea', 'actus reus', 'intent', 'malice', 'premeditation'],
36
+ 'procedure': ['jurisdiction', 'standing', 'statute of limitations', 'res judicata'],
37
+ 'evidence': ['hearsay', 'relevance', 'privilege', 'burden of proof', 'admissibility'],
38
+ 'constitutional': ['due process', 'equal protection', 'free speech', 'search and seizure']
39
+ }
40
+
41
+ QUERY_PATTERNS = {
42
+ 'precedent': ['case', 'precedent', 'ruling', 'held', 'decision'],
43
+ 'statute_interpretation': ['statute', 'section', 'interpretation', 'meaning', 'definition'],
44
+ 'factual': ['what happened', 'facts', 'circumstances', 'events'],
45
+ 'procedure': ['how to', 'procedure', 'process', 'filing', 'requirements']
46
+ }
47
+
48
+ def initialize_models(model_id: str, groq_api_key: str = None):
49
+ """Initialize all models and components"""
50
+ global MODEL, TOKENIZER, GROQ_CLIENT, NLP_MODEL, DEVICE
51
+
52
+ try:
53
+ nltk.download('punkt', quiet=True)
54
+ nltk.download('stopwords', quiet=True)
55
+ except:
56
+ pass
57
+
58
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
59
+ print(f"Using device: {DEVICE}")
60
+
61
+ print(f"Loading model: {model_id}")
62
+ TOKENIZER = AutoTokenizer.from_pretrained(model_id)
63
+ MODEL = AutoModel.from_pretrained(model_id).to(DEVICE)
64
+ MODEL.eval()
65
+
66
+ if groq_api_key:
67
+ GROQ_CLIENT = Groq(api_key=groq_api_key)
68
+
69
+ try:
70
+ NLP_MODEL = spacy.load("en_core_web_sm")
71
+ except:
72
+ print("SpaCy model not found, using basic NER")
73
+ NLP_MODEL = None
74
+
75
+ def create_embedding(text: str) -> np.ndarray:
76
+ """Create dense embedding for text"""
77
+ inputs = TOKENIZER(text, padding=True, truncation=True,
78
+ max_length=512, return_tensors='pt').to(DEVICE)
79
+
80
+ with torch.no_grad():
81
+ outputs = MODEL(**inputs)
82
+ attention_mask = inputs['attention_mask']
83
+ token_embeddings = outputs.last_hidden_state
84
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
85
+ embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
86
+
87
+ # Normalize embeddings
88
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
89
+
90
+ return embeddings.cpu().numpy()[0]
91
+
92
+ def extract_legal_entities(text: str) -> List[Dict[str, Any]]:
93
+ """Extract legal entities from text"""
94
+ entities = []
95
+
96
+ if NLP_MODEL:
97
+ doc = NLP_MODEL(text[:5000]) # Limit for performance
98
+ for ent in doc.ents:
99
+ if ent.label_ in ['PERSON', 'ORG', 'LAW', 'GPE']:
100
+ entities.append({
101
+ 'text': ent.text,
102
+ 'type': ent.label_,
103
+ 'importance': 1.0
104
+ })
105
+
106
+ # Legal citations
107
+ citation_pattern = r'\b\d+\s+[A-Z][a-z]+\.?\s+\d+\b'
108
+ for match in re.finditer(citation_pattern, text):
109
+ entities.append({
110
+ 'text': match.group(),
111
+ 'type': 'case_citation',
112
+ 'importance': 2.0
113
+ })
114
+
115
+ # Statute references
116
+ statute_pattern = r'§\s*\d+[\.\d]*|\bSection\s+\d+'
117
+ for match in re.finditer(statute_pattern, text):
118
+ entities.append({
119
+ 'text': match.group(),
120
+ 'type': 'statute',
121
+ 'importance': 1.5
122
+ })
123
+
124
+ return entities
125
+
126
+ def analyze_query(query: str) -> Dict[str, Any]:
127
+ """Analyze query to understand intent"""
128
+ query_lower = query.lower()
129
+
130
+ # Classify query type
131
+ query_type = 'general'
132
+ for qtype, patterns in QUERY_PATTERNS.items():
133
+ if any(pattern in query_lower for pattern in patterns):
134
+ query_type = qtype
135
+ break
136
+
137
+ # Extract entities
138
+ entities = extract_legal_entities(query)
139
+
140
+ # Extract key concepts
141
+ key_concepts = []
142
+ for concept_category, concepts in LEGAL_CONCEPTS.items():
143
+ for concept in concepts:
144
+ if concept in query_lower:
145
+ key_concepts.append(concept)
146
+
147
+ # Generate expanded queries
148
+ expanded_queries = [query]
149
+
150
+ # Concept expansion
151
+ if key_concepts:
152
+ expanded_queries.append(f"{query} {' '.join(key_concepts[:3])}")
153
+
154
+ # Type-based expansion
155
+ if query_type == 'precedent':
156
+ expanded_queries.append(f"legal precedent case law {query}")
157
+ elif query_type == 'statute_interpretation':
158
+ expanded_queries.append(f"statutory interpretation meaning {query}")
159
+
160
+ # HyDE - Hypothetical document generation
161
+ if GROQ_CLIENT:
162
+ hyde_doc = generate_hypothetical_document(query)
163
+ if hyde_doc:
164
+ expanded_queries.append(hyde_doc)
165
+
166
+ return {
167
+ 'original_query': query,
168
+ 'query_type': query_type,
169
+ 'entities': entities,
170
+ 'key_concepts': key_concepts,
171
+ 'expanded_queries': expanded_queries[:4] # Limit to 4 queries
172
+ }
173
+
174
+ def generate_hypothetical_document(query: str) -> Optional[str]:
175
+ """Generate hypothetical answer document (HyDE technique)"""
176
+ if not GROQ_CLIENT:
177
+ return None
178
+
179
+ try:
180
+ prompt = f"""Generate a brief hypothetical legal document excerpt that would answer this question: {query}
181
+
182
+ Write it as if it's from an actual legal case or statute. Be specific and use legal language.
183
+ Keep it under 100 words."""
184
+
185
+ response = GROQ_CLIENT.chat.completions.create(
186
+ messages=[
187
+ {"role": "system", "content": "You are a legal expert generating hypothetical legal text."},
188
+ {"role": "user", "content": prompt}
189
+ ],
190
+ model="llama-3.1-8b-instant",
191
+ temperature=0.3,
192
+ max_tokens=150
193
+ )
194
+
195
+ return response.choices[0].message.content
196
+ except:
197
+ return None
198
+
199
+ def chunk_text_hierarchical(text: str, title: str = "") -> List[Dict[str, Any]]:
200
+ """Create hierarchical chunks with legal structure awareness"""
201
+ chunks = []
202
+
203
+ # Clean text
204
+ text = re.sub(r'\s+', ' ', text)
205
+
206
+ # Identify legal sections
207
+ section_patterns = [
208
+ (r'(?i)\bFACTS?\b[:\s]', 'facts'),
209
+ (r'(?i)\bHOLDING\b[:\s]', 'holding'),
210
+ (r'(?i)\bREASONING\b[:\s]', 'reasoning'),
211
+ (r'(?i)\bDISSENT\b[:\s]', 'dissent'),
212
+ (r'(?i)\bCONCLUSION\b[:\s]', 'conclusion')
213
+ ]
214
+
215
+ sections = []
216
+ for pattern, section_type in section_patterns:
217
+ matches = list(re.finditer(pattern, text))
218
+ for match in matches:
219
+ sections.append((match.start(), section_type))
220
+
221
+ sections.sort(key=lambda x: x[0])
222
+
223
+ # Split into sentences
224
+ import nltk
225
+ try:
226
+ sentences = nltk.sent_tokenize(text)
227
+ except:
228
+ sentences = text.split('. ')
229
+
230
+ # Create chunks
231
+ current_section = 'introduction'
232
+ section_sentences = []
233
+ chunk_size = 500 # words
234
+
235
+ for sent in sentences:
236
+ # Check section type
237
+ sent_pos = text.find(sent)
238
+ for pos, stype in sections:
239
+ if sent_pos >= pos:
240
+ current_section = stype
241
+
242
+ section_sentences.append(sent)
243
+
244
+ # Create chunk when we have enough content
245
+ chunk_text = ' '.join(section_sentences)
246
+ if len(chunk_text.split()) >= chunk_size or len(section_sentences) >= 10:
247
+ chunk_id = hashlib.md5(f"{title}_{len(chunks)}_{chunk_text[:50]}".encode()).hexdigest()[:12]
248
+
249
+ # Calculate importance
250
+ importance = 1.0
251
+ section_weights = {
252
+ 'holding': 2.0, 'conclusion': 1.8, 'reasoning': 1.5,
253
+ 'facts': 1.2, 'dissent': 0.8
254
+ }
255
+ importance *= section_weights.get(current_section, 1.0)
256
+
257
+ # Entity importance
258
+ entities = extract_legal_entities(chunk_text)
259
+ if entities:
260
+ entity_score = sum(e['importance'] for e in entities) / len(entities)
261
+ importance *= (1 + entity_score * 0.5)
262
+
263
+ chunks.append({
264
+ 'id': chunk_id,
265
+ 'text': chunk_text,
266
+ 'title': title,
267
+ 'section_type': current_section,
268
+ 'importance_score': importance,
269
+ 'entities': entities,
270
+ 'embedding': None # Will be filled during indexing
271
+ })
272
+
273
+ section_sentences = []
274
+
275
+ # Add remaining sentences
276
+ if section_sentences:
277
+ chunk_text = ' '.join(section_sentences)
278
+ chunk_id = hashlib.md5(f"{title}_{len(chunks)}_{chunk_text[:50]}".encode()).hexdigest()[:12]
279
+ chunks.append({
280
+ 'id': chunk_id,
281
+ 'text': chunk_text,
282
+ 'title': title,
283
+ 'section_type': current_section,
284
+ 'importance_score': 1.0,
285
+ 'entities': extract_legal_entities(chunk_text),
286
+ 'embedding': None
287
+ })
288
+
289
+ return chunks
290
+
291
+ def build_all_indices(chunks: List[Dict[str, Any]]):
292
+ """Build all retrieval indices"""
293
+ global DENSE_INDEX, BM25_INDEX, CONCEPT_GRAPH, TOKEN_TO_CHUNKS, CHUNKS_DATA
294
+
295
+ CHUNKS_DATA = chunks
296
+ print(f"Building indices for {len(chunks)} chunks...")
297
+
298
+ # 1. Dense embeddings + FAISS index
299
+ print("Building FAISS index...")
300
+ embeddings = []
301
+ for chunk in tqdm(chunks, desc="Creating embeddings"):
302
+ embedding = create_embedding(chunk['text'])
303
+ chunk['embedding'] = embedding
304
+ embeddings.append(embedding)
305
+
306
+ embeddings_matrix = np.vstack(embeddings)
307
+ DENSE_INDEX = faiss.IndexFlatIP(embeddings_matrix.shape[1]) # Inner product for normalized vectors
308
+ DENSE_INDEX.add(embeddings_matrix.astype('float32'))
309
+
310
+ # 2. BM25 index for sparse retrieval
311
+ print("Building BM25 index...")
312
+ tokenized_corpus = [chunk['text'].lower().split() for chunk in chunks]
313
+ BM25_INDEX = BM25Okapi(tokenized_corpus)
314
+
315
+ # 3. ColBERT-style token index
316
+ print("Building ColBERT token index...")
317
+ TOKEN_TO_CHUNKS = defaultdict(set)
318
+ for i, chunk in enumerate(chunks):
319
+ # Simple tokenization for token-level matching
320
+ tokens = chunk['text'].lower().split()
321
+ for token in tokens:
322
+ TOKEN_TO_CHUNKS[token].add(i)
323
+
324
+ # 4. Legal concept graph
325
+ print("Building legal concept graph...")
326
+ CONCEPT_GRAPH = nx.Graph()
327
+
328
+ for i, chunk in enumerate(chunks):
329
+ CONCEPT_GRAPH.add_node(i, text=chunk['text'][:200], importance=chunk['importance_score'])
330
+
331
+ # Add edges between chunks with shared entities
332
+ for j, other_chunk in enumerate(chunks[i+1:], i+1):
333
+ shared_entities = set(e['text'] for e in chunk['entities']) & \
334
+ set(e['text'] for e in other_chunk['entities'])
335
+ if shared_entities:
336
+ CONCEPT_GRAPH.add_edge(i, j, weight=len(shared_entities))
337
+
338
+ print("All indices built successfully!")
339
+
340
+ def multi_stage_retrieval(query_analysis: Dict[str, Any], top_k: int = 10) -> List[Tuple[Dict[str, Any], float]]:
341
+ """Perform multi-stage retrieval combining all techniques"""
342
+ candidates = {}
343
+
344
+ print("Performing multi-stage retrieval...")
345
+
346
+ # Stage 1: Dense retrieval with expanded queries
347
+ print("Stage 1: Dense retrieval...")
348
+ for query in query_analysis['expanded_queries'][:3]:
349
+ query_emb = create_embedding(query)
350
+ scores, indices = DENSE_INDEX.search(
351
+ query_emb.reshape(1, -1).astype('float32'),
352
+ top_k * 2
353
+ )
354
+
355
+ for idx, score in zip(indices[0], scores[0]):
356
+ if idx < len(CHUNKS_DATA):
357
+ chunk_id = CHUNKS_DATA[idx]['id']
358
+ if chunk_id not in candidates:
359
+ candidates[chunk_id] = {'chunk': CHUNKS_DATA[idx], 'scores': {}}
360
+ candidates[chunk_id]['scores']['dense'] = float(score)
361
+
362
+ # Stage 2: Sparse retrieval (BM25)
363
+ print("Stage 2: Sparse retrieval...")
364
+ query_tokens = query_analysis['original_query'].lower().split()
365
+ bm25_scores = BM25_INDEX.get_scores(query_tokens)
366
+ top_bm25_indices = np.argsort(bm25_scores)[-top_k*2:][::-1]
367
+
368
+ for idx in top_bm25_indices:
369
+ if idx < len(CHUNKS_DATA):
370
+ chunk_id = CHUNKS_DATA[idx]['id']
371
+ if chunk_id not in candidates:
372
+ candidates[chunk_id] = {'chunk': CHUNKS_DATA[idx], 'scores': {}}
373
+ candidates[chunk_id]['scores']['bm25'] = float(bm25_scores[idx])
374
+
375
+ # Stage 3: Entity-based retrieval
376
+ print("Stage 3: Entity-based retrieval...")
377
+ for entity in query_analysis['entities']:
378
+ for chunk in CHUNKS_DATA:
379
+ chunk_entity_texts = [e['text'].lower() for e in chunk['entities']]
380
+ if entity['text'].lower() in chunk_entity_texts:
381
+ chunk_id = chunk['id']
382
+ if chunk_id not in candidates:
383
+ candidates[chunk_id] = {'chunk': chunk, 'scores': {}}
384
+ candidates[chunk_id]['scores']['entity'] = \
385
+ candidates[chunk_id]['scores'].get('entity', 0) + entity['importance']
386
+
387
+ # Stage 4: Graph-based retrieval
388
+ print("Stage 4: Graph-based retrieval...")
389
+ if candidates and CONCEPT_GRAPH:
390
+ seed_chunks = []
391
+ for chunk_id, data in list(candidates.items())[:5]:
392
+ for i, chunk in enumerate(CHUNKS_DATA):
393
+ if chunk['id'] == chunk_id:
394
+ seed_chunks.append(i)
395
+ break
396
+
397
+ for seed_idx in seed_chunks:
398
+ if seed_idx in CONCEPT_GRAPH:
399
+ neighbors = list(CONCEPT_GRAPH.neighbors(seed_idx))[:3]
400
+ for neighbor_idx in neighbors:
401
+ if neighbor_idx < len(CHUNKS_DATA):
402
+ chunk = CHUNKS_DATA[neighbor_idx]
403
+ chunk_id = chunk['id']
404
+ if chunk_id not in candidates:
405
+ candidates[chunk_id] = {'chunk': chunk, 'scores': {}}
406
+ candidates[chunk_id]['scores']['graph'] = 0.5
407
+
408
+ # Combine scores
409
+ print("Combining scores...")
410
+ weights = {'dense': 0.35, 'bm25': 0.25, 'entity': 0.25, 'graph': 0.15}
411
+ final_scores = []
412
+
413
+ for chunk_id, data in candidates.items():
414
+ chunk = data['chunk']
415
+ scores = data['scores']
416
+
417
+ final_score = 0
418
+ for method, weight in weights.items():
419
+ if method in scores:
420
+ # Normalize scores
421
+ if method == 'dense':
422
+ normalized = (scores[method] + 1) / 2 # [-1, 1] to [0, 1]
423
+ elif method == 'bm25':
424
+ normalized = min(scores[method] / 10, 1)
425
+ elif method == 'entity':
426
+ normalized = min(scores[method] / 3, 1)
427
+ else:
428
+ normalized = scores[method]
429
+
430
+ final_score += weight * normalized
431
+
432
+ # Boost by importance and section relevance
433
+ final_score *= chunk['importance_score']
434
+
435
+ if query_analysis['query_type'] == 'precedent' and chunk['section_type'] == 'holding':
436
+ final_score *= 1.5
437
+ elif query_analysis['query_type'] == 'factual' and chunk['section_type'] == 'facts':
438
+ final_score *= 1.5
439
+
440
+ final_scores.append((chunk, final_score))
441
+
442
+ # Sort and return top-k
443
+ final_scores.sort(key=lambda x: x[1], reverse=True)
444
+ return final_scores[:top_k]
445
+
446
+ def generate_answer_with_reasoning(query: str, retrieved_chunks: List[Tuple[Dict[str, Any], float]]) -> Dict[str, Any]:
447
+ """Generate answer with legal reasoning"""
448
+ if not GROQ_CLIENT:
449
+ return {'error': 'Groq client not initialized'}
450
+
451
+ # Prepare context
452
+ context_parts = []
453
+ for i, (chunk, score) in enumerate(retrieved_chunks, 1):
454
+ entities = ', '.join([e['text'] for e in chunk['entities'][:3]])
455
+ context_parts.append(f"""
456
+ Document {i} [{chunk['title']}] - Relevance: {score:.2f}
457
+ Section: {chunk['section_type']}
458
+ Key Entities: {entities}
459
+ Content: {chunk['text'][:800]}
460
+ """)
461
+
462
+ context = "\n---\n".join(context_parts)
463
+
464
+ system_prompt = """You are an expert legal analyst. Provide thorough legal analysis using the IRAC method:
465
+ 1. ISSUE: Identify the legal issue(s)
466
+ 2. RULE: State the applicable legal rules/precedents
467
+ 3. APPLICATION: Apply the rules to the facts
468
+ 4. CONCLUSION: Provide a clear conclusion
469
+
470
+ CRITICAL: Base ALL responses on the provided document excerpts only. Quote directly when making claims.
471
+ If information is not in the excerpts, state "This information is not provided in the available documents."
472
+ """
473
+
474
+ user_prompt = f"""Query: {query}
475
+
476
+ Retrieved Legal Documents:
477
+ {context}
478
+
479
+ Please provide a comprehensive legal analysis using IRAC method. Cite the documents when making claims."""
480
+
481
+ try:
482
+ response = GROQ_CLIENT.chat.completions.create(
483
+ messages=[
484
+ {"role": "system", "content": system_prompt},
485
+ {"role": "user", "content": user_prompt}
486
+ ],
487
+ model="llama-3.1-8b-instant",
488
+ temperature=0.1,
489
+ max_tokens=1000
490
+ )
491
+
492
+ answer = response.choices[0].message.content
493
+
494
+ # Calculate confidence
495
+ avg_score = sum(score for _, score in retrieved_chunks[:3]) / min(3, len(retrieved_chunks))
496
+ confidence = min(avg_score * 100, 100)
497
+
498
+ return {
499
+ 'answer': answer,
500
+ 'confidence': confidence,
501
+ 'sources': [
502
+ {
503
+ 'chunk_id': chunk['id'],
504
+ 'title': chunk['title'],
505
+ 'section': chunk['section_type'],
506
+ 'relevance_score': float(score),
507
+ 'excerpt': chunk['text'][:200] + '...',
508
+ 'entities': [e['text'] for e in chunk['entities'][:5]]
509
+ }
510
+ for chunk, score in retrieved_chunks
511
+ ]
512
+ }
513
+
514
+ except Exception as e:
515
+ return {
516
+ 'error': f'Error generating answer: {str(e)}',
517
+ 'sources': [{'chunk': c['text'][:200], 'score': s} for c, s in retrieved_chunks[:3]]
518
+ }
519
+
520
+ # Main functions for external use
521
+ def process_documents(documents: List[Dict[str, str]]) -> Dict[str, Any]:
522
+ """Process documents and build indices"""
523
+ all_chunks = []
524
+
525
+ for doc in documents:
526
+ chunks = chunk_text_hierarchical(doc['text'], doc.get('title', 'Document'))
527
+ all_chunks.extend(chunks)
528
+
529
+ build_all_indices(all_chunks)
530
+
531
+ return {
532
+ 'success': True,
533
+ 'chunk_count': len(all_chunks),
534
+ 'message': f'Processed {len(documents)} documents into {len(all_chunks)} chunks'
535
+ }
536
+
537
+ def query_documents(query: str, top_k: int = 5) -> Dict[str, Any]:
538
+ """Main query function - takes query, returns answer with sources"""
539
+ if not CHUNKS_DATA:
540
+ return {'error': 'No documents indexed. Call process_documents first.'}
541
+
542
+ # Analyze query
543
+ query_analysis = analyze_query(query)
544
+
545
+ # Multi-stage retrieval
546
+ retrieved_chunks = multi_stage_retrieval(query_analysis, top_k)
547
+
548
+ if not retrieved_chunks:
549
+ return {
550
+ 'error': 'No relevant documents found',
551
+ 'query_analysis': query_analysis
552
+ }
553
+
554
+ # Generate answer
555
+ result = generate_answer_with_reasoning(query, retrieved_chunks)
556
+ result['query_analysis'] = query_analysis
557
+
558
+ return result
559
+
560
+ def search_chunks_simple(query: str, top_k: int = 3) -> List[Dict[str, Any]]:
561
+ """Simple search function for compatibility"""
562
+ if not CHUNKS_DATA:
563
+ return []
564
+
565
+ query_analysis = analyze_query(query)
566
+ retrieved_chunks = multi_stage_retrieval(query_analysis, top_k)
567
+
568
+ results = []
569
+ for chunk, score in retrieved_chunks:
570
+ results.append({
571
+ 'chunk': {
572
+ 'id': chunk['id'],
573
+ 'text': chunk['text'],
574
+ 'title': chunk['title']
575
+ },
576
+ 'score': score
577
+ })
578
+
579
+ return results
580
+
581
+ def generate_conservative_answer(query: str, context_chunks: List[Dict[str, Any]]) -> str:
582
+ """Generate conservative answer - for compatibility"""
583
+ if not context_chunks:
584
+ return "No relevant information found."
585
+
586
+ # Convert format
587
+ retrieved_chunks = [(chunk['chunk'], chunk['score']) for chunk in context_chunks]
588
+ result = generate_answer_with_reasoning(query, retrieved_chunks)
589
+
590
+ if 'error' in result:
591
+ return result['error']
592
+
593
+ return result.get('answer', 'Unable to generate answer.')
requirements.txt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hugging Face Spaces requirements
2
+ gradio==4.44.0
3
+ requests==2.31.0
4
+ fastapi==0.115.6
5
+ uvicorn==0.32.1
6
+ python-multipart==0.0.9 # ✅ needed for FastAPI endpoints
7
+
8
+ # Core ML/NLP
9
+ torch==2.2.2
10
+ transformers==4.44.2
11
+ sentence-transformers==2.2.2
12
+ spacy==3.8.2
13
+ scikit-learn==1.5.2
14
+ numpy==1.26.4
15
+ pandas==2.2.3
16
+ nltk==3.9.1
17
+
18
+ # Retrieval / Search
19
+ faiss-cpu==1.7.4
20
+ rank-bm25==0.2.2
21
+
22
+ # API clients
23
+ groq==0.13.0