123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143 |
- import sys
- import copy
- import torch
- import random
- import numpy as np
- from collections import defaultdict, Counter
- from multiprocessing import Process, Queue
-
-
- def random_neq(l, r, s, user_train,usernum):
- # t = np.random.randint(l, r)
- # while t in s:
- # t = np.random.randint(l, r)
- # return t
- user = np.random.randint(1, usernum + 1)
- candid_item = user_train[user][np.random.randint(0,len(user_train[user]))]
-
- while candid_item in s:
- user = np.random.randint(1, usernum + 1)
- candid_item = user_train[user][np.random.randint(0, len(user_train[user]))]
- return candid_item
-
-
- def sample_function_mixed(user_train, usernum, itemnum, batch_size, maxlen, result_queue, SEED):
- def sample():
-
- if random.random()<=1:
- user = np.random.randint(1, usernum + 1)
- while len(user_train[user]) <= 1: user = np.random.randint(1, usernum + 1)
-
-
- seq = np.zeros([maxlen], dtype=np.int32)
- pos = np.zeros([maxlen], dtype=np.int32)
- neg = np.zeros([maxlen*5], dtype=np.int32)
-
- if len(user_train[user]) < maxlen:
- nxt_idx = len(user_train[user]) - 1
- else:
- nxt_idx = np.random.randint(maxlen,len(user_train[user]))
-
- nxt = user_train[user][nxt_idx]
- idx = maxlen - 1
-
- ts = set(user_train[user])
- for i in reversed(user_train[user][min(0, nxt_idx - 1 - maxlen) : nxt_idx - 1]):
- seq[idx] = i
- pos[idx] = nxt
- # if nxt != 0: neg[idx] = random_neq(1, itemnum + 1, ts, user_train,usernum)
- nxt = i
- idx -= 1
- if idx == -1: break
-
- for i in range(len(neg)):
- neg[i] = random_neq(1, itemnum + 1, ts, user_train,usernum)
-
- curr_rel = user
- support_triples, support_negative_triples, query_triples, negative_triples = [], [], [], []
- for idx in range(maxlen-1):
- support_triples.append([seq[idx],curr_rel,pos[idx]])
- # support_negative_triples.append([seq[idx],curr_rel,neg[idx]])
- # support_negative_triples.append([seq[-1], curr_rel, neg[idx]])
-
- # for idx in range(maxlen*30 - 1):
- # support_negative_triples.append([seq[-1], curr_rel, neg[idx]])
- for j in range(5):
- for idx in range(maxlen-1):
- support_negative_triples.append([seq[idx], curr_rel, neg[j*maxlen + idx]])
-
- query_triples.append([seq[-1],curr_rel,pos[-1]])
- negative_triples.append([seq[-1],curr_rel,neg[-1]])
-
- return support_triples, support_negative_triples, query_triples, negative_triples, curr_rel
-
- else:
- user = np.random.randint(1, usernum + 1)
- while len(user_train[user]) <= 1: user = np.random.randint(1, usernum + 1)
-
- seq = np.zeros([maxlen], dtype=np.int32)
- pos = np.zeros([maxlen], dtype=np.int32)
- neg = np.zeros([maxlen], dtype=np.int32)
-
- list_idx = random.sample([i for i in range(len(user_train[user]))], maxlen + 1)
- list_item = [user_train[user][i] for i in sorted(list_idx)]
-
- nxt = list_item[-1]
- idx = maxlen - 1
-
- ts = set(user_train[user])
- for i in reversed(list_item[:-1]):
- seq[idx] = i
- pos[idx] = nxt
- if nxt != 0: neg[idx] = random_neq(1, itemnum + 1, ts)
- nxt = i
- idx -= 1
- if idx == -1: break
-
- curr_rel = user
- support_triples, support_negative_triples, query_triples, negative_triples = [], [], [], []
- for idx in range(maxlen-1):
- support_triples.append([seq[idx],curr_rel,pos[idx]])
- support_negative_triples.append([seq[idx],curr_rel,neg[idx]])
- query_triples.append([seq[-1],curr_rel,pos[-1]])
- negative_triples.append([seq[-1],curr_rel,neg[-1]])
-
- return support_triples, support_negative_triples, query_triples, negative_triples, curr_rel
-
- np.random.seed(SEED)
-
- while True:
- one_batch = []
- for i in range(batch_size):
- one_batch.append(sample())
-
- support, support_negative, query, negative, curr_rel = zip(*one_batch)
-
- result_queue.put(([support, support_negative, query, negative], curr_rel))
-
- class WarpSampler(object):
- def __init__(self, User, usernum, itemnum, batch_size=64, maxlen=10, n_workers=1):
- self.result_queue = Queue(maxsize=n_workers * 10)
- self.processors = []
- for i in range(n_workers):
- self.processors.append(
- Process(target=sample_function_mixed, args=(User,
- usernum,
- itemnum,
- batch_size,
- maxlen,
- self.result_queue,
- np.random.randint(2e9)
- )))
- self.processors[-1].daemon = True
- self.processors[-1].start()
-
- def next_batch(self):
- return self.result_queue.get()
-
- def close(self):
- for p in self.processors:
- p.terminate()
- p.join()
-
|