Official implementation of the Fake News Revealer paper
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

learner.py 8.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. import gc
  2. import itertools
  3. import random
  4. import numpy as np
  5. import optuna
  6. import pandas as pd
  7. import torch
  8. from evaluation import multiclass_acc
  9. from model import FakeNewsModel, calculate_loss
  10. from utils import AvgMeter, print_lr, EarlyStopping, CheckpointSaving
  11. def batch_constructor(config, batch):
  12. b = {}
  13. for k, v in batch.items():
  14. if k != 'text':
  15. b[k] = v.to(config.device)
  16. return b
  17. def train_epoch(config, model, train_loader, optimizer, scalar):
  18. loss_meter = AvgMeter('train')
  19. c_loss_meter = AvgMeter('train')
  20. s_loss_meter = AvgMeter('train')
  21. # tqdm_object = tqdm(train_loader, total=len(train_loader))
  22. targets = []
  23. predictions = []
  24. for index, batch in enumerate(train_loader):
  25. batch = batch_constructor(config, batch)
  26. optimizer.zero_grad(set_to_none=True)
  27. with torch.cuda.amp.autocast():
  28. output, score = model(batch)
  29. loss, c_loss, s_loss = calculate_loss(model, score, batch['label'])
  30. scalar.scale(loss).backward()
  31. torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
  32. if (index + 1) % 2:
  33. scalar.step(optimizer)
  34. # loss.backward()
  35. # optimizer.step()
  36. scalar.update()
  37. count = batch["id"].size(0)
  38. loss_meter.update(loss.detach(), count)
  39. c_loss_meter.update(c_loss.detach(), count)
  40. s_loss_meter.update(s_loss.detach(), count)
  41. # tqdm_object.set_postfix(train_loss=loss_meter.avg, lr=get_lr(optimizer))
  42. prediction = output.detach()
  43. predictions.append(prediction)
  44. target = batch['label'].detach()
  45. targets.append(target)
  46. losses = (loss_meter, s_loss_meter, c_loss_meter)
  47. return losses, targets, predictions
  48. def validation_epoch(config, model, validation_loader):
  49. loss_meter = AvgMeter('validation')
  50. c_loss_meter = AvgMeter('validation')
  51. s_loss_meter = AvgMeter('validation')
  52. targets = []
  53. predictions = []
  54. # tqdm_object = tqdm(validation_loader, total=len(validation_loader))
  55. for batch in validation_loader:
  56. batch = batch_constructor(config, batch)
  57. with torch.no_grad():
  58. output, score = model(batch)
  59. loss, c_loss, s_loss = calculate_loss(model, score, batch['label'])
  60. count = batch["id"].size(0)
  61. loss_meter.update(loss.detach(), count)
  62. c_loss_meter.update(c_loss.detach(), count)
  63. s_loss_meter.update(s_loss.detach(), count)
  64. # tqdm_object.set_postfix(validation_loss=loss_meter.avg)
  65. prediction = output.detach()
  66. predictions.append(prediction)
  67. target = batch['label'].detach()
  68. targets.append(target)
  69. losses = (loss_meter, s_loss_meter, c_loss_meter)
  70. return losses, targets, predictions
  71. def supervised_train(config, train_loader, validation_loader, trial=None):
  72. torch.cuda.empty_cache()
  73. checkpoint_path2 = checkpoint_path = str(config.output_path) + '/checkpoint.pt'
  74. if trial:
  75. checkpoint_path2 = str(config.output_path) + '/checkpoint_' + str(trial.number) + '.pt'
  76. torch.manual_seed(27)
  77. random.seed(27)
  78. np.random.seed(27)
  79. torch.autograd.set_detect_anomaly(False)
  80. torch.autograd.profiler.profile(False)
  81. torch.autograd.profiler.emit_nvtx(False)
  82. scalar = torch.cuda.amp.GradScaler()
  83. model = FakeNewsModel(config).to(config.device)
  84. params = [
  85. {"params": model.image_encoder.parameters(), "lr": config.image_encoder_lr, "name": 'image_encoder'},
  86. {"params": model.text_encoder.parameters(), "lr": config.text_encoder_lr, "name": 'text_encoder'},
  87. {"params": itertools.chain(model.image_projection.parameters(), model.text_projection.parameters()),
  88. "lr": config.head_lr, "weight_decay": config.head_weight_decay, 'name': 'projection'},
  89. {"params": model.classifier.parameters(), "lr": config.classification_lr,
  90. "weight_decay": config.classification_weight_decay,
  91. 'name': 'classifier'}
  92. ]
  93. optimizer = torch.optim.AdamW(params, amsgrad=True)
  94. lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=config.factor,
  95. patience=config.patience // 5, verbose=True)
  96. early_stopping = EarlyStopping(patience=config.patience, delta=config.delta, path=checkpoint_path, verbose=True)
  97. checkpoint_saving = CheckpointSaving(path=checkpoint_path, verbose=True)
  98. train_losses, train_accuracies = [], []
  99. validation_losses, validation_accuracies = [], []
  100. validation_accuracy, validation_loss = 0, 1
  101. for epoch in range(config.epochs):
  102. print(f"Epoch: {epoch + 1}")
  103. gc.collect()
  104. model.train()
  105. train_loss, train_truth, train_pred = train_epoch(config, model, train_loader, optimizer, scalar)
  106. model.eval()
  107. with torch.no_grad():
  108. validation_loss, validation_truth, validation_pred = validation_epoch(config, model, validation_loader)
  109. train_accuracy = multiclass_acc(train_truth, train_pred)
  110. validation_accuracy = multiclass_acc(validation_truth, validation_pred)
  111. print_lr(optimizer)
  112. print('Training Loss:', train_loss[0], 'Training Accuracy:', train_accuracy)
  113. print('Validation Loss', validation_loss, 'Validation Accuracy:', validation_accuracy)
  114. if lr_scheduler:
  115. lr_scheduler.step(validation_loss[0].avg)
  116. if early_stopping:
  117. early_stopping(validation_loss[0].avg, model)
  118. if early_stopping.early_stop:
  119. print("Early stopping")
  120. break
  121. if checkpoint_saving:
  122. checkpoint_saving(validation_accuracy, model)
  123. train_accuracies.append(train_accuracy)
  124. train_losses.append(train_loss)
  125. validation_accuracies.append(validation_accuracy)
  126. validation_losses.append(validation_loss)
  127. if trial:
  128. trial.report(validation_accuracy, epoch)
  129. if trial.should_prune():
  130. print('trial pruned')
  131. raise optuna.exceptions.TrialPruned()
  132. print()
  133. if checkpoint_saving:
  134. model = FakeNewsModel(config).to(config.device)
  135. model.load_state_dict(torch.load(checkpoint_path))
  136. model.eval()
  137. with torch.no_grad():
  138. validation_loss, validation_truth, validation_pred = validation_epoch(config, model, validation_loader)
  139. validation_accuracy = multiclass_acc(validation_pred, validation_truth)
  140. if trial and validation_accuracy >= config.wanted_accuracy:
  141. loss_accuracy = pd.DataFrame(
  142. {'train_loss': train_losses, 'train_accuracy': train_accuracies, 'validation_loss': validation_losses,
  143. 'validation_accuracy': validation_accuracies})
  144. torch.save({'model_state_dict': model.state_dict(),
  145. 'parameters': str(config),
  146. 'optimizer_state_dict': optimizer.state_dict(),
  147. 'loss_accuracy': loss_accuracy}, checkpoint_path2)
  148. if not checkpoint_saving:
  149. loss_accuracy = pd.DataFrame(
  150. {'train_loss': train_losses, 'train_accuracy': train_accuracies, 'validation_loss': validation_losses,
  151. 'validation_accuracy': validation_accuracies})
  152. torch.save(model.state_dict(), checkpoint_path)
  153. if trial and validation_accuracy >= config.wanted_accuracy:
  154. torch.save({'model_state_dict': model.state_dict(),
  155. 'parameters': str(config),
  156. 'optimizer_state_dict': optimizer.state_dict(),
  157. 'loss_accuracy': loss_accuracy}, checkpoint_path2)
  158. try:
  159. del train_loss
  160. del train_truth
  161. del train_pred
  162. del validation_loss
  163. del validation_truth
  164. del validation_pred
  165. del train_losses
  166. del train_accuracies
  167. del validation_losses
  168. del validation_accuracies
  169. del loss_accuracy
  170. del scalar
  171. del model
  172. del params
  173. except:
  174. print('Error in deleting caches')
  175. pass
  176. return validation_accuracy