Browse Source

Distribute NS to all member of the sequence (except last item of seq in support-set)

RNN
mohamad maheri 2 years ago
parent
commit
6af0db43ad
2 changed files with 13 additions and 6 deletions
  1. 6
    3
      sampler.py
  2. 7
    3
      utils.py

+ 6
- 3
sampler.py View File

@@ -31,7 +31,7 @@ def sample_function_mixed(user_train, usernum, itemnum, batch_size, maxlen, resu
seq = np.zeros([maxlen], dtype=np.int32)
pos = np.zeros([maxlen], dtype=np.int32)
neg = np.zeros([maxlen*10], dtype=np.int32)
neg = np.zeros([maxlen*30], dtype=np.int32)

if len(user_train[user]) < maxlen:
nxt_idx = len(user_train[user]) - 1
@@ -60,8 +60,11 @@ def sample_function_mixed(user_train, usernum, itemnum, batch_size, maxlen, resu
# support_negative_triples.append([seq[idx],curr_rel,neg[idx]])
# support_negative_triples.append([seq[-1], curr_rel, neg[idx]])

for idx in range(maxlen*10 - 1):
support_negative_triples.append([seq[-1], curr_rel, neg[idx]])
# for idx in range(maxlen*30 - 1):
# support_negative_triples.append([seq[-1], curr_rel, neg[idx]])
for j in range(30):
for idx in range(maxlen-1):
support_negative_triples.append([seq[idx], curr_rel, neg[j*maxlen + idx]])

query_triples.append([seq[-1],curr_rel,pos[-1]])
negative_triples.append([seq[-1],curr_rel,neg[-1]])

+ 7
- 3
utils.py View File

@@ -118,7 +118,8 @@ class DataLoader(object):
seq = np.zeros([self.maxlen], dtype=np.int32)
pos = np.zeros([self.maxlen - 1], dtype=np.int32)
neg = np.zeros([self.maxlen*10 - 1], dtype=np.int32)
# neg = np.zeros([self.maxlen*30 - 1], dtype=np.int32)
neg = np.zeros([self.maxlen * 30], dtype=np.int32)
idx = self.maxlen - 1

@@ -141,8 +142,11 @@ class DataLoader(object):
# support_negative_triples.append([seq[idx],curr_rel,neg[idx]])
# support_negative_triples.append([seq[-1],curr_rel,neg[idx]])

for idx in range(len(neg)):
support_negative_triples.append([seq[-1],curr_rel,neg[idx]])
# for idx in range(len(neg)):
# support_negative_triples.append([seq[-1],curr_rel,neg[idx]])
for j in range(30):
for idx in range(self.maxlen-1):
support_negative_triples.append([seq[idx], curr_rel, neg[j * self.maxlen + idx]])

rated = ts
rated.add(0)

Loading…
Cancel
Save