From ccfcb8c6b4a0eff0b48ef8be09058ddcd8a70aaa Mon Sep 17 00:00:00 2001 From: YanjingLiLi Date: Thu, 19 Oct 2023 04:08:28 -0700 Subject: [PATCH 1/5] add protein task --- Geom3D/datasets/__init__.py | 89 ++-- Geom3D/datasets/datasetFOLD_GVP.py | 345 ------------ Geom3D/datasets/dataset_ECMultiple.py | 425 +++++++++++++++ Geom3D/datasets/dataset_ECMultiple_GearNet.py | 294 ++++++++++ Geom3D/datasets/dataset_ECSingle.py | 237 +++++++++ Geom3D/datasets/dataset_ECSingle_GearNet.py | 192 +++++++ Geom3D/datasets/dataset_FOLD.py | 107 ++-- Geom3D/datasets/dataset_FOLD_CDConv.py | 136 ----- Geom3D/datasets/dataset_FOLD_GearNet.py | 5 +- Geom3D/datasets/dataset_GO.py | 454 ++++++++++++++++ Geom3D/datasets/dataset_GO_GearNet.py | 250 +++++++++ Geom3D/datasets/dataset_GVP.py | 155 ++++++ Geom3D/models/CDConv.py | 35 +- Geom3D/models/GVP.py | 85 +++ Geom3D/models/GearNet.py | 128 ++++- Geom3D/models/GearNet_layer.py | 61 ++- Geom3D/models/ProNet/ProNet.py | 11 +- Geom3D/models/__init__.py | 61 +-- examples_3D/config.py | 25 + examples_3D/finetune_ECMultiple.py | 503 ++++++++++++++++++ examples_3D/finetune_ECSingle.py | 407 ++++++++++++++ examples_3D/finetune_FOLD.py | 436 +++++++++++++++ examples_3D/finetune_GO.py | 497 +++++++++++++++++ 23 files changed, 4286 insertions(+), 652 deletions(-) delete mode 100644 Geom3D/datasets/datasetFOLD_GVP.py create mode 100644 Geom3D/datasets/dataset_ECMultiple.py create mode 100644 Geom3D/datasets/dataset_ECMultiple_GearNet.py create mode 100644 Geom3D/datasets/dataset_ECSingle.py create mode 100644 Geom3D/datasets/dataset_ECSingle_GearNet.py delete mode 100644 Geom3D/datasets/dataset_FOLD_CDConv.py create mode 100644 Geom3D/datasets/dataset_GO.py create mode 100644 Geom3D/datasets/dataset_GO_GearNet.py create mode 100644 Geom3D/datasets/dataset_GVP.py create mode 100644 examples_3D/finetune_ECMultiple.py create mode 100644 examples_3D/finetune_ECSingle.py create mode 100644 examples_3D/finetune_FOLD.py create mode 100644 examples_3D/finetune_GO.py diff --git a/Geom3D/datasets/__init__.py b/Geom3D/datasets/__init__.py index 81debee..e8ad773 100644 --- a/Geom3D/datasets/__init__.py +++ b/Geom3D/datasets/__init__.py @@ -1,65 +1,70 @@ from Geom3D.datasets.dataset_utils import graph_data_obj_to_nx_simple, nx_to_graph_data_obj_simple, atom_type_count -from Geom3D.datasets.dataset_GEOM import MoleculeDatasetGEOM -from Geom3D.datasets.dataset_GEOM_Drugs import MoleculeDatasetGEOMDrugs, MoleculeDatasetGEOMDrugsTest -from Geom3D.datasets.dataset_GEOM_QM9 import MoleculeDatasetGEOMQM9, MoleculeDatasetGEOMQM9Test +# from Geom3D.datasets.dataset_GEOM import MoleculeDatasetGEOM +# from Geom3D.datasets.dataset_GEOM_Drugs import MoleculeDatasetGEOMDrugs, MoleculeDatasetGEOMDrugsTest +# from Geom3D.datasets.dataset_GEOM_QM9 import MoleculeDatasetGEOMQM9, MoleculeDatasetGEOMQM9Test -from Geom3D.datasets.dataset_Molecule3D import Molecule3D +# from Geom3D.datasets.dataset_Molecule3D import Molecule3D -from Geom3D.datasets.dataset_PCQM4Mv2 import PCQM4Mv2 -from Geom3D.datasets.dataset_PCQM4Mv2_3D_and_MMFF import PCQM4Mv2_3DandMMFF +# from Geom3D.datasets.dataset_PCQM4Mv2 import PCQM4Mv2 +# from Geom3D.datasets.dataset_PCQM4Mv2_3D_and_MMFF import PCQM4Mv2_3DandMMFF -from Geom3D.datasets.dataset_QM9 import MoleculeDatasetQM9 -from Geom3D.datasets.dataset_QM9_2D import MoleculeDatasetQM92D -from Geom3D.datasets.dataset_QM9_Fingerprints_SMILES import MoleculeDatasetQM9FingerprintsSMILES -from Geom3D.datasets.dataset_QM9_RDKit import MoleculeDatasetQM9RDKit -from Geom3D.datasets.dataset_QM9_3D_and_MMFF import MoleculeDatasetQM9_3DandMMFF -from Geom3D.datasets.dataset_QM9_2D_3D_Transformer import MoleculeDatasetQM9_2Dand3DTransformer +# from Geom3D.datasets.dataset_QM9 import MoleculeDatasetQM9 +# from Geom3D.datasets.dataset_QM9_2D import MoleculeDatasetQM92D +# from Geom3D.datasets.dataset_QM9_Fingerprints_SMILES import MoleculeDatasetQM9FingerprintsSMILES +# from Geom3D.datasets.dataset_QM9_RDKit import MoleculeDatasetQM9RDKit +# from Geom3D.datasets.dataset_QM9_3D_and_MMFF import MoleculeDatasetQM9_3DandMMFF +# from Geom3D.datasets.dataset_QM9_2D_3D_Transformer import MoleculeDatasetQM9_2Dand3DTransformer -from Geom3D.datasets.dataset_COLL import DatasetCOLL -from Geom3D.datasets.dataset_COLLRadius import DatasetCOLLRadius -from Geom3D.datasets.dataset_COLLGemNet import DatasetCOLLGemNet +# from Geom3D.datasets.dataset_COLL import DatasetCOLL +# from Geom3D.datasets.dataset_COLLRadius import DatasetCOLLRadius +# from Geom3D.datasets.dataset_COLLGemNet import DatasetCOLLGemNet -from Geom3D.datasets.dataset_MD17 import DatasetMD17 -from Geom3D.datasets.dataset_rMD17 import DatasetrMD17 +# from Geom3D.datasets.dataset_MD17 import DatasetMD17 +# from Geom3D.datasets.dataset_rMD17 import DatasetrMD17 -from Geom3D.datasets.dataset_LBA import DatasetLBA, TransformLBA -from Geom3D.datasets.dataset_LBARadius import DatasetLBARadius +# from Geom3D.datasets.dataset_LBA import DatasetLBA, TransformLBA +# from Geom3D.datasets.dataset_LBARadius import DatasetLBARadius -from Geom3D.datasets.dataset_LEP import DatasetLEP, TransformLEP -from Geom3D.datasets.dataset_LEPRadius import DatasetLEPRadius +# from Geom3D.datasets.dataset_LEP import DatasetLEP, TransformLEP +# from Geom3D.datasets.dataset_LEPRadius import DatasetLEPRadius -from Geom3D.datasets.dataset_OC20 import DatasetOC20, is2re_data_transform, s2ef_data_transform +# from Geom3D.datasets.dataset_OC20 import DatasetOC20, is2re_data_transform, s2ef_data_transform -from Geom3D.datasets.dataset_MoleculeNet_2D import MoleculeNetDataset2D -from Geom3D.datasets.dataset_MoleculeNet_3D import MoleculeNetDataset3D, MoleculeNetDataset2D_SDE3D +# from Geom3D.datasets.dataset_MoleculeNet_2D import MoleculeNetDataset2D +# from Geom3D.datasets.dataset_MoleculeNet_3D import MoleculeNetDataset3D, MoleculeNetDataset2D_SDE3D -from Geom3D.datasets.dataset_QMOF import DatasetQMOF -from Geom3D.datasets.dataset_MatBench import DatasetMatBench +# from Geom3D.datasets.dataset_QMOF import DatasetQMOF +# from Geom3D.datasets.dataset_MatBench import DatasetMatBench -from Geom3D.datasets.dataset_3D import Molecule3DDataset -from Geom3D.datasets.dataset_3D_Radius import MoleculeDataset3DRadius -from Geom3D.datasets.dataset_3D_Remove_Center import MoleculeDataset3DRemoveCenter +# from Geom3D.datasets.dataset_3D import Molecule3DDataset +# from Geom3D.datasets.dataset_3D_Radius import MoleculeDataset3DRadius +# from Geom3D.datasets.dataset_3D_Remove_Center import MoleculeDataset3DRemoveCenter -# For Distance Prediction -from Geom3D.datasets.dataset_3D_Full import MoleculeDataset3DFull +# # For Distance Prediction +# from Geom3D.datasets.dataset_3D_Full import MoleculeDataset3DFull -# For Torsion Prediction -from Geom3D.datasets.dataset_3D_TorsionAngle import MoleculeDataset3DTorsionAngle +# # For Torsion Prediction +# from Geom3D.datasets.dataset_3D_TorsionAngle import MoleculeDataset3DTorsionAngle -from Geom3D.datasets.dataset_OneAtom import MoleculeDatasetOneAtom +# from Geom3D.datasets.dataset_OneAtom import MoleculeDatasetOneAtom -# For 2D N-Gram-Path -from Geom3D.datasets.dataset_2D_Dense import MoleculeDataset2DDense +# # For 2D N-Gram-Path +# from Geom3D.datasets.dataset_2D_Dense import MoleculeDataset2DDense # For protein -from Geom3D.datasets.dataset_EC import DatasetEC from Geom3D.datasets.dataset_FOLD import DatasetFOLD -from Geom3D.datasets.datasetFOLD_GVP import DatasetFOLD_GVP from Geom3D.datasets.dataset_FOLD_GearNet import DatasetFOLDGearNet -from Geom3D.datasets.dataset_FOLD_CDConv import DatasetFOLD_CDConv +from Geom3D.datasets.dataset_ECSingle import DatasetECSingle +from Geom3D.datasets.dataset_ECMultiple import DatasetECMultiple +from Geom3D.datasets.dataset_GO import DatasetGO +from Geom3D.datasets.dataset_GVP import DatasetGVP +from Geom3D.datasets.dataset_GO_GearNet import DatasetGOGearNet +from Geom3D.datasets.dataset_ECMultiple_GearNet import DatasetECMultipleGearNet +from Geom3D.datasets.dataset_MSP_GearNet import DatasetMSPGearNet +from Geom3D.datasets.dataset_ECSingle_GearNet import DatasetECSingleGearNet # For 2D SSL -from Geom3D.datasets.dataset_2D_Contextual import MoleculeContextualDataset -from Geom3D.datasets.dataset_2D_GPT import MoleculeDatasetGPT -from Geom3D.datasets.dataset_2D_GraphCL import MoleculeDataset_GraphCL \ No newline at end of file +# from Geom3D.datasets.dataset_2D_Contextual import MoleculeContextualDataset +# from Geom3D.datasets.dataset_2D_GPT import MoleculeDatasetGPT +# from Geom3D.datasets.dataset_2D_GraphCL import MoleculeDataset_GraphCL \ No newline at end of file diff --git a/Geom3D/datasets/datasetFOLD_GVP.py b/Geom3D/datasets/datasetFOLD_GVP.py deleted file mode 100644 index eb3d94f..0000000 --- a/Geom3D/datasets/datasetFOLD_GVP.py +++ /dev/null @@ -1,345 +0,0 @@ -# Credit to https://github.com/divelab/DIG/blob/dig-stable/dig/threedgraph/dataset/ECdataset.py -# Data processing credit to https://github.com/drorlab/gvp-pytorch/blob/main/gvp/atom3d.py -import os.path as osp -import numpy as np -import warnings -from tqdm import tqdm -import pandas as pd - -import torch, random, scipy, math -import torch.nn as nn -import torch.nn.functional as F -from torch.utils.data import IterableDataset -import torch_cluster, torch_geometric, torch_scatter - -from torch_geometric.data import Data -from torch_geometric.data import InMemoryDataset - - -def _normalize(tensor, dim=-1): - ''' - Normalizes a `torch.Tensor` along dimension `dim` without `nan`s. - ''' - return torch.nan_to_num( - torch.div(tensor, torch.norm(tensor, dim=dim, keepdim=True))) - - -def _rbf(D, D_min=0., D_max=20., D_count=16, device='cpu'): - ''' - From https://github.com/jingraham/neurips19-graph-protein-design - - Returns an RBF embedding of `torch.Tensor` `D` along a new axis=-1. - That is, if `D` has shape [...dims], then the returned tensor will have - shape [...dims, D_count]. - ''' - D_mu = torch.linspace(D_min, D_max, D_count, device=device) - D_mu = D_mu.view([1, -1]) - D_sigma = (D_max - D_min) / D_count - D_expand = torch.unsqueeze(D, -1) - - RBF = torch.exp(-((D_expand - D_mu) / D_sigma) ** 2) - return RBF - -def _edge_features(coords, edge_index, D_max=4.5, num_rbf=16, device='cpu'): - - E_vectors = coords[edge_index[0]] - coords[edge_index[1]] - rbf = _rbf(E_vectors.norm(dim=-1), - D_max=D_max, D_count=num_rbf, device=device) - - edge_s = rbf - edge_v = _normalize(E_vectors).unsqueeze(-2) - - edge_s, edge_v = map(torch.nan_to_num, - (edge_s, edge_v)) - - return edge_s, edge_v - - -class BaseTransform: - ''' - Implementation of an ATOM3D Transform which featurizes the atomic - coordinates in an ATOM3D dataframes into `torch_geometric.data.Data` - graphs. This class should not be used directly; instead, use the - task-specific transforms, which all extend BaseTransform. Node - and edge features are as described in the EGNN manuscript. - - Returned graphs have the following attributes: - -x atomic coordinates, shape [n_nodes, 3] - -atoms numeric encoding of atomic identity, shape [n_nodes] - -edge_index edge indices, shape [2, n_edges] - -edge_s edge scalar features, shape [n_edges, 16] - -edge_v edge scalar features, shape [n_edges, 1, 3] - - Subclasses of BaseTransform will produce graphs with additional - attributes for the tasks-specific training labels, in addition - to the above. - - All subclasses of BaseTransform directly inherit the BaseTransform - constructor. - - :param edge_cutoff: distance cutoff to use when drawing edges - :param num_rbf: number of radial bases to encode the distance on each edge - :device: if "cuda", will do preprocessing on the GPU - ''' - def __init__(self, edge_cutoff=4.5, num_rbf=16, device='cpu'): - self.edge_cutoff = edge_cutoff - self.num_rbf = num_rbf - self.device = device - - def __call__(self, df): - ''' - :param df: `pandas.DataFrame` of atomic coordinates - in the ATOM3D format - - :return: `torch_geometric.data.Data` structure graph - ''' - _element_mapping = lambda x: { - 'H' : 0, - 'C' : 1, - 'N' : 2, - 'O' : 3, - 'F' : 4, - 'S' : 5, - 'Cl': 6, 'CL': 6, - 'P' : 7 - }.get(x, 8) - - with torch.no_grad(): - coords = torch.as_tensor(df[['x', 'y', 'z']].to_numpy(), - dtype=torch.float32, device=self.device) - atoms = torch.as_tensor(list(map(_element_mapping, df.element)), - dtype=torch.long, device=self.device) - - edge_index = torch_cluster.radius_graph(coords, r=self.edge_cutoff) - - edge_s, edge_v = _edge_features(coords, edge_index, - D_max=self.edge_cutoff, num_rbf=self.num_rbf, device=self.device) - - return torch_geometric.data.Data(x=coords, atoms=atoms, - edge_index=edge_index, edge_s=edge_s, edge_v=edge_v) - - -class DatasetFOLD_GVP(InMemoryDataset): - def __init__(self, root, transform=None, pre_transform=None, pre_filter=None, split='train'): - self.split = split - self.root = root - - super(DatasetFOLD_GVP, self).__init__( - root, transform, pre_transform, pre_filter) - - self.transform, self.pre_transform, self.pre_filter = transform, pre_transform, pre_filter - self.data, self.slices = torch.load(self.processed_paths[0]) - - - @property - def processed_dir(self): - name = 'processed' - return osp.join(self.root, name, self.split) - - @property - def raw_file_names(self): - name = self.split + '.txt' - return name - - @property - def processed_file_names(self): - return 'data.pt' - - def _normalize(self,tensor, dim=-1): - ''' - Normalizes a `torch.Tensor` along dimension `dim` without `nan`s. - ''' - return torch.nan_to_num( - torch.div(tensor, torch.norm(tensor, dim=dim, keepdim=True))) - - def get_atom_pos(self, amino_types, atom_names, atom_amino_id, atom_pos): - # atoms to compute side chain torsion angles: N, CA, CB, _G/_G1, _D/_D1, _E/_E1, _Z, NH1 - mask_n = np.char.equal(atom_names, b'N') - mask_ca = np.char.equal(atom_names, b'CA') - mask_c = np.char.equal(atom_names, b'C') - mask_cb = np.char.equal(atom_names, b'CB') - mask_g = np.char.equal(atom_names, b'CG') | np.char.equal(atom_names, b'SG') | np.char.equal(atom_names, b'OG') | np.char.equal(atom_names, b'CG1') | np.char.equal(atom_names, b'OG1') - mask_d = np.char.equal(atom_names, b'CD') | np.char.equal(atom_names, b'SD') | np.char.equal(atom_names, b'CD1') | np.char.equal(atom_names, b'OD1') | np.char.equal(atom_names, b'ND1') - mask_e = np.char.equal(atom_names, b'CE') | np.char.equal(atom_names, b'NE') | np.char.equal(atom_names, b'OE1') - mask_z = np.char.equal(atom_names, b'CZ') | np.char.equal(atom_names, b'NZ') - mask_h = np.char.equal(atom_names, b'NH1') - - pos_n = np.full((len(amino_types),3),np.nan) - pos_n[atom_amino_id[mask_n]] = atom_pos[mask_n] - pos_n = torch.FloatTensor(pos_n) - - pos_ca = np.full((len(amino_types),3),np.nan) - pos_ca[atom_amino_id[mask_ca]] = atom_pos[mask_ca] - pos_ca = torch.FloatTensor(pos_ca) - - pos_c = np.full((len(amino_types),3),np.nan) - pos_c[atom_amino_id[mask_c]] = atom_pos[mask_c] - pos_c = torch.FloatTensor(pos_c) - - # if data only contain pos_ca, we set the position of C and N as the position of CA - pos_n[torch.isnan(pos_n)] = pos_ca[torch.isnan(pos_n)] - pos_c[torch.isnan(pos_c)] = pos_ca[torch.isnan(pos_c)] - - pos_cb = np.full((len(amino_types),3),np.nan) - pos_cb[atom_amino_id[mask_cb]] = atom_pos[mask_cb] - pos_cb = torch.FloatTensor(pos_cb) - - pos_g = np.full((len(amino_types),3),np.nan) - pos_g[atom_amino_id[mask_g]] = atom_pos[mask_g] - pos_g = torch.FloatTensor(pos_g) - - pos_d = np.full((len(amino_types),3),np.nan) - pos_d[atom_amino_id[mask_d]] = atom_pos[mask_d] - pos_d = torch.FloatTensor(pos_d) - - pos_e = np.full((len(amino_types),3),np.nan) - pos_e[atom_amino_id[mask_e]] = atom_pos[mask_e] - pos_e = torch.FloatTensor(pos_e) - - pos_z = np.full((len(amino_types),3),np.nan) - pos_z[atom_amino_id[mask_z]] = atom_pos[mask_z] - pos_z = torch.FloatTensor(pos_z) - - pos_h = np.full((len(amino_types),3),np.nan) - pos_h[atom_amino_id[mask_h]] = atom_pos[mask_h] - pos_h = torch.FloatTensor(pos_h) - - return pos_n, pos_ca, pos_c, pos_cb, pos_g, pos_d, pos_e, pos_z, pos_h - - def side_chain_embs(self, pos_n, pos_ca, pos_c, pos_cb, pos_g, pos_d, pos_e, pos_z, pos_h): - v1, v2, v3, v4, v5, v6, v7 = pos_ca - pos_n, pos_cb - pos_ca, pos_g - pos_cb, pos_d - pos_g, pos_e - pos_d, pos_z - pos_e, pos_h - pos_z - - # five side chain torsion angles - # We only consider the first four torsion angles in side chains since only the amino acid arginine has five side chain torsion angles, and the fifth angle is close to 0. - angle1 = torch.unsqueeze(self.compute_diherals(v1, v2, v3),1) - angle2 = torch.unsqueeze(self.compute_diherals(v2, v3, v4),1) - angle3 = torch.unsqueeze(self.compute_diherals(v3, v4, v5),1) - angle4 = torch.unsqueeze(self.compute_diherals(v4, v5, v6),1) - angle5 = torch.unsqueeze(self.compute_diherals(v5, v6, v7),1) - - side_chain_angles = torch.cat((angle1, angle2, angle3, angle4),1) - side_chain_embs = torch.cat((torch.sin(side_chain_angles), torch.cos(side_chain_angles)),1) - - return side_chain_embs - - def bb_embs(self, X): - # X should be a num_residues x 3 x 3, order N, C-alpha, and C atoms of each residue - # N coords: X[:,0,:] - # CA coords: X[:,1,:] - # C coords: X[:,2,:] - # return num_residues x 6 - # From https://github.com/jingraham/neurips19-graph-protein-design - - X = torch.reshape(X, [3 * X.shape[0], 3]) - dX = X[1:] - X[:-1] - U = self._normalize(dX, dim=-1) - u0 = U[:-2] - u1 = U[1:-1] - u2 = U[2:] - - angle = self.compute_diherals(u0, u1, u2) - - # add phi[0], psi[-1], omega[-1] with value 0 - angle = F.pad(angle, [1, 2]) - angle = torch.reshape(angle, [-1, 3]) - angle_features = torch.cat([torch.cos(angle), torch.sin(angle)], 1) - return angle_features - - def compute_diherals(self, v1, v2, v3): - n1 = torch.cross(v1, v2) - n2 = torch.cross(v2, v3) - a = (n1 * n2).sum(dim=-1) - b = torch.nan_to_num((torch.cross(n1, n2) * v2).sum(dim=-1) / v2.norm(dim=1)) - torsion = torch.nan_to_num(torch.atan2(b, a)) - return torsion - - def protein_to_graph(self, pFilePath): - import h5py - h5File = h5py.File(pFilePath, "r") - data = Data() - - amino_types = h5File['amino_types'][()] # size: (n_amino,) - mask = amino_types == -1 - if np.sum(mask) > 0: - amino_types[mask] = 25 # for amino acid types, set the value of -1 to 25 - atom_amino_id = h5File['atom_amino_id'][()] # size: (n_atom,) - atom_names = h5File['atom_names'][()] # size: (n_atom,) - atom_pos = h5File['atom_pos'][()][0] #size: (n_atom,3) - - # atoms to compute side chain torsion angles: N, CA, CB, _G/_G1, _D/_D1, _E/_E1, _Z, NH1 - pos_n, pos_ca, pos_c, pos_cb, pos_g, pos_d, pos_e, pos_z, pos_h = self.get_atom_pos(amino_types, atom_names, atom_amino_id, atom_pos) - - # five side chain torsion angles - # We only consider the first four torsion angles in side chains since only the amino acid arginine has five side chain torsion angles, and the fifth angle is close to 0. - side_chain_embs = self.side_chain_embs(pos_n, pos_ca, pos_c, pos_cb, pos_g, pos_d, pos_e, pos_z, pos_h) - side_chain_embs[torch.isnan(side_chain_embs)] = 0 - - # three backbone torsion angles - bb_embs = self.bb_embs(torch.cat((torch.unsqueeze(pos_n,1), torch.unsqueeze(pos_ca,1), torch.unsqueeze(pos_c,1)),1)) - bb_embs[torch.isnan(bb_embs)] = 0 - - # backbone atoms' positions - C_list_1 = ["C"] * pos_ca.shape[0] - N_list = ["N"] * pos_n.shape[0] - C_list_2 = ["C"] * pos_c.shape[0] - element = C_list_1 + N_list + C_list_2 - - backbone_coords = torch.cat((pos_ca, pos_n, pos_c)) - header = ["x", "y", "z"] - backbone_df = pd.DataFrame(backbone_coords.numpy(), columns=header) - backbone_df["element"] = element - - backbone = self.graph(backbone_df) - data.x = backbone.atoms - data.atoms = backbone.atoms - data.edge_index = backbone.edge_index - data.edge_s = backbone.edge_s - data.edge_v = backbone.edge_v - - h5File.close() - return data - - def process(self): - self.graph = BaseTransform(device="cuda") - print('Beginning Processing ...') - - # Load the file with the list of functions. - classes_ = {} - with open(self.root+"/class_map.txt", 'r') as mFile: - for line in mFile: - lineList = line.rstrip().split('\t') - classes_[lineList[0]] = int(lineList[1]) - - # Get the file list. - fileList_ = [] - cathegories_ = [] - with open(self.root+"/"+self.split+".txt", 'r') as mFile: - for curLine in mFile: - splitLine = curLine.rstrip().split('\t') - curClass = classes_[splitLine[-1]] - fileList_.append(self.root+"/"+self.split+"/"+splitLine[0]) - cathegories_.append(curClass) - - # Load the dataset - print("Reading the data") - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - data_list = [] - for fileIter, curFile in tqdm(enumerate(fileList_)): - print(curFile) - fileName = curFile.split('/')[-1] - curProtein = self.protein_to_graph(curFile+".hdf5") - curProtein.id = fileName - curProtein.y = torch.tensor(cathegories_[fileIter]) - if not curProtein.x is None: - data_list.append(curProtein) - data, slices = self.collate(data_list) - torch.save((data, slices), self.processed_paths[0]) - print('Done!') - - -if __name__ == "__main__": - for split in ['training', 'validation', 'test_fold', 'test_superfamily', 'test_family']: - print('#### Now processing {} data ####'.format(split)) - dataset = DatasetFOLD_GVP(root='../../data/FOLD', split=split) - print(dataset) \ No newline at end of file diff --git a/Geom3D/datasets/dataset_ECMultiple.py b/Geom3D/datasets/dataset_ECMultiple.py new file mode 100644 index 0000000..6d4c7ee --- /dev/null +++ b/Geom3D/datasets/dataset_ECMultiple.py @@ -0,0 +1,425 @@ +import os.path as osp +import os +import numpy as np +import warnings +from tqdm import tqdm +from sklearn.preprocessing import normalize +import h5py + +import torch, math +import torch.nn.functional as F +import torch_cluster + +from Bio.PDB import PDBParser +from Bio.PDB.Polypeptide import three_to_one, is_aa +import sys +import Bio.PDB +import Bio.PDB.StructureBuilder +from Bio.PDB.Residue import Residue + +from torch_geometric.data import Data +from torch_geometric.data import InMemoryDataset + + +class SloppyStructureBuilder(Bio.PDB.StructureBuilder.StructureBuilder): + """Cope with resSeq < 10,000 limitation by just incrementing internally.""" + + def __init__(self, verbose=False): + Bio.PDB.StructureBuilder.StructureBuilder.__init__(self) + self.max_resseq = -1 + self.verbose = verbose + + def init_residue(self, resname, field, resseq, icode): + """Initiate a new Residue object. + Arguments: + resname: string, e.g. "ASN" + field: hetero flag, "W" for waters, "H" for hetero residues, otherwise blanc. + resseq: int, sequence identifier + icode: string, insertion code + Return: + None + """ + if field != " ": + if field == "H": + # The hetero field consists of + # H_ + the residue name (e.g. H_FUC) + field = "H_" + resname + res_id = (field, resseq, icode) + + if resseq > self.max_resseq: + self.max_resseq = resseq + + if field == " ": + fudged_resseq = False + while self.chain.has_id(res_id) or resseq == 0: + # There already is a residue with the id (field, resseq, icode) + # resseq == 0 catches already wrapped residue numbers which + # do not trigger the has_id() test. + # + # Be sloppy and just increment... + # (This code will not leave gaps in resids... I think) + # + # XXX: shouldn't we also do this for hetero atoms and water?? + self.max_resseq += 1 + resseq = self.max_resseq + res_id = (field, resseq, icode) # use max_resseq! + fudged_resseq = True + + if fudged_resseq and self.verbose: + sys.stderr.write( + "Residues are wrapping (Residue " + + "('%s', %i, '%s') at line %i)." + % (field, resseq, icode, self.line_counter) + + ".... assigning new resid %d.\n" % self.max_resseq + ) + residue = Residue(res_id, resname, self.segid) + self.chain.add(residue) + self.residue = residue + return None + + +class DatasetECMultiple(InMemoryDataset): + def __init__(self, root, transform=None, pre_transform=None, pre_filter=None, split='train', percent=0.3): + self.split = split + self.root = root + self.percent = percent + + self.letter_to_num = { + 'C': 4, 'D': 3, 'S': 15, 'Q': 5, 'K': 11, 'I': 9, + 'P': 14, 'T': 16, 'F': 13, 'A': 0, 'G': 7, 'H': 8, + 'E': 6, 'L': 10, 'R': 1, 'W': 17, 'V': 19, + 'N': 2, 'Y': 18, 'M': 12, "X":20} + + super(DatasetECMultiple, self).__init__( + root, transform, pre_transform, pre_filter) + + self.transform, self.pre_transform, self.pre_filter = transform, pre_transform, pre_filter + self.data, self.slices = torch.load(self.processed_paths[0]) + + + @property + def processed_dir(self): + if self.split != "test": + name = 'processed_ECMultiple_{}'.format(self.split) + return osp.join(self.root, name) + else: + name = 'processed_ECMultiple_test_{}'.format(self.percent) + return osp.join(self.root, name) + + @property + def raw_file_names(self): + name = self.split + '.txt' + return name + + @property + def processed_file_names(self): + return 'data.pt' + + def get_side_chain_angle_encoding(self, pos_n, pos_ca, pos_c, pos_cb, pos_g, pos_d, pos_e, pos_z, pos_h): + v1, v2, v3, v4, v5, v6, v7 = pos_ca - pos_n, pos_cb - pos_ca, pos_g - pos_cb, pos_d - pos_g, pos_e - pos_d, pos_z - pos_e, pos_h - pos_z + + # five side chain torsion angles + # We only consider the first four torsion angles in side chains since only the amino acid arginine has five side chain torsion angles, and the fifth angle is close to 0. + angle1 = torch.unsqueeze(self.diherals_ProNet(v1, v2, v3),1) + angle2 = torch.unsqueeze(self.diherals_ProNet(v2, v3, v4),1) + angle3 = torch.unsqueeze(self.diherals_ProNet(v3, v4, v5),1) + angle4 = torch.unsqueeze(self.diherals_ProNet(v4, v5, v6),1) + angle5 = torch.unsqueeze(self.diherals_ProNet(v5, v6, v7),1) + + side_chain_angles = torch.cat((angle1, angle2, angle3, angle4),1) + side_chain_embs = torch.cat((torch.sin(side_chain_angles), torch.cos(side_chain_angles)),1) + + return side_chain_embs + + def get_backbone_angle_encoding(self, X): + # X should be a num_residues x 3 x 3, order N, C-alpha, and C atoms of each residue + # N coords: X[:,0,:] + # CA coords: X[:,1,:] + # C coords: X[:,2,:] + # return num_residues x 6 + # From https://github.com/jingraham/neurips19-graph-protein-design + + X = torch.reshape(X, [3 * X.shape[0], 3]) + dX = X[1:] - X[:-1] + U = self._normalize(dX, dim=-1) + u0 = U[:-2] + u1 = U[1:-1] + u2 = U[2:] + + angle = self.diherals_ProNet(u0, u1, u2) + + # add phi[0], psi[-1], omega[-1] with value 0 + angle = F.pad(angle, [1, 2]) + angle = torch.reshape(angle, [-1, 3]) + angle_features = torch.cat([torch.cos(angle), torch.sin(angle)], 1) + return angle_features + + def diherals_ProNet(self, v1, v2, v3): + n1 = torch.cross(v1, v2) + n2 = torch.cross(v2, v3) + a = (n1 * n2).sum(dim=-1) + b = torch.nan_to_num((torch.cross(n1, n2) * v2).sum(dim=-1) / v2.norm(dim=1)) + torsion = torch.nan_to_num(torch.atan2(b, a)) + return torsion + + def _normalize(self, tensor, dim=-1): + ''' + Normalizes a `torch.Tensor` along dimension `dim` without `nan`s. + ''' + return torch.nan_to_num( + torch.div(tensor, torch.norm(tensor, dim=dim, keepdim=True))) + + def three_to_one_standard(self, res): + if not is_aa(res, standard=True): + return "X" + + return three_to_one(res) + + def chain_info(self, chain, name): + """Convert a PDB chain in to coordinates of target atoms from all + AAs + + Args: + chain: a Bio.PDB.Chain object + target_atoms: Target atoms which residues will be resturned. + name: String. Name of the protein. + Returns: + Dictonary containing protein sequence `seq`, 3D coordinates `coord` and name `name`. + + """ + atom_names, atom_amino_id, atom_pos, residue_types = [], [], [], [] + pdb_seq = "" + residue_index = 0 + for residue in chain.get_residues(): + if is_aa(residue) and any(atom.get_name() == "CA" for atom in residue.get_atoms()): + residue_name = self.three_to_one_standard(residue.get_resname()) + pdb_seq += residue_name + residue_types.append(self.letter_to_num[residue_name]) + + for atom in residue.get_atoms(): + atom_names.append(atom.get_name()) + atom_amino_id.append(residue_index) + atom_pos.append(atom.coord) + + residue_index += 1 + + mask_n = np.char.equal(atom_names, 'N') + mask_ca = np.char.equal(atom_names, 'CA') + mask_c = np.char.equal(atom_names, 'C') + mask_cb = np.char.equal(atom_names, 'CB') + mask_g = np.char.equal(atom_names, 'CG') | np.char.equal(atom_names, 'SG') | np.char.equal(atom_names, 'OG') | np.char.equal(atom_names, 'CG1') | np.char.equal(atom_names, 'OG1') + mask_d = np.char.equal(atom_names, 'CD') | np.char.equal(atom_names, 'SD') | np.char.equal(atom_names, 'CD1') | np.char.equal(atom_names, 'OD1') | np.char.equal(atom_names, 'ND1') + mask_e = np.char.equal(atom_names, 'CE') | np.char.equal(atom_names, 'NE') | np.char.equal(atom_names, 'OE1') + mask_z = np.char.equal(atom_names, 'CZ') | np.char.equal(atom_names, 'NZ') + mask_h = np.char.equal(atom_names, 'NH1') + + atom_amino_id = np.array(atom_amino_id) + atom_pos = np.array(atom_pos) + + pos_n = np.full((len(pdb_seq), 3),np.nan) + pos_n[atom_amino_id[mask_n]] = atom_pos[mask_n] + pos_n = torch.FloatTensor(pos_n) + + pos_ca = np.full((len(pdb_seq), 3),np.nan) + pos_ca[atom_amino_id[mask_ca]] = atom_pos[mask_ca] + pos_ca = torch.FloatTensor(pos_ca) + + pos_c = np.full((len(pdb_seq), 3),np.nan) + pos_c[atom_amino_id[mask_c]] = atom_pos[mask_c] + pos_c = torch.FloatTensor(pos_c) + + # if data only contain pos_ca, we set the position of C and N as the position of CA + pos_n[torch.isnan(pos_n)] = pos_ca[torch.isnan(pos_n)] + pos_c[torch.isnan(pos_c)] = pos_ca[torch.isnan(pos_c)] + + pos_cb = np.full((len(pdb_seq), 3),np.nan) + pos_cb[atom_amino_id[mask_cb]] = atom_pos[mask_cb] + pos_cb = torch.FloatTensor(pos_cb) + + pos_g = np.full((len(pdb_seq), 3),np.nan) + pos_g[atom_amino_id[mask_g]] = atom_pos[mask_g] + pos_g = torch.FloatTensor(pos_g) + + pos_d = np.full((len(pdb_seq), 3),np.nan) + pos_d[atom_amino_id[mask_d]] = atom_pos[mask_d] + pos_d = torch.FloatTensor(pos_d) + + pos_e = np.full((len(pdb_seq), 3),np.nan) + pos_e[atom_amino_id[mask_e]] = atom_pos[mask_e] + pos_e = torch.FloatTensor(pos_e) + + pos_z = np.full((len(pdb_seq), 3),np.nan) + pos_z[atom_amino_id[mask_z]] = atom_pos[mask_z] + pos_z = torch.FloatTensor(pos_z) + + pos_h = np.full((len(pdb_seq), 3),np.nan) + pos_h[atom_amino_id[mask_h]] = atom_pos[mask_h] + pos_h = torch.FloatTensor(pos_h) + + chain_struc = { + 'name': name, + 'pos_n': pos_n, + 'pos_ca': pos_ca, + 'pos_c': pos_c, + 'pos_cb': pos_cb, + 'pos_g': pos_g, + 'pos_d': pos_d, + 'pos_e': pos_e, + 'pos_z': pos_z, + 'pos_h': pos_h, + 'atom_names': atom_names, + 'atom_pos': atom_pos, + 'residue_types': residue_types + } + + if len(pdb_seq) <= 1: + # has no or only 1 AA in the chain + return None + + return chain_struc + + + def extract_protein_data(self, pFilePath): + data = Data() + + pdb_parser = PDBParser( + QUIET=True, + PERMISSIVE=True, + structure_builder=SloppyStructureBuilder(), + ) + + name = os.path.basename(pFilePath).split("_")[0] + + try: + structure = pdb_parser.get_structure(name, pFilePath) + except Exception as e: + print(pFilePath, "raised an error:") + print(e) + return None + + records = [] + chain_ids = [] + + for chain in structure.get_chains(): + if chain.id in chain_ids: # skip duplicated chains + continue + chain_ids.append(chain.id) + record = self.chain_info(chain, "{}-{}".format(name.split("-")[0], chain.id)) + if record is not None: + records.append(record) + + records = [rec for rec in records if rec["name"] in self.data] + + for i in records: + if i["name"] == name: + pos_n, pos_ca, pos_c, pos_cb, pos_g, pos_d, pos_e, pos_z, pos_h, atom_names, atom_pos, residue_types = (i[k] for k in ["pos_n", "pos_ca", "pos_c", "pos_cb", "pos_g", "pos_d", "pos_e", "pos_z", "pos_h", "atom_names", "atom_pos", "residue_types"]) + + # calculate side chain torsion angles, up to four + # do encoding + side_chain_angle_encoding = self.get_side_chain_angle_encoding(pos_n, pos_ca, pos_c, pos_cb, pos_g, pos_d, pos_e, pos_z, pos_h) + side_chain_angle_encoding[torch.isnan(side_chain_angle_encoding)] = 0 + + # three backbone torsion angles + backbone_angle_encoding = self.get_backbone_angle_encoding(torch.cat((torch.unsqueeze(pos_n,1), torch.unsqueeze(pos_ca,1), torch.unsqueeze(pos_c,1)),1)) + backbone_angle_encoding[torch.isnan(backbone_angle_encoding)] = 0 + + data.seq = torch.LongTensor(residue_types) + data.side_chain_angle_encoding = side_chain_angle_encoding + data.backbone_angle_encoding = backbone_angle_encoding + data.coords_ca = pos_ca + data.coords_n = pos_n + data.coords_c = pos_c + data.x = atom_names + data.atom_pos = torch.tensor(atom_pos) + data.num_nodes = len(pos_ca) + + return data + + def process(self): + print('Beginning Processing ...') + + if self.split != "test": + with open(os.path.join(self.root, f"nrPDB-EC_{self.split}.txt"), 'r') as file: + self.data = set([line.strip() for line in file]) + else: + self.data = set() + with open(os.path.join(self.root, "nrPDB-EC_test.csv"), 'r') as f: + head = True + for line in f: + if head: + head = False + continue + arr = line.rstrip().split(',') + if self.percent == 0.3 and arr[1] == '1': + self.data.add(arr[0]) + elif self.percent == 0.4 and arr[2] == '1': + self.data.add(arr[0]) + elif self.percent == 0.5 and arr[3] == '1': + self.data.add(arr[0]) + elif self.percent == 0.7 and arr[4] == '1': + self.data.add(arr[0]) + elif self.percent == 0.95 and arr[5] == '1': + self.data.add(arr[0]) + else: + pass + + + # 2. Parse the structure files and save to json files + structure_file_dir = os.path.join( + self.root, f"{self.split}" + ) + files = os.listdir(structure_file_dir) + + + level_idx = 1 + ec_cnt = 0 + ec_num = {} + ec_annotations = {} + self.labels = {} + + with open(os.path.join(self.root, 'nrPDB-EC_annot.tsv'), 'r') as f: + for idx, line in enumerate(f): + if idx == 1: + arr = line.rstrip().split('\t') + for ec in arr: + ec_annotations[ec] = ec_cnt + ec_num[ec] = 0 + ec_cnt += 1 + + elif idx > 2: + arr = line.rstrip().split('\t') + protein_labels = [] + if len(arr) > level_idx: + protein_ec_list = arr[level_idx] + protein_ec_list = protein_ec_list.split(',') + for ec in protein_ec_list: + if len(ec) > 0: + protein_labels.append(ec_annotations[ec]) + ec_num[ec] += 1 + self.labels[arr[0]] = np.array(protein_labels) + + self.num_class = len(ec_annotations) + + invalid_PDB_file_name_list = ["2UV2-A_16534.pdb", "1ENM-A_16555.pdb", "1DIN-A_7896.pdb"] + + data_list = [] + for i in tqdm(range(len(files))): + if files[i].split("_")[0] in self.data: + if files[i] in invalid_PDB_file_name_list: + print("Skipping invalid file {}...".format(files[i])) + continue + file_name = osp.join(self.root, self.split, files[i]) + protein = self.extract_protein_data(file_name) + label = np.zeros((self.num_class,)).astype(np.float32) + + if len(self.labels[osp.basename(file_name).split("_")[0]]) > 0: + label[self.labels[osp.basename(file_name).split("_")[0]]] = 1.0 + + if protein is not None: + protein.id = files[i] + protein.y = torch.tensor(label).unsqueeze(0) + data_list.append(protein) + + data, slices = self.collate(data_list) + torch.save((data, slices), self.processed_paths[0]) + print('Done!') diff --git a/Geom3D/datasets/dataset_ECMultiple_GearNet.py b/Geom3D/datasets/dataset_ECMultiple_GearNet.py new file mode 100644 index 0000000..f5739cc --- /dev/null +++ b/Geom3D/datasets/dataset_ECMultiple_GearNet.py @@ -0,0 +1,294 @@ +import os.path as osp +import os +import numpy as np +import warnings +from tqdm import tqdm +from sklearn.preprocessing import normalize +import h5py +import itertools +from collections import defaultdict + +import torch, math +import torch.nn.functional as F +import torch_cluster + +from Bio.PDB import PDBParser +from Bio.PDB.Polypeptide import three_to_one, is_aa +import sys +import Bio.PDB +import Bio.PDB.StructureBuilder +from Bio.PDB.Residue import Residue + +from torch_geometric.data import Data +from torch_geometric.data import InMemoryDataset + + +class SloppyStructureBuilder(Bio.PDB.StructureBuilder.StructureBuilder): + """Cope with resSeq < 10,000 limitation by just incrementing internally.""" + + def __init__(self, verbose=False): + Bio.PDB.StructureBuilder.StructureBuilder.__init__(self) + self.max_resseq = -1 + self.verbose = verbose + + def init_residue(self, resname, field, resseq, icode): + """Initiate a new Residue object. + Arguments: + resname: string, e.g. "ASN" + field: hetero flag, "W" for waters, "H" for hetero residues, otherwise blanc. + resseq: int, sequence identifier + icode: string, insertion code + Return: + None + """ + if field != " ": + if field == "H": + # The hetero field consists of + # H_ + the residue name (e.g. H_FUC) + field = "H_" + resname + res_id = (field, resseq, icode) + + if resseq > self.max_resseq: + self.max_resseq = resseq + + if field == " ": + fudged_resseq = False + while self.chain.has_id(res_id) or resseq == 0: + # There already is a residue with the id (field, resseq, icode) + # resseq == 0 catches already wrapped residue numbers which + # do not trigger the has_id() test. + # + # Be sloppy and just increment... + # (This code will not leave gaps in resids... I think) + # + # XXX: shouldn't we also do this for hetero atoms and water?? + self.max_resseq += 1 + resseq = self.max_resseq + res_id = (field, resseq, icode) # use max_resseq! + fudged_resseq = True + + if fudged_resseq and self.verbose: + sys.stderr.write( + "Residues are wrapping (Residue " + + "('%s', %i, '%s') at line %i)." + % (field, resseq, icode, self.line_counter) + + ".... assigning new resid %d.\n" % self.max_resseq + ) + residue = Residue(res_id, resname, self.segid) + self.chain.add(residue) + self.residue = residue + return None + + +class DatasetECMultipleGearNet(InMemoryDataset): + def __init__(self, root, transform=None, pre_transform=None, pre_filter=None, split='train', percent=0.3): + self.split = split + self.root = root + self.percent = percent + + self.letter_to_num = { + 'C': 4, 'D': 3, 'S': 15, 'Q': 5, 'K': 11, 'I': 9, + 'P': 14, 'T': 16, 'F': 13, 'A': 0, 'G': 7, 'H': 8, + 'E': 6, 'L': 10, 'R': 1, 'W': 17, 'V': 19, + 'N': 2, 'Y': 18, 'M': 12, "X":20} + + super(DatasetECMultipleGearNet, self).__init__( + root, transform, pre_transform, pre_filter) + + self.transform, self.pre_transform, self.pre_filter = transform, pre_transform, pre_filter + self.data, self.slices = torch.load(self.processed_paths[0]) + + + @property + def processed_dir(self): + if self.split != "test": + name = 'processed_ECMultiple_GearNet_{}'.format(self.split) + return osp.join(self.root, name) + else: + name = 'processed_ECMultiple_test_GearNet_{}'.format(self.percent) + return osp.join(self.root, name) + + @property + def raw_file_names(self): + name = self.split + '.txt' + return name + + @property + def processed_file_names(self): + return 'data.pt' + + def extract_protein_data(self, pFilePath, graph_construction_model): + from torchdrug import data + + protein = data.Protein.from_pdb(pFilePath) + protein = data.Protein.pack([protein]) + protein = graph_construction_model(protein) + item = {"graph": protein} + + if self.transform: + item = self.transform(item) + + protein = item["graph"] + seq = protein.to_sequence()[0] + residue_type = [] + residue_feature = [] + + for i in seq: + residue_type.append(data.Protein.residue_symbol2id.get(i, 0)) + residue_feature.append(data.feature.onehot(data.Protein.id2residue.get(data.Protein.residue_symbol2id.get(i)), data.feature.residue_vocab, allow_unknown=True)) + return_data = Data() + return_data.edge_list = protein.edge_list + return_data.edge_weight = torch.ones(len(protein.edge_list)) + return_data.num_residue = protein.num_residue + return_data.num_node = protein.num_node + return_data.num_edge = protein.num_edge + return_data.x = residue_type # This is important to hack the code + return_data.node_feature = residue_feature + return_data.num_relation = protein.num_relation + return_data.node_position = protein.node_position + return_data.edge_feature = protein.edge_feature + + return return_data + + def process(self): + print('Beginning Processing ...') + + from torchdrug import transforms, layers + from torchdrug.layers import geometry + + self.transform = transforms.ProteinView("residue") + + if self.split != "test": + with open(os.path.join(self.root, f"nrPDB-EC_{self.split}.txt"), 'r') as file: + self.data = set([line.strip() for line in file]) + else: + self.data = set() + with open(os.path.join(self.root, "nrPDB-EC_test.csv"), 'r') as f: + head = True + for line in f: + if head: + head = False + continue + arr = line.rstrip().split(',') + if self.percent == 0.3 and arr[1] == '1': + self.data.add(arr[0]) + elif self.percent == 0.4 and arr[2] == '1': + self.data.add(arr[0]) + elif self.percent == 0.5 and arr[3] == '1': + self.data.add(arr[0]) + elif self.percent == 0.7 and arr[4] == '1': + self.data.add(arr[0]) + elif self.percent == 0.95 and arr[5] == '1': + self.data.add(arr[0]) + else: + pass + + structure_file_dir = os.path.join( + self.root, f"{self.split}" + ) + files = os.listdir(structure_file_dir) + + + level_idx = 1 + ec_cnt = 0 + ec_num = {} + ec_annotations = {} + self.labels = {} + + with open(os.path.join(self.root, 'nrPDB-EC_annot.tsv'), 'r') as f: + for idx, line in enumerate(f): + if idx == 1: + arr = line.rstrip().split('\t') + for ec in arr: + ec_annotations[ec] = ec_cnt + ec_num[ec] = 0 + ec_cnt += 1 + + elif idx > 2: + arr = line.rstrip().split('\t') + protein_labels = [] + if len(arr) > level_idx: + protein_ec_list = arr[level_idx] + protein_ec_list = protein_ec_list.split(',') + for ec in protein_ec_list: + if len(ec) > 0: + protein_labels.append(ec_annotations[ec]) + ec_num[ec] += 1 + self.labels[arr[0]] = np.array(protein_labels) + + self.num_class = len(ec_annotations) + + graph_construction_model = layers.GraphConstruction( + node_layers=[geometry.AlphaCarbonNode()], + edge_layers=[geometry.SpatialEdge(radius=10.0, min_distance=5), geometry.KNNEdge(k=10, min_distance=5), geometry.SequentialEdge(max_distance=2)], + edge_feature="gearnet") + + data_list = [] + for i in tqdm(range(len(files))): + if files[i].split("_")[0] in self.data and files[i].split("_")[0] not in ["2UV2-A", "1ENM-A", "1DIN-A"]: + file_name = osp.join(self.root, self.split, files[i]) + try: + protein = self.extract_protein_data(file_name, graph_construction_model) + except: + protein = None + label = np.zeros((self.num_class,)).astype(np.float32) + + if len(self.labels[osp.basename(file_name).split("_")[0]]) > 0: + label[self.labels[osp.basename(file_name).split("_")[0]]] = 1.0 + + if protein is not None: + protein.id = files[i] + protein.y = torch.tensor(label).unsqueeze(0) + data_list.append(protein) + + data, slices = self.collate(data_list) + torch.save((data, slices), self.processed_paths[0]) + print('Done!') + + def collate_fn(batch): + num_nodes = [] + num_edges = [] + num_residues = [] + node_positions = [] + y = [] + num_cum_node = 0 + num_cum_edge = 0 + num_cum_residue = 0 + num_graph = 0 + data_dict = defaultdict(list) + + for graph in batch: + num_nodes.append(graph.num_node) + num_edges.append(graph.num_edge) + num_residues.append(graph.num_residue) + node_positions.append(graph.node_position) + y.append(graph.y[0]) + for k, v in graph.items(): + if k in ["num_relation", "num_node", "num_edge", "num_residue", "node_position", "y", "id"]: + continue + elif k in ["edge_list"]: + neo_v = v.clone() + neo_v[:, 0] += num_cum_node + neo_v[:, 1] += num_cum_node + data_dict[k].append(neo_v) + continue + + data_dict[k].append(v) + num_cum_node += graph.num_node + num_cum_edge += graph.num_edge + num_cum_residue += graph.num_residue + num_graph += 1 + + data_dict = {k: torch.cat([torch.tensor(v) for v in lst]) for k, lst in data_dict.items()} + + num_nodes = torch.cat(num_nodes) + num_edges = torch.cat(num_edges) + num_residues = torch.cat(num_residues) + node_positions = torch.cat(node_positions) + node2graph = torch.repeat_interleave(num_nodes) + num_node = torch.sum(num_nodes) + num_edge = torch.sum(num_edges) + + return Data( + num_nodes=num_nodes, num_node=num_node, num_edges=num_edges, num_edge=num_edge, num_residues=num_residues, num_relation=batch[0].num_relation, node_position=node_positions, + node2graph=node2graph, batch_size=len(batch), y=torch.stack(y), **data_dict) diff --git a/Geom3D/datasets/dataset_ECSingle.py b/Geom3D/datasets/dataset_ECSingle.py new file mode 100644 index 0000000..0a1c637 --- /dev/null +++ b/Geom3D/datasets/dataset_ECSingle.py @@ -0,0 +1,237 @@ +import os.path as osp +import numpy as np +import warnings +from tqdm import tqdm +from sklearn.preprocessing import normalize +import h5py + +import torch, math +import torch.nn.functional as F +import torch_cluster + +from torch_geometric.data import Data +from torch_geometric.data import InMemoryDataset + + +class DatasetECSingle(InMemoryDataset): + def __init__(self, root, transform=None, pre_transform=None, pre_filter=None, split='train'): + self.split = split + self.root = root + + super(DatasetECSingle, self).__init__( + root, transform, pre_transform, pre_filter) + + self.transform, self.pre_transform, self.pre_filter = transform, pre_transform, pre_filter + self.data, self.slices = torch.load(self.processed_paths[0]) + + @property + def processed_dir(self): + name = 'processed_ECSingle' + return osp.join(self.root, name, self.split) + + @property + def raw_file_names(self): + name = self.split + '.txt' + return name + + @property + def processed_file_names(self): + return 'data.pt' + + def _normalize(self, tensor, dim=-1): + ''' + Normalizes a `torch.Tensor` along dimension `dim` without `nan`s. + ''' + return torch.nan_to_num( + torch.div(tensor, torch.norm(tensor, dim=dim, keepdim=True))) + + def get_key_atom_pos(self, amino_types, atom_names, atom_amino_id, atom_pos): + # atoms to compute side chain torsion angles: N, CA, CB, _G/_G1, _D/_D1, _E/_E1, _Z, NH1 + mask_n = np.char.equal(atom_names, b'N') + mask_ca = np.char.equal(atom_names, b'CA') + mask_c = np.char.equal(atom_names, b'C') + mask_cb = np.char.equal(atom_names, b'CB') + mask_g = np.char.equal(atom_names, b'CG') | np.char.equal(atom_names, b'SG') | np.char.equal(atom_names, b'OG') | np.char.equal(atom_names, b'CG1') | np.char.equal(atom_names, b'OG1') + mask_d = np.char.equal(atom_names, b'CD') | np.char.equal(atom_names, b'SD') | np.char.equal(atom_names, b'CD1') | np.char.equal(atom_names, b'OD1') | np.char.equal(atom_names, b'ND1') + mask_e = np.char.equal(atom_names, b'CE') | np.char.equal(atom_names, b'NE') | np.char.equal(atom_names, b'OE1') + mask_z = np.char.equal(atom_names, b'CZ') | np.char.equal(atom_names, b'NZ') + mask_h = np.char.equal(atom_names, b'NH1') + + pos_n = np.full((len(amino_types),3),np.nan) + pos_n[atom_amino_id[mask_n]] = atom_pos[mask_n] + pos_n = torch.FloatTensor(pos_n) + + pos_ca = np.full((len(amino_types),3),np.nan) + pos_ca[atom_amino_id[mask_ca]] = atom_pos[mask_ca] + pos_ca = torch.FloatTensor(pos_ca) + + pos_c = np.full((len(amino_types),3),np.nan) + pos_c[atom_amino_id[mask_c]] = atom_pos[mask_c] + pos_c = torch.FloatTensor(pos_c) + + # if data only contain pos_ca, we set the position of C and N as the position of CA + pos_n[torch.isnan(pos_n)] = pos_ca[torch.isnan(pos_n)] + pos_c[torch.isnan(pos_c)] = pos_ca[torch.isnan(pos_c)] + + pos_cb = np.full((len(amino_types),3),np.nan) + pos_cb[atom_amino_id[mask_cb]] = atom_pos[mask_cb] + pos_cb = torch.FloatTensor(pos_cb) + + pos_g = np.full((len(amino_types),3),np.nan) + pos_g[atom_amino_id[mask_g]] = atom_pos[mask_g] + pos_g = torch.FloatTensor(pos_g) + + pos_d = np.full((len(amino_types),3),np.nan) + pos_d[atom_amino_id[mask_d]] = atom_pos[mask_d] + pos_d = torch.FloatTensor(pos_d) + + pos_e = np.full((len(amino_types),3),np.nan) + pos_e[atom_amino_id[mask_e]] = atom_pos[mask_e] + pos_e = torch.FloatTensor(pos_e) + + pos_z = np.full((len(amino_types),3),np.nan) + pos_z[atom_amino_id[mask_z]] = atom_pos[mask_z] + pos_z = torch.FloatTensor(pos_z) + + pos_h = np.full((len(amino_types),3),np.nan) + pos_h[atom_amino_id[mask_h]] = atom_pos[mask_h] + pos_h = torch.FloatTensor(pos_h) + + return pos_n, pos_ca, pos_c, pos_cb, pos_g, pos_d, pos_e, pos_z, pos_h + + def get_side_chain_angle_encoding(self, pos_n, pos_ca, pos_c, pos_cb, pos_g, pos_d, pos_e, pos_z, pos_h): + v1, v2, v3, v4, v5, v6, v7 = pos_ca - pos_n, pos_cb - pos_ca, pos_g - pos_cb, pos_d - pos_g, pos_e - pos_d, pos_z - pos_e, pos_h - pos_z + + # five side chain torsion angles + # We only consider the first four torsion angles in side chains since only the amino acid arginine has five side chain torsion angles, and the fifth angle is close to 0. + angle1 = torch.unsqueeze(self.diherals_ProNet(v1, v2, v3),1) + angle2 = torch.unsqueeze(self.diherals_ProNet(v2, v3, v4),1) + angle3 = torch.unsqueeze(self.diherals_ProNet(v3, v4, v5),1) + angle4 = torch.unsqueeze(self.diherals_ProNet(v4, v5, v6),1) + angle5 = torch.unsqueeze(self.diherals_ProNet(v5, v6, v7),1) + + side_chain_angles = torch.cat((angle1, angle2, angle3, angle4),1) + side_chain_embs = torch.cat((torch.sin(side_chain_angles), torch.cos(side_chain_angles)),1) + + return side_chain_embs + + def get_backbone_angle_encoding(self, X): + # X should be a num_residues x 3 x 3, order N, C-alpha, and C atoms of each residue + # N coords: X[:,0,:] + # CA coords: X[:,1,:] + # C coords: X[:,2,:] + # return num_residues x 6 + # From https://github.com/jingraham/neurips19-graph-protein-design + + X = torch.reshape(X, [3 * X.shape[0], 3]) + dX = X[1:] - X[:-1] + U = self._normalize(dX, dim=-1) + u0 = U[:-2] + u1 = U[1:-1] + u2 = U[2:] + + angle = self.diherals_ProNet(u0, u1, u2) + + # add phi[0], psi[-1], omega[-1] with value 0 + angle = F.pad(angle, [1, 2]) + angle = torch.reshape(angle, [-1, 3]) + angle_features = torch.cat([torch.cos(angle), torch.sin(angle)], 1) + return angle_features + + def diherals_ProNet(self, v1, v2, v3): + n1 = torch.cross(v1, v2) + n2 = torch.cross(v2, v3) + a = (n1 * n2).sum(dim=-1) + b = torch.nan_to_num((torch.cross(n1, n2) * v2).sum(dim=-1) / v2.norm(dim=1)) + torsion = torch.nan_to_num(torch.atan2(b, a)) + return torsion + + def extract_protein_data(self, pFilePath): + h5File = h5py.File(pFilePath+".hdf5", "r") + data = Data() + + amino_types = h5File['amino_types'][()] # residue or amino acid, size: (n_amino,) + mask = amino_types == -1 + if np.sum(mask) > 0: + amino_types[mask] = 25 # for amino acid types, set the value of -1 to 25 + atom_amino_id = h5File['atom_amino_id'][()] # size: (n_atom,) + atom_names = h5File['atom_names'][()] # size: (n_atom,) + atom_pos = h5File['atom_pos'][()][0] #size: (n_atom,3) + + ##### compute side chain torsion angles: N, CA, CB, _G/_G1, _D/_D1, _E/_E1, _Z, NH1 ##### + # extract key atom (e.g., backbone) positions + pos_n, pos_ca, pos_c, pos_cb, pos_g, pos_d, pos_e, pos_z, pos_h = self.get_key_atom_pos(amino_types, atom_names, atom_amino_id, atom_pos) + + # calculate side chain torsion angles, up to four + # do encoding + side_chain_angle_encoding = self.get_side_chain_angle_encoding(pos_n, pos_ca, pos_c, pos_cb, pos_g, pos_d, pos_e, pos_z, pos_h) + side_chain_angle_encoding[torch.isnan(side_chain_angle_encoding)] = 0 + + # three backbone torsion angles + backbone_angle_encoding = self.get_backbone_angle_encoding(torch.cat((torch.unsqueeze(pos_n,1), torch.unsqueeze(pos_ca,1), torch.unsqueeze(pos_c,1)),1)) + backbone_angle_encoding[torch.isnan(backbone_angle_encoding)] = 0 + + data.seq = torch.LongTensor(amino_types) + data.side_chain_angle_encoding = side_chain_angle_encoding + data.backbone_angle_encoding = backbone_angle_encoding + data.coords_ca = pos_ca + data.coords_n = pos_n + data.coords_c = pos_c + data.x = atom_names + data.atom_pos = torch.tensor(atom_pos) + data.num_nodes = len(pos_ca) + + h5File.close() + + return data + + def process(self): + print('Beginning Processing ...') + + # Load the file with the list of functions. + functions_ = [] + with open(self.root+"/unique_functions.txt", 'r') as mFile: + for line in mFile: + functions_.append(line.rstrip()) + + # Get the file list. + if self.split == "Train": + splitFile = "/training.txt" + elif self.split == "Val": + splitFile = "/validation.txt" + elif self.split == "Test": + splitFile = "/testing.txt" + + proteinNames_ = [] + fileList_ = [] + with open(self.root+splitFile, 'r') as mFile: + for line in mFile: + proteinNames_.append(line.rstrip()) + fileList_.append(self.root+"/data/"+line.rstrip()) + print("current split {}, data size {}".format(self.split, len(fileList_))) + + # Load the functions. + print("Reading protein functions") + protFunct_ = {} + with open(self.root+"/chain_functions.txt", 'r') as mFile: + for line in mFile: + splitLine = line.rstrip().split(',') + if splitLine[0] in proteinNames_: + protFunct_[splitLine[0]] = int(splitLine[1]) + + # Load the dataset + print("Reading the data") + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + data_list = [] + for fileIter, curFile in tqdm(enumerate(fileList_)): + print(curFile) + fileName = curFile.split('/')[-1] + protein = self.extract_protein_data(curFile) + protein.id = fileName + protein.y = torch.tensor(protFunct_[proteinNames_[fileIter]]) + if not protein.seq is None: + data_list.append(protein) + data, slices = self.collate(data_list) + torch.save((data, slices), self.processed_paths[0]) + print('Done!') diff --git a/Geom3D/datasets/dataset_ECSingle_GearNet.py b/Geom3D/datasets/dataset_ECSingle_GearNet.py new file mode 100644 index 0000000..8f968b3 --- /dev/null +++ b/Geom3D/datasets/dataset_ECSingle_GearNet.py @@ -0,0 +1,192 @@ +import os +import h5py +import torch +import warnings +from tqdm import tqdm +from collections import defaultdict +import os.path as osp +import copy + +import torch.nn.functional as F + +from torch_geometric.data import Data +from torch_geometric.data import InMemoryDataset + + +class DatasetECSingleGearNet(InMemoryDataset): + def __init__(self, root, transform=None, pre_transform=None, pre_filter=None, split='train'): + self.split = split + self.root = root + + super(DatasetECSingleGearNet, self).__init__( + root, transform, pre_transform, pre_filter) + + self.transform, self.pre_transform, self.pre_filter = transform, pre_transform, pre_filter + self.data, self.slices = torch.load(self.processed_paths[0]) + + @property + def processed_dir(self): + name = 'processed_GearNet_ECSingle' + return osp.join(self.root, name, self.split) + + @property + def raw_file_names(self): + name = self.split + '.txt' + return name + + @property + def processed_file_names(self): + return 'data.pt' + + def protein_to_graph(self, pFilePath, graph_construction_model): + from torchdrug import data + + h5File = h5py.File(pFilePath, "r") + + node_position = torch.as_tensor(h5File["atom_pos"][(0)]) + num_atom = node_position.shape[0] + atom_type = torch.as_tensor(h5File["atom_types"][()]) + atom_name = h5File["atom_names"][()] + atom_name = torch.as_tensor([data.Protein.atom_name2id.get(name.decode(), -1) for name in atom_name]) + atom2residue = torch.as_tensor(h5File["atom_residue_id"][()]) + residue_type_name = h5File["atom_residue_names"][()] + residue_type = [] + residue_feature = [] + lst_residue = -1 + for i in range(num_atom): + if atom2residue[i] != lst_residue: + residue_type.append(data.Protein.residue2id.get(residue_type_name[i].decode(), 0)) + residue_feature.append(data.feature.onehot(residue_type_name[i].decode(), data.feature.residue_vocab, allow_unknown=True)) + lst_residue = atom2residue[i] + residue_type = torch.as_tensor(residue_type) + residue_feature = torch.as_tensor(residue_feature) + num_residue = residue_type.shape[0] + + edge_list = torch.cat([ + torch.as_tensor(h5File["cov_bond_list"][()]), + torch.as_tensor(h5File["cov_bond_list_hb"][()]) + ], dim=0) + bond_type = torch.zeros(edge_list.shape[0], dtype=torch.long) + edge_list = torch.cat([edge_list, bond_type.unsqueeze(-1)], dim=-1) + + protein = data.Protein( + edge_list, atom_type, bond_type, num_node=num_atom, num_residue=num_residue, + node_position=node_position, atom_name=atom_name, + atom2residue=atom2residue, residue_feature=residue_feature, + residue_type=residue_type) + + protein = data.Protein.pack([protein]) + protein = graph_construction_model(protein) + + return_data = Data() + return_data.edge_list = protein.edge_list + return_data.edge_weight = torch.ones(len(protein.edge_list)) + return_data.num_residue = protein.num_residue + return_data.num_node = protein.num_node + return_data.num_edge = protein.num_edge + return_data.x = residue_type # This is important to hack the code + return_data.node_feature = residue_feature + return_data.num_relation = protein.num_relation + return return_data + + def process(self): + print('Beginning Processing ...') + # This requires the installment of TorchDrug + + from torchdrug import transforms, layers + from torchdrug.layers import geometry + + graph_construction_model = layers.GraphConstruction( + node_layers=[geometry.AlphaCarbonNode()], + edge_layers=[geometry.SpatialEdge(radius=10.0, min_distance=5), geometry.KNNEdge(k=10, min_distance=5), geometry.SequentialEdge(max_distance=2)], + edge_feature="gearnet") + + # Load the file with the list of functions. + functions_ = [] + with open(self.root+"/unique_functions.txt", 'r') as mFile: + for line in mFile: + functions_.append(line.rstrip()) + + # Get the file list. + if self.split == "Train": + splitFile = "/training.txt" + elif self.split == "Val": + splitFile = "/validation.txt" + elif self.split == "Test": + splitFile = "/testing.txt" + + proteinNames_ = [] + fileList_ = [] + with open(self.root+splitFile, 'r') as mFile: + for line in mFile: + proteinNames_.append(line.rstrip()) + fileList_.append(self.root+"/data/"+line.rstrip()) + + # Load the functions. + print("Reading protein functions") + protFunct_ = {} + with open(self.root+"/chain_functions.txt", 'r') as mFile: + for line in mFile: + splitLine = line.rstrip().split(',') + if splitLine[0] in proteinNames_: + protFunct_[splitLine[0]] = int(splitLine[1]) + + # Load the dataset + print("Reading the data") + print(self.split) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + data_list = [] + for fileIter, curFile in tqdm(enumerate(fileList_)): + #print(curFile) + fileName = curFile.split('/')[-1] + curProtein = self.protein_to_graph(curFile+".hdf5", graph_construction_model=graph_construction_model) + #curProtein.id = fileName + curProtein.y = torch.tensor(protFunct_[proteinNames_[fileIter]]) + if not curProtein.x is None: + data_list.append(curProtein) + data, slices = self.collate(data_list) + torch.save((data, slices), self.processed_paths[0]) + print('Done!') + + def collate_fn(batch): + num_nodes = [] + num_edges = [] + num_residues = [] + num_cum_node = 0 + num_cum_edge = 0 + num_cum_residue = 0 + num_graph = 0 + data_dict = defaultdict(list) + for graph in batch: + num_nodes.append(graph.num_node) + num_edges.append(graph.num_edge) + num_residues.append(graph.num_residue) + for k, v in graph.items(): + if k in ["num_relation", "num_node", "num_edge", "num_residue"]: + continue + elif k in ["edge_list"]: + neo_v = v.clone() + neo_v[:, 0] += num_cum_node + neo_v[:, 1] += num_cum_node + data_dict[k].append(neo_v) + continue + + data_dict[k].append(v) + num_cum_node += graph.num_node + num_cum_edge += graph.num_edge + num_cum_residue += graph.num_residue + num_graph += 1 + + data_dict = {k: torch.cat(v) for k, v in data_dict.items()} + + num_nodes = torch.cat(num_nodes) + num_edges = torch.cat(num_edges) + num_residues = torch.cat(num_residues) + node2graph = torch.repeat_interleave(num_nodes) + num_node = torch.sum(num_nodes) + num_edge = torch.sum(num_edges) + + return Data( + num_node=num_node, num_nodes=num_nodes, num_edge=num_edge, num_edges=num_edges, num_residues=num_residues, num_relation=batch[0].num_relation, + node2graph=node2graph, batch_size=len(batch), **data_dict) diff --git a/Geom3D/datasets/dataset_FOLD.py b/Geom3D/datasets/dataset_FOLD.py index b1b179a..af8f68c 100644 --- a/Geom3D/datasets/dataset_FOLD.py +++ b/Geom3D/datasets/dataset_FOLD.py @@ -1,10 +1,11 @@ -# Credit to https://github.com/divelab/DIG/blob/dig-stable/dig/threedgraph/dataset/ECdataset.py import os.path as osp import numpy as np import warnings from tqdm import tqdm +from sklearn.preprocessing import normalize +import h5py -import torch +import torch import torch.nn.functional as F from torch_geometric.data import Data @@ -15,7 +16,7 @@ class DatasetFOLD(InMemoryDataset): def __init__(self, root, transform=None, pre_transform=None, pre_filter=None, split='train'): self.split = split self.root = root - + super(DatasetFOLD, self).__init__( root, transform, pre_transform, pre_filter) @@ -24,7 +25,7 @@ def __init__(self, root, transform=None, pre_transform=None, pre_filter=None, sp @property def processed_dir(self): - name = 'processed' + name = 'processed_FOLD' return osp.join(self.root, name, self.split) @property @@ -35,15 +36,15 @@ def raw_file_names(self): @property def processed_file_names(self): return 'data.pt' - - def _normalize(self,tensor, dim=-1): + + def _normalize(self, tensor, dim=-1): ''' Normalizes a `torch.Tensor` along dimension `dim` without `nan`s. ''' return torch.nan_to_num( torch.div(tensor, torch.norm(tensor, dim=dim, keepdim=True))) - def get_atom_pos(self, amino_types, atom_names, atom_amino_id, atom_pos): + def get_key_atom_pos(self, amino_types, atom_names, atom_amino_id, atom_pos): # atoms to compute side chain torsion angles: N, CA, CB, _G/_G1, _D/_D1, _E/_E1, _Z, NH1 mask_n = np.char.equal(atom_names, b'N') mask_ca = np.char.equal(atom_names, b'CA') @@ -96,24 +97,24 @@ def get_atom_pos(self, amino_types, atom_names, atom_amino_id, atom_pos): pos_h = torch.FloatTensor(pos_h) return pos_n, pos_ca, pos_c, pos_cb, pos_g, pos_d, pos_e, pos_z, pos_h - - def side_chain_embs(self, pos_n, pos_ca, pos_c, pos_cb, pos_g, pos_d, pos_e, pos_z, pos_h): + + def get_side_chain_angle_encoding(self, pos_n, pos_ca, pos_c, pos_cb, pos_g, pos_d, pos_e, pos_z, pos_h): v1, v2, v3, v4, v5, v6, v7 = pos_ca - pos_n, pos_cb - pos_ca, pos_g - pos_cb, pos_d - pos_g, pos_e - pos_d, pos_z - pos_e, pos_h - pos_z # five side chain torsion angles # We only consider the first four torsion angles in side chains since only the amino acid arginine has five side chain torsion angles, and the fifth angle is close to 0. - angle1 = torch.unsqueeze(self.compute_diherals(v1, v2, v3),1) - angle2 = torch.unsqueeze(self.compute_diherals(v2, v3, v4),1) - angle3 = torch.unsqueeze(self.compute_diherals(v3, v4, v5),1) - angle4 = torch.unsqueeze(self.compute_diherals(v4, v5, v6),1) - angle5 = torch.unsqueeze(self.compute_diherals(v5, v6, v7),1) + angle1 = torch.unsqueeze(self.diherals_ProNet(v1, v2, v3),1) + angle2 = torch.unsqueeze(self.diherals_ProNet(v2, v3, v4),1) + angle3 = torch.unsqueeze(self.diherals_ProNet(v3, v4, v5),1) + angle4 = torch.unsqueeze(self.diherals_ProNet(v4, v5, v6),1) + angle5 = torch.unsqueeze(self.diherals_ProNet(v5, v6, v7),1) side_chain_angles = torch.cat((angle1, angle2, angle3, angle4),1) side_chain_embs = torch.cat((torch.sin(side_chain_angles), torch.cos(side_chain_angles)),1) return side_chain_embs - - def bb_embs(self, X): + + def get_backbone_angle_encoding(self, X): # X should be a num_residues x 3 x 3, order N, C-alpha, and C atoms of each residue # N coords: X[:,0,:] # CA coords: X[:,1,:] @@ -128,7 +129,7 @@ def bb_embs(self, X): u1 = U[1:-1] u2 = U[2:] - angle = self.compute_diherals(u0, u1, u2) + angle = self.diherals_ProNet(u0, u1, u2) # add phi[0], psi[-1], omega[-1] with value 0 angle = F.pad(angle, [1, 2]) @@ -136,7 +137,7 @@ def bb_embs(self, X): angle_features = torch.cat([torch.cos(angle), torch.sin(angle)], 1) return angle_features - def compute_diherals(self, v1, v2, v3): + def diherals_ProNet(self, v1, v2, v3): n1 = torch.cross(v1, v2) n2 = torch.cross(v2, v3) a = (n1 * n2).sum(dim=-1) @@ -144,12 +145,11 @@ def compute_diherals(self, v1, v2, v3): torsion = torch.nan_to_num(torch.atan2(b, a)) return torsion - def protein_to_graph(self, pFilePath): - import h5py - h5File = h5py.File(pFilePath, "r") + def extract_protein_data(self, pFilePath): + h5File = h5py.File(pFilePath+".hdf5", "r") data = Data() - - amino_types = h5File['amino_types'][()] # size: (n_amino,) + + amino_types = h5File['amino_types'][()] # residue or amino acid, size: (n_amino,) mask = amino_types == -1 if np.sum(mask) > 0: amino_types[mask] = 25 # for amino acid types, set the value of -1 to 25 @@ -157,31 +157,34 @@ def protein_to_graph(self, pFilePath): atom_names = h5File['atom_names'][()] # size: (n_atom,) atom_pos = h5File['atom_pos'][()][0] #size: (n_atom,3) - # atoms to compute side chain torsion angles: N, CA, CB, _G/_G1, _D/_D1, _E/_E1, _Z, NH1 - pos_n, pos_ca, pos_c, pos_cb, pos_g, pos_d, pos_e, pos_z, pos_h = self.get_atom_pos(amino_types, atom_names, atom_amino_id, atom_pos) - - # five side chain torsion angles - # We only consider the first four torsion angles in side chains since only the amino acid arginine has five side chain torsion angles, and the fifth angle is close to 0. - side_chain_embs = self.side_chain_embs(pos_n, pos_ca, pos_c, pos_cb, pos_g, pos_d, pos_e, pos_z, pos_h) - side_chain_embs[torch.isnan(side_chain_embs)] = 0 - data.side_chain_embs = side_chain_embs + ##### compute side chain torsion angles: N, CA, CB, _G/_G1, _D/_D1, _E/_E1, _Z, NH1 ##### + # extract key atom (e.g., backbone) positions + pos_n, pos_ca, pos_c, pos_cb, pos_g, pos_d, pos_e, pos_z, pos_h = self.get_key_atom_pos(amino_types, atom_names, atom_amino_id, atom_pos) + + # calculate side chain torsion angles, up to four + # do encoding + side_chain_angle_encoding = self.get_side_chain_angle_encoding(pos_n, pos_ca, pos_c, pos_cb, pos_g, pos_d, pos_e, pos_z, pos_h) + side_chain_angle_encoding[torch.isnan(side_chain_angle_encoding)] = 0 # three backbone torsion angles - bb_embs = self.bb_embs(torch.cat((torch.unsqueeze(pos_n,1), torch.unsqueeze(pos_ca,1), torch.unsqueeze(pos_c,1)),1)) - bb_embs[torch.isnan(bb_embs)] = 0 - data.bb_embs = bb_embs + backbone_angle_encoding = self.get_backbone_angle_encoding(torch.cat((torch.unsqueeze(pos_n,1), torch.unsqueeze(pos_ca,1), torch.unsqueeze(pos_c,1)),1)) + backbone_angle_encoding[torch.isnan(backbone_angle_encoding)] = 0 - data.x = torch.unsqueeze(torch.tensor(amino_types),1) + data.seq = torch.LongTensor(amino_types) + data.side_chain_angle_encoding = side_chain_angle_encoding + data.backbone_angle_encoding = backbone_angle_encoding data.coords_ca = pos_ca data.coords_n = pos_n data.coords_c = pos_c - - assert len(data.x)==len(data.coords_ca)==len(data.coords_n)==len(data.coords_c)==len(data.side_chain_embs)==len(data.bb_embs) + data.x = atom_names + data.pos = torch.tensor(atom_pos) + data.num_nodes = len(pos_ca) h5File.close() + return data - def process(self): + def process(self): print('Beginning Processing ...') # Load the file with the list of functions. @@ -194,12 +197,13 @@ def process(self): # Get the file list. fileList_ = [] cathegories_ = [] + with open(self.root+"/"+self.split+".txt", 'r') as mFile: for curLine in mFile: - splitLine = curLine.rstrip().split('\t') - curClass = classes_[splitLine[-1]] - fileList_.append(self.root+"/"+self.split+"/"+splitLine[0]) - cathegories_.append(curClass) + splitLine = curLine.rstrip().split('\t') + curClass = classes_[splitLine[-1]] + fileList_.append(self.root+"/"+self.split+"/"+splitLine[0]) + cathegories_.append(curClass) # Load the dataset print("Reading the data") @@ -208,18 +212,11 @@ def process(self): data_list = [] for fileIter, curFile in tqdm(enumerate(fileList_)): fileName = curFile.split('/')[-1] - curProtein = self.protein_to_graph(curFile+".hdf5") - curProtein.id = fileName - curProtein.y = torch.tensor(cathegories_[fileIter]) - if not curProtein.x is None: - data_list.append(curProtein) + protein = self.extract_protein_data(curFile) + protein.id = fileName + protein.y = torch.tensor(cathegories_[fileIter]) + if not protein.seq is None: + data_list.append(protein) data, slices = self.collate(data_list) torch.save((data, slices), self.processed_paths[0]) - print('Done!') - - -if __name__ == "__main__": - for split in ['training', 'validation', 'test_fold', 'test_superfamily', 'test_family']: - print('#### Now processing {} data ####'.format(split)) - dataset = DatasetFOLD(root='../../data/FOLD', split=split) - print(dataset) \ No newline at end of file + print('Done!') \ No newline at end of file diff --git a/Geom3D/datasets/dataset_FOLD_CDConv.py b/Geom3D/datasets/dataset_FOLD_CDConv.py deleted file mode 100644 index db2149b..0000000 --- a/Geom3D/datasets/dataset_FOLD_CDConv.py +++ /dev/null @@ -1,136 +0,0 @@ -#Credit to https://github.com/hehefan/Continuous-Discrete-Convolution/blob/main/datasets.py - -import numpy as np -from sklearn.preprocessing import normalize -import torch -from torch.utils.data import Dataset -from torch_geometric.data import Data -import os - -def orientation(pos): - u = normalize(X=pos[1:,:] - pos[:-1,:], norm='l2', axis=1) - u1 = u[1:,:] - u2 = u[:-1, :] - b = normalize(X=u2 - u1, norm='l2', axis=1) - n = normalize(X=np.cross(u2, u1), norm='l2', axis=1) - o = normalize(X=np.cross(b, n), norm='l2', axis=1) - ori = np.stack([b, n, o], axis=1) - return np.concatenate([np.expand_dims(ori[0], 0), ori, np.expand_dims(ori[-1], 0)], axis=0) - -def fmax(probs, labels): - thresholds = np.arange(0, 1, 0.01) - f_max = 0.0 - - for threshold in thresholds: - precision = 0.0 - recall = 0.0 - precision_cnt = 0 - recall_cnt = 0 - for idx in range(probs.shape[0]): - prob = probs[idx] - label = labels[idx] - pred = (prob > threshold).astype(np.int32) - correct_sum = np.sum(label*pred) - pred_sum = np.sum(pred) - label_sum = np.sum(label) - if pred_sum > 0: - precision += correct_sum/pred_sum - precision_cnt += 1 - if label_sum > 0: - recall += correct_sum/label_sum - recall_cnt += 1 - if recall_cnt > 0: - recall = recall / recall_cnt - else: - recall = 0 - if precision_cnt > 0: - precision = precision / precision_cnt - else: - precision = 0 - f = (2.*precision*recall)/max(precision+recall, 1e-8) - f_max = max(f, f_max) - - return f_max - - -# AA Letter to id -aa = "ACDEFGHIKLMNPQRSTVWYX" -aa_to_id = {} -for i in range(0, 21): - aa_to_id[aa[i]] = i - -class DatasetFOLD_CDConv(Dataset): - - def __init__(self, root='/content/drive/MyDrive/proteinDT/fold', random_seed=0, split='training'): - - self.random_state = np.random.RandomState(random_seed) - self.split = split - - npy_dir = os.path.join(root, 'coordinates', split) - fasta_file = os.path.join(root, split+'.fasta') - - # Load the fasta file. - protein_seqs = [] - with open(fasta_file, 'r') as f: - protein_name = '' - for line in f: - if line.startswith('>'): - protein_name = line.rstrip()[1:] - else: - amino_chain = line.rstrip() - amino_ids = [] - for amino in amino_chain: - amino_ids.append(aa_to_id[amino]) - protein_seqs.append((protein_name, np.array(amino_ids))) - - fold_classes = {} - with open(os.path.join(root, 'class_map.txt'), 'r') as f: - for line in f: - arr = line.rstrip().split('\t') - fold_classes[arr[0]] = int(arr[1]) - - protein_folds = {} - with open(os.path.join(root, split+'.txt'), 'r') as f: - for line in f: - arr = line.rstrip().split('\t') - protein_folds[arr[0]] = fold_classes[arr[-1]] - - self.data = [] - self.labels = [] - for protein_name, amino_ids in protein_seqs: - pos = np.load(os.path.join(npy_dir, protein_name+".npy")) - - center = np.sum(a=pos, axis=0, keepdims=True)/pos.shape[0] - pos = pos - center - ori = orientation(pos) - - self.data.append((pos, ori, amino_ids.astype(int))) - - self.labels.append(protein_folds[protein_name]) - - self.num_classes = max(self.labels) + 1 - - def __len__(self): - return len(self.data) - - def __getitem__(self, idx): - - pos, ori, amino = self.data[idx] - label = self.labels[idx] - - if self.split == "training": - pos = pos + self.random_state.normal(0.0, 0.05, pos.shape) - - pos = pos.astype(dtype=np.float32) - ori = ori.astype(dtype=np.float32) - seq = np.expand_dims(a=np.arange(pos.shape[0]), axis=1).astype(dtype=np.float32) - - data = Data(x = torch.from_numpy(amino), # [num_nodes, num_node_features] - edge_index = None, # [2, num_edges] - edge_attr = None, # [num_edges, num_edge_features] - y = label, - ori = torch.from_numpy(ori), # [num_nodes, 3, 3] - seq = torch.from_numpy(seq), # [num_nodes, 1] - pos = torch.from_numpy(pos)) # [num_nodes, num_dimensions] - - return data diff --git a/Geom3D/datasets/dataset_FOLD_GearNet.py b/Geom3D/datasets/dataset_FOLD_GearNet.py index 51ba4f9..4c5f0d1 100644 --- a/Geom3D/datasets/dataset_FOLD_GearNet.py +++ b/Geom3D/datasets/dataset_FOLD_GearNet.py @@ -26,7 +26,7 @@ def __init__(self, root, transform=None, pre_transform=None, pre_filter=None, sp @property def processed_dir(self): - name = 'processed_GearNet' + name = 'processed_FOLD_GearNet' return osp.join(self.root, name, self.split) @property @@ -174,5 +174,4 @@ def collate_fn(batch): return Data( num_node=num_node, num_nodes=num_nodes, num_edge=num_edge, num_edges=num_edges, num_residues=num_residues, num_relation=batch[0].num_relation, - node2graph=node2graph, **data_dict) - \ No newline at end of file + node2graph=node2graph, batch_size=len(batch), **data_dict) diff --git a/Geom3D/datasets/dataset_GO.py b/Geom3D/datasets/dataset_GO.py new file mode 100644 index 0000000..e127a9e --- /dev/null +++ b/Geom3D/datasets/dataset_GO.py @@ -0,0 +1,454 @@ +import os.path as osp +import os +import numpy as np +import warnings +from tqdm import tqdm +from sklearn.preprocessing import normalize +import h5py + +import torch, math +import torch.nn.functional as F +import torch_cluster + +from Bio.PDB import PDBParser +from Bio.PDB.Polypeptide import three_to_one, is_aa +import sys +import Bio.PDB +import Bio.PDB.StructureBuilder +from Bio.PDB.Residue import Residue + +from torch_geometric.data import Data +from torch_geometric.data import InMemoryDataset + + +class SloppyStructureBuilder(Bio.PDB.StructureBuilder.StructureBuilder): + """Cope with resSeq < 10,000 limitation by just incrementing internally.""" + + def __init__(self, verbose=False): + Bio.PDB.StructureBuilder.StructureBuilder.__init__(self) + self.max_resseq = -1 + self.verbose = verbose + + def init_residue(self, resname, field, resseq, icode): + """Initiate a new Residue object. + Arguments: + resname: string, e.g. "ASN" + field: hetero flag, "W" for waters, "H" for hetero residues, otherwise blanc. + resseq: int, sequence identifier + icode: string, insertion code + Return: + None + """ + if field != " ": + if field == "H": + # The hetero field consists of + # H_ + the residue name (e.g. H_FUC) + field = "H_" + resname + res_id = (field, resseq, icode) + + if resseq > self.max_resseq: + self.max_resseq = resseq + + if field == " ": + fudged_resseq = False + while self.chain.has_id(res_id) or resseq == 0: + # There already is a residue with the id (field, resseq, icode) + # resseq == 0 catches already wrapped residue numbers which + # do not trigger the has_id() test. + # + # Be sloppy and just increment... + # (This code will not leave gaps in resids... I think) + # + # XXX: shouldn't we also do this for hetero atoms and water?? + self.max_resseq += 1 + resseq = self.max_resseq + res_id = (field, resseq, icode) # use max_resseq! + fudged_resseq = True + + if fudged_resseq and self.verbose: + sys.stderr.write( + "Residues are wrapping (Residue " + + "('%s', %i, '%s') at line %i)." + % (field, resseq, icode, self.line_counter) + + ".... assigning new resid %d.\n" % self.max_resseq + ) + residue = Residue(res_id, resname, self.segid) + self.chain.add(residue) + self.residue = residue + return None + + + +class DatasetGO(InMemoryDataset): + def __init__(self, root, transform=None, pre_transform=None, pre_filter=None, split='train', level= "mf", percent=0.3): + self.split = split + self.root = root + self.level = level + self.percent = percent + + self.letter_to_num = {'C': 4, 'D': 3, 'S': 15, 'Q': 5, 'K': 11, 'I': 9, + 'P': 14, 'T': 16, 'F': 13, 'A': 0, 'G': 7, 'H': 8, + 'E': 6, 'L': 10, 'R': 1, 'W': 17, 'V': 19, + 'N': 2, 'Y': 18, 'M': 12, "X":20} + + super(DatasetGO, self).__init__( + root, transform, pre_transform, pre_filter) + + self.transform, self.pre_transform, self.pre_filter = transform, pre_transform, pre_filter + self.data, self.slices = torch.load(self.processed_paths[0]) + + + @property + def processed_dir(self): + name = 'processed_GO_' + self.level + if self.split != "test": + return osp.join(self.root, name, self.split) + else: + return osp.join(self.root, name, self.split + "_" + str(self.percent)) + + @property + def raw_file_names(self): + name = self.split + '.txt' + return name + + @property + def processed_file_names(self): + return 'data.pt' + + def get_side_chain_angle_encoding(self, pos_n, pos_ca, pos_c, pos_cb, pos_g, pos_d, pos_e, pos_z, pos_h): + v1, v2, v3, v4, v5, v6, v7 = pos_ca - pos_n, pos_cb - pos_ca, pos_g - pos_cb, pos_d - pos_g, pos_e - pos_d, pos_z - pos_e, pos_h - pos_z + + # five side chain torsion angles + # We only consider the first four torsion angles in side chains since only the amino acid arginine has five side chain torsion angles, and the fifth angle is close to 0. + angle1 = torch.unsqueeze(self.diherals_ProNet(v1, v2, v3),1) + angle2 = torch.unsqueeze(self.diherals_ProNet(v2, v3, v4),1) + angle3 = torch.unsqueeze(self.diherals_ProNet(v3, v4, v5),1) + angle4 = torch.unsqueeze(self.diherals_ProNet(v4, v5, v6),1) + angle5 = torch.unsqueeze(self.diherals_ProNet(v5, v6, v7),1) + + side_chain_angles = torch.cat((angle1, angle2, angle3, angle4),1) + side_chain_embs = torch.cat((torch.sin(side_chain_angles), torch.cos(side_chain_angles)),1) + + return side_chain_embs + + def get_backbone_angle_encoding(self, X): + # X should be a num_residues x 3 x 3, order N, C-alpha, and C atoms of each residue + # N coords: X[:,0,:] + # CA coords: X[:,1,:] + # C coords: X[:,2,:] + # return num_residues x 6 + # From https://github.com/jingraham/neurips19-graph-protein-design + + X = torch.reshape(X, [3 * X.shape[0], 3]) + dX = X[1:] - X[:-1] + U = self._normalize(dX, dim=-1) + u0 = U[:-2] + u1 = U[1:-1] + u2 = U[2:] + + angle = self.diherals_ProNet(u0, u1, u2) + + # add phi[0], psi[-1], omega[-1] with value 0 + angle = F.pad(angle, [1, 2]) + angle = torch.reshape(angle, [-1, 3]) + angle_features = torch.cat([torch.cos(angle), torch.sin(angle)], 1) + return angle_features + + def diherals_ProNet(self, v1, v2, v3): + n1 = torch.cross(v1, v2) + n2 = torch.cross(v2, v3) + a = (n1 * n2).sum(dim=-1) + b = torch.nan_to_num((torch.cross(n1, n2) * v2).sum(dim=-1) / v2.norm(dim=1)) + torsion = torch.nan_to_num(torch.atan2(b, a)) + return torsion + + def _normalize(self, tensor, dim=-1): + ''' + Normalizes a `torch.Tensor` along dimension `dim` without `nan`s. + ''' + return torch.nan_to_num( + torch.div(tensor, torch.norm(tensor, dim=dim, keepdim=True))) + + def three_to_one_standard(self, res): + if not is_aa(res, standard=True): + return "X" + + return three_to_one(res) + + def chain_info(self, chain, name): + """Convert a PDB chain in to coordinates of target atoms from all + AAs + + Args: + chain: a Bio.PDB.Chain object + target_atoms: Target atoms which residues will be resturned. + name: String. Name of the protein. + Returns: + Dictonary containing protein sequence `seq`, 3D coordinates `coord` and name `name`. + + """ + atom_names, atom_amino_id, atom_pos, residue_types = [], [], [], [] + pdb_seq = "" + residue_index = 0 + for residue in chain.get_residues(): + if is_aa(residue) and any(atom.get_name() == "CA" for atom in residue.get_atoms()): + residue_name = self.three_to_one_standard(residue.get_resname()) + pdb_seq += residue_name + residue_types.append(self.letter_to_num[residue_name]) + + for atom in residue.get_atoms(): + atom_names.append(atom.get_name()) + atom_amino_id.append(residue_index) + atom_pos.append(atom.coord) + + residue_index += 1 + + mask_n = np.char.equal(atom_names, 'N') + mask_ca = np.char.equal(atom_names, 'CA') + mask_c = np.char.equal(atom_names, 'C') + mask_cb = np.char.equal(atom_names, 'CB') + mask_g = np.char.equal(atom_names, 'CG') | np.char.equal(atom_names, 'SG') | np.char.equal(atom_names, 'OG') | np.char.equal(atom_names, 'CG1') | np.char.equal(atom_names, 'OG1') + mask_d = np.char.equal(atom_names, 'CD') | np.char.equal(atom_names, 'SD') | np.char.equal(atom_names, 'CD1') | np.char.equal(atom_names, 'OD1') | np.char.equal(atom_names, 'ND1') + mask_e = np.char.equal(atom_names, 'CE') | np.char.equal(atom_names, 'NE') | np.char.equal(atom_names, 'OE1') + mask_z = np.char.equal(atom_names, 'CZ') | np.char.equal(atom_names, 'NZ') + mask_h = np.char.equal(atom_names, 'NH1') + + atom_amino_id = np.array(atom_amino_id) + atom_pos = np.array(atom_pos) + + pos_n = np.full((len(pdb_seq), 3),np.nan) + pos_n[atom_amino_id[mask_n]] = atom_pos[mask_n] + pos_n = torch.FloatTensor(pos_n) + + pos_ca = np.full((len(pdb_seq), 3),np.nan) + pos_ca[atom_amino_id[mask_ca]] = atom_pos[mask_ca] + pos_ca = torch.FloatTensor(pos_ca) + + pos_c = np.full((len(pdb_seq), 3),np.nan) + pos_c[atom_amino_id[mask_c]] = atom_pos[mask_c] + pos_c = torch.FloatTensor(pos_c) + + # if data only contain pos_ca, we set the position of C and N as the position of CA + pos_n[torch.isnan(pos_n)] = pos_ca[torch.isnan(pos_n)] + pos_c[torch.isnan(pos_c)] = pos_ca[torch.isnan(pos_c)] + + pos_cb = np.full((len(pdb_seq), 3),np.nan) + pos_cb[atom_amino_id[mask_cb]] = atom_pos[mask_cb] + pos_cb = torch.FloatTensor(pos_cb) + + pos_g = np.full((len(pdb_seq), 3),np.nan) + pos_g[atom_amino_id[mask_g]] = atom_pos[mask_g] + pos_g = torch.FloatTensor(pos_g) + + pos_d = np.full((len(pdb_seq), 3),np.nan) + pos_d[atom_amino_id[mask_d]] = atom_pos[mask_d] + pos_d = torch.FloatTensor(pos_d) + + pos_e = np.full((len(pdb_seq), 3),np.nan) + pos_e[atom_amino_id[mask_e]] = atom_pos[mask_e] + pos_e = torch.FloatTensor(pos_e) + + pos_z = np.full((len(pdb_seq), 3),np.nan) + pos_z[atom_amino_id[mask_z]] = atom_pos[mask_z] + pos_z = torch.FloatTensor(pos_z) + + pos_h = np.full((len(pdb_seq), 3),np.nan) + pos_h[atom_amino_id[mask_h]] = atom_pos[mask_h] + pos_h = torch.FloatTensor(pos_h) + + chain_struc = { + 'name': name, + 'pos_n': pos_n, + 'pos_ca': pos_ca, + 'pos_c': pos_c, + 'pos_cb': pos_cb, + 'pos_g': pos_g, + 'pos_d': pos_d, + 'pos_e': pos_e, + 'pos_z': pos_z, + 'pos_h': pos_h, + 'atom_names': atom_names, + 'atom_pos': atom_pos, + 'residue_types': residue_types + } + + if len(pdb_seq) <= 1: + # has no or only 1 AA in the chain + return None + + return chain_struc + + + def extract_protein_data(self, pFilePath): + data = Data() + + pdb_parser = PDBParser( + QUIET=True, + PERMISSIVE=True, + structure_builder=SloppyStructureBuilder(), + ) + + name = os.path.basename(pFilePath).split("_")[0] + + try: + structure = pdb_parser.get_structure(name, pFilePath) + except Exception as e: + print(pFilePath, "raised an error:") + print(e) + return None + + records = [] + chain_ids = [] + + for chain in structure.get_chains(): + if chain.id in chain_ids: # skip duplicated chains + continue + chain_ids.append(chain.id) + record = self.chain_info(chain, "{}-{}".format(name.split("-")[0], chain.id)) + if record is not None: + records.append(record) + + records = [rec for rec in records if rec["name"] in self.data] + + for i in records: + if i["name"] == name: + pos_n, pos_ca, pos_c, pos_cb, pos_g, pos_d, pos_e, pos_z, pos_h, atom_names, atom_pos, residue_types = (i[k] for k in ["pos_n", "pos_ca", "pos_c", "pos_cb", "pos_g", "pos_d", "pos_e", "pos_z", "pos_h", "atom_names", "atom_pos", "residue_types"]) + + # calculate side chain torsion angles, up to four + # do encoding + side_chain_angle_encoding = self.get_side_chain_angle_encoding(pos_n, pos_ca, pos_c, pos_cb, pos_g, pos_d, pos_e, pos_z, pos_h) + side_chain_angle_encoding[torch.isnan(side_chain_angle_encoding)] = 0 + + # three backbone torsion angles + backbone_angle_encoding = self.get_backbone_angle_encoding(torch.cat((torch.unsqueeze(pos_n,1), torch.unsqueeze(pos_ca,1), torch.unsqueeze(pos_c,1)),1)) + backbone_angle_encoding[torch.isnan(backbone_angle_encoding)] = 0 + + data.seq = torch.LongTensor(residue_types) + data.side_chain_angle_encoding = side_chain_angle_encoding + data.backbone_angle_encoding = backbone_angle_encoding + data.coords_ca = pos_ca + data.coords_n = pos_n + data.coords_c = pos_c + data.x = atom_names + data.atom_pos = torch.tensor(atom_pos) + data.num_nodes = len(pos_ca) + + return data + + def process(self): + print('Beginning Processing ...') + + if self.split != "test": + with open(os.path.join(self.root, f"nrPDB-GO_{self.split}.txt"), 'r') as file: + self.data = set([line.strip() for line in file]) + else: + self.data = set() + with open(os.path.join(self.root, "nrPDB-GO_test.csv"), 'r') as f: + head = True + for line in f: + if head: + head = False + continue + arr = line.rstrip().split(',') + if self.percent == 0.3 and arr[1] == '1': + self.data.add(arr[0]) + elif self.percent == 0.4 and arr[2] == '1': + self.data.add(arr[0]) + elif self.percent == 0.5 and arr[3] == '1': + self.data.add(arr[0]) + elif self.percent == 0.7 and arr[4] == '1': + self.data.add(arr[0]) + elif self.percent == 0.95 and arr[5] == '1': + self.data.add(arr[0]) + else: + pass + + + # 2. Parse the structure files and save to json files + structure_file_dir = osp.join( + self.root, f"{self.split}" + ) + files = os.listdir(structure_file_dir) + + level_idx = 0 + go_cnt = 0 + go_num = {} + go_annotations = {} + self.labels = {} + with open(osp.join(self.root, 'nrPDB-GO_annot.tsv'), 'r') as f: + for idx, line in enumerate(f): + if idx == 1 and self.level == "mf": + level_idx = 1 + arr = line.rstrip().split('\t') + for go in arr: + go_annotations[go] = go_cnt + go_num[go] = 0 + go_cnt += 1 + elif idx == 5 and self.level == "bp": + level_idx = 2 + arr = line.rstrip().split('\t') + for go in arr: + go_annotations[go] = go_cnt + go_num[go] = 0 + go_cnt += 1 + elif idx == 9 and self.level == "cc": + level_idx = 3 + arr = line.rstrip().split('\t') + for go in arr: + go_annotations[go] = go_cnt + go_num[go] = 0 + go_cnt += 1 + elif idx > 12: + arr = line.rstrip().split('\t') + protein_labels = [] + if len(arr) > level_idx: + protein_go_list = arr[level_idx] + protein_go_list = protein_go_list.split(',') + for go in protein_go_list: + if len(go) > 0: + protein_labels.append(go_annotations[go]) + go_num[go] += 1 + self.labels[arr[0]] = np.array(protein_labels) + + self.num_class = len(go_annotations) + + invalid_PDB_file_name_list = ["1X18-E_5719.pdb", "2UV2-A_11517.pdb", "1EIS-A_990.pdb", "4UPV-Q_24858.pdb", "1DIN-A_746.pdb"] + + data_list = [] + for i in tqdm(range(len(files))): + if files[i].split("_")[0] in self.data: + if files[i] in invalid_PDB_file_name_list: + print("Skipping invalid file {}...".format(files[i])) + continue + file_name = osp.join(self.root, self.split, files[i]) + protein = self.extract_protein_data(file_name) + label = np.zeros((self.num_class,)).astype(np.float32) + + if len(self.labels[osp.basename(file_name).split("_")[0]]) > 0: + label[self.labels[osp.basename(file_name).split("_")[0]]] = 1.0 + + if protein is not None: + protein.id = files[i] + protein.y = torch.tensor(label).unsqueeze(0) + data_list.append(protein) + + data, slices = self.collate(data_list) + torch.save((data, slices), self.processed_paths[0]) + print('Done!') + +# if __name__ == "__main__": +# pdb_parser = PDBParser( +# QUIET=True, +# PERMISSIVE=True, +# structure_builder=SloppyStructureBuilder(), +# ) + +# for level in ["mf", "bp", "cc"]: +# for split in ['test', 'train', 'valid']: +# print('#### Now processing {} data ####'.format(split)) +# if split != "test": +# dataset = DatasetGO(root="/lustre07/scratch/liusheng/GearNet/GeneOntology", level=level, split=split) +# else: +# for cutoff in [0.3, 0.4, 0.5, 0.7, 0.95]: +# dataset = DatasetGO(root="/lustre07/scratch/liusheng/GearNet/GeneOntology", level=level, split=split, percent=cutoff) + \ No newline at end of file diff --git a/Geom3D/datasets/dataset_GO_GearNet.py b/Geom3D/datasets/dataset_GO_GearNet.py new file mode 100644 index 0000000..cbb44d3 --- /dev/null +++ b/Geom3D/datasets/dataset_GO_GearNet.py @@ -0,0 +1,250 @@ +import os.path as osp +import os +import numpy as np +import warnings +from tqdm import tqdm +from sklearn.preprocessing import normalize +from collections import defaultdict +import h5py +import itertools +import argparse + +import torch, math +import torch.nn.functional as F +import torch_cluster + +from Bio.PDB import PDBParser +from Bio.PDB.Polypeptide import three_to_one, is_aa +import sys +import Bio.PDB +import Bio.PDB.StructureBuilder +from Bio.PDB.Residue import Residue + +from torch_geometric.data import Data +from torch_geometric.data import InMemoryDataset + + +class DatasetGOGearNet(InMemoryDataset): + def __init__(self, root, transform=None, pre_transform=None, pre_filter=None, split='train', level="mf", percent=0.95): + self.split = split + self.root = root + self.level = level + self.percent = percent + + self.letter_to_num = {'C': 4, 'D': 3, 'S': 15, 'Q': 5, 'K': 11, 'I': 9, + 'P': 14, 'T': 16, 'F': 13, 'A': 0, 'G': 7, 'H': 8, + 'E': 6, 'L': 10, 'R': 1, 'W': 17, 'V': 19, + 'N': 2, 'Y': 18, 'M': 12, "X":20} + + super(DatasetGOGearNet, self).__init__( + root, transform, pre_transform, pre_filter) + + self.transform, self.pre_transform, self.pre_filter = transform, pre_transform, pre_filter + self.data, self.slices = torch.load(self.processed_paths[0]) + + + @property + def processed_dir(self): + name = 'processed_GO_GearNet_' + self.level + if self.split != "test": + return osp.join(self.root, name, self.split) + else: + return osp.join(self.root, name, self.split + "_" + str(self.percent)) + + @property + def raw_file_names(self): + name = self.split + '.txt' + return name + + @property + def processed_file_names(self): + return 'data.pt' + + + def extract_protein_data(self, pFilePath, graph_construction_model): + from torchdrug import data + protein = data.Protein.from_pdb(pFilePath) + protein = data.Protein.pack([protein]) + protein = graph_construction_model(protein) + item = {"graph": protein} + + if self.transform: + item = self.transform(item) + + protein = item["graph"] + seq = protein.to_sequence()[0] + residue_type = [] + residue_feature = [] + + for i in seq: + residue_type.append(data.Protein.residue_symbol2id.get(i, 0)) + residue_feature.append(data.feature.onehot(data.Protein.id2residue.get(data.Protein.residue_symbol2id.get(i)), data.feature.residue_vocab, allow_unknown=True)) + return_data = Data() + return_data.edge_list = protein.edge_list + return_data.edge_weight = torch.ones(len(protein.edge_list)) + return_data.num_residue = protein.num_residue + return_data.num_node = protein.num_node + return_data.num_edge = protein.num_edge + return_data.x = residue_type # This is important to hack the code + return_data.node_feature = residue_feature + return_data.num_relation = protein.num_relation + return_data.node_position = protein.node_position + return_data.edge_feature = protein.edge_feature + + return return_data + + + def collate_fn(batch): + num_nodes = [] + num_edges = [] + num_residues = [] + node_positions = [] + y = [] + num_cum_node = 0 + num_cum_edge = 0 + num_cum_residue = 0 + num_graph = 0 + data_dict = defaultdict(list) + + for graph in batch: + num_nodes.append(graph.num_node) + num_edges.append(graph.num_edge) + num_residues.append(graph.num_residue) + node_positions.append(graph.node_position) + y.append(graph.y[0]) + for k, v in graph.items(): + if k in ["num_relation", "num_node", "num_edge", "num_residue", "node_position", "y", "id"]: + continue + elif k in ["edge_list"]: + neo_v = v.clone() + neo_v[:, 0] += num_cum_node + neo_v[:, 1] += num_cum_node + data_dict[k].append(neo_v) + continue + + data_dict[k].append(v) + num_cum_node += graph.num_node + num_cum_edge += graph.num_edge + num_cum_residue += graph.num_residue + num_graph += 1 + + data_dict = {k: torch.cat([torch.tensor(v) for v in lst]) for k, lst in data_dict.items()} + + num_nodes = torch.cat(num_nodes) + num_edges = torch.cat(num_edges) + num_residues = torch.cat(num_residues) + node_positions = torch.cat(node_positions) + node2graph = torch.repeat_interleave(num_nodes) + num_node = torch.sum(num_nodes) + num_edge = torch.sum(num_edges) + + return Data( + num_nodes=num_nodes, num_node=num_node, num_edges=num_edges, num_edge=num_edge, num_residues=num_residues, num_relation=batch[0].num_relation, node_position=node_positions, + node2graph=node2graph, batch_size=len(batch), y=torch.stack(y), **data_dict) + + def process(self): + print('Beginning Processing ...') + + from torchdrug import transforms, layers + from torchdrug.layers import geometry + + self.transform = transforms.ProteinView("residue") + + if self.split != "test": + with open(os.path.join(self.root, f"nrPDB-GO_{self.split}.txt"), 'r') as file: + self.data = set([line.strip() for line in file]) + else: + self.data = set() + with open(os.path.join(self.root, "nrPDB-GO_test.csv"), 'r') as f: + head = True + for line in f: + if head: + head = False + continue + arr = line.rstrip().split(',') + if self.percent == 0.3 and arr[1] == '1': + self.data.add(arr[0]) + elif self.percent == 0.4 and arr[2] == '1': + self.data.add(arr[0]) + elif self.percent == 0.5 and arr[3] == '1': + self.data.add(arr[0]) + elif self.percent == 0.7 and arr[4] == '1': + self.data.add(arr[0]) + elif self.percent == 0.95 and arr[5] == '1': + self.data.add(arr[0]) + else: + pass + + structure_file_dir = osp.join( + self.root, f"{self.split}" + ) + files = os.listdir(structure_file_dir) + + level_idx = 0 + go_cnt = 0 + go_num = {} + go_annotations = {} + self.labels = {} + with open(osp.join(self.root, 'nrPDB-GO_annot.tsv'), 'r') as f: + for idx, line in enumerate(f): + if idx == 1 and self.level == "mf": + level_idx = 1 + arr = line.rstrip().split('\t') + for go in arr: + go_annotations[go] = go_cnt + go_num[go] = 0 + go_cnt += 1 + elif idx == 5 and self.level == "bp": + level_idx = 2 + arr = line.rstrip().split('\t') + for go in arr: + go_annotations[go] = go_cnt + go_num[go] = 0 + go_cnt += 1 + elif idx == 9 and self.level == "cc": + level_idx = 3 + arr = line.rstrip().split('\t') + for go in arr: + go_annotations[go] = go_cnt + go_num[go] = 0 + go_cnt += 1 + elif idx > 12: + arr = line.rstrip().split('\t') + protein_labels = [] + if len(arr) > level_idx: + protein_go_list = arr[level_idx] + protein_go_list = protein_go_list.split(',') + for go in protein_go_list: + if len(go) > 0: + protein_labels.append(go_annotations[go]) + go_num[go] += 1 + self.labels[arr[0]] = np.array(protein_labels) + + self.num_class = len(go_annotations) + + graph_construction_model = layers.GraphConstruction( + node_layers=[geometry.AlphaCarbonNode()], + edge_layers=[geometry.SpatialEdge(radius=10.0, min_distance=5), geometry.KNNEdge(k=10, min_distance=5), geometry.SequentialEdge(max_distance=2)], + edge_feature="gearnet") + + data_list = [] + for i in tqdm(range(len(files))): + if files[i].split("_")[0] in self.data and files[i].split("_")[0] not in ["1X18-E", "2UV2-A", "1EIS-A", "4UPV-Q", "1DIN-A"]: + file_name = osp.join(self.root, self.split, files[i]) + try: + protein = self.extract_protein_data(file_name, graph_construction_model) + except: + protein = None + label = np.zeros((self.num_class,)).astype(np.float32) + + if len(self.labels[osp.basename(file_name).split("_")[0]]) > 0: + label[self.labels[osp.basename(file_name).split("_")[0]]] = 1.0 + + if protein is not None: + protein.id = files[i] + protein.y = torch.tensor(label).unsqueeze(0) + data_list.append(protein) + + data, slices = self.collate(data_list) + torch.save((data, slices), self.processed_paths[0]) + print('Done!') diff --git a/Geom3D/datasets/dataset_GVP.py b/Geom3D/datasets/dataset_GVP.py new file mode 100644 index 0000000..3238344 --- /dev/null +++ b/Geom3D/datasets/dataset_GVP.py @@ -0,0 +1,155 @@ + +''' +This is a protein Dataset specifically for GVP. +''' +import numpy as np +import math +import os.path as osp +import torch +from torch_geometric.data import Data, InMemoryDataset +import torch.nn.functional as F +import torch_cluster +from tqdm import tqdm + + +class DatasetGVP(InMemoryDataset): + def __init__( + self, root, dataset, transform=None, pre_transform=None, pre_filter=None, + split='train', num_positional_embeddings=16, top_k=30, num_rbf=16 + ): + self.split = split + self.root = root + self.preprocessed_dataset = dataset + self.num_positional_embeddings = num_positional_embeddings + self.top_k = top_k + self.num_rbf = num_rbf + + super(DatasetGVP, self).__init__(root, transform, pre_transform, pre_filter) + + self.transform, self.pre_transform, self.pre_filter = transform, pre_transform, pre_filter + self.data, self.slices = torch.load(self.processed_paths[0]) + + @property + def processed_dir(self): + return osp.join(self.root, "processed_GVP", self.split) + + @property + def processed_file_names(self): + return 'data.pt' + + def positional_embeddings_GVP(self, edge_index, + num_embeddings=None, + period_range=[2, 1000], device=None): + # From https://github.com/jingraham/neurips19-graph-protein-design + num_embeddings = num_embeddings or self.num_positional_embeddings + d = edge_index[0] - edge_index[1] + + frequency = torch.exp( + torch.arange(0, num_embeddings, 2, dtype=torch.float32, device=device) + * -(np.log(10000.0) / num_embeddings) + ) + angles = d.unsqueeze(-1) * frequency + E = torch.cat((torch.cos(angles), torch.sin(angles)), -1) + return E + + def _normalize(self, tensor, dim=-1): + ''' + Normalizes a `torch.Tensor` along dimension `dim` without `nan`s. + ''' + return torch.nan_to_num( + torch.div(tensor, torch.norm(tensor, dim=dim, keepdim=True))) + + def _rbf(self, D, D_min=0., D_max=20., D_count=16, device=None): + D_mu = torch.linspace(D_min, D_max, D_count, device=device) + D_mu = D_mu.view([1, -1]) + D_sigma = (D_max - D_min) / D_count + D_expand = torch.unsqueeze(D, -1).to(device) + + RBF = torch.exp(-((D_expand - D_mu) / D_sigma) ** 2) + return RBF + + def dihedrals_GVP(self, X, eps=1e-7): + # From https://github.com/jingraham/neurips19-graph-protein-design + + X = torch.reshape(X[:, :3], [3*X.shape[0], 3]) + dX = X[1:] - X[:-1] + U = self._normalize(dX, dim=-1) + u_2 = U[:-2] + u_1 = U[1:-1] + u_0 = U[2:] + + # Backbone normals + n_2 = self._normalize(torch.cross(u_2, u_1), dim=-1) + n_1 = self._normalize(torch.cross(u_1, u_0), dim=-1) + + # Angle between normals + cosD = torch.sum(n_2 * n_1, -1) + cosD = torch.clamp(cosD, -1 + eps, 1 - eps) + D = torch.sign(torch.sum(u_2 * n_1, -1)) * torch.acos(cosD) + + # This scheme will remove phi[0], psi[-1], omega[-1] + D = F.pad(D, [1, 2]) + D = torch.reshape(D, [-1, 3]) + # Lift angle representations to the circle + D_features = torch.cat([torch.cos(D), torch.sin(D)], 1) + return D_features + + def orientations_GVP(self, X): + forward = self._normalize(X[1:] - X[:-1]) + backward = self._normalize(X[:-1] - X[1:]) + forward = F.pad(forward, [0, 0, 0, 1]) + backward = F.pad(backward, [0, 0, 1, 0]) + return torch.cat([forward.unsqueeze(-2), backward.unsqueeze(-2)], -2) + + def sidechains_GVP(self, X): + n, origin, c = X[:, 0], X[:, 1], X[:, 2] + c, n = self._normalize(c - origin), self._normalize(n - origin) + bisector = self._normalize(c + n) + perp = self._normalize(torch.cross(c, n)) + vec = -bisector * math.sqrt(1 / 3) - perp * math.sqrt(2 / 3) + return vec + + def process(self): + print('Beginning Processing ...') + data_list = [] + device = 'cpu' + + with torch.no_grad(): + for data in tqdm(self.preprocessed_dataset): + coords = [] + for i in range(len(data.coords_n)): + coords.append([list(data.coords_n[i]), list(data.coords_ca[i]), list(data.coords_c[i])]) + + coords = torch.tensor(coords, dtype=torch.float32) + + mask = torch.isfinite(coords.sum(dim=(1,2))) + coords[~mask] = np.inf + + X_ca = coords[:, 1] + edge_index = torch_cluster.knn_graph(X_ca, k=self.top_k).to(device) + + pos_embeddings = self.positional_embeddings_GVP(edge_index, self.num_positional_embeddings, device=device) + E_vectors = X_ca[edge_index[0]] - X_ca[edge_index[1]] + rbf = self._rbf(E_vectors.norm(dim=-1), D_count=self.num_rbf, device=device) + + dihedrals = self.dihedrals_GVP(coords) + orientations = self.orientations_GVP(X_ca) + sidechains = self.sidechains_GVP(coords) + + node_s = dihedrals + node_v = torch.cat([orientations, sidechains.unsqueeze(-2)], dim=-2) + edge_s = torch.cat([rbf, pos_embeddings], dim=-1) + edge_v = self._normalize(E_vectors).unsqueeze(-2) + + node_s, node_v, edge_s, edge_v = map(torch.nan_to_num, (node_s, node_v, edge_s, edge_v)) + + data.edge_index = edge_index + data.node_s = node_s + data.node_v = node_v + data.edge_s = edge_s + data.edge_v = edge_v + data_list.append(data) + + data, slices = self.collate(data_list) + torch.save((data, slices), self.processed_paths[0]) + return diff --git a/Geom3D/models/CDConv.py b/Geom3D/models/CDConv.py index 8d402c4..58a9fd8 100644 --- a/Geom3D/models/CDConv.py +++ b/Geom3D/models/CDConv.py @@ -2,6 +2,7 @@ import math from typing import Type, Any, Callable, Union, List, Optional +import numpy as np import torch import torch.nn as nn @@ -15,7 +16,7 @@ from torch_geometric.typing import Adj, OptTensor, PairOptTensor, PairTensor from torch_geometric.utils import add_self_loops, remove_self_loops import torch_geometric.transforms as T -from torch_geometric.nn import MLP, fps, global_max_pool, global_mean_pool, radius +from torch_geometric.nn import fps, global_max_pool, global_mean_pool, radius from torch_geometric.nn.pool import avg_pool, max_pool import torch.optim as optim from torch_geometric.loader import DataLoader @@ -299,7 +300,7 @@ def forward(self, x, pos, seq, ori, batch): out = self.output(x) + identity return out -class CDConv(nn.Module): +class CD_Convolution(nn.Module): def __init__(self, geometric_radii: List[float], sequential_kernel_size: float, @@ -316,7 +317,7 @@ def __init__(self, assert (len(geometric_radii) == len(channels)), "Model: 'geometric_radii' and 'channels' should have the same number of elements!" - self.embedding = torch.nn.Embedding(num_embeddings=21, embedding_dim=embedding_dim) + self.embedding = torch.nn.Embedding(num_embeddings=26, embedding_dim=embedding_dim) self.local_mean_pool = AvgPooling() layers = [] @@ -349,9 +350,30 @@ def __init__(self, out_channels=num_classes, batch_norm=batch_norm, dropout=dropout) - - def forward(self, data): - x, pos, seq, ori, batch = (self.embedding(data.x), data.pos, data.seq, data.ori, data.batch) + + def orientation_CDConv(self, pos): + u = torch.nn.functional.normalize(pos[1:,:] - pos[:-1,:], dim=1) + u1 = u[1:,:] + u2 = u[:-1, :] + b = torch.nn.functional.normalize(u2 - u1, dim=1) + n = torch.nn.functional.normalize(torch.cross(u2, u1, dim=1), dim=1) + o = torch.nn.functional.normalize(torch.cross(b, n, dim=1), dim=1) + ori = torch.stack([b, n, o], dim=1) + return torch.cat([ori[0].unsqueeze(0), ori, ori[-1].unsqueeze(0)], dim=0) + + def forward(self, data, split=None): + pos = data.coords_ca + device = pos.device + + if split == "training": + pos = pos + torch.normal(0.0, 0.05, size=pos.shape, device=device, dtype=pos.dtype) + + seq_idx = torch.arange(pos.shape[0], device=device).unsqueeze(1).float() + center = torch.sum(pos, dim=0, keepdim=True)/pos.shape[0] + pos = pos - center + ori = self.orientation_CDConv(pos) + + x, pos, seq, ori, batch = (self.embedding(data.seq), pos, seq_idx, ori, data.batch) for i, layer in enumerate(self.layers): x = layer(x, pos, seq, ori, batch) @@ -359,6 +381,7 @@ def forward(self, data): x = global_mean_pool(x, batch) elif i % 2 == 1: x, pos, seq, ori, batch = self.local_mean_pool(x, pos, seq, ori, batch) + seq = seq.to(device) out = self.classifier(x) diff --git a/Geom3D/models/GVP.py b/Geom3D/models/GVP.py index 6497c8d..520f7c3 100644 --- a/Geom3D/models/GVP.py +++ b/Geom3D/models/GVP.py @@ -6,6 +6,8 @@ import torch.nn.functional as F from torch_geometric.nn import MessagePassing import torch_scatter +import numpy as np + def tuple_sum(*args): ''' @@ -496,3 +498,86 @@ def forward(self, batch, scatter_mean=True, dense=True): if dense: out = self.dense(out).squeeze(-1) return out +class MQAModel(nn.Module): + ''' + GVP-GNN for Model Quality Assessment as described in manuscript. + + Takes in protein structure graphs of type `torch_geometric.data.Data` + or `torch_geometric.data.Batch` and returns a scalar score for + each graph in the batch in a `torch.Tensor` of shape [n_nodes] + + Should be used with `gvp.data.ProteinGraphDataset`, or with generators + of `torch_geometric.data.Batch` objects with the same attributes. + + :param node_in_dim: node dimensions in input graph, should be + (6, 3) if using original features + :param node_h_dim: node dimensions to use in GVP-GNN layers + :param node_in_dim: edge dimensions in input graph, should be + (32, 1) if using original features + :param edge_h_dim: edge dimensions to embed to before use + in GVP-GNN layers + :seq_in: if `True`, sequences will also be passed in with + the forward pass; otherwise, sequence information + is assumed to be part of input node embeddings + :param num_layers: number of GVP-GNN layers + :param drop_rate: rate to use in all dropout layers + ''' + def __init__( + self, node_in_dim, node_h_dim, edge_in_dim, edge_h_dim, + seq_in=False, num_layers=3, drop_rate=0.1, out_channels=1195): + + super(MQAModel, self).__init__() + + if seq_in: + self.W_s = nn.Embedding(20, 20) + node_in_dim = (node_in_dim[0] + 20, node_in_dim[1]) + + self.W_v = nn.Sequential( + LayerNorm(node_in_dim), + GVP(node_in_dim, node_h_dim, activations=(None, None)) + ) + self.W_e = nn.Sequential( + LayerNorm(edge_in_dim), + GVP(edge_in_dim, edge_h_dim, activations=(None, None)) + ) + + self.layers = nn.ModuleList( + GVPConvLayer(node_h_dim, edge_h_dim, drop_rate=drop_rate) + for _ in range(num_layers)) + + ns, _ = node_h_dim + self.W_out = nn.Sequential( + LayerNorm(node_h_dim), + GVP(node_h_dim, (ns, 0))) + + self.dense = nn.Sequential( + nn.Linear(ns, 2*ns), nn.ReLU(inplace=True), + nn.Dropout(p=drop_rate), + nn.Linear(2*ns, out_channels) + ) + + def forward(self, seq=None, batch=None): + ''' + :param h_V: tuple (s, V) of node embeddings + :param edge_index: `torch.Tensor` of shape [2, num_edges] + :param h_E: tuple (s, V) of edge embeddings + :param seq: if not `None`, int `torch.Tensor` of shape [num_nodes] + to be embedded and appended to `h_V` + ''' + + h_V = (batch.node_s, batch.node_v) # tuple (s, V) of node embeddings + h_E = (batch.edge_s, batch.edge_v) # tuple (s, V) of edge embeddings + + if seq is not None: + seq = self.W_s(seq) + h_V = (torch.cat([h_V[0], seq], dim=-1), h_V[1]) + h_V = self.W_v(h_V) + h_E = self.W_e(h_E) + for layer in self.layers: + h_V = layer(h_V, batch.edge_index, h_E) + out = self.W_out(h_V) + + if batch is None: out = out.mean(dim=0, keepdims=True) + else: out = torch_scatter.scatter_mean(out, batch.batch, dim=0) + + return self.dense(out).squeeze(-1) + 0.5 \ No newline at end of file diff --git a/Geom3D/models/GearNet.py b/Geom3D/models/GearNet.py index 1698742..f16fdcd 100644 --- a/Geom3D/models/GearNet.py +++ b/Geom3D/models/GearNet.py @@ -1,4 +1,3 @@ - from collections.abc import Sequence import torch @@ -6,30 +5,45 @@ from torch.nn import functional as F from torch_scatter import scatter_add, scatter_mean -from .GearNet_layer import GeometricRelationalGraphConv, SpatialLineGraph +from .GearNet_layer import IEConvLayer, GeometricRelationalGraphConv, SpatialLineGraph, SumReadout, MeanReadout + +class GearNetIEConv(nn.Module): -class GearNet(nn.Module): - def __init__(self, input_dim, hidden_dims, num_relation, edge_input_dim=None, num_angle_bin=None, - short_cut=False, batch_norm=False, activation="relu", concat_hidden=False, readout="sum"): - super(GearNet, self).__init__() + def __init__(self, input_dim, embedding_dim, hidden_dims, num_relation, edge_input_dim=None, + batch_norm=False, activation="relu", concat_hidden=False, short_cut=True, + readout="sum", dropout=0, num_angle_bin=None, layer_norm=False, use_ieconv=False): + super(GearNetIEConv, self).__init__() if not isinstance(hidden_dims, Sequence): hidden_dims = [hidden_dims] self.input_dim = input_dim + self.embedding_dim = embedding_dim self.output_dim = sum(hidden_dims) if concat_hidden else hidden_dims[-1] - self.dims = [input_dim] + list(hidden_dims) + self.dims = [embedding_dim if embedding_dim > 0 else input_dim] + list(hidden_dims) self.edge_dims = [edge_input_dim] + self.dims[:-1] self.num_relation = num_relation + self.concat_hidden = concat_hidden + self.short_cut = short_cut self.num_angle_bin = num_angle_bin self.short_cut = short_cut self.concat_hidden = concat_hidden - self.batch_norm = batch_norm + self.layer_norm = layer_norm + self.use_ieconv = use_ieconv + + if embedding_dim > 0: + self.linear = nn.Linear(input_dim, embedding_dim) + self.embedding_batch_norm = nn.BatchNorm1d(embedding_dim) self.layers = nn.ModuleList() + self.ieconvs = nn.ModuleList() for i in range(len(self.dims) - 1): + # note that these layers are from gearnet.layer instead of torchdrug.layers self.layers.append(GeometricRelationalGraphConv(self.dims[i], self.dims[i + 1], num_relation, None, batch_norm, activation)) + if use_ieconv: + self.ieconvs.append(IEConvLayer(self.dims[i], self.dims[i] // 4, + self.dims[i+1], edge_input_dim=14, kernel_hidden_dim=32)) if num_angle_bin: self.spatial_line_graph = SpatialLineGraph(num_angle_bin) self.edge_layers = nn.ModuleList() @@ -37,42 +51,73 @@ def __init__(self, input_dim, hidden_dims, num_relation, edge_input_dim=None, nu self.edge_layers.append(GeometricRelationalGraphConv( self.edge_dims[i], self.edge_dims[i + 1], num_angle_bin, None, batch_norm, activation)) - if batch_norm: - self.batch_norms = nn.ModuleList() + if layer_norm: + self.layer_norms = nn.ModuleList() for i in range(len(self.dims) - 1): - self.batch_norms.append(nn.BatchNorm1d(self.dims[i + 1])) + self.layer_norms.append(nn.LayerNorm(self.dims[i + 1])) + + self.dropout = nn.Dropout(dropout) if readout == "sum": - self.readout = scatter_add + self.readout = SumReadout() elif readout == "mean": - self.readout = scatter_mean + self.readout = MeanReadout() else: raise ValueError("Unknown readout `%s`" % readout) + + + + def get_ieconv_edge_feature(self, graph): + u = torch.ones_like(graph.node_position) + u[1:] = graph.node_position[1:] - graph.node_position[:-1] + u = F.normalize(u, dim=-1) + b = torch.ones_like(graph.node_position) + b[:-1] = u[:-1] - u[1:] + b = F.normalize(b, dim=-1) + n = torch.ones_like(graph.node_position) + n[:-1] = torch.cross(u[:-1], u[1:]) + n = F.normalize(n, dim=-1) + + local_frame = torch.stack([b, n, torch.cross(b, n)], dim=-1) + + node_in, node_out = graph.edge_list.t()[:2] + t = graph.node_position[node_out] - graph.node_position[node_in] + t = torch.einsum('ijk, ij->ik', local_frame[node_in], t) + r = torch.sum(local_frame[node_in] * local_frame[node_out], dim=1) + delta = torch.abs(graph.atom2residue[node_in] - graph.atom2residue[node_out]).float() / 6 + delta = delta.unsqueeze(-1) + + return torch.cat([ + t, r, delta, + 1 - 2 * t.abs(), 1 - 2 * r.abs(), 1 - 2 * delta.abs() + ], dim=-1) def forward(self, graph, input, all_loss=None, metric=None): hiddens = [] layer_input = input + if self.embedding_dim > 0: + layer_input = self.linear(layer_input) + layer_input = self.embedding_batch_norm(layer_input) if self.num_angle_bin: line_graph = self.spatial_line_graph(graph) - edge_input = line_graph.node_feature.float() + edge_hidden = line_graph.node_feature.float() + else: + edge_hidden = None for i in range(len(self.layers)): - hidden = self.layers[i](graph, layer_input) + # edge message passing + if self.num_angle_bin: + edge_hidden = self.edge_layers[i](line_graph, edge_hidden) + hidden = self.layers[i](graph, layer_input, edge_hidden) + # ieconv layer + if self.use_ieconv: + ieconv_edge_feature = self.get_ieconv_edge_feature(graph) + hidden = hidden + self.ieconvs[i](graph, layer_input, ieconv_edge_feature) + hidden = self.dropout(hidden) if self.short_cut and hidden.shape == layer_input.shape: hidden = hidden + layer_input - if self.num_angle_bin: - edge_hidden = self.edge_layers[i](line_graph, edge_input) - edge_weight = graph.edge_weight.unsqueeze(-1) - node_out = graph.edge_list[:, 1] * self.num_relation + graph.edge_list[:, 2] - update = scatter_add(edge_hidden * edge_weight, node_out, dim=0, - dim_size=graph.num_node * self.num_relation) - update = update.view(graph.num_node, self.num_relation * edge_hidden.shape[1]) - update = self.layers[i].linear(update) - update = self.layers[i].activation(update) - hidden = hidden + update - edge_input = edge_hidden - if self.batch_norm: - hidden = self.batch_norms[i](hidden) + if self.layer_norm: + hidden = self.layer_norms[i](hidden) hiddens.append(hidden) layer_input = hidden @@ -80,8 +125,33 @@ def forward(self, graph, input, all_loss=None, metric=None): node_feature = torch.cat(hiddens, dim=-1) else: node_feature = hiddens[-1] + graph_feature = self.readout(graph, node_feature) + + return { + "graph_feature": graph_feature, + "node_feature": node_feature + } + + +class FusionNetwork(nn.Module): + + def __init__(self, sequence_model, structure_model): + super(FusionNetwork, self).__init__() + self.sequence_model = sequence_model + self.structure_model = structure_model + self.output_dim = sequence_model.output_dim + structure_model.output_dim + + def forward(self, graph, input, all_loss=None, metric=None): + output1 = self.sequence_model(graph, input, all_loss, metric) + node_output1 = output1.get("node_feature", output1.get("residue_feature")) + output2 = self.structure_model(graph, node_output1, all_loss, metric) + node_output2 = output2.get("node_feature", output2.get("residue_feature")) - graph_feature = self.readout(node_feature, graph.node2graph, dim=0) + node_feature = torch.cat([node_output1, node_output2], dim=-1) + graph_feature = torch.cat([ + output1['graph_feature'], + output2['graph_feature'] + ], dim=-1) return { "graph_feature": graph_feature, diff --git a/Geom3D/models/GearNet_layer.py b/Geom3D/models/GearNet_layer.py index d40f931..bcf775b 100644 --- a/Geom3D/models/GearNet_layer.py +++ b/Geom3D/models/GearNet_layer.py @@ -2,8 +2,9 @@ import torch from torch import nn from torch.nn import functional as F -from torch_scatter import scatter_add +from torch_scatter import scatter_add, scatter_mean from torch_geometric.data import Data +from collections.abc import Sequence class MultiLayerPerceptron(nn.Module): @@ -271,4 +272,60 @@ def construct_line_graph(graph): return Data( edge_list=edge_list, edge_weight=edge_weight, num_nodes=num_nodes, num_edges=num_edges, offsets=offsets, - node_feature=node_feature) \ No newline at end of file + node_feature=node_feature) + + +class Readout(nn.Module): + + def __init__(self, type="node"): + super(Readout, self).__init__() + self.type = type + + def get_index2graph(self, graph): + if self.type == "node": + input2graph = graph.node2graph + elif self.type == "edge": + input2graph = graph.edge2graph + elif self.type == "residue": + input2graph = graph.residue2graph + else: + raise ValueError("Unknown input type `%s` for readout functions" % self.type) + return input2graph + + +class SumReadout(Readout): + """Sum readout operator over graphs with variadic sizes.""" + + def forward(self, graph, input): + """ + Perform readout over the graph(s). + + Parameters: + graph (Graph): graph(s) + input (Tensor): node representations + + Returns: + Tensor: graph representations + """ + input2graph = self.get_index2graph(graph) + output = scatter_add(input, input2graph, dim=0, dim_size=graph.batch_size) + return output + + +class MeanReadout(Readout): + """Mean readout operator over graphs with variadic sizes.""" + + def forward(self, graph, input): + """ + Perform readout over the graph(s). + + Parameters: + graph (Graph): graph(s) + input (Tensor): node representations + + Returns: + Tensor: graph representations + """ + input2graph = self.get_index2graph(graph) + output = scatter_mean(input, input2graph, dim=0, dim_size=graph.batch_size) + return output \ No newline at end of file diff --git a/Geom3D/models/ProNet/ProNet.py b/Geom3D/models/ProNet/ProNet.py index f78dcda..0a05e9c 100644 --- a/Geom3D/models/ProNet/ProNet.py +++ b/Geom3D/models/ProNet/ProNet.py @@ -352,7 +352,7 @@ def reset_parameters(self): def pos_emb(self, edge_index, num_pos_emb=16): # From https://github.com/jingraham/neurips19-graph-protein-design d = edge_index[0] - edge_index[1] - + frequency = torch.exp( torch.arange(0, num_pos_emb, 2, dtype=torch.float32, device=edge_index.device) * -(np.log(10000.0) / num_pos_emb) @@ -362,12 +362,12 @@ def pos_emb(self, edge_index, num_pos_emb=16): return E def forward(self, batch_data): + z, pos, batch = batch_data.seq, batch_data.coords_ca, batch_data.batch - z, pos, batch = torch.squeeze(batch_data.x.long()), batch_data.coords_ca, batch_data.batch pos_n = batch_data.coords_n pos_c = batch_data.coords_c - bb_embs = batch_data.bb_embs - side_chain_embs = batch_data.side_chain_embs + bb_embs = batch_data.backbone_angle_encoding + side_chain_embs = batch_data.side_chain_angle_encoding device = z.device @@ -432,6 +432,7 @@ def forward(self, batch_data): feature1 = torch.cat((self.feature1(dist, angle1), self.feature1(dist, angle2), self.feature1(dist, angle3)),1) elif self.level == 'aminoacid': + #print("num_nodes:", num_nodes) refi = (i-1)%num_nodes refj0 = (j-1)%num_nodes @@ -452,12 +453,14 @@ def forward(self, batch_data): feature1 = self.feature1(dist, tau) # Interaction blocks. + idx = 0 for interaction_block in self.interaction_blocks: if self.data_augment_eachlayer: # add gaussian noise to features gaussian_noise = torch.clip(torch.empty(x.shape).to(device).normal_(mean=0.0, std=0.025), min=-0.1, max=0.1) x += gaussian_noise x = interaction_block(x, feature0, feature1, pos_emb, edge_index, batch) + idx += 1 y = scatter(x, batch, dim=0) diff --git a/Geom3D/models/__init__.py b/Geom3D/models/__init__.py index 4f06a48..141bc5a 100644 --- a/Geom3D/models/__init__.py +++ b/Geom3D/models/__init__.py @@ -3,40 +3,41 @@ import torch.nn.functional as F from torch_geometric.nn import GATConv, GCNConv -from .AutoEncoder import AutoEncoder, VariationalAutoEncoder +# from .AutoEncoder import AutoEncoder, VariationalAutoEncoder -from .DimeNet import DimeNet -from .DimeNetPlusPlus import DimeNetPlusPlus -from .EGNN import EGNN -from .PaiNN import PaiNN -from .SchNet import SchNet -from .SE3_Transformer import SE3Transformer -from .SEGNN import SEGNNModel as SEGNN -from .SphereNet import SphereNet -from .SphereNet_periodic import SphereNetPeriodic -from .TFN import TFN -from .GemNet import GemNet -from .ClofNet import ClofNet -from .Graphormer import Graphormer -from .TransformerM import TransformerM -from .Equiformer import EquiformerEnergy, EquiformerEnergyForce, EquiformerEnergyPeriodic +# from .DimeNet import DimeNet +# from .DimeNetPlusPlus import DimeNetPlusPlus +# from .EGNN import EGNN +# from .PaiNN import PaiNN +# from .SchNet import SchNet +# from .SE3_Transformer import SE3Transformer +# from .SEGNN import SEGNNModel as SEGNN +# from .SphereNet import SphereNet +# from .SphereNet_periodic import SphereNetPeriodic +# from .TFN import TFN +# from .GemNet import GemNet +# from .ClofNet import ClofNet +# from .Graphormer import Graphormer +# from .TransformerM import TransformerM +# from .Equiformer import EquiformerEnergy, EquiformerEnergyForce, EquiformerEnergyPeriodic -from .GVP import GVP_GNN -from .GearNet import GearNet +from .GVP import GVP_GNN, MQAModel +from .GearNet import GearNetIEConv from .ProNet import ProNet +from .CDConv import CD_Convolution -from .BERT import BertForSequenceRegression +# from .BERT import BertForSequenceRegression -from .GeoSSL_DDM import GeoSSL_DDM -from .GeoSSL_PDM import GeoSSL_PDM +# from .GeoSSL_DDM import GeoSSL_DDM +# from .GeoSSL_PDM import GeoSSL_PDM -from .molecule_gnn_model import GNN, GNN_graphpred -from .molecule_gnn_model_simplified import GNNSimplified -from .PNA import PNA -from .ENN import ENN_S2S -from .DMPNN import DMPNN -from .GPS import GPSModel -from .AWARE import AWARE +# from .molecule_gnn_model import GNN, GNN_graphpred +# from .molecule_gnn_model_simplified import GNNSimplified +# from .PNA import PNA +# from .ENN import ENN_S2S +# from .DMPNN import DMPNN +# from .GPS import GPSModel +# from .AWARE import AWARE -from .MLP import MLP -from .CNN import CNN +# from .MLP import MLP +# from .CNN import CNN diff --git a/examples_3D/config.py b/examples_3D/config.py index 906686a..50a84ab 100644 --- a/examples_3D/config.py +++ b/examples_3D/config.py @@ -6,6 +6,7 @@ # about seed and basic info parser.add_argument("--seed", type=int, default=42) parser.add_argument("--device", type=int, default=0) +parser.add_argument("--data_root") parser.add_argument( "--model_3d", @@ -102,6 +103,9 @@ parser.add_argument("--LEP_useh", dest="LEP_droph", action="store_false") parser.set_defaults(LEP_droph=False) +# for GeneOntology +parser.add_argument("--GO_level", default="mf", choices=["mf", "bp", "cc"]) + # for MoleculeNet parser.add_argument("--moleculenet_num_conformers", type=int, default=10) @@ -116,6 +120,7 @@ parser.add_argument("--decay", type=float, default=0) parser.add_argument("--print_every_epoch", type=int, default=1) parser.add_argument("--loss", type=str, default="mae", choices=["mse", "mae"]) +parser.add_argument("--optimizer", type=str, default="Adam", choices=["Adam", "SGD"]) parser.add_argument("--lr_scheduler", type=str, default="CosineAnnealingLR") parser.add_argument("--lr_decay_factor", type=float, default=0.5) parser.add_argument("--lr_decay_step_size", type=int, default=100) @@ -239,10 +244,30 @@ parser.add_argument("--Equiformer_num_basis", type=int, default=128) parser.add_argument("--Equiformer_hyperparameter", type=int, default=0) +# for GVP +parser.add_argument("--num_positional_embeddings", type=int, default=16) +parser.add_argument("--top_k", type=int, default=30) +parser.add_argument("--num_rbf", type=int, default=16) + # for ProNet parser.add_argument("--ProNet_level", type=str, default="aminoacid", choices=["aminoacid", "backbone", "allatom"]) parser.add_argument("--ProNet_dropout", type=float, default=0.3) +# for CDConv +parser.add_argument("--CDConv_radius", type=float, default=4) +parser.add_argument("--CDConv_kernel_size", type=int, default=21) +parser.add_argument("--CDConv_kernel_channels", type=int, nargs="+", default=[24]) +parser.add_argument("--CDConv_geometric_raddi_coeff", type=int, nargs="+", default=[2, 3, 4, 5]) +parser.add_argument("--CDConv_channels", type=int, nargs="+", default=[256, 512, 1024, 2048]) +parser.add_argument("--CDConv_base_width", type=int, default=64) + +# for GearNet +parser.add_argument("--num_relation", type=int, default=7) +parser.add_argument("--GearNet_readout", type=str, default="sum") +parser.add_argument("--GearNet_dropout", type=float, default=0) +parser.add_argument("--GearNet_edge_input_dim", type=int) +parser.add_argument("--GearNet_num_angle_bin", type=int) + # data augmentation tricks, see appendix E in the paper (https://openreview.net/pdf?id=9X-hgLDLYkQ) parser.add_argument('--mask', action='store_true') parser.add_argument('--noise', action='store_true') diff --git a/examples_3D/finetune_ECMultiple.py b/examples_3D/finetune_ECMultiple.py new file mode 100644 index 0000000..40ce5d9 --- /dev/null +++ b/examples_3D/finetune_ECMultiple.py @@ -0,0 +1,503 @@ +import os +import time + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader as TorchDataLoader +from torch_geometric.loader import DataLoader as PyGDataLoader +from torch_geometric.nn import global_max_pool, global_mean_pool +from tqdm import tqdm + +from config import args +from Geom3D.datasets import DatasetECMultiple, DatasetGVP, DatasetECMultipleGearNet +from Geom3D.models import ProNet, GearNetIEConv, MQAModel, CD_Convolution +import Geom3D.models.GearNet_layer as GearNet_layer + +def fmax(probs, labels): + thresholds = np.arange(0, 1, 0.01) + f_max = 0.0 + + for threshold in thresholds: + precision = 0.0 + recall = 0.0 + precision_cnt = 0 + recall_cnt = 0 + for idx in range(probs.shape[0]): + prob = probs[idx] + label = labels[idx] + pred = (prob > threshold).astype(np.int32) + correct_sum = np.sum(label*pred) + pred_sum = np.sum(pred) + label_sum = np.sum(label) + if pred_sum > 0: + precision += correct_sum/pred_sum + precision_cnt += 1 + if label_sum > 0: + recall += correct_sum/label_sum + recall_cnt += 1 + if recall_cnt > 0: + recall = recall / recall_cnt + else: + recall = 0 + if precision_cnt > 0: + precision = precision / precision_cnt + else: + precision = 0 + f = (2.*precision*recall)/max(precision+recall, 1e-8) + f_max = max(f, f_max) + + return f_max + +def model_setup(): + num_class = 538 + + if args.model_3d == "GVP": + node_in_dim = (6, 3) + node_h_dim = (100, 16) + edge_in_dim = (32, 1) + edge_h_dim = (32, 1) + model = MQAModel(node_in_dim, node_h_dim, edge_in_dim, edge_h_dim, out_channels=num_class) + graph_pred_linear = None + + elif args.model_3d == "GearNet": + input_dim = 21 + model = GearNetIEConv( + input_dim=input_dim, embedding_dim=512, hidden_dims=[512, 512, 512, 512, 512, 512], num_relation=args.num_relation, + batch_norm=True, concat_hidden=True, short_cut=True, readout=args.GearNet_readout, layer_norm=True, dropout=args.GearNet_dropout, + edge_input_dim=args.GearNet_edge_input_dim, num_angle_bin=args.GearNet_num_angle_bin) + + num_mlp_layer = 3 + hidden_dims = [model.output_dim] * (num_mlp_layer - 1) + graph_pred_linear = GearNet_layer.MultiLayerPerceptron( + model.output_dim, hidden_dims + [num_class], batch_norm=True, dropout=0.5) + + elif args.model_3d == "GearNet_IEConv": + input_dim = 21 + model = GearNetIEConv( + input_dim=input_dim, embedding_dim=512, hidden_dims=[512, 512, 512, 512, 512, 512], num_relation=args.num_relation, + batch_norm=True, concat_hidden=True, short_cut=True, readout=args.GearNet_readout, layer_norm=True, dropout=args.GearNet_dropout, + edge_input_dim=args.GearNet_edge_input_dim, num_angle_bin=args.GearNet_num_angle_bin) + + num_mlp_layer = 3 + hidden_dims = [model.output_dim] * (num_mlp_layer - 1) + graph_pred_linear = GearNet_layer.MultiLayerPerceptron( + model.output_dim, hidden_dims + [num_class], batch_norm=True, dropout=0.5) + + elif args.model_3d == "ProNet": + model = ProNet( + level=args.ProNet_level, + dropout=args.ProNet_dropout, + out_channels=num_class, + euler_noise=args.euler_noise, + ) + graph_pred_linear = None + + elif args.model_3d == "CDConv": + geometric_radii = [x * args.CDConv_radius for x in args.CDConv_geometric_raddi_coeff] + model = CD_Convolution( + geometric_radii=geometric_radii, + sequential_kernel_size=args.CDConv_kernel_size, + kernel_channels=args.CDConv_kernel_channels, channels=args.CDConv_channels, base_width=args.CDConv_base_width, + num_classes=num_class) + graph_pred_linear = None + + else: + raise Exception("3D model {} not included.".format(args.model_3d)) + return model, graph_pred_linear + + +def load_model(model, graph_pred_linear, model_weight_file): + print("Loading from {}".format(model_weight_file)) + if "MoleculeSDE" in model_weight_file: + model_weight = torch.load(model_weight_file) + model.load_state_dict(model_weight["model_3D"]) + if (graph_pred_linear is not None) and ("graph_pred_linear" in model_weight): + graph_pred_linear.load_state_dict(model_weight["graph_pred_linear"]) + + else: + model_weight = torch.load(model_weight_file) + model.load_state_dict(model_weight["model"]) + if (graph_pred_linear is not None) and ("graph_pred_linear" in model_weight): + graph_pred_linear.load_state_dict(model_weight["graph_pred_linear"]) + return + + +def save_model(save_best): + if not args.output_model_dir == "": + if save_best: + print("save model with optimal loss") + output_model_path = os.path.join(args.output_model_dir, "model.pth") + saved_model_dict = {} + saved_model_dict["model"] = model.state_dict() + if graph_pred_linear is not None: + saved_model_dict["graph_pred_linear"] = graph_pred_linear.state_dict() + torch.save(saved_model_dict, output_model_path) + + else: + print("save model in the last epoch") + output_model_path = os.path.join(args.output_model_dir, "model_final.pth") + saved_model_dict = {} + saved_model_dict["model"] = model.state_dict() + if graph_pred_linear is not None: + saved_model_dict["graph_pred_linear"] = graph_pred_linear.state_dict() + torch.save(saved_model_dict, output_model_path) + return + + +def train(epoch, device, loader, optimizer): + model.train() + if graph_pred_linear is not None: + graph_pred_linear.train() + + loss_acc = 0 + num_iters = len(loader) + + if args.verbose: + L = tqdm(loader) + else: + L = loader + for step, batch in enumerate(L): + if args.model_3d == "ProNet": + if args.mask: + # random mask node aatype + mask_indice = torch.tensor(np.random.choice(batch.num_nodes, int(batch.num_nodes * args.mask_aatype), replace=False)) + batch.x[:, 0][mask_indice] = 25 + if args.noise: + # add gaussian noise to atom coords + gaussian_noise = torch.clip(torch.normal(mean=0.0, std=0.1, size=batch.coords_ca.shape), min=-0.3, max=0.3) + batch.coords_ca += gaussian_noise + if args.ProNet_level != 'aminoacid': + batch.coords_n += gaussian_noise + batch.coords_c += gaussian_noise + if args.deform: + # Anisotropic scale + deform = torch.clip(torch.normal(mean=1.0, std=0.1, size=(1, 3)), min=0.9, max=1.1) + batch.coords_ca *= deform + if args.ProNet_level != 'aminoacid': + batch.coords_n *= deform + batch.coords_c *= deform + + batch = batch.to(device) + + if args.model_3d == "GVP": + molecule_3D_repr = model(batch=batch) + elif args.model_3d in ["GearNet", "GearNet_IEConv"]: + molecule_3D_repr = model(batch, batch.node_feature.float())["graph_feature"] + elif args.model_3d == "ProNet": + molecule_3D_repr = model(batch) + elif args.model_3d == "CDConv": + molecule_3D_repr = model(batch, split="training") + + if graph_pred_linear is not None: + pred = graph_pred_linear(molecule_3D_repr).squeeze(1) + else: + pred = molecule_3D_repr.squeeze(1) + + y = batch.y + # print(y) + # y = torch.from_numpy(np.stack(y, axis=0)).to(device) + # print(y.shape) + + loss = criterion(pred.sigmoid(), y) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + loss_acc += loss.cpu().detach().item() + + if args.lr_scheduler in ["CosineAnnealingWarmRestarts"]: + lr_scheduler.step(epoch - 1 + step / num_iters) + + loss_acc /= len(loader) + if args.lr_scheduler in ["StepLR", "CosineAnnealingLR"]: + lr_scheduler.step() + elif args.lr_scheduler in [ "ReduceLROnPlateau"]: + lr_scheduler.step(loss_acc) + + return loss_acc + + +@torch.no_grad() +def eval(device, loader): + model.eval() + if graph_pred_linear is not None: + graph_pred_linear.eval() + y_true = [] + y_scores = [] + + if args.verbose: + L = tqdm(loader) + else: + L = loader + for batch in L: + if args.model_3d == "ProNet": + if args.mask: + # random mask node aatype + mask_indice = torch.tensor(np.random.choice(batch.num_nodes, int(batch.num_nodes * args.mask_aatype), replace=False)) + batch.x[:, 0][mask_indice] = 25 + if args.noise: + # add gaussian noise to atom coords + gaussian_noise = torch.clip(torch.normal(mean=0.0, std=0.1, size=batch.coords_ca.shape), min=-0.3, max=0.3) + batch.coords_ca += gaussian_noise + if args.ProNet_level != 'aminoacid': + batch.coords_n += gaussian_noise + batch.coords_c += gaussian_noise + if args.deform: + # Anisotropic scale + deform = torch.clip(torch.normal(mean=1.0, std=0.1, size=(1, 3)), min=0.9, max=1.1) + batch.coords_ca *= deform + if args.ProNet_level != 'aminoacid': + batch.coords_n *= deform + batch.coords_c *= deform + + batch = batch.to(device) + + if args.model_3d == "GVP": + molecule_3D_repr = model(batch=batch) + elif args.model_3d in ["GearNet", "GearNet_IEConv"]: + molecule_3D_repr = model(batch, batch.node_feature.float())["graph_feature"] + elif args.model_3d == "ProNet": + molecule_3D_repr = model(batch) + elif args.model_3d == "CDConv": + molecule_3D_repr = model(batch) + + if graph_pred_linear is not None: + pred = graph_pred_linear(molecule_3D_repr).squeeze() + else: + pred = molecule_3D_repr.squeeze() + pred = pred.sigmoid() + + y = batch.y + + y_true.append(y) + y_scores.append(pred) + + y_true = torch.cat(y_true, dim=0).cpu().numpy() + y_scores = torch.cat(y_scores, dim=0).cpu().numpy() + + return fmax(y_scores, y_true) + +if __name__ == "__main__": + torch.manual_seed(args.seed) + np.random.seed(args.seed) + device = ( + torch.device("cuda:" + str(args.device)) + if torch.cuda.is_available() + else torch.device("cpu") + ) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + + data_root = args.data_root + + dataset_class = DatasetECMultiple + if args.model_3d == "GearNet": + dataset_class = DatasetECMultipleGearNet + + train_dataset = dataset_class(root=data_root, split='train') + valid_dataset = dataset_class(root=data_root, split='valid') + test_30_dataset = dataset_class(root=data_root, split='test', percent=0.3) + test_40_dataset = dataset_class(root=data_root, split='test', percent=0.4) + test_50_dataset = dataset_class(root=data_root, split='test', percent=0.5) + test_70_dataset = dataset_class(root=data_root, split='test', percent=0.7) + test_95_dataset = dataset_class(root=data_root, split='test', percent=0.95) + + if args.model_3d == "GVP": + data_root = "../data/ECMultiple_GVP" + train_dataset = DatasetGVP( + root=data_root, dataset=train_dataset, split='train', num_positional_embeddings=args.num_positional_embeddings, top_k=args.top_k, num_rbf=args.num_rbf) + valid_dataset = DatasetGVP( + root=data_root, dataset=valid_dataset, split='valid', num_positional_embeddings=args.num_positional_embeddings, top_k=args.top_k, num_rbf=args.num_rbf) + test_30_dataset = DatasetGVP( + root=data_root, dataset=test_30_dataset, split='test_0.3', num_positional_embeddings=args.num_positional_embeddings, top_k=args.top_k, num_rbf=args.num_rbf) + test_40_dataset = DatasetGVP( + root=data_root, dataset=test_40_dataset, split='test_0.4', num_positional_embeddings=args.num_positional_embeddings, top_k=args.top_k, num_rbf=args.num_rbf) + test_50_dataset = DatasetGVP( + root=data_root, dataset=test_50_dataset, split='test_0.5', num_positional_embeddings=args.num_positional_embeddings, top_k=args.top_k, num_rbf=args.num_rbf) + test_70_dataset = DatasetGVP( + root=data_root, dataset=test_70_dataset, split='test_0.7', num_positional_embeddings=args.num_positional_embeddings, top_k=args.top_k, num_rbf=args.num_rbf) + test_95_dataset = DatasetGVP( + root=data_root, dataset=test_95_dataset, split='test_0.95', num_positional_embeddings=args.num_positional_embeddings, top_k=args.top_k, num_rbf=args.num_rbf) + + criterion = nn.BCELoss() + + DataLoaderClass = PyGDataLoader + dataloader_kwargs = {} + if args.model_3d in ["GearNet", "GearNet_IEConv"]: + dataloader_kwargs["collate_fn"] = DatasetECMultipleGearNet.collate_fn + DataLoaderClass = TorchDataLoader + + train_loader = DataLoaderClass( + train_dataset, + batch_size=args.batch_size, + shuffle=True, + num_workers=args.num_workers, + drop_last=True, + **dataloader_kwargs + ) + val_loader = DataLoaderClass( + valid_dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.num_workers, + drop_last=True, + **dataloader_kwargs + ) + test_30_loader = DataLoaderClass( + test_30_dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.num_workers, + drop_last=True, + **dataloader_kwargs + ) + test_40_loader = DataLoaderClass( + test_40_dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.num_workers, + drop_last=True, + **dataloader_kwargs + ) + test_50_loader = DataLoaderClass( + test_50_dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.num_workers, + drop_last=True, + **dataloader_kwargs + ) + test_70_loader = DataLoaderClass( + test_70_dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.num_workers, + drop_last=True, + **dataloader_kwargs + ) + test_95_loader = DataLoaderClass( + test_95_dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.num_workers, + drop_last=True, + **dataloader_kwargs + ) + + model, graph_pred_linear = model_setup() + + if args.input_model_file is not "": + load_model(model, graph_pred_linear, args.input_model_file) + model.to(device) + print(model) + if graph_pred_linear is not None: + graph_pred_linear.to(device) + print(graph_pred_linear) + + # set up optimizer + # different learning rate for different part of GNN + model_param_group = [{"params": model.parameters(), "lr": args.lr}] + if graph_pred_linear is not None: + model_param_group.append( + {"params": graph_pred_linear.parameters(), "lr": args.lr} + ) + if args.optimizer == "Adam": + optimizer = optim.Adam(model_param_group, lr=args.lr, weight_decay=args.decay) + elif args.optimizer == "SGD": + optimizer = optim.SGD(model_param_group, lr=args.lr, weight_decay=5e-4, momentum=0.9) + + lr_scheduler = None + if args.lr_scheduler == "CosineAnnealingLR": + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, args.epochs + ) + print("Apply lr scheduler CosineAnnealingLR") + elif args.lr_scheduler == "CosineAnnealingWarmRestarts": + lr_scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts( + optimizer, args.epochs, eta_min=1e-4 + ) + print("Apply lr scheduler CosineAnnealingWarmRestarts") + elif args.lr_scheduler == "StepLR": + lr_scheduler = optim.lr_scheduler.StepLR( + optimizer, step_size=args.lr_decay_step_size, gamma=args.lr_decay_factor + ) + print("Apply lr scheduler StepLR") + elif args.lr_scheduler == "ReduceLROnPlateau": + lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau( + optimizer, factor=args.lr_decay_factor, patience=args.lr_decay_patience, min_lr=args.min_lr + ) + print("Apply lr scheduler ReduceLROnPlateau") + elif args.lr_scheduler == "StepLRCustomized": + print("Will decay with {}, at epochs {}".format(args.lr_decay_factor, args.StepLRCustomized_scheduler)) + print("Apply lr scheduler StepLR (customized)") + else: + print("lr scheduler {} is not included.".format(args.lr_scheduler)) + global_learning_rate = args.lr + + train_acc_list, val_acc_list = [], [] + test_30_list, test_40_list, test_50_list, test_70_list, test_95_list = [], [], [], [], [] + best_val_acc, best_val_idx = -1e10, 0 + for epoch in range(1, args.epochs + 1): + start_time = time.time() + loss_acc = train(epoch, device, train_loader, optimizer) + print("Epoch: {}\nLoss: {}".format(epoch, loss_acc)) + + if epoch % args.print_every_epoch == 0: + if args.eval_train: + train_acc, train_target, train_pred = eval(device, train_loader) + else: + train_acc = 0 + + val_acc = eval(device, val_loader) + test_30 = eval(device, test_30_loader) + test_40 = eval(device, test_40_loader) + test_50 = eval(device, test_50_loader) + test_70 = eval(device, test_70_loader) + test_95 = eval(device, test_95_loader) + + + train_acc_list.append(train_acc) + val_acc_list.append(val_acc) + test_30_list.append(test_30) + test_40_list.append(test_40) + test_50_list.append(test_50) + test_70_list.append(test_70) + test_95_list.append(test_95) + + print( + "train: {:.6f}\tval: {:.6f}\ttest_30: {:.6f}\ttest_40: {:.6f}\ttest_50: {:.6f}\ttest_70: {:.6f}\ttest_95: {:.6f}".format( + train_acc, val_acc, test_30, test_40, test_50, test_70, test_95 + ) + ) + + print(val_acc, best_val_acc) + if val_acc > best_val_acc: + best_val_acc = val_acc + best_val_idx = len(train_acc_list) - 1 + if not args.output_model_dir == "": + save_model(save_best=True) + print(val_acc, best_val_acc) + + if args.lr_scheduler == "StepLRCustomized" and epoch in args.StepLRCustomized_scheduler: + print('ChanGINg learning rate, from {} to {}'.format(global_learning_rate, global_learning_rate * args.lr_decay_factor)), + global_learning_rate *= args.lr_decay_factor + for param_group in optimizer.param_groups: + param_group['lr'] = global_learning_rate + print("Took\t{}\n".format(time.time() - start_time)) + + print( + "best train: {:.6f}\tval: {:.6f}\ttest_30: {:.6f}\ttest_40: {:.6f}\ttest_50: {:.6f}\ttest_70: {:.6f}\ttest_95: {:.6f}".format( + train_acc_list[best_val_idx], + val_acc_list[best_val_idx], + test_30_list[best_val_idx], + test_40_list[best_val_idx], + test_50_list[best_val_idx], + test_70_list[best_val_idx], + test_95_list[best_val_idx] + ) + ) + + save_model(save_best=False) diff --git a/examples_3D/finetune_ECSingle.py b/examples_3D/finetune_ECSingle.py new file mode 100644 index 0000000..9e0086b --- /dev/null +++ b/examples_3D/finetune_ECSingle.py @@ -0,0 +1,407 @@ +import os +import time + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch_geometric.loader import DataLoader as PyGDataLoader +from torch_geometric.nn import global_max_pool, global_mean_pool +from tqdm import tqdm +from torch.utils.data import DataLoader as TorchDataLoader + +from config import args +from Geom3D.datasets import DatasetECSingle, DatasetGVP, DatasetECSingleGearNet +from Geom3D.models import ProNet, GearNetIEConv, MQAModel, CD_Convolution +import Geom3D.models.GearNet_layer as GearNet_layer + + +def model_setup(): + num_class = 384 + + if args.model_3d == "GVP": + node_in_dim = (6, 3) + node_h_dim = (100, 16) + edge_in_dim = (32, 1) + edge_h_dim = (32, 1) + model = MQAModel(node_in_dim, node_h_dim, edge_in_dim, edge_h_dim, out_channels=num_class) + graph_pred_linear = None + + elif args.model_3d == "GearNet": + input_dim = 21 + model = GearNetIEConv( + input_dim=input_dim, embedding_dim=512, hidden_dims=[512, 512, 512, 512, 512, 512], num_relation=args.num_relation, + batch_norm=True, concat_hidden=True, short_cut=True, readout=args.GearNet_readout, layer_norm=True, dropout=args.GearNet_dropout, + edge_input_dim=args.GearNet_edge_input_dim, num_angle_bin=args.GearNet_num_angle_bin) + + num_mlp_layer = 3 + hidden_dims = [model.output_dim] * (num_mlp_layer - 1) + graph_pred_linear = GearNet_layer.MultiLayerPerceptron( + model.output_dim, hidden_dims + [num_class], batch_norm=True, dropout=0.5) + + elif args.model_3d == "GearNet_IEConv": + input_dim = 21 + model = GearNetIEConv( + input_dim=input_dim, embedding_dim=512, hidden_dims=[512, 512, 512, 512, 512, 512], num_relation=7, + batch_norm=True, concat_hidden=True, short_cut=True, readout="sum", layer_norm=True, dropout=0.2, use_ieconv=True) + + num_mlp_layer = 3 + hidden_dims = [model.output_dim] * (num_mlp_layer - 1) + graph_pred_linear = GearNet_layer.MultiLayerPerceptron( + model.output_dim, hidden_dims + [num_class], batch_norm=True, dropout=0.5) + + elif args.model_3d == "ProNet": + model = ProNet( + level=args.ProNet_level, + dropout=args.ProNet_dropout, + out_channels=num_class, + euler_noise=args.euler_noise, + ) + graph_pred_linear = None + + elif args.model_3d == "CDConv": + geometric_radii = [x * args.CDConv_radius for x in args.CDConv_geometric_raddi_coeff] + model = CD_Convolution( + geometric_radii=geometric_radii, + sequential_kernel_size=args.CDConv_kernel_size, + kernel_channels=args.CDConv_kernel_channels, channels=args.CDConv_channels, base_width=args.CDConv_base_width, + num_classes=num_class) + graph_pred_linear = None + + else: + raise Exception("3D model {} not included.".format(args.model_3d)) + return model, graph_pred_linear + + +def load_model(model, graph_pred_linear, model_weight_file): + print("Loading from {}".format(model_weight_file)) + if "MoleculeSDE" in model_weight_file: + model_weight = torch.load(model_weight_file) + model.load_state_dict(model_weight["model_3D"]) + if (graph_pred_linear is not None) and ("graph_pred_linear" in model_weight): + graph_pred_linear.load_state_dict(model_weight["graph_pred_linear"]) + + else: + model_weight = torch.load(model_weight_file) + model.load_state_dict(model_weight["model"]) + if (graph_pred_linear is not None) and ("graph_pred_linear" in model_weight): + graph_pred_linear.load_state_dict(model_weight["graph_pred_linear"]) + return + + +def save_model(save_best): + if not args.output_model_dir == "": + if save_best: + print("save model with optimal loss") + output_model_path = os.path.join(args.output_model_dir, "model.pth") + saved_model_dict = {} + saved_model_dict["model"] = model.state_dict() + if graph_pred_linear is not None: + saved_model_dict["graph_pred_linear"] = graph_pred_linear.state_dict() + torch.save(saved_model_dict, output_model_path) + + else: + print("save model in the last epoch") + output_model_path = os.path.join(args.output_model_dir, "model_final.pth") + saved_model_dict = {} + saved_model_dict["model"] = model.state_dict() + if graph_pred_linear is not None: + saved_model_dict["graph_pred_linear"] = graph_pred_linear.state_dict() + torch.save(saved_model_dict, output_model_path) + return + + +def train(epoch, device, loader, optimizer): + model.train() + if graph_pred_linear is not None: + graph_pred_linear.train() + + loss_acc = 0 + num_iters = len(loader) + + if args.verbose: + L = tqdm(loader) + else: + L = loader + for step, batch in enumerate(L): + if args.model_3d == "ProNet": + if args.mask: + # random mask node aatype + mask_indice = torch.tensor(np.random.choice(batch.num_nodes, int(batch.num_nodes * args.mask_aatype), replace=False)) + batch.x[:, 0][mask_indice] = 25 + if args.noise: + # add gaussian noise to atom coords + gaussian_noise = torch.clip(torch.normal(mean=0.0, std=0.1, size=batch.coords_ca.shape), min=-0.3, max=0.3) + batch.coords_ca += gaussian_noise + if args.ProNet_level != 'aminoacid': + batch.coords_n += gaussian_noise + batch.coords_c += gaussian_noise + if args.deform: + # Anisotropic scale + deform = torch.clip(torch.normal(mean=1.0, std=0.1, size=(1, 3)), min=0.9, max=1.1) + batch.coords_ca *= deform + if args.ProNet_level != 'aminoacid': + batch.coords_n *= deform + batch.coords_c *= deform + + batch = batch.to(device) + + if args.model_3d == "GVP": + molecule_3D_repr = model(batch=batch) + elif args.model_3d in ["GearNet", "GearNet_IEConv"]: + molecule_3D_repr = model(batch, batch.node_feature.float())["graph_feature"] + elif args.model_3d == "ProNet": + molecule_3D_repr = model(batch) + elif args.model_3d == "CDConv": + molecule_3D_repr = model(batch, split="training") + elif args.model_3d == "FrameNetProtein": + molecule_3D_repr = model(batch.coords_n, batch.coords_ca, batch.coords_c, batch.seq, batch.batch) + + if graph_pred_linear is not None: + pred = graph_pred_linear(molecule_3D_repr).squeeze(1) + else: + pred = molecule_3D_repr.squeeze(1) + + y = batch.y + + loss = criterion(pred, y) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + loss_acc += loss.cpu().detach().item() + + if args.lr_scheduler in ["CosineAnnealingWarmRestarts"]: + lr_scheduler.step(epoch - 1 + step / num_iters) + + loss_acc /= len(loader) + if args.lr_scheduler in ["StepLR", "CosineAnnealingLR"]: + lr_scheduler.step() + elif args.lr_scheduler in [ "ReduceLROnPlateau"]: + lr_scheduler.step(loss_acc) + + return loss_acc + + +@torch.no_grad() +def eval(device, loader): + model.eval() + if graph_pred_linear is not None: + graph_pred_linear.eval() + y_true = [] + y_scores = [] + + if args.verbose: + L = tqdm(loader) + else: + L = loader + for batch in L: + if args.model_3d == "ProNet": + if args.mask: + # random mask node aatype + mask_indice = torch.tensor(np.random.choice(batch.num_nodes, int(batch.num_nodes * args.mask_aatype), replace=False)) + batch.x[:, 0][mask_indice] = 25 + if args.noise: + # add gaussian noise to atom coords + gaussian_noise = torch.clip(torch.normal(mean=0.0, std=0.1, size=batch.coords_ca.shape), min=-0.3, max=0.3) + batch.coords_ca += gaussian_noise + if args.ProNet_level != 'aminoacid': + batch.coords_n += gaussian_noise + batch.coords_c += gaussian_noise + if args.deform: + # Anisotropic scale + deform = torch.clip(torch.normal(mean=1.0, std=0.1, size=(1, 3)), min=0.9, max=1.1) + batch.coords_ca *= deform + if args.ProNet_level != 'aminoacid': + batch.coords_n *= deform + batch.coords_c *= deform + + batch = batch.to(device) + + if args.model_3d == "GVP": + molecule_3D_repr = model(batch=batch) + elif args.model_3d in ["GearNet", "GearNet_IEConv"]: + molecule_3D_repr = model(batch, batch.node_feature.float())["graph_feature"] + elif args.model_3d == "ProNet": + molecule_3D_repr = model(batch) + elif args.model_3d == "CDConv": + molecule_3D_repr = model(batch) + elif args.model_3d == "FrameNetProtein": + molecule_3D_repr = model(batch.coords_n, batch.coords_ca, batch.coords_c, batch.seq, batch.batch) + + if graph_pred_linear is not None: + pred = graph_pred_linear(molecule_3D_repr).squeeze(1) + else: + pred = molecule_3D_repr.squeeze(1) + pred = pred.argmax(dim=-1) + + y = batch.y + + y_true.append(y) + y_scores.append(pred) + + y_true = torch.cat(y_true, dim=0).cpu().numpy() + y_scores = torch.cat(y_scores, dim=0).cpu().numpy() + + L = len(y_true) + acc = sum(y_true == y_scores) * 1. / L + return acc + + +if __name__ == "__main__": + torch.manual_seed(args.seed) + np.random.seed(args.seed) + device = ( + torch.device("cuda:" + str(args.device)) + if torch.cuda.is_available() + else torch.device("cpu") + ) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + + data_root = args.data_root + dataset_class = DatasetECSingle + if args.model_3d == "GearNet": + dataset_class = DatasetECSingleGearNet + + train_dataset = dataset_class(root=data_root, split='Train') + valid_dataset = dataset_class(root=data_root, split='Val') + test_dataset = dataset_class(root=data_root, split='Test') + + if args.model_3d == "GVP": + data_root = "../data/ECSingle_GVP" + train_dataset = DatasetGVP( + data_root, train_dataset, split='Train', num_positional_embeddings=args.num_positional_embeddings, top_k=args.top_k, num_rbf=args.num_rbf) + valid_dataset = DatasetGVP( + data_root, valid_dataset, split='Val', num_positional_embeddings=args.num_positional_embeddings, top_k=args.top_k, num_rbf=args.num_rbf) + test_dataset = DatasetGVP( + data_root, test_dataset, split='Test', num_positional_embeddings=args.num_positional_embeddings, top_k=args.top_k, num_rbf=args.num_rbf) + + criterion = nn.CrossEntropyLoss() + + DataLoaderClass = PyGDataLoader + dataloader_kwargs = {} + + if args.model_3d in ["GearNet", "GearNet_IEConv"]: + dataloader_kwargs["collate_fn"] = DatasetECSingleGearNet.collate_fn + DataLoaderClass = TorchDataLoader + + train_loader = DataLoaderClass( + train_dataset, + batch_size=args.batch_size, + shuffle=True, + num_workers=args.num_workers, + **dataloader_kwargs + ) + val_loader = DataLoaderClass( + valid_dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.num_workers, + **dataloader_kwargs + ) + test_loader = DataLoaderClass( + test_dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.num_workers, + **dataloader_kwargs + ) + + model, graph_pred_linear = model_setup() + + if args.input_model_file is not "": + load_model(model, graph_pred_linear, args.input_model_file) + model.to(device) + print(model) + if graph_pred_linear is not None: + graph_pred_linear.to(device) + print(graph_pred_linear) + + # set up optimizer + # different learning rate for different part of GNN + model_param_group = [{"params": model.parameters(), "lr": args.lr}] + if graph_pred_linear is not None: + model_param_group.append( + {"params": graph_pred_linear.parameters(), "lr": args.lr} + ) + if args.optimizer == "Adam": + optimizer = optim.Adam(model_param_group, lr=args.lr, weight_decay=args.decay) + elif args.optimizer == "SGD": + optimizer = optim.SGD(model_param_group, lr=args.lr, weight_decay=5e-4, momentum=0.9) + + lr_scheduler = None + if args.lr_scheduler == "CosineAnnealingLR": + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, args.epochs + ) + print("Apply lr scheduler CosineAnnealingLR") + elif args.lr_scheduler == "CosineAnnealingWarmRestarts": + lr_scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts( + optimizer, args.epochs, eta_min=1e-4 + ) + print("Apply lr scheduler CosineAnnealingWarmRestarts") + elif args.lr_scheduler == "StepLR": + lr_scheduler = optim.lr_scheduler.StepLR( + optimizer, step_size=args.lr_decay_step_size, gamma=args.lr_decay_factor + ) + print("Apply lr scheduler StepLR") + elif args.lr_scheduler == "ReduceLROnPlateau": + lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau( + optimizer, factor=args.lr_decay_factor, patience=args.lr_decay_patience, min_lr=args.min_lr + ) + print("Apply lr scheduler ReduceLROnPlateau") + elif args.lr_scheduler == "StepLRCustomized": + print("Will decay with {}, at epochs {}".format(args.lr_decay_factor, args.StepLRCustomized_scheduler)) + print("Apply lr scheduler StepLR (customized)") + else: + print("lr scheduler {} is not included.".format(args.lr_scheduler)) + global_learning_rate = args.lr + + train_acc_list, val_acc_list, test_acc_list = [], [], [] + best_val_acc, best_val_idx = -1e10, 0 + for epoch in range(1, args.epochs + 1): + start_time = time.time() + loss_acc = train(epoch, device, train_loader, optimizer) + print("Epoch: {}\nLoss: {}".format(epoch, loss_acc)) + + if epoch % args.print_every_epoch == 0: + if args.eval_train: + train_acc, train_target, train_pred = eval(device, train_loader) + else: + train_acc = 0 + val_acc = eval(device, val_loader) + test_acc = eval(device, test_loader) + + train_acc_list.append(train_acc) + val_acc_list.append(val_acc) + test_acc_list.append(test_acc) + print( + "train: {:.6f}\tval: {:.6f}\ttest: {:.6f}".format( + train_acc, val_acc, test_acc + ) + ) + + if val_acc > best_val_acc: + best_val_acc = val_acc + best_val_idx = len(train_acc_list) - 1 + if not args.output_model_dir == "": + save_model(save_best=True) + + if args.lr_scheduler == "StepLRCustomized" and epoch in args.StepLRCustomized_scheduler: + print('ChanGINg learning rate, from {} to {}'.format(global_learning_rate, global_learning_rate * args.lr_decay_factor)), + global_learning_rate *= args.lr_decay_factor + for param_group in optimizer.param_groups: + param_group['lr'] = global_learning_rate + print("Took\t{}\n".format(time.time() - start_time)) + + print( + "best train: {:.6f}\tval: {:.6f}\ttest: {:.6f}".format( + train_acc_list[best_val_idx], + val_acc_list[best_val_idx], + test_acc_list[best_val_idx], + ) + ) + + save_model(save_best=False) diff --git a/examples_3D/finetune_FOLD.py b/examples_3D/finetune_FOLD.py new file mode 100644 index 0000000..c16d8b2 --- /dev/null +++ b/examples_3D/finetune_FOLD.py @@ -0,0 +1,436 @@ +import os +import time + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader as TorchDataLoader +from torch_geometric.loader import DataLoader as PyGDataLoader +from torch_geometric.nn import global_max_pool, global_mean_pool +from tqdm import tqdm + +from config import args +from Geom3D.datasets import DatasetFOLD, DatasetGVP, DatasetFOLDGearNet +from Geom3D.models import ProNet, GearNetIEConv, MQAModel, CD_Convolution +import Geom3D.models.GearNet_layer as GearNet_layer + + +def model_setup(): + num_class = 1195 + graph_pred_linear = None + + if args.model_3d == "GVP": + node_in_dim = (6, 3) + node_h_dim = (100, 16) + edge_in_dim = (32, 1) + edge_h_dim = (32, 1) + model = MQAModel(node_in_dim, node_h_dim, edge_in_dim, edge_h_dim, out_channels=num_class) + + elif args.model_3d == "GearNet": + input_dim = 21 + model = GearNetIEConv( + input_dim=input_dim, embedding_dim=512, hidden_dims=[512, 512, 512, 512, 512, 512], num_relation=args.num_relation, + batch_norm=True, concat_hidden=True, short_cut=True, readout=args.GearNet_readout, layer_norm=True, dropout=args.GearNet_dropout, + edge_input_dim=args.GearNet_edge_input_dim, num_angle_bin=args.GearNet_num_angle_bin) + + num_mlp_layer = 3 + hidden_dims = [model.output_dim] * (num_mlp_layer - 1) + graph_pred_linear = GearNet_layer.MultiLayerPerceptron( + model.output_dim, hidden_dims + [num_class], batch_norm=True, dropout=0.5) + + elif args.model_3d == "GearNet_IEConv": + input_dim = 21 + model = GearNetIEConv( + input_dim=input_dim, embedding_dim=512, hidden_dims=[512, 512, 512, 512, 512, 512], num_relation=7, + batch_norm=True, concat_hidden=True, short_cut=True, readout="sum", layer_norm=True, dropout=0.2, use_ieconv=True) + + num_mlp_layer = 3 + hidden_dims = [model.output_dim] * (num_mlp_layer - 1) + graph_pred_linear = GearNet_layer.MultiLayerPerceptron( + model.output_dim, hidden_dims + [num_class], batch_norm=True, dropout=0.5) + + elif args.model_3d == "ProNet": + model = ProNet( + level=args.ProNet_level, + dropout=args.ProNet_dropout, + out_channels=num_class, + euler_noise=args.euler_noise, + ) + + elif args.model_3d == "CDConv": + geometric_radii = [x * args.CDConv_radius for x in args.CDConv_geometric_raddi_coeff] + model = CD_Convolution( + geometric_radii=geometric_radii, + sequential_kernel_size=args.CDConv_kernel_size, + kernel_channels=args.CDConv_kernel_channels, channels=args.CDConv_channels, base_width=args.CDConv_base_width, + num_classes=num_class) + + else: + raise Exception("3D model {} not included.".format(args.model_3d)) + return model, graph_pred_linear + + +def load_model(model, graph_pred_linear, model_weight_file): + print("Loading from {}".format(model_weight_file)) + if "MoleculeSDE" in model_weight_file: + model_weight = torch.load(model_weight_file) + model.load_state_dict(model_weight["model_3D"]) + if (graph_pred_linear is not None) and ("graph_pred_linear" in model_weight): + graph_pred_linear.load_state_dict(model_weight["graph_pred_linear"]) + + else: + model_weight = torch.load(model_weight_file) + model.load_state_dict(model_weight["model"]) + if (graph_pred_linear is not None) and ("graph_pred_linear" in model_weight): + graph_pred_linear.load_state_dict(model_weight["graph_pred_linear"]) + return + + +def save_model(save_best): + if not args.output_model_dir == "": + if save_best: + print("save model with optimal loss") + output_model_path = os.path.join(args.output_model_dir, "model.pth") + saved_model_dict = {} + saved_model_dict["model"] = model.state_dict() + if graph_pred_linear is not None: + saved_model_dict["graph_pred_linear"] = graph_pred_linear.state_dict() + torch.save(saved_model_dict, output_model_path) + + else: + print("save model in the last epoch") + output_model_path = os.path.join(args.output_model_dir, "model_final.pth") + saved_model_dict = {} + saved_model_dict["model"] = model.state_dict() + if graph_pred_linear is not None: + saved_model_dict["graph_pred_linear"] = graph_pred_linear.state_dict() + torch.save(saved_model_dict, output_model_path) + return + + +def train(epoch, device, loader, optimizer): + model.train() + if graph_pred_linear is not None: + graph_pred_linear.train() + + loss_acc = 0 + num_iters = len(loader) + + if args.verbose: + L = tqdm(loader) + else: + L = loader + for step, batch in enumerate(L): + if args.model_3d == "ProNet": + if args.mask: + # random mask node aatype + mask_indice = torch.tensor(np.random.choice(batch.num_nodes, int(batch.num_nodes * args.mask_aatype), replace=False)) + batch.x[:, 0][mask_indice] = 25 + if args.noise: + # add gaussian noise to atom coords + gaussian_noise = torch.clip(torch.normal(mean=0.0, std=0.1, size=batch.coords_ca.shape), min=-0.3, max=0.3) + batch.coords_ca += gaussian_noise + if args.ProNet_level != 'aminoacid': + batch.coords_n += gaussian_noise + batch.coords_c += gaussian_noise + if args.deform: + # Anisotropic scale + deform = torch.clip(torch.normal(mean=1.0, std=0.1, size=(1, 3)), min=0.9, max=1.1) + batch.coords_ca *= deform + if args.ProNet_level != 'aminoacid': + batch.coords_n *= deform + batch.coords_c *= deform + + batch = batch.to(device) + + if args.model_3d == "GVP": + molecule_3D_repr = model(batch=batch) + elif args.model_3d in ["GearNet", "GearNet_IEConv"]: + molecule_3D_repr = model(batch, batch.node_feature.float())["graph_feature"] + elif args.model_3d == "ProNet": + molecule_3D_repr = model(batch) + elif args.model_3d == "CDConv": + molecule_3D_repr = model(batch, split="training") + elif args.model_3d == "FrameNetProtein": + molecule_3D_repr = model(batch.coords_n, batch.coords_ca, batch.coords_c, batch.seq, batch.batch, split="training") + + if graph_pred_linear is not None: + pred = graph_pred_linear(molecule_3D_repr).squeeze(1) + else: + pred = molecule_3D_repr.squeeze(1) + + y = batch.y + + loss = criterion(pred, y) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + loss_acc += loss.cpu().detach().item() + + if args.lr_scheduler in ["CosineAnnealingWarmRestarts"]: + lr_scheduler.step(epoch - 1 + step / num_iters) + + loss_acc /= len(loader) + if args.lr_scheduler in ["StepLR", "CosineAnnealingLR"]: + lr_scheduler.step() + elif args.lr_scheduler in [ "ReduceLROnPlateau"]: + lr_scheduler.step(loss_acc) + + return loss_acc + + +@torch.no_grad() +def eval(device, loader): + model.eval() + if graph_pred_linear is not None: + graph_pred_linear.eval() + y_true = [] + y_scores = [] + + if args.verbose: + L = tqdm(loader) + else: + L = loader + for batch in L: + if args.model_3d == "ProNet": + if args.mask: + # random mask node aatype + mask_indice = torch.tensor(np.random.choice(batch.num_nodes, int(batch.num_nodes * args.mask_aatype), replace=False)) + batch.x[:, 0][mask_indice] = 25 + if args.noise: + # add gaussian noise to atom coords + gaussian_noise = torch.clip(torch.normal(mean=0.0, std=0.1, size=batch.coords_ca.shape), min=-0.3, max=0.3) + batch.coords_ca += gaussian_noise + if args.ProNet_level != 'aminoacid': + batch.coords_n += gaussian_noise + batch.coords_c += gaussian_noise + if args.deform: + # Anisotropic scale + deform = torch.clip(torch.normal(mean=1.0, std=0.1, size=(1, 3)), min=0.9, max=1.1) + batch.coords_ca *= deform + if args.ProNet_level != 'aminoacid': + batch.coords_n *= deform + batch.coords_c *= deform + + batch = batch.to(device) + + if args.model_3d == "GVP": + molecule_3D_repr = model(batch=batch) + elif args.model_3d in ["GearNet", "GearNet_IEConv"]: + molecule_3D_repr = model(batch, batch.node_feature.float())["graph_feature"] + elif args.model_3d == "ProNet": + molecule_3D_repr = model(batch) + elif args.model_3d == "CDConv": + molecule_3D_repr = model(batch) + elif args.model_3d == "FrameNetProtein": + molecule_3D_repr = model(batch.coords_n, batch.coords_ca, batch.coords_c, batch.seq, batch.batch) + + if graph_pred_linear is not None: + pred = graph_pred_linear(molecule_3D_repr).squeeze() + else: + pred = molecule_3D_repr.squeeze() + pred = pred.argmax(dim=-1) + + y = batch.y + + y_true.append(y) + y_scores.append(pred) + + y_true = torch.cat(y_true, dim=0).cpu().numpy() + y_scores = torch.cat(y_scores, dim=0).cpu().numpy() + + L = len(y_true) + acc = sum(y_true == y_scores) * 1. / L + return acc + +if __name__ == "__main__": + torch.manual_seed(args.seed) + np.random.seed(args.seed) + device = ( + torch.device("cuda:" + str(args.device)) + if torch.cuda.is_available() + else torch.device("cpu") + ) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + + data_root = args.data_root + + dataset_class = DatasetFOLD + if args.model_3d == "GearNet": + dataset_class = DatasetFOLDGearNet + + train_dataset = dataset_class(root=data_root, split='training') + valid_dataset = dataset_class(root=data_root, split='validation') + test_fold_dataset = dataset_class(root=data_root, split='test_fold') + test_superfamily_dataset = dataset_class(root=data_root, split='test_superfamily') + test_family_dataset = dataset_class(root=data_root, split='test_family') + + if args.model_3d == "GVP": + data_root = "../data/FOLD_GVP" + train_dataset = DatasetGVP( + root=data_root, dataset=train_dataset, split='training', num_positional_embeddings=args.num_positional_embeddings, top_k=args.top_k, num_rbf=args.num_rbf) + valid_dataset = DatasetGVP( + root=data_root, dataset=valid_dataset, split='validation', num_positional_embeddings=args.num_positional_embeddings, top_k=args.top_k, num_rbf=args.num_rbf) + test_fold_dataset = DatasetGVP( + root=data_root, dataset=test_fold_dataset, split='test_fold', num_positional_embeddings=args.num_positional_embeddings, top_k=args.top_k, num_rbf=args.num_rbf) + test_superfamily_dataset = DatasetGVP( + root=data_root, dataset=test_superfamily_dataset, split='test_superfamily', num_positional_embeddings=args.num_positional_embeddings, top_k=args.top_k, num_rbf=args.num_rbf) + test_family_dataset = DatasetGVP( + root=data_root, dataset=test_family_dataset, split='test_family', num_positional_embeddings=args.num_positional_embeddings, top_k=args.top_k, num_rbf=args.num_rbf) + + criterion = nn.CrossEntropyLoss() + + DataLoaderClass = PyGDataLoader + dataloader_kwargs = {} + if args.model_3d in ["GearNet", "GearNet_IEConv"]: + dataloader_kwargs["collate_fn"] = DatasetFOLDGearNet.collate_fn + DataLoaderClass = TorchDataLoader + + train_loader = DataLoaderClass( + train_dataset, + batch_size=args.batch_size, + shuffle=True, + num_workers=args.num_workers, + drop_last=True, + **dataloader_kwargs + ) + val_loader = DataLoaderClass( + valid_dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.num_workers, + drop_last=True, + **dataloader_kwargs + ) + test_fold_loader = DataLoaderClass( + test_fold_dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.num_workers, + drop_last=True, + **dataloader_kwargs + ) + test_superfamily_loader = DataLoaderClass( + test_superfamily_dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.num_workers, + drop_last=True, + **dataloader_kwargs + ) + test_family_loader = DataLoaderClass( + test_family_dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.num_workers, + drop_last=True, + **dataloader_kwargs + ) + + model, graph_pred_linear = model_setup() + + if args.input_model_file is not "": + load_model(model, graph_pred_linear, args.input_model_file) + model.to(device) + print(model) + if graph_pred_linear is not None: + graph_pred_linear.to(device) + print(graph_pred_linear) + + # set up optimizer + # different learning rate for different part of GNN + model_param_group = [{"params": model.parameters(), "lr": args.lr}] + if graph_pred_linear is not None: + model_param_group.append( + {"params": graph_pred_linear.parameters(), "lr": args.lr} + ) + if args.optimizer == "Adam": + optimizer = optim.Adam(model_param_group, lr=args.lr, weight_decay=args.decay) + elif args.optimizer == "SGD": + optimizer = optim.SGD(model_param_group, lr=args.lr, weight_decay=5e-4, momentum=0.9) + + lr_scheduler = None + if args.lr_scheduler == "CosineAnnealingLR": + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, args.epochs + ) + print("Apply lr scheduler CosineAnnealingLR") + elif args.lr_scheduler == "CosineAnnealingWarmRestarts": + lr_scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts( + optimizer, args.epochs, eta_min=1e-4 + ) + print("Apply lr scheduler CosineAnnealingWarmRestarts") + elif args.lr_scheduler == "StepLR": + lr_scheduler = optim.lr_scheduler.StepLR( + optimizer, step_size=args.lr_decay_step_size, gamma=args.lr_decay_factor + ) + print("Apply lr scheduler StepLR") + elif args.lr_scheduler == "ReduceLROnPlateau": + lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau( + optimizer, factor=args.lr_decay_factor, patience=args.lr_decay_patience, min_lr=args.min_lr + ) + print("Apply lr scheduler ReduceLROnPlateau") + elif args.lr_scheduler == "StepLRCustomized": + print("Will decay with {}, at epochs {}".format(args.lr_decay_factor, args.StepLRCustomized_scheduler)) + print("Apply lr scheduler StepLR (customized)") + else: + print("lr scheduler {} is not included.".format(args.lr_scheduler)) + global_learning_rate = args.lr + + train_acc_list, val_acc_list = [], [] + test_acc_fold_list, test_acc_superfamily_list, test_acc_family_list = [], [], [] + best_val_acc, best_val_idx = -1e10, 0 + for epoch in range(1, args.epochs + 1): + start_time = time.time() + loss_acc = train(epoch, device, train_loader, optimizer) + print("Epoch: {}\nLoss: {}".format(epoch, loss_acc)) + + if epoch % args.print_every_epoch == 0: + if args.eval_train: + train_acc, train_target, train_pred = eval(device, train_loader) + else: + train_acc = 0 + val_acc = eval(device, val_loader) + test_fold_acc = eval(device, test_fold_loader) + test_superfamily_acc = eval(device, test_superfamily_loader) + test_family_acc = eval(device, test_family_loader) + + train_acc_list.append(train_acc) + val_acc_list.append(val_acc) + test_acc_fold_list.append(test_fold_acc) + test_acc_superfamily_list.append(test_superfamily_acc) + test_acc_family_list.append(test_family_acc) + print( + "train: {:.6f}\tval: {:.6f}\ttest-fold: {:.6f}\ttest-superfamily: {:.6f}\ttest-family: {:.6f}".format( + train_acc, val_acc, test_fold_acc, test_superfamily_acc, test_family_acc + ) + ) + + if val_acc > best_val_acc: + best_val_acc = val_acc + best_val_idx = len(train_acc_list) - 1 + if not args.output_model_dir == "": + save_model(save_best=True) + + if args.lr_scheduler == "StepLRCustomized" and epoch in args.StepLRCustomized_scheduler: + print('ChanGINg learning rate, from {} to {}'.format(global_learning_rate, global_learning_rate * args.lr_decay_factor)), + global_learning_rate *= args.lr_decay_factor + for param_group in optimizer.param_groups: + param_group['lr'] = global_learning_rate + print("Took\t{}\n".format(time.time() - start_time)) + + print( + "best train: {:.6f}\tval: {:.6f}\ttest-fold: {:.6f}\ttest-superfamily: {:.6f}\ttest-family: {:.6f}".format( + train_acc_list[best_val_idx], + val_acc_list[best_val_idx], + test_acc_fold_list[best_val_idx], + test_acc_superfamily_list[best_val_idx], + test_acc_family_list[best_val_idx], + ) + ) + + save_model(save_best=False) diff --git a/examples_3D/finetune_GO.py b/examples_3D/finetune_GO.py new file mode 100644 index 0000000..ca2b0f4 --- /dev/null +++ b/examples_3D/finetune_GO.py @@ -0,0 +1,497 @@ +import os +import time + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader as TorchDataLoader +from torch_geometric.loader import DataLoader as PyGDataLoader +from torch_geometric.nn import global_max_pool, global_mean_pool +from tqdm import tqdm + +from config import args +from Geom3D.datasets import DatasetGO, DatasetGVP, DatasetGOGearNet +from Geom3D.models import ProNet, GearNetIEConv, MQAModel, CD_Convolution +import Geom3D.models.GearNet_layer as GearNet_layer + +def fmax(probs, labels): + thresholds = np.arange(0, 1, 0.01) + f_max = 0.0 + + for threshold in thresholds: + precision = 0.0 + recall = 0.0 + precision_cnt = 0 + recall_cnt = 0 + for idx in range(probs.shape[0]): + prob = probs[idx] + label = labels[idx] + pred = (prob > threshold).astype(np.int32) + correct_sum = np.sum(label*pred) + pred_sum = np.sum(pred) + label_sum = np.sum(label) + if pred_sum > 0: + precision += correct_sum/pred_sum + precision_cnt += 1 + if label_sum > 0: + recall += correct_sum/label_sum + recall_cnt += 1 + if recall_cnt > 0: + recall = recall / recall_cnt + else: + recall = 0 + if precision_cnt > 0: + precision = precision / precision_cnt + else: + precision = 0 + f = (2.*precision*recall)/max(precision+recall, 1e-8) + f_max = max(f, f_max) + + return f_max + +def model_setup(device): + if args.GO_level == "mf": + num_class = 489 + elif args.GO_level == "bp": + num_class = 1943 + elif args.GO_level == "cc": + num_class = 320 + + if args.model_3d == "GVP": + node_in_dim = (6, 3) + node_h_dim = (100, 16) + edge_in_dim = (32, 1) + edge_h_dim = (32, 1) + model = MQAModel(node_in_dim, node_h_dim, edge_in_dim, edge_h_dim, out_channels=num_class) + graph_pred_linear = None + + elif args.model_3d == "GearNet": + input_dim = 21 + model = GearNetIEConv( + input_dim=input_dim, embedding_dim=512, hidden_dims=[512, 512, 512, 512, 512, 512], num_relation=args.num_relation, + batch_norm=True, concat_hidden=True, short_cut=True, readout=args.GearNet_readout, layer_norm=True, dropout=args.GearNet_dropout, + edge_input_dim=args.GearNet_edge_input_dim, num_angle_bin=args.GearNet_num_angle_bin) + + num_mlp_layer = 3 + hidden_dims = [model.output_dim] * (num_mlp_layer - 1) + graph_pred_linear = GearNet_layer.MultiLayerPerceptron( + model.output_dim, hidden_dims + [num_class], batch_norm=True, dropout=0.5) + + elif args.model_3d == "GearNet_IEConv": + input_dim = 21 + model = GearNetIEConv( + input_dim=input_dim, embedding_dim=512, hidden_dims=[512, 512, 512, 512, 512, 512], num_relation=7, + batch_norm=True, concat_hidden=True, short_cut=True, readout="sum", layer_norm=True, dropout=0.2, use_ieconv=True) + + num_mlp_layer = 3 + hidden_dims = [model.output_dim] * (num_mlp_layer - 1) + graph_pred_linear = GearNet_layer.MultiLayerPerceptron( + model.output_dim, hidden_dims + [num_class], batch_norm=True, dropout=0.5) + + elif args.model_3d == "ProNet": + model = ProNet( + level=args.ProNet_level, + dropout=args.ProNet_dropout, + out_channels=num_class, + euler_noise=args.euler_noise, + ) + graph_pred_linear = None + + elif args.model_3d == "CDConv": + geometric_radii = [x * args.CDConv_radius for x in args.CDConv_geometric_raddi_coeff] + model = CD_Convolution( + geometric_radii=geometric_radii, + sequential_kernel_size=args.CDConv_kernel_size, + kernel_channels=args.CDConv_kernel_channels, channels=args.CDConv_channels, base_width=args.CDConv_base_width, + num_classes=num_class) + graph_pred_linear = None + + else: + raise Exception("3D model {} not included.".format(args.model_3d)) + return model, graph_pred_linear + + +def load_model(model, graph_pred_linear, model_weight_file): + print("Loading from {}".format(model_weight_file)) + if "MoleculeSDE" in model_weight_file: + model_weight = torch.load(model_weight_file) + model.load_state_dict(model_weight["model_3D"]) + if (graph_pred_linear is not None) and ("graph_pred_linear" in model_weight): + graph_pred_linear.load_state_dict(model_weight["graph_pred_linear"]) + + else: + model_weight = torch.load(model_weight_file) + model.load_state_dict(model_weight["model"]) + if (graph_pred_linear is not None) and ("graph_pred_linear" in model_weight): + graph_pred_linear.load_state_dict(model_weight["graph_pred_linear"]) + return + + +def save_model(save_best): + if not args.output_model_dir == "": + if save_best: + print("save model with optimal loss") + output_model_path = os.path.join(args.output_model_dir, "model.pth") + saved_model_dict = {} + saved_model_dict["model"] = model.state_dict() + if graph_pred_linear is not None: + saved_model_dict["graph_pred_linear"] = graph_pred_linear.state_dict() + torch.save(saved_model_dict, output_model_path) + + else: + print("save model in the last epoch") + output_model_path = os.path.join(args.output_model_dir, "model_final.pth") + saved_model_dict = {} + saved_model_dict["model"] = model.state_dict() + if graph_pred_linear is not None: + saved_model_dict["graph_pred_linear"] = graph_pred_linear.state_dict() + torch.save(saved_model_dict, output_model_path) + return + + +def train(epoch, device, loader, optimizer): + model.train() + if graph_pred_linear is not None: + graph_pred_linear.train() + + loss_acc = 0 + num_iters = len(loader) + + if args.verbose: + L = tqdm(loader) + else: + L = loader + for step, batch in enumerate(L): + if args.model_3d == "ProNet": + if args.mask: + # random mask node aatype + mask_indice = torch.tensor(np.random.choice(batch.num_nodes, int(batch.num_nodes * args.mask_aatype), replace=False)) + batch.x[:, 0][mask_indice] = 25 + if args.noise: + # add gaussian noise to atom coords + gaussian_noise = torch.clip(torch.normal(mean=0.0, std=0.1, size=batch.coords_ca.shape), min=-0.3, max=0.3) + batch.coords_ca += gaussian_noise + if args.ProNet_level != 'aminoacid': + batch.coords_n += gaussian_noise + batch.coords_c += gaussian_noise + if args.deform: + # Anisotropic scale + deform = torch.clip(torch.normal(mean=1.0, std=0.1, size=(1, 3)), min=0.9, max=1.1) + batch.coords_ca *= deform + if args.ProNet_level != 'aminoacid': + batch.coords_n *= deform + batch.coords_c *= deform + + batch = batch.to(device) + + if args.model_3d == "GVP": + molecule_3D_repr = model(batch=batch) + elif args.model_3d in ["GearNet", "GearNet_IEConv"]: + molecule_3D_repr = model(batch, batch.node_feature.float())["graph_feature"] + elif args.model_3d == "ProNet": + molecule_3D_repr = model(batch) + elif args.model_3d == "CDConv": + molecule_3D_repr = model(batch, split="training") + + if graph_pred_linear is not None: + pred = graph_pred_linear(molecule_3D_repr).squeeze(1) + else: + pred = molecule_3D_repr.squeeze(1) + + y = batch.y + + loss = criterion(pred.sigmoid(), y) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + loss_acc += loss.cpu().detach().item() + + if args.lr_scheduler in ["CosineAnnealingWarmRestarts"]: + lr_scheduler.step(epoch - 1 + step / num_iters) + + loss_acc /= len(loader) + if args.lr_scheduler in ["StepLR", "CosineAnnealingLR"]: + lr_scheduler.step() + elif args.lr_scheduler in [ "ReduceLROnPlateau"]: + lr_scheduler.step(loss_acc) + + return loss_acc + + +@torch.no_grad() +def eval(device, loader): + model.eval() + if graph_pred_linear is not None: + graph_pred_linear.eval() + y_true = [] + y_scores = [] + + if args.verbose: + L = tqdm(loader) + else: + L = loader + for batch in L: + if args.model_3d == "ProNet": + if args.mask: + # random mask node aatype + mask_indice = torch.tensor(np.random.choice(batch.num_nodes, int(batch.num_nodes * args.mask_aatype), replace=False)) + batch.x[:, 0][mask_indice] = 25 + if args.noise: + # add gaussian noise to atom coords + gaussian_noise = torch.clip(torch.normal(mean=0.0, std=0.1, size=batch.coords_ca.shape), min=-0.3, max=0.3) + batch.coords_ca += gaussian_noise + if args.ProNet_level != 'aminoacid': + batch.coords_n += gaussian_noise + batch.coords_c += gaussian_noise + if args.deform: + # Anisotropic scale + deform = torch.clip(torch.normal(mean=1.0, std=0.1, size=(1, 3)), min=0.9, max=1.1) + batch.coords_ca *= deform + if args.ProNet_level != 'aminoacid': + batch.coords_n *= deform + batch.coords_c *= deform + + batch = batch.to(device) + + if args.model_3d == "GVP": + molecule_3D_repr = model(batch=batch) + elif args.model_3d in ["GearNet", "GearNet_IEConv"]: + molecule_3D_repr = model(batch, batch.node_feature.float())["graph_feature"] + elif args.model_3d == "ProNet": + molecule_3D_repr = model(batch) + elif args.model_3d == "CDConv": + molecule_3D_repr = model(batch) + + if graph_pred_linear is not None: + pred = graph_pred_linear(molecule_3D_repr).squeeze() + else: + pred = molecule_3D_repr.squeeze() + pred = pred.sigmoid() + + y = batch.y + + y_true.append(y) + y_scores.append(pred) + + y_true = torch.cat(y_true, dim=0).cpu().numpy() + y_scores = torch.cat(y_scores, dim=0).cpu().numpy() + + return fmax(y_scores, y_true) + +if __name__ == "__main__": + torch.manual_seed(args.seed) + np.random.seed(args.seed) + device = ( + torch.device("cuda:" + str(args.device)) + if torch.cuda.is_available() + else torch.device("cpu") + ) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + + data_root = args.data_root + + dataset_class = DatasetGO + if args.model_3d == "GearNet": + dataset_class = DatasetGOGearNet + + train_dataset = dataset_class(root=data_root, level=args.GO_level, split='train') + valid_dataset = dataset_class(root=data_root, level=args.GO_level, split='valid') + test_30_dataset = dataset_class(root=data_root, level=args.GO_level, split='test', percent=0.3) + test_40_dataset = dataset_class(root=data_root, level=args.GO_level, split='test', percent=0.4) + test_50_dataset = dataset_class(root=data_root, level=args.GO_level, split='test', percent=0.5) + test_70_dataset = dataset_class(root=data_root, level=args.GO_level, split='test', percent=0.7) + test_95_dataset = dataset_class(root=data_root, level=args.GO_level, split='test', percent=0.95) + + if args.model_3d == "GVP": + data_root = "../data/GO_GVP_" + args.GO_level + train_dataset = DatasetGVP( + root=data_root, dataset=train_dataset, split='train', num_positional_embeddings=args.num_positional_embeddings, top_k=args.top_k, num_rbf=args.num_rbf) + valid_dataset = DatasetGVP( + root=data_root, dataset=valid_dataset, split='valid', num_positional_embeddings=args.num_positional_embeddings, top_k=args.top_k, num_rbf=args.num_rbf) + test_30_dataset = DatasetGVP( + root=data_root, dataset=test_30_dataset, split='test_0.3', num_positional_embeddings=args.num_positional_embeddings, top_k=args.top_k, num_rbf=args.num_rbf) + test_40_dataset = DatasetGVP( + root=data_root, dataset=test_40_dataset, split='test_0.4', num_positional_embeddings=args.num_positional_embeddings, top_k=args.top_k, num_rbf=args.num_rbf) + test_50_dataset = DatasetGVP( + root=data_root, dataset=test_50_dataset, split='test_0.5', num_positional_embeddings=args.num_positional_embeddings, top_k=args.top_k, num_rbf=args.num_rbf) + test_70_dataset = DatasetGVP( + root=data_root, dataset=test_70_dataset, split='test_0.7', num_positional_embeddings=args.num_positional_embeddings, top_k=args.top_k, num_rbf=args.num_rbf) + test_95_dataset = DatasetGVP( + root=data_root, dataset=test_95_dataset, split='test_0.95', num_positional_embeddings=args.num_positional_embeddings, top_k=args.top_k, num_rbf=args.num_rbf) + + criterion = nn.BCELoss() + + DataLoaderClass = PyGDataLoader + dataloader_kwargs = {} + if args.model_3d in ["GearNet", "GearNet_IEConv"]: + dataloader_kwargs["collate_fn"] = DatasetGOGearNet.collate_fn + DataLoaderClass = TorchDataLoader + + train_loader = DataLoaderClass( + train_dataset, + batch_size=args.batch_size, + shuffle=True, + num_workers=args.num_workers, + **dataloader_kwargs + ) + val_loader = DataLoaderClass( + valid_dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.num_workers, + **dataloader_kwargs + ) + test_30_loader = DataLoaderClass( + test_30_dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.num_workers, + **dataloader_kwargs + ) + test_40_loader = DataLoaderClass( + test_40_dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.num_workers, + **dataloader_kwargs + ) + test_50_loader = DataLoaderClass( + test_50_dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.num_workers, + **dataloader_kwargs + ) + test_70_loader = DataLoaderClass( + test_70_dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.num_workers, + **dataloader_kwargs + ) + test_95_loader = DataLoaderClass( + test_95_dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.num_workers, + **dataloader_kwargs + ) + + model, graph_pred_linear = model_setup(device) + + if args.input_model_file is not "": + load_model(model, graph_pred_linear, args.input_model_file) + model.to(device) + print(model) + if graph_pred_linear is not None: + graph_pred_linear.to(device) + print(graph_pred_linear) + + # set up optimizer + # different learning rate for different part of GNN + model_param_group = [{"params": model.parameters(), "lr": args.lr}] + if graph_pred_linear is not None: + model_param_group.append( + {"params": graph_pred_linear.parameters(), "lr": args.lr} + ) + if args.optimizer == "Adam": + optimizer = optim.Adam(model_param_group, lr=args.lr, weight_decay=args.decay) + elif args.optimizer == "SGD": + optimizer = optim.SGD(model_param_group, lr=args.lr, weight_decay=5e-4, momentum=0.9) + + lr_scheduler = None + if args.lr_scheduler == "CosineAnnealingLR": + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, args.epochs + ) + print("Apply lr scheduler CosineAnnealingLR") + elif args.lr_scheduler == "CosineAnnealingWarmRestarts": + lr_scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts( + optimizer, args.epochs, eta_min=1e-4 + ) + print("Apply lr scheduler CosineAnnealingWarmRestarts") + elif args.lr_scheduler == "StepLR": + lr_scheduler = optim.lr_scheduler.StepLR( + optimizer, step_size=args.lr_decay_step_size, gamma=args.lr_decay_factor + ) + print("Apply lr scheduler StepLR") + elif args.lr_scheduler == "ReduceLROnPlateau": + lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau( + optimizer, factor=args.lr_decay_factor, patience=args.lr_decay_patience, min_lr=args.min_lr + ) + print("Apply lr scheduler ReduceLROnPlateau") + elif args.lr_scheduler == "StepLRCustomized": + print("Will decay with {}, at epochs {}".format(args.lr_decay_factor, args.StepLRCustomized_scheduler)) + print("Apply lr scheduler StepLR (customized)") + else: + print("lr scheduler {} is not included.".format(args.lr_scheduler)) + global_learning_rate = args.lr + + train_acc_list, val_acc_list = [], [] + test_30_list, test_40_list, test_50_list, test_70_list, test_95_list = [], [], [], [], [] + best_val_acc, best_val_idx = -1e10, 0 + for epoch in range(1, args.epochs + 1): + start_time = time.time() + loss_acc = train(epoch, device, train_loader, optimizer) + print("Epoch: {}\nLoss: {}".format(epoch, loss_acc)) + + if epoch % args.print_every_epoch == 0: + if args.eval_train: + train_acc, train_target, train_pred = eval(device, train_loader) + else: + train_acc = 0 + + val_acc = eval(device, val_loader) + test_30 = eval(device, test_30_loader) + test_40 = eval(device, test_40_loader) + test_50 = eval(device, test_50_loader) + test_70 = eval(device, test_70_loader) + test_95 = eval(device, test_95_loader) + + + train_acc_list.append(train_acc) + val_acc_list.append(val_acc) + test_30_list.append(test_30) + test_40_list.append(test_40) + test_50_list.append(test_50) + test_70_list.append(test_70) + test_95_list.append(test_95) + + print( + "train: {:.6f}\tval: {:.6f}\ttest_30: {:.6f}\ttest_40: {:.6f}\ttest_50: {:.6f}\ttest_70: {:.6f}\ttest_95: {:.6f}".format( + train_acc, val_acc, test_30, test_40, test_50, test_70, test_95 + ) + ) + + print(val_acc, best_val_acc) + if val_acc > best_val_acc: + best_val_acc = val_acc + best_val_idx = len(train_acc_list) - 1 + if not args.output_model_dir == "": + save_model(save_best=True) + print(val_acc, best_val_acc) + + if args.lr_scheduler == "StepLRCustomized" and epoch in args.StepLRCustomized_scheduler: + print('ChanGINg learning rate, from {} to {}'.format(global_learning_rate, global_learning_rate * args.lr_decay_factor)), + global_learning_rate *= args.lr_decay_factor + for param_group in optimizer.param_groups: + param_group['lr'] = global_learning_rate + print("Took\t{}\n".format(time.time() - start_time)) + + print( + "best train: {:.6f}\tval: {:.6f}\ttest_30: {:.6f}\ttest_40: {:.6f}\ttest_50: {:.6f}\ttest_70: {:.6f}\ttest_95: {:.6f}".format( + train_acc_list[best_val_idx], + val_acc_list[best_val_idx], + test_30_list[best_val_idx], + test_40_list[best_val_idx], + test_50_list[best_val_idx], + test_70_list[best_val_idx], + test_95_list[best_val_idx] + ) + ) + + save_model(save_best=False) \ No newline at end of file From 56d6db947854d6fc240c6340c06074f480abaf24 Mon Sep 17 00:00:00 2001 From: YanjingLiLi Date: Mon, 23 Oct 2023 16:21:10 -0700 Subject: [PATCH 2/5] modify init --- Geom3D/datasets/__init__.py | 79 ++++++++++++++++++------------------- Geom3D/models/__init__.py | 56 +++++++++++++------------- 2 files changed, 67 insertions(+), 68 deletions(-) diff --git a/Geom3D/datasets/__init__.py b/Geom3D/datasets/__init__.py index e8ad773..72f7203 100644 --- a/Geom3D/datasets/__init__.py +++ b/Geom3D/datasets/__init__.py @@ -1,56 +1,56 @@ from Geom3D.datasets.dataset_utils import graph_data_obj_to_nx_simple, nx_to_graph_data_obj_simple, atom_type_count -# from Geom3D.datasets.dataset_GEOM import MoleculeDatasetGEOM -# from Geom3D.datasets.dataset_GEOM_Drugs import MoleculeDatasetGEOMDrugs, MoleculeDatasetGEOMDrugsTest -# from Geom3D.datasets.dataset_GEOM_QM9 import MoleculeDatasetGEOMQM9, MoleculeDatasetGEOMQM9Test +from Geom3D.datasets.dataset_GEOM import MoleculeDatasetGEOM +from Geom3D.datasets.dataset_GEOM_Drugs import MoleculeDatasetGEOMDrugs, MoleculeDatasetGEOMDrugsTest +from Geom3D.datasets.dataset_GEOM_QM9 import MoleculeDatasetGEOMQM9, MoleculeDatasetGEOMQM9Test -# from Geom3D.datasets.dataset_Molecule3D import Molecule3D +from Geom3D.datasets.dataset_Molecule3D import Molecule3D -# from Geom3D.datasets.dataset_PCQM4Mv2 import PCQM4Mv2 -# from Geom3D.datasets.dataset_PCQM4Mv2_3D_and_MMFF import PCQM4Mv2_3DandMMFF +from Geom3D.datasets.dataset_PCQM4Mv2 import PCQM4Mv2 +from Geom3D.datasets.dataset_PCQM4Mv2_3D_and_MMFF import PCQM4Mv2_3DandMMFF -# from Geom3D.datasets.dataset_QM9 import MoleculeDatasetQM9 -# from Geom3D.datasets.dataset_QM9_2D import MoleculeDatasetQM92D -# from Geom3D.datasets.dataset_QM9_Fingerprints_SMILES import MoleculeDatasetQM9FingerprintsSMILES -# from Geom3D.datasets.dataset_QM9_RDKit import MoleculeDatasetQM9RDKit -# from Geom3D.datasets.dataset_QM9_3D_and_MMFF import MoleculeDatasetQM9_3DandMMFF -# from Geom3D.datasets.dataset_QM9_2D_3D_Transformer import MoleculeDatasetQM9_2Dand3DTransformer +from Geom3D.datasets.dataset_QM9 import MoleculeDatasetQM9 +from Geom3D.datasets.dataset_QM9_2D import MoleculeDatasetQM92D +from Geom3D.datasets.dataset_QM9_Fingerprints_SMILES import MoleculeDatasetQM9FingerprintsSMILES +from Geom3D.datasets.dataset_QM9_RDKit import MoleculeDatasetQM9RDKit +from Geom3D.datasets.dataset_QM9_3D_and_MMFF import MoleculeDatasetQM9_3DandMMFF +from Geom3D.datasets.dataset_QM9_2D_3D_Transformer import MoleculeDatasetQM9_2Dand3DTransformer -# from Geom3D.datasets.dataset_COLL import DatasetCOLL -# from Geom3D.datasets.dataset_COLLRadius import DatasetCOLLRadius -# from Geom3D.datasets.dataset_COLLGemNet import DatasetCOLLGemNet +from Geom3D.datasets.dataset_COLL import DatasetCOLL +from Geom3D.datasets.dataset_COLLRadius import DatasetCOLLRadius +from Geom3D.datasets.dataset_COLLGemNet import DatasetCOLLGemNet -# from Geom3D.datasets.dataset_MD17 import DatasetMD17 -# from Geom3D.datasets.dataset_rMD17 import DatasetrMD17 +from Geom3D.datasets.dataset_MD17 import DatasetMD17 +from Geom3D.datasets.dataset_rMD17 import DatasetrMD17 -# from Geom3D.datasets.dataset_LBA import DatasetLBA, TransformLBA -# from Geom3D.datasets.dataset_LBARadius import DatasetLBARadius +from Geom3D.datasets.dataset_LBA import DatasetLBA, TransformLBA +from Geom3D.datasets.dataset_LBARadius import DatasetLBARadius -# from Geom3D.datasets.dataset_LEP import DatasetLEP, TransformLEP -# from Geom3D.datasets.dataset_LEPRadius import DatasetLEPRadius +from Geom3D.datasets.dataset_LEP import DatasetLEP, TransformLEP +from Geom3D.datasets.dataset_LEPRadius import DatasetLEPRadius -# from Geom3D.datasets.dataset_OC20 import DatasetOC20, is2re_data_transform, s2ef_data_transform +from Geom3D.datasets.dataset_OC20 import DatasetOC20, is2re_data_transform, s2ef_data_transform -# from Geom3D.datasets.dataset_MoleculeNet_2D import MoleculeNetDataset2D -# from Geom3D.datasets.dataset_MoleculeNet_3D import MoleculeNetDataset3D, MoleculeNetDataset2D_SDE3D +from Geom3D.datasets.dataset_MoleculeNet_2D import MoleculeNetDataset2D +from Geom3D.datasets.dataset_MoleculeNet_3D import MoleculeNetDataset3D, MoleculeNetDataset2D_SDE3D -# from Geom3D.datasets.dataset_QMOF import DatasetQMOF -# from Geom3D.datasets.dataset_MatBench import DatasetMatBench +from Geom3D.datasets.dataset_QMOF import DatasetQMOF +from Geom3D.datasets.dataset_MatBench import DatasetMatBench -# from Geom3D.datasets.dataset_3D import Molecule3DDataset -# from Geom3D.datasets.dataset_3D_Radius import MoleculeDataset3DRadius -# from Geom3D.datasets.dataset_3D_Remove_Center import MoleculeDataset3DRemoveCenter +from Geom3D.datasets.dataset_3D import Molecule3DDataset +from Geom3D.datasets.dataset_3D_Radius import MoleculeDataset3DRadius +from Geom3D.datasets.dataset_3D_Remove_Center import MoleculeDataset3DRemoveCenter -# # For Distance Prediction -# from Geom3D.datasets.dataset_3D_Full import MoleculeDataset3DFull +# For Distance Prediction +from Geom3D.datasets.dataset_3D_Full import MoleculeDataset3DFull -# # For Torsion Prediction -# from Geom3D.datasets.dataset_3D_TorsionAngle import MoleculeDataset3DTorsionAngle +# For Torsion Prediction +from Geom3D.datasets.dataset_3D_TorsionAngle import MoleculeDataset3DTorsionAngle -# from Geom3D.datasets.dataset_OneAtom import MoleculeDatasetOneAtom +from Geom3D.datasets.dataset_OneAtom import MoleculeDatasetOneAtom -# # For 2D N-Gram-Path -# from Geom3D.datasets.dataset_2D_Dense import MoleculeDataset2DDense +# For 2D N-Gram-Path +from Geom3D.datasets.dataset_2D_Dense import MoleculeDataset2DDense # For protein from Geom3D.datasets.dataset_FOLD import DatasetFOLD @@ -61,10 +61,9 @@ from Geom3D.datasets.dataset_GVP import DatasetGVP from Geom3D.datasets.dataset_GO_GearNet import DatasetGOGearNet from Geom3D.datasets.dataset_ECMultiple_GearNet import DatasetECMultipleGearNet -from Geom3D.datasets.dataset_MSP_GearNet import DatasetMSPGearNet from Geom3D.datasets.dataset_ECSingle_GearNet import DatasetECSingleGearNet # For 2D SSL -# from Geom3D.datasets.dataset_2D_Contextual import MoleculeContextualDataset -# from Geom3D.datasets.dataset_2D_GPT import MoleculeDatasetGPT -# from Geom3D.datasets.dataset_2D_GraphCL import MoleculeDataset_GraphCL \ No newline at end of file +from Geom3D.datasets.dataset_2D_Contextual import MoleculeContextualDataset +from Geom3D.datasets.dataset_2D_GPT import MoleculeDatasetGPT +from Geom3D.datasets.dataset_2D_GraphCL import MoleculeDataset_GraphCL \ No newline at end of file diff --git a/Geom3D/models/__init__.py b/Geom3D/models/__init__.py index 141bc5a..57274ec 100644 --- a/Geom3D/models/__init__.py +++ b/Geom3D/models/__init__.py @@ -3,41 +3,41 @@ import torch.nn.functional as F from torch_geometric.nn import GATConv, GCNConv -# from .AutoEncoder import AutoEncoder, VariationalAutoEncoder +from .AutoEncoder import AutoEncoder, VariationalAutoEncoder -# from .DimeNet import DimeNet -# from .DimeNetPlusPlus import DimeNetPlusPlus -# from .EGNN import EGNN -# from .PaiNN import PaiNN -# from .SchNet import SchNet -# from .SE3_Transformer import SE3Transformer -# from .SEGNN import SEGNNModel as SEGNN -# from .SphereNet import SphereNet -# from .SphereNet_periodic import SphereNetPeriodic -# from .TFN import TFN -# from .GemNet import GemNet -# from .ClofNet import ClofNet -# from .Graphormer import Graphormer -# from .TransformerM import TransformerM -# from .Equiformer import EquiformerEnergy, EquiformerEnergyForce, EquiformerEnergyPeriodic +from .DimeNet import DimeNet +from .DimeNetPlusPlus import DimeNetPlusPlus +from .EGNN import EGNN +from .PaiNN import PaiNN +from .SchNet import SchNet +from .SE3_Transformer import SE3Transformer +from .SEGNN import SEGNNModel as SEGNN +from .SphereNet import SphereNet +from .SphereNet_periodic import SphereNetPeriodic +from .TFN import TFN +from .GemNet import GemNet +from .ClofNet import ClofNet +from .Graphormer import Graphormer +from .TransformerM import TransformerM +from .Equiformer import EquiformerEnergy, EquiformerEnergyForce, EquiformerEnergyPeriodic from .GVP import GVP_GNN, MQAModel from .GearNet import GearNetIEConv from .ProNet import ProNet from .CDConv import CD_Convolution -# from .BERT import BertForSequenceRegression +from .BERT import BertForSequenceRegression -# from .GeoSSL_DDM import GeoSSL_DDM -# from .GeoSSL_PDM import GeoSSL_PDM +from .GeoSSL_DDM import GeoSSL_DDM +from .GeoSSL_PDM import GeoSSL_PDM -# from .molecule_gnn_model import GNN, GNN_graphpred -# from .molecule_gnn_model_simplified import GNNSimplified -# from .PNA import PNA -# from .ENN import ENN_S2S -# from .DMPNN import DMPNN -# from .GPS import GPSModel -# from .AWARE import AWARE +from .molecule_gnn_model import GNN, GNN_graphpred +from .molecule_gnn_model_simplified import GNNSimplified +from .PNA import PNA +from .ENN import ENN_S2S +from .DMPNN import DMPNN +from .GPS import GPSModel +from .AWARE import AWARE -# from .MLP import MLP -# from .CNN import CNN +from .MLP import MLP +from .CNN import CNN From 2a990269e5c3987dbdceec6f309614e6134cf73d Mon Sep 17 00:00:00 2001 From: YanjingLiLi Date: Tue, 4 Jun 2024 12:29:12 +0800 Subject: [PATCH 3/5] add MSP/PSR --- Geom3D/dataloaders/__init__.py | 4 +- Geom3D/dataloaders/dataloaders_MSP.py | 60 ++++ Geom3D/datasets/__init__.py | 2 + Geom3D/datasets/dataset_MSP.py | 316 +++++++++++++++++ Geom3D/datasets/dataset_PSR.py | 319 +++++++++++++++++ examples_3D/finetune_ECMultiple.py | 2 + examples_3D/finetune_MSP.py | 475 ++++++++++++++++++++++++++ examples_3D/finetune_PSR.py | 442 ++++++++++++++++++++++++ scripts/ECMultiple/submit_CDConv.sh | 35 ++ scripts/ECMultiple/submit_GVP.sh | 35 ++ scripts/ECMultiple/submit_GearNet.sh | 36 ++ scripts/ECMultiple/submit_ProNet.sh | 38 +++ scripts/ECSingle/submit_CDConv.sh | 35 ++ scripts/ECSingle/submit_GVP.sh | 35 ++ scripts/ECSingle/submit_GearNet.sh | 35 ++ scripts/FOLD/submit_CDConv.sh | 36 ++ scripts/FOLD/submit_GVP.sh | 35 ++ scripts/FOLD/submit_GearNet.sh | 35 ++ scripts/FOLD/submit_GearNet_IEConv.sh | 35 ++ scripts/GO/submit_CDConv_bp.sh | 36 ++ scripts/GO/submit_CDConv_cc.sh | 36 ++ scripts/GO/submit_CDConv_mf.sh | 36 ++ scripts/GO/submit_GVP_bp.sh | 36 ++ scripts/GO/submit_GVP_cc.sh | 36 ++ scripts/GO/submit_GVP_mf.sh | 36 ++ scripts/GO/submit_GearNet_bp.sh | 37 ++ scripts/GO/submit_GearNet_cc.sh | 37 ++ scripts/GO/submit_GearNet_mf.sh | 37 ++ scripts/GO/submit_ProNet_bp.sh | 39 +++ scripts/GO/submit_ProNet_cc.sh | 39 +++ scripts/GO/submit_ProNet_mf.sh | 39 +++ scripts/MSP/submit_CDConv.sh | 35 ++ scripts/MSP/submit_GVP.sh | 35 ++ scripts/MSP/submit_ProNet.sh | 38 +++ scripts/PSR/submit_CDConv.sh | 35 ++ scripts/PSR/submit_GVP.sh | 35 ++ scripts/PSR/submit_ProNet.sh | 38 +++ 37 files changed, 2669 insertions(+), 1 deletion(-) create mode 100644 Geom3D/dataloaders/dataloaders_MSP.py create mode 100644 Geom3D/datasets/dataset_MSP.py create mode 100644 Geom3D/datasets/dataset_PSR.py create mode 100644 examples_3D/finetune_MSP.py create mode 100644 examples_3D/finetune_PSR.py create mode 100644 scripts/ECMultiple/submit_CDConv.sh create mode 100644 scripts/ECMultiple/submit_GVP.sh create mode 100644 scripts/ECMultiple/submit_GearNet.sh create mode 100644 scripts/ECMultiple/submit_ProNet.sh create mode 100644 scripts/ECSingle/submit_CDConv.sh create mode 100644 scripts/ECSingle/submit_GVP.sh create mode 100644 scripts/ECSingle/submit_GearNet.sh create mode 100644 scripts/FOLD/submit_CDConv.sh create mode 100644 scripts/FOLD/submit_GVP.sh create mode 100644 scripts/FOLD/submit_GearNet.sh create mode 100644 scripts/FOLD/submit_GearNet_IEConv.sh create mode 100644 scripts/GO/submit_CDConv_bp.sh create mode 100644 scripts/GO/submit_CDConv_cc.sh create mode 100644 scripts/GO/submit_CDConv_mf.sh create mode 100644 scripts/GO/submit_GVP_bp.sh create mode 100644 scripts/GO/submit_GVP_cc.sh create mode 100644 scripts/GO/submit_GVP_mf.sh create mode 100644 scripts/GO/submit_GearNet_bp.sh create mode 100644 scripts/GO/submit_GearNet_cc.sh create mode 100644 scripts/GO/submit_GearNet_mf.sh create mode 100644 scripts/GO/submit_ProNet_bp.sh create mode 100644 scripts/GO/submit_ProNet_cc.sh create mode 100644 scripts/GO/submit_ProNet_mf.sh create mode 100644 scripts/MSP/submit_CDConv.sh create mode 100644 scripts/MSP/submit_GVP.sh create mode 100644 scripts/MSP/submit_ProNet.sh create mode 100644 scripts/PSR/submit_CDConv.sh create mode 100644 scripts/PSR/submit_GVP.sh create mode 100644 scripts/PSR/submit_ProNet.sh diff --git a/Geom3D/dataloaders/__init__.py b/Geom3D/dataloaders/__init__.py index e528c8e..e184a77 100644 --- a/Geom3D/dataloaders/__init__.py +++ b/Geom3D/dataloaders/__init__.py @@ -10,4 +10,6 @@ from Geom3D.dataloaders.dataloaders_AtomTuple import AtomTupleExtractor, DataLoaderAtomTuple -from Geom3D.dataloaders.dataloaders_PeriodicCrystal import DataLoaderPeriodicCrystal \ No newline at end of file +from Geom3D.dataloaders.dataloaders_PeriodicCrystal import DataLoaderPeriodicCrystal + +from Geom3D.dataloaders.dataloaders_MSP import DataLoaderMultiPro \ No newline at end of file diff --git a/Geom3D/dataloaders/dataloaders_MSP.py b/Geom3D/dataloaders/dataloaders_MSP.py new file mode 100644 index 0000000..6f08c43 --- /dev/null +++ b/Geom3D/dataloaders/dataloaders_MSP.py @@ -0,0 +1,60 @@ +import torch +from torch.utils.data import DataLoader +from torch_geometric.data import Data +import numpy as np + + +class BatchMultiPro(Data): + def __init__(self, **kwargs): + super(BatchMultiPro, self).__init__(**kwargs) + return + + @staticmethod + def from_data_list(data_list): + batch = BatchMultiPro() + + keys = [set(data.keys) for data in data_list] + keys = list(set.union(*keys)) + + for key in keys: + batch[key] = [] + + batch.batch_protein_1 = [] + batch.batch_protein_2 = [] + + for i, data in enumerate(data_list): + num_nodes_protein_1 = data.num_nodes_1 + num_nodes_protein_2 = data.num_nodes_2 + batch.batch_protein_1.append(torch.full((num_nodes_protein_1,), i, dtype=torch.long)) + batch.batch_protein_2.append(torch.full((num_nodes_protein_2,), i, dtype=torch.long)) + + for key in data.keys: + item = data[key] + batch[key].append(item) + + + for key in keys: + if key not in ["x_1", "x_2", "id"]: + batch[key] = torch.cat(batch[key], dim=data_list[0].__cat_dim__(key, batch[key][0])) + else: + batch[key] = np.array(batch[key]).flatten().tolist() + + batch.batch_protein_1 = torch.cat(batch.batch_protein_1, dim=-1) + batch.batch_protein_2 = torch.cat(batch.batch_protein_2, dim=-1) + + return batch.contiguous() + + @property + def num_graphs(self): + '''Returns the number of graphs in the batch.''' + return self.batch[-1].item() + 1 + + +class DataLoaderMultiPro(DataLoader): + def __init__(self, dataset, batch_size=1, shuffle=True, **kwargs): + super(DataLoaderMultiPro, self).__init__( + dataset, + batch_size, + shuffle, + collate_fn=lambda data_list: BatchMultiPro.from_data_list(data_list), + **kwargs) \ No newline at end of file diff --git a/Geom3D/datasets/__init__.py b/Geom3D/datasets/__init__.py index 72f7203..ccf72b6 100644 --- a/Geom3D/datasets/__init__.py +++ b/Geom3D/datasets/__init__.py @@ -62,6 +62,8 @@ from Geom3D.datasets.dataset_GO_GearNet import DatasetGOGearNet from Geom3D.datasets.dataset_ECMultiple_GearNet import DatasetECMultipleGearNet from Geom3D.datasets.dataset_ECSingle_GearNet import DatasetECSingleGearNet +from Geom3D.datasets.dataset_MSP import DatasetMSP +from Geom3D.datasets.dataset_PSR import DatasetPSR # For 2D SSL from Geom3D.datasets.dataset_2D_Contextual import MoleculeContextualDataset diff --git a/Geom3D/datasets/dataset_MSP.py b/Geom3D/datasets/dataset_MSP.py new file mode 100644 index 0000000..db1da96 --- /dev/null +++ b/Geom3D/datasets/dataset_MSP.py @@ -0,0 +1,316 @@ +import os.path as osp +import numpy as np +import warnings +from tqdm import tqdm +from sklearn.preprocessing import normalize +import h5py +import lmdb +import pickle as pkl +import json +import msgpack +import pandas as pd +import scipy +import io +import gzip +import logging +from Bio.PDB.Polypeptide import three_to_one, is_aa +from tqdm import tqdm + +import torch, math +import torch.nn.functional as F +import torch_cluster + +from torch_geometric.data import Data +from torch_geometric.data import InMemoryDataset + + +class DatasetMSP(InMemoryDataset): + def __init__(self, root, transform=None, pre_transform=None, pre_filter=None, split='train'): + self.split = split + self.root = root + self.device = "cuda" + self.index_columns = ['ensemble', 'subunit', 'structure', 'model', 'chain', 'residue'] + self.letter_to_num = {'C': 4, 'D': 3, 'S': 15, 'Q': 5, 'K': 11, 'I': 9, + 'P': 14, 'T': 16, 'F': 13, 'A': 0, 'G': 7, 'H': 8, + 'E': 6, 'L': 10, 'R': 1, 'W': 17, 'V': 19, + 'N': 2, 'Y': 18, 'M': 12, "X":20} + + super(DatasetMSP, self).__init__( + root, transform, pre_transform, pre_filter) + + self.transform, self.pre_transform, self.pre_filter = transform, pre_transform, pre_filter + print(self.processed_paths) + self.data, self.slices = torch.load(self.processed_paths[0]) + print(self.slices) + + @property + def processed_dir(self): + name = 'processed_MSP' + return osp.join(self.root, name, self.split) + + @property + def raw_file_names(self): + name = self.split + '.txt' + return name + + @property + def processed_file_names(self): + return 'data.pt' + + def deserialize(self, x, serialization_format): + """ + Deserializes dataset `x` assuming format given by `serialization_format` (pkl, json, msgpack). + """ + if serialization_format == 'pkl': + return pkl.loads(x) + elif serialization_format == 'json': + serialized = json.loads(x) + elif serialization_format == 'msgpack': + serialized = msgpack.unpackb(x) + else: + raise RuntimeError('Invalid serialization format') + + return serialized + + def _normalize(self, tensor, dim=-1): + ''' + Normalizes a `torch.Tensor` along dimension `dim` without `nan`s. + ''' + return torch.nan_to_num( + torch.div(tensor, torch.norm(tensor, dim=dim, keepdim=True))) + + def get_side_chain_angle_encoding(self, pos_n, pos_ca, pos_c, pos_cb, pos_g, pos_d, pos_e, pos_z, pos_h): + v1, v2, v3, v4, v5, v6, v7 = pos_ca - pos_n, pos_cb - pos_ca, pos_g - pos_cb, pos_d - pos_g, pos_e - pos_d, pos_z - pos_e, pos_h - pos_z + + # five side chain torsion angles + # We only consider the first four torsion angles in side chains since only the amino acid arginine has five side chain torsion angles, and the fifth angle is close to 0. + angle1 = torch.unsqueeze(self.diherals_ProNet(v1, v2, v3),1) + angle2 = torch.unsqueeze(self.diherals_ProNet(v2, v3, v4),1) + angle3 = torch.unsqueeze(self.diherals_ProNet(v3, v4, v5),1) + angle4 = torch.unsqueeze(self.diherals_ProNet(v4, v5, v6),1) + angle5 = torch.unsqueeze(self.diherals_ProNet(v5, v6, v7),1) + + side_chain_angles = torch.cat((angle1, angle2, angle3, angle4),1) + side_chain_embs = torch.cat((torch.sin(side_chain_angles), torch.cos(side_chain_angles)),1) + + return side_chain_embs + + def get_backbone_angle_encoding(self, X): + # X should be a num_residues x 3 x 3, order N, C-alpha, and C atoms of each residue + # N coords: X[:,0,:] + # CA coords: X[:,1,:] + # C coords: X[:,2,:] + # return num_residues x 6 + # From https://github.com/jingraham/neurips19-graph-protein-design + + X = torch.reshape(X, [3 * X.shape[0], 3]) + dX = X[1:] - X[:-1] + U = self._normalize(dX, dim=-1) + u0 = U[:-2] + u1 = U[1:-1] + u2 = U[2:] + + angle = self.diherals_ProNet(u0, u1, u2) + + # add phi[0], psi[-1], omega[-1] with value 0 + angle = F.pad(angle, [1, 2]) + angle = torch.reshape(angle, [-1, 3]) + angle_features = torch.cat([torch.cos(angle), torch.sin(angle)], 1) + return angle_features + + def diherals_ProNet(self, v1, v2, v3): + n1 = torch.cross(v1, v2) + n2 = torch.cross(v2, v3) + a = (n1 * n2).sum(dim=-1) + b = torch.nan_to_num((torch.cross(n1, n2) * v2).sum(dim=-1) / v2.norm(dim=1)) + torsion = torch.nan_to_num(torch.atan2(b, a)) + + return torsion + + def _three_to_one(self, residue): + try: + return three_to_one(residue) + except KeyError: + return "X" + + def parse_protein_df(self, protein_df): + atom_names, atom_pos, residue_type, atom_amino_id = [], [], [], [] + + residue_sum = 0 + processed = [] + for _, row in protein_df.iterrows(): + if is_aa(row["resname"]): + if (row["chain"], row["residue"]) not in processed: + processed.append((row["chain"], row["residue"])) + residue_df = protein_df[(protein_df["chain"] == row["chain"]) & (protein_df["residue"] == row["residue"])] + + if residue_df["fullname"].str.strip().isin(["N"]).any() and residue_df["fullname"].str.strip().isin(["CA"]).any() and residue_df["fullname"].str.strip().isin(["C"]).any(): + residue_type.append(self.letter_to_num[self._three_to_one(residue_df.iloc[0]["resname"])]) + for _, subrow in residue_df.iterrows(): + if isinstance(subrow["fullname"].strip(), str): + atom_names.append(subrow["fullname"].strip()) + atom_pos.append([subrow["x"], subrow["y"], subrow["z"]]) + atom_amino_id.append(residue_sum) + residue_sum += 1 + + return atom_names, np.array(atom_pos), residue_type, np.array(atom_amino_id) + + def get_key_atom_pos(self, amino_types, atom_names, atom_amino_id, atom_pos): + # atoms to compute side chain torsion angles: N, CA, CB, _G/_G1, _D/_D1, _E/_E1, _Z, NH1 + mask_n = np.char.equal(atom_names, 'N') + mask_ca = np.char.equal(atom_names, 'CA') + mask_c = np.char.equal(atom_names, 'C') + mask_cb = np.char.equal(atom_names, 'CB') + mask_g = np.char.equal(atom_names, 'CG') | np.char.equal(atom_names, 'SG') | np.char.equal(atom_names, 'OG') | np.char.equal(atom_names, 'CG1') | np.char.equal(atom_names, 'OG1') + mask_d = np.char.equal(atom_names, 'CD') | np.char.equal(atom_names, 'SD') | np.char.equal(atom_names, 'CD1') | np.char.equal(atom_names, 'OD1') | np.char.equal(atom_names, 'ND1') + mask_e = np.char.equal(atom_names, 'CE') | np.char.equal(atom_names, 'NE') | np.char.equal(atom_names, 'OE1') + mask_z = np.char.equal(atom_names, 'CZ') | np.char.equal(atom_names, 'NZ') + mask_h = np.char.equal(atom_names, 'NH1') + + pos_n = np.full((len(amino_types),3),np.nan) + pos_n[atom_amino_id[mask_n]] = atom_pos[mask_n] + pos_n = torch.FloatTensor(pos_n) + + pos_ca = np.full((len(amino_types),3),np.nan) + pos_ca[atom_amino_id[mask_ca]] = atom_pos[mask_ca] + pos_ca = torch.FloatTensor(pos_ca) + + pos_c = np.full((len(amino_types),3),np.nan) + pos_c[atom_amino_id[mask_c]] = atom_pos[mask_c] + pos_c = torch.FloatTensor(pos_c) + + # if data only contain pos_ca, we set the position of C and N as the position of CA + pos_n[torch.isnan(pos_n)] = pos_ca[torch.isnan(pos_n)] + pos_c[torch.isnan(pos_c)] = pos_ca[torch.isnan(pos_c)] + + pos_cb = np.full((len(amino_types),3),np.nan) + pos_cb[atom_amino_id[mask_cb]] = atom_pos[mask_cb] + pos_cb = torch.FloatTensor(pos_cb) + + pos_g = np.full((len(amino_types),3),np.nan) + pos_g[atom_amino_id[mask_g]] = atom_pos[mask_g] + pos_g = torch.FloatTensor(pos_g) + + pos_d = np.full((len(amino_types),3),np.nan) + pos_d[atom_amino_id[mask_d]] = atom_pos[mask_d] + pos_d = torch.FloatTensor(pos_d) + + pos_e = np.full((len(amino_types),3),np.nan) + pos_e[atom_amino_id[mask_e]] = atom_pos[mask_e] + pos_e = torch.FloatTensor(pos_e) + + pos_z = np.full((len(amino_types),3),np.nan) + pos_z[atom_amino_id[mask_z]] = atom_pos[mask_z] + pos_z = torch.FloatTensor(pos_z) + + pos_h = np.full((len(amino_types),3),np.nan) + pos_h[atom_amino_id[mask_h]] = atom_pos[mask_h] + pos_h = torch.FloatTensor(pos_h) + + return pos_n, pos_ca, pos_c, pos_cb, pos_g, pos_d, pos_e, pos_z, pos_h + + def extract_protein_data(self, original_protein, mutated_protein): + data = Data() + + atom_names_1, atom_pos_1, residue_type_1, atom_amino_id_1 = self.parse_protein_df(original_protein) + pos_n_1, pos_ca_1, pos_c_1, pos_cb_1, pos_g_1, pos_d_1, pos_e_1, pos_z_1, pos_h_1 = self.get_key_atom_pos(residue_type_1, atom_names_1, atom_amino_id_1, atom_pos_1) + atom_names_2, atom_pos_2, residue_type_2, atom_amino_id_2 = self.parse_protein_df(mutated_protein) + pos_n_2, pos_ca_2, pos_c_2, pos_cb_2, pos_g_2, pos_d_2, pos_e_2, pos_z_2, pos_h_2 = self.get_key_atom_pos(residue_type_2, atom_names_2, atom_amino_id_2, atom_pos_2) + + # calculate side chain torsion angles, up to four + # do encoding + side_chain_angle_encoding_1 = self.get_side_chain_angle_encoding(pos_n_1, pos_ca_1, pos_c_1, pos_cb_1, pos_g_1, pos_d_1, pos_e_1, pos_z_1, pos_h_1) + side_chain_angle_encoding_1[torch.isnan(side_chain_angle_encoding_1)] = 0 + side_chain_angle_encoding_2 = self.get_side_chain_angle_encoding(pos_n_2, pos_ca_2, pos_c_2, pos_cb_2, pos_g_2, pos_d_2, pos_e_2, pos_z_2, pos_h_2) + side_chain_angle_encoding_2[torch.isnan(side_chain_angle_encoding_2)] = 0 + + # three backbone torsion angles + backbone_angle_encoding_1 = self.get_backbone_angle_encoding(torch.cat((torch.unsqueeze(pos_n_1,1), torch.unsqueeze(pos_ca_1,1), torch.unsqueeze(pos_c_1,1)),1)) + backbone_angle_encoding_1[torch.isnan(backbone_angle_encoding_1)] = 0 + backbone_angle_encoding_2 = self.get_backbone_angle_encoding(torch.cat((torch.unsqueeze(pos_n_2,1), torch.unsqueeze(pos_ca_2,1), torch.unsqueeze(pos_c_2,1)),1)) + backbone_angle_encoding_2[torch.isnan(backbone_angle_encoding_2)] = 0 + + data.seq_1 = torch.LongTensor(residue_type_1) + data.side_chain_angle_encoding_1 = side_chain_angle_encoding_1 + data.backbone_angle_encoding_1 = backbone_angle_encoding_1 + data.coords_ca_1 = pos_ca_1 + data.coords_n_1 = pos_n_1 + data.coords_c_1 = pos_c_1 + data.x_1 = atom_names_1 + data.pos_1 = torch.tensor(atom_pos_1) + data.num_nodes_1 = len(pos_ca_1) + + data.seq_2 = torch.LongTensor(residue_type_2) + data.side_chain_angle_encoding_2 = side_chain_angle_encoding_2 + data.backbone_angle_encoding_2 = backbone_angle_encoding_2 + data.coords_ca_2 = pos_ca_2 + data.coords_n_2 = pos_n_2 + data.coords_c_2 = pos_c_2 + data.x_2 = atom_names_2 + data.pos_2 = torch.tensor(atom_pos_2) + data.num_nodes_2 = len(pos_ca_2) + + return data + + def process(self): + print('Beginning Processing ...') + + data_list = [] + + env = lmdb.open(osp.join(self.root, self.split), max_readers=1, readonly=True, + lock=False, readahead=False, meminit=False) + + with env.begin(write=False) as txn: + self._num_examples = int(txn.get(b'num_examples')) + self._serialization_format = \ + txn.get(b'serialization_format').decode() + self._id_to_idx = self.deserialize( + txn.get(b'id_to_idx'), self._serialization_format) + + self._env = env + + for index in tqdm(range(self._num_examples), desc="all samples"): + print(index) + with self._env.begin(write=False) as txn: + compressed = txn.get(str(index).encode()) + buf = io.BytesIO(compressed) + with gzip.GzipFile(fileobj=buf, mode="rb") as f: + serialized = f.read() + try: + item = self.deserialize(serialized, self._serialization_format) + except: + return None + + # Recover special data types (currently only pandas dataframes). + if 'types' in item.keys(): + for x in item.keys(): + if (self._serialization_format=='json') and (item['types'][x] == str(pd.DataFrame)): + item[x] = pd.DataFrame(**item[x]) + else: + logging.warning('Data types in item %i not defined. Will use basic types only.'%index) + + if 'file_path' not in item: + item['file_path'] = str(self.data_file) + if 'id' not in item: + item['id'] = str(index) + + + original_protein = item["original_atoms"] + mutated_protein = item["mutated_atoms"] + + data = self.extract_protein_data(original_protein, mutated_protein) + data.y = int(item["label"]) + + if data.seq_1 != None and data.seq_2 != None: + data_list.append(data) + + + data, slices = self.collate(data_list) + torch.save((data, slices), self.processed_paths[0]) + print('Done!') + +# if __name__ == "__main__": +# for split in ["train", "val", "test"]: +# #for split in ['validation']: +# print('#### Now processing {} data ####'.format(split)) +# dataset = DatasetMSP(root="/lustre07/scratch/liusheng/atom3d_data/MSP/split-by-sequence-identity-30/data", split=split) \ No newline at end of file diff --git a/Geom3D/datasets/dataset_PSR.py b/Geom3D/datasets/dataset_PSR.py new file mode 100644 index 0000000..2042bc8 --- /dev/null +++ b/Geom3D/datasets/dataset_PSR.py @@ -0,0 +1,319 @@ +import os.path as osp +import numpy as np +import warnings +from tqdm import tqdm +from sklearn.preprocessing import normalize +import h5py +import lmdb +import pickle as pkl +import json +import msgpack +import pandas as pd +import scipy +import io +import gzip +import logging +from Bio.PDB.Polypeptide import three_to_one, is_aa +from tqdm import tqdm + +import torch, math +import torch.nn.functional as F +import torch_cluster + +from torch_geometric.data import Data +from torch_geometric.data import InMemoryDataset + + +class DatasetPSR(InMemoryDataset): + def __init__(self, root, transform=None, pre_transform=None, pre_filter=None, split='train'): + self.split = split + self.root = root + self.device = "cuda" + self.index_columns = ['ensemble', 'subunit', 'structure', 'model', 'chain', 'residue'] + self.letter_to_num = {'C': 4, 'D': 3, 'S': 15, 'Q': 5, 'K': 11, 'I': 9, + 'P': 14, 'T': 16, 'F': 13, 'A': 0, 'G': 7, 'H': 8, + 'E': 6, 'L': 10, 'R': 1, 'W': 17, 'V': 19, + 'N': 2, 'Y': 18, 'M': 12, "X":20} + + super(DatasetPSR, self).__init__( + root, transform, pre_transform, pre_filter) + + self.transform, self.pre_transform, self.pre_filter = transform, pre_transform, pre_filter + self.data, self.slices = torch.load(self.processed_paths[0]) + + @property + def processed_dir(self): + name = 'processed_PSR_label' + return osp.join(self.root, name, self.split) + + @property + def raw_file_names(self): + name = self.split + '.txt' + return name + + @property + def processed_file_names(self): + return 'data.pt' + + def deserialize(self, x, serialization_format): + """ + Deserializes dataset `x` assuming format given by `serialization_format` (pkl, json, msgpack). + """ + if serialization_format == 'pkl': + return pkl.loads(x) + elif serialization_format == 'json': + serialized = json.loads(x) + elif serialization_format == 'msgpack': + serialized = msgpack.unpackb(x) + else: + raise RuntimeError('Invalid serialization format') + + return serialized + + def _normalize(self, tensor, dim=-1): + ''' + Normalizes a `torch.Tensor` along dimension `dim` without `nan`s. + ''' + return torch.nan_to_num( + torch.div(tensor, torch.norm(tensor, dim=dim, keepdim=True))) + + def get_side_chain_angle_encoding(self, pos_n, pos_ca, pos_c, pos_cb, pos_g, pos_d, pos_e, pos_z, pos_h): + v1, v2, v3, v4, v5, v6, v7 = pos_ca - pos_n, pos_cb - pos_ca, pos_g - pos_cb, pos_d - pos_g, pos_e - pos_d, pos_z - pos_e, pos_h - pos_z + + # five side chain torsion angles + # We only consider the first four torsion angles in side chains since only the amino acid arginine has five side chain torsion angles, and the fifth angle is close to 0. + angle1 = torch.unsqueeze(self.diherals_ProNet(v1, v2, v3),1) + angle2 = torch.unsqueeze(self.diherals_ProNet(v2, v3, v4),1) + angle3 = torch.unsqueeze(self.diherals_ProNet(v3, v4, v5),1) + angle4 = torch.unsqueeze(self.diherals_ProNet(v4, v5, v6),1) + angle5 = torch.unsqueeze(self.diherals_ProNet(v5, v6, v7),1) + + side_chain_angles = torch.cat((angle1, angle2, angle3, angle4),1) + side_chain_embs = torch.cat((torch.sin(side_chain_angles), torch.cos(side_chain_angles)),1) + + return side_chain_embs + + def get_backbone_angle_encoding(self, X): + # X should be a num_residues x 3 x 3, order N, C-alpha, and C atoms of each residue + # N coords: X[:,0,:] + # CA coords: X[:,1,:] + # C coords: X[:,2,:] + # return num_residues x 6 + # From https://github.com/jingraham/neurips19-graph-protein-design + + X = torch.reshape(X, [3 * X.shape[0], 3]) + dX = X[1:] - X[:-1] + U = self._normalize(dX, dim=-1) + u0 = U[:-2] + u1 = U[1:-1] + u2 = U[2:] + + angle = self.diherals_ProNet(u0, u1, u2) + + # add phi[0], psi[-1], omega[-1] with value 0 + angle = F.pad(angle, [1, 2]) + angle = torch.reshape(angle, [-1, 3]) + angle_features = torch.cat([torch.cos(angle), torch.sin(angle)], 1) + return angle_features + + def diherals_ProNet(self, v1, v2, v3): + n1 = torch.cross(v1, v2) + n2 = torch.cross(v2, v3) + a = (n1 * n2).sum(dim=-1) + b = torch.nan_to_num((torch.cross(n1, n2) * v2).sum(dim=-1) / v2.norm(dim=1)) + torsion = torch.nan_to_num(torch.atan2(b, a)) + + return torsion + + def _three_to_one(self, residue): + try: + return three_to_one(residue) + except KeyError: + return "X" + + def parse_protein_df(self, protein_df): + atom_names, atom_pos, residue_type, atom_amino_id = [], [], [], [] + all_residues = protein_df["residue"].unique() + + residue_num = 0 + invalid = False + for residue in all_residues: + residue_name = protein_df[protein_df["residue"] == residue]["resname"].iloc[0] + if is_aa(residue_name) or residue_name == "UNK": + residue_df = protein_df[protein_df["residue"] == residue] + if residue_df["fullname"].str.strip().isin(["N"]).any() and residue_df["fullname"].str.strip().isin(["CA"]).any() and residue_df["fullname"].str.strip().isin(["C"]).any(): + residue_id = self.letter_to_num[self._three_to_one(residue_name)] + residue_type.append(residue_id) + for index, row in residue_df.iterrows(): + atom_names.append(row["fullname"].strip()) + if [row["x"], row["y"], row["z"]] == [0., 0., 0.]: + invalid = True + atom_pos.append([row["x"], row["y"], row["z"]]) + atom_amino_id.append(residue_num) + + residue_num += 1 + + if invalid: + return None, None, None, None + + return atom_names, np.array(atom_pos), residue_type, np.array(atom_amino_id) + + def get_key_atom_pos(self, amino_types, atom_names, atom_amino_id, atom_pos): + # atoms to compute side chain torsion angles: N, CA, CB, _G/_G1, _D/_D1, _E/_E1, _Z, NH1 + mask_n = np.char.equal(atom_names, 'N') + mask_ca = np.char.equal(atom_names, 'CA') + mask_c = np.char.equal(atom_names, 'C') + mask_cb = np.char.equal(atom_names, 'CB') + mask_g = np.char.equal(atom_names, 'CG') | np.char.equal(atom_names, 'SG') | np.char.equal(atom_names, 'OG') | np.char.equal(atom_names, 'CG1') | np.char.equal(atom_names, 'OG1') + mask_d = np.char.equal(atom_names, 'CD') | np.char.equal(atom_names, 'SD') | np.char.equal(atom_names, 'CD1') | np.char.equal(atom_names, 'OD1') | np.char.equal(atom_names, 'ND1') + mask_e = np.char.equal(atom_names, 'CE') | np.char.equal(atom_names, 'NE') | np.char.equal(atom_names, 'OE1') + mask_z = np.char.equal(atom_names, 'CZ') | np.char.equal(atom_names, 'NZ') + mask_h = np.char.equal(atom_names, 'NH1') + mask_u = np.char.equal(atom_names, 'U') + + pos_n = np.full((len(amino_types),3),np.nan) + pos_n[atom_amino_id[mask_n]] = atom_pos[mask_n] + pos_n[atom_amino_id[mask_u]] = atom_pos[mask_u] + pos_n = torch.FloatTensor(pos_n) + + pos_ca = np.full((len(amino_types),3),np.nan) + pos_ca[atom_amino_id[mask_ca]] = atom_pos[mask_ca] + pos_ca[atom_amino_id[mask_u]] = atom_pos[mask_u] + pos_ca = torch.FloatTensor(pos_ca) + + pos_c = np.full((len(amino_types),3),np.nan) + pos_c[atom_amino_id[mask_c]] = atom_pos[mask_c] + pos_c[atom_amino_id[mask_u]] = atom_pos[mask_u] + pos_c = torch.FloatTensor(pos_c) + + # if data only contain pos_ca, we set the position of C and N as the position of CA + pos_n[torch.isnan(pos_n)] = pos_ca[torch.isnan(pos_n)] + pos_c[torch.isnan(pos_c)] = pos_ca[torch.isnan(pos_c)] + + pos_cb = np.full((len(amino_types),3),np.nan) + pos_cb[atom_amino_id[mask_cb]] = atom_pos[mask_cb] + pos_cb = torch.FloatTensor(pos_cb) + + pos_g = np.full((len(amino_types),3),np.nan) + pos_g[atom_amino_id[mask_g]] = atom_pos[mask_g] + pos_g = torch.FloatTensor(pos_g) + + pos_d = np.full((len(amino_types),3),np.nan) + pos_d[atom_amino_id[mask_d]] = atom_pos[mask_d] + pos_d = torch.FloatTensor(pos_d) + + pos_e = np.full((len(amino_types),3),np.nan) + pos_e[atom_amino_id[mask_e]] = atom_pos[mask_e] + pos_e = torch.FloatTensor(pos_e) + + pos_z = np.full((len(amino_types),3),np.nan) + pos_z[atom_amino_id[mask_z]] = atom_pos[mask_z] + pos_z = torch.FloatTensor(pos_z) + + pos_h = np.full((len(amino_types),3),np.nan) + pos_h[atom_amino_id[mask_h]] = atom_pos[mask_h] + pos_h = torch.FloatTensor(pos_h) + + return pos_n, pos_ca, pos_c, pos_cb, pos_g, pos_d, pos_e, pos_z, pos_h + + def extract_protein_data(self, protein_df): + data = Data() + + atom_names, atom_pos, residue_types, atom_amino_id = self.parse_protein_df(protein_df) + + if atom_names == None: + return None + + ##### compute side chain torsion angles: N, CA, CB, _G/_G1, _D/_D1, _E/_E1, _Z, NH1 ##### + # extract key atom (e.g., backbone) positions + pos_n, pos_ca, pos_c, pos_cb, pos_g, pos_d, pos_e, pos_z, pos_h = self.get_key_atom_pos(residue_types, atom_names, atom_amino_id, atom_pos) + + # calculate side chain torsion angles, up to four + # do encoding + side_chain_angle_encoding = self.get_side_chain_angle_encoding(pos_n, pos_ca, pos_c, pos_cb, pos_g, pos_d, pos_e, pos_z, pos_h) + side_chain_angle_encoding[torch.isnan(side_chain_angle_encoding)] = 0 + + # three backbone torsion angles + backbone_angle_encoding = self.get_backbone_angle_encoding(torch.cat((torch.unsqueeze(pos_n,1), torch.unsqueeze(pos_ca,1), torch.unsqueeze(pos_c,1)),1)) + backbone_angle_encoding[torch.isnan(backbone_angle_encoding)] = 0 + + data.seq = torch.LongTensor(residue_types) + data.side_chain_angle_encoding = side_chain_angle_encoding + data.backbone_angle_encoding = backbone_angle_encoding + data.coords_ca = pos_ca + data.coords_n = pos_n + data.coords_c = pos_c + data.x = atom_names + data.pos = torch.tensor(atom_pos) + data.num_nodes = len(pos_ca) + + return data + + def process(self): + print('Beginning Processing ...') + + data_list = [] + + env = lmdb.open(osp.join(self.root, self.split), max_readers=1, readonly=True, + lock=False, readahead=False, meminit=False) + + with env.begin(write=False) as txn: + self._num_examples = int(txn.get(b'num_examples')) + self._serialization_format = \ + txn.get(b'serialization_format').decode() + self._id_to_idx = self.deserialize( + txn.get(b'id_to_idx'), self._serialization_format) + + self._env = env + + for index in tqdm(range(self._num_examples), desc="all samples"): + with self._env.begin(write=False) as txn: + compressed = txn.get(str(index).encode()) + buf = io.BytesIO(compressed) + with gzip.GzipFile(fileobj=buf, mode="rb") as f: + serialized = f.read() + try: + item = self.deserialize(serialized, self._serialization_format) + except: + return None + + # Recover special data types (currently only pandas dataframes). + if 'types' in item.keys(): + for x in item.keys(): + if (self._serialization_format=='json') and (item['types'][x] == str(pd.DataFrame)): + item[x] = pd.DataFrame(**item[x]) + else: + logging.warning('Data types in item %i not defined. Will use basic types only.'%index) + + if 'file_path' not in item: + item['file_path'] = str(self.data_file) + if 'id' not in item: + item['id'] = str(index) + + protein = item["atoms"] + + data = self.extract_protein_data(protein) + + if data != None: + for i in range(len(data.coords_ca)): + for j in range(i+1, len(data.coords_ca)): + if torch.equal(data.coords_ca[i], data.coords_ca[j]): + data = None + break + if data == None: + break + + if data != None: + data.y = item["scores"]["gdt_ts"] + data.id = item['id'] + data_list.append(data) + + data, slices = self.collate(data_list) + torch.save((data, slices), self.processed_paths[0]) + print('Done!') + +if __name__ == "__main__": + for split in ["test"]: + #for split in ['validation']: + print('#### Now processing {} data ####'.format(split)) + dataset = DatasetPSR(root="/lustre07/scratch/liusheng/atom3d_data/PSR/split-by-year/data", split=split) \ No newline at end of file diff --git a/examples_3D/finetune_ECMultiple.py b/examples_3D/finetune_ECMultiple.py index 40ce5d9..6b9a12c 100644 --- a/examples_3D/finetune_ECMultiple.py +++ b/examples_3D/finetune_ECMultiple.py @@ -408,6 +408,8 @@ def eval(device, loader): optimizer = optim.Adam(model_param_group, lr=args.lr, weight_decay=args.decay) elif args.optimizer == "SGD": optimizer = optim.SGD(model_param_group, lr=args.lr, weight_decay=5e-4, momentum=0.9) + elif args.optimizer == "AdamW": + optimizer = optim.AdamW(model_param_group, lr=args.lr, weight_decay=0) lr_scheduler = None if args.lr_scheduler == "CosineAnnealingLR": diff --git a/examples_3D/finetune_MSP.py b/examples_3D/finetune_MSP.py new file mode 100644 index 0000000..2d213f3 --- /dev/null +++ b/examples_3D/finetune_MSP.py @@ -0,0 +1,475 @@ +import os +import time +from sklearn.metrics import roc_auc_score + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader as TorchDataLoader +from torch_geometric.loader import DataLoader as PyGDataLoader +from torch_geometric.nn import global_max_pool, global_mean_pool +from tqdm import tqdm +from torch_geometric.data import Data + +from config import args +from Geom3D.datasets import DatasetGVP, DatasetMSP +from Geom3D.models import ProNet, MQAModel, GearNetIEConv, CD_Convolution +from dataset_loader import * + + + + +def model_setup(): + num_class = 1 + graph_pred_linear = None + + if args.model_3d == "GVP": + node_in_dim = (6, 3) + node_h_dim = (100, 16) + edge_in_dim = (32, 1) + edge_h_dim = (32, 1) + model = MQAModel(node_in_dim, node_h_dim, edge_in_dim, edge_h_dim) + + ns, _ = node_h_dim + ns *= 2 + drop_rate = 0.1 + graph_pred_linear = nn.Sequential( + nn.Linear(ns, 2*ns), nn.ReLU(inplace=True), + nn.Dropout(p=drop_rate), + nn.Linear(2*ns, num_class), + ) + + elif args.model_3d == "GearNet": + input_dim = 21 + model = GearNetIEConv( + input_dim=input_dim, embedding_dim=512, hidden_dims=[512, 512, 512, 512, 512, 512], num_relation=7, + batch_norm=True, concat_hidden=True, short_cut=True, readout="sum", layer_norm=True, dropout=0.2) + + num_mlp_layer = 3 + hidden_dims = [model.output_dim] * (num_mlp_layer - 1) + graph_pred_linear = GearNet_layer.MultiLayerPerceptron( + model.output_dim, hidden_dims + [num_class], batch_norm=True, dropout=0.5) + + elif args.model_3d == "GearNet_IEConv": + input_dim = 21 + model = GearNetIEConv( + input_dim=input_dim, embedding_dim=512, hidden_dims=[512, 512, 512, 512, 512, 512], num_relation=7, + batch_norm=True, concat_hidden=True, short_cut=True, readout="sum", layer_norm=True, dropout=0.2, use_ieconv=True) + + num_mlp_layer = 3 + hidden_dims = [model.output_dim] * (num_mlp_layer - 1) + graph_pred_linear = GearNet_layer.MultiLayerPerceptron( + model.output_dim, hidden_dims + [num_class], batch_norm=True, dropout=0.5) + + elif args.model_3d == "ProNet": + model = ProNet( + level=args.ProNet_level, + dropout=args.ProNet_dropout, + out_channels=num_class, + euler_noise=args.euler_noise, + ) + + graph_pred_linear = torch.nn.Sequential() + out_layers = 2 + hidden_channels=128 * 2 + out_channels=1 + dropout_rate = 0 + + for _ in range(out_layers-1): + graph_pred_linear.add_module("linear", nn.Linear(hidden_channels, hidden_channels)) + graph_pred_linear.add_module("relu", nn.ReLU()) + graph_pred_linear.add_module("dropout", nn.Dropout(dropout_rate)) + graph_pred_linear.add_module("output", nn.Linear(hidden_channels, out_channels)) + + elif args.model_3d == "CDConv": + geometric_radii = [x * args.CDConv_radius for x in args.CDConv_geometric_raddi_coeff] + model = CD_Convolution( + geometric_radii=geometric_radii, + sequential_kernel_size=args.CDConv_kernel_size, + kernel_channels=args.CDConv_kernel_channels, channels=args.CDConv_channels, base_width=args.CDConv_base_width, + num_classes=num_class) + + graph_pred_linear = MLP(in_channels=args.CDConv_channels[-1] * 2, + mid_channels=max(args.CDConv_channels[-1], num_class), + out_channels=num_class, + batch_norm=True, + dropout=0.2) + + else: + raise Exception("3D model {} not included.".format(args.model_3d)) + return model, graph_pred_linear + + +def load_model(model, graph_pred_linear, model_weight_file): + geometric_radii = [x * args.CDConv_radius for x in args.CDConv_geometric_raddi_coeff] + original_model = CD_Convolution( + geometric_radii=geometric_radii, + sequential_kernel_size=args.CDConv_kernel_size, + kernel_channels=args.CDConv_kernel_channels, channels=args.CDConv_channels, base_width=args.CDConv_base_width, + num_classes=1195) + + print("Loading from {}".format(model_weight_file)) + if "MoleculeSDE" in model_weight_file: + model_weight = torch.load(model_weight_file) + model.load_state_dict(model_weight["model_3D"]) + if (graph_pred_linear is not None) and ("graph_pred_linear" in model_weight): + graph_pred_linear.load_state_dict(model_weight["graph_pred_linear"]) + else: + model_weight = torch.load(model_weight_file) + original_model.load_state_dict(model_weight["model"]) + if (graph_pred_linear is not None) and ("graph_pred_linear" in model_weight): + graph_pred_linear.load_state_dict(model_weight["graph_pred_linear"]) + for name, param in original_model.named_parameters(): + if "classifier" not in name: + model.state_dict()[name].copy_(param) + return + + + +def save_model(save_best): + if not args.output_model_dir == "": + if save_best: + print("save model with optimal loss") + output_model_path = os.path.join(args.output_model_dir, "model.pth") + saved_model_dict = {} + saved_model_dict["model"] = model.state_dict() + if graph_pred_linear is not None: + saved_model_dict["graph_pred_linear"] = graph_pred_linear.state_dict() + torch.save(saved_model_dict, output_model_path) + + else: + print("save model in the last epoch") + output_model_path = os.path.join(args.output_model_dir, "model_final.pth") + saved_model_dict = {} + saved_model_dict["model"] = model.state_dict() + if graph_pred_linear is not None: + saved_model_dict["graph_pred_linear"] = graph_pred_linear.state_dict() + torch.save(saved_model_dict, output_model_path) + return + + +def train(epoch, device, loader, optimizer): + model.train() + if graph_pred_linear is not None: + graph_pred_linear.train() + + loss_acc = 0 + num_iters = len(loader) + + if args.verbose: + L = tqdm(loader) + else: + L = loader + for step, batch in enumerate(L): + batch = batch.to(device) + # sub_batch_1, sub_batch_2 = Data().to(device), Data().to(device) + # sub_batch_1.seq, sub_batch_1.side_chain_angle_encoding, sub_batch_1.backbone_angle_encoding, sub_batch_1.coords_ca, sub_batch_1.coords_n, sub_batch_1.coords_c, sub_batch_1.x, sub_batch_1.pos, sub_batch_1.num_nodes, sub_batch_1.edge_index, sub_batch_1.node_s, sub_batch_1.node_v, sub_batch_1.edge_s, sub_batch_1.edge_v = batch.seq_1, batch.side_chain_angle_encoding_1, batch.backbone_angle_encoding_1, batch.coords_ca_1, batch.coords_n_1, batch.coords_c_1, batch.x_1, batch.pos_1, batch.num_nodes_1, batch.edge_index_1, batch.node_s_1, batch.node_v_1, batch.edge_s_1, batch.edge_v_1 + # sub_batch_2.seq, sub_batch_2.side_chain_angle_encoding, sub_batch_2.backbone_angle_encoding, sub_batch_2.coords_ca, sub_batch_2.coords_n, sub_batch_2.coords_c, sub_batch_2.x, sub_batch_2.pos, sub_batch_2.num_nodes, sub_batch_2.edge_index, sub_batch_2.node_s, sub_batch_2.node_v, sub_batch_2.edge_s, sub_batch_2.edge_v = batch.seq_2, batch.side_chain_angle_encoding_2, batch.backbone_angle_encoding_2, batch.coords_ca_2, batch.coords_n_2, batch.coords_c_2, batch.x_2, batch.pos_2, batch.num_nodes_2, batch.edge_index_2, batch.node_s_2, batch.node_v_2, batch.edge_s_2, batch.edge_v_2 + + if args.model_3d == "ProNet": + if args.mask: + # random mask node aatype + mask_indice = torch.tensor(np.random.choice(batch.num_nodes, int(batch.num_nodes * args.mask_aatype), replace=False)) + batch.x[:, 0][mask_indice] = 25 + if args.noise: + # add gaussian noise to atom coords + gaussian_noise = torch.clip(torch.normal(mean=0.0, std=0.1, size=batch.coords_ca.shape), min=-0.3, max=0.3) + batch.coords_ca += gaussian_noise + if args.ProNet_level != 'aminoacid': + batch.coords_n += gaussian_noise + batch.coords_c += gaussian_noise + if args.deform: + # Anisotropic scale + deform = torch.clip(torch.normal(mean=1.0, std=0.1, size=(1, 3)), min=0.9, max=1.1) + batch.coords_ca *= deform + if args.ProNet_level != 'aminoacid': + batch.coords_n *= deform + batch.coords_c *= deform + + if args.model_3d == "GVP": + molecule_3D_repr_1 = model(batch.node_s_1, batch.node_v_1, batch.edge_s_1, batch.edge_v_1, batch.edge_index_1, batch.batch_protein_1, get_repr=True) + molecule_3D_repr_2 = model(batch.node_s_2, batch.node_v_2, batch.edge_s_2, batch.edge_v_2, batch.edge_index_2, batch.batch_protein_2, get_repr=True) + molecule_3D_repr = torch.cat((molecule_3D_repr_1, molecule_3D_repr_2), dim=1) + elif args.model_3d in ["GearNet", "GearNet_IEConv"]: + molecule_3D_repr = model(batch, batch.node_feature.float())["graph_feature"] + elif args.model_3d == "ProNet": + molecule_3D_repr_1 = model(batch.seq_1, batch.coords_n_1, batch.coords_ca_1, batch.coords_c_1, batch.side_chain_angle_encoding_1, batch.backbone_angle_encoding_1, batch.batch_protein_1, get_repr=True) + molecule_3D_repr_2 = model(batch.seq_2, batch.coords_n_2, batch.coords_ca_2, batch.coords_c_2, batch.side_chain_angle_encoding_2, batch.backbone_angle_encoding_2, batch.batch_protein_2, get_repr=True) + molecule_3D_repr = torch.cat((molecule_3D_repr_1, molecule_3D_repr_2), dim=1) + elif args.model_3d == "CDConv": + molecule_3D_repr_1 = model(batch.seq_1, batch.coords_ca_1, batch.batch_protein_1, split="training", get_repr=True) + molecule_3D_repr_2 = model(batch.seq_2, batch.coords_ca_2, batch.batch_protein_2, split="training", get_repr=True) + molecule_3D_repr = torch.cat((molecule_3D_repr_1, molecule_3D_repr_2), dim=1) + + if graph_pred_linear is not None: + pred = graph_pred_linear(molecule_3D_repr).squeeze().sigmoid() + else: + pred = molecule_3D_repr.squeeze().sigmoid() + + y = batch.y + + loss = criterion(pred.float(), y.float()) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + loss_acc += loss.cpu().detach().item() + + if args.lr_scheduler in ["CosineAnnealingWarmRestarts"]: + lr_scheduler.step(epoch - 1 + step / num_iters) + + loss_acc /= len(loader) + if args.lr_scheduler in ["StepLR", "CosineAnnealingLR"]: + lr_scheduler.step() + elif args.lr_scheduler in [ "ReduceLROnPlateau"]: + lr_scheduler.step(loss_acc) + + return loss_acc + + +@torch.no_grad() +def eval(device, loader): + model.eval() + if graph_pred_linear is not None: + graph_pred_linear.eval() + y_true = [] + y_scores = [] + y_scores_raw = [] + + if args.verbose: + L = tqdm(loader) + else: + L = loader + for batch in L: + batch = batch.to(device) + # sub_batch_1, sub_batch_2 = Data().to(device), Data().to(device) + # sub_batch_1.seq, sub_batch_1.side_chain_angle_encoding, sub_batch_1.backbone_angle_encoding, sub_batch_1.coords_ca, sub_batch_1.coords_n, sub_batch_1.coords_c, sub_batch_1.x, sub_batch_1.pos, sub_batch_1.num_nodes, sub_batch_1.edge_index, sub_batch_1.node_s, sub_batch_1.node_v, sub_batch_1.edge_s, sub_batch_1.edge_v = batch.seq_1, batch.side_chain_angle_encoding_1, batch.backbone_angle_encoding_1, batch.coords_ca_1, batch.coords_n_1, batch.coords_c_1, batch.x_1, batch.pos_1, batch.num_nodes_1, batch.edge_index_1, batch.node_s_1, batch.node_v_1, batch.edge_s_1, batch.edge_v_1 + # sub_batch_2.seq, sub_batch_2.side_chain_angle_encoding, sub_batch_2.backbone_angle_encoding, sub_batch_2.coords_ca, sub_batch_2.coords_n, sub_batch_2.coords_c, sub_batch_2.x, sub_batch_2.pos, sub_batch_2.num_nodes, sub_batch_2.edge_index, sub_batch_2.node_s, sub_batch_2.node_v, sub_batch_2.edge_s, sub_batch_2.edge_v = batch.seq_2, batch.side_chain_angle_encoding_2, batch.backbone_angle_encoding_2, batch.coords_ca_2, batch.coords_n_2, batch.coords_c_2, batch.x_2, batch.pos_2, batch.num_nodes_2, batch.edge_index_2, batch.node_s_2, batch.node_v_2, batch.edge_s_2, batch.edge_v_2 + + if args.model_3d == "ProNet": + if args.mask: + # random mask node aatype + mask_indice = torch.tensor(np.random.choice(batch.num_nodes, int(batch.num_nodes * args.mask_aatype), replace=False)) + batch.x[:, 0][mask_indice] = 25 + if args.noise: + # add gaussian noise to atom coords + gaussian_noise = torch.clip(torch.normal(mean=0.0, std=0.1, size=batch.coords_ca.shape), min=-0.3, max=0.3) + batch.coords_ca += gaussian_noise + if args.ProNet_level != 'aminoacid': + batch.coords_n += gaussian_noise + batch.coords_c += gaussian_noise + if args.deform: + # Anisotropic scale + deform = torch.clip(torch.normal(mean=1.0, std=0.1, size=(1, 3)), min=0.9, max=1.1) + batch.coords_ca *= deform + if args.ProNet_level != 'aminoacid': + batch.coords_n *= deform + batch.coords_c *= deform + + if args.model_3d == "GVP": + molecule_3D_repr_1 = model(batch.node_s_1, batch.node_v_1, batch.edge_s_1, batch.edge_v_1, batch.edge_index_1, batch.batch_protein_1, get_repr=True) + molecule_3D_repr_2 = model(batch.node_s_2, batch.node_v_2, batch.edge_s_2, batch.edge_v_2, batch.edge_index_2, batch.batch_protein_2, get_repr=True) + molecule_3D_repr = torch.cat((molecule_3D_repr_1, molecule_3D_repr_2), dim=1) + elif args.model_3d in ["GearNet", "GearNet_IEConv"]: + molecule_3D_repr = model(batch, batch.node_feature.float())["graph_feature"] + elif args.model_3d == "ProNet": + molecule_3D_repr_1 = model(batch.seq_1, batch.coords_n_1, batch.coords_ca_1, batch.coords_c_1, batch.side_chain_angle_encoding_1, batch.backbone_angle_encoding_1, batch.batch_protein_1, get_repr=True) + molecule_3D_repr_2 = model(batch.seq_2, batch.coords_n_2, batch.coords_ca_2, batch.coords_c_2, batch.side_chain_angle_encoding_2, batch.backbone_angle_encoding_2, batch.batch_protein_2, get_repr=True) + molecule_3D_repr = torch.cat((molecule_3D_repr_1, molecule_3D_repr_2), dim=1) + elif args.model_3d == "CDConv": + molecule_3D_repr_1 = model(batch.seq_1, batch.coords_ca_1, batch.batch_protein_1, get_repr=True) + molecule_3D_repr_2 = model(batch.seq_2, batch.coords_ca_2, batch.batch_protein_2, get_repr=True) + molecule_3D_repr = torch.cat((molecule_3D_repr_1, molecule_3D_repr_2), dim=1) + + if graph_pred_linear is not None: + pred = graph_pred_linear(molecule_3D_repr).squeeze().sigmoid() + else: + pred = molecule_3D_repr.squeeze().sigmoid() + + y = batch.y + + y_scores_raw.append(pred) + pred = (pred >= 0.5).long() + y_true.append(y) + y_scores.append(pred) + + for i in range(len(y_scores)): + if y_scores[i].dim() == 0: + y_scores[i] = y_scores[i].unsqueeze(0) + y_scores_raw[i] = y_scores_raw[i].unsqueeze(0) + + y_true = torch.cat(y_true, dim=0).cpu().numpy() + y_scores = torch.cat(y_scores, dim=0).cpu().numpy() + y_scores_raw = torch.cat(y_scores_raw, dim=0).cpu().numpy() + + L = len(y_true) + acc = sum(y_true == y_scores) * 1. / L + + auroc = roc_auc_score(y_true, y_scores_raw) + + return acc, auroc + +if __name__ == "__main__": + torch.manual_seed(42) + np.random.seed(42) + device = ( + torch.device("cuda:" + str(args.device)) + if torch.cuda.is_available() + else torch.device("cpu") + ) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(42) + + data_root = args.data_root + + dataset_class = DatasetMSP + + train_dataset = dataset_class(root=data_root, split='train') + valid_dataset = dataset_class(root=data_root, split='val') + test_dataset = dataset_class(root=data_root, split='test') + + if args.model_3d == "GVP": + data_root = "../data/FOLD_GVP" + train_dataset = DatasetGVP( + root=data_root, dataset=train_dataset, split='train', num_positional_embeddings=args.num_positional_embeddings, top_k=args.top_k, num_rbf=args.num_rbf, multi_protein=True) + valid_dataset = DatasetGVP( + root=data_root, dataset=valid_dataset, split='val', num_positional_embeddings=args.num_positional_embeddings, top_k=args.top_k, num_rbf=args.num_rbf, multi_protein=True) + test_dataset = DatasetGVP( + root=data_root, dataset=test_dataset, split='test', num_positional_embeddings=args.num_positional_embeddings, top_k=args.top_k, num_rbf=args.num_rbf, multi_protein=True) + + criterion = nn.BCELoss() + + DataLoaderClass = DataLoaderMultiPro + dataloader_kwargs = {} + if args.model_3d in ["GearNet", "GearNet_IEConv"]: + dataloader_kwargs["collate_fn"] = DatasetFOLD.collate_fn + DataLoaderClass = TorchDataLoader + + train_loader = DataLoaderClass( + train_dataset, + batch_size=args.batch_size, + shuffle=True, + num_workers=args.num_workers, + **dataloader_kwargs + ) + val_loader = DataLoaderClass( + valid_dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.num_workers, + **dataloader_kwargs + ) + test_loader = DataLoaderClass( + test_dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.num_workers, + **dataloader_kwargs + ) + + model, graph_pred_linear = model_setup() + + if args.input_model_file is not "": + load_model(model, graph_pred_linear, args.input_model_file) + model.to(device) + print(model) + if graph_pred_linear is not None: + graph_pred_linear.to(device) + print(graph_pred_linear) + + # set up optimizer + # different learning rate for different part of GNN + model_param_group = [{"params": model.parameters(), "lr": args.lr}] + if graph_pred_linear is not None: + model_param_group.append( + {"params": graph_pred_linear.parameters(), "lr": args.lr} + ) + + if args.optimizer == "Adam": + optimizer = optim.Adam(model_param_group, lr=args.lr, weight_decay=args.decay) + elif args.optimizer == "SGD": + optimizer = optim.SGD(model_param_group, lr=args.lr, weight_decay=5e-4, momentum=0.9) + + lr_scheduler = None + if args.lr_scheduler == "CosineAnnealingLR": + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, args.epochs + ) + print("Apply lr scheduler CosineAnnealingLR") + elif args.lr_scheduler == "CosineAnnealingWarmRestarts": + lr_scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts( + optimizer, args.epochs, eta_min=1e-4 + ) + print("Apply lr scheduler CosineAnnealingWarmRestarts") + elif args.lr_scheduler == "StepLR": + lr_scheduler = optim.lr_scheduler.StepLR( + optimizer, step_size=args.lr_decay_step_size, gamma=args.lr_decay_factor + ) + print("Apply lr scheduler StepLR") + elif args.lr_scheduler == "ReduceLROnPlateau": + lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau( + optimizer, factor=args.lr_decay_factor, patience=args.lr_decay_patience, min_lr=args.min_lr + ) + print("Apply lr scheduler ReduceLROnPlateau") + elif args.lr_scheduler == "StepLRCustomized": + print("Will decay with {}, at epochs {}".format(args.lr_decay_factor, args.StepLRCustomized_scheduler)) + print("Apply lr scheduler StepLR (customized)") + else: + print("lr scheduler {} is not included.".format(args.lr_scheduler)) + global_learning_rate = args.lr + + train_acc_list, val_acc_list, test_acc_list = [], [], [] + train_auroc_list, val_auroc_list, test_auroc_list = [], [], [] + best_val_auroc, best_val_idx = -1e10, 0 + for epoch in range(1, args.epochs + 1): + start_time = time.time() + loss_acc = train(epoch, device, train_loader, optimizer) + print("Epoch: {}\nLoss: {}".format(epoch, loss_acc)) + + if epoch % args.print_every_epoch == 0: + if args.eval_train: + train_acc, train_auroc = eval(device, train_loader) + else: + train_acc, train_auroc = 0, 0 + val_acc, val_auroc = eval(device, val_loader) + test_acc, test_auroc = eval(device, test_loader) + + train_acc_list.append(train_acc) + val_acc_list.append(val_acc) + test_acc_list.append(test_acc) + train_auroc_list.append(train_auroc) + val_auroc_list.append(val_auroc) + test_auroc_list.append(test_auroc) + print( + "train_acc: {:.6f}\ttrain_auroc: {:.6f}\tval_acc: {:.6f}\tval_auroc: {:.6f}\ttest_acc: {:.6f}\ttest_auroc: {:.6f}".format( + train_acc, train_auroc, val_acc, val_auroc, test_acc, test_auroc + ) + ) + + if val_auroc > best_val_auroc: + best_val_auroc = val_auroc + best_val_idx = len(train_auroc_list) - 1 + if not args.output_model_dir == "": + save_model(save_best=True) + + if args.lr_scheduler == "StepLRCustomized" and epoch in args.StepLRCustomized_scheduler: + print('ChanGINg learning rate, from {} to {}'.format(global_learning_rate, global_learning_rate * args.lr_decay_factor)), + global_learning_rate *= args.lr_decay_factor + for param_group in optimizer.param_groups: + param_group['lr'] = global_learning_rate + print("Took\t{}\n".format(time.time() - start_time)) + + print( + "best train_acc: {:.6f}\ttrain_auroc: {:.6f}\tval_acc: {:.6f}\tval_auroc: {:.6f}\ttest_acc: {:.6f}\ttest_auroc: {:.6f}".format( + train_acc_list[best_val_idx], + train_auroc_list[best_val_idx], + val_acc_list[best_val_idx], + val_auroc_list[best_val_idx], + test_acc_list[best_val_idx], + test_auroc_list[best_val_idx], + ) + ) + + save_model(save_best=False) \ No newline at end of file diff --git a/examples_3D/finetune_PSR.py b/examples_3D/finetune_PSR.py new file mode 100644 index 0000000..0489671 --- /dev/null +++ b/examples_3D/finetune_PSR.py @@ -0,0 +1,442 @@ +import os +import time + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader as TorchDataLoader +from torch_geometric.loader import DataLoader as PyGDataLoader +from torch_geometric.nn import global_max_pool, global_mean_pool +from tqdm import tqdm + +from config import args +from Geom3D.datasets import DatasetPSR, DatasetGVP +from Geom3D.models import ProNet, GearNetIEConv, MQAModel, CD_Convolution +import Geom3D.models.GearNet_layer as GearNet_layer + + +def model_setup(): + num_class = 1 + graph_pred_linear = None + + if args.model_3d == "GVP": + node_in_dim = (6, 3) + node_h_dim = (100, 16) + edge_in_dim = (32, 1) + edge_h_dim = (32, 1) + model = MQAModel(node_in_dim, node_h_dim, edge_in_dim, edge_h_dim, out_channels=num_class) + + elif args.model_3d == "GearNet": + input_dim = 21 + model = GearNetIEConv( + input_dim=input_dim, embedding_dim=512, hidden_dims=[512, 512, 512, 512, 512, 512], num_relation=7, + batch_norm=True, concat_hidden=True, short_cut=True, readout="sum", layer_norm=True, dropout=0.2) + + num_mlp_layer = 3 + hidden_dims = [model.output_dim] * (num_mlp_layer - 1) + graph_pred_linear = GearNet_layer.MultiLayerPerceptron( + model.output_dim, hidden_dims + [num_class], batch_norm=True, dropout=0.5) + + elif args.model_3d == "GearNet_IEConv": + input_dim = 21 + model = GearNetIEConv( + input_dim=input_dim, embedding_dim=512, hidden_dims=[512, 512, 512, 512, 512, 512], num_relation=7, + batch_norm=True, concat_hidden=True, short_cut=True, readout="sum", layer_norm=True, dropout=0.2, use_ieconv=True) + + num_mlp_layer = 3 + hidden_dims = [model.output_dim] * (num_mlp_layer - 1) + graph_pred_linear = GearNet_layer.MultiLayerPerceptron( + model.output_dim, hidden_dims + [num_class], batch_norm=True, dropout=0.5) + + elif args.model_3d == "ProNet": + model = ProNet( + level=args.ProNet_level, + dropout=args.ProNet_dropout, + out_channels=num_class, + euler_noise=args.euler_noise, + ) + + elif args.model_3d == "CDConv": + geometric_radii = [x * args.CDConv_radius for x in args.CDConv_geometric_raddi_coeff] + model = CD_Convolution( + geometric_radii=geometric_radii, + sequential_kernel_size=args.CDConv_kernel_size, + kernel_channels=args.CDConv_kernel_channels, channels=args.CDConv_channels, base_width=args.CDConv_base_width, + num_classes=num_class) + + elif args.model_3d == "FrameNetProtein": + if args.FrameNetProtein_type == "FrameNetProtein01": + model = FrameNetProtein01( + num_residue_acid=26, + latent_dim=args.emb_dim, + num_class=num_class, + num_radial=args.FrameNetProtein_num_radial, + backbone_cutoff=args.FrameNetProtein_backbone_cutoff, + cutoff=args.FrameNetProtein_cutoff, + rbf_type=args.FrameNetProtein_rbf_type, + rbf_gamma=args.FrameNetProtein_gamma, + num_layer=args.FrameNetProtein_num_layer, + readout=args.FrameNetProtein_readout, + ) + + else: + raise Exception("3D model {} not included.".format(args.model_3d)) + return model, graph_pred_linear + + +def load_model(model, graph_pred_linear, model_weight_file): + print("Loading from {}".format(model_weight_file)) + if "MoleculeSDE" in model_weight_file: + model_weight = torch.load(model_weight_file) + model.load_state_dict(model_weight["model_3D"]) + if (graph_pred_linear is not None) and ("graph_pred_linear" in model_weight): + graph_pred_linear.load_state_dict(model_weight["graph_pred_linear"]) + + else: + model_weight = torch.load(model_weight_file) + model.load_state_dict(model_weight["model"]) + if (graph_pred_linear is not None) and ("graph_pred_linear" in model_weight): + graph_pred_linear.load_state_dict(model_weight["graph_pred_linear"]) + return + + +def save_model(save_best): + if not args.output_model_dir == "": + if save_best: + print("save model with optimal loss") + output_model_path = os.path.join(args.output_model_dir, "model.pth") + saved_model_dict = {} + saved_model_dict["model"] = model.state_dict() + if graph_pred_linear is not None: + saved_model_dict["graph_pred_linear"] = graph_pred_linear.state_dict() + torch.save(saved_model_dict, output_model_path) + + else: + print("save model in the last epoch") + output_model_path = os.path.join(args.output_model_dir, "model_final.pth") + saved_model_dict = {} + saved_model_dict["model"] = model.state_dict() + if graph_pred_linear is not None: + saved_model_dict["graph_pred_linear"] = graph_pred_linear.state_dict() + torch.save(saved_model_dict, output_model_path) + return + + +def train(epoch, device, loader, optimizer): + model.train() + if graph_pred_linear is not None: + graph_pred_linear.train() + + loss_acc = 0 + num_iters = len(loader) + + if args.verbose: + L = tqdm(loader) + else: + L = loader + for step, batch in enumerate(L): + if args.model_3d == "ProNet": + if args.mask: + # random mask node aatype + mask_indice = torch.tensor(np.random.choice(batch.num_nodes, int(batch.num_nodes * args.mask_aatype), replace=False)) + batch.x[:, 0][mask_indice] = 25 + if args.noise: + # add gaussian noise to atom coords + gaussian_noise = torch.clip(torch.normal(mean=0.0, std=0.1, size=batch.coords_ca.shape), min=-0.3, max=0.3) + batch.coords_ca += gaussian_noise + if args.ProNet_level != 'aminoacid': + batch.coords_n += gaussian_noise + batch.coords_c += gaussian_noise + if args.deform: + # Anisotropic scale + deform = torch.clip(torch.normal(mean=1.0, std=0.1, size=(1, 3)), min=0.9, max=1.1) + batch.coords_ca *= deform + if args.ProNet_level != 'aminoacid': + batch.coords_n *= deform + batch.coords_c *= deform + + batch = batch.to(device) + + if args.model_3d == "GVP": + molecule_3D_repr = model(batch.node_s, batch.node_v, batch.edge_s, batch.edge_v, batch.edge_index, batch.batch) + elif args.model_3d in ["GearNet", "GearNet_IEConv"]: + molecule_3D_repr = model(batch, batch.node_feature.float())["graph_feature"] + elif args.model_3d == "ProNet": + molecule_3D_repr = model(batch.seq, batch.coords_n, batch.coords_ca, batch.coords_c, batch.side_chain_angle_encoding, batch.backbone_angle_encoding, batch.batch) + elif args.model_3d == "CDConv": + molecule_3D_repr = model(batch.seq, batch.coords_ca, batch.batch, split="training") + elif args.model_3d == "FrameNetProtein": + molecule_3D_repr = model(batch.coords_n, batch.coords_ca, batch.coords_c, batch.seq, batch.batch) + + if graph_pred_linear is not None: + pred = graph_pred_linear(molecule_3D_repr).squeeze(1) + else: + pred = molecule_3D_repr.squeeze(1) + + y = batch.y + + loss = criterion(pred, y) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + loss_acc += loss.cpu().detach().item() + + if args.lr_scheduler in ["CosineAnnealingWarmRestarts"]: + lr_scheduler.step(epoch - 1 + step / num_iters) + + loss_acc /= len(loader) + if args.lr_scheduler in ["StepLR", "CosineAnnealingLR"]: + lr_scheduler.step() + elif args.lr_scheduler in [ "ReduceLROnPlateau"]: + lr_scheduler.step(loss_acc) + + return loss_acc + + +@torch.no_grad() +def eval(device, loader): + model.eval() + if graph_pred_linear is not None: + graph_pred_linear.eval() + y_true = [] + y_scores = [] + + if args.verbose: + L = tqdm(loader) + else: + L = loader + for batch in L: + if args.model_3d == "ProNet": + if args.mask: + # random mask node aatype + mask_indice = torch.tensor(np.random.choice(batch.num_nodes, int(batch.num_nodes * args.mask_aatype), replace=False)) + batch.x[:, 0][mask_indice] = 25 + if args.noise: + # add gaussian noise to atom coords + gaussian_noise = torch.clip(torch.normal(mean=0.0, std=0.1, size=batch.coords_ca.shape), min=-0.3, max=0.3) + batch.coords_ca += gaussian_noise + if args.ProNet_level != 'aminoacid': + batch.coords_n += gaussian_noise + batch.coords_c += gaussian_noise + if args.deform: + # Anisotropic scale + deform = torch.clip(torch.normal(mean=1.0, std=0.1, size=(1, 3)), min=0.9, max=1.1) + batch.coords_ca *= deform + if args.ProNet_level != 'aminoacid': + batch.coords_n *= deform + batch.coords_c *= deform + + batch = batch.to(device) + + if args.model_3d == "GVP": + molecule_3D_repr = model(batch.node_s, batch.node_v, batch.edge_s, batch.edge_v, batch.edge_index, batch.batch) + elif args.model_3d in ["GearNet", "GearNet_IEConv"]: + molecule_3D_repr = model(batch, batch.node_feature.float())["graph_feature"] + elif args.model_3d == "ProNet": + molecule_3D_repr = model(batch.seq, batch.coords_n, batch.coords_ca, batch.coords_c, batch.side_chain_angle_encoding, batch.backbone_angle_encoding, batch.batch) + elif args.model_3d == "CDConv": + molecule_3D_repr = model(batch.seq, batch.coords_ca, batch.batch, split="training") + elif args.model_3d == "FrameNetProtein": + molecule_3D_repr = model(batch.coords_n, batch.coords_ca, batch.coords_c, batch.seq, batch.batch) + + if graph_pred_linear is not None: + pred = graph_pred_linear(molecule_3D_repr).squeeze() + else: + pred = molecule_3D_repr.squeeze() + pred = pred.argmax(dim=-1) + + y = batch.y + + y_true.append(y) + y_scores.append(pred) + + y_true = torch.cat(y_true, dim=0).cpu().numpy() + y_scores = torch.cat(y_scores, dim=0).cpu().numpy() + + L = len(y_true) + acc = sum(y_true == y_scores) * 1. / L + return acc + +if __name__ == "__main__": + torch.manual_seed(args.seed) + np.random.seed(args.seed) + device = ( + torch.device("cuda:" + str(args.device)) + if torch.cuda.is_available() + else torch.device("cpu") + ) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + + data_root = args.data_root + + dataset_class = DatasetFOLD + # if args.model_3d == "GearNet": + # dataset_class = DatasetFOLDGearNet + + train_dataset = dataset_class(root=data_root, split='training') + valid_dataset = dataset_class(root=data_root, split='validation') + test_fold_dataset = dataset_class(root=data_root, split='test_fold') + test_superfamily_dataset = dataset_class(root=data_root, split='test_superfamily') + test_family_dataset = dataset_class(root=data_root, split='test_family') + + if args.model_3d == "GVP": + data_root = "../data/FOLD_GVP" + train_dataset = DatasetGVP( + root=data_root, dataset=train_dataset, split='training', num_positional_embeddings=args.num_positional_embeddings, top_k=args.top_k, num_rbf=args.num_rbf) + valid_dataset = DatasetGVP( + root=data_root, dataset=valid_dataset, split='validation', num_positional_embeddings=args.num_positional_embeddings, top_k=args.top_k, num_rbf=args.num_rbf) + test_fold_dataset = DatasetGVP( + root=data_root, dataset=test_fold_dataset, split='test_fold', num_positional_embeddings=args.num_positional_embeddings, top_k=args.top_k, num_rbf=args.num_rbf) + test_superfamily_dataset = DatasetGVP( + root=data_root, dataset=test_superfamily_dataset, split='test_superfamily', num_positional_embeddings=args.num_positional_embeddings, top_k=args.top_k, num_rbf=args.num_rbf) + test_family_dataset = DatasetGVP( + root=data_root, dataset=test_family_dataset, split='test_family', num_positional_embeddings=args.num_positional_embeddings, top_k=args.top_k, num_rbf=args.num_rbf) + + criterion = nn.CrossEntropyLoss() + + DataLoaderClass = PyGDataLoader + dataloader_kwargs = {} + if args.model_3d in ["GearNet", "GearNet_IEConv"]: + dataloader_kwargs["collate_fn"] = DatasetFOLD.collate_fn + DataLoaderClass = TorchDataLoader + + train_loader = DataLoaderClass( + train_dataset, + batch_size=args.batch_size, + shuffle=True, + num_workers=args.num_workers, + **dataloader_kwargs + ) + val_loader = DataLoaderClass( + valid_dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.num_workers, + **dataloader_kwargs + ) + test_fold_loader = DataLoaderClass( + test_fold_dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.num_workers, + **dataloader_kwargs + ) + test_superfamily_loader = DataLoaderClass( + test_superfamily_dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.num_workers, + **dataloader_kwargs + ) + test_family_loader = DataLoaderClass( + test_family_dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.num_workers, + **dataloader_kwargs + ) + + model, graph_pred_linear = model_setup() + + if args.input_model_file is not "": + load_model(model, graph_pred_linear, args.input_model_file) + model.to(device) + print(model) + if graph_pred_linear is not None: + graph_pred_linear.to(device) + print(graph_pred_linear) + + # set up optimizer + # different learning rate for different part of GNN + model_param_group = [{"params": model.parameters(), "lr": args.lr}] + if graph_pred_linear is not None: + model_param_group.append( + {"params": graph_pred_linear.parameters(), "lr": args.lr} + ) + optimizer = optim.Adam(model_param_group, lr=args.lr, weight_decay=args.decay) + + lr_scheduler = None + if args.lr_scheduler == "CosineAnnealingLR": + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, args.epochs + ) + print("Apply lr scheduler CosineAnnealingLR") + elif args.lr_scheduler == "CosineAnnealingWarmRestarts": + lr_scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts( + optimizer, args.epochs, eta_min=1e-4 + ) + print("Apply lr scheduler CosineAnnealingWarmRestarts") + elif args.lr_scheduler == "StepLR": + lr_scheduler = optim.lr_scheduler.StepLR( + optimizer, step_size=args.lr_decay_step_size, gamma=args.lr_decay_factor + ) + print("Apply lr scheduler StepLR") + elif args.lr_scheduler == "ReduceLROnPlateau": + lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau( + optimizer, factor=args.lr_decay_factor, patience=args.lr_decay_patience, min_lr=args.min_lr + ) + print("Apply lr scheduler ReduceLROnPlateau") + elif args.lr_scheduler == "StepLRCustomized": + print("Will decay with {}, at epochs {}".format(args.lr_decay_factor, args.StepLRCustomized_scheduler)) + print("Apply lr scheduler StepLR (customized)") + else: + print("lr scheduler {} is not included.".format(args.lr_scheduler)) + global_learning_rate = args.lr + + train_acc_list, val_acc_list = [], [] + test_acc_fold_list, test_acc_superfamily_list, test_acc_family_list = [], [], [] + best_val_acc, best_val_idx = -1e10, 0 + for epoch in range(1, args.epochs + 1): + start_time = time.time() + loss_acc = train(epoch, device, train_loader, optimizer) + print("Epoch: {}\nLoss: {}".format(epoch, loss_acc)) + + if epoch % args.print_every_epoch == 0: + if args.eval_train: + train_acc, train_target, train_pred = eval(device, train_loader) + else: + train_acc = 0 + val_acc = eval(device, val_loader) + test_fold_acc = eval(device, test_fold_loader) + test_superfamily_acc = eval(device, test_superfamily_loader) + test_family_acc = eval(device, test_family_loader) + + train_acc_list.append(train_acc) + val_acc_list.append(val_acc) + test_acc_fold_list.append(test_fold_acc) + test_acc_superfamily_list.append(test_superfamily_acc) + test_acc_family_list.append(test_family_acc) + print( + "train: {:.6f}\tval: {:.6f}\ttest-fold: {:.6f}\ttest-superfamily: {:.6f}\ttest-family: {:.6f}".format( + train_acc, val_acc, test_fold_acc, test_superfamily_acc, test_family_acc + ) + ) + + if val_acc > best_val_acc: + best_val_acc = val_acc + best_val_idx = len(train_acc_list) - 1 + if not args.output_model_dir == "": + save_model(save_best=True) + + if args.lr_scheduler == "StepLRCustomized" and epoch in args.StepLRCustomized_scheduler: + print('ChanGINg learning rate, from {} to {}'.format(global_learning_rate, global_learning_rate * args.lr_decay_factor)), + global_learning_rate *= args.lr_decay_factor + for param_group in optimizer.param_groups: + param_group['lr'] = global_learning_rate + print("Took\t{}\n".format(time.time() - start_time)) + + print( + "best train: {:.6f}\tval: {:.6f}\ttest-fold: {:.6f}\ttest-superfamily: {:.6f}\ttest-family: {:.6f}".format( + train_acc_list[best_val_idx], + val_acc_list[best_val_idx], + test_acc_fold_list[best_val_idx], + test_acc_superfamily_list[best_val_idx], + test_acc_family_list[best_val_idx], + ) + ) + + save_model(save_best=False) \ No newline at end of file diff --git a/scripts/ECMultiple/submit_CDConv.sh b/scripts/ECMultiple/submit_CDConv.sh new file mode 100644 index 0000000..a7129c9 --- /dev/null +++ b/scripts/ECMultiple/submit_CDConv.sh @@ -0,0 +1,35 @@ +cd ../../examples_3D + +export model_3d=CDConv + +export lr_list=(5e-4 1e-4 1e-3) +export seed=42 +export batch_size_list=(8 16 64) + + +export epochs=500 +export dataset=ECMultiple +export StepLRCustomized_scheduler="60" + + +export lr_scheduler=CosineAnnealingLR + +for lr in "${lr_list[@]}"; do + +for batch_size in "${batch_size_list[@]}"; do + + export output_model_dir=../output/"$model_3d"/"$dataset"/"$seed"/"$lr"_"$lr_scheduler"_"$batch_size"_"$epochs" + export output_file="$output_model_dir"/result.out + echo "$output_model_dir" + mkdir -p "$output_model_dir" + + python finetune_ECMultiple.py \ + --model_3d="$model_3d" --dataset="$dataset" --epochs="$epochs" \ + --seed="$seed" \ + --batch_size="$batch_size" --optimizer SGD --CDConv_kernel_channels 32 --CDConv_base_width 16 \ + --lr="$lr" --lr_scheduler="$lr_scheduler" --print_every_epoch=1 \ + --output_model_dir="$output_model_dir" \ + > "$output_file" + +done +done diff --git a/scripts/ECMultiple/submit_GVP.sh b/scripts/ECMultiple/submit_GVP.sh new file mode 100644 index 0000000..33ea6c6 --- /dev/null +++ b/scripts/ECMultiple/submit_GVP.sh @@ -0,0 +1,35 @@ +cd ../../examples_3D + +export model_3d=GVP + +export lr_list=(1e-4 1e-3) +export seed=42 +export batch_size_list=(8) + + +export epochs=200 +export dataset=ECMultiple +export StepLRCustomized_scheduler="60" + + +export lr_scheduler=CosineAnnealingLR + +for lr in "${lr_list[@]}"; do + +for batch_size in "${batch_size_list[@]}"; do + + export output_model_dir=../output/"$model_3d"/"$dataset"/"$seed"/"$lr"_"$lr_scheduler"_"$batch_size"_"$epochs" + export output_file="$output_model_dir"/result.out + echo "$output_model_dir" + mkdir -p "$output_model_dir" + + python finetune_ECMultiple.py \ + --model_3d="$model_3d" --epochs="$epochs" \ + --seed="$seed" \ + --batch_size="$batch_size" \ + --lr="$lr" --lr_scheduler="$lr_scheduler" --print_every_epoch=1 \ + --output_model_dir="$output_model_dir" \ + > "$output_file" + +done +done diff --git a/scripts/ECMultiple/submit_GearNet.sh b/scripts/ECMultiple/submit_GearNet.sh new file mode 100644 index 0000000..26ac1c5 --- /dev/null +++ b/scripts/ECMultiple/submit_GearNet.sh @@ -0,0 +1,36 @@ +cd ../../examples_3D + +export model_3d=GearNet + +export lr_list=(1e-4 1e-3) +export seed=42 +export batch_size_list=(8) + + +export epochs=200 +export dataset=ECMultiple +export StepLRCustomized_scheduler="60" + + +export lr_scheduler=ReduceLROnPlateau + +for lr in "${lr_list[@]}"; do + +for batch_size in "${batch_size_list[@]}"; do + + export output_model_dir=../output/"$model_3d"/"$dataset"/"$seed"/"$lr"_"$lr_scheduler"_"$batch_size"_"$epochs" + export output_file="$output_model_dir"/result.out + echo "$output_model_dir" + mkdir -p "$output_model_dir" + + python finetune_ECMultiple.py \ + --model_3d="$model_3d" --epochs="$epochs" \ + --seed="$seed" \ + --batch_size="$batch_size" --optimizer AdamW \ + --lr="$lr" --lr_scheduler="$lr_scheduler" --lr_decay_factor 0.6 --lr_decay_patience 5 \ + --print_every_epoch=1 \ + --output_model_dir="$output_model_dir" \ + > "$output_file" + +done +done diff --git a/scripts/ECMultiple/submit_ProNet.sh b/scripts/ECMultiple/submit_ProNet.sh new file mode 100644 index 0000000..98e65d0 --- /dev/null +++ b/scripts/ECMultiple/submit_ProNet.sh @@ -0,0 +1,38 @@ +cd ../../examples_3D + +export model_3d=ProNet + +export lr_list=(1e-4 1e-3) +export seed=42 +export batch_size_list=(8 64) +export level_list=(allatom aminoacid backbone) + + +export epochs=300 +export dataset=ECMultiple +export StepLRCustomized_scheduler="60" + + +export lr_scheduler=CosineAnnealingLR + +for lr in "${lr_list[@]}"; do +for batch_size in "${batch_size_list[@]}"; do +for level in "${level_list[@]}"; do + + export output_model_dir=../output/"$model_3d"/"$dataset"/"$seed"/"$lr"_"$lr_scheduler"_"$batch_size"_"$epochs"_"$level" + export output_file="$output_model_dir"/result.out + echo "$output_model_dir" + mkdir -p "$output_model_dir" + + python finetune_ECMultiple.py \ + --model_3d="$model_3d" --epochs="$epochs" \ + --ProNet_level="$level" \ + --seed="$seed" \ + --batch_size="$batch_size" \ + --lr="$lr" --lr_scheduler="$lr_scheduler" --print_every_epoch=1 \ + --output_model_dir="$output_model_dir" \ + > "$output_file" + +done +done +done diff --git a/scripts/ECSingle/submit_CDConv.sh b/scripts/ECSingle/submit_CDConv.sh new file mode 100644 index 0000000..50d0561 --- /dev/null +++ b/scripts/ECSingle/submit_CDConv.sh @@ -0,0 +1,35 @@ +cd ../../examples_3D + +export model_3d=CDConv + +export lr_list=(5e-4 1e-4 1e-3) +export seed=42 +export batch_size_list=(8 16 64) + + +export epochs=400 +export dataset=fold +export StepLRCustomized_scheduler="60" + + +export lr_scheduler=CosineAnnealingLR + +for lr in "${lr_list[@]}"; do + +for batch_size in "${batch_size_list[@]}"; do + + export output_model_dir=../output/"$model_3d"/"$dataset"/"$seed"/"$lr"_"$lr_scheduler"_"$batch_size"_"$epochs" + export output_file="$output_model_dir"/result.out + echo "$output_model_dir" + mkdir -p "$output_model_dir" + + python finetune_ECSingle.py \ + --model_3d="$model_3d" --dataset="$dataset" --epochs="$epochs" \ + --seed="$seed" \ + --batch_size="$batch_size" --optimizer SGD --CDConv_base_width 32 \ + --lr="$lr" --lr_scheduler="$lr_scheduler" --print_every_epoch=1 \ + --output_model_dir="$output_model_dir" \ + > "$output_file" + +done +done diff --git a/scripts/ECSingle/submit_GVP.sh b/scripts/ECSingle/submit_GVP.sh new file mode 100644 index 0000000..8d57d67 --- /dev/null +++ b/scripts/ECSingle/submit_GVP.sh @@ -0,0 +1,35 @@ +cd ../../examples_3D + +export model_3d=GVP + +export lr_list=(1e-4 1e-3) +export seed=42 +export batch_size_list=(8) + + +export epochs=300 +export dataset=ECSingle +export StepLRCustomized_scheduler="60" + + +export lr_scheduler=CosineAnnealingLR + +for lr in "${lr_list[@]}"; do + +for batch_size in "${batch_size_list[@]}"; do + + export output_model_dir=../output/"$model_3d"/"$dataset"/"$seed"/"$lr"_"$lr_scheduler"_"$batch_size"_"$epochs" + export output_file="$output_model_dir"/result.out + echo "$output_model_dir" + mkdir -p "$output_model_dir" + + python finetune_ECSingle.py \ + --model_3d="$model_3d" --epochs="$epochs" \ + --seed="$seed" \ + --batch_size="$batch_size" \ + --lr="$lr" --lr_scheduler="$lr_scheduler" --print_every_epoch=1 \ + --output_model_dir="$output_model_dir" \ + > "$output_file" + +done +done diff --git a/scripts/ECSingle/submit_GearNet.sh b/scripts/ECSingle/submit_GearNet.sh new file mode 100644 index 0000000..8f64e42 --- /dev/null +++ b/scripts/ECSingle/submit_GearNet.sh @@ -0,0 +1,35 @@ +cd ../../examples_3D + +export model_3d=GearNet + +export lr_list=(1e-4 1e-3) +export seed=42 +export batch_size_list=(8) + + +export epochs=300 +export dataset=ECSingle +export StepLRCustomized_scheduler="60" + + +export lr_scheduler=StepLR + +for lr in "${lr_list[@]}"; do + +for batch_size in "${batch_size_list[@]}"; do + + export output_model_dir=../output/"$model_3d"/"$dataset"/"$seed"/"$lr"_"$lr_scheduler"_"$batch_size"_"$epochs" + export output_file="$output_model_dir"/result.out + echo "$output_model_dir" + mkdir -p "$output_model_dir" + + python finetune_ECSingle.py \ + --model_3d="$model_3d" --epochs="$epochs" \ + --seed="$seed" \ + --batch_size="$batch_size" --lr_decay_step_size 50 --optimizer SGD \ + --lr="$lr" --lr_scheduler="$lr_scheduler" --print_every_epoch=1 \ + --output_model_dir="$output_model_dir" \ + > "$output_file" + +done +done diff --git a/scripts/FOLD/submit_CDConv.sh b/scripts/FOLD/submit_CDConv.sh new file mode 100644 index 0000000..7afc448 --- /dev/null +++ b/scripts/FOLD/submit_CDConv.sh @@ -0,0 +1,36 @@ +cd ../../examples_3D + +export model_3d=CDConv + +export lr_list=(5e-4 1e-4 1e-3) +export seed=42 +export batch_size_list=(8 16 64) + + +export epochs=400 +export dataset=fold +export StepLRCustomized_scheduler="60" + + +export lr_scheduler=CosineAnnealingLR + +for lr in "${lr_list[@]}"; do + +for batch_size in "${batch_size_list[@]}"; do + + export output_model_dir=../output/"$model_3d"/"$dataset"/"$seed"/"$lr"_"$lr_scheduler"_"$batch_size"_"$epochs" + export output_file="$output_model_dir"/result.out + echo "$output_model_dir" + mkdir -p "$output_model_dir" + + python finetune_FOLD.py \ + --model_3d="$model_3d" --dataset="$dataset" --epochs="$epochs" \ + --seed="$seed" \ + --batch_size="$batch_size" --optimizer SGD \ + --CDConv_kernel_size 5 \ + --lr="$lr" --lr_scheduler="$lr_scheduler" --print_every_epoch=1 \ + --output_model_dir="$output_model_dir" \ + > "$output_file" + +done +done diff --git a/scripts/FOLD/submit_GVP.sh b/scripts/FOLD/submit_GVP.sh new file mode 100644 index 0000000..4319f43 --- /dev/null +++ b/scripts/FOLD/submit_GVP.sh @@ -0,0 +1,35 @@ +cd ../../examples_3D + +export model_3d=GVP + +export lr_list=(5e-4 1e-4 1e-3) +export seed=42 +export batch_size_list=(64) + + +export epochs=400 +export dataset=fold +export StepLRCustomized_scheduler="60" + + +export lr_scheduler=CosineAnnealingLR + +for lr in "${lr_list[@]}"; do + +for batch_size in "${batch_size_list[@]}"; do + + export output_model_dir=../output/"$model_3d"/"$dataset"/"$seed"/"$lr"_"$lr_scheduler"_"$batch_size"_"$epochs" + export output_file="$output_model_dir"/result.out + echo "$output_model_dir" + mkdir -p "$output_model_dir" + + python finetune_FOLD.py \ + --model_3d="$model_3d" --dataset="$dataset" --epochs="$epochs" \ + --seed="$seed" \ + --batch_size="$batch_size" \ + --lr="$lr" --lr_scheduler="$lr_scheduler" --print_every_epoch=1 \ + --output_model_dir="$output_model_dir" \ + > "$output_file" + +done +done diff --git a/scripts/FOLD/submit_GearNet.sh b/scripts/FOLD/submit_GearNet.sh new file mode 100644 index 0000000..869f213 --- /dev/null +++ b/scripts/FOLD/submit_GearNet.sh @@ -0,0 +1,35 @@ +cd ../../examples_3D + +export model_3d=GearNet + +export lr_list=(1e-4 1e-3) +export seed=42 +export batch_size_list=(8) + + +export epochs=400 +export dataset=fold +export StepLRCustomized_scheduler="60" + + +export lr_scheduler=StepLR + +for lr in "${lr_list[@]}"; do + +for batch_size in "${batch_size_list[@]}"; do + + export output_model_dir=../output/"$model_3d"/"$dataset"/"$seed"/"$lr"_"$lr_scheduler"_"$batch_size"_"$epochs" + export output_file="$output_model_dir"/result.out + echo "$output_model_dir" + mkdir -p "$output_model_dir" + + python finetune_FOLD.py \ + --model_3d="$model_3d" --dataset="$dataset" --epochs="$epochs" \ + --seed="$seed" \ + --batch_size="$batch_size" --lr_decay_step_size 50 --optimizer SGD \ + --lr="$lr" --lr_scheduler="$lr_scheduler" --print_every_epoch=1 \ + --output_model_dir="$output_model_dir" \ + > "$output_file" + +done +done diff --git a/scripts/FOLD/submit_GearNet_IEConv.sh b/scripts/FOLD/submit_GearNet_IEConv.sh new file mode 100644 index 0000000..c5d1ee7 --- /dev/null +++ b/scripts/FOLD/submit_GearNet_IEConv.sh @@ -0,0 +1,35 @@ +cd ../../examples_3D + +export model_3d=GearNet_IEConv + +export lr_list=(1e-4 1e-3) +export seed=42 +export batch_size_list=(8) + + +export epochs=200 +export dataset=fold +export StepLRCustomized_scheduler="60" + + +export lr_scheduler=StepLR + +for lr in "${lr_list[@]}"; do + +for batch_size in "${batch_size_list[@]}"; do + + export output_model_dir=../output/"$model_3d"/"$dataset"/"$seed"/"$lr"_"$lr_scheduler"_"$batch_size"_"$epochs" + export output_file="$output_model_dir"/result.out + echo "$output_model_dir" + mkdir -p "$output_model_dir" + + python finetune_FOLD.py \ + --model_3d="$model_3d" --dataset="$dataset" --epochs="$epochs" \ + --seed="$seed" \ + --batch_size="$batch_size" --lr_decay_step_size 50 --optimizer SGD \ + --lr="$lr" --lr_scheduler="$lr_scheduler" --print_every_epoch=1 \ + --output_model_dir="$output_model_dir" \ + > "$output_file" + +done +done diff --git a/scripts/GO/submit_CDConv_bp.sh b/scripts/GO/submit_CDConv_bp.sh new file mode 100644 index 0000000..660ffef --- /dev/null +++ b/scripts/GO/submit_CDConv_bp.sh @@ -0,0 +1,36 @@ +cd ../../examples_3D + +export model_3d=CDConv + +export lr_list=(1e-4 1e-3) +export seed=42 +export batch_size_list=(8 16 64) + + +export epochs=200 +export dataset=GO_mf +export StepLRCustomized_scheduler="60" + + +export lr_scheduler=CosineAnnealingLR + +for lr in "${lr_list[@]}"; do + +for batch_size in "${batch_size_list[@]}"; do + + export output_model_dir=../output/"$model_3d"/"$dataset"/"$seed"/"$lr"_"$lr_scheduler"_"$batch_size"_"$epochs" + export output_file="$output_model_dir"/result.out + echo "$output_model_dir" + mkdir -p "$output_model_dir" + + python finetune_GO.py \ + --model_3d="$model_3d" --dataset="$dataset" --epochs="$epochs" \ + --seed="$seed" \ + --GO_level mf \ + --batch_size="$batch_size" --optimizer SGD --CDConv_base_width 32 \ + --lr="$lr" --lr_scheduler="$lr_scheduler" --print_every_epoch=1 \ + --output_model_dir="$output_model_dir" \ + > "$output_file" + +done +done diff --git a/scripts/GO/submit_CDConv_cc.sh b/scripts/GO/submit_CDConv_cc.sh new file mode 100644 index 0000000..b0e02ca --- /dev/null +++ b/scripts/GO/submit_CDConv_cc.sh @@ -0,0 +1,36 @@ +cd ../../examples_3D + +export model_3d=CDConv + +export lr_list=(1e-4 1e-3) +export seed=42 +export batch_size_list=(8 16 64) + + +export epochs=200 +export dataset=GO_cc +export StepLRCustomized_scheduler="60" + + +export lr_scheduler=CosineAnnealingLR + +for lr in "${lr_list[@]}"; do + +for batch_size in "${batch_size_list[@]}"; do + + export output_model_dir=../output/"$model_3d"/"$dataset"/"$seed"/"$lr"_"$lr_scheduler"_"$batch_size"_"$epochs" + export output_file="$output_model_dir"/result.out + echo "$output_model_dir" + mkdir -p "$output_model_dir" + + python finetune_GO.py \ + --model_3d="$model_3d" --dataset="$dataset" --epochs="$epochs" \ + --seed="$seed" \ + --GO_level cc \ + --batch_size="$batch_size" --optimizer SGD --CDConv_base_width 32 \ + --lr="$lr" --lr_scheduler="$lr_scheduler" --print_every_epoch=1 \ + --output_model_dir="$output_model_dir" \ + > "$output_file" + +done +done diff --git a/scripts/GO/submit_CDConv_mf.sh b/scripts/GO/submit_CDConv_mf.sh new file mode 100644 index 0000000..951ebdb --- /dev/null +++ b/scripts/GO/submit_CDConv_mf.sh @@ -0,0 +1,36 @@ +cd ../../examples_3D + +export model_3d=CDConv + +export lr_list=(1e-4 1e-3) +export seed=42 +export batch_size_list=(8 16 64) + + +export epochs=200 +export dataset=GO_bp +export StepLRCustomized_scheduler="60" + + +export lr_scheduler=CosineAnnealingLR + +for lr in "${lr_list[@]}"; do + +for batch_size in "${batch_size_list[@]}"; do + + export output_model_dir=../output/"$model_3d"/"$dataset"/"$seed"/"$lr"_"$lr_scheduler"_"$batch_size"_"$epochs" + export output_file="$output_model_dir"/result.out + echo "$output_model_dir" + mkdir -p "$output_model_dir" + + python finetune_GO.py \ + --model_3d="$model_3d" --dataset="$dataset" --epochs="$epochs" \ + --seed="$seed" \ + --GO_level bp \ + --batch_size="$batch_size" --optimizer SGD --CDConv_base_width 32 \ + --lr="$lr" --lr_scheduler="$lr_scheduler" --print_every_epoch=1 \ + --output_model_dir="$output_model_dir" \ + > "$output_file" + +done +done diff --git a/scripts/GO/submit_GVP_bp.sh b/scripts/GO/submit_GVP_bp.sh new file mode 100644 index 0000000..2b496e6 --- /dev/null +++ b/scripts/GO/submit_GVP_bp.sh @@ -0,0 +1,36 @@ +cd ../../examples_3D + +export model_3d=GVP + +export lr_list=(1e-4 1e-3) +export seed=42 +export batch_size_list=(8 16) + + +export epochs=200 +export dataset=GO_bp +export StepLRCustomized_scheduler="60" + + +export lr_scheduler=CosineAnnealingLR + +for lr in "${lr_list[@]}"; do + +for batch_size in "${batch_size_list[@]}"; do + + export output_model_dir=../output/"$model_3d"/"$dataset"/"$seed"/"$lr"_"$lr_scheduler"_"$batch_size"_"$epochs" + export output_file="$output_model_dir"/result.out + echo "$output_model_dir" + mkdir -p "$output_model_dir" + + python finetune_GO.py \ + --model_3d="$model_3d" --epochs="$epochs" \ + --seed="$seed" \ + --GO_level bp \ + --batch_size="$batch_size" \ + --lr="$lr" --lr_scheduler="$lr_scheduler" --print_every_epoch=1 \ + --output_model_dir="$output_model_dir" \ + > "$output_file" + +done +done diff --git a/scripts/GO/submit_GVP_cc.sh b/scripts/GO/submit_GVP_cc.sh new file mode 100644 index 0000000..a987dfa --- /dev/null +++ b/scripts/GO/submit_GVP_cc.sh @@ -0,0 +1,36 @@ +cd ../../examples_3D + +export model_3d=GVP + +export lr_list=(1e-4 1e-3) +export seed=42 +export batch_size_list=(8 16 32) + + +export epochs=200 +export dataset=GO_cc +export StepLRCustomized_scheduler="60" + + +export lr_scheduler=CosineAnnealingLR + +for lr in "${lr_list[@]}"; do + +for batch_size in "${batch_size_list[@]}"; do + + export output_model_dir=../output/"$model_3d"/"$dataset"/"$seed"/"$lr"_"$lr_scheduler"_"$batch_size"_"$epochs" + export output_file="$output_model_dir"/result.out + echo "$output_model_dir" + mkdir -p "$output_model_dir" + + python finetune_GO.py \ + --model_3d="$model_3d" --epochs="$epochs" \ + --seed="$seed" \ + --GO_level cc \ + --batch_size="$batch_size" \ + --lr="$lr" --lr_scheduler="$lr_scheduler" --print_every_epoch=1 \ + --output_model_dir="$output_model_dir" \ + > "$output_file" + +done +done diff --git a/scripts/GO/submit_GVP_mf.sh b/scripts/GO/submit_GVP_mf.sh new file mode 100644 index 0000000..2c3f050 --- /dev/null +++ b/scripts/GO/submit_GVP_mf.sh @@ -0,0 +1,36 @@ +cd ../../examples_3D + +export model_3d=GVP + +export lr_list=(1e-4 1e-3) +export seed=42 +export batch_size_list=(8) + + +export epochs=200 +export dataset=GO_mf +export StepLRCustomized_scheduler="60" + + +export lr_scheduler=CosineAnnealingLR + +for lr in "${lr_list[@]}"; do + +for batch_size in "${batch_size_list[@]}"; do + + export output_model_dir=../output/"$model_3d"/"$dataset"/"$seed"/"$lr"_"$lr_scheduler"_"$batch_size"_"$epochs" + export output_file="$output_model_dir"/result.out + echo "$output_model_dir" + mkdir -p "$output_model_dir" + + python finetune_GO.py \ + --model_3d="$model_3d" --epochs="$epochs" \ + --seed="$seed" \ + --GO_level mf \ + --batch_size="$batch_size" \ + --lr="$lr" --lr_scheduler="$lr_scheduler" --print_every_epoch=1 \ + --output_model_dir="$output_model_dir" \ + > "$output_file" + +done +done diff --git a/scripts/GO/submit_GearNet_bp.sh b/scripts/GO/submit_GearNet_bp.sh new file mode 100644 index 0000000..b4acf33 --- /dev/null +++ b/scripts/GO/submit_GearNet_bp.sh @@ -0,0 +1,37 @@ +cd ../../examples_3D + +export model_3d=GearNet + +export lr_list=(1e-4 1e-3) +export seed=42 +export batch_size_list=(2 8 16) + + +export epochs=200 +export dataset=GO_bp +export StepLRCustomized_scheduler="60" + + +export lr_scheduler=ReduceLROnPlateau + +for lr in "${lr_list[@]}"; do + +for batch_size in "${batch_size_list[@]}"; do + + export output_model_dir=../output/"$model_3d"/"$dataset"/"$seed"/"$lr"_"$lr_scheduler"_"$batch_size"_"$epochs" + export output_file="$output_model_dir"/result.out + echo "$output_model_dir" + mkdir -p "$output_model_dir" + + python finetune_GO.py \ + --model_3d="$model_3d" --epochs="$epochs" \ + --seed="$seed" \ + --GO_level bp \ + --batch_size="$batch_size" --optimizer AdamW \ + --lr="$lr" --lr_scheduler="$lr_scheduler" --lr_decay_factor 0.6 --lr_decay_patience 5 \ + --print_every_epoch=1 \ + --output_model_dir="$output_model_dir" \ + > "$output_file" + +done +done diff --git a/scripts/GO/submit_GearNet_cc.sh b/scripts/GO/submit_GearNet_cc.sh new file mode 100644 index 0000000..e13b8e0 --- /dev/null +++ b/scripts/GO/submit_GearNet_cc.sh @@ -0,0 +1,37 @@ +cd ../../examples_3D + +export model_3d=GearNet + +export lr_list=(1e-4 1e-3) +export seed=42 +export batch_size_list=(2 8 16) + + +export epochs=200 +export dataset=GO_cc +export StepLRCustomized_scheduler="60" + + +export lr_scheduler=ReduceLROnPlateau + +for lr in "${lr_list[@]}"; do + +for batch_size in "${batch_size_list[@]}"; do + + export output_model_dir=../output/"$model_3d"/"$dataset"/"$seed"/"$lr"_"$lr_scheduler"_"$batch_size"_"$epochs" + export output_file="$output_model_dir"/result.out + echo "$output_model_dir" + mkdir -p "$output_model_dir" + + python finetune_GO.py \ + --model_3d="$model_3d" --epochs="$epochs" \ + --seed="$seed" \ + --GO_level cc \ + --batch_size="$batch_size" --optimizer AdamW \ + --lr="$lr" --lr_scheduler="$lr_scheduler" --lr_decay_factor 0.6 --lr_decay_patience 5 \ + --print_every_epoch=1 \ + --output_model_dir="$output_model_dir" \ + > "$output_file" + +done +done diff --git a/scripts/GO/submit_GearNet_mf.sh b/scripts/GO/submit_GearNet_mf.sh new file mode 100644 index 0000000..8a53cc3 --- /dev/null +++ b/scripts/GO/submit_GearNet_mf.sh @@ -0,0 +1,37 @@ +cd ../../examples_3D + +export model_3d=GearNet + +export lr_list=(1e-4 1e-3) +export seed=42 +export batch_size_list=(8) + + +export epochs=200 +export dataset=GO_mf +export StepLRCustomized_scheduler="60" + + +export lr_scheduler=ReduceLROnPlateau + +for lr in "${lr_list[@]}"; do + +for batch_size in "${batch_size_list[@]}"; do + + export output_model_dir=../output/"$model_3d"/"$dataset"/"$seed"/"$lr"_"$lr_scheduler"_"$batch_size"_"$epochs" + export output_file="$output_model_dir"/result.out + echo "$output_model_dir" + mkdir -p "$output_model_dir" + + python finetune_GO.py \ + --model_3d="$model_3d" --epochs="$epochs" \ + --seed="$seed" \ + --GO_level mf \ + --batch_size="$batch_size" --optimizer AdamW \ + --lr="$lr" --lr_scheduler="$lr_scheduler" --lr_decay_factor 0.6 --lr_decay_patience 5 \ + --print_every_epoch=1 \ + --output_model_dir="$output_model_dir" \ + > "$output_file" + +done +done diff --git a/scripts/GO/submit_ProNet_bp.sh b/scripts/GO/submit_ProNet_bp.sh new file mode 100644 index 0000000..cb7d41d --- /dev/null +++ b/scripts/GO/submit_ProNet_bp.sh @@ -0,0 +1,39 @@ +cd ../../examples_3D + +export model_3d=ProNet + +export lr_list=(1e-4 1e-3) +export seed=42 +export batch_size_list=(8 64) +export level_list=(allatom aminoacid backbone) + + +export epochs=300 +export dataset=GO_bp +export StepLRCustomized_scheduler="60" + + +export lr_scheduler=CosineAnnealingLR + +for lr in "${lr_list[@]}"; do +for batch_size in "${batch_size_list[@]}"; do +for level in "${level_list[@]}"; do + + export output_model_dir=../output/"$model_3d"/"$dataset"/"$seed"/"$lr"_"$lr_scheduler"_"$batch_size"_"$epochs"_"$level" + export output_file="$output_model_dir"/result.out + echo "$output_model_dir" + mkdir -p "$output_model_dir" + + python finetune_GO.py \ + --model_3d="$model_3d" --epochs="$epochs" \ + --ProNet_level="$level" \ + --GO_level bp \ + --seed="$seed" \ + --batch_size="$batch_size" \ + --lr="$lr" --lr_scheduler="$lr_scheduler" --print_every_epoch=1 \ + --output_model_dir="$output_model_dir" \ + > "$output_file" + +done +done +done diff --git a/scripts/GO/submit_ProNet_cc.sh b/scripts/GO/submit_ProNet_cc.sh new file mode 100644 index 0000000..cf6a912 --- /dev/null +++ b/scripts/GO/submit_ProNet_cc.sh @@ -0,0 +1,39 @@ +cd ../../examples_3D + +export model_3d=ProNet + +export lr_list=(1e-4 1e-3) +export seed=42 +export batch_size_list=(8 64) +export level_list=(allatom aminoacid backbone) + + +export epochs=300 +export dataset=GO_cc +export StepLRCustomized_scheduler="60" + + +export lr_scheduler=CosineAnnealingLR + +for lr in "${lr_list[@]}"; do +for batch_size in "${batch_size_list[@]}"; do +for level in "${level_list[@]}"; do + + export output_model_dir=../output/"$model_3d"/"$dataset"/"$seed"/"$lr"_"$lr_scheduler"_"$batch_size"_"$epochs"_"$level" + export output_file="$output_model_dir"/result.out + echo "$output_model_dir" + mkdir -p "$output_model_dir" + + python finetune_GO.py \ + --model_3d="$model_3d" --epochs="$epochs" \ + --ProNet_level="$level" \ + --GO_level cc \ + --seed="$seed" \ + --batch_size="$batch_size" \ + --lr="$lr" --lr_scheduler="$lr_scheduler" --print_every_epoch=1 \ + --output_model_dir="$output_model_dir" \ + > "$output_file" + +done +done +done diff --git a/scripts/GO/submit_ProNet_mf.sh b/scripts/GO/submit_ProNet_mf.sh new file mode 100644 index 0000000..9dcaeb8 --- /dev/null +++ b/scripts/GO/submit_ProNet_mf.sh @@ -0,0 +1,39 @@ +cd ../../examples_3D + +export model_3d=ProNet + +export lr_list=(1e-4 1e-3) +export seed=42 +export batch_size_list=(8 64) +export level_list=(allatom aminoacid backbone) + + +export epochs=300 +export dataset=GO_mf +export StepLRCustomized_scheduler="60" + + +export lr_scheduler=CosineAnnealingLR + +for lr in "${lr_list[@]}"; do +for batch_size in "${batch_size_list[@]}"; do +for level in "${level_list[@]}"; do + + export output_model_dir=../output/"$model_3d"/"$dataset"/"$seed"/"$lr"_"$lr_scheduler"_"$batch_size"_"$epochs"_"$level" + export output_file="$output_model_dir"/result.out + echo "$output_model_dir" + mkdir -p "$output_model_dir" + + python finetune_GO.py \ + --model_3d="$model_3d" --epochs="$epochs" \ + --ProNet_level="$level" \ + --GO_level mf \ + --seed="$seed" \ + --batch_size="$batch_size" \ + --lr="$lr" --lr_scheduler="$lr_scheduler" --print_every_epoch=1 \ + --output_model_dir="$output_model_dir" \ + > "$output_file" + +done +done +done diff --git a/scripts/MSP/submit_CDConv.sh b/scripts/MSP/submit_CDConv.sh new file mode 100644 index 0000000..32aaefd --- /dev/null +++ b/scripts/MSP/submit_CDConv.sh @@ -0,0 +1,35 @@ +cd ../../examples_3D + +export model_3d=CDConv + +export lr_list=(1e-4 1e-3) +export seed=42 +export batch_size_list=(8 16 64) + + +export epochs=300 +export dataset=MSP +export StepLRCustomized_scheduler="60" + + +export lr_scheduler=CosineAnnealingLR + +for lr in "${lr_list[@]}"; do + +for batch_size in "${batch_size_list[@]}"; do + + export output_model_dir=../output/"$model_3d"/"$dataset"/"$seed"/"$lr"_"$lr_scheduler"_"$batch_size"_"$epochs" + export output_file="$output_model_dir"/result.out + echo "$output_model_dir" + mkdir -p "$output_model_dir" + + python finetune_MSP.py \ + --model_3d="$model_3d" --dataset="$dataset" --epochs="$epochs" \ + --seed="$seed" \ + --batch_size="$batch_size" \ + --lr="$lr" --lr_scheduler="$lr_scheduler" --print_every_epoch=1 \ + --output_model_dir="$output_model_dir" \ + > "$output_file" + +done +done diff --git a/scripts/MSP/submit_GVP.sh b/scripts/MSP/submit_GVP.sh new file mode 100644 index 0000000..83390bd --- /dev/null +++ b/scripts/MSP/submit_GVP.sh @@ -0,0 +1,35 @@ +cd ../../examples_3D + +export model_3d=GVP + +export lr_list=(1e-4 1e-3) +export seed=42 +export batch_size_list=(8 16 64) + + +export epochs=300 +export dataset=MSP +export StepLRCustomized_scheduler="60" + + +export lr_scheduler=CosineAnnealingLR + +for lr in "${lr_list[@]}"; do + +for batch_size in "${batch_size_list[@]}"; do + + export output_model_dir=../output/"$model_3d"/"$dataset"/"$seed"/"$lr"_"$lr_scheduler"_"$batch_size"_"$epochs" + export output_file="$output_model_dir"/result.out + echo "$output_model_dir" + mkdir -p "$output_model_dir" + + python finetune_MSP.py \ + --model_3d="$model_3d" --dataset="$dataset" --epochs="$epochs" \ + --seed="$seed" \ + --batch_size="$batch_size" \ + --lr="$lr" --lr_scheduler="$lr_scheduler" --print_every_epoch=1 \ + --output_model_dir="$output_model_dir" \ + > "$output_file" + +done +done diff --git a/scripts/MSP/submit_ProNet.sh b/scripts/MSP/submit_ProNet.sh new file mode 100644 index 0000000..55b3196 --- /dev/null +++ b/scripts/MSP/submit_ProNet.sh @@ -0,0 +1,38 @@ +cd ../../examples_3D + +export model_3d=ProNet + +export lr_list=(1e-4 1e-3) +export seed=42 +export batch_size_list=(8 16 64) +export level_list=(allatom aminoacid backbone) + + +export epochs=300 +export dataset=MSP +export StepLRCustomized_scheduler="60" + + +export lr_scheduler=CosineAnnealingLR + +for lr in "${lr_list[@]}"; do +for batch_size in "${batch_size_list[@]}"; do +for level in "${level_list[@]}"; do + + export output_model_dir=../output/"$model_3d"/"$dataset"/"$seed"/"$lr"_"$lr_scheduler"_"$batch_size"_"$epochs"_"$level" + export output_file="$output_model_dir"/result.out + echo "$output_model_dir" + mkdir -p "$output_model_dir" + + python finetune_MSP.py \ + --model_3d="$model_3d" --epochs="$epochs" \ + --ProNet_level="$level" \ + --seed="$seed" \ + --batch_size="$batch_size" \ + --lr="$lr" --lr_scheduler="$lr_scheduler" --print_every_epoch=1 \ + --output_model_dir="$output_model_dir" \ + > "$output_file" + +done +done +done diff --git a/scripts/PSR/submit_CDConv.sh b/scripts/PSR/submit_CDConv.sh new file mode 100644 index 0000000..55891d6 --- /dev/null +++ b/scripts/PSR/submit_CDConv.sh @@ -0,0 +1,35 @@ +cd ../../examples_3D + +export model_3d=CDConv + +export lr_list=(1e-4 1e-3) +export seed=42 +export batch_size_list=(8 16 64) + + +export epochs=300 +export dataset=PSR +export StepLRCustomized_scheduler="60" + + +export lr_scheduler=CosineAnnealingLR + +for lr in "${lr_list[@]}"; do + +for batch_size in "${batch_size_list[@]}"; do + + export output_model_dir=../output/"$model_3d"/"$dataset"/"$seed"/"$lr"_"$lr_scheduler"_"$batch_size"_"$epochs" + export output_file="$output_model_dir"/result.out + echo "$output_model_dir" + mkdir -p "$output_model_dir" + + python finetune_PSR.py \ + --model_3d="$model_3d" --dataset="$dataset" --epochs="$epochs" \ + --seed="$seed" \ + --batch_size="$batch_size" \ + --lr="$lr" --lr_scheduler="$lr_scheduler" --print_every_epoch=1 \ + --output_model_dir="$output_model_dir" \ + > "$output_file" + +done +done diff --git a/scripts/PSR/submit_GVP.sh b/scripts/PSR/submit_GVP.sh new file mode 100644 index 0000000..4437302 --- /dev/null +++ b/scripts/PSR/submit_GVP.sh @@ -0,0 +1,35 @@ +cd ../../examples_3D + +export model_3d=GVP + +export lr_list=(1e-4 1e-3) +export seed=42 +export batch_size_list=(8 16 64) + + +export epochs=300 +export dataset=PSR +export StepLRCustomized_scheduler="60" + + +export lr_scheduler=CosineAnnealingLR + +for lr in "${lr_list[@]}"; do + +for batch_size in "${batch_size_list[@]}"; do + + export output_model_dir=../output/"$model_3d"/"$dataset"/"$seed"/"$lr"_"$lr_scheduler"_"$batch_size"_"$epochs" + export output_file="$output_model_dir"/result.out + echo "$output_model_dir" + mkdir -p "$output_model_dir" + + python finetune_PSR.py \ + --model_3d="$model_3d" --dataset="$dataset" --epochs="$epochs" \ + --seed="$seed" \ + --batch_size="$batch_size" \ + --lr="$lr" --lr_scheduler="$lr_scheduler" --print_every_epoch=1 \ + --output_model_dir="$output_model_dir" \ + > "$output_file" + +done +done diff --git a/scripts/PSR/submit_ProNet.sh b/scripts/PSR/submit_ProNet.sh new file mode 100644 index 0000000..634b963 --- /dev/null +++ b/scripts/PSR/submit_ProNet.sh @@ -0,0 +1,38 @@ +cd ../../examples_3D + +export model_3d=ProNet + +export lr_list=(1e-4 1e-3) +export seed=42 +export batch_size_list=(8) +export level_list=(allatom aminoacid backbone) + + +export epochs=300 +export dataset=PSR +export StepLRCustomized_scheduler="60" + + +export lr_scheduler=CosineAnnealingLR + +for lr in "${lr_list[@]}"; do +for batch_size in "${batch_size_list[@]}"; do +for level in "${level_list[@]}"; do + + export output_model_dir=../output/"$model_3d"/"$dataset"/"$seed"/"$lr"_"$lr_scheduler"_"$batch_size"_"$epochs"_"$level" + export output_file="$output_model_dir"/result.out + echo "$output_model_dir" + mkdir -p "$output_model_dir" + + python finetune_PSR.py \ + --model_3d="$model_3d" --epochs="$epochs" \ + --ProNet_level="$level" \ + --seed="$seed" \ + --batch_size="$batch_size" \ + --lr="$lr" --lr_scheduler="$lr_scheduler" --print_every_epoch=1 \ + --output_model_dir="$output_model_dir" \ + > "$output_file" + +done +done +done From 7519f24b104cfcd1b8e15f8a2b7caf8b3357cd9d Mon Sep 17 00:00:00 2001 From: YanjingLiLi Date: Wed, 5 Jun 2024 11:14:48 +0800 Subject: [PATCH 4/5] edit --- scripts/ECSingle/submit_CDConv.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/ECSingle/submit_CDConv.sh b/scripts/ECSingle/submit_CDConv.sh index 50d0561..804209c 100644 --- a/scripts/ECSingle/submit_CDConv.sh +++ b/scripts/ECSingle/submit_CDConv.sh @@ -8,7 +8,7 @@ export batch_size_list=(8 16 64) export epochs=400 -export dataset=fold +export dataset=ECSingle export StepLRCustomized_scheduler="60" From 769d9f81094f5a94e0eb3ca673e43c13dc4f507a Mon Sep 17 00:00:00 2001 From: YanjingLiLi Date: Wed, 5 Jun 2024 11:18:50 +0800 Subject: [PATCH 5/5] edit --- scripts/GO/submit_CDConv_bp.sh | 4 ++-- scripts/GO/submit_CDConv_mf.sh | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/scripts/GO/submit_CDConv_bp.sh b/scripts/GO/submit_CDConv_bp.sh index 660ffef..951ebdb 100644 --- a/scripts/GO/submit_CDConv_bp.sh +++ b/scripts/GO/submit_CDConv_bp.sh @@ -8,7 +8,7 @@ export batch_size_list=(8 16 64) export epochs=200 -export dataset=GO_mf +export dataset=GO_bp export StepLRCustomized_scheduler="60" @@ -26,7 +26,7 @@ for batch_size in "${batch_size_list[@]}"; do python finetune_GO.py \ --model_3d="$model_3d" --dataset="$dataset" --epochs="$epochs" \ --seed="$seed" \ - --GO_level mf \ + --GO_level bp \ --batch_size="$batch_size" --optimizer SGD --CDConv_base_width 32 \ --lr="$lr" --lr_scheduler="$lr_scheduler" --print_every_epoch=1 \ --output_model_dir="$output_model_dir" \ diff --git a/scripts/GO/submit_CDConv_mf.sh b/scripts/GO/submit_CDConv_mf.sh index 951ebdb..660ffef 100644 --- a/scripts/GO/submit_CDConv_mf.sh +++ b/scripts/GO/submit_CDConv_mf.sh @@ -8,7 +8,7 @@ export batch_size_list=(8 16 64) export epochs=200 -export dataset=GO_bp +export dataset=GO_mf export StepLRCustomized_scheduler="60" @@ -26,7 +26,7 @@ for batch_size in "${batch_size_list[@]}"; do python finetune_GO.py \ --model_3d="$model_3d" --dataset="$dataset" --epochs="$epochs" \ --seed="$seed" \ - --GO_level bp \ + --GO_level mf \ --batch_size="$batch_size" --optimizer SGD --CDConv_base_width 32 \ --lr="$lr" --lr_scheduler="$lr_scheduler" --print_every_epoch=1 \ --output_model_dir="$output_model_dir" \