Official implementation of the Fake News Revealer paper
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

config.py 2.2KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. import torch
  2. from data.config import Config
  3. from data.twitter.data_loader import TwitterDatasetLoader
  4. class TwitterConfig(Config):
  5. name = 'twitter'
  6. DatasetLoader = TwitterDatasetLoader
  7. data_path = '/twitter/'
  8. output_path = ''
  9. train_image_path = data_path + 'images_train/'
  10. validation_image_path = data_path + 'images_validation/'
  11. test_image_path = data_path + 'images_test/'
  12. train_text_path = data_path + 'twitter_train_translated.csv'
  13. validation_text_path = data_path + 'twitter_validation_translated.csv'
  14. test_text_path = data_path + 'twitter_test_translated.csv'
  15. batch_size = 128
  16. epochs = 100
  17. num_workers = 2
  18. head_lr = 1e-03
  19. image_encoder_lr = 1e-04
  20. text_encoder_lr = 1e-04
  21. attention_lr = 1e-3
  22. classification_lr = 1e-03
  23. head_weight_decay = 0.001
  24. attention_weight_decay = 0.001
  25. classification_weight_decay = 0.001
  26. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  27. image_model_name = 'vit-base-patch16-224'
  28. image_embedding = 768
  29. text_encoder_model = "bert-base-uncased"
  30. text_tokenizer = "bert-base-uncased"
  31. text_embedding = 768
  32. max_length = 32
  33. pretrained = True
  34. trainable = False
  35. temperature = 1.0
  36. classes = ['real', 'fake']
  37. class_weights = [1, 1]
  38. wanted_accuracy = 0.76
  39. def optuna(self, trial):
  40. self.head_lr = trial.suggest_loguniform('head_lr', 1e-5, 1e-1)
  41. self.image_encoder_lr = trial.suggest_loguniform('image_encoder_lr', 1e-6, 1e-3)
  42. self.text_encoder_lr = trial.suggest_loguniform('text_encoder_lr', 1e-6, 1e-3)
  43. self.classification_lr = trial.suggest_loguniform('classification_lr', 1e-5, 1e-1)
  44. self.head_weight_decay = trial.suggest_loguniform('head_weight_decay', 1e-5, 1e-1)
  45. # self.attention_weight_decay = trial.suggest_loguniform('attention_weight_decay', 1e-5, 1e-1)
  46. self.classification_weight_decay = trial.suggest_loguniform('classification_weight_decay', 1e-5, 1e-1)
  47. # self.projection_size = trial.suggest_categorical('projection_size', [256, 128, 64])
  48. # self.hidden_size = trial.suggest_categorical('hidden_size', [256, 128, 64, ])
  49. # self.dropout = trial.suggest_categorical('drop_out', [0.1, 0.3, 0.5, ])