|
|
@@ -6,7 +6,8 @@ from torch.nn import functional as F |
|
|
|
class Embedding(nn.Module): |
|
|
|
def __init__(self, num_ent, parameter): |
|
|
|
super(Embedding, self).__init__() |
|
|
|
self.device = torch.device('cuda:0') |
|
|
|
# self.device = torch.device('cuda:0') |
|
|
|
self.device = torch.device(parameter['device']) |
|
|
|
self.es = parameter['embed_dim'] |
|
|
|
|
|
|
|
self.embedding = nn.Embedding(num_ent + 1, self.es) |
|
|
@@ -55,7 +56,7 @@ class EmbeddingLearner(nn.Module): |
|
|
|
class MetaTL(nn.Module): |
|
|
|
def __init__(self, itemnum, parameter): |
|
|
|
super(MetaTL, self).__init__() |
|
|
|
self.device = torch.device('cuda:0') |
|
|
|
self.device = torch.device(parameter['device']) |
|
|
|
self.beta = parameter['beta'] |
|
|
|
# self.dropout_p = parameter['dropout_p'] |
|
|
|
self.embed_dim = parameter['embed_dim'] |