Browse Source

py_mgcn commit

master
Mahsa 2 years ago
commit
fb645a5fc9

+ 76
- 0
ReadMe.md View File

MGCN: Semi-supervised Classification in Multi-layer Graphs with Graph Convolutional Networks
====

Here is the code for node embedding in multi-layer networks with attributes written in Pytorch.
Ghorbani et.al. "MGCN: Semi-supervised Classification in Multi-layer Graphs with Graph Convolutional Networks" [1]



Usage
------------
The main file is "train.py". It contains the running code for infra dataset with default parameters.
```python train.py```


Input Data
------------
You should create a folder in "data" folder, named "dataset_str" in the code. For example if your dataset_str is "infra", you should place files in the path "data/infra/". The folder contains the following information:
1) The adjacency list of within layer edges for every layer. Node numbering should be started from 0. The name of the files containing the within layer adjacency list should follow the pattern "dataset_name.adj#NumberOfLayer".
For example if your dataset name is "infra" with three layers, you should have "infra.adj0", "infra.adj1" and "infra.adj2" files that every one of them contains adjacency list like follow:
```
0 1
0 2
0 3
0 4
```
2) The adjacency list of between layer edges for every pair of layers (if available). The name of the files containing the between layer adjacency list should follow the pattern "dataset_name.bet#NumberOfLayer1_#NumberOfLayer2".
For example if your dataset name is "infra" with three layers, you should have "infra.bet0_1", "infra.bet0_2" and "infra.bet1_2" (One file for every pair of layers is enough)
3) Features and labels of every node in every layer. If features aren't available identity matrix should be used. The name of the files containing the features and labels should follow the pattern "dataset_name.feat#NumberOfLayer"
For example if your dataset name is "infra" with three layers, you should have "infra.feat0", "infra.feat1", "infra.feat2". Every line of file includes the node number, node features, node label. For example for "infra" dataset, if the first layer have 5 node belongs to 2 classes, it should be like follow:
```
NodeNumber Features Label
0 1 0 0 0 0 1
1 0 1 0 0 0 1
2 0 0 1 0 0 2
3 0 0 0 1 0 2
4 0 0 0 0 1 2
```

Parameters
------------
Parameters of the code are all in the list format for testing different combination of configs.
Parameters are as follow:
- dataset_str: Dataset name
- adj_weights: It is a list for assigning different weights to input layers. If you leave it empty, all layers will have equal weihts. [[1,1,1]] includes one config with equal weights.
- wlambdas: It is the weight of reconstruction loss.
- hidden_structures: A list contains different structures for hidden spaces in every layer. For example [[[32],[32],[32]]] contains one config for hidden spaces of every layer which in all of them, the dimension of embedding for all layers are 32.
- lrs: A list contains learning rate(s).
- test_sizes: A list contains the test size(s).

Note
------------
Thanks to Thomas Kipf. The code is written based on the "Graph Convolutional Networks in PyTorch" [2].

Bug Report
------------
If you find a bug, please send a bug report to [email protected] including if necessary the input file and the parameters that caused the bug.
You can also send me any comment or suggestion about the program.

