Browse Source

cuda setting fixed

RNN
mohamad maheri 2 years ago
parent
commit
8e21c390fc
3 changed files with 5 additions and 3 deletions
  1. 1
    0
      main.py
  2. 3
    2
      models.py
  3. 1
    1
      trainer.py

+ 1
- 0
main.py View File



params['device'] = torch.device('cuda:'+str(args.device)) params['device'] = torch.device('cuda:'+str(args.device))
# params['device'] = torch.device('cpu') # params['device'] = torch.device('cpu')
print("gpu:",params['device'])


return params, args return params, args



+ 3
- 2
models.py View File

class Embedding(nn.Module): class Embedding(nn.Module):
def __init__(self, num_ent, parameter): def __init__(self, num_ent, parameter):
super(Embedding, self).__init__() super(Embedding, self).__init__()
self.device = torch.device('cuda:0')
# self.device = torch.device('cuda:0')
self.device = torch.device(parameter['device'])
self.es = parameter['embed_dim'] self.es = parameter['embed_dim']
self.embedding = nn.Embedding(num_ent + 1, self.es) self.embedding = nn.Embedding(num_ent + 1, self.es)
class MetaTL(nn.Module): class MetaTL(nn.Module):
def __init__(self, itemnum, parameter): def __init__(self, itemnum, parameter):
super(MetaTL, self).__init__() super(MetaTL, self).__init__()
self.device = torch.device('cuda:0')
self.device = torch.device(parameter['device'])
self.beta = parameter['beta'] self.beta = parameter['beta']
# self.dropout_p = parameter['dropout_p'] # self.dropout_p = parameter['dropout_p']
self.embed_dim = parameter['embed_dim'] self.embed_dim = parameter['embed_dim']

+ 1
- 1
trainer.py View File

# self.print_epoch = parameter['print_epoch'] # self.print_epoch = parameter['print_epoch']
# self.eval_epoch = parameter['eval_epoch'] # self.eval_epoch = parameter['eval_epoch']
self.eval_epoch = 1000 self.eval_epoch = 1000
self.device = torch.device('cuda:0')
self.device = torch.device(parameter['device'])


self.MetaTL = MetaTL(itemnum, parameter) self.MetaTL = MetaTL(itemnum, parameter)
self.MetaTL.to(self.device) self.MetaTL.to(self.device)

Loading…
Cancel
Save