Browse Source

bug in evaluation fixed

master
Faeze 3 months ago
parent
commit
9ebc0db415
3 changed files with 21 additions and 10 deletions
  1. 1
    1
      data/weibo/config.py
  2. 13
    1
      evaluation.py
  3. 7
    8
      test_main.py

+ 1
- 1
data/weibo/config.py View File

@@ -33,7 +33,7 @@ class WeiboConfig(Config):
projection_size = 64
dropout = 0.5

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

image_model_name = '../../../../../media/external_10TB/10TB/ghorbanpoor/vit-base-patch16-224'
image_embedding = 768

+ 13
- 1
evaluation.py View File

@@ -235,7 +235,19 @@ def plot_tsne(config, x, y, fname='tsne.png'):
plt.show()


def save_embedding(config, x, fname='embedding.tsv'):
def save_loss(ids, predictions, targets, l, path):
ids = [i.cpu().numpy() for i in ids]
predictions = [i.cpu().numpy() for i in predictions]
targets = [i.cpu().numpy() for i in targets]
losses = [i[0].cpu().numpy() for i in l]
classifier_losses = [i[1].cpu().numpy() for i in l]
similarity_losses = [i[2].cpu().numpy() for i in l]

pd.DataFrame({'id': ids, 'predicted_label': predictions, 'real_label': targets, 'losses': losses,
'classifier_losses': classifier_losses, 'similarity_losses': similarity_losses}).to_csv(path)


def save_embedding(x, fname='embedding.tsv'):
x = [i.cpu().numpy() for i in x]

x = np.concatenate(x, axis=0)

+ 7
- 8
test_main.py View File

@@ -7,7 +7,7 @@ from tqdm import tqdm

from data_loaders import make_dfs, build_loaders
from evaluation import metrics, report_per_class, roc_auc_plot, precision_recall_plot, plot_tsne, plot_pca, \
save_embedding
save_embedding, save_loss
from learner import batch_constructor
from model import FakeNewsModel, calculate_loss

@@ -75,18 +75,17 @@ def test(config, test_loader, trial_number=None):
roc_auc_plot(targets, scores, fname=str(config.output_path) + "/roc.png")
precision_recall_plot(targets, scores, fname=str(config.output_path) + "/pr.png")

save_embedding(config, image_features, fname=str(config.output_path) + '/new_image_features.tsv')
save_embedding(config, text_features, fname=str(config.output_path) + '/new_text_features.tsv')
save_embedding(config, multimodal_features, fname=str(config.output_path) + '/new_multimodal_features_.tsv')
save_embedding(config, concat_features, fname=str(config.output_path) + '/new_concat_features.tsv')
save_embedding(image_features, fname=str(config.output_path) + '/new_image_features.tsv')
save_embedding(text_features, fname=str(config.output_path) + '/new_text_features.tsv')
save_embedding(multimodal_features, fname=str(config.output_path) + '/new_multimodal_features_.tsv')
save_embedding(concat_features, fname=str(config.output_path) + '/new_concat_features.tsv')

config_parameters = str(config)
with open(config.output_path + '/new_parameters.txt', 'w') as f:
f.write(config_parameters)
print(config)

pd.DataFrame({'id': ids, 'predicted_label': predictions, 'real_label': targets, 'losses': losses}).to_csv(
str(config.output_path) + '/new_text_label.csv')
save_loss(ids, predictions, targets, losses, str(config.output_path) + '/new_text_label.csv')


def test_main(config, trial_number=None):

Loading…
Cancel
Save