File size: 5,050 Bytes
1f0c7b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
from rdkit import Chem
from typing import Any, Optional
import pytorch_lightning as pl
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities.types import STEP_OUTPUT
# from torch_geometric.data import Data
from torch_scatter import scatter_mean
import numpy as np
# import torch
import os
from tqdm import tqdm
# import pickle as pkl
import json
import matplotlib
# import wandb
# import copy
# import glob
import shutil
from core.evaluation.metrics import CondMolGenMetric
# from core.evaluation.utils import convert_atomcloud_to_mol_smiles, save_mol_list
# from core.evaluation.visualization import visualize, visualize_chain
# from core.utils import transforms as trans
# from core.evaluation.utils import timing
# this file contains the model which we used to visualize the
matplotlib.use("Agg")
import matplotlib.pyplot as plt
# TODO: refactor and move center_pos (and that in train_bfn.py) into utils
def center_pos(protein_pos, ligand_pos, batch_protein, batch_ligand, mode='protein'):
if mode == 'none':
offset = 0.
pass
elif mode == 'protein':
offset = scatter_mean(protein_pos, batch_protein, dim=0)
protein_pos = protein_pos - offset[batch_protein]
ligand_pos = ligand_pos - offset[batch_ligand]
else:
raise NotImplementedError
return protein_pos, ligand_pos, offset
OUT_DIR = './output'
LAST_PROTEIN_FN = None
class DockingTestCallback(Callback):
def __init__(self, dataset, atom_enc_mode, atom_decoder, atom_type_one_hot, single_bond, docking_config) -> None:
super().__init__()
self.dataset = dataset
self.atom_enc_mode = atom_enc_mode
self.atom_decoder = atom_decoder
self.single_bond = single_bond
self.type_one_hot = atom_type_one_hot
self.docking_config = docking_config
self.outputs = []
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
super().setup(trainer, pl_module, stage)
self.metric = CondMolGenMetric(
atom_decoder=self.atom_decoder,
atom_enc_mode=self.atom_enc_mode,
type_one_hot=self.type_one_hot,
single_bond=self.single_bond,
docking_config=self.docking_config,
)
def on_test_batch_end(
self,
trainer: Trainer,
pl_module: LightningModule,
outputs: STEP_OUTPUT,
batch: Any,
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
super().on_test_batch_end(
trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
)
self.outputs.extend(outputs)
def on_test_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
super().on_test_start(trainer, pl_module)
self.outputs = []
def on_test_epoch_end(
self, trainer: Trainer, pl_module: LightningModule
) -> None:
super().on_test_epoch_end(trainer, pl_module)
path = pl_module.cfg.accounting.test_outputs_dir
if os.path.exists(path):
shutil.rmtree(path)
os.makedirs(path, exist_ok=True)
if os.path.exists(OUT_DIR):
shutil.rmtree(OUT_DIR)
os.makedirs(OUT_DIR, exist_ok=True)
for idx, graph in enumerate(tqdm(self.outputs, total=len(self.outputs), desc="Chem eval")):
try:
mol = graph.mol
ligand_filename = graph.ligand_filename
mol.SetProp('_Name', ligand_filename)
Chem.SanitizeMol(mol)
smiles = Chem.MolToSmiles(mol)
validity = smiles is not None
complete = '.' not in smiles
except:
print('sanitize failed')
continue
if not validity or not complete:
print('validity', validity, 'complete', complete)
continue
# mol_frags = Chem.rdmolops.GetMolFrags(mol, asMols=True)
# if len(mol_frags) > 1:
# print('molecule incomplete')
# continue
# _info = mol.GetRingInfo()
# _sizes = [len(r) for r in _info.AtomRings()]
# if len(_sizes) and max(_sizes) > 9:
# print('contain a large ring')
# continue
ligand_filename = graph.ligand_filename
# protein_fn = glob.glob(os.path.dirname(ligand_filename) + '/*.pdb')
# protein_fn = [f for f in protein_fn if 'tmp' not in f]
ligand_dir = os.path.dirname(ligand_filename)
ligand_fn = os.path.basename(ligand_filename)
protein_fn = os.path.join(ligand_dir, ligand_fn[:10] + '.pdb')
# print(json.dumps(chem_results, indent=4, cls=NpEncoder))
out_fn = os.path.join(OUT_DIR, f'{idx}.sdf')
with Chem.SDWriter(out_fn) as w:
w.write(mol)
|