Official implementation of the Fake News Revealer paper
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_main.py 3.4KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. import random
  2. import numpy as np
  3. import pandas as pd
  4. import torch
  5. from tqdm import tqdm
  6. from data_loaders import make_dfs, build_loaders
  7. from evaluation import metrics, report_per_class, roc_auc_plot, precision_recall_plot, plot_tsne, plot_pca, \
  8. save_embedding, save_loss
  9. from learner import batch_constructor
  10. from model import FakeNewsModel, calculate_loss
  11. def test(config, test_loader, trial_number=None):
  12. if trial_number:
  13. try:
  14. checkpoint = torch.load(str(config.output_path) + '/checkpoint_' + str(trial_number) + '.pt')
  15. except:
  16. checkpoint = torch.load(str(config.output_path) + '/checkpoint.pt')
  17. else:
  18. checkpoint = torch.load(str(config.output_path) + '/checkpoint.pt', map_location=torch.device(config.device))
  19. try:
  20. parameters = checkpoint['parameters']
  21. config.assign_hyperparameters(parameters)
  22. except:
  23. pass
  24. model = FakeNewsModel(config).to(config.device)
  25. try:
  26. model.load_state_dict(checkpoint['model_state_dict'])
  27. except:
  28. model.load_state_dict(checkpoint)
  29. model.eval()
  30. torch.manual_seed(27)
  31. random.seed(27)
  32. np.random.seed(27)
  33. image_features = []
  34. text_features = []
  35. multimodal_features = []
  36. concat_features = []
  37. targets = []
  38. predictions = []
  39. scores = []
  40. ids = []
  41. losses = []
  42. tqdm_object = tqdm(test_loader, total=len(test_loader))
  43. for i, batch in enumerate(tqdm_object):
  44. batch = batch_constructor(config, batch)
  45. with torch.no_grad():
  46. output, score = model(batch)
  47. loss, c_loss, s_loss = calculate_loss(model, score, batch['label'])
  48. predictions.append(output.detach())
  49. scores.append(score.detach())
  50. targets.append(batch['label'].detach())
  51. ids.append(batch['id'].detach())
  52. image_features.append(model.image_embeddings.detach())
  53. text_features.append(model.text_embeddings.detach())
  54. multimodal_features.append(model.multimodal_embeddings.detach())
  55. concat_features.append(model.classifier.embeddings.detach())
  56. losses.append((loss.detach(), c_loss.detach(), s_loss.detach()))
  57. s = ''
  58. s += report_per_class(targets, predictions) + '\n'
  59. s += metrics(targets, predictions, scores, file_path=str(config.output_path) + '/new_fpr_tpr.csv') + '\n'
  60. with open(config.output_path + '/new_results.txt', 'w') as f:
  61. f.write(s)
  62. roc_auc_plot(targets, scores, fname=str(config.output_path) + "/roc.png")
  63. precision_recall_plot(targets, scores, fname=str(config.output_path) + "/pr.png")
  64. save_embedding(image_features, fname=str(config.output_path) + '/new_image_features.tsv')
  65. save_embedding(text_features, fname=str(config.output_path) + '/new_text_features.tsv')
  66. save_embedding(multimodal_features, fname=str(config.output_path) + '/new_multimodal_features_.tsv')
  67. save_embedding(concat_features, fname=str(config.output_path) + '/new_concat_features.tsv')
  68. config_parameters = str(config)
  69. with open(config.output_path + '/new_parameters.txt', 'w') as f:
  70. f.write(config_parameters)
  71. save_loss(ids, predictions, targets, losses, str(config.output_path) + '/new_text_label.csv')
  72. def test_main(config, trial_number=None):
  73. train_df, test_df, validation_df = make_dfs(config, )
  74. test_loader = build_loaders(config, test_df, mode="test")
  75. test(config, test_loader, trial_number)