References
------------
[1] [Ghorbani et.al., MGCN: Semi-supervised Classification in Multi-layer Graphs with Graph Convolutional Networks, 2019](https://arxiv.org/pdf/1811.08800)

[2] [Kipf & Welling, Semi-Supervised Classification with Graph Convolutional Networks, 2016](https://arxiv.org/abs/1609.02907)

Cite
------------
Please cite our paper if you use this code in your own work:

```
@article{ghorbani2018multi,
title={Multi-layered Graph Embedding with Graph Convolution Networks},
author={Ghorbani, Mahsa and Baghshah, Mahdieh Soleymani and Rabiee, Hamid R},
journal={arXiv preprint arXiv:1811.08800},
year={2018}
}
```

BIN
code/__pycache__/layers.cpython-36.pyc View File


BIN
code/__pycache__/models.cpython-36.pyc View File


BIN
code/__pycache__/utils.cpython-36.pyc View File


+ 80
- 0
code/layers.py View File

import math

import torch

from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module


class GraphConvolution(Module):
"""
Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
"""
def __init__(self, in_features, out_features, bias=True):
super(GraphConvolution, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = Parameter(torch.FloatTensor(in_features, out_features))
if bias:
self.bias = Parameter(torch.FloatTensor(out_features))
else:
self.register_parameter('bias', None)
self.reset_parameters()

def reset_parameters(self):
stdv = 1. / math.sqrt(self.weight.size(1))
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.uniform_(-stdv, stdv)

def forward(self, input, adj):
support = torch.spmm(input, self.weight)
output = torch.spmm(adj, support)
if self.bias is not None:
return output + self.bias
else:
return output

def __repr__(self):
return self.__class__.__name__ + ' (' \
+ str(self.in_features) + ' -> ' \
+ str(self.out_features) + ')'


class CrossLayer(Module):
"""
MultiLayer
"""
def __init__(self, L1_dim, L2_dim, bias=True, bet_weight=True):
super(CrossLayer, self).__init__()
self.L1_dim = L1_dim
self.L2_dim = L2_dim
self.bet_weight = bet_weight
self.weight = Parameter(torch.FloatTensor(L1_dim, L2_dim))
if bias:
self.bias = Parameter(torch.FloatTensor(L2_dim))
else:
self.register_parameter('bias', None)
self.reset_parameters()

def reset_parameters(self):
stdv = 1. / math.sqrt(self.weight.size(1))
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.uniform_(-stdv, stdv)

def forward(self, L1_features, L2_features):
if self.bet_weight:
temp = torch.mm(L1_features, self.weight)
output = torch.mm(temp, torch.t(L2_features))
if self.bias is not None:
output = output + self.bias
else:
output = torch.mm(L1_features, torch.t(L2_features))
return output


def __repr__(self):
return self.__class__.__name__ + ' (' \
+ str(self.L1_dim) + ' -> ' \
+ str(self.L2_dim) + ')'

+ 113
- 0
code/models.py View File

import torch
import torch.nn as nn
import torch.nn.functional as F
from layers import GraphConvolution, CrossLayer
from utils import between_index


class GCN(nn.Module):
def __init__(self, nfeat, nhid, nclass, dropout):
super(GCN, self).__init__()
self.nclass = nclass

layers = []
layers.append(GraphConvolution(nfeat, nhid[0]))
for i in range(len(nhid)-1):
layers.append(GraphConvolution(nhid[i], nhid[i+1]))
if nclass > 1:
layers.append(GraphConvolution(nhid[-1], nclass))
self.gc = nn.ModuleList(layers)
self.dropout = dropout

def forward(self, x, adj):
end_layer = len(self.gc)-1 if self.nclass > 1 else len(self.gc)
for i in range(end_layer):
x = F.relu(self.gc[i](x, adj))
x = F.dropout(x, self.dropout, training=self.training)
if self.nclass > 1:
classifier = self.gc[-1](x, adj)
return F.log_softmax(classifier, dim=1), x
else:
return None, x


class CCN(nn.Module):
def __init__(self, n_inputs, inputs_nfeat, inputs_nhid, inputs_nclass, dropout):
super(CCN, self).__init__()

# Features of network
self.n_inputs = n_inputs
self.inputs_nfeat = inputs_nfeat
self.inputs_nhid = inputs_nhid
self.inputs_nclass = inputs_nclass
self.dropout = dropout

# Every single layer
temp_in_layer = []
temp_bet_layer = []
for i in range(n_inputs):
temp_in_layer.append(GCN(inputs_nfeat[i], inputs_nhid[i], inputs_nclass[i], dropout))
self.in_layer = nn.ModuleList(temp_in_layer)

# Between layers
for i in range(n_inputs):
for j in range(i+1,n_inputs):
temp_bet_layer.append(CrossLayer(inputs_nhid[i][-1], inputs_nhid[j][-1], bias=True, bet_weight=False))
self.bet_layer = nn.ModuleList(temp_bet_layer)

def forward(self, xs, adjs):
classes_labels = []
bet_layers_output = []
in_layers_output = []
features = []
# Single layer forward to find features and classes
for i in range(self.n_inputs):
class_labels, feature = self.in_layer[i](xs[i], adjs[i])
classes_labels.append(class_labels)
features.append(feature)

temp = torch.mm(feature, torch.t(feature))
in_layers_output.append(temp)

# Find between layers with CCN
for i in range(self.n_inputs):
for j in range(i+1,self.n_inputs):
temp = self.bet_layer[between_index(self.n_inputs, i, j)](features[i], features[j])
bet_layers_output.append(temp)

return classes_labels, bet_layers_output, in_layers_output


class AggCCN(nn.Module):
def __init__(self, n_inputs, inputs_nfeat, inputs_nhid, inputs_nclass, dropout):
super(AggCCN, self).__init__()

# Features of network
self.n_inputs = n_inputs
self.inputs_nfeat = inputs_nfeat
self.inputs_nhid = inputs_nhid
self.inputs_nclass = inputs_nclass
self.dropout = dropout

# Every single layer
temp_in_layer = []
temp_bet_layer = []
for i in range(n_inputs):
temp_in_layer.append(GCN(inputs_nfeat[i], inputs_nhid[i], inputs_nclass[i], dropout))
self.in_layer = nn.ModuleList(temp_in_layer)

def forward(self, xs, adjs):
classes_labels = []
bet_layers_output = []
in_layers_output = []
features = []
# Single layer forward to find features and classes
for i in range(self.n_inputs):
class_labels, feature = self.in_layer[i](xs[i], adjs[i])
classes_labels.append(class_labels)
features.append(feature)

temp = torch.mm(feature, torch.t(feature))
in_layers_output.append(temp)

return classes_labels, in_layers_output

+ 271
- 0
code/train.py View File

from __future__ import division
from __future__ import print_function

import time
import argparse

import torch
import torch.nn.functional as F
import torch.optim as optim

from sklearn.metrics import f1_score

from utils import load_data, class_accuracy, class_f1, layer_accuracy, layer_f1, dict_to_writer, trains_vals_tests_split
from models import CCN
from tensorboardX import SummaryWriter

# f1 name should be f1_micro or f1_macro
# class_metrics = {'loss': F.nll_loss, 'acc': class_accuracy, 'f1_micro': class_f1, 'f1_macro': class_f1,
# 'all_f1_micro': f1_score, 'all_f1_macro': f1_score}
class_metrics = {'loss': F.nll_loss, 'all_f1_micro': f1_score, 'all_f1_macro': f1_score}
layer_metrics = {'loss': torch.nn.BCEWithLogitsLoss()}


# Preparing the output of classifier for every layer separately (class numbers should be different over layers)
def prepare_output_labels(outputs, labels, idx):
tmp_labels = []
tmp_outputs = []
max_label = 0
for i in range(len(labels)):
tmp_labels.extend(labels[i][idx[i]] + max_label)
tmp_outputs.extend((outputs[i].max(1)[1].type_as(labels[i]) + max_label)[idx[i]])
max_label = labels[i].max()

return tmp_outputs, tmp_labels


# Calculate classification metrics for every layer
def model_stat_in_layer_class(in_classes_pred, in_labels, idx, metrics):
stats = {}
for i in range(len(in_classes_pred)):
if not in_classes_pred[i] is None:
for metr_name,metr_func in metrics.items():
if not metr_name in stats:
stats[metr_name] = []
if metr_name[:2] == 'f1':
try:
stats[metr_name].append(metr_func(in_classes_pred[i][idx[i]], in_labels[i][idx[i]], metr_name.split('f1_')[-1]))
except:
print("Exception in F1 for class" + str(i) + " for classification")
stats[metr_name].append(0)
elif metr_name[:3] == 'all':
continue
else:
stats[metr_name].append(metr_func(in_classes_pred[i][idx[i]], in_labels[i][idx[i]]))
else:
for metr_name, metr_func in metrics.items():
if not metr_name in stats:
stats[metr_name] = []
stats[metr_name].append(0)

tmp_in_classes_pred, tmp_labels = prepare_output_labels(in_classes_pred, in_labels, idx)
for metr_name, metr_func in metrics.items():
if metr_name[:3] == 'all':
stats[metr_name] = metr_func(tmp_in_classes_pred, tmp_labels, average=metr_name.split('f1_')[-1])

return stats


# Calculate reconstruction metrics for within and between layer edges (based on the input)
def model_stat_struc(preds, reals, metrics, pos_weights, norms, idx=None):
stats = {}
for i in range(len(preds)):
if reals[i] is None:
continue
if 'sparse' in reals[i].type():
curr_real = reals[i].to_dense()
else:
curr_real = reals[i]
if idx is None:
final_preds = preds[i]
final_real = curr_real
else:
final_preds = preds[i][idx[i][0],idx[i][1]]
final_real = curr_real[idx[i][0],idx[i][1]]
for metr_name,metr_func in metrics.items():
if not metr_name in stats:
stats[metr_name] = []
if metr_name == 'f1':
try:
stats[metr_name].append(metr_func(final_preds, final_real, type = metr_name.split('f1_')[-1]))
except:
print("Exception in F1 for layer" + str(i) + " for structure")
stats[metr_name].append(0)
elif metr_name == 'loss':
if args.cuda:
pw = pos_weights[i] * torch.ones(final_real.size()).cuda()
else:
pw = pos_weights[i] * torch.ones(final_real.size())
stats[metr_name].append(norms[i]*torch.nn.BCEWithLogitsLoss(pos_weight=pw)(final_preds, final_real))
else:
try:
stats[metr_name].append(metr_func(final_preds, final_real))
except:
print("Odd problem")

return stats


def train(epoch):
t = time.time()
model.train()
optimizer.zero_grad()

# Run model
t1 = time.time()
classes_pred, bet_layers_pred, in_layers_pred = model(features, adjs)
# Final statistics
stats = dict()
stats['in_class'] = model_stat_in_layer_class(classes_pred, labels, idx_trains, class_metrics)
# Reconstructing all the within and between layer edges without train and test split (for using part of the edges
# for reconstruction, idx is needed)
stats['in_struc'] = model_stat_struc(in_layers_pred, adjs_orig, layer_metrics, adjs_pos_weights, adjs_norms)
stats['bet_struc'] = model_stat_struc(bet_layers_pred, bet_adjs_orig, layer_metrics, bet_pos_weights, bet_norms)
# Write to writer
for name, stat in stats.items():
dict_to_writer(stat, writer, epoch, name, 'Train')
# Calculate loss
loss_train = 0
for name, stat in stats.items():
# Weighted loss
if name == 'in_struc':
loss_train += wlambda * sum([a * b for a, b in zip(stat['loss'], adj_weight)])
elif name== 'bet_struc':
loss_train += wlambda * sum([a * b for a, b in zip(stat['loss'], adj_weight)])
else:
loss_train += sum([a * b for a, b in zip(stat['loss'], adj_weight)])

# Gradients
loss_train.backward()
optimizer.step()

# Validation if needed
if not args.fastmode:
model.eval()
classes_pred, bet_layers_pred, in_layers_pred = model(features, adjs)
stats = dict()
stats['in_class'] = model_stat_in_layer_class(classes_pred, labels, idx_vals, class_metrics)

stats['in_struc'] = model_stat_struc(in_layers_pred, adjs_orig, layer_metrics, adjs_pos_weights, adjs_norms)
stats['bet_struc'] = model_stat_struc(bet_layers_pred, bet_adjs_orig, layer_metrics, bet_pos_weights, bet_norms)

for name, stat in stats.items():
dict_to_writer(stat, writer, epoch, name, 'Validation')
loss_val = sum([sum(stat['loss']) for stat in stats.values()])

print('Epoch: {:04d}'.format(epoch+1),
'loss_train: {:.4f}'.format(loss_train.item()),
'time: {:.4f}s'.format(time.time() - t))


def test():
# Test
model.eval()
classes_pred, bet_layers_pred, in_layers_pred = model(features, adjs)
stats = dict()
stats['in_class'] = model_stat_in_layer_class(classes_pred, labels, idx_tests, class_metrics)

for name, stat in stats.items():
dict_to_writer(stat, writer, epoch, name, 'Test')

loss_test = sum([sum(stat['loss']) for stat in stats.values()])
print("Test set results:",
"loss= {:.4f}".format(loss_test.item()))


if __name__=='__main__':
# Training settings
parser = argparse.ArgumentParser()
parser.add_argument('--no-cuda', action='store_true', default=True,
help='Disables CUDA training.')
parser.add_argument('--fastmode', action='store_true', default=False,
help='Validate during training pass.')
parser.add_argument('--seed', type=int, default=42, help='Random seed.')
parser.add_argument('--epochs', type=int, default=20,
help='Number of epochs to train.')
parser.add_argument('--weight_decay', type=float, default=5e-4,
help='Weight decay (L2 loss on parameters).')
parser.add_argument('--dropout', type=float, default=0.5,
help='Dropout rate (1 - keep probability).')

args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()

dataset_str = "infra"

# parameter
# All combination of parameters will be tested
# adj_weights contains a list of different configs, for example [[1,1,1]] contains one config with equal weights
# for a three-layered input
adj_weights = []
# parameter, weight of link reconstruction loss function
wlambdas = [10]
# parameter, hidden dimension for every layer should be defined
hidden_structures = [[[32],[32],[32]]]
# hidden_structures = [[[16], [16], [16]],[[32], [32], [32]],[[64], [64], [64]],[[128], [128], [128]],[[256], [256], [256]]]
# Learning rate
lrs = [0.01]
# test size
test_sizes = [0.2]

# Load data
adjs, adjs_orig, adjs_sizes, adjs_pos_weights, adjs_norms, bet_pos_weights, bet_norms, bet_adjs, bet_adjs_orig, bet_adjs_sizes, \
features, features_sizes, labels, labels_nclass = load_data(path="../data/" + dataset_str + "/", dataset=dataset_str)
# Number of layers
n_inputs = len(adjs)
# Weights of layers
adj_weights.append(n_inputs * [1])

if args.cuda:
for i in range(n_inputs):
labels[i] = labels[i].cuda()

adjs_pos_weights[i] = adjs_pos_weights[i].cuda()
adjs[i] = adjs[i].cuda()
adjs_orig[i] = adjs_orig[i].cuda()
features[i] = features[i].cuda()

for i in range(len(bet_adjs)):
if not bet_adjs[i] is None:
bet_adjs[i] = bet_adjs[i].cuda()
bet_adjs_orig[i] = bet_adjs_orig[i].cuda()
bet_pos_weights[i] = bet_pos_weights[i].cuda()
# Number of runs
for run in range(10):
for ts in test_sizes:
idx_trains, idx_vals, idx_tests = trains_vals_tests_split(n_inputs, [s[0] for s in adjs_sizes], val_size=0.1,
test_size=ts, random_state=int(run + ts*100))
if args.cuda:
for i in range(n_inputs):
idx_trains[i] = idx_trains[i].cuda()
idx_vals[i] = idx_vals[i].cuda()
idx_tests[i] = idx_tests[i].cuda()
# Train model
t_total = time.time()
for wlambda in wlambdas:
for adj_weight in adj_weights:
for lr in lrs:
for hidden_structure in hidden_structures:
temp_weight = ['{:.2f}'.format(x) for x in adj_weight]
w_str = '-'.join(temp_weight)
h_str = '-'.join([','.join(map(str, temp)) for temp in hidden_structure])
writer = SummaryWriter('log/' + dataset_str + '_run-'+ str(run)+ '/' + '_lambda-' +
str(wlambda) + "_adj_w-" + w_str + "_LR-" + str(lr) +
'_hStruc-' + h_str + '_test_size-' + str(ts))
model = CCN(n_inputs=n_inputs,
inputs_nfeat=features_sizes,
inputs_nhid=hidden_structure,
inputs_nclass=labels_nclass,
dropout=args.dropout)
optimizer = optim.Adam(model.parameters(),
lr=lr, weight_decay=args.weight_decay)
if args.cuda:
model.cuda()
for epoch in range(args.epochs):
train(epoch)
# Testing
test()
print("Optimization Finished!")
print("Total time elapsed: {:.4f}s".format(time.time() - t_total))



+ 276
- 0
code/utils.py View File

import numpy as np
import scipy.sparse as sp
import torch
from sklearn.metrics import accuracy_score, f1_score

import random
from os import listdir
from os.path import isfile, join
from sklearn.model_selection import train_test_split


def encode_onehot(labels):
classes = set(labels)
classes_dict = {c: np.identity(len(classes))[i, :] for i, c in
enumerate(classes)}
labels_onehot = np.array(list(map(classes_dict.get, labels)),
dtype=np.int32)
return labels_onehot


def load_bet_adj(layer_num1, layer_num2, idx_map_l1, idx_map_l2, path="../data/cora/", dataset="cora"):
temp = "{}{}.bet" + str(layer_num1) + "_" + str(layer_num2)
# print("bet file")
# print(temp)
if not isfile(temp.format(path, dataset)):
return None
edges_unordered = np.genfromtxt(temp.format(path, dataset), dtype=np.int32)
# idx1 = np.array(np.unique(edges_unordered[:,0]), dtype=np.int32)
# idx2 = np.array(np.unique(edges_unordered[:, 1]), dtype=np.int32)
# idx_map_l1 = {j: i for i, j in enumerate(idx1)}
# idx_map_l2 = {j: i for i, j in enumerate(idx2)}
N1 = len(list(idx_map_l1))
N2 = len(list(idx_map_l2))
edges = np.array(list(map(idx_map_l1.get, edges_unordered[:,0])) + list(map(idx_map_l2.get, edges_unordered[:,1])),
dtype=np.int32).reshape(edges_unordered.shape, order='F')
adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])),
shape=(N1, N2),
dtype=np.float32)
adj_orig = sparse_mx_to_torch_sparse_tensor(adj)

