You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

models.py 2.4KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from torch_geometric.nn import GCNConv
  5. # a simple base GCN model
  6. # class GCN(torch.nn.Module):
  7. # def __init__(self, in_channels, out_channels):
  8. # super().__init__()
  9. # torch.manual_seed(1234)
  10. # self.conv = GCNConv(in_channels, out_channels, add_self_loops=False)
  11. # def forward(self, x, edge_index, edge_weight=None):
  12. # x = F.dropout(x, p=0.5, training=self.training)
  13. # x = self.conv(x, edge_index, edge_weight).relu()
  14. # return x
  15. # base from this notebook: https://colab.research.google.com/drive/1LJir3T6M6Omc2Vn2GV2cDW_GV2YfI53_?usp=sharing#scrollTo=jNsToorfSgS0
  16. class GCN(torch.nn.Module):
  17. def __init__(self, num_features, hidden_channels): # num_features = dataset.num_features
  18. super(GCN, self).__init__()
  19. torch.manual_seed(42)
  20. # Initialize the layers
  21. self.conv1 = GCNConv(num_features, hidden_channels)
  22. self.conv2 = GCNConv(hidden_channels, num_features)
  23. def forward(self, x, edge_index):
  24. # First Message Passing Layer (Transformation)
  25. x = self.conv1(x, edge_index)
  26. x = x.relu()
  27. x = F.dropout(x, p=0.5, training=self.training)
  28. # Second Message Passing Layer
  29. x = self.conv2(x, edge_index)
  30. x = x.relu()
  31. return x
  32. class MLP(nn.Module):
  33. def __init__(self, input_size: int, hidden_size: int):
  34. super(MLP, self).__init__()
  35. self.layers = nn.Sequential(
  36. nn.Linear(input_size, hidden_size),
  37. nn.ReLU(),
  38. nn.BatchNorm1d(hidden_size),
  39. nn.Linear(hidden_size, hidden_size // 2),
  40. nn.ReLU(),
  41. nn.BatchNorm1d(hidden_size // 2),
  42. nn.Linear(hidden_size // 2, 1)
  43. )
  44. def forward(self, drug1_feat: torch.Tensor, drug2_feat: torch.Tensor, cell_feat: torch.Tensor):
  45. feat = torch.cat([drug1_feat, drug2_feat, cell_feat], 1)
  46. out = self.layers(feat)
  47. return out
  48. # model = GCN(dataset.num_features, dataset.num_classes)
  49. # model.train()
  50. # optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
  51. # print("Training on CPU.")
  52. # for epoch in range(1, 6):
  53. # optimizer.zero_grad()
  54. # out = model(data.x, data.edge_index, data.edge_attr)
  55. # loss = F.cross_entropy(out, data.y)
  56. # loss.backward()
  57. # optimizer.step()
  58. # print(f"Epoch: {epoch}, Loss: {loss}")