π Document Classifier β XGBoost + TF-IDF
A lightweight, high-performance document classification model trained on the RVL-CDIP Small dataset.
It classifies scanned/OCR-processed documents into their category using handcrafted TF-IDF (word & character n-gram) features combined with numeric heuristic features, fed into an XGBoost classifier.
ποΈ Model Architecture
| Component | Details |
|---|---|
| Classifier | XGBoost (XGBClassifier) |
| Text features | TF-IDF word n-grams (1β2), char n-grams (3β5) |
| Numeric features | char_count, digit_count, uppercase_count, currency_count, line_count |
| Scaler | StandardScaler (on numeric features) |
| Training rounds | 400 estimators, early stopping (30 rounds) |
π¦ Files
| File | Description |
|---|---|
document_classifier_xgb.pkl |
Serialised model bundle (joblib) β contains model + vectorizers + scaler |
predict_document.py |
Ready-to-use inference script |
train_model.py |
Full training script |
training_curve.png |
Train vs validation log-loss curve |
feature_importance.png |
Top-20 feature importances |
π Quick Start
import joblib
# Load the model bundle
bundle = joblib.load("document_classifier_xgb.pkl")
model = bundle["model"]
word_vectorizer = bundle["word_vectorizer"]
char_vectorizer = bundle["char_vectorizer"]
scaler = bundle["scaler"]
from scipy.sparse import hstack, csr_matrix
import numpy as np
def predict(text: str) -> int:
word_feat = word_vectorizer.transform([text])
char_feat = char_vectorizer.transform([text])
num_feat = scaler.transform([[
len(text), # char_count
sum(c.isdigit() for c in text), # digit_count
sum(c.isupper() for c in text), # uppercase_count
text.count("$") + text.count("Β£"), # currency_count
text.count("\n"), # line_count
]])
features = hstack([word_feat, char_feat, csr_matrix(num_feat)])
return int(model.predict(features)[0])
label = predict("Invoice No. 12345 Total: $499.99 Date: 01/01/2024")
print("Predicted label:", label)
π Training Details
- Dataset: RVL-CDIP Small (train / val / test split)
- Objective:
multi:softprob(multi-class log loss) - Hardware: CPU
- Framework: XGBoost 2.x, scikit-learn, joblib
π License
MIT
- Downloads last month
- -