adj = normalize(adj)
adj = sparse_mx_to_torch_sparse_tensor(adj)

return adj, adj_orig


def load_in_adj(layer_num, idx_map=None, path="../data/cora/", dataset="cora"):
temp = "{}{}.adj" + str(layer_num)
edges_unordered = np.genfromtxt(temp.format(path, dataset),
dtype=np.int32)
if idx_map is None:
idx = np.array(np.unique(edges_unordered.flatten()), dtype=np.int32)
N = len(list(idx))
idx_map = {j: i for i, j in enumerate(idx)}
else:
N = len(list(idx_map))
edges = np.array(list(map(idx_map.get, edges_unordered.flatten())),
dtype=np.int32).reshape(edges_unordered.shape)
adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])),
shape=(N, N),
dtype=np.float32)
# print("Edges")
# print(np.count_nonzero(adj.toarray()))
#build symmetric adjacency matrix
adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)

adj_orig = sparse_mx_to_torch_sparse_tensor(adj)

adj = normalize(adj + sp.eye(adj.shape[0]))

adj = sparse_mx_to_torch_sparse_tensor(adj)

return adj, adj_orig


def load_features_labels(layer_num, path, dataset,N=-1):
print('Loading {} dataset...'.format(dataset))
temp = "{}{}.feat"+str(layer_num)
idx_features_labels = np.genfromtxt(temp.format(path, dataset),
dtype=np.dtype(str))
temp = idx_features_labels[:, 1:-1]
if temp.size == 0:
features = sp.csr_matrix(np.identity(N), dtype=np.float32)
else:
features = sp.csr_matrix(temp, dtype=np.float32)
labels = encode_onehot(idx_features_labels[:, -1])

