File size: 5,050 Bytes
1f0c7b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
from rdkit import Chem
from typing import Any, Optional
import pytorch_lightning as pl
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities.types import STEP_OUTPUT
# from torch_geometric.data import Data
from torch_scatter import scatter_mean
import numpy as np
# import torch
import os
from tqdm import tqdm
# import pickle as pkl
import json
import matplotlib
# import wandb
# import copy
# import glob
import shutil

from core.evaluation.metrics import CondMolGenMetric
# from core.evaluation.utils import convert_atomcloud_to_mol_smiles, save_mol_list
# from core.evaluation.visualization import visualize, visualize_chain
# from core.utils import transforms as trans
# from core.evaluation.utils import timing

# this file contains the model which we used to visualize the

matplotlib.use("Agg")

import matplotlib.pyplot as plt


# TODO: refactor and move center_pos (and that in train_bfn.py) into utils
def center_pos(protein_pos, ligand_pos, batch_protein, batch_ligand, mode='protein'):
    if mode == 'none':
        offset = 0.
        pass
    elif mode == 'protein':
        offset = scatter_mean(protein_pos, batch_protein, dim=0)
        protein_pos = protein_pos - offset[batch_protein]
        ligand_pos = ligand_pos - offset[batch_ligand]
    else:
        raise NotImplementedError
    return protein_pos, ligand_pos, offset


OUT_DIR = './output'
LAST_PROTEIN_FN = None


class DockingTestCallback(Callback):
    def __init__(self, dataset, atom_enc_mode, atom_decoder, atom_type_one_hot, single_bond, docking_config) -> None:
        super().__init__()
        self.dataset = dataset
        self.atom_enc_mode = atom_enc_mode
        self.atom_decoder = atom_decoder
        self.single_bond = single_bond
        self.type_one_hot = atom_type_one_hot
        self.docking_config = docking_config
        self.outputs = []
    
    def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
        super().setup(trainer, pl_module, stage)
        self.metric = CondMolGenMetric(
            atom_decoder=self.atom_decoder,
            atom_enc_mode=self.atom_enc_mode,
            type_one_hot=self.type_one_hot,
            single_bond=self.single_bond,
            docking_config=self.docking_config,
        )
    
    def on_test_batch_end(
        self,
        trainer: Trainer,
        pl_module: LightningModule,
        outputs: STEP_OUTPUT,
        batch: Any,
        batch_idx: int,
        dataloader_idx: int = 0,
    ) -> None:
        super().on_test_batch_end(
            trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
        )
        self.outputs.extend(outputs)

    def on_test_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
        super().on_test_start(trainer, pl_module)
        self.outputs = []

    def on_test_epoch_end(
        self, trainer: Trainer, pl_module: LightningModule
    ) -> None:
        super().on_test_epoch_end(trainer, pl_module)

        path = pl_module.cfg.accounting.test_outputs_dir
        if os.path.exists(path):
            shutil.rmtree(path)
        os.makedirs(path, exist_ok=True)


        if os.path.exists(OUT_DIR):
            shutil.rmtree(OUT_DIR)
        os.makedirs(OUT_DIR, exist_ok=True)

        for idx, graph in enumerate(tqdm(self.outputs, total=len(self.outputs), desc="Chem eval")):
            try:
                mol = graph.mol
                ligand_filename = graph.ligand_filename
                mol.SetProp('_Name', ligand_filename)
            
                Chem.SanitizeMol(mol)
                smiles = Chem.MolToSmiles(mol)
                validity = smiles is not None
                complete = '.' not in smiles
            except:
                print('sanitize failed')
                continue

            if not validity or not complete:
                print('validity', validity, 'complete', complete)
                continue
        
            # mol_frags = Chem.rdmolops.GetMolFrags(mol, asMols=True)
            # if len(mol_frags) > 1:
            #     print('molecule incomplete')
            #     continue
                
            

            # _info = mol.GetRingInfo()
            # _sizes = [len(r) for r in _info.AtomRings()]
            # if len(_sizes) and max(_sizes) > 9:
            #     print('contain a large ring')
            #     continue
            ligand_filename = graph.ligand_filename
            # protein_fn = glob.glob(os.path.dirname(ligand_filename) + '/*.pdb')
            # protein_fn = [f for f in protein_fn if 'tmp' not in f]
            ligand_dir = os.path.dirname(ligand_filename)
            ligand_fn = os.path.basename(ligand_filename)
            protein_fn = os.path.join(ligand_dir, ligand_fn[:10] + '.pdb')
    
            # print(json.dumps(chem_results, indent=4, cls=NpEncoder))

            out_fn = os.path.join(OUT_DIR, f'{idx}.sdf')
            with Chem.SDWriter(out_fn) as w:
                w.write(mol)