Official implementation of the Fake News Revealer paper
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

config.py 2.8KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. import torch
  2. from data.config import Config
  3. from data.twitter.data_loader import TwitterDatasetLoader
  4. class TwitterConfig(Config):
  5. name = 'twitter'
  6. DatasetLoader = TwitterDatasetLoader
  7. data_path = '../../../../../media/external_3TB/3TB/ghorbanpoor/twitter/'
  8. # data_path = '/home/faeze/PycharmProjects/fake_news_detection/data/twitter/'
  9. output_path = '../../../../../media/external_10TB/10TB/ghorbanpoor/'
  10. # output_path = ''
  11. train_image_path = data_path + 'images_train/'
  12. validation_image_path = data_path + 'images_test/'
  13. test_image_path = data_path + 'images_test/'
  14. train_text_path = data_path + 'twitter_train_translated.csv'
  15. validation_text_path = data_path + 'twitter_test_translated.csv'
  16. test_text_path = data_path + 'twitter_test_translated.csv'
  17. batch_size = 128
  18. epochs = 100
  19. num_workers = 2
  20. head_lr = 0.0254
  21. image_encoder_lr = 0.0005
  22. text_encoder_lr = 2.0e-05
  23. attention_lr = 1e-3
  24. classification_lr = 0.0034
  25. dropout = 0.5
  26. hidden_size = 128
  27. projection_size = 64
  28. head_weight_decay = 0.05
  29. attention_weight_decay = 0.05
  30. classification_weight_decay = 0.05
  31. device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
  32. image_model_name = '../../../../../media/external_10TB/10TB/ghorbanpoor/vit-base-patch16-224'
  33. image_embedding = 768
  34. text_encoder_model = "../../../../../media/external_10TB/10TB/ghorbanpoor/bert-base-uncased"
  35. # text_encoder_model = "/home/faeze/PycharmProjects/new_fake_news_detectioin/bert/bert-base-uncased"
  36. text_tokenizer = "../../../../../media/external_10TB/10TB/ghorbanpoor/bert-base-uncased"
  37. # text_tokenizer = "/home/faeze/PycharmProjects/new_fake_news_detectioin/bert/bert-base-uncased"
  38. text_embedding = 768
  39. max_length = 32
  40. pretrained = True
  41. trainable = False
  42. temperature = 1.0
  43. classes = ['real', 'fake']
  44. class_weights = [1, 1]
  45. wanted_accuracy = 0.76
  46. def optuna(self, trial):
  47. self.head_lr = trial.suggest_loguniform('head_lr', 1e-5, 1e-1)
  48. self.image_encoder_lr = trial.suggest_loguniform('image_encoder_lr', 1e-6, 1e-3)
  49. self.text_encoder_lr = trial.suggest_loguniform('text_encoder_lr', 1e-6, 1e-3)
  50. self.classification_lr = trial.suggest_loguniform('classification_lr', 1e-5, 1e-1)
  51. self.head_weight_decay = trial.suggest_loguniform('head_weight_decay', 1e-5, 1e-1)
  52. # self.attention_weight_decay = trial.suggest_loguniform('attention_weight_decay', 1e-5, 1e-1)
  53. self.classification_weight_decay = trial.suggest_loguniform('classification_weight_decay', 1e-5, 1e-1)
  54. # self.projection_size = trial.suggest_categorical('projection_size', [256, 128, 64])
  55. # self.hidden_size = trial.suggest_categorical('hidden_size', [256, 128, 64, ])
  56. # self.dropout = trial.suggest_categorical('drop_out', [0.1, 0.3, 0.5, ])