A PyTorch implementation of the paper "CSI: a hybrid deep neural network for fake news detection"
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

model.py 3.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. import torch
  2. from torch import nn
  3. from torchmetrics.functional import confusion_matrix
  4. import pytorch_lightning as pl
  5. class CSIModel(pl.LightningModule):
  6. def __init__(self, config):
  7. super().__init__()
  8. self.config = config
  9. self.criterion = nn.BCELoss()
  10. self.capture_rnn = nn.Sequential(
  11. nn.Linear(config['capture_input_dim'], config['d_Wa']),
  12. nn.Tanh(),
  13. nn.Dropout(config['dropout']),
  14. nn.LSTM(
  15. input_size=config['d_Wa'],
  16. hidden_size=config['d_lstm'],
  17. num_layers=1,
  18. batch_first=True)
  19. )
  20. self.capture_proj = nn.Sequential(
  21. nn.Linear(config['d_lstm'], config['d_Wr']),
  22. nn.Tanh(),
  23. nn.Dropout(config['dropout'])
  24. )
  25. self.score = nn.Sequential(
  26. nn.Linear(config['score_input_dim'], config['d_Wu']),
  27. nn.Tanh(),
  28. nn.Linear(config['d_Wu'], config['d_Ws']),
  29. nn.Sigmoid()
  30. )
  31. self.cls = nn.Sequential(
  32. nn.Linear(config['d_Ws'] + config['d_Wr'], 1),
  33. nn.Sigmoid()
  34. )
  35. def configure_optimizers(self):
  36. all_params = dict(self.named_parameters())
  37. wd_name = 'score.0.weight'
  38. wd_params = all_params[wd_name]
  39. del all_params[wd_name]
  40. return torch.optim.Adam(
  41. [
  42. {'params': wd_params, 'weight_decay': self.config['weight_decay']},
  43. {'params': list(all_params.values())},
  44. ],
  45. lr=self.config['lr']
  46. )
  47. def count_parameters(self):
  48. return sum(p.numel() for p in self.parameters() if p.requires_grad)
  49. def forward(self, x_capture, x_score):
  50. hc, (_, _) = self.capture_rnn(x_capture.float())
  51. hc = self.capture_proj(hc[:, -1])
  52. hs = self.score(x_score.float()).mean(dim=1)
  53. h = torch.cat([hc, hs], dim=1)
  54. return self.cls(h)
  55. def step(self, batch, mode='train'):
  56. x_capture, x_score, labels = batch
  57. labels = labels[:, None].float()
  58. logits = self.forward(x_capture, x_score)
  59. loss = self.criterion(logits, labels)
  60. preds = logits.clone()
  61. preds[preds >=0.5] = 1
  62. preds[preds < 0.5] = 0
  63. acc = (preds == labels).sum() / labels.shape[0]
  64. tn, fn, fp, tp = confusion_matrix(logits, labels.int(), num_classes=2, threshold=0.5).flatten()
  65. self.log(f'{mode}_loss', loss.item())
  66. self.log(f'{mode}_acc', acc.item())
  67. self.log(f'{mode}_tn', tn.item())
  68. self.log(f'{mode}_fn', fn.item())
  69. self.log(f'{mode}_fp', fp.item())
  70. self.log(f'{mode}_tp', tp.item())
  71. return {
  72. 'loss':loss,
  73. 'acc':acc,
  74. 'tn':tn,
  75. 'fn':fn,
  76. 'fp':fp,
  77. 'tp':tp
  78. }
  79. def training_step(self, batch, batch_idx):
  80. return self.step(batch)
  81. def test_step(self, batch, batch_idx):
  82. return self.step(batch, mode='test')
  83. def validation_step(self, batch, batch_idx):
  84. return self.step(batch, mode='val')