Official implementation of the Fake News Revealer paper
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_main.py 3.4KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. import random
  2. import numpy as np
  3. import pandas as pd
  4. import torch
  5. from tqdm import tqdm
  6. from data_loaders import make_dfs, build_loaders
  7. from evaluation import metrics, report_per_class, roc_auc_plot, precision_recall_plot, plot_tsne, plot_pca, \
  8. save_embedding, save_loss
  9. from learner import batch_constructor
  10. from model import FakeNewsModel, calculate_loss
  11. def test(config, test_loader, trial_number=None):
  12. try:
  13. checkpoint = torch.load(str(config.output_path) + '/checkpoint_' + str(trial_number) + '.pt')
  14. except:
  15. checkpoint = torch.load(str(config.output_path) + '/checkpoint.pt')
  16. try:
  17. parameters = checkpoint['parameters']
  18. config.assign_hyperparameters(parameters)
  19. except:
  20. pass
  21. model = FakeNewsModel(config).to(config.device)
  22. try:
  23. model.load_state_dict(checkpoint['model_state_dict'])
  24. except:
  25. model.load_state_dict(checkpoint)
  26. model.eval()
  27. torch.manual_seed(27)
  28. random.seed(27)
  29. np.random.seed(27)
  30. image_features = []
  31. text_features = []
  32. multimodal_features = []
  33. concat_features = []
  34. targets = []
  35. predictions = []
  36. scores = []
  37. ids = []
  38. losses = []
  39. similarities = []
  40. tqdm_object = tqdm(test_loader, total=len(test_loader))
  41. for i, batch in enumerate(tqdm_object):
  42. batch = batch_constructor(config, batch)
  43. with torch.no_grad():
  44. output, score = model(batch)
  45. loss, c_loss, s_loss = calculate_loss(model, score, batch['label'])
  46. predictions.append(output.detach())
  47. scores.append(score.detach())
  48. targets.append(batch['label'].detach())
  49. ids.append(batch['id'].detach())
  50. image_features.append(model.image_embeddings.detach())
  51. text_features.append(model.text_embeddings.detach())
  52. multimodal_features.append(model.multimodal_embeddings.detach())
  53. concat_features.append(model.classifier.embeddings.detach())
  54. similarities.append(model.similarity.detach())
  55. losses.append((loss.detach(), c_loss.detach(), s_loss.detach()))
  56. s = ''
  57. s += report_per_class(targets, predictions) + '\n'
  58. s += metrics(targets, predictions, scores, file_path=str(config.output_path) + '/new_fpr_tpr.csv') + '\n'
  59. with open(config.output_path + '/new_results.txt', 'w') as f:
  60. f.write(s)
  61. roc_auc_plot(targets, scores, fname=str(config.output_path) + "/roc.png")
  62. precision_recall_plot(targets, scores, fname=str(config.output_path) + "/pr.png")
  63. save_embedding(image_features, fname=str(config.output_path) + '/new_image_features.tsv')
  64. save_embedding(text_features, fname=str(config.output_path) + '/new_text_features.tsv')
  65. save_embedding(multimodal_features, fname=str(config.output_path) + '/new_multimodal_features_.tsv')
  66. save_embedding(concat_features, fname=str(config.output_path) + '/new_concat_features.tsv')
  67. save_embedding(similarities, fname=str(config.output_path) + '/similarities.tsv')
  68. config_parameters = str(config)
  69. with open(config.output_path + '/new_parameters.txt', 'w') as f:
  70. f.write(config_parameters)
  71. save_loss(ids, predictions, targets, losses, str(config.output_path) + '/new_text_label.csv')
  72. def test_main(config, trial_number=None):
  73. train_df, test_df, validation_df = make_dfs(config, )
  74. test_loader = build_loaders(config, test_df, mode="test")
  75. test(config, test_loader, trial_number)