Official implementation of the Fake News Revealer paper
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

learner.py 7.5KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. import gc
  2. import itertools
  3. import random
  4. import numpy as np
  5. import optuna
  6. import pandas as pd
  7. import torch
  8. from evaluation import multiclass_acc
  9. from model import FakeNewsModel, calculate_loss
  10. from utils import AvgMeter, print_lr, EarlyStopping, CheckpointSaving
  11. def batch_constructor(config, batch):
  12. b = {}
  13. for k, v in batch.items():
  14. if k != 'text':
  15. b[k] = v.to(config.device)
  16. return b
  17. def train_epoch(config, model, train_loader, optimizer, scalar):
  18. loss_meter = AvgMeter('train')
  19. # tqdm_object = tqdm(train_loader, total=len(train_loader))
  20. targets = []
  21. predictions = []
  22. for index, batch in enumerate(train_loader):
  23. batch = batch_constructor(config, batch)
  24. optimizer.zero_grad(set_to_none=True)
  25. with torch.cuda.amp.autocast():
  26. output, score = model(batch)
  27. loss = calculate_loss(model, score, batch['label'])
  28. scalar.scale(loss).backward()
  29. torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)
  30. if (index + 1) % 2:
  31. scalar.step(optimizer)
  32. # loss.backward()
  33. # optimizer.step()
  34. scalar.update()
  35. count = batch["id"].size(0)
  36. loss_meter.update(loss.detach(), count)
  37. # tqdm_object.set_postfix(train_loss=loss_meter.avg, lr=get_lr(optimizer))
  38. prediction = output.detach()
  39. predictions.append(prediction)
  40. target = batch['label'].detach()
  41. targets.append(target)
  42. return loss_meter, targets, predictions
  43. def validation_epoch(config, model, validation_loader):
  44. loss_meter = AvgMeter('validation')
  45. targets = []
  46. predictions = []
  47. # tqdm_object = tqdm(validation_loader, total=len(validation_loader))
  48. for batch in validation_loader:
  49. batch = batch_constructor(config, batch)
  50. with torch.no_grad():
  51. output, score = model(batch)
  52. loss = calculate_loss(model, score, batch['label'])
  53. count = batch["id"].size(0)
  54. loss_meter.update(loss.detach(), count)
  55. # tqdm_object.set_postfix(validation_loss=loss_meter.avg)
  56. prediction = output.detach()
  57. predictions.append(prediction)
  58. target = batch['label'].detach()
  59. targets.append(target)
  60. return loss_meter, targets, predictions
  61. def supervised_train(config, train_loader, validation_loader, trial=None):
  62. torch.cuda.empty_cache()
  63. checkpoint_path2 = checkpoint_path = str(config.output_path) + '/checkpoint.pt'
  64. if trial:
  65. checkpoint_path2 = str(config.output_path) + '/checkpoint_' + str(trial.number) + '.pt'
  66. torch.manual_seed(27)
  67. random.seed(27)
  68. np.random.seed(27)
  69. torch.autograd.set_detect_anomaly(False)
  70. torch.autograd.profiler.profile(False)
  71. torch.autograd.profiler.emit_nvtx(False)
  72. scalar = torch.cuda.amp.GradScaler()
  73. model = FakeNewsModel(config).to(config.device)
  74. params = [
  75. {"params": model.image_encoder.parameters(), "lr": config.image_encoder_lr, "name": 'image_encoder'},
  76. {"params": model.text_encoder.parameters(), "lr": config.text_encoder_lr, "name": 'text_encoder'},
  77. {"params": itertools.chain(model.image_projection.parameters(), model.text_projection.parameters()),
  78. "lr": config.head_lr, "weight_decay": config.head_weight_decay, 'name': 'projection'},
  79. {"params": model.classifier.parameters(), "lr": config.classification_lr,
  80. "weight_decay": config.classification_weight_decay,
  81. 'name': 'classifier'}
  82. ]
  83. optimizer = torch.optim.AdamW(params, amsgrad=True)
  84. lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=config.factor,
  85. patience=config.patience // 5, verbose=True)
  86. early_stopping = EarlyStopping(patience=config.patience, delta=config.delta, path=checkpoint_path, verbose=True)
  87. checkpoint_saving = CheckpointSaving(path=checkpoint_path, verbose=True)
  88. train_losses, train_accuracies = [], []
  89. validation_losses, validation_accuracies = [], []
  90. validation_accuracy, validation_loss = 0, 1
  91. for epoch in range(config.epochs):
  92. print(f"Epoch: {epoch + 1}")
  93. gc.collect()
  94. model.train()
  95. train_loss, train_truth, train_pred = train_epoch(config, model, train_loader, optimizer, scalar)
  96. model.eval()
  97. with torch.no_grad():
  98. validation_loss, validation_truth, validation_pred = validation_epoch(config, model, validation_loader)
  99. train_accuracy = multiclass_acc(train_truth, train_pred)
  100. validation_accuracy = multiclass_acc(validation_truth, validation_pred)
  101. print_lr(optimizer)
  102. print('Training Loss:', train_loss, 'Training Accuracy:', train_accuracy)
  103. print('Validation Loss', validation_loss, 'Validation Accuracy:', validation_accuracy)
  104. if lr_scheduler:
  105. lr_scheduler.step(validation_loss.avg)
  106. if early_stopping:
  107. early_stopping(validation_loss.avg, model)
  108. if early_stopping.early_stop:
  109. print("Early stopping")
  110. break
  111. if checkpoint_saving:
  112. checkpoint_saving(validation_accuracy, model)
  113. train_accuracies.append(train_accuracy)
  114. train_losses.append(train_loss)
  115. validation_accuracies.append(validation_accuracy)
  116. validation_losses.append(validation_loss)
  117. if trial:
  118. trial.report(validation_accuracy, epoch)
  119. if trial.should_prune():
  120. print('trial pruned')
  121. raise optuna.exceptions.TrialPruned()
  122. print()
  123. if checkpoint_saving:
  124. model = FakeNewsModel(config).to(config.device)
  125. model.load_state_dict(torch.load(checkpoint_path))
  126. model.eval()
  127. with torch.no_grad():
  128. validation_loss, validation_truth, validation_pred = validation_epoch(config, model, validation_loader)
  129. validation_accuracy = multiclass_acc(validation_pred, validation_truth)
  130. if trial and validation_accuracy >= config.wanted_accuracy:
  131. loss_accuracy = pd.DataFrame(
  132. {'train_loss': train_losses, 'train_accuracy': train_accuracies, 'validation_loss': validation_losses,
  133. 'validation_accuracy': validation_accuracies})
  134. torch.save({'model_state_dict': model.state_dict(),
  135. 'parameters': str(config),
  136. 'optimizer_state_dict': optimizer.state_dict(),
  137. 'loss_accuracy': loss_accuracy}, checkpoint_path2)
  138. if not checkpoint_saving:
  139. loss_accuracy = pd.DataFrame(
  140. {'train_loss': train_losses, 'train_accuracy': train_accuracies, 'validation_loss': validation_losses,
  141. 'validation_accuracy': validation_accuracies})
  142. torch.save(model.state_dict(), checkpoint_path)
  143. if trial and validation_accuracy >= config.wanted_accuracy:
  144. torch.save({'model_state_dict': model.state_dict(),
  145. 'parameters': str(config),
  146. 'optimizer_state_dict': optimizer.state_dict(),
  147. 'loss_accuracy': loss_accuracy}, checkpoint_path2)
  148. try:
  149. del train_loss
  150. del train_truth
  151. del train_pred
  152. del validation_loss
  153. del validation_truth
  154. del validation_pred
  155. del train_losses
  156. del train_accuracies
  157. del validation_losses
  158. del validation_accuracies
  159. del loss_accuracy
  160. del scalar
  161. del model
  162. del params
  163. except:
  164. print('Error in deleting caches')
  165. pass
  166. return validation_accuracy