D-RAG Phase 1 Checkpoints

Pre-trained retriever checkpoints and heuristics for D-RAG: Differentiable Retrieval-Augmented Generation.

📦 Contents

File Size Description
checkpoints_cwq_subgraph/phase1_best.pt 288 MB CWQ retriever (27,613 samples, 10 epochs)
checkpoints_webqsp_subgraph/phase1_best.pt 288 MB WebQSP retriever (2,826 samples, 10 epochs)
data/train_heuristics_cwq.jsonl 111 MB CWQ heuristics with per-question subgraphs
data/train_heuristics_webqsp_subgraph.jsonl 12 MB WebQSP heuristics with per-question subgraphs

Total size: ~700 MB

🚀 Usage

Automatic Download (Recommended)

git clone https://github.com/rhordoan/drag-improved.git
cd drag-improved
./scripts/setup_environment.sh  # Downloads checkpoints automatically

Manual Download

# Using Python script
python scripts/download_checkpoints.py

# Or using huggingface-cli
huggingface-cli download rhordoan/drag-improved-checkpoints \
    checkpoints_cwq_subgraph/phase1_best.pt \
    --local-dir .

Load in Python

import torch
from src.model.retriever import DRAGRetriever

# Load checkpoint
checkpoint = torch.load('checkpoints_cwq_subgraph/phase1_best.pt')

# Initialize retriever
retriever = DRAGRetriever(
    node_dim=256,
    edge_dim=256,
    hidden_dim=256,
    instruction_dim=384,
    relation_dim=256,
    num_reasoning_steps=3
)

# Load weights
retriever.load_state_dict(checkpoint['model_state_dict'])

📊 Dataset Details

CWQ (ComplexWebQuestions)

  • Source: rmanluo/RoG-cwq (Hugging Face)
  • Samples: 27,613
  • Training time: ~3.5 minutes on A100
  • Final loss: 0.2616 (BCE: 0.092, Ranking: 0.656)

WebQSP (WebQuestions Semantic Parses)

  • Source: rmanluo/RoG-webqsp (Hugging Face)
  • Samples: 2,826
  • Training time: ~30 seconds on A100
  • Final loss: ~0.25

🏗️ Model Architecture

DRAGRetriever (GNN-based fact retriever):

  • Instruction Module: Sentence-BERT encoder
  • Graph Reasoning: 3 layers of instruction-conditioned message passing
  • Instruction Update: Iterative refinement
  • Fact Scorer: Binary selection per edge

Training config:

  • Optimizer: AdamW (lr=5e-5, weight_decay=0.001)
  • Loss: ρ × BCE + (1-ρ) × Ranking (ρ=0.7)
  • Batch size: 16
  • Gradient clipping: 1.0

📜 Citation

@article{drag2024,
  title={D-RAG: Differentiable Retrieval-Augmented Generation},
  journal={arXiv preprint},
  year={2024}
}

🔗 Links

📄 License

MIT License - See repository for details.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support