You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 1.4KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. import torch
  2. import numpy as np
  3. from sklearn.metrics import f1_score, roc_auc_score
  4. def accuracy(output, labels):
  5. preds = output.max(1)[1].type_as(labels)
  6. correct = preds.eq(labels).double()
  7. correct = correct.sum()
  8. return correct / len(labels)
  9. def class_f1(output, labels, type='micro', pos_label=1):
  10. preds = output.max(1)[1].type_as(labels)
  11. return f1_score(labels.detach().cpu().numpy(), preds.cpu(), average=type, pos_label=pos_label)
  12. def roc_auc(output, labels):
  13. return roc_auc_score(labels.cpu().numpy(), output.detach().cpu().numpy())
  14. def loss(output,labels, weights=None):
  15. if weights is None:
  16. weights = torch.ones(labels.shape[0])
  17. return torch.sum(- weights * (labels.float() * output).sum(1), -1)
  18. def half_normalize(mx):
  19. rowsum = mx.sum(1).float()
  20. r_inv = rowsum.pow(-1).flatten()
  21. r_inv[torch.isinf(r_inv)] = 0.
  22. r_mat_inv = torch.diag(r_inv)
  23. mx = r_mat_inv.mm(mx)
  24. return mx
  25. def encode_onehot_torch(labels,num_classes=None):
  26. if num_classes is None:
  27. num_classes = int(labels.max() + 1)
  28. y = torch.eye(num_classes)
  29. return y[labels]
  30. def calculate_imbalance_weight(idx,labels):
  31. weights = torch.ones(len(labels))
  32. for i in range(labels.max()+1):
  33. sub_node = torch.where(labels == i)[0]
  34. sub_idx = [x.item() for x in sub_node if x in idx]
  35. weights[sub_idx] = 1 - len(sub_idx)/ len(idx)
  36. return weights