kn29 commited on
Commit
fef6ed9
·
verified ·
1 Parent(s): d5bcbe3

Update rag.py

Browse files
Files changed (1) hide show
  1. rag.py +159 -497
rag.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import torch
2
  import numpy as np
3
  from transformers import AutoTokenizer, AutoModel
@@ -13,14 +14,22 @@ import networkx as nx
13
  from collections import defaultdict
14
  import spacy
15
  from rank_bm25 import BM25Okapi
 
 
 
 
 
 
 
16
 
17
  # Global model instances (shared across sessions)
18
  _SHARED_MODEL = None
19
  _SHARED_TOKENIZER = None
20
  _SHARED_NLP_MODEL = None
21
  _DEVICE = None
 
22
 
23
- # Legal knowledge base (shared constants)
24
  LEGAL_CONCEPTS = {
25
  'liability': ['negligence', 'strict liability', 'vicarious liability', 'product liability'],
26
  'contract': ['breach', 'consideration', 'offer', 'acceptance', 'damages', 'specific performance'],
@@ -39,7 +48,7 @@ QUERY_PATTERNS = {
39
 
40
  def initialize_models(model_id: str, groq_api_key: str = None):
41
  """Initialize shared models (call once at startup)"""
42
- global _SHARED_MODEL, _SHARED_TOKENIZER, _SHARED_NLP_MODEL, _DEVICE
43
 
44
  try:
45
  nltk.download('punkt', quiet=True)
@@ -48,21 +57,24 @@ def initialize_models(model_id: str, groq_api_key: str = None):
48
  pass
49
 
50
  _DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
51
- print(f"Using device: {_DEVICE}")
52
 
53
- print(f"Loading model: {model_id}")
54
  _SHARED_TOKENIZER = AutoTokenizer.from_pretrained(model_id)
55
  _SHARED_MODEL = AutoModel.from_pretrained(model_id).to(_DEVICE)
56
  _SHARED_MODEL.eval()
57
 
 
 
 
58
  try:
59
  _SHARED_NLP_MODEL = spacy.load("en_core_web_sm")
60
  except:
61
- print("SpaCy model not found, using basic NER")
62
  _SHARED_NLP_MODEL = None
63
 
64
- class SessionRAG:
65
- """Session-specific RAG instance"""
66
 
67
  def __init__(self, session_id: str, groq_api_key: str = None):
68
  self.session_id = session_id
@@ -71,495 +83,209 @@ class SessionRAG:
71
  # Session-specific indices and data
72
  self.dense_index = None
73
  self.bm25_index = None
74
- self.concept_graph = None
75
  self.token_to_chunks = None
76
  self.chunks_data = []
77
 
 
 
 
 
78
  # Verify shared models are initialized
79
  if _SHARED_MODEL is None or _SHARED_TOKENIZER is None:
80
  raise ValueError("Models not initialized. Call initialize_models() first.")
81
 
82
- def create_embedding(self, text: str) -> np.ndarray:
83
- """Create dense embedding for text"""
84
- inputs = _SHARED_TOKENIZER(text, padding=True, truncation=True,
85
- max_length=512, return_tensors='pt').to(_DEVICE)
86
-
87
- with torch.no_grad():
88
- outputs = _SHARED_MODEL(**inputs)
89
- attention_mask = inputs['attention_mask']
90
- token_embeddings = outputs.last_hidden_state
91
- input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
92
- embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
93
-
94
- # Normalize embeddings
95
- embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
96
-
97
- return embeddings.cpu().numpy()[0]
98
-
99
  def load_existing_session_data(self, chunks_from_db: List[Dict[str, Any]]):
100
- """Load pre-existing chunks with embeddings from database"""
101
- print(f"Loading existing session data for {self.session_id}: {len(chunks_from_db)} chunks...")
 
102
 
103
- # Process chunks from MongoDB format
104
- self.chunks_data = self.process_db_chunks(chunks_from_db)
105
 
106
- # Rebuild indices from existing embeddings (don't recreate embeddings)
107
- self.rebuild_indices_from_existing_embeddings()
108
 
109
- print(f"Session {self.session_id} loaded with existing embeddings!")
 
110
 
111
- def rebuild_indices_from_existing_embeddings(self):
112
- """Rebuild search indices using existing embeddings from database"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  if not self.chunks_data:
114
  raise ValueError("No chunks data available")
115
 
116
- print(f"Rebuilding indices from existing embeddings...")
 
117
 
118
- # Extract existing embeddings
119
  embeddings = []
120
  for chunk in self.chunks_data:
121
- if 'embedding' in chunk and chunk['embedding'] is not None:
122
- embeddings.append(chunk['embedding'])
123
- else:
124
  raise ValueError(f"Missing embedding for chunk {chunk.get('id', 'unknown')}")
 
 
 
 
 
125
 
126
- # Build FAISS index from existing embeddings
127
- embeddings_matrix = np.vstack(embeddings)
128
  self.dense_index = faiss.IndexFlatIP(embeddings_matrix.shape[1])
129
- self.dense_index.add(embeddings_matrix.astype('float32'))
130
 
131
- # Build other indices
132
  tokenized_corpus = [chunk['text'].lower().split() for chunk in self.chunks_data]
133
  self.bm25_index = BM25Okapi(tokenized_corpus)
134
 
135
- # 3. ColBERT-style token index
136
  self.token_to_chunks = defaultdict(set)
137
  for i, chunk in enumerate(self.chunks_data):
138
  tokens = chunk['text'].lower().split()
139
  for token in tokens:
140
  self.token_to_chunks[token].add(i)
141
 
142
- # 4. Legal concept graph
143
- self.concept_graph = nx.Graph()
144
- for i, chunk in enumerate(self.chunks_data):
145
- self.concept_graph.add_node(i, text=chunk['text'][:200], importance=chunk['importance_score'])
146
-
147
- for j, other_chunk in enumerate(self.chunks_data[i+1:], i+1):
148
- shared_entities = set(e['text'] for e in chunk['entities']) & \
149
- set(e['text'] for e in other_chunk['entities'])
150
- if shared_entities:
151
- self.concept_graph.add_edge(i, j, weight=len(shared_entities))
152
-
153
- print(f"All indices rebuilt from existing embeddings for session {self.session_id}!")
154
 
155
- def process_db_chunks(self, chunks_from_db: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
156
- """Convert MongoDB chunk format to internal format"""
157
- processed_chunks = []
158
- for chunk in chunks_from_db:
159
- # Convert embedding from list to numpy array if needed
160
- embedding = chunk.get('embedding')
161
- if embedding and isinstance(embedding, list):
162
- embedding = np.array(embedding)
163
-
164
- processed_chunk = {
165
- 'id': chunk.get('chunk_id', chunk.get('id')),
166
- 'text': chunk.get('content', chunk.get('text', '')),
167
- 'title': chunk.get('title', 'Document'),
168
- 'section_type': chunk.get('section_type', 'general'),
169
- 'importance_score': chunk.get('importance_score', 1.0),
170
- 'entities': chunk.get('entities', []),
171
- 'embedding': embedding
172
- }
173
- processed_chunks.append(processed_chunk)
174
 
175
- return processed_chunks
176
-
177
- def extract_legal_entities(self, text: str) -> List[Dict[str, Any]]:
178
- """Extract legal entities from text"""
179
- entities = []
180
-
181
- if _SHARED_NLP_MODEL:
182
- doc = _SHARED_NLP_MODEL(text[:5000]) # Limit for performance
183
- for ent in doc.ents:
184
- if ent.label_ in ['PERSON', 'ORG', 'LAW', 'GPE']:
185
- entities.append({
186
- 'text': ent.text,
187
- 'type': ent.label_,
188
- 'importance': 1.0
189
- })
190
-
191
- # Legal citations
192
- citation_pattern = r'\b\d+\s+[A-Z][a-z]+\.?\s+\d+\b'
193
- for match in re.finditer(citation_pattern, text):
194
- entities.append({
195
- 'text': match.group(),
196
- 'type': 'case_citation',
197
- 'importance': 2.0
198
- })
199
-
200
- # Statute references
201
- statute_pattern = r'§\s*\d+[\.\d]*|\bSection\s+\d+'
202
- for match in re.finditer(statute_pattern, text):
203
- entities.append({
204
- 'text': match.group(),
205
- 'type': 'statute',
206
- 'importance': 1.5
207
- })
208
-
209
- return entities
210
 
211
- def analyze_query(self, query: str) -> Dict[str, Any]:
212
- """Analyze query to understand intent"""
213
  query_lower = query.lower()
214
 
215
- # Classify query type
216
  query_type = 'general'
217
  for qtype, patterns in QUERY_PATTERNS.items():
218
  if any(pattern in query_lower for pattern in patterns):
219
  query_type = qtype
220
  break
221
 
222
- # Extract entities
223
- entities = self.extract_legal_entities(query)
224
-
225
- # Extract key concepts
226
  key_concepts = []
227
  for concept_category, concepts in LEGAL_CONCEPTS.items():
228
  for concept in concepts:
229
  if concept in query_lower:
230
  key_concepts.append(concept)
231
 
232
- # Generate expanded queries
233
  expanded_queries = [query]
234
-
235
- # Concept expansion
236
  if key_concepts:
237
- expanded_queries.append(f"{query} {' '.join(key_concepts[:3])}")
238
-
239
- # Type-based expansion
240
- if query_type == 'precedent':
241
- expanded_queries.append(f"legal precedent case law {query}")
242
- elif query_type == 'statute_interpretation':
243
- expanded_queries.append(f"statutory interpretation meaning {query}")
244
-
245
- # HyDE - Hypothetical document generation
246
- if self.groq_client:
247
- hyde_doc = self.generate_hypothetical_document(query)
248
- if hyde_doc:
249
- expanded_queries.append(hyde_doc)
250
 
251
  return {
252
  'original_query': query,
253
  'query_type': query_type,
254
- 'entities': entities,
255
  'key_concepts': key_concepts,
256
- 'expanded_queries': expanded_queries[:4] # Limit to 4 queries
257
  }
258
 
259
- def generate_hypothetical_document(self, query: str) -> Optional[str]:
260
- """Generate hypothetical answer document (HyDE technique)"""
261
- if not self.groq_client:
262
- return None
263
-
264
- try:
265
- prompt = f"""Generate a brief hypothetical legal document excerpt that would answer this question: {query}
266
-
267
- Write it as if it's from an actual legal case or statute. Be specific and use legal language.
268
- Keep it under 100 words."""
269
-
270
- response = self.groq_client.chat.completions.create(
271
- messages=[
272
- {"role": "system", "content": "You are a legal expert generating hypothetical legal text."},
273
- {"role": "user", "content": prompt}
274
- ],
275
- model="llama-3.1-8b-instant",
276
- temperature=0.3,
277
- max_tokens=150
278
- )
279
-
280
- return response.choices[0].message.content
281
- except:
282
- return None
283
-
284
- def chunk_text_hierarchical(self, text: str, title: str = "") -> List[Dict[str, Any]]:
285
- """Create hierarchical chunks with legal structure awareness"""
286
- chunks = []
287
-
288
- # Clean text
289
- text = re.sub(r'\s+', ' ', text)
290
-
291
- # Identify legal sections
292
- section_patterns = [
293
- (r'(?i)\bFACTS?\b[:\s]', 'facts'),
294
- (r'(?i)\bHOLDING\b[:\s]', 'holding'),
295
- (r'(?i)\bREASONING\b[:\s]', 'reasoning'),
296
- (r'(?i)\bDISSENT\b[:\s]', 'dissent'),
297
- (r'(?i)\bCONCLUSION\b[:\s]', 'conclusion')
298
- ]
299
-
300
- sections = []
301
- for pattern, section_type in section_patterns:
302
- matches = list(re.finditer(pattern, text))
303
- for match in matches:
304
- sections.append((match.start(), section_type))
305
-
306
- sections.sort(key=lambda x: x[0])
307
-
308
- # Split into sentences
309
- import nltk
310
- try:
311
- sentences = nltk.sent_tokenize(text)
312
- except:
313
- sentences = text.split('. ')
314
-
315
- # Create chunks
316
- current_section = 'introduction'
317
- section_sentences = []
318
- chunk_size = 500 # words
319
-
320
- for sent in sentences:
321
- # Check section type
322
- sent_pos = text.find(sent)
323
- for pos, stype in sections:
324
- if sent_pos >= pos:
325
- current_section = stype
326
-
327
- section_sentences.append(sent)
328
-
329
- # Create chunk when we have enough content
330
- chunk_text = ' '.join(section_sentences)
331
- if len(chunk_text.split()) >= chunk_size or len(section_sentences) >= 10:
332
- chunk_id = hashlib.md5(f"{title}_{len(chunks)}_{chunk_text[:50]}".encode()).hexdigest()[:12]
333
-
334
- # Calculate importance
335
- importance = 1.0
336
- section_weights = {
337
- 'holding': 2.0, 'conclusion': 1.8, 'reasoning': 1.5,
338
- 'facts': 1.2, 'dissent': 0.8
339
- }
340
- importance *= section_weights.get(current_section, 1.0)
341
-
342
- # Entity importance
343
- entities = self.extract_legal_entities(chunk_text)
344
- if entities:
345
- entity_score = sum(e['importance'] for e in entities) / len(entities)
346
- importance *= (1 + entity_score * 0.5)
347
-
348
- chunks.append({
349
- 'id': chunk_id,
350
- 'text': chunk_text,
351
- 'title': title,
352
- 'section_type': current_section,
353
- 'importance_score': importance,
354
- 'entities': entities,
355
- 'embedding': None # Will be filled during indexing
356
- })
357
-
358
- section_sentences = []
359
-
360
- # Add remaining sentences
361
- if section_sentences:
362
- chunk_text = ' '.join(section_sentences)
363
- chunk_id = hashlib.md5(f"{title}_{len(chunks)}_{chunk_text[:50]}".encode()).hexdigest()[:12]
364
- chunks.append({
365
- 'id': chunk_id,
366
- 'text': chunk_text,
367
- 'title': title,
368
- 'section_type': current_section,
369
- 'importance_score': 1.0,
370
- 'entities': self.extract_legal_entities(chunk_text),
371
- 'embedding': None
372
- })
373
-
374
- return chunks
375
-
376
- def build_all_indices(self, chunks: List[Dict[str, Any]]):
377
- """Build all retrieval indices for this session"""
378
- self.chunks_data = chunks
379
- print(f"Building indices for session {self.session_id}: {len(chunks)} chunks...")
380
-
381
- # 1. Dense embeddings + FAISS index
382
- print("Building FAISS index...")
383
- embeddings = []
384
- for chunk in tqdm(chunks, desc="Creating embeddings"):
385
- embedding = self.create_embedding(chunk['text'])
386
- chunk['embedding'] = embedding
387
- embeddings.append(embedding)
388
-
389
- embeddings_matrix = np.vstack(embeddings)
390
- self.dense_index = faiss.IndexFlatIP(embeddings_matrix.shape[1]) # Inner product for normalized vectors
391
- self.dense_index.add(embeddings_matrix.astype('float32'))
392
-
393
- # 2. BM25 index for sparse retrieval
394
- print("Building BM25 index...")
395
- tokenized_corpus = [chunk['text'].lower().split() for chunk in chunks]
396
- self.bm25_index = BM25Okapi(tokenized_corpus)
397
-
398
- # 3. ColBERT-style token index
399
- print("Building ColBERT token index...")
400
- self.token_to_chunks = defaultdict(set)
401
- for i, chunk in enumerate(chunks):
402
- # Simple tokenization for token-level matching
403
- tokens = chunk['text'].lower().split()
404
- for token in tokens:
405
- self.token_to_chunks[token].add(i)
406
-
407
- # 4. Legal concept graph
408
- print("Building legal concept graph...")
409
- self.concept_graph = nx.Graph()
410
-
411
- for i, chunk in enumerate(chunks):
412
- self.concept_graph.add_node(i, text=chunk['text'][:200], importance=chunk['importance_score'])
413
-
414
- # Add edges between chunks with shared entities
415
- for j, other_chunk in enumerate(chunks[i+1:], i+1):
416
- shared_entities = set(e['text'] for e in chunk['entities']) & \
417
- set(e['text'] for e in other_chunk['entities'])
418
- if shared_entities:
419
- self.concept_graph.add_edge(i, j, weight=len(shared_entities))
420
-
421
- print(f"All indices built successfully for session {self.session_id}!")
422
-
423
- def multi_stage_retrieval(self, query_analysis: Dict[str, Any], top_k: int = 10) -> List[Tuple[Dict[str, Any], float]]:
424
- """Perform multi-stage retrieval combining all techniques"""
425
  candidates = {}
426
 
427
- print(f"Performing multi-stage retrieval for session {self.session_id}...")
 
 
 
 
 
 
428
 
429
- # Stage 1: Dense retrieval with expanded queries
430
- print("Stage 1: Dense retrieval...")
431
- for query in query_analysis['expanded_queries'][:3]:
432
- query_emb = self.create_embedding(query)
433
- scores, indices = self.dense_index.search(
434
- query_emb.reshape(1, -1).astype('float32'),
435
- top_k * 2
436
- )
 
 
 
 
 
 
437
 
438
- for idx, score in zip(indices[0], scores[0]):
439
  if idx < len(self.chunks_data):
440
- chunk_id = self.chunks_data[idx]['id']
441
- if chunk_id not in candidates:
442
- candidates[chunk_id] = {'chunk': self.chunks_data[idx], 'scores': {}}
443
- candidates[chunk_id]['scores']['dense'] = float(score)
444
-
445
- # Stage 2: Sparse retrieval (BM25)
446
- print("Stage 2: Sparse retrieval...")
447
- query_tokens = query_analysis['original_query'].lower().split()
448
- bm25_scores = self.bm25_index.get_scores(query_tokens)
449
- top_bm25_indices = np.argsort(bm25_scores)[-top_k*2:][::-1]
450
-
451
- for idx in top_bm25_indices:
452
- if idx < len(self.chunks_data):
453
- chunk_id = self.chunks_data[idx]['id']
454
- if chunk_id not in candidates:
455
- candidates[chunk_id] = {'chunk': self.chunks_data[idx], 'scores': {}}
456
- candidates[chunk_id]['scores']['bm25'] = float(bm25_scores[idx])
457
-
458
- # Stage 3: Entity-based retrieval
459
- print("Stage 3: Entity-based retrieval...")
460
- for entity in query_analysis['entities']:
461
- for chunk in self.chunks_data:
462
- chunk_entity_texts = [e['text'].lower() for e in chunk['entities']]
463
- if entity['text'].lower() in chunk_entity_texts:
464
  chunk_id = chunk['id']
465
  if chunk_id not in candidates:
466
- candidates[chunk_id] = {'chunk': chunk, 'scores': {}}
467
- candidates[chunk_id]['scores']['entity'] = \
468
- candidates[chunk_id]['scores'].get('entity', 0) + entity['importance']
469
-
470
- # Stage 4: Graph-based retrieval
471
- print("Stage 4: Graph-based retrieval...")
472
- if candidates and self.concept_graph:
473
- seed_chunks = []
474
- for chunk_id, data in list(candidates.items())[:5]:
475
- for i, chunk in enumerate(self.chunks_data):
476
- if chunk['id'] == chunk_id:
477
- seed_chunks.append(i)
478
- break
479
-
480
- for seed_idx in seed_chunks:
481
- if seed_idx in self.concept_graph:
482
- neighbors = list(self.concept_graph.neighbors(seed_idx))[:3]
483
- for neighbor_idx in neighbors:
484
- if neighbor_idx < len(self.chunks_data):
485
- chunk = self.chunks_data[neighbor_idx]
486
- chunk_id = chunk['id']
487
- if chunk_id not in candidates:
488
- candidates[chunk_id] = {'chunk': chunk, 'scores': {}}
489
- candidates[chunk_id]['scores']['graph'] = 0.5
490
-
491
- # Combine scores
492
- print("Combining scores...")
493
- weights = {'dense': 0.35, 'bm25': 0.25, 'entity': 0.25, 'graph': 0.15}
494
- final_scores = []
495
-
496
- for chunk_id, data in candidates.items():
497
- chunk = data['chunk']
498
- scores = data['scores']
499
-
500
- final_score = 0
501
- for method, weight in weights.items():
502
- if method in scores:
503
- # Normalize scores
504
- if method == 'dense':
505
- normalized = (scores[method] + 1) / 2 # [-1, 1] to [0, 1]
506
- elif method == 'bm25':
507
- normalized = min(scores[method] / 10, 1)
508
- elif method == 'entity':
509
- normalized = min(scores[method] / 3, 1)
510
  else:
511
- normalized = scores[method]
512
-
513
- final_score += weight * normalized
514
-
515
- # Boost by importance and section relevance
516
- final_score *= chunk['importance_score']
517
-
518
- if query_analysis['query_type'] == 'precedent' and chunk['section_type'] == 'holding':
519
- final_score *= 1.5
520
- elif query_analysis['query_type'] == 'factual' and chunk['section_type'] == 'facts':
521
- final_score *= 1.5
522
-
523
- final_scores.append((chunk, final_score))
524
 
525
- # Sort and return top-k
 
526
  final_scores.sort(key=lambda x: x[1], reverse=True)
 
527
  return final_scores[:top_k]
528
 
529
- def generate_answer_with_reasoning(self, query: str, retrieved_chunks: List[Tuple[Dict[str, Any], float]]) -> Dict[str, Any]:
530
- """Generate answer with legal reasoning"""
531
  if not self.groq_client:
532
  return {'error': 'Groq client not initialized'}
533
 
534
- # Prepare context
535
  context_parts = []
536
- for i, (chunk, score) in enumerate(retrieved_chunks, 1):
537
- entities = ', '.join([e['text'] for e in chunk['entities'][:3]])
538
  context_parts.append(f"""
539
- Document {i} [{chunk['title']}] - Relevance: {score:.2f}
540
- Section: {chunk['section_type']}
541
- Key Entities: {entities}
542
- Content: {chunk['text'][:800]}
543
- """)
544
 
545
  context = "\n---\n".join(context_parts)
546
 
547
- system_prompt = """You are an expert legal analyst. Provide thorough legal analysis using the IRAC method:
548
- 1. ISSUE: Identify the legal issue(s)
549
- 2. RULE: State the applicable legal rules/precedents
550
- 3. APPLICATION: Apply the rules to the facts
551
- 4. CONCLUSION: Provide a clear conclusion
552
-
553
- CRITICAL: Base ALL responses on the provided document excerpts only. Quote directly when making claims.
554
- If information is not in the excerpts, state "This information is not provided in the available documents."
555
- """
556
 
557
  user_prompt = f"""Query: {query}
558
 
559
- Retrieved Legal Documents:
560
- {context}
561
 
562
- Please provide a comprehensive legal analysis using IRAC method. Cite the documents when making claims."""
563
 
564
  try:
565
  response = self.groq_client.chat.completions.create(
@@ -569,7 +295,7 @@ class SessionRAG:
569
  ],
570
  model="llama-3.1-8b-instant",
571
  temperature=0.1,
572
- max_tokens=1000
573
  )
574
 
575
  answer = response.choices[0].message.content
@@ -587,45 +313,28 @@ class SessionRAG:
587
  'title': chunk['title'],
588
  'section': chunk['section_type'],
589
  'relevance_score': float(score),
590
- 'excerpt': chunk['text'][:200] + '...',
591
- 'entities': [e['text'] for e in chunk['entities'][:5]]
592
  }
593
- for chunk, score in retrieved_chunks
594
  ]
595
  }
596
 
597
  except Exception as e:
598
- return {
599
- 'error': f'Error generating answer: {str(e)}',
600
- 'sources': [{'chunk': c['text'][:200], 'score': s} for c, s in retrieved_chunks[:3]]
601
- }
602
-
603
- def process_documents(self, documents: List[Dict[str, str]]) -> Dict[str, Any]:
604
- """Process documents and build indices for this session"""
605
- all_chunks = []
606
-
607
- for doc in documents:
608
- chunks = self.chunk_text_hierarchical(doc['text'], doc.get('title', 'Document'))
609
- all_chunks.extend(chunks)
610
-
611
- self.build_all_indices(all_chunks)
612
-
613
- return {
614
- 'success': True,
615
- 'chunk_count': len(all_chunks),
616
- 'message': f'Processed {len(documents)} documents into {len(all_chunks)} chunks for session {self.session_id}'
617
- }
618
 
619
  def query_documents(self, query: str, top_k: int = 5) -> Dict[str, Any]:
620
- """Main query function - takes query, returns answer with sources"""
621
  if not self.chunks_data:
622
- return {'error': f'No documents indexed for session {self.session_id}. Call process_documents first.'}
623
 
624
- # Analyze query
625
- query_analysis = self.analyze_query(query)
626
 
627
- # Multi-stage retrieval
628
- retrieved_chunks = self.multi_stage_retrieval(query_analysis, top_k)
 
 
 
629
 
630
  if not retrieved_chunks:
631
  return {
@@ -634,59 +343,12 @@ class SessionRAG:
634
  }
635
 
636
  # Generate answer
637
- result = self.generate_answer_with_reasoning(query, retrieved_chunks)
638
  result['query_analysis'] = query_analysis
 
639
 
 
640
  return result
641
 
642
- def search_chunks_simple(self, query: str, top_k: int = 3) -> List[Dict[str, Any]]:
643
- """Simple search function for compatibility"""
644
- if not self.chunks_data:
645
- return []
646
-
647
- query_analysis = self.analyze_query(query)
648
- retrieved_chunks = self.multi_stage_retrieval(query_analysis, top_k)
649
-
650
- results = []
651
- for chunk, score in retrieved_chunks:
652
- results.append({
653
- 'chunk': {
654
- 'id': chunk['id'],
655
- 'text': chunk['text'],
656
- 'title': chunk['title']
657
- },
658
- 'score': score
659
- })
660
-
661
- return results
662
-
663
- def generate_conservative_answer(self, query: str, context_chunks: List[Dict[str, Any]]) -> str:
664
- """Generate conservative answer - for compatibility"""
665
- if not context_chunks:
666
- return "No relevant information found."
667
-
668
- # Convert format
669
- retrieved_chunks = [(chunk['chunk'], chunk['score']) for chunk in context_chunks]
670
- result = self.generate_answer_with_reasoning(query, retrieved_chunks)
671
-
672
- if 'error' in result:
673
- return result['error']
674
-
675
- return result.get('answer', 'Unable to generate answer.')
676
-
677
- # Backward compatibility functions (deprecated - use SessionRAG instead)
678
- def process_documents(documents: List[Dict[str, str]]) -> Dict[str, Any]:
679
- """Deprecated: Use SessionRAG.process_documents() instead"""
680
- raise NotImplementedError("Global functions are deprecated. Use SessionRAG class instead.")
681
-
682
- def query_documents(query: str, top_k: int = 5) -> Dict[str, Any]:
683
- """Deprecated: Use SessionRAG.query_documents() instead"""
684
- raise NotImplementedError("Global functions are deprecated. Use SessionRAG class instead.")
685
-
686
- def search_chunks_simple(query: str, top_k: int = 3) -> List[Dict[str, Any]]:
687
- """Deprecated: Use SessionRAG.search_chunks_simple() instead"""
688
- raise NotImplementedError("Global functions are deprecated. Use SessionRAG class instead.")
689
-
690
- def generate_conservative_answer(query: str, context_chunks: List[Dict[str, Any]]) -> str:
691
- """Deprecated: Use SessionRAG.generate_conservative_answer() instead"""
692
- raise NotImplementedError("Global functions are deprecated. Use SessionRAG class instead.")
 
1
+ # rag_optimized.py - Performance-Optimized RAG System
2
  import torch
3
  import numpy as np
4
  from transformers import AutoTokenizer, AutoModel
 
14
  from collections import defaultdict
15
  import spacy
16
  from rank_bm25 import BM25Okapi
17
+ import asyncio
18
+ import time
19
+ from concurrent.futures import ThreadPoolExecutor
20
+ import logging
21
+
22
+ # Configure logging
23
+ logger = logging.getLogger(__name__)
24
 
25
  # Global model instances (shared across sessions)
26
  _SHARED_MODEL = None
27
  _SHARED_TOKENIZER = None
28
  _SHARED_NLP_MODEL = None
29
  _DEVICE = None
30
+ _THREAD_POOL = None
31
 
32
+ # Legal knowledge base (optimized)
33
  LEGAL_CONCEPTS = {
34
  'liability': ['negligence', 'strict liability', 'vicarious liability', 'product liability'],
35
  'contract': ['breach', 'consideration', 'offer', 'acceptance', 'damages', 'specific performance'],
 
48
 
49
  def initialize_models(model_id: str, groq_api_key: str = None):
50
  """Initialize shared models (call once at startup)"""
51
+ global _SHARED_MODEL, _SHARED_TOKENIZER, _SHARED_NLP_MODEL, _DEVICE, _THREAD_POOL
52
 
53
  try:
54
  nltk.download('punkt', quiet=True)
 
57
  pass
58
 
59
  _DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
60
+ logger.info(f"Using device: {_DEVICE}")
61
 
62
+ logger.info(f"Loading model: {model_id}")
63
  _SHARED_TOKENIZER = AutoTokenizer.from_pretrained(model_id)
64
  _SHARED_MODEL = AutoModel.from_pretrained(model_id).to(_DEVICE)
65
  _SHARED_MODEL.eval()
66
 
67
+ # Initialize thread pool for CPU-bound operations
68
+ _THREAD_POOL = ThreadPoolExecutor(max_workers=4)
69
+
70
  try:
71
  _SHARED_NLP_MODEL = spacy.load("en_core_web_sm")
72
  except:
73
+ logger.warning("SpaCy model not found, using basic NER")
74
  _SHARED_NLP_MODEL = None
75
 
76
+ class OptimizedSessionRAG:
77
+ """High-performance session-specific RAG instance that loads pre-computed embeddings"""
78
 
79
  def __init__(self, session_id: str, groq_api_key: str = None):
80
  self.session_id = session_id
 
83
  # Session-specific indices and data
84
  self.dense_index = None
85
  self.bm25_index = None
 
86
  self.token_to_chunks = None
87
  self.chunks_data = []
88
 
89
+ # Performance tracking
90
+ self.load_time = None
91
+ self.index_build_time = None
92
+
93
  # Verify shared models are initialized
94
  if _SHARED_MODEL is None or _SHARED_TOKENIZER is None:
95
  raise ValueError("Models not initialized. Call initialize_models() first.")
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  def load_existing_session_data(self, chunks_from_db: List[Dict[str, Any]]):
98
+ """OPTIMIZED: Load pre-existing chunks with embeddings from database - NO EMBEDDING CREATION"""
99
+ start_time = time.time()
100
+ logger.info(f"Loading existing session data for {self.session_id}: {len(chunks_from_db)} chunks...")
101
 
102
+ # Process chunks from MongoDB format - DIRECT LOADING, NO EMBEDDING COMPUTATION
103
+ self.chunks_data = self._process_db_chunks_fast(chunks_from_db)
104
 
105
+ # Rebuild indices from existing embeddings ONLY
106
+ self._rebuild_indices_from_precomputed_embeddings()
107
 
108
+ self.load_time = time.time() - start_time
109
+ logger.info(f"Session {self.session_id} loaded in {self.load_time:.2f}s with PRE-COMPUTED embeddings!")
110
 
111
+ def _process_db_chunks_fast(self, chunks_from_db: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
112
+ """FAST: Convert MongoDB chunk format to internal format without any computation"""
113
+ processed_chunks = []
114
+
115
+ for chunk in chunks_from_db:
116
+ # Convert embedding from list to numpy array if needed - NO COMPUTATION
117
+ embedding = chunk.get('embedding')
118
+ if embedding is None:
119
+ raise ValueError(f"Missing embedding for chunk {chunk.get('chunk_id', 'unknown')}")
120
+
121
+ if isinstance(embedding, list):
122
+ embedding = np.array(embedding, dtype=np.float32)
123
+
124
+ processed_chunk = {
125
+ 'id': chunk.get('chunk_id', chunk.get('id')),
126
+ 'text': chunk.get('content', chunk.get('text', '')),
127
+ 'title': chunk.get('title', 'Document'),
128
+ 'section_type': chunk.get('section_type', 'general'),
129
+ 'importance_score': chunk.get('importance_score', 1.0),
130
+ 'entities': chunk.get('entities', []),
131
+ 'embedding': embedding # PRE-COMPUTED, NO CREATION
132
+ }
133
+ processed_chunks.append(processed_chunk)
134
+
135
+ return processed_chunks
136
+
137
+ def _rebuild_indices_from_precomputed_embeddings(self):
138
+ """OPTIMIZED: Rebuild search indices using ONLY pre-computed embeddings from database"""
139
  if not self.chunks_data:
140
  raise ValueError("No chunks data available")
141
 
142
+ start_time = time.time()
143
+ logger.info(f"Rebuilding indices from {len(self.chunks_data)} pre-computed embeddings...")
144
 
145
+ # 1. Build FAISS index from existing embeddings - NO EMBEDDING COMPUTATION
146
  embeddings = []
147
  for chunk in self.chunks_data:
148
+ if chunk['embedding'] is None:
 
 
149
  raise ValueError(f"Missing embedding for chunk {chunk.get('id', 'unknown')}")
150
+ embeddings.append(chunk['embedding'])
151
+
152
+ # Stack embeddings efficiently
153
+ embeddings_matrix = np.vstack(embeddings).astype('float32')
154
+ logger.info(f"Built embeddings matrix: {embeddings_matrix.shape}")
155
 
156
+ # Build FAISS index
 
157
  self.dense_index = faiss.IndexFlatIP(embeddings_matrix.shape[1])
158
+ self.dense_index.add(embeddings_matrix)
159
 
160
+ # 2. Build BM25 index efficiently
161
  tokenized_corpus = [chunk['text'].lower().split() for chunk in self.chunks_data]
162
  self.bm25_index = BM25Okapi(tokenized_corpus)
163
 
164
+ # 3. Build token-to-chunk mapping efficiently
165
  self.token_to_chunks = defaultdict(set)
166
  for i, chunk in enumerate(self.chunks_data):
167
  tokens = chunk['text'].lower().split()
168
  for token in tokens:
169
  self.token_to_chunks[token].add(i)
170
 
171
+ self.index_build_time = time.time() - start_time
172
+ logger.info(f"All indices rebuilt in {self.index_build_time:.2f}s from pre-computed embeddings!")
 
 
 
 
 
 
 
 
 
 
173
 
174
+ def create_embedding(self, text: str) -> np.ndarray:
175
+ """Create embedding for query (ONLY used for new queries, not document loading)"""
176
+ inputs = _SHARED_TOKENIZER(text, padding=True, truncation=True,
177
+ max_length=512, return_tensors='pt').to(_DEVICE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
 
179
+ with torch.no_grad():
180
+ outputs = _SHARED_MODEL(**inputs)
181
+ attention_mask = inputs['attention_mask']
182
+ token_embeddings = outputs.last_hidden_state
183
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
184
+ embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
185
+
186
+ # Normalize embeddings
187
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
188
+
189
+ return embeddings.cpu().numpy()[0].astype('float32')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
+ def analyze_query_fast(self, query: str) -> Dict[str, Any]:
192
+ """FAST query analysis - minimal processing"""
193
  query_lower = query.lower()
194
 
195
+ # Quick query type classification
196
  query_type = 'general'
197
  for qtype, patterns in QUERY_PATTERNS.items():
198
  if any(pattern in query_lower for pattern in patterns):
199
  query_type = qtype
200
  break
201
 
202
+ # Extract key concepts quickly
 
 
 
203
  key_concepts = []
204
  for concept_category, concepts in LEGAL_CONCEPTS.items():
205
  for concept in concepts:
206
  if concept in query_lower:
207
  key_concepts.append(concept)
208
 
209
+ # Simple query expansion
210
  expanded_queries = [query]
 
 
211
  if key_concepts:
212
+ expanded_queries.append(f"{query} {' '.join(key_concepts[:2])}")
 
 
 
 
 
 
 
 
 
 
 
 
213
 
214
  return {
215
  'original_query': query,
216
  'query_type': query_type,
 
217
  'key_concepts': key_concepts,
218
+ 'expanded_queries': expanded_queries[:2] # Limit to 2 for speed
219
  }
220
 
221
+ def fast_retrieval(self, query_analysis: Dict[str, Any], top_k: int = 10) -> List[Tuple[Dict[str, Any], float]]:
222
+ """OPTIMIZED: Fast multi-stage retrieval with minimal overhead"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  candidates = {}
224
 
225
+ # Stage 1: Dense retrieval with primary query only
226
+ query = query_analysis['original_query']
227
+ query_emb = self.create_embedding(query)
228
+ scores, indices = self.dense_index.search(
229
+ query_emb.reshape(1, -1),
230
+ min(top_k * 2, len(self.chunks_data))
231
+ )
232
 
233
+ for idx, score in zip(indices[0], scores[0]):
234
+ if idx < len(self.chunks_data):
235
+ chunk = self.chunks_data[idx]
236
+ chunk_id = chunk['id']
237
+ candidates[chunk_id] = {
238
+ 'chunk': chunk,
239
+ 'score': float(score) * chunk['importance_score']
240
+ }
241
+
242
+ # Stage 2: BM25 boost for top candidates
243
+ if len(candidates) < top_k:
244
+ query_tokens = query.lower().split()
245
+ bm25_scores = self.bm25_index.get_scores(query_tokens)
246
+ top_bm25_indices = np.argsort(bm25_scores)[-top_k:][::-1]
247
 
248
+ for idx in top_bm25_indices:
249
  if idx < len(self.chunks_data):
250
+ chunk = self.chunks_data[idx]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
  chunk_id = chunk['id']
252
  if chunk_id not in candidates:
253
+ candidates[chunk_id] = {
254
+ 'chunk': chunk,
255
+ 'score': float(bm25_scores[idx]) * 0.3 # Lower weight for BM25
256
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  else:
258
+ candidates[chunk_id]['score'] += float(bm25_scores[idx]) * 0.2
 
 
 
 
 
 
 
 
 
 
 
 
259
 
260
+ # Convert to list and sort
261
+ final_scores = [(data['chunk'], data['score']) for data in candidates.values()]
262
  final_scores.sort(key=lambda x: x[1], reverse=True)
263
+
264
  return final_scores[:top_k]
265
 
266
+ def generate_fast_answer(self, query: str, retrieved_chunks: List[Tuple[Dict[str, Any], float]]) -> Dict[str, Any]:
267
+ """Generate answer with minimal overhead"""
268
  if not self.groq_client:
269
  return {'error': 'Groq client not initialized'}
270
 
271
+ # Prepare context efficiently
272
  context_parts = []
273
+ for i, (chunk, score) in enumerate(retrieved_chunks[:3], 1): # Limit to top 3 for speed
 
274
  context_parts.append(f"""
275
+ Document {i} - Relevance: {score:.2f}
276
+ {chunk['text'][:600]}
277
+ """)
 
 
278
 
279
  context = "\n---\n".join(context_parts)
280
 
281
+ system_prompt = """You are a legal AI assistant. Provide concise, accurate answers based ONLY on the provided documents. If information isn't in the documents, state that clearly."""
 
 
 
 
 
 
 
 
282
 
283
  user_prompt = f"""Query: {query}
284
 
285
+ Documents:
286
+ {context}
287
 
288
+ Provide a clear, concise answer based on the documents."""
289
 
290
  try:
291
  response = self.groq_client.chat.completions.create(
 
295
  ],
296
  model="llama-3.1-8b-instant",
297
  temperature=0.1,
298
+ max_tokens=500 # Limit for speed
299
  )
300
 
301
  answer = response.choices[0].message.content
 
313
  'title': chunk['title'],
314
  'section': chunk['section_type'],
315
  'relevance_score': float(score),
316
+ 'text_preview': chunk['text'][:200] + '...',
317
+ 'entities': [e['text'] for e in chunk['entities'][:3]]
318
  }
319
+ for chunk, score in retrieved_chunks[:5]
320
  ]
321
  }
322
 
323
  except Exception as e:
324
+ return {'error': f'Error generating answer: {str(e)}'}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
 
326
  def query_documents(self, query: str, top_k: int = 5) -> Dict[str, Any]:
327
+ """OPTIMIZED: Main query function with minimal processing time"""
328
  if not self.chunks_data:
329
+ return {'error': f'No documents indexed for session {self.session_id}'}
330
 
331
+ start_time = time.time()
 
332
 
333
+ # Fast query analysis
334
+ query_analysis = self.analyze_query_fast(query)
335
+
336
+ # Fast retrieval
337
+ retrieved_chunks = self.fast_retrieval(query_analysis, top_k)
338
 
339
  if not retrieved_chunks:
340
  return {
 
343
  }
344
 
345
  # Generate answer
346
+ result = self.generate_fast_answer(query, retrieved_chunks)
347
  result['query_analysis'] = query_analysis
348
+ result['processing_time'] = time.time() - start_time
349
 
350
+ logger.info(f"Query processed in {result['processing_time']:.2f}s")
351
  return result
352
 
353
+ # For backward compatibility - replace SessionRAG with OptimizedSessionRAG
354
+ SessionRAG = OptimizedSessionRAG