reteropred

Model Details

  • Model Type: BART (Bidirectional and Auto-Regressive Transformers)
  • Task: Molecular Retrosynthesis (Product SMILES → Reactant SMILES)
  • Language: SMILES (Simplified Molecular Input Line Entry System)
  • Architecture: BartForConditionalGeneration
  • Parameters: 44.8M
  • License: mpl-2.0

Model Description

This model is designed for computer-aided retrosynthesis analysis. It predicts reactant sets given a product molecule represented as a SMILES string. The model uses a sequence-to-sequence transformer architecture (BART) customized for chemical language processing.

Key features include:

  • Custom Tokenization: Uses a Byte-Pair Encoding (BPE) tokenizer trained specifically on SMILES syntax with a vocabulary size of 1,000 tokens.
  • Data Augmentation: Implements SMILES randomization during training to improve robustness against different molecular representations.
  • Canonicalization: Utilizes RDKit for canonical SMILES conversion during preprocessing and evaluation to ensure chemical validity.

Intended Uses

Direct Use:

  • Predicting reactants for a given product molecule in organic chemistry research.
  • Assisting chemists in planning synthesis routes.
  • Educational purposes for computational chemistry.

Out-of-Scope Use:

  • Not for Clinical Use: This model is not validated for pharmaceutical manufacturing or clinical applications.
  • Not for Hazardous Materials: Should not be used to plan synthesis of regulated or dangerous substances without expert oversight.
  • Guarantee of Validity: The model outputs SMILES strings that should be validated chemically (e.g., via RDKit) before use.

Training Data

The model was trained on a combination of public reaction datasets and template rules:

  1. USPTO Dataset: Curated patent reactions containing reactant-product pairs.
  2. Preprocessing:
    • Canonicalization: All SMILES were canonicalized using RDKit (Chem.MolToSmiles).
    • Cleaning: Atom maps were stripped, and explicit hydrogens were removed.
    • Filtering: Identity mappings (where product == reactant) were removed.
    • Augmentation: Training inputs were randomized using Chem.MolToSmiles(mol, doRandom=True) to prevent overfitting to specific SMILES representations.

Training Procedure

Hyperparameters

Hyperparameter Value
Base Architecture BART (Custom Config)
Hidden Size (d_model) 512
Encoder/Decoder Layers 6
Attention Heads 8
Vocabulary Size 1,000
Max Sequence Length 128
Batch Size 192
Epochs 10
Learning Rate 1e-4
Optimizer AdamW (via Transformers)
Precision FP16

Framework & Libraries

  • Deep Learning: PyTorch, Hugging Face transformers
  • Cheminformatics: RDKit
  • Data Processing: Pandas, Scikit-learn, Tokenizers

Evaluation

The model was evaluated on a held-out validation set and external test sets (Enamine Real, ChEMBL, ZINC). Accuracy is measured using Exact Match (EM) based on canonical SMILES comparison.

Metrics

validation_accuracy Top-1, Top-3, and Top-5 exact match accuracy on the validation set.

Performance by Reaction Class

The heatmap below illustrates how the model performs across different USPTO reaction classes, distinguishing between correct predictions, invalid SMILES generation, and reactant mismatches.

class_performance_matrix

Performance breakdown by USPTO reaction class.

Token-Level Analysis

To understand the model's "chemical vocabulary" performance, the following confusion matrix shows the most frequent SMILES tokens and how accurately the model predicts them.

token_confusion_matrix

Confusion matrix for the top 15 most frequent chemical tokens.

Prediction Examples

Below are representative examples of the model's retrosynthetic predictions compared against the ground truth.

prediction_examples_grid

Visual grid comparing the Target Product, the True Reactants, and the Model's Predicted Reactants.

Note: Evaluation involves generating 5 beam search sequences and checking if the canonicalized ground truth matches any of the top-k predictions.

How to Use

Load with Transformers

from transformers import BartForConditionalGeneration, PreTrainedTokenizerFast
import torch
from rdkit import Chem

# Load model and tokenizer
model = BartForConditionalGeneration.from_pretrained("surya/bart-retrosynth") # Replace with your HF repo
tokenizer = PreTrainedTokenizerFast.from_pretrained("surya/bart-retrosynth")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

def predict_reactants(product_smiles: str):
    # Canonicalize input
    mol = Chem.MolFromSmiles(product_smiles)
    if not mol:
        return "Invalid SMILES"
    canon_smiles = Chem.MolToSmiles(mol, canonical=True)
    
    # Tokenize
    inputs = tokenizer(canon_smiles, return_tensors="pt", max_length=128, truncation=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    # Generate
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_length=128,
            num_beams=5,
            early_stopping=True
        )
    
    # Decode
    predictions = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    return predictions

# Example
product = "CCO" # Ethanol
reactants = predict_reactants(product)
print(f"Predicted Reactants: {reactants}")

Limitations & Bias

  • Stereochemistry: The model was trained with isomericSmiles=False in some preprocessing steps (check canonicalize function). Stereochemical accuracy may be limited.
  • Validity: Not all generated SMILES strings are guaranteed to be chemically valid. Post-processing validation is required.
  • Length Constraint: Molecules requiring SMILES representations longer than 128 tokens will be truncated.
Downloads last month
26
Safetensors
Model size
44.8M params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support