# import os # import pickle # import lmdb # from torch.utils.data import Dataset # from tqdm.auto import tqdm # import sys # from time import time # import torch # from torch_geometric.transforms import Compose # from core.datasets.utils import PDBProtein, parse_sdf_file, ATOM_FAMILIES_ID # from core.datasets.pl_data import ProteinLigandData, torchify_dict # import core.utils.transforms as trans # class DBReader: # def __init__(self, path) -> None: # self.path = path # self.db = None # self.keys = None # def _connect_db(self): # """ # Establish read-only database connection # """ # assert self.db is None, 'A connection has already been opened.' # self.db = lmdb.open( # self.path, # map_size=10*(1024*1024*1024), # 10GB # create=False, # subdir=False, # readonly=True, # lock=False, # readahead=False, # meminit=False, # ) # with self.db.begin() as txn: # self.keys = list(txn.cursor().iternext(values=False)) # def _close_db(self): # self.db.close() # self.db = None # self.keys = None # def __del__(self): # if self.db is not None: # self._close_db() # def __len__(self): # if self.db is None: # self._connect_db() # return len(self.keys) # def __getitem__(self, idx): # if self.db is None: # self._connect_db() # key = self.keys[idx] # data = pickle.loads(self.db.begin().get(key)) # data = ProteinLigandData(**data) # data.id = idx # assert data.protein_pos.size(0) > 0 # return data # class PocketLigandPairDataset(Dataset): # def __init__(self, raw_path, transform=None, version='final'): # super().__init__() # self.raw_path = raw_path.rstrip('/') # self.index_path = os.path.join(self.raw_path, 'index.pkl') # self.processed_path = os.path.join(os.path.dirname(self.raw_path), # os.path.basename(self.raw_path) + f'_processed_{version}.lmdb') # self.transform = transform # self.reader = DBReader(self.processed_path) # if not os.path.exists(self.processed_path): # print(f'{self.processed_path} does not exist, begin processing data') # self._process() # def _process(self): # db = lmdb.open( # self.processed_path, # map_size=10*(1024*1024*1024), # 10GB # create=True, # subdir=False, # readonly=False, # Writable # ) # with open(self.index_path, 'rb') as f: # index = pickle.load(f) # num_skipped = 0 # with db.begin(write=True, buffers=True) as txn: # for i, (pocket_fn, ligand_fn, *_) in enumerate(tqdm(index)): # if pocket_fn is None: continue # try: # # data_prefix = '/data/work/jiaqi/binding_affinity' # data_prefix = self.raw_path # pocket_dict = PDBProtein(os.path.join(data_prefix, pocket_fn)).to_dict_atom() # ligand_dict = parse_sdf_file(os.path.join(data_prefix, ligand_fn)) # data = ProteinLigandData.from_protein_ligand_dicts( # protein_dict=torchify_dict(pocket_dict), # ligand_dict=torchify_dict(ligand_dict), # ) # data.protein_filename = pocket_fn # data.ligand_filename = ligand_fn # data = data.to_dict() # avoid torch_geometric version issue # txn.put( # key=str(i).encode(), # value=pickle.dumps(data) # ) # except: # num_skipped += 1 # print('Skipping (%d) %s' % (num_skipped, ligand_fn, )) # continue # db.close() # def __len__(self): # return len(self.reader) # def __getitem__(self, idx): # data = self.reader[idx] # if self.transform is not None: # data = self.transform(data) # return data # class PocketLigandGeneratedPairDataset(Dataset): # def __init__(self, raw_path, transform=None, version='4-decompdiff'): # super().__init__() # self.raw_path = raw_path.rstrip('/') # self.generated_path = os.path.join('/sharefs/share/sbdd_data/all_results', f'{version}_docked_pose_checked.pt') # self.processed_path = os.path.join(os.path.dirname(self.raw_path), # os.path.basename(self.raw_path) + f'_processed_{version}.lmdb') # self.transform = transform # self.reader = DBReader(self.processed_path) # if not os.path.exists(self.processed_path): # print(f'{self.processed_path} does not exist, begin processing data') # self._process() # def _process(self): # db = lmdb.open( # self.processed_path, # map_size=10*(1024*1024*1024), # 10GB # create=True, # subdir=False, # readonly=False, # Writable # ) # with open(self.generated_path, 'rb') as f: # results = torch.load(f) # num_skipped = 0 # with db.begin(write=True, buffers=True) as txn: # idx = -1 # for i, res in tqdm(enumerate(results), total=len(results)): # if isinstance(res, dict): # res = [res] # for r in res: # idx += 1 # mol = r["mol"] # ligand_fn = r["ligand_filename"] # pocket_fn = os.path.join( # os.path.dirname(ligand_fn), # os.path.basename(ligand_fn)[:-4] + '_pocket10.pdb' # ) # if pocket_fn is None: continue # try: # data_prefix = self.raw_path # pocket_dict = PDBProtein(os.path.join(data_prefix, pocket_fn)).to_dict_atom() # ligand_dict = parse_sdf_file(mol) # # ligand_dict = parse_sdf_file(os.path.join(data_prefix, ligand_fn)) # data = ProteinLigandData.from_protein_ligand_dicts( # protein_dict=torchify_dict(pocket_dict), # ligand_dict=torchify_dict(ligand_dict), # ) # data.protein_filename = pocket_fn # data.ligand_filename = ligand_fn # data = data.to_dict() # avoid torch_geometric version issue # txn.put( # key=str(idx).encode(), # value=pickle.dumps(data) # ) # except Exception as e: # num_skipped += 1 # print('Skipping (%d) %s' % (num_skipped, ligand_fn, ), e) # continue # db.close() # def __len__(self): # return len(self.reader) # def __getitem__(self, idx): # data = self.reader[idx] # if self.transform is not None: # data = self.transform(data) # return data # class PocketLigandPairDatasetFromComplex(Dataset): # def __init__(self, raw_path, transform=None, version='final', radius=10.0): # super().__init__() # self.raw_path = raw_path.rstrip('/') # self.index_path = os.path.join(self.raw_path, 'index.pkl') # base_name = os.path.basename(self.raw_path) # if 'pocket' in base_name: # self.processed_path = os.path.join(os.path.dirname(self.raw_path), # os.path.basename(self.raw_path) + f'_processed_{version}.lmdb') # else: # self.processed_path = os.path.join(os.path.dirname(self.raw_path), # os.path.basename(self.raw_path) + f'_pocket{radius}_processed_{version}.lmdb') # self.transform = transform # self.reader = DBReader(self.processed_path) # self.radius = radius # if not os.path.exists(self.processed_path): # print(f'{self.processed_path} does not exist, begin processing data') # self._process() # def _process(self): # db = lmdb.open( # self.processed_path, # map_size=10*(1024*1024*1024), # 50GB # create=True, # subdir=False, # readonly=False, # Writable # max_readers=256, # ) # with open(self.index_path, 'rb') as f: # index = pickle.load(f) # print('Processing data...', 'index', self.index_path, index[0]) # num_skipped = 0 # with db.begin(write=True, buffers=True) as txn: # for i, (pocket_fn, ligand_fn, *_) in enumerate(tqdm(index)): # if pocket_fn is None: continue # try: # data_prefix = self.raw_path # # clip pocket # ligand_dict = parse_sdf_file(os.path.join(data_prefix, ligand_fn)) # protein = PDBProtein(os.path.join(data_prefix, pocket_fn)) # selected = protein.query_residues_ligand(ligand_dict, self.radius) # pdb_block_pocket = protein.residues_to_pdb_block(selected) # pocket_dict = PDBProtein(pdb_block_pocket).to_dict_atom() # # pocket_dict = PDBProtein(os.path.join(data_prefix, pocket_fn)).to_dict_atom() # # ligand_dict = parse_sdf_file(os.path.join(data_prefix, ligand_fn)) # data = ProteinLigandData.from_protein_ligand_dicts( # protein_dict=torchify_dict(pocket_dict), # ligand_dict=torchify_dict(ligand_dict), # ) # data.protein_filename = pocket_fn # data.ligand_filename = ligand_fn # data = data.to_dict() # avoid torch_geometric version issue # txn.put( # key=str(i).encode(), # value=pickle.dumps(data) # ) # except Exception as e: # num_skipped += 1 # print('Skipping (%d) %s' % (num_skipped, ligand_fn, ), e) # with open('skipped.txt', 'a') as f: # f.write('Skip %s due to %s\n' % (ligand_fn, e)) # continue # db.close() # def __len__(self): # return len(self.reader) # def __getitem__(self, idx): # data = self.reader[idx] # if self.transform is not None: # data = self.transform(data) # return data # class PocketLigandPairDatasetFeaturized(Dataset): # def __init__(self, raw_path, ligand_atom_mode, version='simple'): # """ # in simple version, only these features are saved for better IO: # protein_pos, protein_atom_feature, protein_element, # ligand_pos, ligand_atom_feature_full, ligand_element # """ # self.raw_path = raw_path # self.ligand_atom_mode = ligand_atom_mode # self.version = version # if version == 'simple': # self.features_to_save = [ # 'protein_pos', 'protein_atom_feature', 'protein_element', # 'ligand_pos', 'ligand_atom_feature_full', 'ligand_element', # 'protein_filename', 'ligand_filename', # ] # else: # raise NotImplementedError # self.transformed_path = os.path.join( # os.path.dirname(self.raw_path), os.path.basename(self.raw_path) + # f'_{ligand_atom_mode}_transformed_{version}.pt' # ) # if not os.path.exists(self.transformed_path): # print(f'{self.transformed_path} does not exist, begin transforming data') # self._transform() # else: # print(f'reading data from {self.transformed_path}...') # tic = time() # tr_data = torch.load(self.transformed_path) # toc = time() # print(f'{toc - tic} elapsed') # self.train_data, self.test_data = tr_data['train'], tr_data['test'] # self.protein_atom_feature_dim = tr_data['protein_atom_feature_dim'] # self.ligand_atom_feature_dim = tr_data['ligand_atom_feature_dim'] # def _transform(self): # raw_dataset = PocketLigandPairDataset(self.raw_path, None, 'final') # split_path = os.path.join( # os.path.dirname(self.raw_path), 'crossdocked_pocket10_pose_split.pt', # ) # split = torch.load(split_path) # train_ids, test_ids = split['train'], split['test'] # print(f'train_size: {len(train_ids)}, test_size: {len(test_ids)}') # protein_featurizer = trans.FeaturizeProteinAtom() # ligand_featurizer = trans.FeaturizeLigandAtom(self.ligand_atom_mode) # transform_list = [ # protein_featurizer, # ligand_featurizer, # # trans.FeaturizeLigandBond(), # ] # transform = Compose(transform_list) # self.protein_atom_feature_dim = protein_featurizer.feature_dim # self.ligand_atom_feature_dim = ligand_featurizer.feature_dim # def _transform_subset(ids): # data_list = [] # for idx in tqdm(ids): # data = raw_dataset[idx] # data = transform(data) # tr_data = {} # for k in self.features_to_save: # tr_data[k] = getattr(data, k) # tr_data['id'] = idx # tr_data = ProteinLigandData(**tr_data) # data_list.append(tr_data) # return data_list # self.train_data = _transform_subset(train_ids) # print(f'train_size: {len(self.train_data)}, {sys.getsizeof(self.train_data)}') # self.test_data = _transform_subset(test_ids) # print(f'test_size: {len(self.test_data)}, {sys.getsizeof(self.test_data)}') # torch.save({ # 'train': self.train_data, 'test': self.test_data, # 'protein_atom_feature_dim': self.protein_atom_feature_dim, # 'ligand_atom_feature_dim': self.ligand_atom_feature_dim, # }, self.transformed_path) # if __name__ == '__main__': # # original dataset # dataset = PocketLigandPairDataset('./data/crossdocked_v1.1_rmsd1.0_pocket10') # print(len(dataset), dataset[0]) # ############################################################ # # test DecompDiffDataset # # dataset = PocketLigandGeneratedPairDataset('/sharefs/share/sbdd_data/crossdocked_pocket10') # # print(len(dataset), dataset[0]) # ############################################################ # # test featurized dataset (GPU accelerated) # # path = '/sharefs/share/sbdd_data/crossdocked_v1.1_rmsd1.0_pocket10' # # ligand_atom_mode = 'add_aromatic' # # dataset = PocketLigandPairDatasetFeaturized(path, ligand_atom_mode) # # train_data, test_data = dataset.train_data, dataset.test_data # # print(f'train_size: {len(train_data)}, {sys.getsizeof(train_data)}') # # print(f'test_size: {len(test_data)}, {sys.getsizeof(test_data)}') # # print(test_data[0], sys.getsizeof(test_data[0])) # ############################################################ # # test featurization # # find all atom types # # atom_types = {(1, False): 0} # # dataset = PocketLigandPairDataset(path, transform) # # for i in tqdm(range(len(dataset))): # # data = dataset[i] # # element_list = data.ligand_element # # hybridization_list = data.ligand_hybridization # # aromatic_list = [v[trans.AROMATIC_FEAT_MAP_IDX] for v in data.ligand_atom_feature] # # types = [(e, a) for e, h, a in zip(element_list, hybridization_list, aromatic_list)] # # for t in types: # # t = (t[0].item(), bool(t[1].item())) # # if t not in atom_types: # # atom_types[t] = 0 # # atom_types[t] += 1 # # idx = 0 # # for k in sorted(atom_types.keys()): # # print(f'{k}: {idx}, # {atom_types[k]}') # # idx += 1 # ############################################################ # # count atom types # # type_counter, aromatic_counter, full_counter = {}, {}, {} # # for i, data in enumerate(tqdm(dataset)): # # element_list = data.ligand_element # # hybridization_list = data.ligand_hybridization # # aromatic_list = [v[trans.AROMATIC_FEAT_MAP_IDX] for v in data.ligand_atom_feature] # # flag = False # # for atom_type in element_list: # # atom_type = int(atom_type.item()) # # if atom_type not in type_counter: # # type_counter[atom_type] = 0 # # type_counter[atom_type] += 1 # # for e, a in zip(element_list, aromatic_list): # # e = int(e.item()) # # a = bool(a.item()) # # key = (e, a) # # if key not in aromatic_counter: # # aromatic_counter[key] = 0 # # aromatic_counter[key] += 1 # # if key not in trans.MAP_ATOM_TYPE_AROMATIC_TO_INDEX: # # flag = True # # for e, h, a in zip(element_list, hybridization_list, aromatic_list): # # e = int(e.item()) # # a = bool(a.item()) # # key = (e, h, a) # # if key not in full_counter: # # full_counter[key] = 0 # # full_counter[key] += 1 # # print('type_counter', type_counter) # # print('aromatic_counter', aromatic_counter) # # print('full_counter', full_counter)