You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

models.py 2.9KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import os
  5. import sys
  6. PROJ_DIR = os.path.dirname(os.path.abspath(os.path.join(os.path.dirname( __file__ ), '..')))
  7. sys.path.insert(0, PROJ_DIR)
  8. from drug.models import GCN
  9. from drug.datasets import DDInteractionDataset
  10. from model.utils import get_FP_by_negative_index
  11. class Connector(nn.Module):
  12. def __init__(self, gpu_id=None):
  13. self.gpu_id = gpu_id
  14. super(Connector, self).__init__()
  15. # self.ddiDataset = DDInteractionDataset(gpu_id = gpu_id)
  16. self.gcn = None
  17. #Cell line features
  18. # np.load('cell_feat.npy')
  19. def forward(self, drug1_idx, drug2_idx, cell_feat, subgraph):
  20. if self.gcn == None:
  21. print("here is for initializing the GCN. num_features: ", subgraph.num_features)
  22. self.gcn = GCN(subgraph.num_features, subgraph.num_features // 2)
  23. print("this is subgraph: --------------")
  24. print(subgraph)
  25. # graph.get().x --> DDInteractionDataset
  26. # subgraph = graph.get() --> Data
  27. x = subgraph.x
  28. edge_index = subgraph.edge_index
  29. x = self.gcn(x, edge_index)
  30. drug1_idx = torch.flatten(drug1_idx)
  31. drug2_idx = torch.flatten(drug2_idx)
  32. #drug1_feat = x[drug1_idx]
  33. #drug2_feat = x[drug2_idx]
  34. drug1_feat = torch.empty((len(drug1_idx), len(x[0])))
  35. drug2_feat = torch.empty((len(drug2_idx), len(x[0])))
  36. for index, element in enumerate(drug1_idx):
  37. drug1_feat[index] = (x[element])
  38. for index, element in enumerate(drug2_idx):
  39. drug2_feat[index] = (x[element])
  40. if self.gpu_id is not None:
  41. drug1_feat = drug1_feat.cuda(self.gpu_id)
  42. drug2_feat = drug2_feat.cuda(self.gpu_id)
  43. for i, x in enumerate(drug1_idx):
  44. if x < 0:
  45. drug1_feat[i] = get_FP_by_negative_index(x)
  46. for i, x in enumerate(drug2_idx):
  47. if x < 0:
  48. drug2_feat[i] = get_FP_by_negative_index(x)
  49. feat = torch.cat([drug1_feat, drug2_feat, cell_feat], 1)
  50. return feat
  51. class MLP(nn.Module):
  52. def __init__(self, input_size: int, hidden_size: int, gpu_id=None):
  53. super(MLP, self).__init__()
  54. self.layers = nn.Sequential(
  55. nn.Linear(input_size, hidden_size),
  56. nn.ReLU(),
  57. nn.BatchNorm1d(hidden_size),
  58. nn.Linear(hidden_size, hidden_size // 2),
  59. nn.ReLU(),
  60. nn.BatchNorm1d(hidden_size // 2),
  61. nn.Linear(hidden_size // 2, 1)
  62. )
  63. self.connector = Connector(gpu_id)
  64. # prev input: self, drug1_feat: torch.Tensor, drug2_feat: torch.Tensor, cell_feat: torch.Tensor, subgraph: related subgraph for the batch
  65. def forward(self, drug1_idx, drug2_idx, cell_feat, subgraph):
  66. feat = self.connector(drug1_idx, drug2_idx, cell_feat, subgraph)
  67. out = self.layers(feat)
  68. return out
  69. # other PRODeepSyn models have been deleted for now