123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193 |
- import matplotlib.pyplot as plt
-
- import torch
- import torch.optim as optim
-
- from utils import accuracy, encode_onehot_torch, class_f1, roc_auc, half_normalize, loss, calculate_imbalance_weight
- from models import Fully
-
-
- def run_gkd(adj, features, labels, idx_train, idx_val, idx_test, idx_connected,
- params_teacher, params_student, params_lpa, isolated_test, sample_weight, use_cuda=True, show_stats=False):
-
- def train_teacher(epoch):
- model_teacher.train()
- optimizer_teacher.zero_grad()
- output = model_teacher(features)
- loss_train = loss(output[idx_train], labels_one_hot[idx_train], sample_weight[idx_train])
- acc_train = accuracy(output[idx_train], labels[idx_train])
- stats['f1macro_train'].append(class_f1(output[idx_train], labels[idx_train], type='macro'))
- if 'auc_train' in stats:
- stats['auc_train'].append(roc_auc(output[idx_train, 1], labels[idx_train]))
- loss_train.backward()
- optimizer_teacher.step()
-
- model_teacher.eval()
- output = model_teacher(features)
-
- loss_val = loss(output[idx_val], labels_one_hot[idx_val], sample_weight[idx_val])
- acc_val = accuracy(output[idx_val], labels[idx_val])
- if 'auc_val' in stats:
- stats['auc_val'].append(roc_auc(output[idx_val, 1], labels[idx_val]))
- stats['f1macro_val'].append(class_f1(output[idx_val], labels[idx_val], type='macro'))
- stats['loss_train'].append(loss_train.item())
- stats['acc_train'].append(acc_train.item())
- stats['loss_val'].append(loss_val.item())
- stats['acc_val'].append(acc_val.item())
-
- print('Epoch: {:04d}'.format(epoch + 1),
- 'loss_train: {:.4f}'.format(loss_train.item()),
- 'acc_train: {:.4f}'.format(acc_train.item()),
- 'loss_val: {:.4f}'.format(loss_val.item()),
- 'acc_val: {:.4f}'.format(acc_val.item()))
-
- return output.detach()
-
- def test_teacher():
- model_teacher.eval()
- output = model_teacher(features)
- loss_test = loss(output[idx_test], labels_one_hot[idx_test], sample_weight[idx_test])
- acc_test = accuracy(output[idx_test], labels[idx_test])
- stats['f1macro_test'].append(class_f1(output[idx_test], labels[idx_test], type='macro'))
- stats['loss_test'].append(loss_test.item())
- stats['acc_test'].append(acc_test.item())
- if 'auc_test' in stats:
- stats['auc_test'].append(roc_auc(output[idx_test, 1], labels[idx_test]))
-
- print("Test set results:",
- "loss= {:.4f}".format(loss_test.item()),
- "accuracy= {:.4f}".format(acc_test.item()))
-
- def train_student(epoch):
- model_student.train()
- optimizer_student.zero_grad()
- output = model_student(features)
- loss_train = loss(output[idx_lpa], labels_lpa[idx_lpa], sample_weight[idx_lpa])
- acc_train = accuracy(output[idx_lpa], labels[idx_lpa])
- stats['f1macro_train'].append(class_f1(output[idx_train], labels[idx_train], type='macro'))
- if 'auc_train' in stats:
- stats['auc_train'].append(roc_auc(output[idx_train, 1], labels[idx_train]))
- if 'f1binary_train' in stats:
- stats['f1binary_train'].append(class_f1(output[idx_train], labels[idx_train], type='binary'))
- loss_train.backward()
- optimizer_student.step()
-
- model_student.eval()
- output = model_student(features)
- loss_val = loss(output[idx_val], labels_one_hot[idx_val], sample_weight[idx_val])
- acc_val = accuracy(output[idx_val], labels[idx_val])
- if 'auc_val' in stats:
- stats['auc_val'].append(roc_auc(output[idx_val, 1], labels[idx_val]))
- stats['f1macro_val'].append(class_f1(output[idx_val], labels[idx_val], type='macro'))
- if 'f1binary_val' in stats:
- stats['f1binary_val'].append(class_f1(output[idx_val], labels[idx_val], type='binary'))
- stats['loss_train'].append(loss_train.item())
- stats['acc_train'].append(acc_train.item())
- stats['loss_val'].append(loss_val.item())
- stats['acc_val'].append(acc_val.item())
-
- print('Epoch: {:04d}'.format(epoch + 1),
- 'loss_train: {:.4f}'.format(loss_train.item()),
- 'acc_train: {:.4f}'.format(acc_train.item()),
- 'loss_val: {:.4f}'.format(loss_val.item()),
- 'acc_val: {:.4f}'.format(acc_val.item()))
-
- def test_student():
- model_student.eval()
- output = model_student(features)
- loss_test = loss(output[idx_test], labels_one_hot[idx_test], sample_weight[idx_test])
- acc_test = accuracy(output[idx_test], labels[idx_test])
- stats['f1macro_test'].append(class_f1(output[idx_test], labels[idx_test], type='macro'))
- if 'f1binary_test' in stats:
- stats['f1binary_test'].append(class_f1(output[idx_test], labels[idx_test], type='binary'))
- stats['loss_test'].append(loss_test.item())
- stats['acc_test'].append(acc_test.item())
- if 'auc_test' in stats:
- stats['auc_test'].append(roc_auc(output[idx_test, 1], labels[idx_test]))
-
- print("Test set results:",
- "loss= {:.4f}".format(loss_test.item()),
- "accuracy= {:.4f}".format(acc_test.item()))
- return
-
- stats = dict()
- states = ['train', 'val', 'test']
- metrics = ['loss', 'acc', 'f1macro','auc']
-
- for s in states:
- for m in metrics:
- stats[m + '_' + s] = []
-
- labels_one_hot = encode_onehot_torch(labels)
- idx_lpa = list(idx_connected) + list(idx_train)
- if isolated_test:
- idx_lpa = [x for x in idx_lpa if (x not in idx_test) and (x not in idx_val)]
- idx_lpa = torch.LongTensor(idx_lpa)
- model_teacher = Fully(nfeat=features.shape[1],
- nhid=params_teacher['hidden'],
- nclass=int(labels_one_hot.shape[1]),
- dropout=params_teacher['dropout'])
- model_student = Fully(nfeat=features.shape[1],
- nhid=params_student['hidden'],
- nclass=int(labels_one_hot.shape[1]),
- dropout=params_student['dropout'])
-
- optimizer_teacher = optim.Adam(model_teacher.parameters(),
- lr=params_teacher['lr'],
- weight_decay=params_teacher['weight_decay'])
- optimizer_student = optim.Adam(model_student.parameters(),
- lr=params_student['lr'],
- weight_decay=params_student['weight_decay'])
-
- if use_cuda:
- model_student.cuda()
- model_teacher.cuda()
- features = features.cuda()
- sample_weight = sample_weight.cuda()
- adj = adj.cuda()
- labels = labels.cuda()
- labels_one_hot = labels_one_hot.cuda()
- idx_train = idx_train.cuda()
- idx_val = idx_val.cuda()
- idx_test = idx_test.cuda()
- idx_lpa = idx_lpa.cuda()
-
- labels_lpa = labels_one_hot
- best_metric_val = 0
- for epoch in range(params_teacher['epochs']):
- output_fc = train_teacher(epoch)
- test_teacher()
- if stats[params_teacher['best_metric']][-1] >= best_metric_val and epoch > params_teacher['burn_out']:
- best_metric_val = stats[params_teacher['best_metric']][-1]
- best_output = output_fc
-
- alpha = params_lpa['alpha']
- labels_start = torch.exp(best_output)
- labels_lpa = labels_start
- labels_lpa[idx_train, :] = labels_one_hot[idx_train, :].float()
- for i in range(params_lpa['epochs']):
- labels_lpa = (1 - alpha) * adj.mm(labels_lpa) + alpha * labels_start
- labels_lpa[idx_train, :] = labels_one_hot[idx_train, :].float()
- labels_lpa = half_normalize(labels_lpa)
-
- sample_weight = calculate_imbalance_weight(idx_lpa, labels_lpa.argmax(1))
- if use_cuda:
- sample_weight = sample_weight.cuda()
- ### empty stats
- for k, v in stats.items():
- stats[k] = []
-
- for epoch in range(params_student['epochs']):
- train_student(epoch)
- test_student()
-
- if show_stats:
- plt.figure()
- fig, axs = plt.subplots(len(states), len(metrics), figsize=(5 * len(metrics), 3 * len(states)))
- for i, s in enumerate(states):
- for j, m in enumerate(metrics):
- axs[i, j].plot(stats[m + '_' + s])
- axs[i, j].set_title(m + ' ' + s)
- plt.show()
-
- return stats
|