#build graph
idx = np.array(idx_features_labels[:, 0], dtype=np.int32)
idx_map = {j: i for i, j in enumerate(idx)}

features = normalize(features)

features = sp.csr_matrix(features)
features = sparse_mx_to_torch_sparse_tensor(features)
# features = torch.FloatTensor(np.array(features.todense()))
labels = torch.LongTensor(np.where(labels)[1])

return features, labels, idx_map


def train_val_test_split(N, val_size=0.2, test_size=0.2, random_state=1):
idx_train_temp, idx_test = train_test_split(range(N), test_size=test_size, random_state=random_state)
if val_size == 0:
idx_train = idx_train_temp
else:
idx_train, idx_val = train_test_split(idx_train_temp, test_size=val_size, random_state=random_state)
idx_train = torch.LongTensor(idx_train)
idx_val = torch.LongTensor(idx_val)
idx_test = torch.LongTensor(idx_test)

return idx_train, idx_val, idx_test


def trains_vals_tests_split(n_layers, labels_sizes, val_size, test_size, random_state):
idx_trains = []
idx_vals = []
idx_tests = []
for i in range(n_layers):
idx_train, idx_val, idx_test = train_val_test_split(labels_sizes[i], val_size, test_size, random_state)
idx_trains.append(idx_train)
idx_vals.append(idx_val)
idx_tests.append(idx_test)

