123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596 |
- import torch
- from torch import nn
- from torchmetrics.functional import confusion_matrix
- import pytorch_lightning as pl
-
- class CSIModel(pl.LightningModule):
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.criterion = nn.BCELoss()
- self.capture_rnn = nn.Sequential(
- nn.Linear(config['capture_input_dim'], config['d_Wa']),
- nn.Tanh(),
- nn.Dropout(config['dropout']),
- nn.LSTM(
- input_size=config['d_Wa'],
- hidden_size=config['d_lstm'],
- num_layers=1,
- batch_first=True)
- )
- self.capture_proj = nn.Sequential(
- nn.Linear(config['d_lstm'], config['d_Wr']),
- nn.Tanh(),
- nn.Dropout(config['dropout'])
- )
- self.score = nn.Sequential(
- nn.Linear(config['score_input_dim'], config['d_Wu']),
- nn.Tanh(),
- nn.Linear(config['d_Wu'], config['d_Ws']),
- nn.Sigmoid()
- )
- self.cls = nn.Sequential(
- nn.Linear(config['d_Ws'] + config['d_Wr'], 1),
- nn.Sigmoid()
- )
-
- def configure_optimizers(self):
- all_params = dict(self.named_parameters())
- wd_name = 'score.0.weight'
- wd_params = all_params[wd_name]
- del all_params[wd_name]
- return torch.optim.Adam(
- [
- {'params': wd_params, 'weight_decay': self.config['weight_decay']},
- {'params': list(all_params.values())},
- ],
- lr=self.config['lr']
- )
-
-
- def count_parameters(self):
- return sum(p.numel() for p in self.parameters() if p.requires_grad)
-
- def forward(self, x_capture, x_score):
- hc, (_, _) = self.capture_rnn(x_capture.float())
- hc = self.capture_proj(hc[:, -1])
- hs = self.score(x_score.float()).mean(dim=1)
- h = torch.cat([hc, hs], dim=1)
- return self.cls(h)
-
- def step(self, batch, mode='train'):
- x_capture, x_score, labels = batch
- labels = labels[:, None].float()
- logits = self.forward(x_capture, x_score)
- loss = self.criterion(logits, labels)
-
- preds = logits.clone()
- preds[preds >=0.5] = 1
- preds[preds < 0.5] = 0
- acc = (preds == labels).sum() / labels.shape[0]
- tn, fn, fp, tp = confusion_matrix(logits, labels.int(), num_classes=2, threshold=0.5).flatten()
-
- self.log(f'{mode}_loss', loss.item())
- self.log(f'{mode}_acc', acc.item())
- self.log(f'{mode}_tn', tn.item())
- self.log(f'{mode}_fn', fn.item())
- self.log(f'{mode}_fp', fp.item())
- self.log(f'{mode}_tp', tp.item())
- return {
- 'loss':loss,
- 'acc':acc,
- 'tn':tn,
- 'fn':fn,
- 'fp':fp,
- 'tp':tp
- }
-
- def training_step(self, batch, batch_idx):
- return self.step(batch)
-
- def test_step(self, batch, batch_idx):
- return self.step(batch, mode='test')
-
- def validation_step(self, batch, batch_idx):
- return self.step(batch, mode='val')
-
|