@@ -24,7 +24,7 @@ def get_params(): | |||
args.add_argument("-m", "--margin", default=1, type=float) | |||
args.add_argument("-p", "--dropout_p", default=0.5, type=float) | |||
args.add_argument("-gpu", "--device", default=0, type=int) | |||
args.add_argument("-gpu", "--device", default=1, type=int) | |||
args = args.parse_args() |
@@ -7,11 +7,19 @@ from collections import defaultdict, Counter | |||
from multiprocessing import Process, Queue | |||
def random_neq(l, r, s): | |||
t = np.random.randint(l, r) | |||
while t in s: | |||
t = np.random.randint(l, r) | |||
return t | |||
def random_neq(l, r, s, user_train,usernum): | |||
# t = np.random.randint(l, r) | |||
# while t in s: | |||
# t = np.random.randint(l, r) | |||
# return t | |||
user = np.random.randint(1, usernum + 1) | |||
candid_item = user_train[user][np.random.randint(0,len(user_train[user]))] | |||
while candid_item in s: | |||
user = np.random.randint(1, usernum + 1) | |||
candid_item = user_train[user][np.random.randint(0, len(user_train[user]))] | |||
return candid_item | |||
def sample_function_mixed(user_train, usernum, itemnum, batch_size, maxlen, result_queue, SEED): | |||
def sample(): | |||
@@ -37,16 +45,20 @@ def sample_function_mixed(user_train, usernum, itemnum, batch_size, maxlen, resu | |||
for i in reversed(user_train[user][min(0, nxt_idx - 1 - maxlen) : nxt_idx - 1]): | |||
seq[idx] = i | |||
pos[idx] = nxt | |||
if nxt != 0: neg[idx] = random_neq(1, itemnum + 1, ts) | |||
if nxt != 0: neg[idx] = random_neq(1, itemnum + 1, ts, user_train,usernum) | |||
nxt = i | |||
idx -= 1 | |||
if idx == -1: break | |||
# for i in range(len(neg)): | |||
# neg[i] = random_neq(1, itemnum + 1, ts, user_train,usernum) | |||
curr_rel = user | |||
support_triples, support_negative_triples, query_triples, negative_triples = [], [], [], [] | |||
for idx in range(maxlen-1): | |||
support_triples.append([seq[idx],curr_rel,pos[idx]]) | |||
support_negative_triples.append([seq[idx],curr_rel,neg[idx]]) | |||
# support_negative_triples.append([seq[idx],curr_rel,neg[idx]]) | |||
support_negative_triples.append([seq[-1], curr_rel, neg[idx]]) | |||
query_triples.append([seq[-1],curr_rel,pos[-1]]) | |||
negative_triples.append([seq[-1],curr_rel,neg[-1]]) | |||
@@ -5,12 +5,22 @@ import random | |||
import numpy as np | |||
from collections import defaultdict, Counter | |||
from multiprocessing import Process, Queue | |||
# sampler for batch generation | |||
def random_neq(l, r, s): | |||
t = np.random.randint(l, r) | |||
while t in s: | |||
t = np.random.randint(l, r) | |||
return t | |||
def random_neq(l, r, s,user_train): | |||
# t = np.random.randint(l, r) | |||
# while t in s: | |||
# t = np.random.randint(l, r) | |||
# return t | |||
user = random.choice(list(user_train.keys())) | |||
item = random.choice(user_train[user]) | |||
while item in s: | |||
user = random.choice(list(user_train.keys())) | |||
item = random.choice(user_train[user]) | |||
return item | |||
def trans_to_cuda(variable): | |||
if torch.cuda.is_available(): | |||
@@ -97,6 +107,7 @@ class DataLoader(object): | |||
self.itemnum = itemnum | |||
def next_one_on_eval(self): | |||
if self.curr_tri_idx == self.num_tris: | |||
return "EOT", "EOT" | |||
@@ -116,7 +127,7 @@ class DataLoader(object): | |||
seq[idx] = i | |||
if idx > 0: | |||
pos[idx - 1] = i | |||
if i != 0: neg[idx - 1] = random_neq(1, self.itemnum + 1, ts) | |||
if i != 0: neg[idx - 1] = random_neq(1, self.itemnum + 1, ts,self.train) | |||
idx -= 1 | |||
if idx == -1: break | |||
@@ -124,7 +135,8 @@ class DataLoader(object): | |||
support_triples, support_negative_triples, query_triples, negative_triples = [], [], [], [] | |||
for idx in range(self.maxlen-1): | |||
support_triples.append([seq[idx],curr_rel,pos[idx]]) | |||
support_negative_triples.append([seq[idx],curr_rel,neg[idx]]) | |||
# support_negative_triples.append([seq[idx],curr_rel,neg[idx]]) | |||
support_negative_triples.append([seq[-1],curr_rel,neg[idx]]) | |||
rated = ts | |||
rated.add(0) |