You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

mlp.py 4.4KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. from torch import nn
  2. from torch import optim, no_grad
  3. import torch
  4. from evaluation import Evaluation
  5. class EarlyStopper:
  6. def __init__(self, patience=1, min_delta=0):
  7. self.patience = patience
  8. self.min_delta = min_delta
  9. self.counter = 0
  10. self.min_validation_loss = float('inf')
  11. def early_stop(self, validation_loss):
  12. if validation_loss < self.min_validation_loss:
  13. self.min_validation_loss = validation_loss
  14. self.counter = 0
  15. elif validation_loss > (self.min_validation_loss + self.min_delta):
  16. self.counter += 1
  17. if self.counter >= self.patience:
  18. return True
  19. return False
  20. class MLP(nn.Module):
  21. def __init__(self, input_dim, output_dim):
  22. super(MLP, self).__init__()
  23. self.mlp = nn.Sequential(
  24. nn.Linear(input_dim, 128),
  25. nn.ReLU(inplace=True),
  26. nn.Linear(128, output_dim),
  27. nn.Sigmoid(),
  28. )
  29. def forward(self, x):
  30. return self.mlp(x)
  31. def train_mlp(model, train_loader, val_loader, num_epochs):
  32. mlp_loss_fn = nn.BCELoss()
  33. mlp_optimizer = optim.Adadelta(model.parameters(), lr=0.01,)
  34. # scheduler = lr_scheduler.ReduceLROnPlateau(mlp_optimizer, mode='min', factor=0.8, patience=5, verbose=True)
  35. train_accuracies = []
  36. val_accuracies = []
  37. train_loss = []
  38. val_loss = []
  39. early_stopper = EarlyStopper(patience=3, min_delta=0.05)
  40. for epoch in range(num_epochs):
  41. # Training
  42. model.trainCombinedModel()
  43. total_train_loss = 0.0
  44. train_correct = 0
  45. train_total_samples = 0
  46. for batch_idx, (data, target) in enumerate(train_loader):
  47. mlp_optimizer.zero_grad()
  48. mlp_output = model(data)
  49. mlp_loss = mlp_loss_fn(mlp_output, target)
  50. mlp_loss.backward()
  51. mlp_optimizer.step()
  52. total_train_loss += mlp_loss.item()
  53. # Calculate accuracy
  54. train_predictions = torch.round(mlp_output)
  55. train_correct += (train_predictions == target).sum().item()
  56. train_total_samples += target.size(0)
  57. # if batch_idx % 200 == 0:
  58. # after_lr = mlp_optimizer.param_groups[0]["lr"]
  59. # print('Epoch [{}/{}], Batch [{}/{}], Total Loss: {:.4f}, Learning Rate: {:.8f},'.format(
  60. # epoch + 1, num_epochs, batch_idx + 1, len(train_loader), mlp_loss.item(), after_lr))
  61. avg_train_loss = total_train_loss / len(train_loader)
  62. train_loss.append(avg_train_loss)
  63. # Validation
  64. model.eval()
  65. total_val_loss = 0.0
  66. correct = 0
  67. total_samples = 0
  68. with torch.no_grad():
  69. for val_batch_idx, (data, val_target) in enumerate(val_loader):
  70. val_mlp_output = model(data)
  71. val_mlp_loss = mlp_loss_fn(val_mlp_output, val_target)
  72. total_val_loss += val_mlp_loss.item()
  73. # Calculate accuracy
  74. val_predictions = torch.round(val_mlp_output)
  75. correct += (val_predictions == val_target).sum().item()
  76. total_samples += val_target.size(0)
  77. avg_val_loss = total_val_loss / len(val_loader)
  78. val_loss.append(avg_val_loss)
  79. train_accuracy = train_correct / train_total_samples
  80. train_accuracies.append(train_accuracy)
  81. val_accuracy = correct / total_samples
  82. val_accuracies.append(val_accuracy)
  83. print(
  84. 'Epoch [{}/{}], Train Loss: {:.4f}, Val Loss: {:.4f}, Train Accuracy: {:.4f}, Val Accuracy: {:.4f}'.format(
  85. epoch + 1, num_epochs, avg_train_loss, avg_val_loss, train_accuracy,
  86. val_accuracy))
  87. # if early_stopper.early_stop(avg_val_loss):
  88. # break
  89. before_lr = mlp_optimizer.param_groups[0]["lr"]
  90. # scheduler.step(avg_val_loss)
  91. after_lr = mlp_optimizer.param_groups[0]["lr"]
  92. if before_lr != after_lr:
  93. print("Epoch %d: Adam lr %.8f -> %.8f" % (epoch, before_lr, after_lr))
  94. Evaluation.plot_train_val_accuracy(train_accuracies, val_accuracies, epoch+1)
  95. Evaluation.plot_train_val_loss(train_loss, val_loss, epoch+1)
  96. def test_mlp(model, test_loader):
  97. for i, (data, labels) in enumerate(test_loader):
  98. model.eval()
  99. with torch.no_grad():
  100. mlp_output = model(data)
  101. return Evaluation.evaluate(labels, mlp_output)