You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

autoencoder.py 3.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. import torch
  2. import torch.nn as nn
  3. import torch.optim as optim
  4. class Autoencoder(nn.Module):
  5. """
  6. Autoencoder neural network model for feature learning.
  7. Parameters:
  8. - input_dim (int): Dimensionality of the input features.
  9. - latent_dim (int): Dimensionality of the latent space.
  10. """
  11. def __init__(self, input_dim, latent_dim):
  12. super(Autoencoder, self).__init__()
  13. # Encoder architecture
  14. self.encoder = nn.Sequential(
  15. nn.Linear(input_dim, 256),
  16. nn.ReLU(inplace=True),
  17. nn.Linear(256, latent_dim),
  18. nn.ReLU(inplace=True),
  19. )
  20. # Decoder architecture
  21. self.decoder = nn.Sequential(
  22. nn.Linear(latent_dim, 256),
  23. nn.ReLU(inplace=True),
  24. nn.Linear(256, input_dim),
  25. )
  26. def forward(self, x):
  27. """
  28. Forward pass of the autoencoder.
  29. Parameters:
  30. - x (torch.Tensor): Input tensor.
  31. Returns:
  32. - decoded (torch.Tensor): Decoded output tensor.
  33. """
  34. encoded = self.encoder(x)
  35. decoded = self.decoder(encoded)
  36. return decoded
  37. def trainAutoencoder(model, train_loader, val_loader, num_epochs, name):
  38. """
  39. Train the autoencoder model.
  40. Parameters:
  41. - model (Autoencoder): The autoencoder model to be trained.
  42. - train_loader (DataLoader): DataLoader for the training dataset.
  43. - val_loader (DataLoader): DataLoader for the validation dataset.
  44. - num_epochs (int): Number of training epochs.
  45. - name (str): Name to save the trained model.
  46. Returns:
  47. - None
  48. """
  49. loss_fn = nn.MSELoss()
  50. optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-8)
  51. train_loss = []
  52. val_final_loss = []
  53. for epoch in range(num_epochs):
  54. # Training
  55. model.train()
  56. total_train_loss = 0.0
  57. for batch_idx, data in enumerate(train_loader):
  58. data = data[0]
  59. output = model(data)
  60. loss = loss_fn(output, data)
  61. optimizer.zero_grad()
  62. loss.backward()
  63. optimizer.step()
  64. total_train_loss += loss
  65. avg_train_loss = total_train_loss
  66. train_loss.append(avg_train_loss)
  67. # Validation
  68. model.eval()
  69. total_val_loss = 0.0
  70. with torch.no_grad():
  71. for val_batch_idx, (val_data) in enumerate(val_loader):
  72. val_data = val_data[0]
  73. val_output = model(val_data)
  74. val_loss = loss_fn(val_output, val_data)
  75. total_val_loss += val_loss
  76. avg_val_loss = total_val_loss
  77. val_final_loss.append(avg_val_loss)
  78. print('Epoch [{}/{}], Train Loss: {:.4f}, Val Loss: {:.4f}'.format(
  79. epoch + 1, num_epochs, avg_train_loss, avg_val_loss))
  80. before_lr = optimizer.param_groups[0]["lr"]
  81. after_lr = optimizer.param_groups[0]["lr"]
  82. if before_lr != after_lr:
  83. print("Epoch %d: Adam lr %.8f -> %.8f" % (epoch, before_lr, after_lr))
  84. # Save the trained model
  85. torch.save(model.state_dict(), "autoencoder" + name + '.pth')
  86. print(model.encoder[0].weight.detach().numpy())