legal-llama-rag / app.py
siyagajbhe's picture
Update app.py
da35a71 verified
import os
from huggingface_hub import login
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import gradio as gr
print("🚀 Starting Legal LLaMA RAG Space...")
# ---------------------------
# Step 1: Login to Hugging Face
# ---------------------------
token = os.getenv("HF_TOKEN", "").strip()
if not token:
raise ValueError("❌ HF_TOKEN missing. Add it in: Settings → Secrets → HF_TOKEN")
login(token=token)
print("🔐 Logged into Hugging Face Hub!")
# ---------------------------
# Step 2: Load Caselaw Dataset (small subset)
# ---------------------------
print("📚 Loading Caselaw dataset (VERY small subset to avoid OOM)...")
print("📚 Loading small local sample instead of full dataset...")
ds = [
{
"text": "This is a sample case about contract disputes under Indian Contract Act.",
"case_name": "Sample Contract Case",
"court": "Supreme Court of India",
"date": "2020"
},
{
"text": "This case discusses negligence under tort law.",
"case_name": "Negligence Case",
"court": "Delhi High Court",
"date": "2018"
}
]
print("✅ Dataset loaded!")
# ---------------------------
# Step 3: Preprocess & Chunk
# ---------------------------
def chunk_text(text, size=500):
return [text[i:i+size] for i in range(0, len(text), size)]
documents = []
metadatas = []
for item in ds:
chunks = chunk_text(item["text"])
for chunk in chunks:
documents.append(chunk)
metadatas.append({
"case_name": item.get("case_name", "N/A"),
"court": item.get("court", "N/A"),
"date": item.get("date", "N/A")
})
print(f"🧩 Total chunks created: {len(documents)}")
# ---------------------------
# Step 4: Create Embeddings & FAISS Index
# ---------------------------
print("⚙️ Creating embeddings and FAISS index...")
embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
embeddings = embed_model.encode(documents, show_progress_bar=True)
embeddings = np.array(embeddings).astype("float32")
index = faiss.IndexFlatL2(embeddings.shape[1])
index.add(embeddings)
print("✅ FAISS index ready!")
# ---------------------------
# Step 5: Retrieval Function
# ---------------------------
def answer_query(query, k=3):
query_vec = embed_model.encode([query]).astype("float32")
D, I = index.search(query_vec, k)
retrieved_chunks = [documents[i] for i in I[0]]
retrieved_citations = [metadatas[i] for i in I[0]]
answer = "Based on retrieved legal cases:\n"
for i, chunk in enumerate(retrieved_chunks):
answer += (
f"\n- Case: {retrieved_citations[i]['case_name']} "
f"({retrieved_citations[i]['court']} - {retrieved_citations[i]['date']})\n"
+ chunk[:300] + "...\n"
)
return {
"input": query,
"output": answer,
"citations": retrieved_citations
}
# ---------------------------
# Step 6: Gradio Interface
# ---------------------------
demo = gr.Interface(
fn=answer_query,
inputs="text",
outputs="json",
title="Legal LLaMA RAG Demo",
description="Ask a legal question and get evidence-backed case citations."
)
print("✅ Demo ready – launching now")
if __name__ == "__main__":
demo.launch()