|
|
|
|
|
|
|
|
import os |
|
|
import sys |
|
|
import json |
|
|
import argparse |
|
|
import random |
|
|
from typing import List, Optional |
|
|
from tqdm import tqdm |
|
|
|
|
|
import torch |
|
|
from rdkit import Chem |
|
|
from rdkit.Chem import AllChem |
|
|
from rdkit import RDLogger |
|
|
import selfies as sf |
|
|
import pandas as pd |
|
|
|
|
|
|
|
|
RDLogger.DisableLog('rdApp.*') |
|
|
|
|
|
|
|
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
|
|
|
|
|
from FastChemTokenizerHF import FastChemTokenizerSelfies |
|
|
from ChemQ3MTP import ChemQ3MTPForCausalLM |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def selfies_to_smiles(selfies_str: str) -> Optional[str]: |
|
|
"""Convert SELFIES string to SMILES, handling tokenizer artifacts.""" |
|
|
try: |
|
|
clean_selfies = selfies_str.replace(" ", "") |
|
|
return sf.decoder(clean_selfies) |
|
|
except Exception: |
|
|
return None |
|
|
|
|
|
|
|
|
def is_valid_smiles(smiles: str) -> bool: |
|
|
""" |
|
|
Check if a SMILES string represents a valid molecule. |
|
|
FIXED: Now properly checks for heavy atoms (non-hydrogens) >= 3 |
|
|
and rejects disconnected/separated molecules |
|
|
""" |
|
|
if not isinstance(smiles, str) or len(smiles.strip()) == 0: |
|
|
return False |
|
|
|
|
|
smiles = smiles.strip() |
|
|
|
|
|
|
|
|
if '.' in smiles: |
|
|
return False |
|
|
|
|
|
try: |
|
|
mol = Chem.MolFromSmiles(smiles) |
|
|
if mol is None: |
|
|
return False |
|
|
|
|
|
|
|
|
heavy_atoms = mol.GetNumHeavyAtoms() |
|
|
if heavy_atoms < 3: |
|
|
return False |
|
|
|
|
|
return True |
|
|
except Exception: |
|
|
return False |
|
|
|
|
|
def passes_durrant_lab_filter(smiles: str) -> bool: |
|
|
""" |
|
|
Apply Durrant's lab filter to remove improbable substructures. |
|
|
FIXED: More robust error handling, pattern checking, and disconnected molecule rejection. |
|
|
Returns True if molecule passes the filter (is acceptable), False otherwise. |
|
|
""" |
|
|
if not smiles or not isinstance(smiles, str) or len(smiles.strip()) == 0: |
|
|
return False |
|
|
|
|
|
try: |
|
|
mol = Chem.MolFromSmiles(smiles.strip()) |
|
|
if mol is None: |
|
|
return False |
|
|
|
|
|
|
|
|
if mol.GetNumHeavyAtoms() < 3: |
|
|
return False |
|
|
|
|
|
|
|
|
fragments = Chem.rdmolops.GetMolFrags(mol, asMols=False) |
|
|
if len(fragments) > 1: |
|
|
return False |
|
|
|
|
|
|
|
|
problematic_patterns = [ |
|
|
"C=[N-]", |
|
|
"[N-]C=[N+]", |
|
|
"[nH+]c[n-]", |
|
|
"[#7+]~[#7+]", |
|
|
"[#7-]~[#7-]", |
|
|
"[!#7]~[#7+]~[#7-]~[!#7]", |
|
|
"[#5]", |
|
|
"O=[PH](=O)([#8])([#8])", |
|
|
"N=c1cc[#7]c[#7]1", |
|
|
"[$([NX2H1]),$([NX3H2])]=C[$([OH]),$([O-])]", |
|
|
] |
|
|
|
|
|
|
|
|
metal_exclusions = {11, 12, 19, 20} |
|
|
for atom in mol.GetAtoms(): |
|
|
atomic_num = atom.GetAtomicNum() |
|
|
|
|
|
if atomic_num > 20 and atomic_num not in metal_exclusions: |
|
|
return False |
|
|
|
|
|
|
|
|
for pattern in problematic_patterns: |
|
|
try: |
|
|
patt_mol = Chem.MolFromSmarts(pattern) |
|
|
if patt_mol is not None: |
|
|
matches = mol.GetSubstructMatches(patt_mol) |
|
|
if matches: |
|
|
return False |
|
|
except Exception: |
|
|
|
|
|
continue |
|
|
|
|
|
return True |
|
|
|
|
|
except Exception: |
|
|
return False |
|
|
|
|
|
|
|
|
def get_sa_label_and_confidence(selfies_str: str) -> tuple[str, float]: |
|
|
"""Get SA label (Easy/Hard) and confidence from the model's SA classifier.""" |
|
|
try: |
|
|
from ChemQ3MTP.rl_utils import get_sa_classifier |
|
|
classifier = get_sa_classifier() |
|
|
if classifier is None: |
|
|
return "Unknown", 0.0 |
|
|
|
|
|
|
|
|
result = classifier(selfies_str, truncation=True, max_length=128)[0] |
|
|
return result["label"], result["score"] |
|
|
except Exception as e: |
|
|
return "Unknown", 0.0 |
|
|
|
|
|
def get_morgan_fingerprint_from_smiles(smiles: str, radius=2, n_bits=2048): |
|
|
mol = Chem.MolFromSmiles(smiles) |
|
|
if mol is None: |
|
|
return None |
|
|
return AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=n_bits) |
|
|
|
|
|
def tanimoto_sim(fp1, fp2): |
|
|
from rdkit.DataStructs import TanimotoSimilarity |
|
|
return TanimotoSimilarity(fp1, fp2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def evaluate_model( |
|
|
model_path: str, |
|
|
train_data_path: str = "../data/chunk_5.csv", |
|
|
n_samples: int = 1000, |
|
|
seed: int = 42, |
|
|
max_gen_len: int = 32 |
|
|
): |
|
|
torch.manual_seed(seed) |
|
|
random.seed(seed) |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
print(f"🚀 Evaluating model at: {model_path}") |
|
|
print(f" Device: {device} | Samples: {n_samples} | Seed: {seed}\n") |
|
|
|
|
|
|
|
|
tokenizer = FastChemTokenizerSelfies.from_pretrained("../selftok_core") |
|
|
model = ChemQ3MTPForCausalLM.from_pretrained(model_path) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
print("📂 Loading and normalizing training set for novelty...") |
|
|
train_df = pd.read_csv(train_data_path) |
|
|
train_selfies_clean = set() |
|
|
for s in train_df["SELFIES"].dropna().astype(str): |
|
|
clean_s = s.replace(" ", "") |
|
|
train_selfies_clean.add(clean_s) |
|
|
print(f" Training set size: {len(train_selfies_clean)} unique (space-free) SELFIES\n") |
|
|
|
|
|
|
|
|
print("GenerationStrategy: Using MTP-aware generation...") |
|
|
all_selfies_raw = [] |
|
|
batch_size = 32 |
|
|
num_batches = (n_samples + batch_size - 1) // batch_size |
|
|
|
|
|
with torch.no_grad(): |
|
|
for _ in tqdm(range(num_batches), desc="Generating"): |
|
|
current_batch_size = min(batch_size, n_samples - len(all_selfies_raw)) |
|
|
if current_batch_size <= 0: |
|
|
break |
|
|
|
|
|
input_ids = torch.full( |
|
|
(current_batch_size, 1), |
|
|
tokenizer.bos_token_id, |
|
|
dtype=torch.long, |
|
|
device=device |
|
|
) |
|
|
|
|
|
if hasattr(model, 'generate_with_logprobs'): |
|
|
try: |
|
|
outputs = model.generate_with_logprobs( |
|
|
input_ids=input_ids, |
|
|
max_new_tokens=25, |
|
|
temperature=1.0, |
|
|
top_k=50, |
|
|
top_p=0.95, |
|
|
do_sample=True, |
|
|
return_probs=True, |
|
|
tokenizer=tokenizer |
|
|
) |
|
|
batch_selfies = outputs[0] |
|
|
except Exception as e: |
|
|
print(f"⚠️ MTP generation failed: {e}. Falling back.") |
|
|
gen_tokens = model.generate( |
|
|
input_ids, |
|
|
max_length=max_gen_len, |
|
|
do_sample=True, |
|
|
top_k=50, |
|
|
top_p=0.95, |
|
|
temperature=1.0, |
|
|
pad_token_id=tokenizer.pad_token_id, |
|
|
eos_token_id=tokenizer.eos_token_id |
|
|
) |
|
|
batch_selfies = [ |
|
|
tokenizer.decode(seq, skip_special_tokens=True) |
|
|
for seq in gen_tokens |
|
|
] |
|
|
else: |
|
|
gen_tokens = model.generate( |
|
|
input_ids, |
|
|
max_length=max_gen_len, |
|
|
do_sample=True, |
|
|
top_k=50, |
|
|
top_p=0.95, |
|
|
temperature=1.0, |
|
|
pad_token_id=tokenizer.pad_token_id, |
|
|
eos_token_id=tokenizer.eos_token_id |
|
|
) |
|
|
batch_selfies = [ |
|
|
tokenizer.decode(seq, skip_special_tokens=True) |
|
|
for seq in gen_tokens |
|
|
] |
|
|
|
|
|
all_selfies_raw.extend(batch_selfies) |
|
|
if len(all_selfies_raw) >= n_samples: |
|
|
break |
|
|
|
|
|
all_selfies_raw = all_selfies_raw[:n_samples] |
|
|
print(f"\n✅ Generated {len(all_selfies_raw)} raw SELFIES strings.\n") |
|
|
|
|
|
|
|
|
valid_records = [] |
|
|
print("🧪 Processing SELFIES and converting to SMILES...") |
|
|
for i, raw_selfies in enumerate(tqdm(all_selfies_raw, desc="Converting")): |
|
|
|
|
|
clean_selfies = raw_selfies.replace(" ", "") |
|
|
|
|
|
|
|
|
smiles = selfies_to_smiles(clean_selfies) |
|
|
|
|
|
if smiles is not None and is_valid_smiles(smiles) and passes_durrant_lab_filter(smiles): |
|
|
valid_records.append({ |
|
|
"raw_selfies": raw_selfies, |
|
|
"selfies_clean": clean_selfies, |
|
|
"selfies": clean_selfies, |
|
|
"smiles": smiles.strip() |
|
|
}) |
|
|
|
|
|
|
|
|
if valid_records: |
|
|
print("\n🔍 DEBUG: Sample generated molecules") |
|
|
print("-" * 70) |
|
|
for i in range(min(5, len(valid_records))): |
|
|
example = valid_records[i] |
|
|
print(f"Example {i+1}:") |
|
|
print(f" Raw SELFIES : {example['raw_selfies'][:80]}{'...' if len(example['raw_selfies']) > 80 else ''}") |
|
|
print(f" SMILES : {example['smiles']}") |
|
|
|
|
|
|
|
|
label, confidence = get_sa_label_and_confidence(example['raw_selfies']) |
|
|
print(f" SA Label : {label} (confidence: {confidence:.3f})") |
|
|
|
|
|
if i == 0: |
|
|
|
|
|
simple_label, simple_conf = get_sa_label_and_confidence('[C]') |
|
|
benzene_label, benzene_conf = get_sa_label_and_confidence('[c] [c] [c] [c] [c] [c] [Ring1] [=Branch1]') |
|
|
print(f" 🧪 SA Test - Simple molecule: {simple_label} ({simple_conf:.3f})") |
|
|
print(f" 🧪 SA Test - Benzene: {benzene_label} ({benzene_conf:.3f})") |
|
|
|
|
|
|
|
|
mol = Chem.MolFromSmiles(example['smiles']) |
|
|
if mol: |
|
|
print(f" Atoms : {mol.GetNumAtoms()}") |
|
|
print(f" Bonds : {mol.GetNumBonds()}") |
|
|
print() |
|
|
print("-" * 70) |
|
|
|
|
|
|
|
|
sa_labels = [] |
|
|
for r in valid_records[:100]: |
|
|
label, _ = get_sa_label_and_confidence(r["raw_selfies"]) |
|
|
sa_labels.append(label) |
|
|
|
|
|
easy_count = sa_labels.count("Easy") |
|
|
hard_count = sa_labels.count("Hard") |
|
|
unknown_count = sa_labels.count("Unknown") |
|
|
|
|
|
print(f"🔍 SA Label Analysis (first 100 molecules):") |
|
|
print(f" Easy to synthesize: {easy_count}/100 ({easy_count}%)") |
|
|
print(f" Hard to synthesize: {hard_count}/100 ({hard_count}%)") |
|
|
if unknown_count > 0: |
|
|
print(f" Unknown/Failed: {unknown_count}/100 ({unknown_count}%)") |
|
|
else: |
|
|
print("\n⚠️ WARNING: No valid molecules generated in sample!") |
|
|
|
|
|
|
|
|
|
|
|
validity = len(valid_records) / n_samples |
|
|
|
|
|
unique_valid = list({r["selfies_clean"]: r for r in valid_records}.values()) |
|
|
uniqueness = len(unique_valid) / len(valid_records) if valid_records else 0.0 |
|
|
|
|
|
novel_count = sum(1 for r in unique_valid if r["selfies_clean"] not in train_selfies_clean) |
|
|
novelty = novel_count / len(unique_valid) if unique_valid else 0.0 |
|
|
|
|
|
|
|
|
sa_labels_all = [] |
|
|
for r in unique_valid: |
|
|
label, _ = get_sa_label_and_confidence(r["raw_selfies"]) |
|
|
sa_labels_all.append(label) |
|
|
|
|
|
easy_total = sa_labels_all.count("Easy") |
|
|
hard_total = sa_labels_all.count("Hard") |
|
|
unknown_total = sa_labels_all.count("Unknown") |
|
|
total_labeled = len(sa_labels_all) |
|
|
|
|
|
|
|
|
if len(unique_valid) >= 2: |
|
|
fps = [] |
|
|
for r in unique_valid: |
|
|
fp = get_morgan_fingerprint_from_smiles(r["smiles"]) |
|
|
if fp is not None: |
|
|
fps.append(fp) |
|
|
if len(fps) >= 2: |
|
|
total_sim, count = 0.0, 0 |
|
|
for i in range(len(fps)): |
|
|
for j in range(i + 1, len(fps)): |
|
|
total_sim += tanimoto_sim(fps[i], fps[j]) |
|
|
count += 1 |
|
|
internal_diversity = 1.0 - (total_sim / count) |
|
|
else: |
|
|
internal_diversity = 0.0 |
|
|
else: |
|
|
internal_diversity = 0.0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\n" + "="*55) |
|
|
print("📊 MOLECULAR GENERATION EVALUATION SUMMARY") |
|
|
print("="*55) |
|
|
print(f"Model Path : {model_path}") |
|
|
print(f"Generation Mode : {'MTP-aware' if hasattr(model, 'generate_with_logprobs') else 'Standard'}") |
|
|
print(f"Samples Generated: {n_samples}") |
|
|
print("-"*55) |
|
|
print(f"Validity : {validity:.4f} ({len(valid_records)}/{n_samples})") |
|
|
print(f"Uniqueness : {uniqueness:.4f} (unique valid)") |
|
|
print(f"Novelty (vs train): {novelty:.4f} (space-free SELFIES)") |
|
|
print(f"Synthesis Labels : Easy: {easy_total}/{total_labeled} ({easy_total/max(1,total_labeled)*100:.1f}%) | Hard: {hard_total}/{total_labeled} ({hard_total/max(1,total_labeled)*100:.1f}%)") |
|
|
if unknown_total > 0: |
|
|
print(f" Unknown: {unknown_total}/{total_labeled} ({unknown_total/max(1,total_labeled)*100:.1f}%)") |
|
|
print(f"Internal Diversity: {internal_diversity:.4f} (1 - avg Tanimoto)") |
|
|
print("="*55) |
|
|
|
|
|
results = { |
|
|
"model_path": model_path, |
|
|
"generation_mode": "MTP-aware" if hasattr(model, 'generate_with_logprobs') else "standard", |
|
|
"n_samples": n_samples, |
|
|
"validity": validity, |
|
|
"uniqueness": uniqueness, |
|
|
"novelty": novelty, |
|
|
"sa_easy_count": easy_total, |
|
|
"sa_hard_count": hard_total, |
|
|
"sa_easy_percentage": easy_total/max(1,total_labeled)*100, |
|
|
"sa_hard_percentage": hard_total/max(1,total_labeled)*100, |
|
|
"internal_diversity": internal_diversity, |
|
|
"valid_molecules_count": len(valid_records) |
|
|
} |
|
|
|
|
|
if unknown_total > 0: |
|
|
results["sa_unknown_count"] = unknown_total |
|
|
results["sa_unknown_percentage"] = unknown_total/max(1,total_labeled)*100 |
|
|
|
|
|
output_json = os.path.join(model_path, "evaluation_summary.json") |
|
|
with open(output_json, "w") as f: |
|
|
json.dump(results, f, indent=2) |
|
|
print(f"\n💾 Results saved to: {output_json}") |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser(description="Evaluate molecular generative model with MTP-aware generation") |
|
|
parser.add_argument("--model_path", type=str, required=True, help="Path to model checkpoint") |
|
|
parser.add_argument("--n_samples", type=int, default=1000, help="Number of molecules to generate") |
|
|
parser.add_argument("--seed", type=int, default=42, help="Random seed") |
|
|
parser.add_argument("--train_data", type=str, default="../data/chunk_5.csv", help="Training data CSV") |
|
|
|
|
|
args = parser.parse_args() |
|
|
evaluate_model( |
|
|
model_path=args.model_path, |
|
|
train_data_path=args.train_data, |
|
|
n_samples=args.n_samples, |
|
|
seed=args.seed |
|
|
) |