from rdkit import Chem from typing import Any, Optional import pytorch_lightning as pl from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.callbacks import Callback from pytorch_lightning.utilities.types import STEP_OUTPUT # from torch_geometric.data import Data from torch_scatter import scatter_mean import numpy as np # import torch import os from tqdm import tqdm # import pickle as pkl import json import matplotlib # import wandb # import copy # import glob import shutil from core.evaluation.metrics import CondMolGenMetric # from core.evaluation.utils import convert_atomcloud_to_mol_smiles, save_mol_list # from core.evaluation.visualization import visualize, visualize_chain # from core.utils import transforms as trans # from core.evaluation.utils import timing # this file contains the model which we used to visualize the matplotlib.use("Agg") import matplotlib.pyplot as plt # TODO: refactor and move center_pos (and that in train_bfn.py) into utils def center_pos(protein_pos, ligand_pos, batch_protein, batch_ligand, mode='protein'): if mode == 'none': offset = 0. pass elif mode == 'protein': offset = scatter_mean(protein_pos, batch_protein, dim=0) protein_pos = protein_pos - offset[batch_protein] ligand_pos = ligand_pos - offset[batch_ligand] else: raise NotImplementedError return protein_pos, ligand_pos, offset OUT_DIR = './output' LAST_PROTEIN_FN = None class DockingTestCallback(Callback): def __init__(self, dataset, atom_enc_mode, atom_decoder, atom_type_one_hot, single_bond, docking_config) -> None: super().__init__() self.dataset = dataset self.atom_enc_mode = atom_enc_mode self.atom_decoder = atom_decoder self.single_bond = single_bond self.type_one_hot = atom_type_one_hot self.docking_config = docking_config self.outputs = [] def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None: super().setup(trainer, pl_module, stage) self.metric = CondMolGenMetric( atom_decoder=self.atom_decoder, atom_enc_mode=self.atom_enc_mode, type_one_hot=self.type_one_hot, single_bond=self.single_bond, docking_config=self.docking_config, ) def on_test_batch_end( self, trainer: Trainer, pl_module: LightningModule, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int = 0, ) -> None: super().on_test_batch_end( trainer, pl_module, outputs, batch, batch_idx, dataloader_idx ) self.outputs.extend(outputs) def on_test_start(self, trainer: Trainer, pl_module: LightningModule) -> None: super().on_test_start(trainer, pl_module) self.outputs = [] def on_test_epoch_end( self, trainer: Trainer, pl_module: LightningModule ) -> None: super().on_test_epoch_end(trainer, pl_module) path = pl_module.cfg.accounting.test_outputs_dir if os.path.exists(path): shutil.rmtree(path) os.makedirs(path, exist_ok=True) if os.path.exists(OUT_DIR): shutil.rmtree(OUT_DIR) os.makedirs(OUT_DIR, exist_ok=True) for idx, graph in enumerate(tqdm(self.outputs, total=len(self.outputs), desc="Chem eval")): try: mol = graph.mol ligand_filename = graph.ligand_filename mol.SetProp('_Name', ligand_filename) Chem.SanitizeMol(mol) smiles = Chem.MolToSmiles(mol) validity = smiles is not None complete = '.' not in smiles except: print('sanitize failed') continue if not validity or not complete: print('validity', validity, 'complete', complete) continue # mol_frags = Chem.rdmolops.GetMolFrags(mol, asMols=True) # if len(mol_frags) > 1: # print('molecule incomplete') # continue # _info = mol.GetRingInfo() # _sizes = [len(r) for r in _info.AtomRings()] # if len(_sizes) and max(_sizes) > 9: # print('contain a large ring') # continue ligand_filename = graph.ligand_filename # protein_fn = glob.glob(os.path.dirname(ligand_filename) + '/*.pdb') # protein_fn = [f for f in protein_fn if 'tmp' not in f] ligand_dir = os.path.dirname(ligand_filename) ligand_fn = os.path.basename(ligand_filename) protein_fn = os.path.join(ligand_dir, ligand_fn[:10] + '.pdb') # print(json.dumps(chem_results, indent=4, cls=NpEncoder)) out_fn = os.path.join(OUT_DIR, f'{idx}.sdf') with Chem.SDWriter(out_fn) as w: w.write(mol)