import torch import torch.nn.functional as F from torch_geometric.nn import knn_graph from torch_geometric.utils.num_nodes import maybe_num_nodes from torch_scatter import scatter_add from core.datasets.pl_data import ProteinLigandData from core.datasets.protein_ligand import ATOM_FEATS class FeaturizeProteinAtom(object): def __init__(self): super().__init__() self.atomic_numbers = torch.LongTensor([1, 6, 7, 8, 16, 34]) # H, C, N, O, S, Se self.max_num_aa = 20 @property def feature_dim(self): return self.atomic_numbers.size(0) + self.max_num_aa + 1 def __call__(self, data: ProteinLigandData): element = data.protein_element.view(-1, 1) == self.atomic_numbers.view(1, -1) # (N_atoms, N_elements) amino_acid = F.one_hot(data.protein_atom_to_aa_type, num_classes=self.max_num_aa) is_backbone = data.protein_is_backbone.view(-1, 1).long() x = torch.cat([element, amino_acid, is_backbone], dim=-1) data.protein_atom_feature = x return data class FeaturizeLigandAtom(object): def __init__(self): super().__init__() self.atomic_numbers = torch.LongTensor([1, 6, 7, 8, 9, 15, 16, 17]) # H, C, N, O, F, P, S, Cl # self.n_degree = torch.LongTensor([0, 1, 2, 3, 4, 5]) # 0 - 5 # self.n_num_hs = 6 # 0 - 5 @property def num_properties(self): return sum(ATOM_FEATS.values()) @property def feature_dim(self): return self.atomic_numbers.size(0) + self.num_properties def __call__(self, data: ProteinLigandData): element = data.ligand_element.view(-1, 1) == self.atomic_numbers.view(1, -1) # (N_atoms, N_elements) # convert some features to one-hot vectors atom_feature = [] for i, (k, v) in enumerate(ATOM_FEATS.items()): feat = data.ligand_atom_feature[:, i:i+1] if v > 1: feat = (feat == torch.LongTensor(list(range(v))).view(1, -1)) else: if k == 'AtomicNumber': feat = feat / 100. atom_feature.append(feat) # atomic_number = data.ligand_atom_feature[:, 0:1] # aromatic = data.ligand_atom_feature[:, 1:2] # degree = data.ligand_atom_feature[:, 2:3] == self.n_degree.view(1, -1) # num_hs = F.one_hot(data.ligand_atom_feature[:, 3], num_classes=self.n_num_hs) # data.ligand_atom_feature_full = torch.cat([element, atomic_number, aromatic, degree, num_hs], dim=-1) atom_feature = torch.cat(atom_feature, dim=-1) data.ligand_atom_feature_full = torch.cat([element, atom_feature], dim=-1) return data class FeaturizeLigandBond(object): def __init__(self): super().__init__() def __call__(self, data: ProteinLigandData): data.ligand_bond_feature = F.one_hot(data.ligand_bond_type - 1, num_classes=4) # (1,2,3) to (0,1,2)-onehot return data class LigandCountNeighbors(object): @staticmethod def count_neighbors(edge_index, symmetry, valence=None, num_nodes=None): assert symmetry == True, 'Only support symmetrical edges.' if num_nodes is None: num_nodes = maybe_num_nodes(edge_index) if valence is None: valence = torch.ones([edge_index.size(1)], device=edge_index.device) valence = valence.view(edge_index.size(1)) return scatter_add(valence, index=edge_index[0], dim=0, dim_size=num_nodes).long() def __init__(self): super().__init__() def __call__(self, data): data.ligand_num_neighbors = self.count_neighbors( data.ligand_bond_index, symmetry=True, num_nodes=data.ligand_element.size(0), ) data.ligand_atom_valence = self.count_neighbors( data.ligand_bond_index, symmetry=True, valence=data.ligand_bond_type, num_nodes=data.ligand_element.size(0), ) return data class EdgeConnection(object): def __init__(self, kind, k): super(EdgeConnection, self).__init__() self.kind = kind self.k = k def __call__(self, data): pos = torch.cat([data.protein_pos, data.ligand_pos], dim=0) if self.kind == 'knn': data.edge_index = knn_graph(pos, k=self.k, flow='target_to_source') return data def convert_to_single_emb(x, offset=128): feature_num = x.size(1) if len(x.size()) > 1 else 1 feature_offset = 1 + torch.arange(0, feature_num * offset, offset, dtype=torch.long) x = x + feature_offset return x