import torch from transformers import AutoTokenizer import sys import os # Add current dir to path to find DLM_emb_model sys.path.append(os.getcwd()) try: from DLM_emb_model import MolEmbDLM except ImportError: print("Could not import MolEmbDLM. Make sure you are running from ApexOracle directory.") exit(1) # Use local model path where we applied the fix model_path = "/data2/tianang/projects/mdlm/huggingface/huggingface_model" print(f"Loading model from {model_path}...") try: tokenizer = AutoTokenizer.from_pretrained(model_path) model = MolEmbDLM.from_pretrained(model_path) except Exception as e: print(f"Failed to load model: {e}") # Try loading from local directory if remote fails try: model = MolEmbDLM.from_pretrained(".") except Exception as e2: print(f"Failed to load from local: {e2}") exit(1) model.eval() device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = model.to(device) # Two different SELFIES selfies_list = [ "[C][C][O]", # Ethanol "[C][C][=O][O]" # Acetic Acid ] # Preprocessing from example.py: seq.replace('][', '] [') processed_selfies = [s.replace('][', '] [') for s in selfies_list] print(f"Processed SELFIES: {processed_selfies}") # Tokenize with padding=True to create a batch (essential to test the bug fix) # example.py had padding=False because it was single sequence. print("Tokenizing inputs...") inputs = tokenizer( processed_selfies, padding=True, truncation=True, return_tensors="pt" ) print(f"Input IDs:\n{inputs['input_ids']}") print(f"Attention Mask:\n{inputs['attention_mask']}") inputs = {k: v.to(device) for k, v in inputs.items() if k in ["input_ids", "attention_mask"]} print("Running model...") with torch.no_grad(): embeddings = model(**inputs) print(f"Embeddings shape: {embeddings.shape}") # Compare embeddings of the two molecules # We compare the mean embedding or the first token embedding emb1 = embeddings[0] emb2 = embeddings[1] # Calculate difference diff = torch.abs(emb1 - emb2).sum().item() print(f"Difference between embeddings (sum of abs diff): {diff}") if diff < 1e-6: print("ISSUE: Embeddings are identical.") else: print("SUCCESS: Embeddings are different.") print(f"Emb1 mean: {emb1.mean().item()}") print(f"Emb2 mean: {emb2.mean().item()}")