Browse Source

could skip sequence (p=0.5)

RNN
mohamad maheri 2 years ago
parent
commit
703b6e570b
2 changed files with 59 additions and 14 deletions
  1. 18
    0
      models.py
  2. 41
    14
      sampler.py

+ 18
- 0
models.py View File

@@ -76,6 +76,23 @@ def bpr_max_loss(p_scores, n_values,device):
loss = loss.sum()
return loss

def bpr_max_loss_regularized(p_scores, n_values,device,l=0.0001):
s = F.softmax(n_values,dim=1)
ratio = int(n_values.shape[1] / p_scores.shape[1])
temp_pvalues = torch.tensor([],device=device)
for i in range(p_scores.shape[1]):
temp_pvalues = torch.cat((temp_pvalues,p_scores[:,i,None].expand(-1,ratio)),dim=1)

d = torch.sigmoid(torch.sub(temp_pvalues,n_values))
t = torch.mul(s,d)
loss = -1 * torch.log(t.sum(dim=1))
loss = loss.sum()

loss2 = torch.mul(s,n_values**2)
loss2 = loss2.sum(dim=1)
loss2 = loss2.sum()
return loss + l*loss2

def top_loss(p_scores, n_values,device):
ratio = int(n_values.shape[1] / p_scores.shape[1])
temp_pvalues = torch.tensor([],device=device)
@@ -107,6 +124,7 @@ class MetaTL(nn.Module):

self.embedding_learner = EmbeddingLearner()
# self.loss_func = nn.MarginRankingLoss(self.margin)
# self.loss_func = bpr_max_loss
self.loss_func = bpr_loss

self.rel_q_sharing = dict()

+ 41
- 14
sampler.py View File

@@ -12,8 +12,8 @@ def random_neq(l, r, s, user_train,usernum):
# while t in s:
# t = np.random.randint(l, r)
# return t
user = np.random.randint(1, usernum + 1)
candid_item = user_train[user][np.random.randint(0,len(user_train[user]))]
user = np.random.choice(1, usernum + 1)
candid_item = user_train[user][np.random.randint(0, len(user_train[user]))]

while candid_item in s:
user = np.random.randint(1, usernum + 1)
@@ -21,14 +21,23 @@ def random_neq(l, r, s, user_train,usernum):
return candid_item


def random_negetive_batch(l, r, s, user_train,usernum, batch_users):
user = np.random.choice(batch_users)
candid_item = user_train[user][np.random.randint(0, len(user_train[user]))]

while candid_item in s:
user = np.random.choice(batch_users)
candid_item = user_train[user][np.random.randint(0, len(user_train[user]))]
return candid_item


def sample_function_mixed(user_train, usernum, itemnum, batch_size, maxlen, result_queue, SEED,number_of_neg):
def sample():

if random.random()<=1:
user = np.random.randint(1, usernum + 1)
while len(user_train[user]) <= 1: user = np.random.randint(1, usernum + 1)
def sample(user,batch_users):
if random.random()<=0.5:
# user = np.random.randint(1, usernum + 1)
# while len(user_train[user]) <= 1: user = np.random.randint(1, usernum + 1)

seq = np.zeros([maxlen], dtype=np.int32)
pos = np.zeros([maxlen], dtype=np.int32)
neg = np.zeros([maxlen*number_of_neg], dtype=np.int32)
@@ -51,7 +60,8 @@ def sample_function_mixed(user_train, usernum, itemnum, batch_size, maxlen, resu
if idx == -1: break

for i in range(len(neg)):
neg[i] = random_neq(1, itemnum + 1, ts, user_train,usernum)
# neg[i] = random_neq(1, itemnum + 1, ts, user_train,usernum)
neg[i] = random_negetive_batch(1, itemnum + 1, ts, user_train, usernum, batch_users = batch_users)

curr_rel = user
support_triples, support_negative_triples, query_triples, negative_triples = [], [], [], []
@@ -72,12 +82,13 @@ def sample_function_mixed(user_train, usernum, itemnum, batch_size, maxlen, resu
return support_triples, support_negative_triples, query_triples, negative_triples, curr_rel

else:
user = np.random.randint(1, usernum + 1)
while len(user_train[user]) <= 1: user = np.random.randint(1, usernum + 1)
# print("bug happened in sample_function_mixed")
# user = np.random.randint(1, usernum + 1)
# while len(user_train[user]) <= 1: user = np.random.randint(1, usernum + 1)

seq = np.zeros([maxlen], dtype=np.int32)
pos = np.zeros([maxlen], dtype=np.int32)
neg = np.zeros([maxlen], dtype=np.int32)
neg = np.zeros([maxlen*number_of_neg], dtype=np.int32)

list_idx = random.sample([i for i in range(len(user_train[user]))], maxlen + 1)
list_item = [user_train[user][i] for i in sorted(list_idx)]
@@ -89,16 +100,25 @@ def sample_function_mixed(user_train, usernum, itemnum, batch_size, maxlen, resu
for i in reversed(list_item[:-1]):
seq[idx] = i
pos[idx] = nxt
if nxt != 0: neg[idx] = random_neq(1, itemnum + 1, ts)
# if nxt != 0: neg[idx] = random_neq(1, itemnum + 1, ts)
nxt = i
idx -= 1
if idx == -1: break

curr_rel = user
support_triples, support_negative_triples, query_triples, negative_triples = [], [], [], []

for i in range(len(neg)):
# neg[i] = random_neq(1, itemnum + 1, ts, user_train,usernum)
neg[i] = random_negetive_batch(1, itemnum + 1, ts, user_train, usernum, batch_users = batch_users)

for j in range(number_of_neg):
for idx in range(maxlen-1):
support_negative_triples.append([seq[idx], curr_rel, neg[j*maxlen + idx]])

for idx in range(maxlen-1):
support_triples.append([seq[idx],curr_rel,pos[idx]])
support_negative_triples.append([seq[idx],curr_rel,neg[idx]])
# support_negative_triples.append([seq[idx],curr_rel,neg[idx]])
query_triples.append([seq[-1],curr_rel,pos[-1]])
negative_triples.append([seq[-1],curr_rel,neg[-1]])

@@ -108,8 +128,15 @@ def sample_function_mixed(user_train, usernum, itemnum, batch_size, maxlen, resu
while True:
one_batch = []

users = []
for i in range(batch_size):
user = np.random.randint(1, usernum + 1)
while len(user_train[user]) <= 1: user = np.random.randint(1, usernum + 1)
users.append(user)

for i in range(batch_size):
one_batch.append(sample())
one_batch.append(sample(user = users[i], batch_users = users))

support, support_negative, query, negative, curr_rel = zip(*one_batch)


Loading…
Cancel
Save