import os import pickle import lmdb import torch from torch.utils.data import Dataset from tqdm.auto import tqdm from core.datasets.utils import PDBProtein from core.datasets.protein_ligand import parse_sdf_file_mol from core.datasets.pl_data import ProteinLigandData, torchify_dict from scipy import stats class PDBBindDataset(Dataset): def __init__(self, raw_path, transform=None, emb_path=None, heavy_only=False): super().__init__() self.raw_path = raw_path.rstrip('/') self.index_path = os.path.join(self.raw_path, 'index.pkl') self.processed_path = os.path.join(self.raw_path, os.path.basename(self.raw_path) + '_processed.lmdb') self.emb_path = emb_path self.transform = transform self.heavy_only = heavy_only self.db = None self.keys = None if not os.path.exists(self.processed_path): self._process() print('Load dataset from ', self.processed_path) if self.emb_path is not None: print('Load embedding from ', self.emb_path) self.emb = torch.load(self.emb_path) def _connect_db(self): """ Establish read-only database connection """ assert self.db is None, 'A connection has already been opened.' self.db = lmdb.open( self.processed_path, map_size=10*(1024*1024*1024), # 10GB create=False, subdir=False, readonly=True, lock=False, readahead=False, meminit=False, ) with self.db.begin() as txn: self.keys = list(txn.cursor().iternext(values=False)) def _close_db(self): self.db.close() self.db = None self.keys = None def _process(self): db = lmdb.open( self.processed_path, map_size=10*(1024*1024*1024), # 10GB create=True, subdir=False, readonly=False, # Writable ) with open(self.index_path, 'rb') as f: index = pickle.load(f) # index = parse_pdbbind_index_file(self.index_path) num_skipped = 0 with db.begin(write=True, buffers=True) as txn: for i, (pocket_fn, ligand_fn, resolution, pka, kind) in enumerate(tqdm(index)): # try: # pdb_path = os.path.join(self.raw_path, 'refined-set', pdb_idx) # pocket_fn = os.path.join(pdb_path, f'{pdb_idx}_pocket.pdb') # ligand_fn = os.path.join(pdb_path, f'{pdb_idx}_ligand.sdf') pocket_dict = PDBProtein(pocket_fn).to_dict_atom() ligand_dict = parse_sdf_file_mol(ligand_fn, heavy_only=self.heavy_only) data = ProteinLigandData.from_protein_ligand_dicts( protein_dict=torchify_dict(pocket_dict), ligand_dict=torchify_dict(ligand_dict), ) data.protein_filename = pocket_fn data.ligand_filename = ligand_fn data.y = torch.tensor(float(pka)) data.kind = torch.tensor(kind) txn.put( key=f'{i:05d}'.encode(), value=pickle.dumps(data) ) # except: # num_skipped += 1 # print('Skipping (%d) %s' % (num_skipped, ligand_fn, )) # continue print('num_skipped: ', num_skipped) def __len__(self): if self.db is None: self._connect_db() return len(self.keys) def __getitem__(self, idx): if self.db is None: self._connect_db() key = self.keys[idx] data = pickle.loads(self.db.begin().get(key)) data.id = idx assert data.protein_pos.size(0) > 0 if self.transform is not None: data = self.transform(data) # add features extracted by molopt if self.emb_path is not None: emb = self.emb[idx] data.nll = torch.cat([emb['kl_pos'][1:], emb['kl_v'][1:]]).view(1, -1) data.nll_all = torch.cat([emb['kl_pos'], emb['kl_v']]).view(1, -1) data.pred_ligand_v = torch.softmax(emb['pred_ligand_v'], dim=-1) data.final_h = emb['final_h'] # data.final_ligand_h = emb['final_ligand_h'] data.pred_v_entropy = torch.from_numpy( stats.entropy(torch.softmax(emb['pred_ligand_v'], dim=-1).numpy(), axis=-1)).view(-1, 1) return data if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() parser.add_argument('path', type=str) args = parser.parse_args() PDBBindDataset(args.path)