return idx_trains, idx_vals, idx_tests


def load_data(path="../data/cora/", dataset="cora"):
layers = 0
adjs = []
adjs_orig = []
adjs_sizes = []
adjs_pos_weights = []
adjs_norms = []

bet_adjs = []
bet_adjs_orig = []
bet_adjs_sizes = []
bet_pos_weights = []
bet_norms = []

features = []
features_sizes = []
labels = []
labels_nclass = []

idx_maps = []
for f in listdir(path):
if isfile(join(path, f)):
if 'adj' in f:
layers += 1
for i in range(layers):
feature, label, idx_map = load_features_labels(i, path, dataset)
adj, adj_orig = load_in_adj(i, idx_map, path, dataset)

pos_weight = float(adj.shape[0] * adj.shape[1] - adj.to_dense().sum()) / adj.to_dense().sum()
norm = adj.shape[0] * adj.shape[1] / float((adj.shape[0] * adj.shape[1] - adj.to_dense().sum()) * 2)

idx_maps.append(idx_map)

adjs.append(adj)
adjs_orig.append(adj_orig)
adjs_sizes.append(tuple(adj.size()))
adjs_pos_weights.append(pos_weight)
adjs_norms.append(norm)
features.append(feature)
features_sizes.append(feature.shape[1])
labels.append(label)
labels_nclass.append((label.max().item() + 1))

