Official implementation of the Fake News Revealer paper
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

learner.py 8.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. import gc
  2. import itertools
  3. import random
  4. import numpy as np
  5. import optuna
  6. import pandas as pd
  7. import torch
  8. from evaluation import multiclass_acc
  9. from model import FakeNewsModel, calculate_loss
  10. from utils import AvgMeter, print_lr, EarlyStopping, CheckpointSaving
  11. def batch_constructor(config, batch):
  12. b = {}
  13. for key, value in batch.items():
  14. if key != 'text':
  15. b[key] = value.to(config.device)
  16. else:
  17. b[key] = value
  18. return b
  19. def train_epoch(config, model, train_loader, optimizer, scalar):
  20. loss_meter = AvgMeter('train')
  21. c_loss_meter = AvgMeter('train')
  22. s_loss_meter = AvgMeter('train')
  23. # tqdm_object = tqdm(train_loader, total=len(train_loader))
  24. targets = []
  25. predictions = []
  26. for index, batch in enumerate(train_loader):
  27. batch = batch_constructor(config, batch)
  28. optimizer.zero_grad(set_to_none=True)
  29. with torch.cuda.amp.autocast():
  30. output, score = model(batch)
  31. loss, c_loss, s_loss = calculate_loss(model, score, batch['label'])
  32. scalar.scale(loss).backward()
  33. torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
  34. if (index + 1) % 2:
  35. scalar.step(optimizer)
  36. # loss.backward()
  37. # optimizer.step()
  38. scalar.update()
  39. count = batch["id"].size(0)
  40. loss_meter.update(loss.detach(), count)
  41. c_loss_meter.update(c_loss.detach(), count)
  42. s_loss_meter.update(s_loss.detach(), count)
  43. # tqdm_object.set_postfix(train_loss=loss_meter.avg, lr=get_lr(optimizer))
  44. prediction = output.detach()
  45. predictions.append(prediction)
  46. target = batch['label'].detach()
  47. targets.append(target)
  48. losses = (loss_meter, s_loss_meter, c_loss_meter)
  49. return losses, targets, predictions
  50. def validation_epoch(config, model, validation_loader):
  51. loss_meter = AvgMeter('validation')
  52. c_loss_meter = AvgMeter('validation')
  53. s_loss_meter = AvgMeter('validation')
  54. targets = []
  55. predictions = []
  56. # tqdm_object = tqdm(validation_loader, total=len(validation_loader))
  57. for batch in validation_loader:
  58. batch = batch_constructor(config, batch)
  59. with torch.no_grad():
  60. output, score = model(batch)
  61. loss, c_loss, s_loss = calculate_loss(model, score, batch['label'])
  62. count = batch["id"].size(0)
  63. loss_meter.update(loss.detach(), count)
  64. c_loss_meter.update(c_loss.detach(), count)
  65. s_loss_meter.update(s_loss.detach(), count)
  66. # tqdm_object.set_postfix(validation_loss=loss_meter.avg)
  67. prediction = output.detach()
  68. predictions.append(prediction)
  69. target = batch['label'].detach()
  70. targets.append(target)
  71. losses = (loss_meter, s_loss_meter, c_loss_meter)
  72. return losses, targets, predictions
  73. def supervised_train(config, train_loader, validation_loader, trial=None):
  74. torch.cuda.empty_cache()
  75. checkpoint_path2 = checkpoint_path = str(config.output_path) + '/checkpoint.pt'
  76. if trial:
  77. checkpoint_path2 = str(config.output_path) + '/checkpoint_' + str(trial.number) + '.pt'
  78. torch.manual_seed(27)
  79. random.seed(27)
  80. np.random.seed(27)
  81. torch.autograd.set_detect_anomaly(False)
  82. torch.autograd.profiler.profile(False)
  83. torch.autograd.profiler.emit_nvtx(False)
  84. scalar = torch.cuda.amp.GradScaler()
  85. model = FakeNewsModel(config).to(config.device)
  86. params = [
  87. {"params": model.image_encoder.parameters(), "lr": config.image_encoder_lr, "name": 'image_encoder'},
  88. {"params": model.text_encoder.parameters(), "lr": config.text_encoder_lr, "name": 'text_encoder'},
  89. {"params": itertools.chain(model.image_projection.parameters(), model.text_projection.parameters()),
  90. "lr": config.head_lr, "weight_decay": config.head_weight_decay, 'name': 'projection'},
  91. {"params": model.classifier.parameters(), "lr": config.classification_lr,
  92. "weight_decay": config.classification_weight_decay,
  93. 'name': 'classifier'}
  94. ]
  95. optimizer = torch.optim.AdamW(params, amsgrad=True)
  96. lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=config.factor,
  97. patience=config.patience // 5, verbose=True)
  98. early_stopping = EarlyStopping(patience=config.patience, delta=config.delta, path=checkpoint_path, verbose=True)
  99. checkpoint_saving = CheckpointSaving(path=checkpoint_path, verbose=True)
  100. train_losses, train_accuracies = [], []
  101. validation_losses, validation_accuracies = [], []
  102. validation_accuracy, validation_loss = 0, 1
  103. for epoch in range(config.epochs):
  104. print(f"Epoch: {epoch + 1}")
  105. gc.collect()
  106. model.train()
  107. train_loss, train_truth, train_pred = train_epoch(config, model, train_loader, optimizer, scalar)
  108. model.eval()
  109. with torch.no_grad():
  110. validation_loss, validation_truth, validation_pred = validation_epoch(config, model, validation_loader)
  111. train_accuracy = multiclass_acc(train_truth, train_pred)
  112. validation_accuracy = multiclass_acc(validation_truth, validation_pred)
  113. print_lr(optimizer)
  114. print('Training Loss:', train_loss[0], 'Training Accuracy:', train_accuracy)
  115. print('Validation Loss', validation_loss, 'Validation Accuracy:', validation_accuracy)
  116. if lr_scheduler:
  117. lr_scheduler.step(validation_loss[0].avg)
  118. if early_stopping:
  119. early_stopping(validation_loss[0].avg, model)
  120. if early_stopping.early_stop:
  121. print("Early stopping")
  122. break
  123. if checkpoint_saving:
  124. checkpoint_saving(validation_accuracy, model)
  125. train_accuracies.append(train_accuracy)
  126. train_losses.append(train_loss)
  127. validation_accuracies.append(validation_accuracy)
  128. validation_losses.append(validation_loss)
  129. if trial:
  130. trial.report(validation_accuracy, epoch)
  131. if trial.should_prune():
  132. print('trial pruned')
  133. raise optuna.exceptions.TrialPruned()
  134. print()
  135. if checkpoint_saving:
  136. model = FakeNewsModel(config).to(config.device)
  137. model.load_state_dict(torch.load(checkpoint_path))
  138. model.eval()
  139. with torch.no_grad():
  140. validation_loss, validation_truth, validation_pred = validation_epoch(config, model, validation_loader)
  141. validation_accuracy = multiclass_acc(validation_pred, validation_truth)
  142. if trial and validation_accuracy >= config.wanted_accuracy:
  143. loss_accuracy = pd.DataFrame(
  144. {'train_loss': train_losses, 'train_accuracy': train_accuracies, 'validation_loss': validation_losses,
  145. 'validation_accuracy': validation_accuracies})
  146. torch.save({'model_state_dict': model.state_dict(),
  147. 'parameters': str(config),
  148. 'optimizer_state_dict': optimizer.state_dict(),
  149. 'loss_accuracy': loss_accuracy}, checkpoint_path2)
  150. if not checkpoint_saving:
  151. loss_accuracy = pd.DataFrame(
  152. {'train_loss': train_losses, 'train_accuracy': train_accuracies, 'validation_loss': validation_losses,
  153. 'validation_accuracy': validation_accuracies})
  154. torch.save(model.state_dict(), checkpoint_path)
  155. if trial and validation_accuracy >= config.wanted_accuracy:
  156. torch.save({'model_state_dict': model.state_dict(),
  157. 'parameters': str(config),
  158. 'optimizer_state_dict': optimizer.state_dict(),
  159. 'loss_accuracy': loss_accuracy}, checkpoint_path2)
  160. try:
  161. del train_loss
  162. del train_truth
  163. del train_pred
  164. del validation_loss
  165. del validation_truth
  166. del validation_pred
  167. del train_losses
  168. del train_accuracies
  169. del validation_losses
  170. del validation_accuracies
  171. del loss_accuracy
  172. del scalar
  173. del model
  174. del params
  175. except:
  176. print('Error in deleting caches')
  177. pass
  178. return validation_accuracy