D-RAG Phase 1 Checkpoints
Pre-trained retriever checkpoints and heuristics for D-RAG: Differentiable Retrieval-Augmented Generation.
📦 Contents
| File | Size | Description |
|---|---|---|
checkpoints_cwq_subgraph/phase1_best.pt |
288 MB | CWQ retriever (27,613 samples, 10 epochs) |
checkpoints_webqsp_subgraph/phase1_best.pt |
288 MB | WebQSP retriever (2,826 samples, 10 epochs) |
data/train_heuristics_cwq.jsonl |
111 MB | CWQ heuristics with per-question subgraphs |
data/train_heuristics_webqsp_subgraph.jsonl |
12 MB | WebQSP heuristics with per-question subgraphs |
Total size: ~700 MB
🚀 Usage
Automatic Download (Recommended)
git clone https://github.com/rhordoan/drag-improved.git
cd drag-improved
./scripts/setup_environment.sh # Downloads checkpoints automatically
Manual Download
# Using Python script
python scripts/download_checkpoints.py
# Or using huggingface-cli
huggingface-cli download rhordoan/drag-improved-checkpoints \
checkpoints_cwq_subgraph/phase1_best.pt \
--local-dir .
Load in Python
import torch
from src.model.retriever import DRAGRetriever
# Load checkpoint
checkpoint = torch.load('checkpoints_cwq_subgraph/phase1_best.pt')
# Initialize retriever
retriever = DRAGRetriever(
node_dim=256,
edge_dim=256,
hidden_dim=256,
instruction_dim=384,
relation_dim=256,
num_reasoning_steps=3
)
# Load weights
retriever.load_state_dict(checkpoint['model_state_dict'])
📊 Dataset Details
CWQ (ComplexWebQuestions)
- Source:
rmanluo/RoG-cwq(Hugging Face) - Samples: 27,613
- Training time: ~3.5 minutes on A100
- Final loss: 0.2616 (BCE: 0.092, Ranking: 0.656)
WebQSP (WebQuestions Semantic Parses)
- Source:
rmanluo/RoG-webqsp(Hugging Face) - Samples: 2,826
- Training time: ~30 seconds on A100
- Final loss: ~0.25
🏗️ Model Architecture
DRAGRetriever (GNN-based fact retriever):
- Instruction Module: Sentence-BERT encoder
- Graph Reasoning: 3 layers of instruction-conditioned message passing
- Instruction Update: Iterative refinement
- Fact Scorer: Binary selection per edge
Training config:
- Optimizer: AdamW (lr=5e-5, weight_decay=0.001)
- Loss: ρ × BCE + (1-ρ) × Ranking (ρ=0.7)
- Batch size: 16
- Gradient clipping: 1.0
📜 Citation
@article{drag2024,
title={D-RAG: Differentiable Retrieval-Augmented Generation},
journal={arXiv preprint},
year={2024}
}
🔗 Links
- GitHub: https://github.com/rhordoan/drag-improved
- Paper: D-RAG Paper
- Datasets: RoG-CWQ | RoG-WebQSP
📄 License
MIT License - See repository for details.