Adapted to Movie lens dataset
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

embeddings_TaNP.py 14KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.init as init
  4. import torch.nn.functional as F
  5. class Item(torch.nn.Module):
  6. def __init__(self, config):
  7. super(Item, self).__init__()
  8. self.feature_dim = config['if_dim']
  9. self.first_embedding_dim = config['first_embedding_dim']
  10. self.second_embedding_dim = config['second_embedding_dim']
  11. self.first_embedding_layer = torch.nn.Linear(
  12. in_features=self.feature_dim,
  13. out_features=self.first_embedding_dim,
  14. bias=True
  15. )
  16. self.second_embedding_layer = torch.nn.Linear(
  17. in_features=self.first_embedding_dim,
  18. out_features=self.second_embedding_dim,
  19. bias=True
  20. )
  21. def forward(self, x, vars=None):
  22. first_hidden = self.first_embedding_layer(x)
  23. first_hidden = F.relu(first_hidden)
  24. sec_hidden = self.second_embedding_layer(first_hidden)
  25. return F.relu(sec_hidden)
  26. class Movie_item(torch.nn.Module):
  27. def __init__(self, config):
  28. super(Movie_item, self).__init__()
  29. self.num_rate = config['num_rate']
  30. self.num_genre = config['num_genre']
  31. self.num_director = config['num_director']
  32. self.num_actor = config['num_actor']
  33. self.embedding_dim = config['embedding_dim']
  34. # change for Movie
  35. self.feature_dim = 4 * self.embedding_dim
  36. self.embedding_rate = torch.nn.Embedding(
  37. num_embeddings=self.num_rate,
  38. embedding_dim=self.embedding_dim
  39. )
  40. self.embedding_genre = torch.nn.Linear(
  41. in_features=self.num_genre,
  42. out_features=self.embedding_dim,
  43. bias=False
  44. )
  45. self.embedding_director = torch.nn.Linear(
  46. in_features=self.num_director,
  47. out_features=self.embedding_dim,
  48. bias=False
  49. )
  50. self.embedding_actor = torch.nn.Linear(
  51. in_features=self.num_actor,
  52. out_features=self.embedding_dim,
  53. bias=False
  54. )
  55. def forward(self, rate_idx, genre_idx, director_idx, actors_idx, vars=None):
  56. rate_emb = self.embedding_rate(rate_idx)
  57. genre_emb = self.embedding_genre(genre_idx.float()) / torch.sum(genre_idx.float(), 1).view(-1, 1)
  58. director_emb = self.embedding_director(director_idx.float()) / torch.sum(director_idx.float(), 1).view(-1, 1)
  59. actors_emb = self.embedding_actor(actors_idx.float()) / torch.sum(actors_idx.float(), 1).view(-1, 1)
  60. return torch.cat((rate_emb, genre_emb, director_emb, actors_emb), 1)
  61. class User(torch.nn.Module):
  62. def __init__(self, config):
  63. super(User, self).__init__()
  64. self.feature_dim = config['uf_dim']
  65. self.first_embedding_dim = config['first_embedding_dim']
  66. self.second_embedding_dim = config['second_embedding_dim']
  67. self.first_embedding_layer = torch.nn.Linear(
  68. in_features=self.feature_dim,
  69. out_features=self.first_embedding_dim,
  70. bias=True
  71. )
  72. self.second_embedding_layer = torch.nn.Linear(
  73. in_features=self.first_embedding_dim,
  74. out_features=self.second_embedding_dim,
  75. bias=True
  76. )
  77. def forward(self, x, vars=None):
  78. first_hidden = self.first_embedding_layer(x)
  79. first_hidden = F.relu(first_hidden)
  80. sec_hidden = self.second_embedding_layer(first_hidden)
  81. return F.relu(sec_hidden)
  82. class Movie_user(torch.nn.Module):
  83. def __init__(self, config):
  84. super(Movie_user, self).__init__()
  85. self.num_gender = config['num_gender']
  86. self.num_age = config['num_age']
  87. self.num_occupation = config['num_occupation']
  88. self.num_zipcode = config['num_zipcode']
  89. self.embedding_dim = config['embedding_dim']
  90. self.embedding_gender = torch.nn.Embedding(
  91. num_embeddings=self.num_gender,
  92. embedding_dim=self.embedding_dim
  93. )
  94. self.embedding_age = torch.nn.Embedding(
  95. num_embeddings=self.num_age,
  96. embedding_dim=self.embedding_dim
  97. )
  98. self.embedding_occupation = torch.nn.Embedding(
  99. num_embeddings=self.num_occupation,
  100. embedding_dim=self.embedding_dim
  101. )
  102. self.embedding_area = torch.nn.Embedding(
  103. num_embeddings=self.num_zipcode,
  104. embedding_dim=self.embedding_dim
  105. )
  106. def forward(self, gender_idx, age_idx, occupation_idx, area_idx):
  107. gender_emb = self.embedding_gender(gender_idx)
  108. age_emb = self.embedding_age(age_idx)
  109. occupation_emb = self.embedding_occupation(occupation_idx)
  110. area_emb = self.embedding_area(area_idx)
  111. return torch.cat((gender_emb, age_emb, occupation_emb, area_emb), 1)
  112. class Encoder(nn.Module):
  113. #Maps an (x_i, y_i) pair to a representation r_i.
  114. # Add the dropout into encoder ---03.31
  115. def __init__(self, x_dim, y_dim, h1_dim, h2_dim, z1_dim, dropout_rate):
  116. super(Encoder, self).__init__()
  117. self.x_dim = x_dim
  118. self.y_dim = y_dim
  119. self.h1_dim = h1_dim
  120. self.h2_dim = h2_dim
  121. self.z1_dim = z1_dim
  122. self.dropout_rate = dropout_rate
  123. layers = [nn.Linear(self.x_dim + self.y_dim, self.h1_dim),
  124. torch.nn.Dropout(self.dropout_rate),
  125. nn.ReLU(inplace=True),
  126. nn.Linear(self.h1_dim, self.h2_dim),
  127. torch.nn.Dropout(self.dropout_rate),
  128. nn.ReLU(inplace=True),
  129. nn.Linear(self.h2_dim, self.z1_dim)]
  130. self.input_to_hidden = nn.Sequential(*layers)
  131. def forward(self, x, y):
  132. y = y.view(-1, 1)
  133. input_pairs = torch.cat((x, y), dim=1)
  134. return self.input_to_hidden(input_pairs)
  135. class MuSigmaEncoder(nn.Module):
  136. def __init__(self, z1_dim, z2_dim, z_dim):
  137. super(MuSigmaEncoder, self).__init__()
  138. self.z1_dim = z1_dim
  139. self.z2_dim = z2_dim
  140. self.z_dim = z_dim
  141. self.z_to_hidden = nn.Linear(self.z1_dim, self.z2_dim)
  142. self.hidden_to_mu = nn.Linear(self.z2_dim, z_dim)
  143. self.hidden_to_logsigma = nn.Linear(self.z2_dim, z_dim)
  144. def forward(self, z_input):
  145. hidden = torch.relu(self.z_to_hidden(z_input))
  146. mu = self.hidden_to_mu(hidden)
  147. log_sigma = self.hidden_to_logsigma(hidden)
  148. std = torch.exp(0.5 * log_sigma)
  149. eps = torch.randn_like(std)
  150. z = eps.mul(std).add_(mu)
  151. return mu, log_sigma, z
  152. class TaskEncoder(nn.Module):
  153. def __init__(self, x_dim, y_dim, h1_dim, h2_dim, final_dim, dropout_rate):
  154. super(TaskEncoder, self).__init__()
  155. self.x_dim = x_dim
  156. self.y_dim = y_dim
  157. self.h1_dim = h1_dim
  158. self.h2_dim = h2_dim
  159. self.final_dim = final_dim
  160. self.dropout_rate = dropout_rate
  161. layers = [nn.Linear(self.x_dim + self.y_dim, self.h1_dim),
  162. torch.nn.Dropout(self.dropout_rate),
  163. nn.ReLU(inplace=True),
  164. nn.Linear(self.h1_dim, self.h2_dim),
  165. torch.nn.Dropout(self.dropout_rate),
  166. nn.ReLU(inplace=True),
  167. nn.Linear(self.h2_dim, self.final_dim)]
  168. self.input_to_hidden = nn.Sequential(*layers)
  169. def forward(self, x, y):
  170. y = y.view(-1, 1)
  171. input_pairs = torch.cat((x, y), dim=1)
  172. return self.input_to_hidden(input_pairs)
  173. class MemoryUnit(nn.Module):
  174. # clusters_k is k keys
  175. def __init__(self, clusters_k, emb_size, temperature):
  176. super(MemoryUnit, self).__init__()
  177. self.clusters_k = clusters_k
  178. self.embed_size = emb_size
  179. self.temperature = temperature
  180. self.array = nn.Parameter(init.xavier_uniform_(torch.FloatTensor(self.clusters_k, self.embed_size)))
  181. def forward(self, task_embed):
  182. res = torch.norm(task_embed-self.array, p=2, dim=1, keepdim=True)
  183. res = torch.pow((res / self.temperature) + 1, (self.temperature + 1) / -2)
  184. # 1*k
  185. C = torch.transpose(res / res.sum(), 0, 1)
  186. # 1*k, k*d, 1*d
  187. value = torch.mm(C, self.array)
  188. # simple add operation
  189. new_task_embed = value + task_embed
  190. # calculate target distribution
  191. return C, new_task_embed
  192. class Decoder(nn.Module):
  193. """
  194. Maps target input x_target and z, r to predictions y_target.
  195. """
  196. def __init__(self, x_dim, z_dim, task_dim, h1_dim, h2_dim, h3_dim, y_dim, dropout_rate):
  197. super(Decoder, self).__init__()
  198. self.x_dim = x_dim
  199. self.z_dim = z_dim
  200. self.task_dim = task_dim
  201. self.h1_dim = h1_dim
  202. self.h2_dim = h2_dim
  203. self.h3_dim = h3_dim
  204. self.y_dim = y_dim
  205. self.dropout_rate = dropout_rate
  206. self.dropout = nn.Dropout(self.dropout_rate)
  207. self.hidden_layer_1 = nn.Linear(self.x_dim + self.z_dim, self.h1_dim)
  208. self.hidden_layer_2 = nn.Linear(self.h1_dim, self.h2_dim)
  209. self.hidden_layer_3 = nn.Linear(self.h2_dim, self.h3_dim)
  210. self.film_layer_1_beta = nn.Linear(self.task_dim, self.h1_dim, bias=False)
  211. self.film_layer_1_gamma = nn.Linear(self.task_dim, self.h1_dim, bias=False)
  212. self.film_layer_2_beta = nn.Linear(self.task_dim, self.h2_dim, bias=False)
  213. self.film_layer_2_gamma = nn.Linear(self.task_dim, self.h2_dim, bias=False)
  214. self.film_layer_3_beta = nn.Linear(self.task_dim, self.h3_dim, bias=False)
  215. self.film_layer_3_gamma = nn.Linear(self.task_dim, self.h3_dim, bias=False)
  216. self.final_projection = nn.Linear(self.h3_dim, self.y_dim)
  217. def forward(self, x, z, task):
  218. interaction_size, _ = x.size()
  219. z = z.unsqueeze(0).repeat(interaction_size, 1)
  220. # Input is concatenation of z with every row of x
  221. inputs = torch.cat((x, z), dim=1)
  222. hidden_1 = self.hidden_layer_1(inputs)
  223. beta_1 = torch.tanh(self.film_layer_1_beta(task))
  224. gamma_1 = torch.tanh(self.film_layer_1_gamma(task))
  225. hidden_1 = torch.mul(hidden_1, gamma_1) + beta_1
  226. hidden_1 = self.dropout(hidden_1)
  227. hidden_2 = F.relu(hidden_1)
  228. hidden_2 = self.hidden_layer_2(hidden_2)
  229. beta_2 = torch.tanh(self.film_layer_2_beta(task))
  230. gamma_2 = torch.tanh(self.film_layer_2_gamma(task))
  231. hidden_2 = torch.mul(hidden_2, gamma_2) + beta_2
  232. hidden_2 = self.dropout(hidden_2)
  233. hidden_3 = F.relu(hidden_2)
  234. hidden_3 = self.hidden_layer_3(hidden_3)
  235. beta_3 = torch.tanh(self.film_layer_3_beta(task))
  236. gamma_3 = torch.tanh(self.film_layer_3_gamma(task))
  237. hidden_final = torch.mul(hidden_3, gamma_3) + beta_3
  238. hidden_final = self.dropout(hidden_final)
  239. hidden_final = F.relu(hidden_final)
  240. y_pred = self.final_projection(hidden_final)
  241. return y_pred
  242. class Gating_Decoder(nn.Module):
  243. def __init__(self, x_dim, z_dim, task_dim, h1_dim, h2_dim, h3_dim, y_dim, dropout_rate):
  244. super(Gating_Decoder, self).__init__()
  245. self.x_dim = x_dim
  246. self.z_dim = z_dim
  247. self.task_dim = task_dim
  248. self.h1_dim = h1_dim
  249. self.h2_dim = h2_dim
  250. self.h3_dim = h3_dim
  251. self.y_dim = y_dim
  252. self.dropout_rate = dropout_rate
  253. self.dropout = nn.Dropout(self.dropout_rate)
  254. self.hidden_layer_1 = nn.Linear(self.x_dim + self.z_dim, self.h1_dim)
  255. self.hidden_layer_2 = nn.Linear(self.h1_dim, self.h2_dim)
  256. self.hidden_layer_3 = nn.Linear(self.h2_dim, self.h3_dim)
  257. self.film_layer_1_beta = nn.Linear(self.task_dim, self.h1_dim, bias=False)
  258. self.film_layer_1_gamma = nn.Linear(self.task_dim, self.h1_dim, bias=False)
  259. self.film_layer_1_eta = nn.Linear(self.task_dim, self.h1_dim, bias=False)
  260. self.film_layer_1_delta = nn.Linear(self.task_dim, self.h1_dim, bias=False)
  261. self.film_layer_2_beta = nn.Linear(self.task_dim, self.h2_dim, bias=False)
  262. self.film_layer_2_gamma = nn.Linear(self.task_dim, self.h2_dim, bias=False)
  263. self.film_layer_2_eta = nn.Linear(self.task_dim, self.h2_dim, bias=False)
  264. self.film_layer_2_delta = nn.Linear(self.task_dim, self.h2_dim, bias=False)
  265. self.film_layer_3_beta = nn.Linear(self.task_dim, self.h3_dim, bias=False)
  266. self.film_layer_3_gamma = nn.Linear(self.task_dim, self.h3_dim, bias=False)
  267. self.film_layer_3_eta = nn.Linear(self.task_dim, self.h3_dim, bias=False)
  268. self.film_layer_3_delta = nn.Linear(self.task_dim, self.h3_dim, bias=False)
  269. self.final_projection = nn.Linear(self.h3_dim, self.y_dim)
  270. def forward(self, x, z, task):
  271. interaction_size, _ = x.size()
  272. z = z.unsqueeze(0).repeat(interaction_size, 1)
  273. # Input is concatenation of z with every row of x
  274. inputs = torch.cat((x, z), dim=1)
  275. hidden_1 = self.hidden_layer_1(inputs)
  276. beta_1 = torch.tanh(self.film_layer_1_beta(task))
  277. gamma_1 = torch.tanh(self.film_layer_1_gamma(task))
  278. eta_1 = torch.tanh(self.film_layer_1_eta(task))
  279. delta_1 = torch.sigmoid(self.film_layer_1_delta(task))
  280. gamma_1 = gamma_1 * delta_1 + eta_1 * (1-delta_1)
  281. beta_1 = beta_1 * delta_1 + eta_1 * (1-delta_1)
  282. hidden_1 = torch.mul(hidden_1, gamma_1) + beta_1
  283. hidden_1 = self.dropout(hidden_1)
  284. hidden_2 = F.relu(hidden_1)
  285. hidden_2 = self.hidden_layer_2(hidden_2)
  286. beta_2 = torch.tanh(self.film_layer_2_beta(task))
  287. gamma_2 = torch.tanh(self.film_layer_2_gamma(task))
  288. eta_2 = torch.tanh(self.film_layer_2_eta(task))
  289. delta_2 = torch.sigmoid(self.film_layer_2_delta(task))
  290. gamma_2 = gamma_2 * delta_2 + eta_2 * (1 - delta_2)
  291. beta_2 = beta_2 * delta_2 + eta_2 * (1 - delta_2)
  292. hidden_2 = torch.mul(hidden_2, gamma_2) + beta_2
  293. hidden_2 = self.dropout(hidden_2)
  294. hidden_3 = F.relu(hidden_2)
  295. hidden_3 = self.hidden_layer_3(hidden_3)
  296. beta_3 = torch.tanh(self.film_layer_3_beta(task))
  297. gamma_3 = torch.tanh(self.film_layer_3_gamma(task))
  298. eta_3 = torch.tanh(self.film_layer_3_eta(task))
  299. delta_3 = torch.sigmoid(self.film_layer_3_delta(task))
  300. gamma_3 = gamma_3 * delta_3 + eta_3 * (1 - delta_3)
  301. beta_3 = beta_3 * delta_3 + eta_3 * (1 - delta_3)
  302. hidden_final = torch.mul(hidden_3, gamma_3) + beta_3
  303. hidden_final = self.dropout(hidden_final)
  304. hidden_final = F.relu(hidden_final)
  305. y_pred = self.final_projection(hidden_final)
  306. return y_pred