# used in the pdbbind dataset import sys from io import StringIO import numpy as np import torch from rdkit import Chem from rdkit.Chem.rdchem import BondType, HybridizationType from torch_scatter import scatter ATOM_FAMILIES = ['Acceptor', 'Donor', 'Aromatic', 'Hydrophobe', 'LumpedHydrophobe', 'NegIonizable', 'PosIonizable', 'ZnBinder'] ATOM_FAMILIES_ID = {s: i for i, s in enumerate(ATOM_FAMILIES)} ATOM_FEATS = {'AtomicNumber': 1, 'Aromatic': 1, 'Degree': 6, 'NumHs': 6, 'Hybridization': len(HybridizationType.values)} BOND_TYPES = {t: i for i, t in enumerate(BondType.names.values())} BOND_NAMES = {i: t for i, t in enumerate(BondType.names.keys())} KMAP = {'Ki': 1, 'Kd': 2, 'IC50': 3} def get_ligand_atom_features(rdmol): num_atoms = rdmol.GetNumAtoms() atomic_number = [] aromatic = [] # sp, sp2, sp3 = [], [], [] hybrid = [] degree = [] for atom_idx in range(num_atoms): atom = rdmol.GetAtomWithIdx(atom_idx) atomic_number.append(atom.GetAtomicNum()) aromatic.append(1 if atom.GetIsAromatic() else 0) hybridization = atom.GetHybridization() HYBRID_TYPES = {t: i for i, t in enumerate(HybridizationType.names.values())} hybrid.append(HYBRID_TYPES[hybridization]) # sp.append(1 if hybridization == HybridizationType.SP else 0) # sp2.append(1 if hybridization == HybridizationType.SP2 else 0) # sp3.append(1 if hybridization == HybridizationType.SP3 else 0) degree.append(atom.GetDegree()) node_type = torch.tensor(atomic_number, dtype=torch.long) row, col = [], [] for bond in rdmol.GetBonds(): start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() row += [start, end] col += [end, start] row = torch.tensor(row, dtype=torch.long) col = torch.tensor(col, dtype=torch.long) hs = (node_type == 1).to(torch.float) num_hs = scatter(hs[row], col, dim_size=num_atoms).numpy() # need to change ATOM_FEATS accordingly feat_mat = np.array([atomic_number, aromatic, degree, num_hs, hybrid], dtype=np.long).transpose() return feat_mat # used for fixing some errors in sdf file def parse_sdf_file_text(path): with open(path, 'r') as f: sdf = f.read() sdf = sdf.splitlines() num_atoms, num_bonds = map(int, [sdf[3][0:3], sdf[3][3:6]]) ptable = Chem.GetPeriodicTable() element, pos = [], [] accum_pos = np.array([0.0, 0.0, 0.0], dtype=np.float32) accum_mass = 0.0 for atom_line in map(lambda x:x.split(), sdf[4:4+num_atoms]): x, y, z = map(float, atom_line[:3]) symb = atom_line[3] atomic_number = ptable.GetAtomicNumber(symb.capitalize()) element.append(atomic_number) pos.append([x, y, z]) atomic_weight = ptable.GetAtomicWeight(atomic_number) accum_pos += np.array([x, y, z]) * atomic_weight accum_mass += atomic_weight center_of_mass = np.array(accum_pos / accum_mass, dtype=np.float32) element = np.array(element, dtype=np.int) pos = np.array(pos, dtype=np.float32) BOND_TYPES = {t: i for i, t in enumerate(BondType.names.values())} bond_type_map = { 1: BOND_TYPES[BondType.SINGLE], 2: BOND_TYPES[BondType.DOUBLE], 3: BOND_TYPES[BondType.TRIPLE], 4: BOND_TYPES[BondType.AROMATIC], 8: BOND_TYPES[BondType.UNSPECIFIED] } row, col, edge_type = [], [], [] for bond_line in sdf[4+num_atoms:4+num_atoms+num_bonds]: start, end = int(bond_line[0:3])-1, int(bond_line[3:6])-1 row += [start, end] col += [end, start] edge_type += 2 * [bond_type_map[int(bond_line[6:9])]] edge_index = np.array([row, col], dtype=np.long) edge_type = np.array(edge_type, dtype=np.long) perm = (edge_index[0] * num_atoms + edge_index[1]).argsort() edge_index = edge_index[:, perm] edge_type = edge_type[perm] data = { 'element': element, 'pos': pos, 'bond_index': edge_index, 'bond_type': edge_type, 'center_of_mass': center_of_mass } return data # used for preparing the dataset def read_mol(sdf_fileName, mol2_fileName, verbose=False): Chem.WrapLogs() stderr = sys.stderr sio = sys.stderr = StringIO() mol = Chem.MolFromMolFile(sdf_fileName, sanitize=False) problem = False ligand_path = None try: Chem.SanitizeMol(mol) mol = Chem.RemoveHs(mol) sm = Chem.MolToSmiles(mol) ligand_path = sdf_fileName except Exception as e: sm = str(e) problem = True if problem: mol = Chem.MolFromMol2File(mol2_fileName, sanitize=False) problem = False try: Chem.SanitizeMol(mol) mol = Chem.RemoveHs(mol) sm = Chem.MolToSmiles(mol) problem = False ligand_path = mol2_fileName except Exception as e: sm = str(e) problem = True if verbose: print(sio.getvalue()) sys.stderr = stderr return mol, problem, ligand_path def parse_sdf_file_mol(path, heavy_only=True, mol=None): if mol is None: if path.endswith('.sdf'): mol = Chem.MolFromMolFile(path, sanitize=False) elif path.endswith('.mol2'): mol = Chem.MolFromMol2File(path, sanitize=False) else: raise ValueError Chem.SanitizeMol(mol) if heavy_only: mol = Chem.RemoveHs(mol) # mol = next(iter(Chem.SDMolSupplier(path, removeHs=heavy_only))) feat_mat = get_ligand_atom_features(mol) # fdefName = os.path.join(RDConfig.RDDataDir, 'BaseFeatures.fdef') # factory = ChemicalFeatures.BuildFeatureFactory(fdefName) # rdmol = next(iter(Chem.SDMolSupplier(path, removeHs=heavy_only))) # rd_num_atoms = rdmol.GetNumAtoms() # feat_mat = np.zeros([rd_num_atoms, len(ATOM_FAMILIES)], dtype=np.long) # for feat in factory.GetFeaturesForMol(rdmol): # feat_mat[feat.GetAtomIds(), ATOM_FAMILIES_ID[feat.GetFamily()]] = 1 ptable = Chem.GetPeriodicTable() num_atoms = mol.GetNumAtoms() num_bonds = mol.GetNumBonds() pos = mol.GetConformer().GetPositions() element = [] accum_pos = np.array([0.0, 0.0, 0.0], dtype=np.float32) accum_mass = 0.0 for atom_idx in range(num_atoms): atom = mol.GetAtomWithIdx(atom_idx) atomic_number = atom.GetAtomicNum() element.append(atomic_number) x, y, z = pos[atom_idx] atomic_weight = ptable.GetAtomicWeight(atomic_number) accum_pos += np.array([x, y, z]) * atomic_weight accum_mass += atomic_weight center_of_mass = np.array(accum_pos / accum_mass, dtype=np.float32) element = np.array(element, dtype=np.int) pos = np.array(pos, dtype=np.float32) row, col, edge_type = [], [], [] BOND_TYPES = {t: i for i, t in enumerate(BondType.names.values())} for bond in mol.GetBonds(): start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() row += [start, end] col += [end, start] edge_type += 2 * [BOND_TYPES[bond.GetBondType()]] edge_index = np.array([row, col], dtype=np.long) edge_type = np.array(edge_type, dtype=np.long) perm = (edge_index[0] * num_atoms + edge_index[1]).argsort() edge_index = edge_index[:, perm] edge_type = edge_type[perm] data = { 'element': element, 'pos': pos, 'bond_index': edge_index, 'bond_type': edge_type, 'center_of_mass': center_of_mass, 'atom_feature': feat_mat } return data