Browse Source

cuda setting fixed

RNN
mohamad maheri 2 years ago
parent
commit
8e21c390fc
3 changed files with 5 additions and 3 deletions
  1. 1
    0
      main.py
  2. 3
    2
      models.py
  3. 1
    1
      trainer.py

+ 1
- 0
main.py View File

@@ -34,6 +34,7 @@ def get_params():

params['device'] = torch.device('cuda:'+str(args.device))
# params['device'] = torch.device('cpu')
print("gpu:",params['device'])

return params, args


+ 3
- 2
models.py View File

@@ -6,7 +6,8 @@ from torch.nn import functional as F
class Embedding(nn.Module):
def __init__(self, num_ent, parameter):
super(Embedding, self).__init__()
self.device = torch.device('cuda:0')
# self.device = torch.device('cuda:0')
self.device = torch.device(parameter['device'])
self.es = parameter['embed_dim']
self.embedding = nn.Embedding(num_ent + 1, self.es)
@@ -55,7 +56,7 @@ class EmbeddingLearner(nn.Module):
class MetaTL(nn.Module):
def __init__(self, itemnum, parameter):
super(MetaTL, self).__init__()
self.device = torch.device('cuda:0')
self.device = torch.device(parameter['device'])
self.beta = parameter['beta']
# self.dropout_p = parameter['dropout_p']
self.embed_dim = parameter['embed_dim']

+ 1
- 1
trainer.py View File

@@ -21,7 +21,7 @@ class Trainer:
# self.print_epoch = parameter['print_epoch']
# self.eval_epoch = parameter['eval_epoch']
self.eval_epoch = 1000
self.device = torch.device('cuda:0')
self.device = torch.device(parameter['device'])

self.MetaTL = MetaTL(itemnum, parameter)
self.MetaTL.to(self.device)

Loading…
Cancel
Save