Official implementation of the Fake News Revealer paper
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_main.py 3.2KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. import random
  2. import numpy as np
  3. import pandas as pd
  4. import torch
  5. from tqdm import tqdm
  6. from data_loaders import make_dfs, build_loaders
  7. from evaluation import *
  8. from learner import batch_constructor
  9. from model import FakeNewsModel, calculate_loss
  10. def test(config, test_loader, trial_number=None):
  11. try:
  12. checkpoint = torch.load(str(config.output_path) + '/checkpoint_' + str(trial_number) + '.pt')
  13. except:
  14. checkpoint = torch.load(str(config.output_path) + '/checkpoint.pt')
  15. try:
  16. parameters = checkpoint['parameters']
  17. config.assign_hyperparameters(parameters)
  18. except:
  19. pass
  20. model = FakeNewsModel(config).to(config.device)
  21. try:
  22. model.load_state_dict(checkpoint['model_state_dict'])
  23. except:
  24. model.load_state_dict(checkpoint)
  25. model.eval()
  26. torch.manual_seed(27)
  27. random.seed(27)
  28. np.random.seed(27)
  29. image_features = []
  30. text_features = []
  31. multimodal_features = []
  32. concat_features = []
  33. targets = []
  34. predictions = []
  35. scores = []
  36. ids = []
  37. losses = []
  38. similarities = []
  39. tqdm_object = tqdm(test_loader, total=len(test_loader))
  40. for i, batch in enumerate(tqdm_object):
  41. batch = batch_constructor(config, batch)
  42. with torch.no_grad():
  43. output, score = model(batch)
  44. loss, c_loss, s_loss = calculate_loss(model, score, batch['label'])
  45. predictions.append(output.detach())
  46. scores.append(score.detach())
  47. targets.append(batch['label'].detach())
  48. ids.append(batch['id'].detach())
  49. image_features.append(model.image_embeddings.detach())
  50. text_features.append(model.text_embeddings.detach())
  51. multimodal_features.append(model.multimodal_embeddings.detach())
  52. concat_features.append(model.classifier.embeddings.detach())
  53. similarities.append(model.similarity.detach())
  54. losses.append((loss.detach(), c_loss.detach(), s_loss.detach()))
  55. s = ''
  56. s += report_per_class(targets, predictions) + '\n'
  57. s += metrics(targets, predictions, scores, file_path=str(config.output_path) + '/fpr_tpr.csv') + '\n'
  58. with open(config.output_path + '/results.txt', 'w') as f:
  59. f.write(s)
  60. roc_auc_plot(targets, scores, fname=str(config.output_path) + "/roc.png")
  61. precision_recall_plot(targets, scores, fname=str(config.output_path) + "/pr.png")
  62. save_embedding(image_features, fname=str(config.output_path) + '/image_features.tsv')
  63. save_embedding(text_features, fname=str(config.output_path) + '/text_features.tsv')
  64. save_embedding(multimodal_features, fname=str(config.output_path) + '/multimodal_features_.tsv')
  65. save_embedding(concat_features, fname=str(config.output_path) + '/concat_features.tsv')
  66. save_2D_embedding(similarities, fname=str(config.output_path))
  67. config_parameters = str(config)
  68. with open(config.output_path + '/parameters.txt', 'w') as f:
  69. f.write(config_parameters)
  70. save_loss(ids, predictions, targets, losses, str(config.output_path) + '/text_label.csv')
  71. def test_main(config, trial_number=None):
  72. train_df, test_df, validation_df = make_dfs(config, )
  73. test_loader = build_loaders(config, test_df, mode="test")
  74. test(config, test_loader, trial_number)