| @@ -0,0 +1,140 @@ | |||
| import math | |||
| import matplotlib.pyplot as plt | |||
| import numpy as np | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.nn.functional as F | |||
| class Gaussian(object): | |||
| def __init__(self, mu, rho, device=None): | |||
| super().__init__() | |||
| self.mu = mu.to(device) | |||
| self.rho = rho.to(device) | |||
| self.device = device | |||
| self.normal = torch.distributions.Normal(torch.tensor(0.0, device=device), torch.tensor(1.0, device=device)) | |||
| @property | |||
| def sigma(self): | |||
| return torch.log1p(torch.exp(self.rho)) | |||
| def sample(self): | |||
| epsilon = self.normal.sample(self.rho.size()) | |||
| return self.mu + self.sigma * epsilon | |||
| def log_prob(self, input): | |||
| return (-math.log(math.sqrt(2 * math.pi)) | |||
| - torch.log(self.sigma) | |||
| - ((input - self.mu) ** 2) / (2 * self.sigma ** 2)).sum() | |||
| class ScaleMixtureGaussian(object): | |||
| def __init__(self, pi, sigma1, sigma2, device=None): | |||
| super().__init__() | |||
| self.pi = pi | |||
| self.device = device | |||
| self.sigma1 = sigma1.to(device) | |||
| self.sigma2 = sigma2.to(device) | |||
| self.gaussian1 = torch.distributions.Normal(torch.tensor(0.0, device=device), self.sigma1) | |||
| self.gaussian2 = torch.distributions.Normal(torch.tensor(0.0, device=device), self.sigma2) | |||
| def log_prob(self, input): | |||
| prob1 = torch.exp(self.gaussian1.log_prob(input.to(self.device))) | |||
| prob2 = torch.exp(self.gaussian2.log_prob(input.to(self.device))) | |||
| return (torch.log(self.pi * prob1 + (1 - self.pi) * prob2)).sum() | |||
| class LaplacePrior(object): | |||
| def __init__(self, mu, b, device=None): | |||
| super().__init__() | |||
| self.device = device | |||
| self.mu = mu.to(device) | |||
| self.b = b.to(device) | |||
| def log_prob(self, x, do_sum=True): | |||
| if do_sum: | |||
| return (-torch.log(2 * self.b) - torch.abs(x - self.mu) / self.b).sum() | |||
| else: | |||
| return -torch.log(2 * self.b) - torch.abs(x - self.mu) / self.b | |||
| class IsotropicGaussian(object): | |||
| def __init__(self, mu, sigma, device=None): | |||
| self.device = device | |||
| self.mu = mu.to(device) | |||
| self.sigma = sigma.to(device) | |||
| self.cte_term = -0.5 * torch.log(torch.tensor(2 * np.pi)) | |||
| self.det_sig_term = -torch.log(self.sigma) | |||
| def log_prob(self, x, do_sum=True): | |||
| dist_term = -0.5 * ((x - self.mu) / self.sigma) ** 2 | |||
| if do_sum: | |||
| return (self.cte_term + self.det_sig_term + dist_term).sum() | |||
| else: | |||
| return self.cte_term + self.det_sig_term + dist_term | |||
| # ScaleMixture prior hyperparameters | |||
| PI = 0.5 | |||
| SIGMA_1 = torch.FloatTensor([math.exp(-0)]) | |||
| SIGMA_2 = torch.FloatTensor([math.exp(-6)]) | |||
| # Laplace and IsotropicGaussian priors hyperparameters | |||
| MU = torch.FloatTensor([0]) | |||
| B = torch.FloatTensor([0.1]) | |||
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |||
| class BayesianLinear(nn.Module): | |||
| def __init__(self, in_features, out_features, prior_type='ScaleMixtureGaussian', device=None): | |||
| super().__init__() | |||
| self.device = device | |||
| self.in_features = in_features | |||
| self.out_features = out_features | |||
| # Weight parameters | |||
| self.weight_mu = nn.Parameter( | |||
| torch.Tensor(out_features, in_features).uniform_(-0.2, 0.2).to(device) | |||
| ) | |||
| self.weight_rho = nn.Parameter( | |||
| torch.Tensor(out_features, in_features).uniform_(-5, -4).to(device) | |||
| ) | |||
| self.weight = Gaussian(self.weight_mu, self.weight_rho, device) | |||
| # Bias parameters | |||
| self.bias_mu = nn.Parameter( | |||
| torch.Tensor(out_features).uniform_(-0.2, 0.2).to(device) | |||
| ) | |||
| self.bias_rho = nn.Parameter( | |||
| torch.Tensor(out_features).uniform_(-5, -4).to(device) | |||
| ) | |||
| self.bias = Gaussian(self.bias_mu, self.bias_rho, device) | |||
| # Prior distributions | |||
| if prior_type == 'ScaleMixtureGaussian': | |||
| self.weight_prior = ScaleMixtureGaussian(PI, SIGMA_1, SIGMA_2, device) | |||
| self.bias_prior = ScaleMixtureGaussian(PI, SIGMA_1, SIGMA_2, device) | |||
| elif prior_type == 'Laplace': | |||
| self.weight_prior = LaplacePrior(MU, B, device) | |||
| self.bias_prior = LaplacePrior(MU, B, device) | |||
| elif prior_type == 'IsotropicGaussian': | |||
| self.weight_prior = IsotropicGaussian(MU, B, device) | |||
| self.bias_prior = IsotropicGaussian(MU, B, device) | |||
| self.log_prior = 0 | |||
| self.log_variational_posterior = 0 | |||
| def forward(self, input, sample=False, calculate_log_probs=False): | |||
| input = input.to(self.device) | |||
| if self.training or sample: | |||
| weight = self.weight.sample() | |||
| bias = self.bias.sample() | |||
| else: | |||
| weight = self.weight.mu | |||
| bias = self.bias.mu | |||
| if self.training or calculate_log_probs: | |||
| self.log_prior = ( | |||
| self.weight_prior.log_prob(weight) + self.bias_prior.log_prob(bias) | |||
| ) | |||
| self.log_variational_posterior = ( | |||
| self.weight.log_prob(weight) + self.bias.log_prob(bias) | |||
| ) | |||
| else: | |||
| self.log_prior, self.log_variational_posterior = 0, 0 | |||
| return F.linear(input, weight, bias) | |||
| @@ -0,0 +1,122 @@ | |||
| # -*- Encoding:UTF-8 -*- | |||
| import numpy as np | |||
| import sys | |||
| class DataSet(object): | |||
| def __init__(self, fileName): | |||
| self.data, self.shape = self.getData(fileName) | |||
| self.train, self.test = self.getTrainTest() | |||
| self.trainDict = self.getTrainDict() | |||
| def getData(self, fileName): | |||
| if fileName == 'ml-1m' or fileName == 'ml-100k': | |||
| # print(f"Loading {fileName} data set...") | |||
| if fileName == 'ml-1m': | |||
| filePath = './Data/ml-1m/ratings.dat' | |||
| separator = '::' | |||
| else: | |||
| filePath = './Data/ml-100k/u.data' | |||
| separator = '\t' | |||
| data = [] | |||
| u = 0 | |||
| i = 0 | |||
| maxr = 0.0 | |||
| with open(filePath, 'r') as f: | |||
| for line in f: | |||
| if line: | |||
| lines = line[:-1].split(separator) | |||
| user = int(lines[0]) | |||
| movie = int(lines[1]) | |||
| score = float(lines[2]) | |||
| time = int(lines[3]) | |||
| data.append((user, movie, score, time)) | |||
| if user > u: | |||
| u = user | |||
| if movie > i: | |||
| i = movie | |||
| if score > maxr: | |||
| maxr = score | |||
| self.maxRate = maxr | |||
| # print("Loading Success!\n" | |||
| # "Data Info:\n" | |||
| # "\tUser Num: {}\n" | |||
| # "\tItem Num: {}\n" | |||
| # "\tData Size: {}".format(u, i, len(data))) | |||
| return data, [u, i] | |||
| else: | |||
| print("Current data set is not support!") | |||
| sys.exit() | |||
| def getTrainTest(self): | |||
| data = self.data | |||
| data = sorted(data, key=lambda x: (x[0], x[3])) | |||
| train = [] | |||
| test = [] | |||
| for i in range(len(data)-1): | |||
| user = data[i][0]-1 | |||
| item = data[i][1]-1 | |||
| rate = data[i][2] | |||
| if data[i][0] != data[i+1][0]: | |||
| test.append((user, item, rate)) | |||
| else: | |||
| train.append((user, item, rate)) | |||
| test.append((data[-1][0]-1, data[-1][1]-1, data[-1][2])) | |||
| return train, test | |||
| def getTrainDict(self): | |||
| dataDict = {} | |||
| for i in self.train: | |||
| dataDict[(i[0], i[1])] = i[2] | |||
| return dataDict | |||
| def getEmbedding(self): | |||
| train_matrix = np.zeros([self.shape[0], self.shape[1]], dtype=np.float32) | |||
| for i in self.train: | |||
| user = i[0] | |||
| movie = i[1] | |||
| rating = i[2] | |||
| train_matrix[user][movie] = rating | |||
| return np.array(train_matrix) | |||
| def getInstances(self, data, negNum): | |||
| user = [] | |||
| item = [] | |||
| rate = [] | |||
| for i in data: | |||
| user.append(i[0]) | |||
| item.append(i[1]) | |||
| rate.append(i[2]) | |||
| for t in range(negNum): | |||
| j = np.random.randint(self.shape[1]) | |||
| while (i[0], j) in self.trainDict: | |||
| j = np.random.randint(self.shape[1]) | |||
| user.append(i[0]) | |||
| item.append(j) | |||
| rate.append(0.0) | |||
| return np.array(user), np.array(item), np.array(rate) | |||
| def getTestNeg(self, testData, negNum): | |||
| user = [] | |||
| item = [] | |||
| for s in testData: | |||
| tmp_user = [] | |||
| tmp_item = [] | |||
| u = s[0] | |||
| i = s[1] | |||
| tmp_user.append(u) | |||
| tmp_item.append(i) | |||
| neglist = set() | |||
| neglist.add(i) | |||
| for t in range(negNum): | |||
| j = np.random.randint(self.shape[1]) | |||
| while (u, j) in self.trainDict or j in neglist: | |||
| j = np.random.randint(self.shape[1]) | |||
| neglist.add(j) | |||
| tmp_user.append(u) | |||
| tmp_item.append(j) | |||
| user.append(tmp_user) | |||
| item.append(tmp_item) | |||
| return [np.array(user), np.array(item)] | |||
| @@ -0,0 +1,76 @@ | |||
| ## Project Overview: An Ensemble Bayesian Neural Network for Recommendation Systems | |||
| This project implements a sophisticated recommendation system that leverages the power of **Bayesian Neural Networks (BNNs)** and **ensemble learning**. The core idea is to build a robust collaborative filtering model that not only provides accurate recommendations but also quantifies the uncertainty associated with its predictions. This is achieved by combining multiple, diverse BNN models into a powerful ensemble, whose collective predictions are intelligently aggregated by a meta-learning component. | |||
| The system is designed to work with standard recommendation datasets like MovieLens and is evaluated on ranking-based metrics such as Hit Ratio (HR) and Normalized Discounted Cumulative Gain (NDCG). | |||
| --- | |||
| ## Architectural Breakdown | |||
| The project is logically structured into three main Python files, each with a distinct responsibility. | |||
| ### **File 1: `BNN.py` — The Bayesian Building Blocks** | |||
| This script defines the fundamental components required to construct Bayesian Neural Networks. Instead of learning fixed-point weights like standard neural networks, BNNs learn probability distributions over their weights. This file provides the necessary classes to manage these distributions. | |||
| **Key Components:** | |||
| * **Distribution Classes:** | |||
| * `Gaussian`: Implements a Gaussian distribution for the weights and biases of the BNN layers. It uses the reparameterization trick (`μ + σ * ε`) for efficient sampling during training. The standard deviation `σ` is derived from a learnable parameter `ρ` to ensure it remains positive. | |||
| * `ScaleMixtureGaussian`: Defines a more flexible prior distribution for weights, constructed as a mixture of two Gaussian distributions with different variances. This allows the model to distinguish between highly important weights and those that can be pruned, effectively encouraging sparsity. | |||
| * `LaplacePrior` & `IsotropicGaussian`: Provide alternative, simpler prior distributions (Laplace and standard Gaussian, respectively) that can be used for regularization and experimentation. | |||
| * **Core BNN Layer:** | |||
| * `BayesianLinear`: This is the cornerstone module, acting as a drop-in replacement for a standard `torch.nn.Linear` layer. It maintains learnable parameters (`mu` and `rho`) for the distributions of its weights and biases. During training, it calculates the **log-prior** (how well the sampled weights fit the prior distribution) and the **log-variational-posterior** (how probable the sampled weights are under their learned distribution). These two values are essential for computing the KL-divergence, a key component of the BNN's loss function. | |||
| ### **File 2: `DataSet.py` — Data Handling and Preprocessing** | |||
| This script is dedicated to loading, parsing, and preparing the dataset for training and evaluation. It handles the specifics of the MovieLens datasets and transforms the raw data into a format suitable for the PyTorch model. | |||
| **Key Class: `DataSet`** | |||
| * **Initialization & Data Loading:** The constructor, via the `getData` method, loads the specified MovieLens dataset (`ml-1m` or `ml-100k`) from file, parsing user IDs, item IDs, and ratings. | |||
| * **Train-Test Splitting:** The `getTrainTest` method implements a temporal split. For each user, their last interaction is held out for the test set, while all preceding interactions form the training set. This is a standard and realistic evaluation protocol in recommendation systems. | |||
| * **Negative Sampling:** | |||
| * `getInstances`: For each positive user-item interaction in the training set, this method samples a specified number of "negative" items (items the user has not interacted with). This is crucial for training the model to discriminate between relevant and irrelevant items. | |||
| * `getTestNeg`: Prepares the test set for ranking evaluation. For each user's true positive item in the test set, it samples 99 negative items, creating a list of 100 items to rank. | |||
| * **Embedding Matrix:** The `getEmbedding` method constructs a full user-item interaction matrix, which is used to initialize the input embeddings for the model. | |||
| ### **File 3: `main.py` — Model Architecture, Training, and Evaluation** | |||
| This is the main driver script that assembles the components from the other files into a complete system, defines the training loop, and manages the ensemble logic. | |||
| **Key Classes:** | |||
| * **`Model`:** This class defines the architecture of a single BNN-based recommendation model. | |||
| * **Initialization:** It sets up separate user and item processing streams. Each stream consists of a stack of `BayesianLinear` layers. It also initializes an attention mechanism. | |||
| * **Embeddings:** It uses the user-item interaction matrix from `DataSet.py` as a non-trainable input embedding. | |||
| * **Forward Pass:** A user and an item are passed through their respective BNN towers. The resulting latent representations are then combined element-wise and fed into a `MultiheadAttention` layer to capture complex interactions. A final MLP block (`interaction_layer`) produces the predicted interaction probability. | |||
| * **Loss Calculation:** Includes helper methods (`log_prior`, `log_variational_posterior`, `sample_elbo`) to compute the total model loss, which is a combination of the standard Binary Cross-Entropy (BCE) loss and the KL-divergence term (the "complexity cost") from the Bayesian layers. | |||
| * **`SuperModel`:** This class defines the meta-learner for the ensemble. | |||
| * **Architecture:** It is a small neural network that takes the concatenated prediction scores from all individual models in the ensemble as input. | |||
| * **Function:** It learns to weigh the predictions from the different base models to produce a final, more accurate prediction, effectively learning the strengths and weaknesses of each ensemble member. | |||
| **Execution Flow:** | |||
| 1. **Ensemble Initialization:** The `main` function begins by creating a diverse ensemble of `Model` instances. Diversity is achieved by randomly assigning different network architectures (layer depths/widths) and prior distributions (`ScaleMixtureGaussian`, `Laplace`, etc.) to each model. Each model is also trained on a different **bootstrap sample** of the training data. | |||
| 2. **Individual Model Training:** Each model in the ensemble is trained independently using the `run_epoch` function. | |||
| * The loss function is the **Evidence Lower Bound (ELBO)**, which balances the BCE loss (fitting the data) with the KL divergence (regularizing the model complexity). | |||
| * After each epoch, the model is evaluated using the `evaluate` function, which calculates HR@10 and NDCG@10. The best-performing checkpoint for each model is saved. | |||
| 3. **Super Model Training:** After all base models are trained, the saved best checkpoints are loaded. A `SuperModel` is then instantiated and trained via the `train_super_model` function. It learns to combine the predictions of the frozen base models. | |||
| 4. **Final Evaluation:** The entire ensemble, aggregated by the trained `SuperModel`, is evaluated on the test set using the `ensemble_eval` function to report the final HR and NDCG scores. | |||
| --- | |||
| ## Summary of Key Features | |||
| * **Uncertainty Quantification:** The use of `BayesianLinear` layers allows the model to capture uncertainty in its weights, leading to more robust predictions. | |||
| * **Ensemble Diversity:** The system actively promotes diversity through architectural heterogeneity and data bootstrapping (bagging), which is key to a successful ensemble. | |||
| * **Advanced Interaction Modeling:** A `MultiheadAttention` mechanism is used to effectively model the complex, non-linear interactions between user and item latent features. | |||
| * **Meta-Learning for Aggregation:** Instead of simple averaging, a dedicated `SuperModel` learns the optimal way to combine predictions from the ensemble members. | |||
| * **Principled Loss Function:** The training relies on optimizing the ELBO, a standard and theoretically grounded objective for variational inference in Bayesian models. | |||
| @@ -0,0 +1,412 @@ | |||
| # -*- Encoding:UTF-8 -*- | |||
| import torch | |||
| import torch.nn as nn | |||
| import torch.optim as optim | |||
| import numpy as np | |||
| import argparse | |||
| import os | |||
| import heapq | |||
| import math | |||
| import random | |||
| from DataSet import DataSet | |||
| from BNN import * | |||
| # Set device to GPU if available | |||
| DEVICE = torch.device('cuda') if torch.cuda.is_available() else 'cpu' | |||
| print(DEVICE) | |||
| # Seed for reproducibility | |||
| seed = 0 | |||
| random.seed(seed) | |||
| np.random.seed(seed) | |||
| torch.manual_seed(seed) | |||
| torch.cuda.manual_seed(seed) | |||
| torch.backends.cudnn.deterministic = True | |||
| torch.backends.cudnn.benchmark = False | |||
| class Model(nn.Module): | |||
| def __init__(self, args): | |||
| super(Model, self).__init__() | |||
| self.dataName = args.dataName | |||
| self.dataSet = DataSet(self.dataName) | |||
| self.shape = self.dataSet.shape | |||
| self.maxRate = self.dataSet.maxRate | |||
| self.sample_size = 2 | |||
| self.train_data = self.bootstrap_sample(self.sample_size) | |||
| self.test_data = self.dataSet.test | |||
| self.negNum = args.negNum | |||
| self.testNeg = self.dataSet.getTestNeg(self.test_data, 99) | |||
| self.userLayer = args.userLayer | |||
| self.itemLayer = args.itemLayer | |||
| self.user_item_embedding = nn.Parameter( | |||
| torch.tensor(self.dataSet.getEmbedding(), dtype=torch.float32).to(DEVICE)) | |||
| self.item_user_embedding = self.user_item_embedding.t().to(DEVICE) | |||
| # Define User Layer | |||
| self.user_layers = nn.ModuleList() | |||
| input_size = self.shape[1] | |||
| for size in self.userLayer: | |||
| self.user_layers.append(BayesianLinear(input_size, size, prior_type=args.priorType, device=DEVICE)) | |||
| input_size = size | |||
| # Define Item Layer | |||
| self.item_layers = nn.ModuleList() | |||
| input_size = self.shape[0] | |||
| for size in self.itemLayer: | |||
| self.item_layers.append(BayesianLinear(input_size, size, prior_type=args.priorType, device=DEVICE)) | |||
| input_size = size | |||
| self.interaction_layer = nn.Sequential( | |||
| nn.Linear(input_size, 64), | |||
| nn.ReLU(), | |||
| nn.Linear(64, 32), | |||
| nn.ReLU(), | |||
| nn.Linear(32, 1) | |||
| ) | |||
| self.attention = nn.MultiheadAttention(embed_dim=input_size, num_heads=4, dropout=0.3, batch_first=True) | |||
| def forward(self, user, item): | |||
| user_input = self.user_item_embedding[user] | |||
| item_input = self.item_user_embedding[item] | |||
| for layer in self.user_layers: | |||
| user_input = torch.relu(layer(user_input)) | |||
| for layer in self.item_layers: | |||
| item_input = torch.relu(layer(item_input)) | |||
| user_att = user_input.unsqueeze(1) # Shape: (batch_size, 1, embed_dim) | |||
| item_att = item_input.unsqueeze(1) # Shape: (batch_size, 1, embed_dim) | |||
| combined = user_att * item_att # Shape: (batch_size, 2, embed_dim) | |||
| att_output, att_weights = self.attention( | |||
| query=combined, | |||
| key=combined, | |||
| value=combined | |||
| ) | |||
| att_output = att_output.mean(dim=1) | |||
| interaction_input = att_output | |||
| y_hat = self.interaction_layer(interaction_input) | |||
| y_hat = torch.sigmoid(y_hat) | |||
| return torch.clamp(y_hat.squeeze(), min=1e-6, max=1.0) | |||
| def log_prior(self, type): | |||
| if type == "user": | |||
| return sum(layer.log_prior for layer in self.user_layers) | |||
| else: | |||
| return sum(layer.log_prior for layer in self.item_layers) | |||
| def log_variational_posterior(self, type): | |||
| if type == "user": | |||
| return sum(layer.log_variational_posterior for layer in self.user_layers) | |||
| else: | |||
| return sum(layer.log_variational_posterior for layer in self.item_layers) | |||
| def sample_elbo(self, user_tensor, item_tensor, target, num_samples, num_batches): | |||
| outputs = torch.zeros(num_samples, user_tensor.size(0), device=DEVICE) | |||
| user_log_priors = torch.zeros(num_samples, device=DEVICE) | |||
| user_log_variational_posteriors = torch.zeros(num_samples, device=DEVICE) | |||
| item_log_priors = torch.zeros(num_samples, device=DEVICE) | |||
| item_log_variational_posteriors = torch.zeros(num_samples, device=DEVICE) | |||
| for i in range(num_samples): | |||
| outputs[i] = self(user_tensor, item_tensor) | |||
| user_log_priors[i] = self.log_prior(type="user") | |||
| user_log_variational_posteriors[i] = self.log_variational_posterior(type="user") | |||
| item_log_priors[i] = self.log_prior(type="item") | |||
| item_log_variational_posteriors[i] = self.log_variational_posterior(type="item") | |||
| user_log_prior = user_log_priors.mean() | |||
| user_log_variational_posterior = user_log_variational_posteriors.mean() | |||
| item_log_prior = item_log_priors.mean() | |||
| item_log_variational_posterior = item_log_variational_posteriors.mean() | |||
| item_loss = (item_log_variational_posterior - item_log_prior) | |||
| user_loss = (user_log_variational_posterior - user_log_prior) | |||
| return user_loss + item_loss | |||
| def bootstrap_sample(self, sample_size): | |||
| """ | |||
| Generate a bootstrapped dataset by sampling with replacement. | |||
| """ | |||
| indices = np.random.choice(len(self.dataSet.train), size=len(self.dataSet.train) // sample_size, | |||
| replace=True) | |||
| sampled_train = [self.dataSet.train[i] for i in indices] | |||
| return sampled_train | |||
| class SuperModel(nn.Module): | |||
| """ | |||
| A super model that combines predictions from multiple ensemble models using a neural network. | |||
| """ | |||
| def __init__(self, ensemble_models, input_size): | |||
| super(SuperModel, self).__init__() | |||
| self.ensemble_models = ensemble_models | |||
| self.combiner = nn.Sequential( | |||
| nn.Linear(input_size, input_size // 2), | |||
| nn.ReLU(), | |||
| nn.Linear(input_size // 2, 1), | |||
| nn.Sigmoid() # To ensure the output is a probability | |||
| ) | |||
| def forward(self, user, item): | |||
| """ | |||
| Forward pass of the super model. | |||
| Combines predictions from ensemble models using a neural network. | |||
| """ | |||
| ensemble_predictions = [] | |||
| with torch.no_grad(): # Ensure no gradients are computed for ensemble models | |||
| for model in self.ensemble_models: | |||
| model.eval() # Set individual models to evaluation mode | |||
| predictions = model(user, item) | |||
| ensemble_predictions.append(predictions) | |||
| # Stack predictions to create input for the combiner network | |||
| stacked_predictions = torch.stack(ensemble_predictions, dim=1) # Shape: (batch_size, num_ensemble_models) | |||
| combined_predictions = self.combiner(stacked_predictions).squeeze(-1) # Shape: (batch_size,) | |||
| return combined_predictions | |||
| def run_epoch(model, optimizer, criterion, args): | |||
| model.train() | |||
| train_u, train_i, train_r = model.dataSet.getInstances(model.train_data, args.negNum) | |||
| train_len = len(train_u) | |||
| shuffled_idx = np.random.permutation(np.arange(train_len)) | |||
| train_u, train_i, train_r = train_u[shuffled_idx], train_i[shuffled_idx], train_r[shuffled_idx] | |||
| num_batches = (train_len + args.batchSize - 1) // args.batchSize | |||
| BCE_losses, kls = [], [] | |||
| for i in range(num_batches): | |||
| min_idx = i * args.batchSize | |||
| max_idx = min(train_len, (i + 1) * args.batchSize) | |||
| user_tensor = torch.tensor(train_u[min_idx:max_idx], dtype=torch.long).to(DEVICE) | |||
| item_tensor = torch.tensor(train_i[min_idx:max_idx], dtype=torch.long).to(DEVICE) | |||
| rate_tensor = torch.tensor(train_r[min_idx:max_idx], dtype=torch.float32).to(DEVICE) | |||
| rate_tensor = (rate_tensor - rate_tensor.min()) / (rate_tensor.max() - rate_tensor.min()) | |||
| optimizer.zero_grad() | |||
| y_hat = model(user_tensor, item_tensor) | |||
| loss = criterion(y_hat, rate_tensor) | |||
| BCE_losses.append(loss.item()) | |||
| kl_coef = 4.42322e-08 | |||
| loss += kl_coef * model.sample_elbo(user_tensor, item_tensor, rate_tensor, 5, num_batches) | |||
| loss.backward() | |||
| optimizer.step() | |||
| kls.append(loss.item()) | |||
| if i % 10 == 0: | |||
| print(f'\rBatch {i}/{num_batches}: KL = {np.mean(kls[-10:]):.4f}, BCE = {np.mean(BCE_losses[-10:]):.4f}', end='') | |||
| print(f"\nMean BCE Loss: {np.mean(BCE_losses):.4f}") | |||
| print(f"Mean KL Divergence: {np.mean(kls):.4f}") | |||
| return np.mean(kls) | |||
| def evaluate(model, topK): | |||
| model.eval() | |||
| hr, NDCG = [], [] | |||
| with torch.no_grad(): | |||
| for i in range(len(model.testNeg[0])): | |||
| user_tensor = model.testNeg[0][i] | |||
| item_tensor = model.testNeg[1][i] | |||
| predict = model(user_tensor, item_tensor) | |||
| item_score_dict = {item: predict[j].item() for j, item in enumerate(item_tensor)} | |||
| ranklist = heapq.nlargest(topK, item_score_dict, key=item_score_dict.get) | |||
| hr.append(1 if item_tensor[0].item() in ranklist else 0) | |||
| NDCG.append(math.log(2) / math.log(ranklist.index(item_tensor[0].item()) + 2) if item_tensor[0].item() in ranklist else 0) | |||
| return np.mean(hr), np.mean(NDCG) | |||
| def main(): | |||
| parser = argparse.ArgumentParser(description="Options") | |||
| parser.add_argument('-dataName', action='store', dest='dataName', default='ml-100k') | |||
| parser.add_argument('-negNum', action='store', dest='negNum', default=5, type=int) | |||
| parser.add_argument('-userLayer', action='store', dest='userLayer', default=[512, 64, 64], type=int, nargs='+') | |||
| parser.add_argument('-itemLayer', action='store', dest='itemLayer', default=[1024, 64, 64], type=int, nargs='+') | |||
| parser.add_argument('-lr', action='store', dest='lr', default=0.0001, type=float) | |||
| parser.add_argument('-maxEpochs', action='store', dest='maxEpochs', default=50, type=int) | |||
| parser.add_argument('-batchSize', action='store', dest='batchSize', default=256, type=int) | |||
| parser.add_argument('-earlyStop', action='store', dest='earlyStop', default=5, type=int) | |||
| parser.add_argument('-checkPoint', action='store', dest='checkPoint', default='./checkPoint/') | |||
| parser.add_argument('-topK', action='store', dest='topK', default=10, type=int) | |||
| parser.add_argument('-loadModel', action='store_true', dest='loadModel', help="Load a saved model") | |||
| parser.add_argument('-ensembleSize', action='store', dest='ensembleSize', default=10, type=int) | |||
| parser.add_argument('-maxEpochN', action='store', dest='maxEpochN', default=30, type=int) | |||
| parser.add_argument('-priorType', action='store', dest='priorType', default='ScaleMixtureGaussian', | |||
| choices=['ScaleMixtureGaussian', 'Laplace', 'IsotropicGaussian']) | |||
| args = parser.parse_args() | |||
| if not os.path.exists(args.checkPoint): | |||
| os.mkdir(args.checkPoint) | |||
| ensemble_models = [] | |||
| optimizers = [] | |||
| ensemble_args = [] | |||
| network_layers = [[512, 64, 64], [512, 64], [1024, 64, 64], [512, 256, 64], [1024, 256, 256, 64]] | |||
| prior_types = ['ScaleMixtureGaussian', 'Laplace', 'IsotropicGaussian'] | |||
| for ensemble_idx in range(args.ensembleSize): | |||
| args_copy = argparse.Namespace(**vars(args)) | |||
| args_copy.userLayer = random.choice(network_layers) | |||
| args_copy.itemLayer = random.choice(network_layers) | |||
| args_copy.priorType = random.choice(prior_types) | |||
| ensemble_model = Model(args_copy).to(DEVICE) | |||
| ensemble_args.append(args_copy) | |||
| optimizer = optim.Adam(ensemble_model.parameters(), lr=args.lr) | |||
| ensemble_models.append(ensemble_model) | |||
| optimizers.append(optimizer) | |||
| criterion = nn.BCELoss() | |||
| for ensemble_idx in range(args.ensembleSize): | |||
| best_hr = -1 | |||
| best_NDCG = -1 | |||
| best_epoch = -1 | |||
| print(f'Ensemble Model Number {ensemble_idx}') | |||
| print("Start Training!") | |||
| print(ensemble_args[ensemble_idx]) | |||
| classifier = ensemble_models[ensemble_idx] | |||
| optimizer = optimizers[ensemble_idx] | |||
| for epoch in range(args.maxEpochs): | |||
| print("=" * 20 + "Epoch " + str(epoch) + "=" * 20) | |||
| run_epoch(classifier, optimizer, criterion, args) | |||
| print('=' * 50) | |||
| print("Start Evaluation!") | |||
| hr, NDCG = evaluate(classifier, args.topK) | |||
| print("Epoch ", epoch, "HR: {}, NDCG: {}".format(hr, NDCG)) | |||
| if hr > best_hr or NDCG > best_NDCG: | |||
| best_hr = hr | |||
| best_NDCG = NDCG | |||
| best_epoch = epoch | |||
| torch.save(classifier.state_dict(), os.path.join(args.checkPoint, f'model{ensemble_idx}.pth')) | |||
| if epoch - best_epoch > args.earlyStop: | |||
| print("Normal Early stop!") | |||
| break | |||
| print("=" * 20 + "Epoch " + str(epoch) + " End" + "=" * 20) | |||
| print("Best hr: {}, NDCG: {}, At Epoch {}".format(best_hr, best_NDCG, best_epoch)) | |||
| print("Training complete!\n") | |||
| for ensemble_idx in range(args.ensembleSize): | |||
| model_path = os.path.join(args.checkPoint, f'model{ensemble_idx}.pth') | |||
| if os.path.exists(model_path): | |||
| print("Loading saved model from", model_path) | |||
| ensemble_models[ensemble_idx].load_state_dict(torch.load(model_path)) | |||
| else: | |||
| print("No saved model found at", model_path) | |||
| super_model = SuperModel(ensemble_models, input_size=len(ensemble_models)).to(DEVICE) | |||
| for epoch in range(args.maxEpochN): | |||
| train_super_model(super_model, ensemble_models[0].dataSet.train, args) | |||
| print("\nStart Testing") | |||
| total_hr, total_NDCG = ensemble_eval(ensemble_models, super_model, args.topK) | |||
| print("total hr: {}, total NDCG: {}".format(total_hr, total_NDCG)) | |||
| def train_super_model(super_model, train_data, args): | |||
| super_model.train() | |||
| optimizer = optim.Adam(super_model.parameters(), lr=args.lr) | |||
| criterion = nn.BCELoss() | |||
| train_u, train_i, train_r = super_model.ensemble_models[0].dataSet.getInstances(train_data, args.negNum) | |||
| train_len = len(train_u) | |||
| shuffled_idx = np.random.permutation(np.arange(train_len)) | |||
| train_u = train_u[shuffled_idx] | |||
| train_i = train_i[shuffled_idx] | |||
| train_r = train_r[shuffled_idx] | |||
| num_batches = len(train_u) // args.batchSize + 1 | |||
| losses = [] | |||
| for i in range(num_batches): | |||
| min_idx = i * args.batchSize | |||
| max_idx = min(train_len, (i + 1) * args.batchSize) | |||
| user_tensor = torch.tensor(train_u[min_idx:max_idx], dtype=torch.long).to(DEVICE) | |||
| item_tensor = torch.tensor(train_i[min_idx:max_idx], dtype=torch.long).to(DEVICE) | |||
| rate_tensor = torch.tensor(train_r[min_idx:max_idx], dtype=torch.float32).to(DEVICE) | |||
| rate_tensor = (rate_tensor - rate_tensor.min()) / (rate_tensor.max() - rate_tensor.min()) | |||
| optimizer.zero_grad() | |||
| y_hat = super_model(user_tensor, item_tensor) | |||
| loss = criterion(y_hat, rate_tensor) | |||
| loss.backward() | |||
| optimizer.step() | |||
| losses.append(loss.item()) | |||
| if i % 10 == 0: | |||
| print(f'\rBatch {i}/{num_batches}: loss = {np.mean(losses[-10:]):.4f}', end='') | |||
| print("\nMean loss for super model in this epoch is: {}".format(np.mean(losses))) | |||
| def ensemble_eval(ensemble_models, superModel,topK): | |||
| def getHitRatio(ranklist, targetItem): | |||
| return 1 if targetItem in ranklist else 0 | |||
| def getNDCG(ranklist, targetItem): | |||
| for i, item in enumerate(ranklist): | |||
| if item == targetItem: | |||
| return math.log(2) / math.log(i + 2) | |||
| return 0 | |||
| hr = [] | |||
| NDCG = [] | |||
| testUser = ensemble_models[0].testNeg[0] | |||
| testItem = ensemble_models[0].testNeg[1] | |||
| with torch.no_grad(): | |||
| for i in range(len(testUser)): | |||
| target = testItem[i][0] | |||
| user_tensor = torch.tensor(testUser[i], dtype=torch.long).to(DEVICE) | |||
| item_tensor = torch.tensor(testItem[i], dtype=torch.long).to(DEVICE) | |||
| # ensemble_predicts = [] | |||
| # for model in ensemble_models: | |||
| # predict = model(user_tensor, item_tensor) | |||
| # ensemble_predicts.append(predict) | |||
| total_predict = superModel(user_tensor, item_tensor) | |||
| # print(total_predict) | |||
| item_score_dict = {item: total_predict[j].item() for j, item in enumerate(testItem[i])} | |||
| ranklist = heapq.nlargest(topK, item_score_dict, key=item_score_dict.get) | |||
| tmp_hr = getHitRatio(ranklist, target) | |||
| tmp_NDCG = getNDCG(ranklist, target) | |||
| hr.append(tmp_hr) | |||
| NDCG.append(tmp_NDCG) | |||
| return np.mean(hr), np.mean(NDCG) | |||
| if __name__ == '__main__': | |||
| main() | |||