PyTorch implementation of Dynamic Graph Convolutional Network
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

data_loader.py 5.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. import scipy.io as sio
  2. import numpy as np
  3. import torch
  4. class DataLoader:
  5. """Data Loader class"""
  6. def __init__(self, S_train, S_val, S_test, data_dir, mat_file_path):
  7. self.S_train = S_train
  8. self.S_val = S_val
  9. self.S_test = S_test
  10. self.data_root_dir = data_dir
  11. self.mat_file_path = mat_file_path
  12. self.saved_content = sio.loadmat(self.data_root_dir + self.mat_file_path)
  13. self.T = np.max(self.saved_content["A_labels_subs"][0, :]) + 1
  14. print("Number of total graphs: {}".format(self.T))
  15. self.N = max(np.max(self.saved_content["A_labels_subs"][1, :]),
  16. np.max(self.saved_content["A_labels_subs"][2, :])) + 1
  17. print("Number of nodes in each graph: {}".format(self.N))
  18. self.A_size = torch.Size([self.T, self.N, self.N]) # size of the adjacency matrix
  19. self.C_size = torch.Size([self.T, self.N, self.N]) # similar to the adjacency matrix, C tensor has TxNxN shape
  20. # labels of the edges
  21. self.A_labels = torch.sparse.FloatTensor(torch.tensor(self.saved_content["A_labels_subs"], dtype=torch.long),
  22. torch.squeeze(torch.tensor(self.saved_content["A_labels_vals"])),
  23. self.A_size).coalesce()
  24. # Laplacian Transformed of adjacency matrix
  25. self.C = torch.sparse.FloatTensor(torch.tensor(self.saved_content["C_subs"], dtype=torch.long),
  26. torch.squeeze(torch.tensor(self.saved_content["C_vals"])),
  27. self.C_size).coalesce()
  28. # adjacency matrix
  29. self.A = torch.sparse.FloatTensor(self.A_labels._indices(),
  30. torch.ones(self.A_labels._values().shape),
  31. self.A_size).coalesce()
  32. # create node features
  33. self.X = self.create_node_features()
  34. def split_data(self):
  35. C_train = []
  36. for j in range(self.S_train):
  37. idx = self.C._indices()[0] == j
  38. C_train.append(torch.sparse.FloatTensor(self.C._indices()[1:3, idx],
  39. self.C._values()[idx]))
  40. C_val = []
  41. for j in range(self.S_train, self.S_train + self.S_val):
  42. idx = self.C._indices()[0] == j
  43. C_val.append(torch.sparse.FloatTensor(self.C._indices()[1:3, idx],
  44. self.C._values()[idx]))
  45. C_test = []
  46. for j in range(self.S_train+self.S_test, self.S_train + self.S_val + self.S_test):
  47. idx = self.C._indices()[0] == j
  48. C_test.append(torch.sparse.FloatTensor(self.C._indices()[1:3, idx],
  49. self.C._values()[idx]))
  50. C = {'C_train': C_train,
  51. 'C_val': C_val,
  52. 'C_test': C_test}
  53. X_train = self.X[0:self.S_train].double()
  54. X_val = self.X[self.S_train:self.S_train + self.S_val].double()
  55. X_test = self.X[self.S_train + self.S_val:].double()
  56. data = {'X_train' : X_train,
  57. 'X_val': X_val,
  58. 'X_test': X_test}
  59. return data, C
  60. def get_edges_and_labels(self):
  61. # training
  62. subs_train = self.A_labels._indices()[0] < self.S_train
  63. edges_train = self.A_labels._indices()[:, subs_train]
  64. labels_train = torch.sign(self.A_labels._values()[subs_train])
  65. target_train = (labels_train != -1).long() # element = 0 if class = -1; and 1 if class is 0 or +1
  66. # validation
  67. subs_val = (self.A_labels._indices()[0] >= self.S_train) & (self.A_labels._indices()[0] < self.S_train + self.S_val)
  68. edges_val = self.A_labels._indices()[:, subs_val]
  69. edges_val[0] -= self.S_train
  70. labels_val = torch.sign(self.A_labels._values()[subs_val])
  71. target_val = (labels_val != -1).long()
  72. # Testing
  73. subs_test = (self.A_labels._indices()[0] >= self.S_train + self.S_val)
  74. edges_test = self.A_labels._indices()[:, subs_test]
  75. edges_test[0] -= (self.S_train + self.S_val)
  76. labels_test = torch.sign(self.A_labels._values()[subs_test])
  77. target_test = (labels_test != -1).long()
  78. targets = {'target_train': target_train,
  79. 'target_val': target_val,
  80. 'target_test': target_test}
  81. edges = {'edges_train': edges_train,
  82. 'edges_val': edges_val,
  83. 'edges_test': edges_test}
  84. return targets, edges
  85. def create_node_features(self):
  86. X = torch.zeros(self.A.shape[0], self.A.shape[1], 2)
  87. X[:, :, 0] = torch.sparse.sum(self.A, 1).to_dense() # number of outgoing edges
  88. X[:, :, 1] = torch.sparse.sum(self.A, 2).to_dense() # number of in coming edges
  89. return X
  90. def load_data(self):
  91. print("Loading the data...")
  92. data, C = self.split_data()
  93. targets, edges = self.get_edges_and_labels()
  94. print("======================")
  95. return data, C, targets, edges