Browse Source

could skip sequence (p=0.5)

RNN
mohamad maheri 2 years ago
parent
commit
703b6e570b
2 changed files with 59 additions and 14 deletions
  1. 18
    0
      models.py
  2. 41
    14
      sampler.py

+ 18
- 0
models.py View File

loss = loss.sum() loss = loss.sum()
return loss return loss


def bpr_max_loss_regularized(p_scores, n_values,device,l=0.0001):
s = F.softmax(n_values,dim=1)
ratio = int(n_values.shape[1] / p_scores.shape[1])
temp_pvalues = torch.tensor([],device=device)
for i in range(p_scores.shape[1]):
temp_pvalues = torch.cat((temp_pvalues,p_scores[:,i,None].expand(-1,ratio)),dim=1)

d = torch.sigmoid(torch.sub(temp_pvalues,n_values))
t = torch.mul(s,d)
loss = -1 * torch.log(t.sum(dim=1))
loss = loss.sum()

loss2 = torch.mul(s,n_values**2)
loss2 = loss2.sum(dim=1)
loss2 = loss2.sum()
return loss + l*loss2

def top_loss(p_scores, n_values,device): def top_loss(p_scores, n_values,device):
ratio = int(n_values.shape[1] / p_scores.shape[1]) ratio = int(n_values.shape[1] / p_scores.shape[1])
temp_pvalues = torch.tensor([],device=device) temp_pvalues = torch.tensor([],device=device)


self.embedding_learner = EmbeddingLearner() self.embedding_learner = EmbeddingLearner()
# self.loss_func = nn.MarginRankingLoss(self.margin) # self.loss_func = nn.MarginRankingLoss(self.margin)
# self.loss_func = bpr_max_loss
self.loss_func = bpr_loss self.loss_func = bpr_loss


self.rel_q_sharing = dict() self.rel_q_sharing = dict()

+ 41
- 14
sampler.py View File

# while t in s: # while t in s:
# t = np.random.randint(l, r) # t = np.random.randint(l, r)
# return t # return t
user = np.random.randint(1, usernum + 1)
candid_item = user_train[user][np.random.randint(0,len(user_train[user]))]
user = np.random.choice(1, usernum + 1)
candid_item = user_train[user][np.random.randint(0, len(user_train[user]))]


while candid_item in s: while candid_item in s:
user = np.random.randint(1, usernum + 1) user = np.random.randint(1, usernum + 1)
return candid_item return candid_item




def random_negetive_batch(l, r, s, user_train,usernum, batch_users):
user = np.random.choice(batch_users)
candid_item = user_train[user][np.random.randint(0, len(user_train[user]))]

while candid_item in s:
user = np.random.choice(batch_users)
candid_item = user_train[user][np.random.randint(0, len(user_train[user]))]
return candid_item


def sample_function_mixed(user_train, usernum, itemnum, batch_size, maxlen, result_queue, SEED,number_of_neg): def sample_function_mixed(user_train, usernum, itemnum, batch_size, maxlen, result_queue, SEED,number_of_neg):
def sample():


if random.random()<=1:
user = np.random.randint(1, usernum + 1)
while len(user_train[user]) <= 1: user = np.random.randint(1, usernum + 1)
def sample(user,batch_users):
if random.random()<=0.5:
# user = np.random.randint(1, usernum + 1)
# while len(user_train[user]) <= 1: user = np.random.randint(1, usernum + 1)


seq = np.zeros([maxlen], dtype=np.int32) seq = np.zeros([maxlen], dtype=np.int32)
pos = np.zeros([maxlen], dtype=np.int32) pos = np.zeros([maxlen], dtype=np.int32)
neg = np.zeros([maxlen*number_of_neg], dtype=np.int32) neg = np.zeros([maxlen*number_of_neg], dtype=np.int32)
if idx == -1: break if idx == -1: break


