You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

mlp.py 4.5KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. from torch import nn
  2. from torch import optim, no_grad
  3. import torch
  4. from evaluation import Evaluation
  5. from utils import data_modalities_abbreviation
  6. class EarlyStopper:
  7. def __init__(self, patience=1, min_delta=0):
  8. self.patience = patience
  9. self.min_delta = min_delta
  10. self.counter = 0
  11. self.min_validation_loss = float('inf')
  12. def early_stop(self, validation_loss):
  13. if validation_loss < self.min_validation_loss:
  14. self.min_validation_loss = validation_loss
  15. self.counter = 0
  16. elif validation_loss > (self.min_validation_loss + self.min_delta):
  17. self.counter += 1
  18. if self.counter >= self.patience:
  19. return True
  20. return False
  21. class MLP(nn.Module):
  22. def __init__(self, input_dim, output_dim):
  23. super(MLP, self).__init__()
  24. self.mlp = nn.Sequential(
  25. nn.Linear(input_dim, 128),
  26. nn.ReLU(inplace=True),
  27. nn.Linear(128, output_dim),
  28. nn.Hardsigmoid(),
  29. )
  30. def forward(self, x):
  31. return self.mlp(x)
  32. def train_mlp(model, train_loader, val_loader, num_epochs):
  33. mlp_loss_fn = nn.BCELoss()
  34. mlp_optimizer = optim.Adadelta(model.parameters(), lr=0.01,)
  35. # scheduler = lr_scheduler.ReduceLROnPlateau(mlp_optimizer, mode='min', factor=0.8, patience=5, verbose=True)
  36. train_accuracies = []
  37. val_accuracies = []
  38. train_loss = []
  39. val_loss = []
  40. early_stopper = EarlyStopper(patience=3, min_delta=0.05)
  41. for epoch in range(num_epochs):
  42. # Training
  43. model.trainCombinedModel()
  44. total_train_loss = 0.0
  45. train_correct = 0
  46. train_total_samples = 0
  47. for batch_idx, (data, target) in enumerate(train_loader):
  48. mlp_optimizer.zero_grad()
  49. mlp_output = model(data)
  50. mlp_loss = mlp_loss_fn(mlp_output, target)
  51. mlp_loss.backward()
  52. mlp_optimizer.step()
  53. total_train_loss += mlp_loss.item()
  54. # Calculate accuracy
  55. train_predictions = torch.round(mlp_output)
  56. train_correct += (train_predictions == target).sum().item()
  57. train_total_samples += target.size(0)
  58. # if batch_idx % 200 == 0:
  59. # after_lr = mlp_optimizer.param_groups[0]["lr"]
  60. # print('Epoch [{}/{}], Batch [{}/{}], Total Loss: {:.4f}, Learning Rate: {:.8f},'.format(
  61. # epoch + 1, num_epochs, batch_idx + 1, len(train_loader), mlp_loss.item(), after_lr))
  62. avg_train_loss = total_train_loss / len(train_loader)
  63. train_loss.append(avg_train_loss)
  64. # Validation
  65. model.eval()
  66. total_val_loss = 0.0
  67. correct = 0
  68. total_samples = 0
  69. with torch.no_grad():
  70. for val_batch_idx, (data, val_target) in enumerate(val_loader):
  71. val_mlp_output = model(data)
  72. val_mlp_loss = mlp_loss_fn(val_mlp_output, val_target)
  73. total_val_loss += val_mlp_loss.item()
  74. # Calculate accuracy
  75. val_predictions = torch.round(val_mlp_output)
  76. correct += (val_predictions == val_target).sum().item()
  77. total_samples += val_target.size(0)
  78. avg_val_loss = total_val_loss / len(val_loader)
  79. val_loss.append(avg_val_loss)
  80. train_accuracy = train_correct / train_total_samples
  81. train_accuracies.append(train_accuracy)
  82. val_accuracy = correct / total_samples
  83. val_accuracies.append(val_accuracy)
  84. print(
  85. 'Epoch [{}/{}], Train Loss: {:.4f}, Val Loss: {:.4f}, Train Accuracy: {:.4f}, Val Accuracy: {:.4f}'.format(
  86. epoch + 1, num_epochs, avg_train_loss, avg_val_loss, train_accuracy,
  87. val_accuracy))
  88. # if early_stopper.early_stop(avg_val_loss):
  89. # break
  90. before_lr = mlp_optimizer.param_groups[0]["lr"]
  91. # scheduler.step(avg_val_loss)
  92. after_lr = mlp_optimizer.param_groups[0]["lr"]
  93. if before_lr != after_lr:
  94. print("Epoch %d: Adam lr %.8f -> %.8f" % (epoch, before_lr, after_lr))
  95. Evaluation.plot_train_val_accuracy(train_accuracies, val_accuracies, epoch+1)
  96. Evaluation.plot_train_val_loss(train_loss, val_loss, epoch+1)
  97. def test_mlp(model, test_loader):
  98. for i, (data, labels) in enumerate(test_loader):
  99. model.eval()
  100. with torch.no_grad():
  101. mlp_output = model(data)
  102. return Evaluation.evaluate(labels, mlp_output)