A7m0d's picture
Upload folder using huggingface_hub
7dfe46c verified
import logging
from typing import Dict, List, Any, Optional, Tuple
from dataclasses import dataclass, field
import time
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.embedding_system import EmbeddingSystem, RerankResult
from src.vector_store import QdrantVectorStore, SearchResult
from src.groq_client import LLMSystem
from src.document_processor import DocumentChunk
from src.utilites import load_yaml_config
try:
from logger.custom_logger import CustomLoggerTracker
custom_log = CustomLoggerTracker()
logger = custom_log.get_logger("rag_engine")
except ImportError:
logger = logging.getLogger("rag_engine")
@dataclass
class Citation:
source_file: str
page_number: Optional[int] = None
worksheet_name: Optional[str] = None
cell_range: Optional[str] = None
section_title: Optional[str] = None
text_snippet: str = ""
confidence: float = 0.0
chunk_id: str = ""
@dataclass
class RAGResponse:
answer: str
confidence_score: float
citations: List[Citation] = field(default_factory=list)
context_chunks: List[DocumentChunk] = field(default_factory=list)
processing_time: float = 0.0
retrieval_time: float = 0.0
generation_time: float = 0.0
rerank_time: float = 0.0
total_chunks_retrieved: int = 0
total_chunks_reranked: int = 0
model_used: str = ""
success: bool = True
error_message: Optional[str] = None
class RAGEngine:
def __init__(self, config: Dict[str, Any]):
self.config = config
# Initialize components
self.embedding_system = EmbeddingSystem(config)
self.vector_store = QdrantVectorStore(config)
self.llm_system = LLMSystem(config)
# RAG parameters
self.max_context_chunks = config.get('max_context_chunks', 5)
self.similarity_threshold = config.get('similarity_threshold', 0.7)
self.rerank_top_k = config.get('rerank_top_k', 20)
self.final_top_k = config.get('final_top_k', 5)
self.max_context_length = config.get('max_context_length', 4000)
logger.info(f"RAG engine initialized with max_context_chunks={self.max_context_chunks}")
def answer_question(self, question: str, filters: Optional[Dict[str, Any]] = None) -> RAGResponse:
start_time = time.time()
try:
logger.info(f"Processing question: {question[:100]}...")
# Step 1: Generate query embedding
query_embedding = self.embedding_system.generate_query_embedding(question)
if not query_embedding:
return RAGResponse(
answer="I apologize, but I'm unable to process your question due to an embedding generation error.",
confidence_score=0.0,
success=False,
error_message="Failed to generate query embedding")
# Step 2: Retrieve relevant chunks
retrieval_start = time.time()
search_results = self.vector_store.similarity_search(
query_embedding=query_embedding,
k=self.rerank_top_k,
filters=filters)
retrieval_time = time.time() - retrieval_start
if not search_results:
return RAGResponse(
answer="I couldn't find any relevant information to answer your question. Please try rephrasing your question or check if the relevant documents have been uploaded.",
confidence_score=0.0,
retrieval_time=retrieval_time,
processing_time=time.time() - start_time,
success=True)
logger.info(f"Retrieved {len(search_results)} chunks from vector store in {retrieval_time:.2f}s")
# Step 3: Rerank results
rerank_start = time.time()
reranked_chunks = self._rerank_chunks(question, search_results)
rerank_time = time.time() - rerank_start
# Step 4: Select top chunks and build context
context_chunks = reranked_chunks[:self.final_top_k]
context_text = self._build_context(context_chunks)
logger.info(f"Built context from top {len(context_chunks)} chunks")
# Step 5: Generate answer
generation_start = time.time()
answer = self.llm_system.answer_question(question, context_text)
generation_time = time.time() - generation_start
if not answer:
return RAGResponse(
answer="I apologize, but I was unable to generate an answer to your question.",
confidence_score=0.0,
retrieval_time=retrieval_time,
generation_time=generation_time,
rerank_time=rerank_time,
processing_time=time.time() - start_time,
success=False,
error_message="LLM failed to generate an answer")
logger.info(f"Generated answer in {generation_time:.2f}s")
# Step 6: Extract citations
citations = self._extract_citations(context_chunks)
logger.info(f"Extracted {len(citations)} citations")
# Step 7: Calculate confidence score
confidence_score = self._calculate_confidence_score(search_results, answer)
logger.info(f"Calculated confidence score: {confidence_score:.2f}")
total_time = time.time() - start_time
response = RAGResponse(
answer=answer,
confidence_score=confidence_score,
citations=citations,
context_chunks=[result.chunk for result in context_chunks],
processing_time=total_time,
retrieval_time=retrieval_time,
generation_time=generation_time,
rerank_time=rerank_time,
total_chunks_retrieved=len(search_results),
total_chunks_reranked=len(reranked_chunks),
model_used=self.llm_system.default_model,
success=True)
logger.info(f"Question answered successfully in {total_time:.2f}s")
return response
except Exception as e:
error_msg = f"RAG processing failed: {str(e)}"
logger.error(error_msg)
return RAGResponse(
answer="I apologize, but I encountered an error while processing your question. Please try again.",
confidence_score=0.0,
processing_time=time.time() - start_time,
success=False,
error_message=error_msg
)
def get_relevant_context(self, question: str, k: int = 5,
filters: Optional[Dict[str, Any]] = None) -> List[DocumentChunk]:
try:
query_embedding = self.embedding_system.generate_query_embedding(question)
if not query_embedding:
return []
# Retrieve and rerank
search_results = self.vector_store.similarity_search(
query_embedding=query_embedding,
k=min(k * 2, self.rerank_top_k),
filters=filters)
if not search_results:
return []
# Rerank and return top k
reranked_chunks = self._rerank_chunks(question, search_results)
return [result.chunk for result in reranked_chunks[:k]]
except Exception as e:
logger.error(f"Failed to get relevant context: {e}")
return []
def _rerank_chunks(self, question: str, search_results: List[SearchResult]) -> List[SearchResult]:
try:
if len(search_results) <= 1:
return search_results
# Extract documents for reranking
documents = [result.chunk.content for result in search_results]
# Perform reranking
rerank_results = self.embedding_system.rerank_results(
query=question,
documents=documents,
top_k=len(documents)
)
# Map rerank results back to search results
reranked_search_results = []
for rerank_result in rerank_results:
# Find corresponding search result
original_index = rerank_result.index
if 0 <= original_index < len(search_results):
search_result = search_results[original_index]
search_result.rerank_score = rerank_result.score
reranked_search_results.append(search_result)
logger.debug(f"Reranked {len(search_results)} chunks")
return reranked_search_results
except Exception as e:
logger.warning(f"Reranking failed, using original order: {e}")
return search_results
def _build_context(self, search_results: List[SearchResult]) -> str:
"""
Build context text from search results.
Args:
search_results: List of search results
Returns:
Formatted context text
"""
context_parts = []
current_length = 0
for i, result in enumerate(search_results):
chunk = result.chunk
# Create context entry with citation info
citation_info = self._format_citation_info(chunk)
content = f"[Source {i+1}: {citation_info}]\n{chunk.content}\n"
# Check if adding this chunk would exceed max context length
if current_length + len(content) > self.max_context_length:
# Try to fit a truncated version
remaining_space = self.max_context_length - current_length - len(f"[Source {i+1}: {citation_info}]\n") - 20
if remaining_space > 100: # Only add if we have reasonable space
truncated_content = chunk.content[:remaining_space] + "..."
content = f"[Source {i+1}: {citation_info}]\n{truncated_content}\n"
context_parts.append(content)
break
context_parts.append(content)
current_length += len(content)
return "\n".join(context_parts)
def _format_citation_info(self, chunk: DocumentChunk) -> str:
"""
Format citation information for a chunk.
Args:
chunk: Document chunk
Returns:
Formatted citation string
"""
parts = []
# Add document ID or filename if available
if hasattr(chunk.metadata, 'document_id'):
parts.append(f"Doc: {chunk.metadata.document_id}")
# Add page number for PDFs
if chunk.metadata.page_number:
parts.append(f"Page {chunk.metadata.page_number}")
# Add worksheet info for Excel
if chunk.metadata.worksheet_name:
parts.append(f"Sheet: {chunk.metadata.worksheet_name}")
if chunk.metadata.cell_range:
parts.append(f"Range: {chunk.metadata.cell_range}")
# Add section title if available
if chunk.metadata.section_title:
parts.append(f"Section: {chunk.metadata.section_title}")
return ", ".join(parts) if parts else "Unknown source"
def _extract_citations(self, search_results: List[SearchResult]) -> List[Citation]:
citations = []
for result in search_results:
chunk = result.chunk
# Create citation
citation = Citation(
source_file=getattr(chunk.metadata, 'document_id', 'Unknown'),
page_number=chunk.metadata.page_number,
worksheet_name=chunk.metadata.worksheet_name,
cell_range=chunk.metadata.cell_range,
section_title=chunk.metadata.section_title,
text_snippet=chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content,
confidence=result.similarity_score,
chunk_id=chunk.metadata.chunk_id
)
citations.append(citation)
return citations
def _calculate_confidence_score(self, search_results: List[SearchResult], answer: str) -> float:
if not search_results:
return 0.0
# Base confidence on similarity scores
similarity_scores = [result.similarity_score for result in search_results]
avg_similarity = sum(similarity_scores) / len(similarity_scores)
# Adjust based on number of sources
source_factor = min(len(search_results) / self.final_top_k, 1.0)
# Adjust based on answer length (very short answers might be less reliable)
length_factor = min(len(answer) / 100, 1.0) if answer else 0.0
# Combine factors
confidence = (avg_similarity * 0.6 + source_factor * 0.2 + length_factor * 0.2)
return min(max(confidence, 0.0), 1.0) # Clamp to [0, 1]
def health_check(self) -> Dict[str, bool]:
return {
"vector_store": self.vector_store.health_check(),
"llm_system": self.llm_system.client.health_check(),
"embedding_system": True # Silicon Flow doesn't have a direct health check
}
def get_stats(self) -> Dict[str, Any]:
try:
vector_stats = self.vector_store.get_collection_info()
embedding_stats = self.embedding_system.get_cache_stats()
return {
"vector_store": vector_stats.__dict__ if vector_stats else {},
"embedding_cache": embedding_stats,
"config": {
"max_context_chunks": self.max_context_chunks,
"similarity_threshold": self.similarity_threshold,
"rerank_top_k": self.rerank_top_k,
"final_top_k": self.final_top_k
}
}
except Exception as e:
logger.error(f"Failed to get RAG stats: {e}")
return {"error": str(e)}
if __name__ == "__main__":
from src.utilites import validate_api_keys
validation_results = validate_api_keys()
if not validation_results['valid']:
logger.error("Missing required API keys. Please set them in the environment variables.")
else:
logger.info("All required API keys are present.")
## Example usage
config = load_yaml_config("src/config.yaml")
rag_engine = RAGEngine(config)