Official implementation of the Fake News Revealer paper
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_main.py 4.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. import random
  2. import numpy as np
  3. import torch
  4. from tqdm import tqdm
  5. from data_loaders import make_dfs, build_loaders
  6. from evaluation import metrics, report_per_class, roc_auc_plot, precision_recall_plot, plot_tsne, plot_pca
  7. from learner import batch_constructor
  8. from model import FakeNewsModel
  9. def test(config, test_loader, trial_number=None):
  10. if trial_number:
  11. try:
  12. checkpoint = torch.load(str(config.output_path) + '/checkpoint_' + str(trial_number) + '.pt')
  13. except:
  14. checkpoint = torch.load(str(config.output_path) + '/checkpoint.pt')
  15. else:
  16. checkpoint = torch.load(str(config.output_path) + '/checkpoint.pt', map_location=torch.device(config.device))
  17. try:
  18. parameters = checkpoint['parameters']
  19. config.assign_hyperparameters(parameters)
  20. except:
  21. pass
  22. model = FakeNewsModel(config).to(config.device)
  23. try:
  24. model.load_state_dict(checkpoint['model_state_dict'])
  25. except:
  26. model.load_state_dict(checkpoint)
  27. model.eval()
  28. torch.manual_seed(27)
  29. random.seed(27)
  30. np.random.seed(27)
  31. image_features = []
  32. text_features = []
  33. multimodal_features = []
  34. concat_features = []
  35. targets = []
  36. predictions = []
  37. scores = []
  38. tqdm_object = tqdm(test_loader, total=len(test_loader))
  39. for i, batch in enumerate(tqdm_object):
  40. batch = batch_constructor(config, batch)
  41. with torch.no_grad():
  42. output, score = model(batch)
  43. prediction = output.detach()
  44. predictions.append(prediction)
  45. score = score.detach()
  46. scores.append(score)
  47. target = batch['label'].detach()
  48. targets.append(target)
  49. image_feature = model.image_embeddings.detach()
  50. image_features.append(image_feature)
  51. text_feature = model.text_embeddings.detach()
  52. text_features.append(text_feature)
  53. multimodal_feature = model.multimodal_embeddings.detach()
  54. multimodal_features.append(multimodal_feature)
  55. concat_feature = model.classifier.embeddings.detach()
  56. concat_features.append(concat_feature)
  57. # config.writer.add_graph(model, input_to_model=batch, verbose=True)
  58. s = ''
  59. s += report_per_class(targets, predictions) + '\n'
  60. s += metrics(targets, predictions, scores, file_path=str(config.output_path) + '/fpr_tpr.csv') + '\n'
  61. with open(config.output_path + '/results.txt', 'w') as f:
  62. f.write(s)
  63. roc_auc_plot(targets, scores, fname=str(config.output_path) + "/roc.png")
  64. precision_recall_plot(targets, scores, fname=str(config.output_path) + "/pr.png")
  65. # saving_in_tensorboard(config, image_features, targets, 'image_features')
  66. plot_tsne(config, image_features, targets, fname=str(config.output_path) + '/image_features_tsne.png')
  67. plot_pca(config, image_features, targets, fname=str(config.output_path) + '/image_features_pca.png')
  68. # saving_in_tensorboard(config, text_features, targets, 'text_features')
  69. plot_tsne(config, text_features, targets, fname=str(config.output_path) + '/text_features_tsne.png')
  70. plot_pca(config, text_features, targets, fname=str(config.output_path) + '/text_features_pca.png')
  71. #
  72. # saving_in_tensorboard(config, multimodal_features, targets, 'multimodal_features')
  73. plot_tsne(config, multimodal_features, targets, fname=str(config.output_path) + '/multimodal_features_tsne.png')
  74. plot_pca(config, multimodal_features, targets, fname=str(config.output_path) + '/multimodal_features_pca.png')
  75. # saving_in_tensorboard(config, concat_features, targets, 'concat_features')
  76. plot_tsne(config, concat_features, targets, fname=str(config.output_path) + '/concat_features_tsne.png')
  77. plot_pca(config, concat_features, targets, fname=str(config.output_path) + '/concat_features_pca.png')
  78. config_parameters = str(config)
  79. with open(config.output_path + '/parameters.txt', 'w') as f:
  80. f.write(config_parameters)
  81. print(config)
  82. def test_main(config, trial_number=None):
  83. train_df, test_df, validation_df = make_dfs(config, )
  84. test_loader = build_loaders(config, test_df, mode="test")
  85. test(config, test_loader, trial_number)