Browse Source

repair bug of skipping in sequence

RNN
mohamad maheri 2 years ago
parent
commit
dbbec9facc
2 changed files with 28 additions and 12 deletions
  1. 23
    11
      sampler.py
  2. 5
    1
      utils.py

+ 23
- 11
sampler.py View File

@@ -12,14 +12,26 @@ def random_neq(l, r, s, user_train,usernum):
# while t in s:
# t = np.random.randint(l, r)
# return t
user = np.random.choice(1, usernum + 1)
candid_item = user_train[user][np.random.randint(0, len(user_train[user]))]

while candid_item in s:
user = np.random.randint(1, usernum + 1)
candid_item = user_train[user][np.random.randint(0, len(user_train[user]))]
return candid_item

# user = np.random.choice(1, usernum + 1)

# user = random.randint(1,usernum+1)
# while len(user_train[user])<3:
# user = random.randint(1, usernum + 1)
# candid_item = user_train[user][random.randint(0, len(user_train[user])-1)]
#
# while candid_item in s:
# while len(user_train[user]) < 3:
# user = random.randint(1, usernum + 1)
# candid_item = user_train[user][random.randint(0, len(user_train[user])-1)]
# return candid_item

user = random.choice(list(user_train.keys()))
item = random.choice(user_train[user])

while item in s:
user = random.choice(list(user_train.keys()))
item = random.choice(user_train[user])
return item

def random_negetive_batch(l, r, s, user_train,usernum, batch_users):
user = np.random.choice(batch_users)
@@ -40,7 +52,7 @@ def sample_function_mixed(user_train, usernum, itemnum, batch_size, maxlen, resu

seq = np.zeros([maxlen], dtype=np.int32)
pos = np.zeros([maxlen], dtype=np.int32)
neg = np.zeros([maxlen*number_of_neg], dtype=np.int32)
neg = np.zeros([(maxlen-1)*number_of_neg + 1], dtype=np.int32)

if len(user_train[user]) < maxlen:
nxt_idx = len(user_train[user]) - 1
@@ -51,7 +63,7 @@ def sample_function_mixed(user_train, usernum, itemnum, batch_size, maxlen, resu
idx = maxlen - 1

ts = set(user_train[user])
for i in reversed(user_train[user][min(0, nxt_idx - 1 - maxlen) : nxt_idx - 1]):
for i in reversed(user_train[user][(nxt_idx - maxlen) : nxt_idx ]):
seq[idx] = i
pos[idx] = nxt
# if nxt != 0: neg[idx] = random_neq(1, itemnum + 1, ts, user_train,usernum)
@@ -74,7 +86,7 @@ def sample_function_mixed(user_train, usernum, itemnum, batch_size, maxlen, resu
# support_negative_triples.append([seq[-1], curr_rel, neg[idx]])
for j in range(number_of_neg):
for idx in range(maxlen-1):
support_negative_triples.append([seq[idx], curr_rel, neg[j*maxlen + idx]])
support_negative_triples.append([seq[idx], curr_rel, neg[j*(maxlen-1) + idx]])

query_triples.append([seq[-1],curr_rel,pos[-1]])
negative_triples.append([seq[-1],curr_rel,neg[-1]])

+ 5
- 1
utils.py View File

@@ -148,9 +148,13 @@ class DataLoader(object):

# for idx in range(len(neg)):
# support_negative_triples.append([seq[-1],curr_rel,neg[idx]])
# print("injam",self.maxlen,list(range(self.maxlen-1)))
# print("====")
for j in range(self.number_of_neg):
for idx in range(self.maxlen-1):
support_negative_triples.append([seq[idx], curr_rel, neg[j * self.maxlen + idx]])
# print(j * self.maxlen + idx)
support_negative_triples.append([seq[idx], curr_rel, neg[j * (self.maxlen-1) + idx]])
# print("=end=\n\n")

rated = ts
rated.add(0)

Loading…
Cancel
Save