|
|
@@ -7,7 +7,7 @@ from tqdm import tqdm |
|
|
|
|
|
|
|
from data_loaders import make_dfs, build_loaders |
|
|
|
from evaluation import metrics, report_per_class, roc_auc_plot, precision_recall_plot, plot_tsne, plot_pca, \ |
|
|
|
save_embedding |
|
|
|
save_embedding, save_loss |
|
|
|
from learner import batch_constructor |
|
|
|
from model import FakeNewsModel, calculate_loss |
|
|
|
|
|
|
@@ -75,18 +75,17 @@ def test(config, test_loader, trial_number=None): |
|
|
|
roc_auc_plot(targets, scores, fname=str(config.output_path) + "/roc.png") |
|
|
|
precision_recall_plot(targets, scores, fname=str(config.output_path) + "/pr.png") |
|
|
|
|
|
|
|
save_embedding(config, image_features, fname=str(config.output_path) + '/new_image_features.tsv') |
|
|
|
save_embedding(config, text_features, fname=str(config.output_path) + '/new_text_features.tsv') |
|
|
|
save_embedding(config, multimodal_features, fname=str(config.output_path) + '/new_multimodal_features_.tsv') |
|
|
|
save_embedding(config, concat_features, fname=str(config.output_path) + '/new_concat_features.tsv') |
|
|
|
save_embedding(image_features, fname=str(config.output_path) + '/new_image_features.tsv') |
|
|
|
save_embedding(text_features, fname=str(config.output_path) + '/new_text_features.tsv') |
|
|
|
save_embedding(multimodal_features, fname=str(config.output_path) + '/new_multimodal_features_.tsv') |
|
|
|
save_embedding(concat_features, fname=str(config.output_path) + '/new_concat_features.tsv') |
|
|
|
|
|
|
|
config_parameters = str(config) |
|
|
|
with open(config.output_path + '/new_parameters.txt', 'w') as f: |
|
|
|
f.write(config_parameters) |
|
|
|
print(config) |
|
|
|
|
|
|
|
pd.DataFrame({'id': ids, 'predicted_label': predictions, 'real_label': targets, 'losses': losses}).to_csv( |
|
|
|
str(config.output_path) + '/new_text_label.csv') |
|
|
|
save_loss(ids, predictions, targets, losses, str(config.output_path) + '/new_text_label.csv') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_main(config, trial_number=None): |