|
|
|
|
|
import torch |
|
|
import numpy as np |
|
|
from transformers import AutoTokenizer, AutoModel |
|
|
from typing import List, Dict, Any, Tuple, Optional |
|
|
import faiss |
|
|
import hashlib |
|
|
from tqdm import tqdm |
|
|
from groq import Groq |
|
|
import re |
|
|
import nltk |
|
|
from sklearn.metrics.pairwise import cosine_similarity |
|
|
import networkx as nx |
|
|
from collections import defaultdict |
|
|
import spacy |
|
|
from rank_bm25 import BM25Okapi |
|
|
import asyncio |
|
|
import time |
|
|
from concurrent.futures import ThreadPoolExecutor |
|
|
import logging |
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
_SHARED_MODEL = None |
|
|
_SHARED_TOKENIZER = None |
|
|
_SHARED_NLP_MODEL = None |
|
|
_DEVICE = None |
|
|
_THREAD_POOL = None |
|
|
|
|
|
|
|
|
LEGAL_CONCEPTS = { |
|
|
'liability': ['negligence', 'strict liability', 'vicarious liability', 'product liability'], |
|
|
'contract': ['breach', 'consideration', 'offer', 'acceptance', 'damages', 'specific performance'], |
|
|
'criminal': ['mens rea', 'actus reus', 'intent', 'malice', 'premeditation'], |
|
|
'procedure': ['jurisdiction', 'standing', 'statute of limitations', 'res judicata'], |
|
|
'evidence': ['hearsay', 'relevance', 'privilege', 'burden of proof', 'admissibility'], |
|
|
'constitutional': ['due process', 'equal protection', 'free speech', 'search and seizure'] |
|
|
} |
|
|
|
|
|
QUERY_PATTERNS = { |
|
|
'precedent': ['case', 'precedent', 'ruling', 'held', 'decision'], |
|
|
'statute_interpretation': ['statute', 'section', 'interpretation', 'meaning', 'definition'], |
|
|
'factual': ['what happened', 'facts', 'circumstances', 'events'], |
|
|
'procedure': ['how to', 'procedure', 'process', 'filing', 'requirements'] |
|
|
} |
|
|
|
|
|
def initialize_models(model_id: str, groq_api_key: str = None): |
|
|
"""Initialize shared models (call once at startup)""" |
|
|
global _SHARED_MODEL, _SHARED_TOKENIZER, _SHARED_NLP_MODEL, _DEVICE, _THREAD_POOL |
|
|
|
|
|
try: |
|
|
nltk.download('punkt', quiet=True) |
|
|
nltk.download('stopwords', quiet=True) |
|
|
except: |
|
|
pass |
|
|
|
|
|
_DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
logger.info(f"Using device: {_DEVICE}") |
|
|
|
|
|
logger.info(f"Loading model: {model_id}") |
|
|
_SHARED_TOKENIZER = AutoTokenizer.from_pretrained(model_id) |
|
|
_SHARED_MODEL = AutoModel.from_pretrained(model_id).to(_DEVICE) |
|
|
_SHARED_MODEL.eval() |
|
|
|
|
|
|
|
|
_THREAD_POOL = ThreadPoolExecutor(max_workers=4) |
|
|
|
|
|
try: |
|
|
_SHARED_NLP_MODEL = spacy.load("en_core_web_sm") |
|
|
except: |
|
|
logger.warning("SpaCy model not found, using basic NER") |
|
|
_SHARED_NLP_MODEL = None |
|
|
|
|
|
class OptimizedSessionRAG: |
|
|
"""High-performance session-specific RAG instance that loads pre-computed embeddings""" |
|
|
|
|
|
def __init__(self, session_id: str, groq_api_key: str = None): |
|
|
self.session_id = session_id |
|
|
self.groq_client = Groq(api_key=groq_api_key) if groq_api_key else None |
|
|
|
|
|
|
|
|
self.dense_index = None |
|
|
self.bm25_index = None |
|
|
self.token_to_chunks = None |
|
|
self.chunks_data = [] |
|
|
|
|
|
|
|
|
self.load_time = None |
|
|
self.index_build_time = None |
|
|
|
|
|
|
|
|
if _SHARED_MODEL is None or _SHARED_TOKENIZER is None: |
|
|
raise ValueError("Models not initialized. Call initialize_models() first.") |
|
|
|
|
|
def load_existing_session_data(self, chunks_from_db: List[Dict[str, Any]]): |
|
|
"""OPTIMIZED: Load pre-existing chunks with embeddings from database - NO EMBEDDING CREATION""" |
|
|
start_time = time.time() |
|
|
logger.info(f"Loading existing session data for {self.session_id}: {len(chunks_from_db)} chunks...") |
|
|
|
|
|
|
|
|
self.chunks_data = self._process_db_chunks_fast(chunks_from_db) |
|
|
|
|
|
|
|
|
self._rebuild_indices_from_precomputed_embeddings() |
|
|
|
|
|
self.load_time = time.time() - start_time |
|
|
logger.info(f"Session {self.session_id} loaded in {self.load_time:.2f}s with PRE-COMPUTED embeddings!") |
|
|
|
|
|
def _process_db_chunks_fast(self, chunks_from_db: List[Dict[str, Any]]) -> List[Dict[str, Any]]: |
|
|
"""FAST: Convert MongoDB chunk format to internal format without any computation""" |
|
|
processed_chunks = [] |
|
|
|
|
|
for chunk in chunks_from_db: |
|
|
|
|
|
embedding = chunk.get('embedding') |
|
|
if embedding is None: |
|
|
raise ValueError(f"Missing embedding for chunk {chunk.get('chunk_id', 'unknown')}") |
|
|
|
|
|
if isinstance(embedding, list): |
|
|
embedding = np.array(embedding, dtype=np.float32) |
|
|
|
|
|
processed_chunk = { |
|
|
'id': chunk.get('chunk_id', chunk.get('id')), |
|
|
'text': chunk.get('content', chunk.get('text', '')), |
|
|
'title': chunk.get('title', 'Document'), |
|
|
'section_type': chunk.get('section_type', 'general'), |
|
|
'importance_score': chunk.get('importance_score', 1.0), |
|
|
'entities': chunk.get('entities', []), |
|
|
'embedding': embedding |
|
|
} |
|
|
processed_chunks.append(processed_chunk) |
|
|
|
|
|
return processed_chunks |
|
|
|
|
|
def _rebuild_indices_from_precomputed_embeddings(self): |
|
|
"""OPTIMIZED: Rebuild search indices using ONLY pre-computed embeddings from database""" |
|
|
if not self.chunks_data: |
|
|
raise ValueError("No chunks data available") |
|
|
|
|
|
start_time = time.time() |
|
|
logger.info(f"Rebuilding indices from {len(self.chunks_data)} pre-computed embeddings...") |
|
|
|
|
|
|
|
|
embeddings = [] |
|
|
for chunk in self.chunks_data: |
|
|
if chunk['embedding'] is None: |
|
|
raise ValueError(f"Missing embedding for chunk {chunk.get('id', 'unknown')}") |
|
|
embeddings.append(chunk['embedding']) |
|
|
|
|
|
|
|
|
embeddings_matrix = np.vstack(embeddings).astype('float32') |
|
|
logger.info(f"Built embeddings matrix: {embeddings_matrix.shape}") |
|
|
|
|
|
|
|
|
self.dense_index = faiss.IndexFlatIP(embeddings_matrix.shape[1]) |
|
|
self.dense_index.add(embeddings_matrix) |
|
|
|
|
|
|
|
|
tokenized_corpus = [chunk['text'].lower().split() for chunk in self.chunks_data] |
|
|
self.bm25_index = BM25Okapi(tokenized_corpus) |
|
|
|
|
|
|
|
|
self.token_to_chunks = defaultdict(set) |
|
|
for i, chunk in enumerate(self.chunks_data): |
|
|
tokens = chunk['text'].lower().split() |
|
|
for token in tokens: |
|
|
self.token_to_chunks[token].add(i) |
|
|
|
|
|
self.index_build_time = time.time() - start_time |
|
|
logger.info(f"All indices rebuilt in {self.index_build_time:.2f}s from pre-computed embeddings!") |
|
|
|
|
|
def create_embedding(self, text: str) -> np.ndarray: |
|
|
"""Create embedding for query (ONLY used for new queries, not document loading)""" |
|
|
inputs = _SHARED_TOKENIZER(text, padding=True, truncation=True, |
|
|
max_length=512, return_tensors='pt').to(_DEVICE) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = _SHARED_MODEL(**inputs) |
|
|
attention_mask = inputs['attention_mask'] |
|
|
token_embeddings = outputs.last_hidden_state |
|
|
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
|
|
embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) |
|
|
|
|
|
|
|
|
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) |
|
|
|
|
|
return embeddings.cpu().numpy()[0].astype('float32') |
|
|
|
|
|
def analyze_query_fast(self, query: str) -> Dict[str, Any]: |
|
|
"""FAST query analysis - minimal processing""" |
|
|
query_lower = query.lower() |
|
|
|
|
|
|
|
|
query_type = 'general' |
|
|
for qtype, patterns in QUERY_PATTERNS.items(): |
|
|
if any(pattern in query_lower for pattern in patterns): |
|
|
query_type = qtype |
|
|
break |
|
|
|
|
|
|
|
|
key_concepts = [] |
|
|
for concept_category, concepts in LEGAL_CONCEPTS.items(): |
|
|
for concept in concepts: |
|
|
if concept in query_lower: |
|
|
key_concepts.append(concept) |
|
|
|
|
|
|
|
|
expanded_queries = [query] |
|
|
if key_concepts: |
|
|
expanded_queries.append(f"{query} {' '.join(key_concepts[:2])}") |
|
|
|
|
|
return { |
|
|
'original_query': query, |
|
|
'query_type': query_type, |
|
|
'key_concepts': key_concepts, |
|
|
'expanded_queries': expanded_queries[:2] |
|
|
} |
|
|
|
|
|
def fast_retrieval(self, query_analysis: Dict[str, Any], top_k: int = 10) -> List[Tuple[Dict[str, Any], float]]: |
|
|
"""OPTIMIZED: Fast multi-stage retrieval with minimal overhead""" |
|
|
candidates = {} |
|
|
|
|
|
|
|
|
query = query_analysis['original_query'] |
|
|
query_emb = self.create_embedding(query) |
|
|
scores, indices = self.dense_index.search( |
|
|
query_emb.reshape(1, -1), |
|
|
min(top_k * 2, len(self.chunks_data)) |
|
|
) |
|
|
|
|
|
for idx, score in zip(indices[0], scores[0]): |
|
|
if idx < len(self.chunks_data): |
|
|
chunk = self.chunks_data[idx] |
|
|
chunk_id = chunk['id'] |
|
|
candidates[chunk_id] = { |
|
|
'chunk': chunk, |
|
|
'score': float(score) * chunk['importance_score'] |
|
|
} |
|
|
|
|
|
|
|
|
if len(candidates) < top_k: |
|
|
query_tokens = query.lower().split() |
|
|
bm25_scores = self.bm25_index.get_scores(query_tokens) |
|
|
top_bm25_indices = np.argsort(bm25_scores)[-top_k:][::-1] |
|
|
|
|
|
for idx in top_bm25_indices: |
|
|
if idx < len(self.chunks_data): |
|
|
chunk = self.chunks_data[idx] |
|
|
chunk_id = chunk['id'] |
|
|
if chunk_id not in candidates: |
|
|
candidates[chunk_id] = { |
|
|
'chunk': chunk, |
|
|
'score': float(bm25_scores[idx]) * 0.3 |
|
|
} |
|
|
else: |
|
|
candidates[chunk_id]['score'] += float(bm25_scores[idx]) * 0.2 |
|
|
|
|
|
|
|
|
final_scores = [(data['chunk'], data['score']) for data in candidates.values()] |
|
|
final_scores.sort(key=lambda x: x[1], reverse=True) |
|
|
|
|
|
return final_scores[:top_k] |
|
|
|
|
|
def generate_fast_answer(self, query: str, retrieved_chunks: List[Tuple[Dict[str, Any], float]]) -> Dict[str, Any]: |
|
|
"""Generate answer with minimal overhead""" |
|
|
if not self.groq_client: |
|
|
return {'error': 'Groq client not initialized'} |
|
|
|
|
|
|
|
|
context_parts = [] |
|
|
for i, (chunk, score) in enumerate(retrieved_chunks[:3], 1): |
|
|
context_parts.append(f""" |
|
|
Document {i} - Relevance: {score:.2f} |
|
|
{chunk['text'][:600]} |
|
|
""") |
|
|
|
|
|
context = "\n---\n".join(context_parts) |
|
|
|
|
|
system_prompt = """You are a legal AI assistant. Provide concise, accurate answers based ONLY on the provided documents. If information isn't in the documents, state that clearly.""" |
|
|
|
|
|
user_prompt = f"""Query: {query} |
|
|
|
|
|
Documents: |
|
|
{context} |
|
|
|
|
|
Provide a clear, concise answer based on the documents.""" |
|
|
|
|
|
try: |
|
|
response = self.groq_client.chat.completions.create( |
|
|
messages=[ |
|
|
{"role": "system", "content": system_prompt}, |
|
|
{"role": "user", "content": user_prompt} |
|
|
], |
|
|
model="llama-3.1-8b-instant", |
|
|
temperature=0.1, |
|
|
max_tokens=500 |
|
|
) |
|
|
|
|
|
answer = response.choices[0].message.content |
|
|
|
|
|
|
|
|
avg_score = sum(score for _, score in retrieved_chunks[:3]) / min(3, len(retrieved_chunks)) |
|
|
confidence = min(avg_score * 100, 100) |
|
|
|
|
|
return { |
|
|
'answer': answer, |
|
|
'confidence': confidence, |
|
|
'sources': [ |
|
|
{ |
|
|
'chunk_id': chunk['id'], |
|
|
'title': chunk['title'], |
|
|
'section': chunk['section_type'], |
|
|
'relevance_score': float(score), |
|
|
'text_preview': chunk['text'][:200] + '...', |
|
|
'entities': [e['text'] for e in chunk['entities'][:3]] |
|
|
} |
|
|
for chunk, score in retrieved_chunks[:5] |
|
|
] |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
return {'error': f'Error generating answer: {str(e)}'} |
|
|
|
|
|
def query_documents(self, query: str, top_k: int = 5) -> Dict[str, Any]: |
|
|
"""OPTIMIZED: Main query function with minimal processing time""" |
|
|
if not self.chunks_data: |
|
|
return {'error': f'No documents indexed for session {self.session_id}'} |
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
query_analysis = self.analyze_query_fast(query) |
|
|
|
|
|
|
|
|
retrieved_chunks = self.fast_retrieval(query_analysis, top_k) |
|
|
|
|
|
if not retrieved_chunks: |
|
|
return { |
|
|
'error': 'No relevant documents found', |
|
|
'query_analysis': query_analysis |
|
|
} |
|
|
|
|
|
|
|
|
result = self.generate_fast_answer(query, retrieved_chunks) |
|
|
result['query_analysis'] = query_analysis |
|
|
result['processing_time'] = time.time() - start_time |
|
|
|
|
|
logger.info(f"Query processed in {result['processing_time']:.2f}s") |
|
|
return result |
|
|
|
|
|
|
|
|
SessionRAG = OptimizedSessionRAG |