Bayesian Deep Ensemble Collaborative Filtering
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

main.py 16KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412
  1. # -*- Encoding:UTF-8 -*-
  2. import torch
  3. import torch.nn as nn
  4. import torch.optim as optim
  5. import numpy as np
  6. import argparse
  7. import os
  8. import heapq
  9. import math
  10. import random
  11. from DataSet import DataSet
  12. from BNN import *
  13. # Set device to GPU if available
  14. DEVICE = torch.device('cuda') if torch.cuda.is_available() else 'cpu'
  15. print(DEVICE)
  16. # Seed for reproducibility
  17. seed = 0
  18. random.seed(seed)
  19. np.random.seed(seed)
  20. torch.manual_seed(seed)
  21. torch.cuda.manual_seed(seed)
  22. torch.backends.cudnn.deterministic = True
  23. torch.backends.cudnn.benchmark = False
  24. class Model(nn.Module):
  25. def __init__(self, args):
  26. super(Model, self).__init__()
  27. self.dataName = args.dataName
  28. self.dataSet = DataSet(self.dataName)
  29. self.shape = self.dataSet.shape
  30. self.maxRate = self.dataSet.maxRate
  31. self.sample_size = 2
  32. self.train_data = self.bootstrap_sample(self.sample_size)
  33. self.test_data = self.dataSet.test
  34. self.negNum = args.negNum
  35. self.testNeg = self.dataSet.getTestNeg(self.test_data, 99)
  36. self.userLayer = args.userLayer
  37. self.itemLayer = args.itemLayer
  38. self.user_item_embedding = nn.Parameter(
  39. torch.tensor(self.dataSet.getEmbedding(), dtype=torch.float32).to(DEVICE))
  40. self.item_user_embedding = self.user_item_embedding.t().to(DEVICE)
  41. # Define User Layer
  42. self.user_layers = nn.ModuleList()
  43. input_size = self.shape[1]
  44. for size in self.userLayer:
  45. self.user_layers.append(BayesianLinear(input_size, size, prior_type=args.priorType, device=DEVICE))
  46. input_size = size
  47. # Define Item Layer
  48. self.item_layers = nn.ModuleList()
  49. input_size = self.shape[0]
  50. for size in self.itemLayer:
  51. self.item_layers.append(BayesianLinear(input_size, size, prior_type=args.priorType, device=DEVICE))
  52. input_size = size
  53. self.interaction_layer = nn.Sequential(
  54. nn.Linear(input_size, 64),
  55. nn.ReLU(),
  56. nn.Linear(64, 32),
  57. nn.ReLU(),
  58. nn.Linear(32, 1)
  59. )
  60. self.attention = nn.MultiheadAttention(embed_dim=input_size, num_heads=4, dropout=0.3, batch_first=True)
  61. def forward(self, user, item):
  62. user_input = self.user_item_embedding[user]
  63. item_input = self.item_user_embedding[item]
  64. for layer in self.user_layers:
  65. user_input = torch.relu(layer(user_input))
  66. for layer in self.item_layers:
  67. item_input = torch.relu(layer(item_input))
  68. user_att = user_input.unsqueeze(1) # Shape: (batch_size, 1, embed_dim)
  69. item_att = item_input.unsqueeze(1) # Shape: (batch_size, 1, embed_dim)
  70. combined = user_att * item_att # Shape: (batch_size, 2, embed_dim)
  71. att_output, att_weights = self.attention(
  72. query=combined,
  73. key=combined,
  74. value=combined
  75. )
  76. att_output = att_output.mean(dim=1)
  77. interaction_input = att_output
  78. y_hat = self.interaction_layer(interaction_input)
  79. y_hat = torch.sigmoid(y_hat)
  80. return torch.clamp(y_hat.squeeze(), min=1e-6, max=1.0)
  81. def log_prior(self, type):
  82. if type == "user":
  83. return sum(layer.log_prior for layer in self.user_layers)
  84. else:
  85. return sum(layer.log_prior for layer in self.item_layers)
  86. def log_variational_posterior(self, type):
  87. if type == "user":
  88. return sum(layer.log_variational_posterior for layer in self.user_layers)
  89. else:
  90. return sum(layer.log_variational_posterior for layer in self.item_layers)
  91. def sample_elbo(self, user_tensor, item_tensor, target, num_samples, num_batches):
  92. outputs = torch.zeros(num_samples, user_tensor.size(0), device=DEVICE)
  93. user_log_priors = torch.zeros(num_samples, device=DEVICE)
  94. user_log_variational_posteriors = torch.zeros(num_samples, device=DEVICE)
  95. item_log_priors = torch.zeros(num_samples, device=DEVICE)
  96. item_log_variational_posteriors = torch.zeros(num_samples, device=DEVICE)
  97. for i in range(num_samples):
  98. outputs[i] = self(user_tensor, item_tensor)
  99. user_log_priors[i] = self.log_prior(type="user")
  100. user_log_variational_posteriors[i] = self.log_variational_posterior(type="user")
  101. item_log_priors[i] = self.log_prior(type="item")
  102. item_log_variational_posteriors[i] = self.log_variational_posterior(type="item")
  103. user_log_prior = user_log_priors.mean()
  104. user_log_variational_posterior = user_log_variational_posteriors.mean()
  105. item_log_prior = item_log_priors.mean()
  106. item_log_variational_posterior = item_log_variational_posteriors.mean()
  107. item_loss = (item_log_variational_posterior - item_log_prior)
  108. user_loss = (user_log_variational_posterior - user_log_prior)
  109. return user_loss + item_loss
  110. def bootstrap_sample(self, sample_size):
  111. """
  112. Generate a bootstrapped dataset by sampling with replacement.
  113. """
  114. indices = np.random.choice(len(self.dataSet.train), size=len(self.dataSet.train) // sample_size,
  115. replace=True)
  116. sampled_train = [self.dataSet.train[i] for i in indices]
  117. return sampled_train
  118. class SuperModel(nn.Module):
  119. """
  120. A super model that combines predictions from multiple ensemble models using a neural network.
  121. """
  122. def __init__(self, ensemble_models, input_size):
  123. super(SuperModel, self).__init__()
  124. self.ensemble_models = ensemble_models
  125. self.combiner = nn.Sequential(
  126. nn.Linear(input_size, input_size // 2),
  127. nn.ReLU(),
  128. nn.Linear(input_size // 2, 1),
  129. nn.Sigmoid() # To ensure the output is a probability
  130. )
  131. def forward(self, user, item):
  132. """
  133. Forward pass of the super model.
  134. Combines predictions from ensemble models using a neural network.
  135. """
  136. ensemble_predictions = []
  137. with torch.no_grad(): # Ensure no gradients are computed for ensemble models
  138. for model in self.ensemble_models:
  139. model.eval() # Set individual models to evaluation mode
  140. predictions = model(user, item)
  141. ensemble_predictions.append(predictions)
  142. # Stack predictions to create input for the combiner network
  143. stacked_predictions = torch.stack(ensemble_predictions, dim=1) # Shape: (batch_size, num_ensemble_models)
  144. combined_predictions = self.combiner(stacked_predictions).squeeze(-1) # Shape: (batch_size,)
  145. return combined_predictions
  146. def run_epoch(model, optimizer, criterion, args):
  147. model.train()
  148. train_u, train_i, train_r = model.dataSet.getInstances(model.train_data, args.negNum)
  149. train_len = len(train_u)
  150. shuffled_idx = np.random.permutation(np.arange(train_len))
  151. train_u, train_i, train_r = train_u[shuffled_idx], train_i[shuffled_idx], train_r[shuffled_idx]
  152. num_batches = (train_len + args.batchSize - 1) // args.batchSize
  153. BCE_losses, kls = [], []
  154. for i in range(num_batches):
  155. min_idx = i * args.batchSize
  156. max_idx = min(train_len, (i + 1) * args.batchSize)
  157. user_tensor = torch.tensor(train_u[min_idx:max_idx], dtype=torch.long).to(DEVICE)
  158. item_tensor = torch.tensor(train_i[min_idx:max_idx], dtype=torch.long).to(DEVICE)
  159. rate_tensor = torch.tensor(train_r[min_idx:max_idx], dtype=torch.float32).to(DEVICE)
  160. rate_tensor = (rate_tensor - rate_tensor.min()) / (rate_tensor.max() - rate_tensor.min())
  161. optimizer.zero_grad()
  162. y_hat = model(user_tensor, item_tensor)
  163. loss = criterion(y_hat, rate_tensor)
  164. BCE_losses.append(loss.item())
  165. kl_coef = 4.42322e-08
  166. loss += kl_coef * model.sample_elbo(user_tensor, item_tensor, rate_tensor, 5, num_batches)
  167. loss.backward()
  168. optimizer.step()
  169. kls.append(loss.item())
  170. if i % 10 == 0:
  171. print(f'\rBatch {i}/{num_batches}: KL = {np.mean(kls[-10:]):.4f}, BCE = {np.mean(BCE_losses[-10:]):.4f}', end='')
  172. print(f"\nMean BCE Loss: {np.mean(BCE_losses):.4f}")
  173. print(f"Mean KL Divergence: {np.mean(kls):.4f}")
  174. return np.mean(kls)
  175. def evaluate(model, topK):
  176. model.eval()
  177. hr, NDCG = [], []
  178. with torch.no_grad():
  179. for i in range(len(model.testNeg[0])):
  180. user_tensor = model.testNeg[0][i]
  181. item_tensor = model.testNeg[1][i]
  182. predict = model(user_tensor, item_tensor)
  183. item_score_dict = {item: predict[j].item() for j, item in enumerate(item_tensor)}
  184. ranklist = heapq.nlargest(topK, item_score_dict, key=item_score_dict.get)
  185. hr.append(1 if item_tensor[0].item() in ranklist else 0)
  186. NDCG.append(math.log(2) / math.log(ranklist.index(item_tensor[0].item()) + 2) if item_tensor[0].item() in ranklist else 0)
  187. return np.mean(hr), np.mean(NDCG)
  188. def main():
  189. parser = argparse.ArgumentParser(description="Options")
  190. parser.add_argument('-dataName', action='store', dest='dataName', default='ml-100k')
  191. parser.add_argument('-negNum', action='store', dest='negNum', default=5, type=int)
  192. parser.add_argument('-userLayer', action='store', dest='userLayer', default=[512, 64, 64], type=int, nargs='+')
  193. parser.add_argument('-itemLayer', action='store', dest='itemLayer', default=[1024, 64, 64], type=int, nargs='+')
  194. parser.add_argument('-lr', action='store', dest='lr', default=0.0001, type=float)
  195. parser.add_argument('-maxEpochs', action='store', dest='maxEpochs', default=50, type=int)
  196. parser.add_argument('-batchSize', action='store', dest='batchSize', default=256, type=int)
  197. parser.add_argument('-earlyStop', action='store', dest='earlyStop', default=5, type=int)
  198. parser.add_argument('-checkPoint', action='store', dest='checkPoint', default='./checkPoint/')
  199. parser.add_argument('-topK', action='store', dest='topK', default=10, type=int)
  200. parser.add_argument('-loadModel', action='store_true', dest='loadModel', help="Load a saved model")
  201. parser.add_argument('-ensembleSize', action='store', dest='ensembleSize', default=10, type=int)
  202. parser.add_argument('-maxEpochN', action='store', dest='maxEpochN', default=30, type=int)
  203. parser.add_argument('-priorType', action='store', dest='priorType', default='ScaleMixtureGaussian',
  204. choices=['ScaleMixtureGaussian', 'Laplace', 'IsotropicGaussian'])
  205. args = parser.parse_args()
  206. if not os.path.exists(args.checkPoint):
  207. os.mkdir(args.checkPoint)
  208. ensemble_models = []
  209. optimizers = []
  210. ensemble_args = []
  211. network_layers = [[512, 64, 64], [512, 64], [1024, 64, 64], [512, 256, 64], [1024, 256, 256, 64]]
  212. prior_types = ['ScaleMixtureGaussian', 'Laplace', 'IsotropicGaussian']
  213. for ensemble_idx in range(args.ensembleSize):
  214. args_copy = argparse.Namespace(**vars(args))
  215. args_copy.userLayer = random.choice(network_layers)
  216. args_copy.itemLayer = random.choice(network_layers)
  217. args_copy.priorType = random.choice(prior_types)
  218. ensemble_model = Model(args_copy).to(DEVICE)
  219. ensemble_args.append(args_copy)
  220. optimizer = optim.Adam(ensemble_model.parameters(), lr=args.lr)
  221. ensemble_models.append(ensemble_model)
  222. optimizers.append(optimizer)
  223. criterion = nn.BCELoss()
  224. for ensemble_idx in range(args.ensembleSize):
  225. best_hr = -1
  226. best_NDCG = -1
  227. best_epoch = -1
  228. print(f'Ensemble Model Number {ensemble_idx}')
  229. print("Start Training!")
  230. print(ensemble_args[ensemble_idx])
  231. classifier = ensemble_models[ensemble_idx]
  232. optimizer = optimizers[ensemble_idx]
  233. for epoch in range(args.maxEpochs):
  234. print("=" * 20 + "Epoch " + str(epoch) + "=" * 20)
  235. run_epoch(classifier, optimizer, criterion, args)
  236. print('=' * 50)
  237. print("Start Evaluation!")
  238. hr, NDCG = evaluate(classifier, args.topK)
  239. print("Epoch ", epoch, "HR: {}, NDCG: {}".format(hr, NDCG))
  240. if hr > best_hr or NDCG > best_NDCG:
  241. best_hr = hr
  242. best_NDCG = NDCG
  243. best_epoch = epoch
  244. torch.save(classifier.state_dict(), os.path.join(args.checkPoint, f'model{ensemble_idx}.pth'))
  245. if epoch - best_epoch > args.earlyStop:
  246. print("Normal Early stop!")
  247. break
  248. print("=" * 20 + "Epoch " + str(epoch) + " End" + "=" * 20)
  249. print("Best hr: {}, NDCG: {}, At Epoch {}".format(best_hr, best_NDCG, best_epoch))
  250. print("Training complete!\n")
  251. for ensemble_idx in range(args.ensembleSize):
  252. model_path = os.path.join(args.checkPoint, f'model{ensemble_idx}.pth')
  253. if os.path.exists(model_path):
  254. print("Loading saved model from", model_path)
  255. ensemble_models[ensemble_idx].load_state_dict(torch.load(model_path))
  256. else:
  257. print("No saved model found at", model_path)
  258. super_model = SuperModel(ensemble_models, input_size=len(ensemble_models)).to(DEVICE)
  259. for epoch in range(args.maxEpochN):
  260. train_super_model(super_model, ensemble_models[0].dataSet.train, args)
  261. print("\nStart Testing")
  262. total_hr, total_NDCG = ensemble_eval(ensemble_models, super_model, args.topK)
  263. print("total hr: {}, total NDCG: {}".format(total_hr, total_NDCG))
  264. def train_super_model(super_model, train_data, args):
  265. super_model.train()
  266. optimizer = optim.Adam(super_model.parameters(), lr=args.lr)
  267. criterion = nn.BCELoss()
  268. train_u, train_i, train_r = super_model.ensemble_models[0].dataSet.getInstances(train_data, args.negNum)
  269. train_len = len(train_u)
  270. shuffled_idx = np.random.permutation(np.arange(train_len))
  271. train_u = train_u[shuffled_idx]
  272. train_i = train_i[shuffled_idx]
  273. train_r = train_r[shuffled_idx]
  274. num_batches = len(train_u) // args.batchSize + 1
  275. losses = []
  276. for i in range(num_batches):
  277. min_idx = i * args.batchSize
  278. max_idx = min(train_len, (i + 1) * args.batchSize)
  279. user_tensor = torch.tensor(train_u[min_idx:max_idx], dtype=torch.long).to(DEVICE)
  280. item_tensor = torch.tensor(train_i[min_idx:max_idx], dtype=torch.long).to(DEVICE)
  281. rate_tensor = torch.tensor(train_r[min_idx:max_idx], dtype=torch.float32).to(DEVICE)
  282. rate_tensor = (rate_tensor - rate_tensor.min()) / (rate_tensor.max() - rate_tensor.min())
  283. optimizer.zero_grad()
  284. y_hat = super_model(user_tensor, item_tensor)
  285. loss = criterion(y_hat, rate_tensor)
  286. loss.backward()
  287. optimizer.step()
  288. losses.append(loss.item())
  289. if i % 10 == 0:
  290. print(f'\rBatch {i}/{num_batches}: loss = {np.mean(losses[-10:]):.4f}', end='')
  291. print("\nMean loss for super model in this epoch is: {}".format(np.mean(losses)))
  292. def ensemble_eval(ensemble_models, superModel,topK):
  293. def getHitRatio(ranklist, targetItem):
  294. return 1 if targetItem in ranklist else 0
  295. def getNDCG(ranklist, targetItem):
  296. for i, item in enumerate(ranklist):
  297. if item == targetItem:
  298. return math.log(2) / math.log(i + 2)
  299. return 0
  300. hr = []
  301. NDCG = []
  302. testUser = ensemble_models[0].testNeg[0]
  303. testItem = ensemble_models[0].testNeg[1]
  304. with torch.no_grad():
  305. for i in range(len(testUser)):
  306. target = testItem[i][0]
  307. user_tensor = torch.tensor(testUser[i], dtype=torch.long).to(DEVICE)
  308. item_tensor = torch.tensor(testItem[i], dtype=torch.long).to(DEVICE)
  309. # ensemble_predicts = []
  310. # for model in ensemble_models:
  311. # predict = model(user_tensor, item_tensor)
  312. # ensemble_predicts.append(predict)
  313. total_predict = superModel(user_tensor, item_tensor)
  314. # print(total_predict)
  315. item_score_dict = {item: total_predict[j].item() for j, item in enumerate(testItem[i])}
  316. ranklist = heapq.nlargest(topK, item_score_dict, key=item_score_dict.get)
  317. tmp_hr = getHitRatio(ranklist, target)
  318. tmp_NDCG = getNDCG(ranklist, target)
  319. hr.append(tmp_hr)
  320. NDCG.append(tmp_NDCG)
  321. return np.mean(hr), np.mean(NDCG)
  322. if __name__ == '__main__':
  323. main()