You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

models.py 4.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import os
  5. import sys
  6. from torch_geometric.data import DataLoader
  7. import time
  8. PROJ_DIR = os.path.dirname(os.path.abspath(os.path.join(os.path.dirname( __file__ ), '..')))
  9. sys.path.insert(0, PROJ_DIR)
  10. from drug.models import GCN, TransformerGNN
  11. from drug.datasets import DDInteractionDataset, MoleculeDataset
  12. from model.utils import get_FP_by_negative_index
  13. from config import DRUG_MODEL_HYPERPARAMETERS
  14. class Connector(nn.Module):
  15. def __init__(self, gpu_id=None):
  16. self.gpu_id = gpu_id
  17. super(Connector, self).__init__()
  18. self.moleculeDataset = MoleculeDataset(root = "drug/data/")
  19. drug_model_params = DRUG_MODEL_HYPERPARAMETERS
  20. drug_model_params["model_edge_dim"] = self.moleculeDataset[0].edge_attr.shape[1]
  21. self.transformerGNN = TransformerGNN(feature_size=self.moleculeDataset[0].x.shape[1], model_params=drug_model_params).cuda(self.gpu_id)
  22. self.gcn = GCN(self.moleculeDataset[0].x.shape[1], self.moleculeDataset[0].x.shape[1] * 2).cuda(self.gpu_id)
  23. def forward(self, drug1_idx, drug2_idx, drug1_fp, drug2_fp, drug1_dti, drug2_dti, cell_feat):
  24. drug1_idx = torch.flatten(drug1_idx)
  25. drug2_idx = torch.flatten(drug2_idx)
  26. drug1_idx = drug1_idx.type(torch.long)
  27. drug2_idx = drug2_idx.type(torch.long)
  28. drug_index = torch.unique(torch.cat((drug1_idx,drug2_idx)))
  29. # start = time.time()
  30. subset = torch.utils.data.Subset(self.moleculeDataset, drug_index)
  31. feat_d = int(DRUG_MODEL_HYPERPARAMETERS["model_dense_neurons"]/2)
  32. all_drug_feat = torch.empty((0, feat_d), dtype=torch.float32)
  33. # GCN
  34. # all_drug_feat = torch.empty((0, self.moleculeDataset[0].x.shape[1] * 2), dtype=torch.float32)
  35. if self.gpu_id is not None:
  36. all_drug_feat = all_drug_feat.cuda(self.gpu_id)
  37. train_loader = DataLoader(subset, batch_size=DRUG_MODEL_HYPERPARAMETERS["batch_size"], shuffle=True)
  38. for _, batch in enumerate(train_loader):
  39. if self.gpu_id is not None:
  40. batch = batch.cuda(self.gpu_id)
  41. x = batch.x
  42. edge_index = batch.edge_index
  43. edge_attr = batch.edge_attr
  44. # Passing the node features and the connection info
  45. # GCN:
  46. # drug_feat = self.gcn(x.float(),
  47. # edge_index)
  48. drug_feat = self.transformerGNN(x.float(),
  49. edge_attr.float(),
  50. edge_index,
  51. batch.batch)
  52. all_drug_feat = torch.cat((all_drug_feat, drug_feat), 0)
  53. # print("molecule loop end time:", time.time() - start)
  54. value_to_index = {value.item(): index for index, value in enumerate(drug_index)}
  55. drug1_idx = torch.tensor([value_to_index[value.item()] for value in drug1_idx])
  56. drug1_feat = all_drug_feat[drug1_idx]
  57. drug2_idx = torch.tensor([value_to_index[value.item()] for value in drug2_idx])
  58. drug2_feat = all_drug_feat[drug2_idx]
  59. feat = torch.cat([drug1_feat, drug2_feat, drug1_fp, drug2_fp, drug1_dti, drug2_dti, cell_feat], 1)
  60. return feat
  61. class MLP(nn.Module):
  62. def __init__(self, input_size: int, hidden_size: int, gpu_id=None):
  63. super(MLP, self).__init__()
  64. self.attention = nn.MultiheadAttention(embed_dim=input_size, num_heads=8)
  65. self.linear1 = nn.Linear(input_size, hidden_size)
  66. self.relu = nn.ReLU()
  67. self.batch_norm1 = nn.BatchNorm1d(hidden_size)
  68. self.linear2 = nn.Linear(hidden_size, hidden_size // 2)
  69. self.batch_norm2 = nn.BatchNorm1d(hidden_size // 2)
  70. self.output = nn.Linear(hidden_size // 2, 1)
  71. self.connector = Connector(gpu_id)
  72. def forward(self, drug1_idx, drug2_idx, drug1_fp, drug2_fp, drug1_dti, drug2_dti, cell_feat):
  73. feat = self.connector(drug1_idx, drug2_idx, drug1_fp, drug2_fp, drug1_dti, drug2_dti, cell_feat)
  74. # Reshape feat for attention layer, assuming feat is [batch_size, seq_len, input_size]
  75. feat = feat.unsqueeze(0) # Add a dummy batch dimension if necessary
  76. attn_output, attn_output_weights = self.attention(feat, feat, feat)
  77. # Reshape back if needed
  78. attn_output = attn_output.squeeze(0)
  79. out = self.linear1(attn_output)
  80. out = self.relu(out)
  81. out = self.batch_norm1(out)
  82. out = self.linear2(out)
  83. out = self.relu(out)
  84. out = self.batch_norm2(out)
  85. out = self.output(out)
  86. return out