1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495 |
- import random
-
- import numpy as np
- import pandas as pd
- import torch
- from tqdm import tqdm
-
- from data_loaders import make_dfs, build_loaders
- from evaluation import metrics, report_per_class, roc_auc_plot, precision_recall_plot, plot_tsne, plot_pca, \
- save_embedding
- from learner import batch_constructor
- from model import FakeNewsModel, calculate_loss
-
-
- def test(config, test_loader, trial_number=None):
- if trial_number:
- try:
- checkpoint = torch.load(str(config.output_path) + '/checkpoint_' + str(trial_number) + '.pt')
- except:
- checkpoint = torch.load(str(config.output_path) + '/checkpoint.pt')
- else:
- checkpoint = torch.load(str(config.output_path) + '/checkpoint.pt', map_location=torch.device(config.device))
-
- try:
- parameters = checkpoint['parameters']
- config.assign_hyperparameters(parameters)
- except:
- pass
-
- model = FakeNewsModel(config).to(config.device)
- try:
- model.load_state_dict(checkpoint['model_state_dict'])
- except:
- model.load_state_dict(checkpoint)
-
- model.eval()
-
- torch.manual_seed(27)
- random.seed(27)
- np.random.seed(27)
-
- image_features = []
- text_features = []
- multimodal_features = []
- concat_features = []
-
- targets = []
- predictions = []
- scores = []
- ids = []
- losses = []
- tqdm_object = tqdm(test_loader, total=len(test_loader))
- for i, batch in enumerate(tqdm_object):
- batch = batch_constructor(config, batch)
- with torch.no_grad():
- output, score = model(batch)
- loss, c_loss, s_loss = calculate_loss(model, score, batch['label'])
-
- predictions.append(output.detach())
- scores.append(score.detach())
- targets.append(batch['label'].detach())
- ids.append(batch['id'].detach())
- image_features.append(model.image_embeddings.detach())
- text_features.append(model.text_embeddings.detach())
- multimodal_features.append(model.multimodal_embeddings.detach())
- concat_features.append(model.classifier.embeddings.detach())
- losses.append((loss.detach(), c_loss.detach(), s_loss.detach()))
-
- s = ''
- s += report_per_class(targets, predictions) + '\n'
- s += metrics(targets, predictions, scores, file_path=str(config.output_path) + '/new_fpr_tpr.csv') + '\n'
- with open(config.output_path + '/new_results.txt', 'w') as f:
- f.write(s)
-
- roc_auc_plot(targets, scores, fname=str(config.output_path) + "/roc.png")
- precision_recall_plot(targets, scores, fname=str(config.output_path) + "/pr.png")
-
- save_embedding(config, image_features, fname=str(config.output_path) + '/new_image_features.tsv')
- save_embedding(config, text_features, fname=str(config.output_path) + '/new_text_features.tsv')
- save_embedding(config, multimodal_features, fname=str(config.output_path) + '/new_multimodal_features_.tsv')
- save_embedding(config, concat_features, fname=str(config.output_path) + '/new_concat_features.tsv')
-
- config_parameters = str(config)
- with open(config.output_path + '/new_parameters.txt', 'w') as f:
- f.write(config_parameters)
- print(config)
-
- pd.DataFrame({'id': ids, 'predicted_label': predictions, 'real_label': targets, 'losses': losses}).to_csv(
- str(config.output_path) + '/new_text_label.csv')
-
-
- def test_main(config, trial_number=None):
- train_df, test_df, validation_df = make_dfs(config, )
- test_loader = build_loaders(config, test_df, mode="test")
- test(config, test_loader, trial_number)
|