MolCRAFT / core /datasets /pl_pair_dataset.py
Atomu2014's picture
remove vina
3b77f6e
# import os
# import pickle
# import lmdb
# from torch.utils.data import Dataset
# from tqdm.auto import tqdm
# import sys
# from time import time
# import torch
# from torch_geometric.transforms import Compose
# from core.datasets.utils import PDBProtein, parse_sdf_file, ATOM_FAMILIES_ID
# from core.datasets.pl_data import ProteinLigandData, torchify_dict
# import core.utils.transforms as trans
# class DBReader:
# def __init__(self, path) -> None:
# self.path = path
# self.db = None
# self.keys = None
# def _connect_db(self):
# """
# Establish read-only database connection
# """
# assert self.db is None, 'A connection has already been opened.'
# self.db = lmdb.open(
# self.path,
# map_size=10*(1024*1024*1024), # 10GB
# create=False,
# subdir=False,
# readonly=True,
# lock=False,
# readahead=False,
# meminit=False,
# )
# with self.db.begin() as txn:
# self.keys = list(txn.cursor().iternext(values=False))
# def _close_db(self):
# self.db.close()
# self.db = None
# self.keys = None
# def __del__(self):
# if self.db is not None:
# self._close_db()
# def __len__(self):
# if self.db is None:
# self._connect_db()
# return len(self.keys)
# def __getitem__(self, idx):
# if self.db is None:
# self._connect_db()
# key = self.keys[idx]
# data = pickle.loads(self.db.begin().get(key))
# data = ProteinLigandData(**data)
# data.id = idx
# assert data.protein_pos.size(0) > 0
# return data
# class PocketLigandPairDataset(Dataset):
# def __init__(self, raw_path, transform=None, version='final'):
# super().__init__()
# self.raw_path = raw_path.rstrip('/')
# self.index_path = os.path.join(self.raw_path, 'index.pkl')
# self.processed_path = os.path.join(os.path.dirname(self.raw_path),
# os.path.basename(self.raw_path) + f'_processed_{version}.lmdb')
# self.transform = transform
# self.reader = DBReader(self.processed_path)
# if not os.path.exists(self.processed_path):
# print(f'{self.processed_path} does not exist, begin processing data')
# self._process()
# def _process(self):
# db = lmdb.open(
# self.processed_path,
# map_size=10*(1024*1024*1024), # 10GB
# create=True,
# subdir=False,
# readonly=False, # Writable
# )
# with open(self.index_path, 'rb') as f:
# index = pickle.load(f)
# num_skipped = 0
# with db.begin(write=True, buffers=True) as txn:
# for i, (pocket_fn, ligand_fn, *_) in enumerate(tqdm(index)):
# if pocket_fn is None: continue
# try:
# # data_prefix = '/data/work/jiaqi/binding_affinity'
# data_prefix = self.raw_path
# pocket_dict = PDBProtein(os.path.join(data_prefix, pocket_fn)).to_dict_atom()
# ligand_dict = parse_sdf_file(os.path.join(data_prefix, ligand_fn))
# data = ProteinLigandData.from_protein_ligand_dicts(
# protein_dict=torchify_dict(pocket_dict),
# ligand_dict=torchify_dict(ligand_dict),
# )
# data.protein_filename = pocket_fn
# data.ligand_filename = ligand_fn
# data = data.to_dict() # avoid torch_geometric version issue
# txn.put(
# key=str(i).encode(),
# value=pickle.dumps(data)
# )
# except:
# num_skipped += 1
# print('Skipping (%d) %s' % (num_skipped, ligand_fn, ))
# continue
# db.close()
# def __len__(self):
# return len(self.reader)
# def __getitem__(self, idx):
# data = self.reader[idx]
# if self.transform is not None:
# data = self.transform(data)
# return data
# class PocketLigandGeneratedPairDataset(Dataset):
# def __init__(self, raw_path, transform=None, version='4-decompdiff'):
# super().__init__()
# self.raw_path = raw_path.rstrip('/')
# self.generated_path = os.path.join('/sharefs/share/sbdd_data/all_results', f'{version}_docked_pose_checked.pt')
# self.processed_path = os.path.join(os.path.dirname(self.raw_path),
# os.path.basename(self.raw_path) + f'_processed_{version}.lmdb')
# self.transform = transform
# self.reader = DBReader(self.processed_path)
# if not os.path.exists(self.processed_path):
# print(f'{self.processed_path} does not exist, begin processing data')
# self._process()
# def _process(self):
# db = lmdb.open(
# self.processed_path,
# map_size=10*(1024*1024*1024), # 10GB
# create=True,
# subdir=False,
# readonly=False, # Writable
# )
# with open(self.generated_path, 'rb') as f:
# results = torch.load(f)
# num_skipped = 0
# with db.begin(write=True, buffers=True) as txn:
# idx = -1
# for i, res in tqdm(enumerate(results), total=len(results)):
# if isinstance(res, dict):
# res = [res]
# for r in res:
# idx += 1
# mol = r["mol"]
# ligand_fn = r["ligand_filename"]
# pocket_fn = os.path.join(
# os.path.dirname(ligand_fn),
# os.path.basename(ligand_fn)[:-4] + '_pocket10.pdb'
# )
# if pocket_fn is None: continue
# try:
# data_prefix = self.raw_path
# pocket_dict = PDBProtein(os.path.join(data_prefix, pocket_fn)).to_dict_atom()
# ligand_dict = parse_sdf_file(mol)
# # ligand_dict = parse_sdf_file(os.path.join(data_prefix, ligand_fn))
# data = ProteinLigandData.from_protein_ligand_dicts(
# protein_dict=torchify_dict(pocket_dict),
# ligand_dict=torchify_dict(ligand_dict),
# )
# data.protein_filename = pocket_fn
# data.ligand_filename = ligand_fn
# data = data.to_dict() # avoid torch_geometric version issue
# txn.put(
# key=str(idx).encode(),
# value=pickle.dumps(data)
# )
# except Exception as e:
# num_skipped += 1
# print('Skipping (%d) %s' % (num_skipped, ligand_fn, ), e)
# continue
# db.close()
# def __len__(self):
# return len(self.reader)
# def __getitem__(self, idx):
# data = self.reader[idx]
# if self.transform is not None:
# data = self.transform(data)
# return data
# class PocketLigandPairDatasetFromComplex(Dataset):
# def __init__(self, raw_path, transform=None, version='final', radius=10.0):
# super().__init__()
# self.raw_path = raw_path.rstrip('/')
# self.index_path = os.path.join(self.raw_path, 'index.pkl')
# base_name = os.path.basename(self.raw_path)
# if 'pocket' in base_name:
# self.processed_path = os.path.join(os.path.dirname(self.raw_path),
# os.path.basename(self.raw_path) + f'_processed_{version}.lmdb')
# else:
# self.processed_path = os.path.join(os.path.dirname(self.raw_path),
# os.path.basename(self.raw_path) + f'_pocket{radius}_processed_{version}.lmdb')
# self.transform = transform
# self.reader = DBReader(self.processed_path)
# self.radius = radius
# if not os.path.exists(self.processed_path):
# print(f'{self.processed_path} does not exist, begin processing data')
# self._process()
# def _process(self):
# db = lmdb.open(
# self.processed_path,
# map_size=10*(1024*1024*1024), # 50GB
# create=True,
# subdir=False,
# readonly=False, # Writable
# max_readers=256,
# )
# with open(self.index_path, 'rb') as f:
# index = pickle.load(f)
# print('Processing data...', 'index', self.index_path, index[0])
# num_skipped = 0
# with db.begin(write=True, buffers=True) as txn:
# for i, (pocket_fn, ligand_fn, *_) in enumerate(tqdm(index)):
# if pocket_fn is None: continue
# try:
# data_prefix = self.raw_path
# # clip pocket
# ligand_dict = parse_sdf_file(os.path.join(data_prefix, ligand_fn))
# protein = PDBProtein(os.path.join(data_prefix, pocket_fn))
# selected = protein.query_residues_ligand(ligand_dict, self.radius)
# pdb_block_pocket = protein.residues_to_pdb_block(selected)
# pocket_dict = PDBProtein(pdb_block_pocket).to_dict_atom()
# # pocket_dict = PDBProtein(os.path.join(data_prefix, pocket_fn)).to_dict_atom()
# # ligand_dict = parse_sdf_file(os.path.join(data_prefix, ligand_fn))
# data = ProteinLigandData.from_protein_ligand_dicts(
# protein_dict=torchify_dict(pocket_dict),
# ligand_dict=torchify_dict(ligand_dict),
# )
# data.protein_filename = pocket_fn
# data.ligand_filename = ligand_fn
# data = data.to_dict() # avoid torch_geometric version issue
# txn.put(
# key=str(i).encode(),
# value=pickle.dumps(data)
# )
# except Exception as e:
# num_skipped += 1
# print('Skipping (%d) %s' % (num_skipped, ligand_fn, ), e)
# with open('skipped.txt', 'a') as f:
# f.write('Skip %s due to %s\n' % (ligand_fn, e))
# continue
# db.close()
# def __len__(self):
# return len(self.reader)
# def __getitem__(self, idx):
# data = self.reader[idx]
# if self.transform is not None:
# data = self.transform(data)
# return data
# class PocketLigandPairDatasetFeaturized(Dataset):
# def __init__(self, raw_path, ligand_atom_mode, version='simple'):
# """
# in simple version, only these features are saved for better IO:
# protein_pos, protein_atom_feature, protein_element,
# ligand_pos, ligand_atom_feature_full, ligand_element
# """
# self.raw_path = raw_path
# self.ligand_atom_mode = ligand_atom_mode
# self.version = version
# if version == 'simple':
# self.features_to_save = [
# 'protein_pos', 'protein_atom_feature', 'protein_element',
# 'ligand_pos', 'ligand_atom_feature_full', 'ligand_element',
# 'protein_filename', 'ligand_filename',
# ]
# else:
# raise NotImplementedError
# self.transformed_path = os.path.join(
# os.path.dirname(self.raw_path), os.path.basename(self.raw_path) +
# f'_{ligand_atom_mode}_transformed_{version}.pt'
# )
# if not os.path.exists(self.transformed_path):
# print(f'{self.transformed_path} does not exist, begin transforming data')
# self._transform()
# else:
# print(f'reading data from {self.transformed_path}...')
# tic = time()
# tr_data = torch.load(self.transformed_path)
# toc = time()
# print(f'{toc - tic} elapsed')
# self.train_data, self.test_data = tr_data['train'], tr_data['test']
# self.protein_atom_feature_dim = tr_data['protein_atom_feature_dim']
# self.ligand_atom_feature_dim = tr_data['ligand_atom_feature_dim']
# def _transform(self):
# raw_dataset = PocketLigandPairDataset(self.raw_path, None, 'final')
# split_path = os.path.join(
# os.path.dirname(self.raw_path), 'crossdocked_pocket10_pose_split.pt',
# )
# split = torch.load(split_path)
# train_ids, test_ids = split['train'], split['test']
# print(f'train_size: {len(train_ids)}, test_size: {len(test_ids)}')
# protein_featurizer = trans.FeaturizeProteinAtom()
# ligand_featurizer = trans.FeaturizeLigandAtom(self.ligand_atom_mode)
# transform_list = [
# protein_featurizer,
# ligand_featurizer,
# # trans.FeaturizeLigandBond(),
# ]
# transform = Compose(transform_list)
# self.protein_atom_feature_dim = protein_featurizer.feature_dim
# self.ligand_atom_feature_dim = ligand_featurizer.feature_dim
# def _transform_subset(ids):
# data_list = []
# for idx in tqdm(ids):
# data = raw_dataset[idx]
# data = transform(data)
# tr_data = {}
# for k in self.features_to_save:
# tr_data[k] = getattr(data, k)
# tr_data['id'] = idx
# tr_data = ProteinLigandData(**tr_data)
# data_list.append(tr_data)
# return data_list
# self.train_data = _transform_subset(train_ids)
# print(f'train_size: {len(self.train_data)}, {sys.getsizeof(self.train_data)}')
# self.test_data = _transform_subset(test_ids)
# print(f'test_size: {len(self.test_data)}, {sys.getsizeof(self.test_data)}')
# torch.save({
# 'train': self.train_data, 'test': self.test_data,
# 'protein_atom_feature_dim': self.protein_atom_feature_dim,
# 'ligand_atom_feature_dim': self.ligand_atom_feature_dim,
# }, self.transformed_path)
# if __name__ == '__main__':
# # original dataset
# dataset = PocketLigandPairDataset('./data/crossdocked_v1.1_rmsd1.0_pocket10')
# print(len(dataset), dataset[0])
# ############################################################
# # test DecompDiffDataset
# # dataset = PocketLigandGeneratedPairDataset('/sharefs/share/sbdd_data/crossdocked_pocket10')
# # print(len(dataset), dataset[0])
# ############################################################
# # test featurized dataset (GPU accelerated)
# # path = '/sharefs/share/sbdd_data/crossdocked_v1.1_rmsd1.0_pocket10'
# # ligand_atom_mode = 'add_aromatic'
# # dataset = PocketLigandPairDatasetFeaturized(path, ligand_atom_mode)
# # train_data, test_data = dataset.train_data, dataset.test_data
# # print(f'train_size: {len(train_data)}, {sys.getsizeof(train_data)}')
# # print(f'test_size: {len(test_data)}, {sys.getsizeof(test_data)}')
# # print(test_data[0], sys.getsizeof(test_data[0]))
# ############################################################
# # test featurization
# # find all atom types
# # atom_types = {(1, False): 0}
# # dataset = PocketLigandPairDataset(path, transform)
# # for i in tqdm(range(len(dataset))):
# # data = dataset[i]
# # element_list = data.ligand_element
# # hybridization_list = data.ligand_hybridization
# # aromatic_list = [v[trans.AROMATIC_FEAT_MAP_IDX] for v in data.ligand_atom_feature]
# # types = [(e, a) for e, h, a in zip(element_list, hybridization_list, aromatic_list)]
# # for t in types:
# # t = (t[0].item(), bool(t[1].item()))
# # if t not in atom_types:
# # atom_types[t] = 0
# # atom_types[t] += 1
# # idx = 0
# # for k in sorted(atom_types.keys()):
# # print(f'{k}: {idx}, # {atom_types[k]}')
# # idx += 1
# ############################################################
# # count atom types
# # type_counter, aromatic_counter, full_counter = {}, {}, {}
# # for i, data in enumerate(tqdm(dataset)):
# # element_list = data.ligand_element
# # hybridization_list = data.ligand_hybridization
# # aromatic_list = [v[trans.AROMATIC_FEAT_MAP_IDX] for v in data.ligand_atom_feature]
# # flag = False
# # for atom_type in element_list:
# # atom_type = int(atom_type.item())
# # if atom_type not in type_counter:
# # type_counter[atom_type] = 0
# # type_counter[atom_type] += 1
# # for e, a in zip(element_list, aromatic_list):
# # e = int(e.item())
# # a = bool(a.item())
# # key = (e, a)
# # if key not in aromatic_counter:
# # aromatic_counter[key] = 0
# # aromatic_counter[key] += 1
# # if key not in trans.MAP_ATOM_TYPE_AROMATIC_TO_INDEX:
# # flag = True
# # for e, h, a in zip(element_list, hybridization_list, aromatic_list):
# # e = int(e.item())
# # a = bool(a.item())
# # key = (e, h, a)
# # if key not in full_counter:
# # full_counter[key] = 0
# # full_counter[key] += 1
# # print('type_counter', type_counter)
# # print('aromatic_counter', aromatic_counter)
# # print('full_counter', full_counter)