Official implementation of the Fake News Revealer paper
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

config.py 3.5KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. import torch
  2. import ast
  3. class Config:
  4. name = ''
  5. DatasetLoader = None
  6. debug = False
  7. data_path = ''
  8. output_path = ''
  9. train_image_path = ''
  10. validation_image_path = ''
  11. test_image_path = ''
  12. train_text_path = ''
  13. validation_text_path = ''
  14. test_text_path = ''
  15. train_text_embedding_path = ''
  16. validation_text_embedding_path = ''
  17. train_image_embedding_path = ''
  18. validation_image_embedding_path = ''
  19. batch_size = 256
  20. epochs = 50
  21. num_workers = 2
  22. head_lr = 1e-03
  23. image_encoder_lr = 1e-04
  24. text_encoder_lr = 1e-04
  25. attention_lr = 1e-3
  26. classification_lr = 1e-03
  27. max_grad_norm = 5.0
  28. head_weight_decay = 0.001
  29. attention_weight_decay = 0.001
  30. classification_weight_decay = 0.001
  31. patience = 30
  32. delta = 0.0000001
  33. factor = 0.8
  34. image_model_name = '../../../../../media/external_10TB/10TB/ghorbanpoor/vit-base-patch16-224'
  35. image_embedding = 768
  36. num_img_region = 64 # 16 #TODO
  37. text_encoder_model = "../../../../../media/external_10TB/10TB/ghorbanpoor/bert-base-uncased"
  38. text_tokenizer = "../../../../../media/external_10TB/10TB/ghorbanpoor/bert-base-uncased"
  39. text_embedding = 768
  40. max_length = 32
  41. pretrained = True # for both image encoder and text encoder
  42. trainable = False # for both image encoder and text encoder
  43. temperature = 1.0
  44. # image size
  45. size = 224
  46. num_projection_layers = 1
  47. projection_size = 256
  48. dropout = 0.3
  49. hidden_size = 256
  50. num_region = 64 # 16 #TODO
  51. region_size = projection_size // num_region
  52. classes = ['real', 'fake']
  53. class_num = 2
  54. loss_weight = 1
  55. class_weights = [1, 1]
  56. writer = None
  57. has_unlabeled_data = False
  58. step = 0
  59. T1 = 10
  60. T2 = 150
  61. af = 3
  62. wanted_accuracy = 0.76
  63. def optuna(self, trial):
  64. self.head_lr = trial.suggest_loguniform('head_lr', 1e-5, 1e-1)
  65. self.image_encoder_lr = trial.suggest_loguniform('image_encoder_lr', 1e-6, 1e-3)
  66. self.text_encoder_lr = trial.suggest_loguniform('text_encoder_lr', 1e-6, 1e-3)
  67. self.classification_lr = trial.suggest_loguniform('classification_lr', 1e-5, 1e-1)
  68. self.attention_lr = trial.suggest_loguniform('attention_lr', 1e-5, 1e-1)
  69. self.attention_weight_decay = trial.suggest_loguniform('attention_weight_decay', 1e-5, 1e-1)
  70. self.head_weight_decay = trial.suggest_loguniform('head_weight_decay', 1e-5, 1e-1)
  71. self.classification_weight_decay = trial.suggest_loguniform('classification_weight_decay', 1e-5, 1e-1)
  72. self.projection_size = trial.suggest_categorical('projection_size', [256, 128, 64])
  73. self.hidden_size = trial.suggest_categorical('hidden_size', [256, 128, 64, ])
  74. self.dropout = trial.suggest_categorical('drop_out', [0.1, 0.3, 0.5, ])
  75. def __str__(self):
  76. s = ''
  77. members = [attr for attr in dir(self) if not callable(getattr(self, attr)) and not attr.startswith("__")]
  78. for member in members:
  79. s += member + '\t' + str(getattr(self, member)) + '\n'
  80. return s
  81. def assign_hyperparameters(self, s):
  82. for line in s.split('\n'):
  83. s = line.split('\t')
  84. try:
  85. attr = getattr(self, s[0])
  86. if type(attr) not in [list, set, dict]:
  87. setattr(self, s[0], type(attr)(s[1]))
  88. else:
  89. setattr(self, s[0], ast.literal_eval(s[1]))
  90. except:
  91. print(s[0], 'doesnot exist')