Joblib
PeptiVerse / load.py
ynuozhang
add inference
62e6dc2
# peptiverse_infer.py
from __future__ import annotations
import csv, re, json
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Optional, Tuple, Any, List
import numpy as np
import torch
import torch.nn as nn
import joblib
import xgboost as xgb
from transformers import EsmModel, EsmTokenizer, AutoModelForMaskedLM
from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
# -----------------------------
# Manifest
# -----------------------------
@dataclass(frozen=True)
class BestRow:
property_key: str
best_wt: Optional[str]
best_smiles: Optional[str]
task_type: str # "Classifier" or "Regression"
thr_wt: Optional[float]
thr_smiles: Optional[float]
def _clean(s: str) -> str:
return (s or "").strip()
def _none_if_dash(s: str) -> Optional[str]:
s = _clean(s)
if s in {"", "-", "—", "NA", "N/A"}:
return None
return s
def _float_or_none(s: str) -> Optional[float]:
s = _clean(s)
if s in {"", "-", "—", "NA", "N/A"}:
return None
return float(s)
def normalize_property_key(name: str) -> str:
n = name.strip().lower()
n = re.sub(r"\s*\(.*?\)\s*", "", n)
n = n.replace("-", "_").replace(" ", "_")
if "permeability" in n and "pampa" not in n and "caco" not in n:
return "permeability_penetrance"
if n == "binding_affinity":
return "binding_affinity"
if n == "halflife":
return "half_life"
if n == "non_fouling":
return "nf"
return n
def read_best_manifest_csv(path: str | Path) -> Dict[str, BestRow]:
"""
Properties, Best_Model_WT, Best_Model_SMILES, Type, Threshold_WT, Threshold_SMILES,
Hemolysis, SVM, SGB, Classifier, 0.2801, 0.2223,
"""
p = Path(path)
out: Dict[str, BestRow] = {}
with p.open("r", newline="") as f:
reader = csv.reader(f)
header = None
for raw in reader:
if not raw or all(_clean(x) == "" for x in raw):
continue
while raw and _clean(raw[-1]) == "":
raw = raw[:-1]
if header is None:
header = [h.strip() for h in raw]
continue
if len(raw) < len(header):
raw = raw + [""] * (len(header) - len(raw))
rec = dict(zip(header, raw))
prop_raw = _clean(rec.get("Properties", ""))
if not prop_raw:
continue
prop_key = normalize_property_key(prop_raw)
row = BestRow(
property_key=prop_key,
best_wt=_none_if_dash(rec.get("Best_Model_WT", "")),
best_smiles=_none_if_dash(rec.get("Best_Model_SMILES", "")),
task_type=_clean(rec.get("Type", "Classifier")),
thr_wt=_float_or_none(rec.get("Threshold_WT", "")),
thr_smiles=_float_or_none(rec.get("Threshold_SMILES", "")),
)
out[prop_key] = row
return out
MODEL_ALIAS = {
"SVM": "svm_gpu",
"SVR": "svr",
"ENET": "enet_gpu",
"CNN": "cnn",
"MLP": "mlp",
"TRANSFORMER": "transformer",
"XGB": "xgb",
"XGB_REG": "xgb_reg",
"POOLED": "pooled",
"UNPOOLED": "unpooled"
}
def canon_model(label: Optional[str]) -> Optional[str]:
if label is None:
return None
k = label.strip().upper()
return MODEL_ALIAS.get(k, label.strip().lower())
# -----------------------------
# Generic artifact loading
# -----------------------------
def find_best_artifact(model_dir: Path) -> Path:
for pat in ["best_model.json", "best_model.pt", "best_model*.joblib"]:
hits = sorted(model_dir.glob(pat))
if hits:
return hits[0]
raise FileNotFoundError(f"No best_model artifact found in {model_dir}")
def load_artifact(model_dir: Path, device: torch.device) -> Tuple[str, Any, Path]:
art = find_best_artifact(model_dir)
if art.suffix == ".json":
booster = xgb.Booster()
print(str(art))
booster.load_model(str(art))
return "xgb", booster, art
if art.suffix == ".joblib":
obj = joblib.load(art)
return "joblib", obj, art
if art.suffix == ".pt":
ckpt = torch.load(art, map_location=device, weights_only=False)
return "torch_ckpt", ckpt, art
raise ValueError(f"Unknown artifact type: {art}")
# -----------------------------
# NN architectures
# -----------------------------
class MaskedMeanPool(nn.Module):
def forward(self, X, M): # X:(B,L,H), M:(B,L)
Mf = M.unsqueeze(-1).float()
denom = Mf.sum(dim=1).clamp(min=1.0)
return (X * Mf).sum(dim=1) / denom
class MLPHead(nn.Module):
def __init__(self, in_dim, hidden=512, dropout=0.1):
super().__init__()
self.pool = MaskedMeanPool()
self.net = nn.Sequential(
nn.Linear(in_dim, hidden),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden, 1),
)
def forward(self, X, M):
z = self.pool(X, M)
return self.net(z).squeeze(-1)
class CNNHead(nn.Module):
def __init__(self, in_ch, c=256, k=5, layers=2, dropout=0.1):
super().__init__()
blocks = []
ch = in_ch
for _ in range(layers):
blocks += [nn.Conv1d(ch, c, kernel_size=k, padding=k//2),
nn.GELU(),
nn.Dropout(dropout)]
ch = c
self.conv = nn.Sequential(*blocks)
self.head = nn.Linear(c, 1)
def forward(self, X, M):
Xc = X.transpose(1, 2) # (B,H,L)
Y = self.conv(Xc).transpose(1, 2) # (B,L,C)
Mf = M.unsqueeze(-1).float()
denom = Mf.sum(dim=1).clamp(min=1.0)
pooled = (Y * Mf).sum(dim=1) / denom
return self.head(pooled).squeeze(-1)
class TransformerHead(nn.Module):
def __init__(self, in_dim, d_model=256, nhead=8, layers=2, ff=512, dropout=0.1):
super().__init__()
self.proj = nn.Linear(in_dim, d_model)
enc_layer = nn.TransformerEncoderLayer(
d_model=d_model, nhead=nhead, dim_feedforward=ff,
dropout=dropout, batch_first=True, activation="gelu"
)
self.enc = nn.TransformerEncoder(enc_layer, num_layers=layers)
self.head = nn.Linear(d_model, 1)
def forward(self, X, M):
pad_mask = ~M
Z = self.proj(X)
Z = self.enc(Z, src_key_padding_mask=pad_mask)
Mf = M.unsqueeze(-1).float()
denom = Mf.sum(dim=1).clamp(min=1.0)
pooled = (Z * Mf).sum(dim=1) / denom
return self.head(pooled).squeeze(-1)
def _infer_in_dim_from_sd(sd: dict, model_name: str) -> int:
if model_name == "mlp":
return int(sd["net.0.weight"].shape[1])
if model_name == "cnn":
return int(sd["conv.0.weight"].shape[1])
if model_name == "transformer":
return int(sd["proj.weight"].shape[1])
raise ValueError(model_name)
def build_torch_model_from_ckpt(model_name: str, ckpt: dict, device: torch.device) -> nn.Module:
params = ckpt["best_params"]
sd = ckpt["state_dict"]
in_dim = int(ckpt.get("in_dim", _infer_in_dim_from_sd(sd, model_name)))
dropout = float(params.get("dropout", 0.1))
if model_name == "mlp":
model = MLPHead(in_dim=in_dim, hidden=int(params["hidden"]), dropout=dropout)
elif model_name == "cnn":
model = CNNHead(in_ch=in_dim, c=int(params["channels"]), k=int(params["kernel"]),
layers=int(params["layers"]), dropout=dropout)
elif model_name == "transformer":
model = TransformerHead(in_dim=in_dim, d_model=int(params["d_model"]), nhead=int(params["nhead"]),
layers=int(params["layers"]), ff=int(params["ff"]), dropout=dropout)
else:
raise ValueError(f"Unknown NN model_name={model_name}")
model.load_state_dict(sd)
model.to(device)
model.eval()
return model
# -----------------------------
# Binding affinity models
# -----------------------------
def affinity_to_class(y: float) -> int:
# 0=High(>=9), 1=Moderate(7-9), 2=Low(<7)
if y >= 9.0: return 0
if y < 7.0: return 2
return 1
class CrossAttnPooled(nn.Module):
def __init__(self, Ht, Hb, hidden=512, n_heads=8, n_layers=3, dropout=0.1):
super().__init__()
self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden))
self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden))
self.layers = nn.ModuleList([])
for _ in range(n_layers):
self.layers.append(nn.ModuleDict({
"attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False),
"attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False),
"n1t": nn.LayerNorm(hidden),
"n2t": nn.LayerNorm(hidden),
"n1b": nn.LayerNorm(hidden),
"n2b": nn.LayerNorm(hidden),
"fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
"ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
}))
self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout))
self.reg = nn.Linear(hidden, 1)
self.cls = nn.Linear(hidden, 3)
def forward(self, t_vec, b_vec):
t = self.t_proj(t_vec).unsqueeze(0) # (1,B,H)
b = self.b_proj(b_vec).unsqueeze(0) # (1,B,H)
for L in self.layers:
t_attn, _ = L["attn_tb"](t, b, b)
t = L["n1t"]((t + t_attn).transpose(0,1)).transpose(0,1)
t = L["n2t"]((t + L["fft"](t)).transpose(0,1)).transpose(0,1)
b_attn, _ = L["attn_bt"](b, t, t)
b = L["n1b"]((b + b_attn).transpose(0,1)).transpose(0,1)
b = L["n2b"]((b + L["ffb"](b)).transpose(0,1)).transpose(0,1)
z = torch.cat([t[0], b[0]], dim=-1)
h = self.shared(z)
return self.reg(h).squeeze(-1), self.cls(h)
class CrossAttnUnpooled(nn.Module):
def __init__(self, Ht, Hb, hidden=512, n_heads=8, n_layers=3, dropout=0.1):
super().__init__()
self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden))
self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden))
self.layers = nn.ModuleList([])
for _ in range(n_layers):
self.layers.append(nn.ModuleDict({
"attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True),
"attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True),
"n1t": nn.LayerNorm(hidden),
"n2t": nn.LayerNorm(hidden),
"n1b": nn.LayerNorm(hidden),
"n2b": nn.LayerNorm(hidden),
"fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
"ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
}))
self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout))
self.reg = nn.Linear(hidden, 1)
self.cls = nn.Linear(hidden, 3)
def _masked_mean(self, X, M):
Mf = M.unsqueeze(-1).float()
denom = Mf.sum(dim=1).clamp(min=1.0)
return (X * Mf).sum(dim=1) / denom
def forward(self, T, Mt, B, Mb):
T = self.t_proj(T)
Bx = self.b_proj(B)
kp_t = ~Mt
kp_b = ~Mb
for L in self.layers:
T_attn, _ = L["attn_tb"](T, Bx, Bx, key_padding_mask=kp_b)
T = L["n1t"](T + T_attn)
T = L["n2t"](T + L["fft"](T))
B_attn, _ = L["attn_bt"](Bx, T, T, key_padding_mask=kp_t)
Bx = L["n1b"](Bx + B_attn)
Bx = L["n2b"](Bx + L["ffb"](Bx))
t_pool = self._masked_mean(T, Mt)
b_pool = self._masked_mean(Bx, Mb)
z = torch.cat([t_pool, b_pool], dim=-1)
h = self.shared(z)
return self.reg(h).squeeze(-1), self.cls(h)
def load_binding_model(best_model_pt: Path, pooled_or_unpooled: str, device: torch.device) -> nn.Module:
ckpt = torch.load(best_model_pt, map_location=device, weights_only=False)
params = ckpt["best_params"]
sd = ckpt["state_dict"]
# infer Ht/Hb from projection weights
Ht = int(sd["t_proj.0.weight"].shape[1])
Hb = int(sd["b_proj.0.weight"].shape[1])
common = dict(
Ht=Ht, Hb=Hb,
hidden=int(params["hidden_dim"]),
n_heads=int(params["n_heads"]),
n_layers=int(params["n_layers"]),
dropout=float(params["dropout"]),
)
if pooled_or_unpooled == "pooled":
model = CrossAttnPooled(**common)
elif pooled_or_unpooled == "unpooled":
model = CrossAttnUnpooled(**common)
else:
raise ValueError(pooled_or_unpooled)
model.load_state_dict(sd)
model.to(device).eval()
return model
# -----------------------------
# Embedding generation
# -----------------------------
def _safe_isin(ids: torch.Tensor, test_ids: torch.Tensor) -> torch.Tensor:
"""
Pytorch patch
"""
if hasattr(torch, "isin"):
return torch.isin(ids, test_ids)
# Fallback: compare against each special id
# (B,L,1) == (1,1,K) -> (B,L,K)
return (ids.unsqueeze(-1) == test_ids.view(1, 1, -1)).any(dim=-1)
class SMILESEmbedder:
"""
PeptideCLM RoFormer embeddings for SMILES.
- pooled(): mean over tokens where attention_mask==1 AND token_id not in SPECIAL_IDS
- unpooled(): returns token embeddings filtered to valid tokens (specials removed),
plus a 1-mask of length Li (since already filtered).
"""
def __init__(
self,
device: torch.device,
vocab_path: str,
splits_path: str,
clm_name: str = "aaronfeller/PeptideCLM-23M-all",
max_len: int = 512,
use_cache: bool = True,
):
self.device = device
self.max_len = max_len
self.use_cache = use_cache
self.tokenizer = SMILES_SPE_Tokenizer(vocab_path, splits_path)
self.model = AutoModelForMaskedLM.from_pretrained(clm_name).roformer.to(device).eval()
self.special_ids = self._get_special_ids(self.tokenizer)
self.special_ids_t = (torch.tensor(self.special_ids, device=device, dtype=torch.long)
if len(self.special_ids) else None)
self._cache_pooled: Dict[str, torch.Tensor] = {}
self._cache_unpooled: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {}
@staticmethod
def _get_special_ids(tokenizer) -> List[int]:
cand = [
getattr(tokenizer, "pad_token_id", None),
getattr(tokenizer, "cls_token_id", None),
getattr(tokenizer, "sep_token_id", None),
getattr(tokenizer, "bos_token_id", None),
getattr(tokenizer, "eos_token_id", None),
getattr(tokenizer, "mask_token_id", None),
]
return sorted({int(x) for x in cand if x is not None})
def _tokenize(self, smiles_list: List[str]) -> Dict[str, torch.Tensor]:
tok = self.tokenizer(
smiles_list,
return_tensors="pt",
padding=True,
truncation=True,
max_length=self.max_len,
)
for k in tok:
tok[k] = tok[k].to(self.device)
if "attention_mask" not in tok:
tok["attention_mask"] = torch.ones_like(tok["input_ids"], dtype=torch.long, device=self.device)
return tok
@torch.no_grad()
def pooled(self, smiles: str) -> torch.Tensor:
s = smiles.strip()
if self.use_cache and s in self._cache_pooled:
return self._cache_pooled[s]
tok = self._tokenize([s])
ids = tok["input_ids"] # (1,L)
attn = tok["attention_mask"].bool() # (1,L)
out = self.model(input_ids=ids, attention_mask=tok["attention_mask"])
h = out.last_hidden_state # (1,L,H)
valid = attn
if self.special_ids_t is not None and self.special_ids_t.numel() > 0:
valid = valid & (~_safe_isin(ids, self.special_ids_t))
vf = valid.unsqueeze(-1).float()
summed = (h * vf).sum(dim=1) # (1,H)
denom = vf.sum(dim=1).clamp(min=1e-9) # (1,1)
pooled = summed / denom # (1,H)
if self.use_cache:
self._cache_pooled[s] = pooled
return pooled
@torch.no_grad()
def unpooled(self, smiles: str) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Returns:
X: (1, Li, H) float32 on device
M: (1, Li) bool on device
where Li excludes padding + special tokens.
"""
s = smiles.strip()
if self.use_cache and s in self._cache_unpooled:
return self._cache_unpooled[s]
tok = self._tokenize([s])
ids = tok["input_ids"] # (1,L)
attn = tok["attention_mask"].bool() # (1,L)
out = self.model(input_ids=ids, attention_mask=tok["attention_mask"])
h = out.last_hidden_state # (1,L,H)
valid = attn
if self.special_ids_t is not None and self.special_ids_t.numel() > 0:
valid = valid & (~_safe_isin(ids, self.special_ids_t))
# filter valid tokens
keep = valid[0] # (L,)
X = h[:, keep, :] # (1,Li,H)
M = torch.ones((1, X.shape[1]), dtype=torch.bool, device=self.device)
if self.use_cache:
self._cache_unpooled[s] = (X, M)
return X, M
class WTEmbedder:
"""
ESM2 embeddings for AA sequences.
- pooled(): mean over tokens where attention_mask==1 AND token_id not in {CLS, EOS, PAD,...}
- unpooled(): returns token embeddings filtered to valid tokens (specials removed),
plus a 1-mask of length Li (since already filtered).
"""
def __init__(
self,
device: torch.device,
esm_name: str = "facebook/esm2_t33_650M_UR50D",
max_len: int = 1022,
use_cache: bool = True,
):
self.device = device
self.max_len = max_len
self.use_cache = use_cache
self.tokenizer = EsmTokenizer.from_pretrained(esm_name)
self.model = EsmModel.from_pretrained(esm_name, add_pooling_layer=False).to(device).eval()
self.special_ids = self._get_special_ids(self.tokenizer)
self.special_ids_t = (torch.tensor(self.special_ids, device=device, dtype=torch.long)
if len(self.special_ids) else None)
self._cache_pooled: Dict[str, torch.Tensor] = {}
self._cache_unpooled: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {}
@staticmethod
def _get_special_ids(tokenizer) -> List[int]:
cand = [
getattr(tokenizer, "pad_token_id", None),
getattr(tokenizer, "cls_token_id", None),
getattr(tokenizer, "sep_token_id", None),
getattr(tokenizer, "bos_token_id", None),
getattr(tokenizer, "eos_token_id", None),
getattr(tokenizer, "mask_token_id", None),
]
return sorted({int(x) for x in cand if x is not None})
def _tokenize(self, seq_list: List[str]) -> Dict[str, torch.Tensor]:
tok = self.tokenizer(
seq_list,
return_tensors="pt",
padding=True,
truncation=True,
max_length=self.max_len,
)
tok = {k: v.to(self.device) for k, v in tok.items()}
if "attention_mask" not in tok:
tok["attention_mask"] = torch.ones_like(tok["input_ids"], dtype=torch.long, device=self.device)
return tok
@torch.no_grad()
def pooled(self, seq: str) -> torch.Tensor:
s = seq.strip()
if self.use_cache and s in self._cache_pooled:
return self._cache_pooled[s]
tok = self._tokenize([s])
ids = tok["input_ids"] # (1,L)
attn = tok["attention_mask"].bool() # (1,L)
out = self.model(**tok)
h = out.last_hidden_state # (1,L,H)
valid = attn
if self.special_ids_t is not None and self.special_ids_t.numel() > 0:
valid = valid & (~_safe_isin(ids, self.special_ids_t))
vf = valid.unsqueeze(-1).float()
summed = (h * vf).sum(dim=1) # (1,H)
denom = vf.sum(dim=1).clamp(min=1e-9) # (1,1)
pooled = summed / denom # (1,H)
if self.use_cache:
self._cache_pooled[s] = pooled
return pooled
@torch.no_grad()
def unpooled(self, seq: str) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Returns:
X: (1, Li, H) float32 on device
M: (1, Li) bool on device
where Li excludes padding + special tokens.
"""
s = seq.strip()
if self.use_cache and s in self._cache_unpooled:
return self._cache_unpooled[s]
tok = self._tokenize([s])
ids = tok["input_ids"] # (1,L)
attn = tok["attention_mask"].bool() # (1,L)
out = self.model(**tok)
h = out.last_hidden_state # (1,L,H)
valid = attn
if self.special_ids_t is not None and self.special_ids_t.numel() > 0:
valid = valid & (~_safe_isin(ids, self.special_ids_t))
keep = valid[0] # (L,)
X = h[:, keep, :] # (1,Li,H)
M = torch.ones((1, X.shape[1]), dtype=torch.bool, device=self.device)
if self.use_cache:
self._cache_unpooled[s] = (X, M)
return X, M
# -----------------------------
# Predictor
# -----------------------------
class PeptiVersePredictor:
"""
- loads best models from training_classifiers/
- computes embeddings as needed (pooled/unpooled)
- supports: xgb, joblib(ENET/SVM/SVR), NN(mlp/cnn/transformer), binding pooled/unpooled.
"""
def __init__(
self,
manifest_path: str | Path,
classifier_weight_root: str | Path,
esm_name="facebook/esm2_t33_650M_UR50D",
clm_name="aaronfeller/PeptideCLM-23M-all",
smiles_vocab="tokenizer/new_vocab.txt",
smiles_splits="tokenizer/new_splits.txt",
device: Optional[str] = None,
):
self.root = Path(classifier_weight_root)
self.training_root = self.root / "training_classifiers"
self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
self.manifest = read_best_manifest_csv(manifest_path)
self.wt_embedder = WTEmbedder(self.device)
self.smiles_embedder = SMILESEmbedder(self.device, clm_name=clm_name,
vocab_path=str(self.root / smiles_vocab),
splits_path=str(self.root / smiles_splits))
self.models: Dict[Tuple[str, str], Any] = {}
self.meta: Dict[Tuple[str, str], Dict[str, Any]] = {}
self._load_all_best_models()
def _resolve_dir(self, prop_key: str, model_name: str, mode: str) -> Path:
"""
Usual layout: training_classifiers/<prop>/<model>_<mode>/
Fallbacks:
- training_classifiers/<prop>/<model>/
- training_classifiers/<prop>/<model>_wt
"""
base = self.training_root / prop_key
candidates = [
base / f"{model_name}_{mode}",
base / model_name,
]
if mode == "wt":
candidates += [base / f"{model_name}_wt"]
if mode == "smiles":
candidates += [base / f"{model_name}_smiles"]
for d in candidates:
if d.exists():
return d
raise FileNotFoundError(f"Cannot find model directory for {prop_key} {model_name} {mode}. Tried: {candidates}")
def _load_all_best_models(self):
for prop_key, row in self.manifest.items():
for mode, label, thr in [
("wt", row.best_wt, row.thr_wt),
("smiles", row.best_smiles, row.thr_smiles),
]:
m = canon_model(label)
if m is None:
continue
# ---- binding affinity special ----
if prop_key == "binding_affinity":
# label is pooled/unpooled; mode chooses folder wt_wt_* vs wt_smiles_*
pooled_or_unpooled = m # "pooled" or "unpooled"
folder = f"wt_{mode}_{pooled_or_unpooled}" # wt_wt_pooled / wt_smiles_unpooled etc.
model_dir = self.training_root / "binding_affinity" / folder
art = find_best_artifact(model_dir)
if art.suffix != ".pt":
raise RuntimeError(f"Binding model expected best_model.pt, got {art}")
model = load_binding_model(art, pooled_or_unpooled=pooled_or_unpooled, device=self.device)
self.models[(prop_key, mode)] = model
self.meta[(prop_key, mode)] = {
"task_type": "Regression",
"threshold": None,
"artifact": str(art),
"model_name": pooled_or_unpooled,
}
continue
model_dir = self._resolve_dir(prop_key, m, mode)
kind, obj, art = load_artifact(model_dir, self.device)
if kind in {"xgb", "joblib"}:
self.models[(prop_key, mode)] = obj
else:
# rebuild NN architecture
self.models[(prop_key, mode)] = build_torch_model_from_ckpt(m, obj, self.device)
self.meta[(prop_key, mode)] = {
"task_type": row.task_type,
"threshold": thr,
"artifact": str(art),
"model_name": m,
"kind": kind,
}
def _get_features_for_model(self, prop_key: str, mode: str, input_str: str):
"""
Returns either:
- pooled np array shape (1,H) for xgb/joblib
- unpooled torch tensors (X,M) for NN
"""
model = self.models[(prop_key, mode)]
meta = self.meta[(prop_key, mode)]
kind = meta.get("kind", None)
model_name = meta.get("model_name", "")
if prop_key == "binding_affinity":
raise RuntimeError("Use predict_binding_affinity().")
# If torch NN: needs unpooled
if kind == "torch_ckpt":
if mode == "wt":
X, M = self.wt_embedder.unpooled(input_str)
else:
X, M = self.smiles_embedder.unpooled(input_str)
return X, M
# Otherwise pooled vectors for xgb/joblib
if mode == "wt":
v = self.wt_embedder.pooled(input_str) # (1,H)
else:
v = self.smiles_embedder.pooled(input_str) # (1,H)
feats = v.detach().cpu().numpy().astype(np.float32)
feats = np.nan_to_num(feats, nan=0.0)
feats = np.clip(feats, np.finfo(np.float32).min, np.finfo(np.float32).max)
return feats
def predict_property(self, prop_key: str, mode: str, input_str: str) -> Dict[str, Any]:
"""
mode: "wt" for AA sequence input, "smiles" for SMILES input
Returns dict with score + label if classifier threshold exists.
"""
if (prop_key, mode) not in self.models:
raise KeyError(f"No model loaded for ({prop_key}, {mode}). Check manifest and folders.")
meta = self.meta[(prop_key, mode)]
model = self.models[(prop_key, mode)]
task_type = meta["task_type"].lower()
thr = meta.get("threshold", None)
kind = meta.get("kind", None)
if prop_key == "binding_affinity":
raise RuntimeError("Use predict_binding_affinity().")
# NN path (logits / regression)
if kind == "torch_ckpt":
X, M = self._get_features_for_model(prop_key, mode, input_str)
with torch.no_grad():
y = model(X, M).squeeze().float().cpu().item()
if task_type == "classifier":
prob = float(1.0 / (1.0 + np.exp(-y))) # sigmoid(logit)
out = {"property": prop_key, "mode": mode, "score": prob}
if thr is not None:
out["label"] = int(prob >= float(thr))
out["threshold"] = float(thr)
return out
else:
return {"property": prop_key, "mode": mode, "score": float(y)}
# xgb path
if kind == "xgb":
feats = self._get_features_for_model(prop_key, mode, input_str) # (1,H)
dmat = xgb.DMatrix(feats)
pred = float(model.predict(dmat)[0])
out = {"property": prop_key, "mode": mode, "score": pred}
if task_type == "classifier" and thr is not None:
out["label"] = int(pred >= float(thr))
out["threshold"] = float(thr)
return out
# joblib path (svm/enet/svr)
if kind == "joblib":
feats = self._get_features_for_model(prop_key, mode, input_str) # (1,H)
# classifier vs regressor behavior differs by estimator
if task_type == "classifier":
if hasattr(model, "predict_proba"):
pred = float(model.predict_proba(feats)[:, 1][0])
else:
if hasattr(model, "decision_function"):
logit = float(model.decision_function(feats)[0])
pred = float(1.0 / (1.0 + np.exp(-logit)))
else:
pred = float(model.predict(feats)[0])
out = {"property": prop_key, "mode": mode, "score": pred}
if thr is not None:
out["label"] = int(pred >= float(thr))
out["threshold"] = float(thr)
return out
else:
pred = float(model.predict(feats)[0])
return {"property": prop_key, "mode": mode, "score": pred}
raise RuntimeError(f"Unknown model kind={kind}")
def predict_binding_affinity(self, mode: str, target_seq: str, binder_str: str) -> Dict[str, Any]:
"""
mode: "wt" (binder is AA sequence) -> wt_wt_(pooled|unpooled)
"smiles" (binder is SMILES) -> wt_smiles_(pooled|unpooled)
"""
prop_key = "binding_affinity"
if (prop_key, mode) not in self.models:
raise KeyError(f"No binding model loaded for ({prop_key}, {mode}).")
model = self.models[(prop_key, mode)]
pooled_or_unpooled = self.meta[(prop_key, mode)]["model_name"] # pooled/unpooled
# target is always WT sequence (ESM)
if pooled_or_unpooled == "pooled":
t_vec = self.wt_embedder.pooled(target_seq) # (1,Ht)
if mode == "wt":
b_vec = self.wt_embedder.pooled(binder_str) # (1,Hb)
else:
b_vec = self.smiles_embedder.pooled(binder_str) # (1,Hb)
with torch.no_grad():
reg, logits = model(t_vec, b_vec)
affinity = float(reg.squeeze().cpu().item())
cls_logit = int(torch.argmax(logits, dim=-1).cpu().item())
cls_thr = affinity_to_class(affinity)
else:
T, Mt = self.wt_embedder.unpooled(target_seq)
if mode == "wt":
B, Mb = self.wt_embedder.unpooled(binder_str)
else:
B, Mb = self.smiles_embedder.unpooled(binder_str)
with torch.no_grad():
reg, logits = model(T, Mt, B, Mb)
affinity = float(reg.squeeze().cpu().item())
cls_logit = int(torch.argmax(logits, dim=-1).cpu().item())
cls_thr = affinity_to_class(affinity)
names = {0: "High (≥9)", 1: "Moderate (7–9)", 2: "Low (<7)"}
return {
"property": "binding_affinity",
"mode": mode,
"affinity": affinity,
"class_by_threshold": names[cls_thr],
"class_by_logits": names[cls_logit],
"binding_model": pooled_or_unpooled,
}
# -----------------------------
# Minimal usage
# -----------------------------
if __name__ == "__main__":
# Example:
predictor = PeptiVersePredictor(
manifest_path="best_models.txt",
classifier_weight_root="/vast/projects/pranam/lab/yz927/projects/Classifier_Weight"
)
print(predictor.predict_property("hemolysis", "wt", "GIGAVLKVLTTGLPALISWIKRKRQQ"))
print(predictor.predict_binding_affinity("wt", target_seq="...", binder_str="..."))
# Test Embedding #
"""
device = torch.device("cuda:0")
wt = WTEmbedder(device)
sm = SMILESEmbedder(device,
vocab_path="/home/enol/PeptideGym/Data_split/tokenizer/new_vocab.txt",
splits_path="/home/enol/PeptideGym/Data_split/tokenizer/new_splits.txt"
)
p = wt.pooled("GIGAVLKVLTTGLPALISWIKRKRQQ") # (1,1280)
X, M = wt.unpooled("GIGAVLKVLTTGLPALISWIKRKRQQ") # (1,Li,1280), (1,Li)
p2 = sm.pooled("NCC(=O)N[C@H](CS)C(=O)O") # (1,H_smiles)
X2, M2 = sm.unpooled("NCC(=O)N[C@H](CS)C(=O)O") # (1,Li,H_smiles), (1,Li)
"""