@@ -100,7 +100,10 @@ class MetaTL(nn.Module): | |||
y = torch.Tensor([1]).to(self.device) | |||
self.zero_grad() | |||
loss = self.loss_func(p_score, n_score, y) | |||
sorted,indecies = torch.sort(n_score, descending=True,dim=1) | |||
n_values = sorted[:,0:p_score.shape[1]] | |||
loss = self.loss_func(p_score, n_values, y) | |||
loss.backward(retain_graph=True) | |||
grad_meta = rel.grad |
@@ -31,7 +31,7 @@ def sample_function_mixed(user_train, usernum, itemnum, batch_size, maxlen, resu | |||
seq = np.zeros([maxlen], dtype=np.int32) | |||
pos = np.zeros([maxlen], dtype=np.int32) | |||
neg = np.zeros([maxlen], dtype=np.int32) | |||
neg = np.zeros([maxlen*10], dtype=np.int32) | |||
if len(user_train[user]) < maxlen: | |||
nxt_idx = len(user_train[user]) - 1 | |||
@@ -45,20 +45,24 @@ def sample_function_mixed(user_train, usernum, itemnum, batch_size, maxlen, resu | |||
for i in reversed(user_train[user][min(0, nxt_idx - 1 - maxlen) : nxt_idx - 1]): | |||
seq[idx] = i | |||
pos[idx] = nxt | |||
if nxt != 0: neg[idx] = random_neq(1, itemnum + 1, ts, user_train,usernum) | |||
# if nxt != 0: neg[idx] = random_neq(1, itemnum + 1, ts, user_train,usernum) | |||
nxt = i | |||
idx -= 1 | |||
if idx == -1: break | |||
# for i in range(len(neg)): | |||
# neg[i] = random_neq(1, itemnum + 1, ts, user_train,usernum) | |||
for i in range(len(neg)): | |||
neg[i] = random_neq(1, itemnum + 1, ts, user_train,usernum) | |||
curr_rel = user | |||
support_triples, support_negative_triples, query_triples, negative_triples = [], [], [], [] | |||
for idx in range(maxlen-1): | |||
support_triples.append([seq[idx],curr_rel,pos[idx]]) | |||
# support_negative_triples.append([seq[idx],curr_rel,neg[idx]]) | |||
# support_negative_triples.append([seq[-1], curr_rel, neg[idx]]) | |||
for idx in range(maxlen*10 - 1): | |||
support_negative_triples.append([seq[-1], curr_rel, neg[idx]]) | |||
query_triples.append([seq[-1],curr_rel,pos[-1]]) | |||
negative_triples.append([seq[-1],curr_rel,neg[-1]]) | |||
@@ -118,7 +118,7 @@ class DataLoader(object): | |||
seq = np.zeros([self.maxlen], dtype=np.int32) | |||
pos = np.zeros([self.maxlen - 1], dtype=np.int32) | |||
neg = np.zeros([self.maxlen - 1], dtype=np.int32) | |||
neg = np.zeros([self.maxlen*10 - 1], dtype=np.int32) | |||
idx = self.maxlen - 1 | |||
@@ -127,15 +127,21 @@ class DataLoader(object): | |||
seq[idx] = i | |||
if idx > 0: | |||
pos[idx - 1] = i | |||
if i != 0: neg[idx - 1] = random_neq(1, self.itemnum + 1, ts,self.train) | |||
# if i != 0: neg[idx - 1] = random_neq(1, self.itemnum + 1, ts,self.train) | |||
idx -= 1 | |||
if idx == -1: break | |||
for i in range(len(neg)): | |||
neg[i] = random_neq(1, self.itemnum + 1, ts,self.train) | |||
curr_rel = u | |||
support_triples, support_negative_triples, query_triples, negative_triples = [], [], [], [] | |||
for idx in range(self.maxlen-1): | |||
support_triples.append([seq[idx],curr_rel,pos[idx]]) | |||
# support_negative_triples.append([seq[idx],curr_rel,neg[idx]]) | |||
# support_negative_triples.append([seq[-1],curr_rel,neg[idx]]) | |||
for idx in range(len(neg)): | |||
support_negative_triples.append([seq[-1],curr_rel,neg[idx]]) | |||
rated = ts |