File size: 2,367 Bytes
80ad4cd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import torch
from transformers import AutoTokenizer
import sys
import os
# Add current dir to path to find DLM_emb_model
sys.path.append(os.getcwd())
try:
from DLM_emb_model import MolEmbDLM
except ImportError:
print("Could not import MolEmbDLM. Make sure you are running from ApexOracle directory.")
exit(1)
# Use local model path where we applied the fix
model_path = "/data2/tianang/projects/mdlm/huggingface/huggingface_model"
print(f"Loading model from {model_path}...")
try:
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = MolEmbDLM.from_pretrained(model_path)
except Exception as e:
print(f"Failed to load model: {e}")
# Try loading from local directory if remote fails
try:
model = MolEmbDLM.from_pretrained(".")
except Exception as e2:
print(f"Failed to load from local: {e2}")
exit(1)
model.eval()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
# Two different SELFIES
selfies_list = [
"[C][C][O]", # Ethanol
"[C][C][=O][O]" # Acetic Acid
]
# Preprocessing from example.py: seq.replace('][', '] [')
processed_selfies = [s.replace('][', '] [') for s in selfies_list]
print(f"Processed SELFIES: {processed_selfies}")
# Tokenize with padding=True to create a batch (essential to test the bug fix)
# example.py had padding=False because it was single sequence.
print("Tokenizing inputs...")
inputs = tokenizer(
processed_selfies,
padding=True,
truncation=True,
return_tensors="pt"
)
print(f"Input IDs:\n{inputs['input_ids']}")
print(f"Attention Mask:\n{inputs['attention_mask']}")
inputs = {k: v.to(device) for k, v in inputs.items() if k in ["input_ids", "attention_mask"]}
print("Running model...")
with torch.no_grad():
embeddings = model(**inputs)
print(f"Embeddings shape: {embeddings.shape}")
# Compare embeddings of the two molecules
# We compare the mean embedding or the first token embedding
emb1 = embeddings[0]
emb2 = embeddings[1]
# Calculate difference
diff = torch.abs(emb1 - emb2).sum().item()
print(f"Difference between embeddings (sum of abs diff): {diff}")
if diff < 1e-6:
print("ISSUE: Embeddings are identical.")
else:
print("SUCCESS: Embeddings are different.")
print(f"Emb1 mean: {emb1.mean().item()}")
print(f"Emb2 mean: {emb2.mean().item()}")
|