Browse Source

negative sampling by choosing from other user's items

RNN
mohamad maheri 2 years ago
parent
commit
de111c73dc
3 changed files with 39 additions and 15 deletions
  1. 1
    1
      main.py
  2. 19
    7
      sampler.py
  3. 19
    7
      utils.py

+ 1
- 1
main.py View File

@@ -24,7 +24,7 @@ def get_params():
args.add_argument("-m", "--margin", default=1, type=float)
args.add_argument("-p", "--dropout_p", default=0.5, type=float)

args.add_argument("-gpu", "--device", default=0, type=int)
args.add_argument("-gpu", "--device", default=1, type=int)


args = args.parse_args()

+ 19
- 7
sampler.py View File

@@ -7,11 +7,19 @@ from collections import defaultdict, Counter
from multiprocessing import Process, Queue


def random_neq(l, r, s):
t = np.random.randint(l, r)
while t in s:
t = np.random.randint(l, r)
return t
def random_neq(l, r, s, user_train,usernum):
# t = np.random.randint(l, r)
# while t in s:
# t = np.random.randint(l, r)
# return t
user = np.random.randint(1, usernum + 1)
candid_item = user_train[user][np.random.randint(0,len(user_train[user]))]

while candid_item in s:
user = np.random.randint(1, usernum + 1)
candid_item = user_train[user][np.random.randint(0, len(user_train[user]))]
return candid_item


def sample_function_mixed(user_train, usernum, itemnum, batch_size, maxlen, result_queue, SEED):
def sample():
@@ -37,16 +45,20 @@ def sample_function_mixed(user_train, usernum, itemnum, batch_size, maxlen, resu
for i in reversed(user_train[user][min(0, nxt_idx - 1 - maxlen) : nxt_idx - 1]):
seq[idx] = i
pos[idx] = nxt
if nxt != 0: neg[idx] = random_neq(1, itemnum + 1, ts)
if nxt != 0: neg[idx] = random_neq(1, itemnum + 1, ts, user_train,usernum)
nxt = i
idx -= 1
if idx == -1: break

# for i in range(len(neg)):
# neg[i] = random_neq(1, itemnum + 1, ts, user_train,usernum)

curr_rel = user
support_triples, support_negative_triples, query_triples, negative_triples = [], [], [], []
for idx in range(maxlen-1):
support_triples.append([seq[idx],curr_rel,pos[idx]])
support_negative_triples.append([seq[idx],curr_rel,neg[idx]])
# support_negative_triples.append([seq[idx],curr_rel,neg[idx]])
support_negative_triples.append([seq[-1], curr_rel, neg[idx]])
query_triples.append([seq[-1],curr_rel,pos[-1]])
negative_triples.append([seq[-1],curr_rel,neg[-1]])


+ 19
- 7
utils.py View File

@@ -5,12 +5,22 @@ import random
import numpy as np
from collections import defaultdict, Counter
from multiprocessing import Process, Queue

# sampler for batch generation
def random_neq(l, r, s):
t = np.random.randint(l, r)
while t in s:
t = np.random.randint(l, r)
return t
def random_neq(l, r, s,user_train):
# t = np.random.randint(l, r)
# while t in s:
# t = np.random.randint(l, r)
# return t

user = random.choice(list(user_train.keys()))
item = random.choice(user_train[user])

while item in s:
user = random.choice(list(user_train.keys()))
item = random.choice(user_train[user])
return item


def trans_to_cuda(variable):
if torch.cuda.is_available():
@@ -97,6 +107,7 @@ class DataLoader(object):
self.itemnum = itemnum


def next_one_on_eval(self):
if self.curr_tri_idx == self.num_tris:
return "EOT", "EOT"
@@ -116,7 +127,7 @@ class DataLoader(object):
seq[idx] = i
if idx > 0:
pos[idx - 1] = i
if i != 0: neg[idx - 1] = random_neq(1, self.itemnum + 1, ts)
if i != 0: neg[idx - 1] = random_neq(1, self.itemnum + 1, ts,self.train)
idx -= 1
if idx == -1: break

@@ -124,7 +135,8 @@ class DataLoader(object):
support_triples, support_negative_triples, query_triples, negative_triples = [], [], [], []
for idx in range(self.maxlen-1):
support_triples.append([seq[idx],curr_rel,pos[idx]])
support_negative_triples.append([seq[idx],curr_rel,neg[idx]])
# support_negative_triples.append([seq[idx],curr_rel,neg[idx]])
support_negative_triples.append([seq[-1],curr_rel,neg[idx]])

rated = ts
rated.add(0)

Loading…
Cancel
Save