for i in range(len(neg)): for i in range(len(neg)):
neg[i] = random_neq(1, itemnum + 1, ts, user_train,usernum)
# neg[i] = random_neq(1, itemnum + 1, ts, user_train,usernum)
neg[i] = random_negetive_batch(1, itemnum + 1, ts, user_train, usernum, batch_users = batch_users)


curr_rel = user curr_rel = user
support_triples, support_negative_triples, query_triples, negative_triples = [], [], [], [] support_triples, support_negative_triples, query_triples, negative_triples = [], [], [], []
return support_triples, support_negative_triples, query_triples, negative_triples, curr_rel return support_triples, support_negative_triples, query_triples, negative_triples, curr_rel


else: else:
user = np.random.randint(1, usernum + 1)
while len(user_train[user]) <= 1: user = np.random.randint(1, usernum + 1)
# print("bug happened in sample_function_mixed")
# user = np.random.randint(1, usernum + 1)
# while len(user_train[user]) <= 1: user = np.random.randint(1, usernum + 1)


seq = np.zeros([maxlen], dtype=np.int32) seq = np.zeros([maxlen], dtype=np.int32)
pos = np.zeros([maxlen], dtype=np.int32) pos = np.zeros([maxlen], dtype=np.int32)
neg = np.zeros([maxlen], dtype=np.int32)
neg = np.zeros([maxlen*number_of_neg], dtype=np.int32)


list_idx = random.sample([i for i in range(len(user_train[user]))], maxlen + 1) list_idx = random.sample([i for i in range(len(user_train[user]))], maxlen + 1)
list_item = [user_train[user][i] for i in sorted(list_idx)] list_item = [user_train[user][i] for i in sorted(list_idx)]
for i in reversed(list_item[:-1]): for i in reversed(list_item[:-1]):
seq[idx] = i seq[idx] = i
pos[idx] = nxt pos[idx] = nxt
if nxt != 0: neg[idx] = random_neq(1, itemnum + 1, ts)
# if nxt != 0: neg[idx] = random_neq(1, itemnum + 1, ts)
nxt = i nxt = i
idx -= 1 idx -= 1
if idx == -1: break if idx == -1: break


curr_rel = user curr_rel = user
support_triples, support_negative_triples, query_triples, negative_triples = [], [], [], [] support_triples, support_negative_triples, query_triples, negative_triples = [], [], [], []

for i in range(len(neg)):
# neg[i] = random_neq(1, itemnum + 1, ts, user_train,usernum)
neg[i] = random_negetive_batch(1, itemnum + 1, ts, user_train, usernum, batch_users = batch_users)

for j in range(number_of_neg):
for idx in range(maxlen-1):
support_negative_triples.append([seq[idx], curr_rel, neg[j*maxlen + idx]])

for idx in range(maxlen-1): for idx in range(maxlen-1):
support_triples.append([seq[idx],curr_rel,pos[idx]]) support_triples.append([seq[idx],curr_rel,pos[idx]])
support_negative_triples.append([seq[idx],curr_rel,neg[idx]])
# support_negative_triples.append([seq[idx],curr_rel,neg[idx]])
query_triples.append([seq[-1],curr_rel,pos[-1]]) query_triples.append([seq[-1],curr_rel,pos[-1]])
negative_triples.append([seq[-1],curr_rel,neg[-1]]) negative_triples.append([seq[-1],curr_rel,neg[-1]])


while True: while True:
one_batch = [] one_batch = []

users = []
for i in range(batch_size):
user = np.random.randint(1, usernum + 1)
while len(user_train[user]) <= 1: user = np.random.randint(1, usernum + 1)
users.append(user)

for i in range(batch_size): for i in range(batch_size):
one_batch.append(sample())
one_batch.append(sample(user = users[i], batch_users = users))


support, support_negative, query, negative, curr_rel = zip(*one_batch) support, support_negative, query, negative, curr_rel = zip(*one_batch)



Loading…
Cancel
Save