Official implementation of the Fake News Revealer paper
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 3.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. import numpy as np
  2. import torch
  3. class AvgMeter:
  4. def __init__(self, name="Metric"):
  5. self.name = name
  6. self.reset()
  7. def reset(self):
  8. self.avg, self.sum, self.count = [0] * 3
  9. def update(self, val, count=1):
  10. self.count += count
  11. self.sum += val * count
  12. self.avg = self.sum / self.count
  13. def __repr__(self):
  14. text = f"{self.name}: {self.avg:.4f}"
  15. return text
  16. def print_lr(optimizer):
  17. for param_group in optimizer.param_groups:
  18. print(param_group['name'], param_group['lr'])
  19. class CheckpointSaving:
  20. def __init__(self, path='checkpoint.pt', verbose=True, trace_func=print):
  21. self.best_score = None
  22. self.val_acc_max = 0
  23. self.path = path
  24. self.verbose = verbose
  25. self.trace_func = trace_func
  26. def __call__(self, val_acc, model):
  27. if self.best_score is None:
  28. self.best_score = val_acc
  29. self.save_checkpoint(val_acc, model)
  30. elif val_acc > self.best_score:
  31. self.best_score = val_acc
  32. self.save_checkpoint(val_acc, model)
  33. def save_checkpoint(self, val_acc, model):
  34. if self.verbose:
  35. self.trace_func(
  36. f'Validation accuracy increased ({self.val_acc_max:.6f} --> {val_acc:.6f}). Model saved ...')
  37. torch.save(model.state_dict(), self.path)
  38. self.val_acc_max = val_acc
  39. class EarlyStopping:
  40. def __init__(self, patience=10, verbose=False, delta=0.000001, path='checkpoint.pt', trace_func=print):
  41. self.patience = patience
  42. self.verbose = verbose
  43. self.counter = 0
  44. self.best_score = None
  45. self.early_stop = False
  46. self.val_loss_min = np.Inf
  47. self.delta = delta
  48. self.path = path
  49. self.trace_func = trace_func
  50. def __call__(self, val_loss, model):
  51. score = -val_loss
  52. if self.best_score is None:
  53. self.best_score = score
  54. if self.verbose:
  55. self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).')
  56. self.val_loss_min = val_loss
  57. # self.save_checkpoint(val_loss, model)
  58. elif score < self.best_score + self.delta:
  59. self.counter += 1
  60. self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
  61. if self.counter >= self.patience:
  62. self.early_stop = True
  63. else:
  64. self.best_score = score
  65. if self.verbose:
  66. self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).')
  67. # self.save_checkpoint(val_loss, model)
  68. self.val_loss_min = val_loss
  69. self.counter = 0
  70. # def save_checkpoint(self, val_loss, model):
  71. # if self.verbose:
  72. # self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Model saved ...')
  73. # torch.save(model.state_dict(), self.path)
  74. # self.val_loss_min = val_loss