Official implementation of the Fake News Revealer paper
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 3.2KB

2 years ago
2 years ago
2 years ago
2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. import numpy as np
  2. import torch
  3. import os
  4. class AvgMeter:
  5. def __init__(self, name="Metric"):
  6. self.name = name
  7. self.reset()
  8. def reset(self):
  9. self.avg, self.sum, self.count = [0] * 3
  10. def update(self, val, count=1):
  11. self.count += count
  12. self.sum += val * count
  13. self.avg = self.sum / self.count
  14. def __repr__(self):
  15. text = f"{self.name}: {self.avg:.4f}"
  16. return text
  17. def print_lr(optimizer):
  18. for param_group in optimizer.param_groups:
  19. print(param_group['name'], param_group['lr'])
  20. class CheckpointSaving:
  21. def __init__(self, path='checkpoint.pt', verbose=True, trace_func=print):
  22. self.best_score = None
  23. self.val_acc_max = 0
  24. self.path = path
  25. self.verbose = verbose
  26. self.trace_func = trace_func
  27. def __call__(self, val_acc, model):
  28. if self.best_score is None:
  29. self.best_score = val_acc
  30. self.save_checkpoint(val_acc, model)
  31. elif val_acc > self.best_score:
  32. self.best_score = val_acc
  33. self.save_checkpoint(val_acc, model)
  34. def save_checkpoint(self, val_acc, model):
  35. if self.verbose:
  36. self.trace_func(
  37. f'Validation accuracy increased ({self.val_acc_max:.6f} --> {val_acc:.6f}). Model saved ...')
  38. torch.save(model.state_dict(), self.path)
  39. self.val_acc_max = val_acc
  40. class EarlyStopping:
  41. def __init__(self, patience=10, verbose=False, delta=0.000001, path='checkpoint.pt', trace_func=print):
  42. self.patience = patience
  43. self.verbose = verbose
  44. self.counter = 0
  45. self.best_score = None
  46. self.early_stop = False
  47. self.val_loss_min = np.Inf
  48. self.delta = delta
  49. self.path = path
  50. self.trace_func = trace_func
  51. def __call__(self, val_loss, model):
  52. score = -val_loss
  53. if self.best_score is None:
  54. self.best_score = score
  55. if self.verbose:
  56. self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).')
  57. self.val_loss_min = val_loss
  58. # self.save_checkpoint(val_loss, model)
  59. elif score < self.best_score + self.delta:
  60. self.counter += 1
  61. self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
  62. if self.counter >= self.patience:
  63. self.early_stop = True
  64. else:
  65. self.best_score = score
  66. if self.verbose:
  67. self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).')
  68. # self.save_checkpoint(val_loss, model)
  69. self.val_loss_min = val_loss
  70. self.counter = 0
  71. # def save_checkpoint(self, val_loss, model):
  72. # if self.verbose:
  73. # self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Model saved ...')
  74. # torch.save(model.state_dict(), self.path)
  75. # self.val_loss_min = val_loss
  76. def make_directory(path):
  77. try:
  78. os.mkdir(path)
  79. except OSError:
  80. print("Creation of the directory failed")
  81. else:
  82. print("Successfully created the directory")