Bayesian Deep Ensemble Collaborative Filtering
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

BNN.py 5.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. import math
  2. import matplotlib.pyplot as plt
  3. import numpy as np
  4. import torch
  5. import torch.nn as nn
  6. import torch.nn.functional as F
  7. class Gaussian(object):
  8. def __init__(self, mu, rho, device=None):
  9. super().__init__()
  10. self.mu = mu.to(device)
  11. self.rho = rho.to(device)
  12. self.device = device
  13. self.normal = torch.distributions.Normal(torch.tensor(0.0, device=device), torch.tensor(1.0, device=device))
  14. @property
  15. def sigma(self):
  16. return torch.log1p(torch.exp(self.rho))
  17. def sample(self):
  18. epsilon = self.normal.sample(self.rho.size())
  19. return self.mu + self.sigma * epsilon
  20. def log_prob(self, input):
  21. return (-math.log(math.sqrt(2 * math.pi))
  22. - torch.log(self.sigma)
  23. - ((input - self.mu) ** 2) / (2 * self.sigma ** 2)).sum()
  24. class ScaleMixtureGaussian(object):
  25. def __init__(self, pi, sigma1, sigma2, device=None):
  26. super().__init__()
  27. self.pi = pi
  28. self.device = device
  29. self.sigma1 = sigma1.to(device)
  30. self.sigma2 = sigma2.to(device)
  31. self.gaussian1 = torch.distributions.Normal(torch.tensor(0.0, device=device), self.sigma1)
  32. self.gaussian2 = torch.distributions.Normal(torch.tensor(0.0, device=device), self.sigma2)
  33. def log_prob(self, input):
  34. prob1 = torch.exp(self.gaussian1.log_prob(input.to(self.device)))
  35. prob2 = torch.exp(self.gaussian2.log_prob(input.to(self.device)))
  36. return (torch.log(self.pi * prob1 + (1 - self.pi) * prob2)).sum()
  37. class LaplacePrior(object):
  38. def __init__(self, mu, b, device=None):
  39. super().__init__()
  40. self.device = device
  41. self.mu = mu.to(device)
  42. self.b = b.to(device)
  43. def log_prob(self, x, do_sum=True):
  44. if do_sum:
  45. return (-torch.log(2 * self.b) - torch.abs(x - self.mu) / self.b).sum()
  46. else:
  47. return -torch.log(2 * self.b) - torch.abs(x - self.mu) / self.b
  48. class IsotropicGaussian(object):
  49. def __init__(self, mu, sigma, device=None):
  50. self.device = device
  51. self.mu = mu.to(device)
  52. self.sigma = sigma.to(device)
  53. self.cte_term = -0.5 * torch.log(torch.tensor(2 * np.pi))
  54. self.det_sig_term = -torch.log(self.sigma)
  55. def log_prob(self, x, do_sum=True):
  56. dist_term = -0.5 * ((x - self.mu) / self.sigma) ** 2
  57. if do_sum:
  58. return (self.cte_term + self.det_sig_term + dist_term).sum()
  59. else:
  60. return self.cte_term + self.det_sig_term + dist_term
  61. # ScaleMixture prior hyperparameters
  62. PI = 0.5
  63. SIGMA_1 = torch.FloatTensor([math.exp(-0)])
  64. SIGMA_2 = torch.FloatTensor([math.exp(-6)])
  65. # Laplace and IsotropicGaussian priors hyperparameters
  66. MU = torch.FloatTensor([0])
  67. B = torch.FloatTensor([0.1])
  68. DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  69. class BayesianLinear(nn.Module):
  70. def __init__(self, in_features, out_features, prior_type='ScaleMixtureGaussian', device=None):
  71. super().__init__()
  72. self.device = device
  73. self.in_features = in_features
  74. self.out_features = out_features
  75. # Weight parameters
  76. self.weight_mu = nn.Parameter(
  77. torch.Tensor(out_features, in_features).uniform_(-0.2, 0.2).to(device)
  78. )
  79. self.weight_rho = nn.Parameter(
  80. torch.Tensor(out_features, in_features).uniform_(-5, -4).to(device)
  81. )
  82. self.weight = Gaussian(self.weight_mu, self.weight_rho, device)
  83. # Bias parameters
  84. self.bias_mu = nn.Parameter(
  85. torch.Tensor(out_features).uniform_(-0.2, 0.2).to(device)
  86. )
  87. self.bias_rho = nn.Parameter(
  88. torch.Tensor(out_features).uniform_(-5, -4).to(device)
  89. )
  90. self.bias = Gaussian(self.bias_mu, self.bias_rho, device)
  91. # Prior distributions
  92. if prior_type == 'ScaleMixtureGaussian':
  93. self.weight_prior = ScaleMixtureGaussian(PI, SIGMA_1, SIGMA_2, device)
  94. self.bias_prior = ScaleMixtureGaussian(PI, SIGMA_1, SIGMA_2, device)
  95. elif prior_type == 'Laplace':
  96. self.weight_prior = LaplacePrior(MU, B, device)
  97. self.bias_prior = LaplacePrior(MU, B, device)
  98. elif prior_type == 'IsotropicGaussian':
  99. self.weight_prior = IsotropicGaussian(MU, B, device)
  100. self.bias_prior = IsotropicGaussian(MU, B, device)
  101. self.log_prior = 0
  102. self.log_variational_posterior = 0
  103. def forward(self, input, sample=False, calculate_log_probs=False):
  104. input = input.to(self.device)
  105. if self.training or sample:
  106. weight = self.weight.sample()
  107. bias = self.bias.sample()
  108. else:
  109. weight = self.weight.mu
  110. bias = self.bias.mu
  111. if self.training or calculate_log_probs:
  112. self.log_prior = (
  113. self.weight_prior.log_prob(weight) + self.bias_prior.log_prob(bias)
  114. )
  115. self.log_variational_posterior = (
  116. self.weight.log_prob(weight) + self.bias.log_prob(bias)
  117. )
  118. else:
  119. self.log_prior, self.log_variational_posterior = 0, 0
  120. return F.linear(input, weight, bias)