Spaces:
Runtime error
Runtime error
| """ | |
| Law RAG Chatbot API using FastAPI, Langchain, Groq, and ChromaDB | |
| """ | |
| import os | |
| import time | |
| import logging | |
| from typing import List, Dict, Any, Optional | |
| from fastapi import FastAPI, HTTPException, Depends, Header | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import uvicorn | |
| from config import * | |
| from rag_system import RAGSystem | |
| from session_manager import SessionManager | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title=API_TITLE, | |
| version=API_VERSION, | |
| description=API_DESCRIPTION | |
| ) | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Pydantic models for API requests/responses | |
| class ChatRequest(BaseModel): | |
| question: str | |
| context_length: int = 5 # Increased default context length | |
| session_id: Optional[str] = None | |
| class ChatResponse(BaseModel): | |
| answer: str | |
| sources: List[Dict[str, Any]] | |
| confidence: float | |
| processing_time: float | |
| question: str | |
| session_id: str | |
| chat_history_count: int | |
| class SessionCreateRequest(BaseModel): | |
| user_info: Optional[str] = None | |
| metadata: Optional[Dict[str, Any]] = None | |
| class SessionResponse(BaseModel): | |
| session_id: str | |
| created_at: str | |
| user_info: str | |
| metadata: Dict[str, Any] | |
| class HealthResponse(BaseModel): | |
| status: str | |
| message: str | |
| components: Dict[str, str] | |
| # Global instances | |
| rag_system: RAGSystem = None | |
| session_manager: SessionManager = None | |
| async def startup_event(): | |
| """Initialize the RAG system and session manager on startup""" | |
| global rag_system, session_manager | |
| try: | |
| logger.info("Initializing RAG system...") | |
| rag_system = RAGSystem() | |
| await rag_system.initialize() | |
| logger.info("RAG system initialized successfully") | |
| logger.info("Initializing session manager...") | |
| session_manager = SessionManager() | |
| logger.info("Session manager initialized successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize systems: {e}") | |
| raise | |
| async def root(): | |
| """Root endpoint with health check""" | |
| return HealthResponse( | |
| status="healthy", | |
| message="Law RAG Chatbot API is running", | |
| components={ | |
| "api": "running", | |
| "rag_system": "running" if rag_system else "not_initialized", | |
| "session_manager": "running" if session_manager else "not_initialized", | |
| "vector_db": "connected" if rag_system and rag_system.is_ready() else "disconnected" | |
| } | |
| ) | |
| async def health_check(): | |
| """Health check endpoint""" | |
| if not rag_system: | |
| raise HTTPException(status_code=503, detail="RAG system not initialized") | |
| if not rag_system.is_ready(): | |
| raise HTTPException(status_code=503, detail="RAG system not ready") | |
| if not session_manager: | |
| raise HTTPException(status_code=503, detail="Session manager not initialized") | |
| return HealthResponse( | |
| status="healthy", | |
| message="All systems operational", | |
| components={ | |
| "api": "running", | |
| "rag_system": "ready", | |
| "session_manager": "ready", | |
| "vector_db": "connected", | |
| "embeddings": "ready", | |
| "llm": "ready" | |
| } | |
| ) | |
| async def create_session(request: SessionCreateRequest): | |
| """Create a new chat session""" | |
| if not session_manager: | |
| raise HTTPException(status_code=503, detail="Session manager not initialized") | |
| try: | |
| session_id = session_manager.create_session( | |
| user_info=request.user_info, | |
| metadata=request.metadata | |
| ) | |
| session = session_manager.get_session(session_id) | |
| return SessionResponse( | |
| session_id=session_id, | |
| created_at=session["created_at"].isoformat(), | |
| user_info=session["user_info"], | |
| metadata=session["metadata"] | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error creating session: {e}") | |
| raise HTTPException(status_code=500, detail=f"Failed to create session: {str(e)}") | |
| async def get_session_info(session_id: str): | |
| """Get session information and statistics""" | |
| if not session_manager: | |
| raise HTTPException(status_code=503, detail="Session manager not initialized") | |
| try: | |
| session_stats = session_manager.get_session_stats(session_id) | |
| if not session_stats: | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| return session_stats | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error getting session info: {e}") | |
| raise HTTPException(status_code=500, detail=f"Failed to get session info: {str(e)}") | |
| async def get_chat_history(session_id: str, limit: int = 10): | |
| """Get chat history for a session""" | |
| if not session_manager: | |
| raise HTTPException(status_code=503, detail="Session manager not initialized") | |
| try: | |
| history = session_manager.get_chat_history(session_id, limit) | |
| return { | |
| "session_id": session_id, | |
| "history": history, | |
| "total": len(history) | |
| } | |
| except Exception as e: | |
| logger.error(f"Error getting chat history: {e}") | |
| raise HTTPException(status_code=500, detail=f"Failed to get chat history: {str(e)}") | |
| async def chat(request: ChatRequest): | |
| """Main chat endpoint for legal questions with session support""" | |
| if not rag_system: | |
| raise HTTPException(status_code=503, detail="RAG system not initialized") | |
| if not rag_system.is_ready(): | |
| raise HTTPException(status_code=503, detail="RAG system not ready") | |
| if not session_manager: | |
| raise HTTPException(status_code=503, detail="Session manager not initialized") | |
| try: | |
| start_time = time.time() | |
| # Handle session | |
| if not request.session_id: | |
| # Create new session if none provided | |
| request.session_id = session_manager.create_session() | |
| logger.info(f"Created new session: {request.session_id}") | |
| # Verify session exists | |
| session = session_manager.get_session(request.session_id) | |
| if not session: | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| # Get response from RAG system | |
| response = await rag_system.get_response( | |
| question=request.question, | |
| context_length=request.context_length | |
| ) | |
| processing_time = time.time() - start_time | |
| # Store chat response in session | |
| session_manager.store_chat_response( | |
| session_id=request.session_id, | |
| question=request.question, | |
| answer=response["answer"], | |
| sources=response["sources"], | |
| confidence=response["confidence"], | |
| processing_time=processing_time | |
| ) | |
| # Get chat history count | |
| chat_history = session_manager.get_chat_history(request.session_id, limit=1) | |
| chat_history_count = len(chat_history) | |
| return ChatResponse( | |
| answer=response["answer"], | |
| sources=response["sources"], | |
| confidence=response["confidence"], | |
| processing_time=processing_time, | |
| question=request.question, | |
| session_id=request.session_id, | |
| chat_history_count=chat_history_count | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error processing chat request: {e}") | |
| raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") | |
| async def search(query: str, limit: int = 5, session_id: Optional[str] = None): | |
| """Search for relevant legal documents with optional session tracking""" | |
| if not rag_system: | |
| raise HTTPException(status_code=503, detail="RAG system not initialized") | |
| try: | |
| results = await rag_system.search_documents(query, limit) | |
| # Store search query if session provided | |
| if session_id and session_manager: | |
| session_manager.store_search_query(session_id, query, len(results)) | |
| return { | |
| "query": query, | |
| "results": results, | |
| "total": len(results), | |
| "session_id": session_id | |
| } | |
| except Exception as e: | |
| logger.error(f"Error in search: {e}") | |
| raise HTTPException(status_code=500, detail=f"Search failed: {str(e)}") | |
| async def get_stats(): | |
| """Get system statistics""" | |
| if not rag_system: | |
| raise HTTPException(status_code=503, detail="RAG system not initialized") | |
| try: | |
| rag_stats = await rag_system.get_stats() | |
| # Add session statistics | |
| if session_manager: | |
| # Get total sessions count (this would need to be implemented in session manager) | |
| session_stats = { | |
| "session_manager": "active", | |
| "total_sessions": "available" # Could implement actual count | |
| } | |
| else: | |
| session_stats = {"session_manager": "not_initialized"} | |
| return {**rag_stats, **session_stats} | |
| except Exception as e: | |
| logger.error(f"Error getting stats: {e}") | |
| raise HTTPException(status_code=500, detail=f"Failed to get stats: {str(e)}") | |
| async def reindex(): | |
| """Reindex the vector database""" | |
| if not rag_system: | |
| raise HTTPException(status_code=503, detail="RAG system not initialized") | |
| try: | |
| await rag_system.reindex() | |
| return {"message": "Reindexing completed successfully"} | |
| except Exception as e: | |
| logger.error(f"Error in reindexing: {e}") | |
| raise HTTPException(status_code=500, detail=f"Reindexing failed: {str(e)}") | |
| async def delete_session(session_id: str): | |
| """Delete a session and all its data""" | |
| if not session_manager: | |
| raise HTTPException(status_code=503, detail="Session manager not initialized") | |
| try: | |
| session_manager.delete_session(session_id) | |
| return {"message": f"Session {session_id} deleted successfully"} | |
| except Exception as e: | |
| logger.error(f"Error deleting session: {e}") | |
| raise HTTPException(status_code=500, detail=f"Failed to delete session: {str(e)}") | |
| if __name__ == "__main__": | |
| uvicorn.run( | |
| "app:app", | |
| host=HOST, | |
| port=PORT, | |
| reload=True, | |
| log_level="info" | |
| ) | |