You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

models.py 4.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import os
  5. import sys
  6. import pandas as pd
  7. import time
  8. PROJ_DIR = os.path.dirname(os.path.abspath(os.path.join(os.path.dirname( __file__ ), '..')))
  9. sys.path.insert(0, PROJ_DIR)
  10. from drug.models import GCN
  11. from drug.datasets import DDInteractionDataset
  12. from model.utils import get_FP_by_negative_index, get_FP_by_negative_indices
  13. from const import Drug2FP_FILE
  14. class Connector(nn.Module):
  15. def __init__(self, gpu_id=None):
  16. self.gpu_id = gpu_id
  17. super(Connector, self).__init__()
  18. # self.ddiDataset = DDInteractionDataset(gpu_id = gpu_id)
  19. self.gcn = None
  20. self.drug2FP_df = pd.read_csv(Drug2FP_FILE)
  21. #Cell line features
  22. # np.load('cell_feat.npy')
  23. def forward(self, drug1_idx, drug2_idx, cell_feat, subgraph):
  24. if self.gcn == None:
  25. # print("here is for initializing the GCN. num_features: ", subgraph.num_features)
  26. self.gcn = GCN(subgraph.num_features, subgraph.num_features // 2)
  27. # print("this is subgraph in connector model forward: --------------")
  28. # print(subgraph)
  29. # graph.get().x --> DDInteractionDataset
  30. # subgraph = graph.get() --> Data
  31. x = subgraph.x
  32. edge_index = subgraph.edge_index
  33. x = self.gcn(x, edge_index)
  34. # print("node local indices:")
  35. node_indices = edge_index.flatten().unique()
  36. # print(node_indices)
  37. # print("-----------------------")
  38. # print("node global indices:")
  39. node_indices = subgraph.n_id
  40. if self.gpu_id is not None:
  41. node_indices = node_indices.cuda(self.gpu_id)
  42. # print(node_indices)
  43. drug1_idx = torch.flatten(drug1_idx)
  44. drug2_idx = torch.flatten(drug2_idx)
  45. #drug1_feat = x[drug1_idx]
  46. #drug2_feat = x[drug2_idx]
  47. drug1_feat = torch.empty((len(drug1_idx), len(x[0])))
  48. drug2_feat = torch.empty((len(drug2_idx), len(x[0])))
  49. print("x shape: ", x.size())
  50. print("node_indices: ", node_indices.size())
  51. start_time = time.time()
  52. # for index, element in enumerate(drug1_idx):
  53. # x_element = element
  54. # if element >= 0:
  55. # x_element = (node_indices == element).nonzero().squeeze()
  56. # drug1_feat[index] = (x[x_element])
  57. # for index, element in enumerate(drug2_idx):
  58. # x_element = element
  59. # if element >= 0:
  60. # x_element = (node_indices == element).nonzero().squeeze()
  61. # drug2_feat[index] = (x[x_element])
  62. mask_positive = (drug1_idx >= 0)
  63. x_elements_positive = (node_indices.unsqueeze(-1) == drug1_idx[mask_positive]).nonzero(as_tuple=True)[0]
  64. drug1_feat[mask_positive] = x[x_elements_positive]
  65. mask_negative = ~mask_positive
  66. drug1_feat[mask_negative] = get_FP_by_negative_indices(drug1_idx[mask_negative], self.drug2FP_df)
  67. mask_positive = (drug2_idx >= 0)
  68. x_elements_positive = (node_indices.unsqueeze(-1) == drug2_idx[mask_positive]).nonzero(as_tuple=True)[0]
  69. drug2_feat[mask_positive] = x[x_elements_positive]
  70. if self.gpu_id is not None:
  71. drug1_feat = drug1_feat.cuda(self.gpu_id)
  72. drug2_feat = drug2_feat.cuda(self.gpu_id)
  73. print("first: ", time.time() - start_time)
  74. start_time = time.time()
  75. for i, x in enumerate(drug1_idx):
  76. if x < 0:
  77. drug1_feat[i] = get_FP_by_negative_index(x,self.drug2FP_df)
  78. for i, x in enumerate(drug2_idx):
  79. if x < 0:
  80. drug2_feat[i] = get_FP_by_negative_index(x,self.drug2FP_df)
  81. print("second: ", time.time() - start_time)
  82. feat = torch.cat([drug1_feat, drug2_feat, cell_feat], 1)
  83. return feat
  84. class MLP(nn.Module):
  85. def __init__(self, input_size: int, hidden_size: int, gpu_id=None):
  86. super(MLP, self).__init__()
  87. self.layers = nn.Sequential(
  88. nn.Linear(input_size, hidden_size),
  89. nn.ReLU(),
  90. nn.BatchNorm1d(hidden_size),
  91. nn.Linear(hidden_size, hidden_size // 2),
  92. nn.ReLU(),
  93. nn.BatchNorm1d(hidden_size // 2),
  94. nn.Linear(hidden_size // 2, 1)
  95. )
  96. self.connector = Connector(gpu_id)
  97. # prev input: self, drug1_feat: torch.Tensor, drug2_feat: torch.Tensor, cell_feat: torch.Tensor, subgraph: related subgraph for the batch
  98. def forward(self, drug1_idx, drug2_idx, cell_feat, subgraph):
  99. start = time.time()
  100. feat = self.connector(drug1_idx, drug2_idx, cell_feat, subgraph)
  101. print("Connector forward time: ", time.time() - start)
  102. out = self.layers(feat)
  103. return out
  104. # other PRODeepSyn models have been deleted for now