for i in range(layers):
for j in range(i+1,layers):
bet_adj, bet_adj_orig = load_bet_adj(i, j, idx_maps[i], idx_maps[j], path, dataset)
bet_adjs.append(bet_adj)
bet_adjs_orig.append(bet_adj_orig)
bet_adjs_sizes.append(tuple(bet_adj.size()) if not bet_adj is None else tuple((0,0)))
if not bet_adj is None:
pos_weight = float(
bet_adj.shape[0] * bet_adj.shape[1] - bet_adj.to_dense().sum()) / bet_adj.to_dense().sum()
norm = bet_adj.shape[0] * bet_adj.shape[1] / float((bet_adj.shape[0] * bet_adj.shape[1] - bet_adj.to_dense().sum()) * 2)
else:
norm = None
pos_weight = None
bet_pos_weights.append(pos_weight)
bet_norms.append(norm)

return adjs, adjs_orig, adjs_sizes, adjs_pos_weights, adjs_norms, bet_pos_weights, bet_norms, bet_adjs, bet_adjs_orig, \
bet_adjs_sizes, features, features_sizes, labels, labels_nclass


def normalize(mx):
"""Row-normalize sparse matrix"""
rowsum = np.array(mx.sum(1))
r_inv = np.power(rowsum, -1).flatten()
r_inv[np.isinf(r_inv)] = 0.
r_mat_inv = sp.diags(r_inv)
mx = r_mat_inv.dot(mx)
return mx


