MolCRAFT / core /utils /transforms_prop.py
Atomu2014's picture
demo init commit
1f0c7b9
import torch
import torch.nn.functional as F
from torch_geometric.nn import knn_graph
from torch_geometric.utils.num_nodes import maybe_num_nodes
from torch_scatter import scatter_add
from core.datasets.pl_data import ProteinLigandData
from core.datasets.protein_ligand import ATOM_FEATS
class FeaturizeProteinAtom(object):
def __init__(self):
super().__init__()
self.atomic_numbers = torch.LongTensor([1, 6, 7, 8, 16, 34]) # H, C, N, O, S, Se
self.max_num_aa = 20
@property
def feature_dim(self):
return self.atomic_numbers.size(0) + self.max_num_aa + 1
def __call__(self, data: ProteinLigandData):
element = data.protein_element.view(-1, 1) == self.atomic_numbers.view(1, -1) # (N_atoms, N_elements)
amino_acid = F.one_hot(data.protein_atom_to_aa_type, num_classes=self.max_num_aa)
is_backbone = data.protein_is_backbone.view(-1, 1).long()
x = torch.cat([element, amino_acid, is_backbone], dim=-1)
data.protein_atom_feature = x
return data
class FeaturizeLigandAtom(object):
def __init__(self):
super().__init__()
self.atomic_numbers = torch.LongTensor([1, 6, 7, 8, 9, 15, 16, 17]) # H, C, N, O, F, P, S, Cl
# self.n_degree = torch.LongTensor([0, 1, 2, 3, 4, 5]) # 0 - 5
# self.n_num_hs = 6 # 0 - 5
@property
def num_properties(self):
return sum(ATOM_FEATS.values())
@property
def feature_dim(self):
return self.atomic_numbers.size(0) + self.num_properties
def __call__(self, data: ProteinLigandData):
element = data.ligand_element.view(-1, 1) == self.atomic_numbers.view(1, -1) # (N_atoms, N_elements)
# convert some features to one-hot vectors
atom_feature = []
for i, (k, v) in enumerate(ATOM_FEATS.items()):
feat = data.ligand_atom_feature[:, i:i+1]
if v > 1:
feat = (feat == torch.LongTensor(list(range(v))).view(1, -1))
else:
if k == 'AtomicNumber':
feat = feat / 100.
atom_feature.append(feat)
# atomic_number = data.ligand_atom_feature[:, 0:1]
# aromatic = data.ligand_atom_feature[:, 1:2]
# degree = data.ligand_atom_feature[:, 2:3] == self.n_degree.view(1, -1)
# num_hs = F.one_hot(data.ligand_atom_feature[:, 3], num_classes=self.n_num_hs)
# data.ligand_atom_feature_full = torch.cat([element, atomic_number, aromatic, degree, num_hs], dim=-1)
atom_feature = torch.cat(atom_feature, dim=-1)
data.ligand_atom_feature_full = torch.cat([element, atom_feature], dim=-1)
return data
class FeaturizeLigandBond(object):
def __init__(self):
super().__init__()
def __call__(self, data: ProteinLigandData):
data.ligand_bond_feature = F.one_hot(data.ligand_bond_type - 1, num_classes=4) # (1,2,3) to (0,1,2)-onehot
return data
class LigandCountNeighbors(object):
@staticmethod
def count_neighbors(edge_index, symmetry, valence=None, num_nodes=None):
assert symmetry == True, 'Only support symmetrical edges.'
if num_nodes is None:
num_nodes = maybe_num_nodes(edge_index)
if valence is None:
valence = torch.ones([edge_index.size(1)], device=edge_index.device)
valence = valence.view(edge_index.size(1))
return scatter_add(valence, index=edge_index[0], dim=0, dim_size=num_nodes).long()
def __init__(self):
super().__init__()
def __call__(self, data):
data.ligand_num_neighbors = self.count_neighbors(
data.ligand_bond_index,
symmetry=True,
num_nodes=data.ligand_element.size(0),
)
data.ligand_atom_valence = self.count_neighbors(
data.ligand_bond_index,
symmetry=True,
valence=data.ligand_bond_type,
num_nodes=data.ligand_element.size(0),
)
return data
class EdgeConnection(object):
def __init__(self, kind, k):
super(EdgeConnection, self).__init__()
self.kind = kind
self.k = k
def __call__(self, data):
pos = torch.cat([data.protein_pos, data.ligand_pos], dim=0)
if self.kind == 'knn':
data.edge_index = knn_graph(pos, k=self.k, flow='target_to_source')
return data
def convert_to_single_emb(x, offset=128):
feature_num = x.size(1) if len(x.size()) > 1 else 1
feature_offset = 1 + torch.arange(0, feature_num * offset, offset, dtype=torch.long)
x = x + feature_offset
return x