| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140 |
- import math
- import matplotlib.pyplot as plt
- import numpy as np
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
-
- class Gaussian(object):
- def __init__(self, mu, rho, device=None):
- super().__init__()
- self.mu = mu.to(device)
- self.rho = rho.to(device)
- self.device = device
- self.normal = torch.distributions.Normal(torch.tensor(0.0, device=device), torch.tensor(1.0, device=device))
-
- @property
- def sigma(self):
- return torch.log1p(torch.exp(self.rho))
-
- def sample(self):
- epsilon = self.normal.sample(self.rho.size())
- return self.mu + self.sigma * epsilon
-
- def log_prob(self, input):
- return (-math.log(math.sqrt(2 * math.pi))
- - torch.log(self.sigma)
- - ((input - self.mu) ** 2) / (2 * self.sigma ** 2)).sum()
-
- class ScaleMixtureGaussian(object):
- def __init__(self, pi, sigma1, sigma2, device=None):
- super().__init__()
- self.pi = pi
- self.device = device
- self.sigma1 = sigma1.to(device)
- self.sigma2 = sigma2.to(device)
- self.gaussian1 = torch.distributions.Normal(torch.tensor(0.0, device=device), self.sigma1)
- self.gaussian2 = torch.distributions.Normal(torch.tensor(0.0, device=device), self.sigma2)
-
- def log_prob(self, input):
- prob1 = torch.exp(self.gaussian1.log_prob(input.to(self.device)))
- prob2 = torch.exp(self.gaussian2.log_prob(input.to(self.device)))
- return (torch.log(self.pi * prob1 + (1 - self.pi) * prob2)).sum()
-
-
- class LaplacePrior(object):
- def __init__(self, mu, b, device=None):
- super().__init__()
- self.device = device
- self.mu = mu.to(device)
- self.b = b.to(device)
-
- def log_prob(self, x, do_sum=True):
- if do_sum:
- return (-torch.log(2 * self.b) - torch.abs(x - self.mu) / self.b).sum()
- else:
- return -torch.log(2 * self.b) - torch.abs(x - self.mu) / self.b
-
-
- class IsotropicGaussian(object):
- def __init__(self, mu, sigma, device=None):
- self.device = device
- self.mu = mu.to(device)
- self.sigma = sigma.to(device)
-
- self.cte_term = -0.5 * torch.log(torch.tensor(2 * np.pi))
- self.det_sig_term = -torch.log(self.sigma)
-
- def log_prob(self, x, do_sum=True):
- dist_term = -0.5 * ((x - self.mu) / self.sigma) ** 2
- if do_sum:
- return (self.cte_term + self.det_sig_term + dist_term).sum()
- else:
- return self.cte_term + self.det_sig_term + dist_term
-
-
-
- # ScaleMixture prior hyperparameters
- PI = 0.5
- SIGMA_1 = torch.FloatTensor([math.exp(-0)])
- SIGMA_2 = torch.FloatTensor([math.exp(-6)])
-
- # Laplace and IsotropicGaussian priors hyperparameters
- MU = torch.FloatTensor([0])
- B = torch.FloatTensor([0.1])
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
- class BayesianLinear(nn.Module):
- def __init__(self, in_features, out_features, prior_type='ScaleMixtureGaussian', device=None):
- super().__init__()
- self.device = device
- self.in_features = in_features
- self.out_features = out_features
- # Weight parameters
- self.weight_mu = nn.Parameter(
- torch.Tensor(out_features, in_features).uniform_(-0.2, 0.2).to(device)
- )
- self.weight_rho = nn.Parameter(
- torch.Tensor(out_features, in_features).uniform_(-5, -4).to(device)
- )
- self.weight = Gaussian(self.weight_mu, self.weight_rho, device)
- # Bias parameters
- self.bias_mu = nn.Parameter(
- torch.Tensor(out_features).uniform_(-0.2, 0.2).to(device)
- )
- self.bias_rho = nn.Parameter(
- torch.Tensor(out_features).uniform_(-5, -4).to(device)
- )
- self.bias = Gaussian(self.bias_mu, self.bias_rho, device)
- # Prior distributions
- if prior_type == 'ScaleMixtureGaussian':
- self.weight_prior = ScaleMixtureGaussian(PI, SIGMA_1, SIGMA_2, device)
- self.bias_prior = ScaleMixtureGaussian(PI, SIGMA_1, SIGMA_2, device)
- elif prior_type == 'Laplace':
- self.weight_prior = LaplacePrior(MU, B, device)
- self.bias_prior = LaplacePrior(MU, B, device)
- elif prior_type == 'IsotropicGaussian':
- self.weight_prior = IsotropicGaussian(MU, B, device)
- self.bias_prior = IsotropicGaussian(MU, B, device)
- self.log_prior = 0
- self.log_variational_posterior = 0
-
- def forward(self, input, sample=False, calculate_log_probs=False):
- input = input.to(self.device)
- if self.training or sample:
- weight = self.weight.sample()
- bias = self.bias.sample()
- else:
- weight = self.weight.mu
- bias = self.bias.mu
- if self.training or calculate_log_probs:
- self.log_prior = (
- self.weight_prior.log_prob(weight) + self.bias_prior.log_prob(bias)
- )
- self.log_variational_posterior = (
- self.weight.log_prob(weight) + self.bias.log_prob(bias)
- )
- else:
- self.log_prior, self.log_variational_posterior = 0, 0
-
- return F.linear(input, weight, bias)
|