import math import matplotlib.pyplot as plt import numpy as np import torch import torch.nn as nn import torch.nn.functional as F class Gaussian(object): def __init__(self, mu, rho, device=None): super().__init__() self.mu = mu.to(device) self.rho = rho.to(device) self.device = device self.normal = torch.distributions.Normal(torch.tensor(0.0, device=device), torch.tensor(1.0, device=device)) @property def sigma(self): return torch.log1p(torch.exp(self.rho)) def sample(self): epsilon = self.normal.sample(self.rho.size()) return self.mu + self.sigma * epsilon def log_prob(self, input): return (-math.log(math.sqrt(2 * math.pi)) - torch.log(self.sigma) - ((input - self.mu) ** 2) / (2 * self.sigma ** 2)).sum() class ScaleMixtureGaussian(object): def __init__(self, pi, sigma1, sigma2, device=None): super().__init__() self.pi = pi self.device = device self.sigma1 = sigma1.to(device) self.sigma2 = sigma2.to(device) self.gaussian1 = torch.distributions.Normal(torch.tensor(0.0, device=device), self.sigma1) self.gaussian2 = torch.distributions.Normal(torch.tensor(0.0, device=device), self.sigma2) def log_prob(self, input): prob1 = torch.exp(self.gaussian1.log_prob(input.to(self.device))) prob2 = torch.exp(self.gaussian2.log_prob(input.to(self.device))) return (torch.log(self.pi * prob1 + (1 - self.pi) * prob2)).sum() class LaplacePrior(object): def __init__(self, mu, b, device=None): super().__init__() self.device = device self.mu = mu.to(device) self.b = b.to(device) def log_prob(self, x, do_sum=True): if do_sum: return (-torch.log(2 * self.b) - torch.abs(x - self.mu) / self.b).sum() else: return -torch.log(2 * self.b) - torch.abs(x - self.mu) / self.b class IsotropicGaussian(object): def __init__(self, mu, sigma, device=None): self.device = device self.mu = mu.to(device) self.sigma = sigma.to(device) self.cte_term = -0.5 * torch.log(torch.tensor(2 * np.pi)) self.det_sig_term = -torch.log(self.sigma) def log_prob(self, x, do_sum=True): dist_term = -0.5 * ((x - self.mu) / self.sigma) ** 2 if do_sum: return (self.cte_term + self.det_sig_term + dist_term).sum() else: return self.cte_term + self.det_sig_term + dist_term # ScaleMixture prior hyperparameters PI = 0.5 SIGMA_1 = torch.FloatTensor([math.exp(-0)]) SIGMA_2 = torch.FloatTensor([math.exp(-6)]) # Laplace and IsotropicGaussian priors hyperparameters MU = torch.FloatTensor([0]) B = torch.FloatTensor([0.1]) DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") class BayesianLinear(nn.Module): def __init__(self, in_features, out_features, prior_type='ScaleMixtureGaussian', device=None): super().__init__() self.device = device self.in_features = in_features self.out_features = out_features # Weight parameters self.weight_mu = nn.Parameter( torch.Tensor(out_features, in_features).uniform_(-0.2, 0.2).to(device) ) self.weight_rho = nn.Parameter( torch.Tensor(out_features, in_features).uniform_(-5, -4).to(device) ) self.weight = Gaussian(self.weight_mu, self.weight_rho, device) # Bias parameters self.bias_mu = nn.Parameter( torch.Tensor(out_features).uniform_(-0.2, 0.2).to(device) ) self.bias_rho = nn.Parameter( torch.Tensor(out_features).uniform_(-5, -4).to(device) ) self.bias = Gaussian(self.bias_mu, self.bias_rho, device) # Prior distributions if prior_type == 'ScaleMixtureGaussian': self.weight_prior = ScaleMixtureGaussian(PI, SIGMA_1, SIGMA_2, device) self.bias_prior = ScaleMixtureGaussian(PI, SIGMA_1, SIGMA_2, device) elif prior_type == 'Laplace': self.weight_prior = LaplacePrior(MU, B, device) self.bias_prior = LaplacePrior(MU, B, device) elif prior_type == 'IsotropicGaussian': self.weight_prior = IsotropicGaussian(MU, B, device) self.bias_prior = IsotropicGaussian(MU, B, device) self.log_prior = 0 self.log_variational_posterior = 0 def forward(self, input, sample=False, calculate_log_probs=False): input = input.to(self.device) if self.training or sample: weight = self.weight.sample() bias = self.bias.sample() else: weight = self.weight.mu bias = self.bias.mu if self.training or calculate_log_probs: self.log_prior = ( self.weight_prior.log_prob(weight) + self.bias_prior.log_prob(bias) ) self.log_variational_posterior = ( self.weight.log_prob(weight) + self.bias.log_prob(bias) ) else: self.log_prior, self.log_variational_posterior = 0, 0 return F.linear(input, weight, bias)