You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 3.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import numpy as np
  2. import scipy.sparse as sp
  3. import torch
  4. import networkx as nx
  5. from sklearn.model_selection import train_test_split
  6. from sklearn.metrics import f1_score, roc_auc_score
  7. import pickle
  8. def encode_onehot_torch(labels):
  9. num_classes = int(labels.max() + 1)
  10. y = torch.eye(num_classes)
  11. return y[labels]
  12. def encode_onehot(labels):
  13. classes = set(labels)
  14. classes_dict = {c: np.identity(len(classes))[i, :] for i, c in
  15. enumerate(classes)}
  16. labels_onehot = np.array(list(map(classes_dict.get, labels)),
  17. dtype=np.int32)
  18. return labels_onehot
  19. def fill_features(features):
  20. empty_indices = np.argwhere(features != features)
  21. for index in empty_indices:
  22. features[index[0], index[1]] = np.nanmean(features[:,index[1]])
  23. return features
  24. def load_data_medical(dataset_addr, train_ratio, test_ratio=0.2):
  25. with open(dataset_addr, 'rb') as f: # Python 3: open(..., 'rb')
  26. adj, features, labels = pickle.load(f)
  27. n_node = adj.shape[1]
  28. adj = nx.adjacency_matrix(nx.from_numpy_array(adj))
  29. adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)
  30. adj_D = normalize(adj + sp.eye(adj.shape[0]))
  31. adj_W = normalize(adj + sp.eye(adj.shape[0]))
  32. adj= dict()
  33. adj['D'] = adj_D
  34. features = features.astype(np.float)
  35. features = fill_features(features)
  36. idx_train, idx_test = train_test_split(range(n_node), test_size=test_ratio, random_state=42, stratify=labels)
  37. idx_train, idx_val = train_test_split(idx_train, train_size=train_ratio/(1-test_ratio), random_state=42, stratify=labels[idx_train])
  38. adj['W'] = []
  39. for nc in range(labels.max() + 1):
  40. nc_idx = np.where(labels[idx_train] == nc)[0]
  41. nc_idx = np.array(idx_train)[nc_idx]
  42. adj['W'].append(adj_W[np.ix_(nc_idx,nc_idx)])
  43. features = torch.FloatTensor(features)
  44. for key,val in adj.items():
  45. if key == 'D':
  46. adj[key] = torch.FloatTensor(np.array(adj[key].todense()))#sparse_mx_to_torch_sparse_tensor(adj[i])
  47. else:
  48. for i in range(len(val)):
  49. adj[key][i] = torch.FloatTensor(np.array(adj[key][i].todense()))
  50. labels = torch.LongTensor(labels)
  51. idx_train = torch.LongTensor(idx_train)
  52. idx_val = torch.LongTensor(idx_val)
  53. idx_test = torch.LongTensor(idx_test)
  54. return adj, features, labels, idx_train, idx_val, idx_test
  55. def normalize(mx):
  56. """Row-normalize sparse matrix"""
  57. rowsum = np.array(mx.sum(1))
  58. r_inv = np.power(rowsum, -1).flatten()
  59. r_inv[np.isinf(r_inv)] = 0.
  60. r_mat_inv = sp.diags(r_inv)
  61. mx = r_mat_inv.dot(mx)
  62. return mx
  63. def accuracy(output, labels):
  64. preds = output.max(1)[1].type_as(labels)
  65. correct = preds.eq(labels).double()
  66. correct = correct.sum()
  67. return correct / len(labels)
  68. def class_f1(output, labels, type='micro', pos_label=None):
  69. preds = output.max(1)[1].type_as(labels)
  70. if pos_label is None:
  71. return f1_score(labels.cpu().numpy(), preds.cpu(), average=type)
  72. else:
  73. return f1_score(labels.cpu().numpy(), preds.cpu(), average=type, pos_label=pos_label)
  74. def auc_score(output, labels):
  75. return roc_auc_score(labels.cpu().numpy(), output[:,1].detach().cpu().numpy())
  76. def sparse_mx_to_torch_sparse_tensor(sparse_mx):
  77. """Convert a scipy sparse matrix to a torch sparse tensor."""
  78. sparse_mx = sparse_mx.tocoo().astype(np.float32)
  79. indices = torch.from_numpy(
  80. np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
  81. values = torch.from_numpy(sparse_mx.data)
  82. shape = torch.Size(sparse_mx.shape)
  83. return torch.sparse.FloatTensor(indices, values, shape)