def class_accuracy(output, labels, type=None):
preds = output.max(1)[1].type_as(labels)
return accuracy_score(labels.data, preds)


def class_f1(output, labels, type='micro'):
preds = output.max(1)[1].type_as(labels)
return f1_score(labels.data, preds, average=type)


def layer_accuracy(output, real, type=None):
preds = output.data.clone()
true = real.data.clone()
preds[output < 0.5] = 0
preds[output >= 0.5] = 1
true[real > 0] = 1
return accuracy_score(true, preds)


def layer_f1(output, real, type='micro'):
preds = output.data.clone()
true = real.data.clone()
preds[output < 0.5] = 0
preds[output >= 0.5] = 1
true[real > 0] = 1
return f1_score(true, preds, average=type)


def writer_data(values, writer, epoch, type, name):
try:
for i, j in enumerate(values):
# my_dic[name+str(i)] = j
writer.add_scalar(type + "/" + name + str(i), j, epoch)
except TypeError as te:
writer.add_scalar(type + "/" + name, values, epoch)


# type = in_class, in_struc, bet_struc
def dict_to_writer(stats, writer, epoch, type, train_test_val):
for key, value in stats.items():
# type_str = train_test_val + "/" + type + '/Loss' if key == 'loss' else train_test_val + type + '/Stats'
type_str = train_test_val + "/" + type
writer_data(value, writer, epoch, type_str, key)


def sparse_mx_to_torch_sparse_tensor(sparse_mx):
"""Convert a scipy sparse matrix to a torch sparse tensor."""
sparse_mx = sparse_mx.tocoo().astype(np.float32)
indices = torch.from_numpy(
np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
values = torch.from_numpy(sparse_mx.data)
shape = torch.Size(sparse_mx.shape)
return torch.sparse.FloatTensor(indices, values, shape)


def sparse_to_tuple(sparse_mx):
if not sp.isspmatrix_coo(sparse_mx):
sparse_mx = sparse_mx.tocoo()
coords = np.vstack((sparse_mx.row, sparse_mx.col)).transpose()
values = sparse_mx.data
shape = sparse_mx.shape
return coords, values, shape


def between_index(n_inputs, i,j):
return int(i * n_inputs - (i * (i + 1) / 2) + (j - i - 1))

def gather_edges(pos_edges, neg_edges):
all_edges = [[], []]
all_edges[0].extend([idx_i[0] for idx_i in pos_edges])
all_edges[1].extend([idx_i[1] for idx_i in pos_edges])
all_edges[0].extend([idx_i[0] for idx_i in neg_edges])
all_edges[1].extend([idx_i[1] for idx_i in neg_edges])
return all_edges



Loading…
Cancel
Save