|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from torch_geometric.nn import knn_graph |
|
|
from torch_geometric.utils.num_nodes import maybe_num_nodes |
|
|
from torch_scatter import scatter_add |
|
|
|
|
|
from core.datasets.pl_data import ProteinLigandData |
|
|
from core.datasets.protein_ligand import ATOM_FEATS |
|
|
|
|
|
|
|
|
class FeaturizeProteinAtom(object): |
|
|
|
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
self.atomic_numbers = torch.LongTensor([1, 6, 7, 8, 16, 34]) |
|
|
self.max_num_aa = 20 |
|
|
|
|
|
@property |
|
|
def feature_dim(self): |
|
|
return self.atomic_numbers.size(0) + self.max_num_aa + 1 |
|
|
|
|
|
def __call__(self, data: ProteinLigandData): |
|
|
element = data.protein_element.view(-1, 1) == self.atomic_numbers.view(1, -1) |
|
|
amino_acid = F.one_hot(data.protein_atom_to_aa_type, num_classes=self.max_num_aa) |
|
|
is_backbone = data.protein_is_backbone.view(-1, 1).long() |
|
|
x = torch.cat([element, amino_acid, is_backbone], dim=-1) |
|
|
data.protein_atom_feature = x |
|
|
return data |
|
|
|
|
|
|
|
|
class FeaturizeLigandAtom(object): |
|
|
|
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
self.atomic_numbers = torch.LongTensor([1, 6, 7, 8, 9, 15, 16, 17]) |
|
|
|
|
|
|
|
|
|
|
|
@property |
|
|
def num_properties(self): |
|
|
return sum(ATOM_FEATS.values()) |
|
|
|
|
|
@property |
|
|
def feature_dim(self): |
|
|
return self.atomic_numbers.size(0) + self.num_properties |
|
|
|
|
|
def __call__(self, data: ProteinLigandData): |
|
|
element = data.ligand_element.view(-1, 1) == self.atomic_numbers.view(1, -1) |
|
|
|
|
|
atom_feature = [] |
|
|
for i, (k, v) in enumerate(ATOM_FEATS.items()): |
|
|
feat = data.ligand_atom_feature[:, i:i+1] |
|
|
if v > 1: |
|
|
feat = (feat == torch.LongTensor(list(range(v))).view(1, -1)) |
|
|
else: |
|
|
if k == 'AtomicNumber': |
|
|
feat = feat / 100. |
|
|
atom_feature.append(feat) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
atom_feature = torch.cat(atom_feature, dim=-1) |
|
|
data.ligand_atom_feature_full = torch.cat([element, atom_feature], dim=-1) |
|
|
return data |
|
|
|
|
|
|
|
|
class FeaturizeLigandBond(object): |
|
|
|
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
|
|
|
def __call__(self, data: ProteinLigandData): |
|
|
data.ligand_bond_feature = F.one_hot(data.ligand_bond_type - 1, num_classes=4) |
|
|
return data |
|
|
|
|
|
|
|
|
class LigandCountNeighbors(object): |
|
|
|
|
|
@staticmethod |
|
|
def count_neighbors(edge_index, symmetry, valence=None, num_nodes=None): |
|
|
assert symmetry == True, 'Only support symmetrical edges.' |
|
|
|
|
|
if num_nodes is None: |
|
|
num_nodes = maybe_num_nodes(edge_index) |
|
|
|
|
|
if valence is None: |
|
|
valence = torch.ones([edge_index.size(1)], device=edge_index.device) |
|
|
valence = valence.view(edge_index.size(1)) |
|
|
|
|
|
return scatter_add(valence, index=edge_index[0], dim=0, dim_size=num_nodes).long() |
|
|
|
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
|
|
|
def __call__(self, data): |
|
|
data.ligand_num_neighbors = self.count_neighbors( |
|
|
data.ligand_bond_index, |
|
|
symmetry=True, |
|
|
num_nodes=data.ligand_element.size(0), |
|
|
) |
|
|
data.ligand_atom_valence = self.count_neighbors( |
|
|
data.ligand_bond_index, |
|
|
symmetry=True, |
|
|
valence=data.ligand_bond_type, |
|
|
num_nodes=data.ligand_element.size(0), |
|
|
) |
|
|
return data |
|
|
|
|
|
|
|
|
class EdgeConnection(object): |
|
|
def __init__(self, kind, k): |
|
|
super(EdgeConnection, self).__init__() |
|
|
self.kind = kind |
|
|
self.k = k |
|
|
|
|
|
def __call__(self, data): |
|
|
pos = torch.cat([data.protein_pos, data.ligand_pos], dim=0) |
|
|
if self.kind == 'knn': |
|
|
data.edge_index = knn_graph(pos, k=self.k, flow='target_to_source') |
|
|
return data |
|
|
|
|
|
|
|
|
def convert_to_single_emb(x, offset=128): |
|
|
feature_num = x.size(1) if len(x.size()) > 1 else 1 |
|
|
feature_offset = 1 + torch.arange(0, feature_num * offset, offset, dtype=torch.long) |
|
|
x = x + feature_offset |
|
|
return x |
|
|
|