PyTorch implementation of Dynamic Graph Convolutional Network
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

read_data.py 11KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. import numpy as np
  2. import math
  3. import torch
  4. import scipy.io as sio
  5. import argparse
  6. parser = argparse.ArgumentParser()
  7. parser.add_argument('--data_path', type=str, help='path to the root of data directory')
  8. parser.add_argument('--dataset', type=str, help='dataset name')
  9. parser.add_argument('--time_partitioning', type=list, help='time partitioning for train/test/val split')
  10. args = parser.parse_args()
  11. #Settings
  12. edge_life = True
  13. edge_life_window = 10
  14. no_diag = 20
  15. make_symmetric = True
  16. print(args.dataset)
  17. def print_tensor(A, str):
  18. print('------------------------')
  19. print(str)
  20. print(A)
  21. print(torch.sum(A._values()))
  22. print('------------------------')
  23. if args.dataset == 'Bitcoin Alpha':
  24. data = np.loadtxt(args.data_path, delimiter=',')
  25. save_file_location = 'data/Bitcoin_Alpha/'
  26. save_file_name = 'saved_content_bitcoin_alpha.mat'
  27. time_delta = 60*60*24*14 # 2 weeks
  28. time_slices = args.time_partitioning
  29. no_train_samples, no_val_samples, no_test_samples = time_slices[0], time_slices[1], time_slices[2]
  30. if args.dataset == 'Bitcoin OTC':
  31. data = np.loadtxt(args.data_path, delimiter=',')
  32. save_file_location = 'data/Bitcoin_OTC/'
  33. save_file_name = 'saved_content_bitcoin_otc.mat'
  34. time_delta = 60*60*24*14 # 2 weeks
  35. time_slices = args.time_partitioning
  36. no_train_samples, no_val_samples, no_test_samples = time_slices[0], time_slices[1], time_slices[2]
  37. if args.dataset == 'Reddit':
  38. data = np.loadtxt(args.data_path)
  39. save_file_location = 'data/reddit/'
  40. save_file_name = 'saved_content_reddit.mat'
  41. time_delta = 60 * 60 * 24 * 14 # 2 weeks
  42. time_slices = args.time_partitioning
  43. no_train_samples, no_val_samples, no_test_samples = time_slices[0], time_slices[1], time_slices[2]
  44. elif args.dataset == 'Chess':
  45. data = np.loadtxt(args.data_path, delimiter=',', skiprows=1)
  46. #data = readmatrix('./data/chess/out.chess.csv')
  47. save_file_location = '/home/shivmaran/Desktop/Tensor-GCN/data/chess/'
  48. save_file_name = 'saved_content_python_chess.mat'
  49. time_delta = 60*60*24*31 # 31 days
  50. time_slices = args.time_partitioning
  51. no_train_samples, no_val_samples, no_test_samples = time_slices[0], time_slices[1], time_slices[2]
  52. else:
  53. print('Invalid dataset')
  54. #exit
  55. data = torch.tensor(data)
  56. # Create full tensor
  57. if args.dataset == 'Chess':
  58. dates = np.unique(data[:,3])
  59. no_time_slices = len(dates)
  60. else:
  61. no_time_slices = math.floor((max(data[:,3]) - min(data[:,3]))/time_delta)
  62. N = int(max(max(data[:,0]), max(data[:,1])))
  63. T = int(no_train_samples)
  64. TT = int(no_time_slices)
  65. #Create M
  66. M = np.zeros((T,T))
  67. for i in range(no_diag):
  68. A = M[i:, :T-i]
  69. np.fill_diagonal(A, 1)
  70. L = np.sum(M, axis=1)
  71. M = M/L[:,None]
  72. M = torch.tensor(M)
  73. #Create A and A_labels
  74. if not args.dataset == 'Chess':
  75. data = data[data[:,3] < min(data[:,3])+TT*time_delta]
  76. start_time = min(data[:,3]);
  77. tensor_idx = torch.zeros([data.size()[0], 3], dtype=torch.long)
  78. tensor_val = torch.ones([data.size()[0]], dtype=torch.double)
  79. tensor_labels = torch.zeros([data.size()[0]], dtype=torch.double)
  80. for t in range(TT):
  81. if args.dataset == 'Chess':
  82. idx = data[:,3] == dates[t]
  83. else:
  84. end_time = start_time + time_delta
  85. idx = (data[:, 3] >= start_time) & (data[:, 3] < end_time)
  86. start_time = end_time
  87. tensor_idx[idx, 1:3] = (data[idx, 0:2] - 1).type('torch.LongTensor')
  88. tensor_idx[idx, 0] = t
  89. tensor_labels[idx] = data[idx, 2].type('torch.DoubleTensor')
  90. A = torch.sparse.DoubleTensor(tensor_idx.transpose(1,0), tensor_val, torch.Size([TT, N, N])).coalesce()
  91. A_labels = torch.sparse.DoubleTensor(tensor_idx.transpose(1,0), tensor_labels, torch.Size([TT, N, N])).coalesce()
  92. def func_make_symmetric(sparse_tensor, N, TT):
  93. count = 0
  94. tensor_idx = torch.LongTensor([])
  95. tensor_val = torch.DoubleTensor([]).unsqueeze(1)
  96. A_idx = sparse_tensor._indices()
  97. A_val = sparse_tensor._values()
  98. for j in range(TT):
  99. idx = A_idx[0] == j
  100. mat = torch.sparse.DoubleTensor(A_idx[1:3,idx], A_val[idx], torch.Size([N,N]))
  101. mat_t = mat.transpose(1,0)
  102. sym_mat = mat + mat_t
  103. sym_mat = sym_mat/2
  104. count = count + sym_mat._nnz()
  105. vertices = torch.tensor(sym_mat._indices())
  106. time = torch.ones(sym_mat._nnz(), dtype=torch.long)* j
  107. time = time.unsqueeze(0)
  108. full = torch.cat((time,vertices),0)
  109. tensor_idx = torch.cat((tensor_idx,full),1)
  110. tensor_val = torch.cat((tensor_val, sym_mat._values().unsqueeze(1)),0)
  111. tensor_val.squeeze_(1)
  112. A = torch.sparse.DoubleTensor(tensor_idx, tensor_val, torch.Size([TT, N, N])).coalesce()
  113. return A
  114. if make_symmetric:
  115. B = func_make_symmetric(A, N, TT)
  116. else:
  117. B = A
  118. def func_edge_life(A, N, TT):
  119. A_new = A.clone()
  120. A_new._values()[:] = 0
  121. idx = A._indices()[0] == 0
  122. for t in range(TT):
  123. idx = (A._indices()[0] >= max(0, t-edge_life_window+1)) & (A._indices()[0] <= t)
  124. block = torch.sparse.DoubleTensor(A._indices()[0:3,idx], A._values()[idx], torch.Size([TT, N, N]))
  125. block._indices()[0] = t
  126. A_new = A_new + block
  127. return A_new.coalesce()
  128. if edge_life:
  129. B = func_edge_life(B,N,TT)
  130. def func_laplacian_transformation(B, N, TT):
  131. vertices = torch.LongTensor([range(N), range(N)])
  132. tensor_idx = torch.LongTensor([])
  133. tensor_val = torch.DoubleTensor([]).unsqueeze(1)
  134. for j in range(TT):
  135. time = torch.ones(N, dtype=torch.long) * j
  136. time = time.unsqueeze(0)
  137. full = torch.cat((time,vertices),0)
  138. tensor_idx = torch.cat((tensor_idx,full),1)
  139. val = torch.ones(N, dtype=torch.double)
  140. tensor_val = torch.cat((tensor_val, val.unsqueeze(1)),0)
  141. tensor_val.squeeze_(1)
  142. I = torch.sparse.DoubleTensor(tensor_idx, tensor_val , torch.Size([TT,N,N]))
  143. C = B + I
  144. tensor_idx = torch.LongTensor([])
  145. tensor_val = torch.DoubleTensor([]).unsqueeze(1)
  146. for j in range(TT):
  147. idx = C._indices()[0] == j
  148. mat = torch.sparse.DoubleTensor(C._indices()[1:3,idx], C._values()[idx], torch.Size([N,N]))
  149. vec = torch.ones([N,1], dtype=torch.double)
  150. degree = 1/torch.sqrt(torch.sparse.mm(mat, vec))
  151. index = torch.LongTensor(C._indices()[0:3,idx].size())
  152. values = torch.DoubleTensor(C._values()[idx].size())
  153. index[0] = j
  154. index[1:3] = mat._indices()
  155. values = mat._values()
  156. count = 0
  157. for i,j in index[1:3].transpose(1,0):
  158. values[count] = values[count] * degree[i] * degree[j]
  159. count = count + 1
  160. tensor_idx = torch.cat((tensor_idx,index), 1)
  161. tensor_val = torch.cat((tensor_val,values.unsqueeze(1)),0)
  162. tensor_val.squeeze_(1)
  163. C = torch.sparse.DoubleTensor(tensor_idx, tensor_val , torch.Size([TT,N,N]))
  164. return C.coalesce()
  165. C = func_laplacian_transformation(B, N, TT)
  166. Ct = C.clone().coalesce()
  167. if TT < (T + no_val_samples + no_test_samples):
  168. TTT= (T + no_val_samples + no_test_samples)
  169. Ct = torch.sparse.DoubleTensor(Ct._indices(), Ct._values() , torch.Size([TTT,N,N])).coalesce()
  170. else:
  171. TTT = TT
  172. def func_create_sparse(A, N, TTT, T, start, end):
  173. assert (end-start) == T
  174. idx = (A._indices()[0] >= start) & (A._indices()[0] < end)
  175. index = torch.LongTensor(A._indices()[0:3,idx].size())
  176. values = torch.DoubleTensor(A._values()[idx].size())
  177. index[0:3] = A._indices()[0:3,idx]
  178. index[0] = index[0] - start
  179. values = A._values()[idx]
  180. sub = torch.sparse.DoubleTensor(index, values , torch.Size([T,N,N]))
  181. return sub.coalesce()
  182. C_train = func_create_sparse(Ct, N, TTT, T, 0, T)#There is a bug in matlab Bitcoin Alpha
  183. C_val = func_create_sparse(Ct, N, TTT, T, no_val_samples, T+no_val_samples)
  184. C_test = func_create_sparse(Ct, N, TTT, T, no_val_samples+no_test_samples, TTT)
  185. def to_sparse(x):
  186. x_typename = torch.typename(x).split('.')[-1]
  187. sparse_tensortype = getattr(torch.sparse, x_typename)
  188. indices = torch.nonzero(x)
  189. if len(indices.shape) == 0: # if all elements are zeros
  190. return sparse_tensortype(*x.shape)
  191. indices = indices.t()
  192. values = x[tuple(indices[i] for i in range(indices.shape[0]))]
  193. return sparse_tensortype(indices, values, x.size())
  194. dense = torch.randn(3,3)
  195. dense[[0,0,1], [1,2,0]] = 0 # make sparse
  196. def func_MProduct(C, M):
  197. assert C.size()[0] == M.size()[0]
  198. Tr = C.size()[0]
  199. N = C.size()[1]
  200. C_new = torch.sparse.DoubleTensor(C.size())
  201. #C_new = C.clone()
  202. for j in range(Tr):
  203. idx = C._indices()[0] == j
  204. mat = torch.sparse.DoubleTensor(C._indices()[1:3,idx], C._values()[idx], torch.Size([N,N]))
  205. tensor_idx = torch.zeros([3, mat._nnz()], dtype=torch.long)
  206. tensor_val = torch.zeros([mat._nnz()], dtype=torch.double)
  207. tensor_idx[1:3] = mat._indices()[0:2]
  208. indices = torch.nonzero(M[:,j])
  209. assert indices.size()[0] <= no_diag
  210. for i in range(indices.size()[0]):
  211. tensor_idx[0] = indices[i]
  212. tensor_val = M[indices[i], j] * mat._values()
  213. C_new = C_new + torch.sparse.DoubleTensor(tensor_idx, tensor_val , C.size())
  214. C_new.coalesce()
  215. return C_new.coalesce()
  216. Ct_train = func_MProduct(C_train, M)#There is a bug in matlab Bitcoin Alpha
  217. Ct_val = func_MProduct(C_val, M)
  218. Ct_test = func_MProduct(C_test, M)
  219. A_subs = A._indices()
  220. A_vals = A._values()
  221. A_labels_subs = A_labels._indices()
  222. A_labels_vals = A_labels._values()
  223. C_subs = C._indices()
  224. C_vals = C.values()
  225. C_train_subs = C_train._indices()
  226. C_train_vals = C_train.values()
  227. C_val_subs = C_val._indices()
  228. C_val_vals = C_val.values()
  229. C_test_subs = C_test._indices()
  230. C_test_vals = C_test.values()
  231. Ct_train_subs = Ct_train._indices()
  232. Ct_train_vals = Ct_train.values()
  233. Ct_val_subs = Ct_val._indices()
  234. Ct_val_vals = Ct_val.values()
  235. Ct_test_subs = Ct_test._indices()
  236. Ct_test_vals = Ct_test.values()
  237. sio.savemat(save_file_location + save_file_name, {
  238. 'tensor_idx': np.array(tensor_idx),
  239. 'tensor_labels': np.array(tensor_labels),
  240. 'A_labels_subs': np.array(A_labels_subs),
  241. 'A_labels_vals': np.array(A_labels_vals),
  242. 'A_subs': np.array(A_subs),
  243. 'A_vals': np.array(A_vals),
  244. 'C_subs': np.array(C_subs),
  245. 'C_vals': np.array(C_vals),
  246. 'C_train_subs': np.array(C_train_subs),
  247. 'C_train_vals': np.array(C_train_vals),
  248. 'C_val_subs': np.array(C_val_subs),
  249. 'C_val_vals': np.array(C_val_vals),
  250. 'C_test_subs': np.array(C_test_subs),
  251. 'C_test_vals': np.array(C_test_vals),
  252. 'Ct_train_subs': np.array(Ct_train_subs),
  253. 'Ct_train_vals': np.array(Ct_train_vals),
  254. 'Ct_val_subs': np.array(Ct_val_subs),
  255. 'Ct_val_vals': np.array(Ct_val_vals),
  256. 'Ct_test_subs': np.array(Ct_test_subs),
  257. 'Ct_test_vals': np.array(Ct_test_vals),
  258. 'M': np.array(M)
  259. })