πŸ“„ Document Classifier β€” XGBoost + TF-IDF

A lightweight, high-performance document classification model trained on the RVL-CDIP Small dataset.

It classifies scanned/OCR-processed documents into their category using handcrafted TF-IDF (word & character n-gram) features combined with numeric heuristic features, fed into an XGBoost classifier.


πŸ—οΈ Model Architecture

Component Details
Classifier XGBoost (XGBClassifier)
Text features TF-IDF word n-grams (1–2), char n-grams (3–5)
Numeric features char_count, digit_count, uppercase_count, currency_count, line_count
Scaler StandardScaler (on numeric features)
Training rounds 400 estimators, early stopping (30 rounds)

πŸ“¦ Files

File Description
document_classifier_xgb.pkl Serialised model bundle (joblib) β€” contains model + vectorizers + scaler
predict_document.py Ready-to-use inference script
train_model.py Full training script
training_curve.png Train vs validation log-loss curve
feature_importance.png Top-20 feature importances

πŸš€ Quick Start

import joblib

# Load the model bundle
bundle = joblib.load("document_classifier_xgb.pkl")
model            = bundle["model"]
word_vectorizer  = bundle["word_vectorizer"]
char_vectorizer  = bundle["char_vectorizer"]
scaler           = bundle["scaler"]

from scipy.sparse import hstack, csr_matrix
import numpy as np

def predict(text: str) -> int:
    word_feat = word_vectorizer.transform([text])
    char_feat = char_vectorizer.transform([text])
    num_feat  = scaler.transform([[
        len(text),                          # char_count
        sum(c.isdigit() for c in text),     # digit_count
        sum(c.isupper() for c in text),     # uppercase_count
        text.count("$") + text.count("Β£"),  # currency_count
        text.count("\n"),                   # line_count
    ]])
    features = hstack([word_feat, char_feat, csr_matrix(num_feat)])
    return int(model.predict(features)[0])

label = predict("Invoice No. 12345  Total: $499.99  Date: 01/01/2024")
print("Predicted label:", label)

πŸ“Š Training Details

  • Dataset: RVL-CDIP Small (train / val / test split)
  • Objective: multi:softprob (multi-class log loss)
  • Hardware: CPU
  • Framework: XGBoost 2.x, scikit-learn, joblib

πŸ“ License

MIT

Downloads last month
-
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support