Browse Source

using LSTM instead of MLP

RNN
mohamad maheri 2 years ago
parent
commit
1c015e3b35
3 changed files with 14 additions and 27 deletions
  1. 2
    1
      main.py
  2. 11
    25
      models.py
  3. 1
    1
      trainer.py

+ 2
- 1
main.py View File

args.add_argument("-m", "--margin", default=1, type=float) args.add_argument("-m", "--margin", default=1, type=float)
args.add_argument("-p", "--dropout_p", default=0.5, type=float) args.add_argument("-p", "--dropout_p", default=0.5, type=float)


args.add_argument("-gpu", "--device", default=0, type=int)
args.add_argument("-gpu", "--device", default=1, type=int)




args = args.parse_args() args = args.parse_args()
params[k] = v params[k] = v


params['device'] = torch.device('cuda:'+str(args.device)) params['device'] = torch.device('cuda:'+str(args.device))
# params['device'] = torch.device('cpu')


return params, args return params, args



+ 11
- 25
models.py View File

self.embedding = nn.Embedding(num_ent + 1, self.es) self.embedding = nn.Embedding(num_ent + 1, self.es)
nn.init.xavier_uniform_(self.embedding.weight) nn.init.xavier_uniform_(self.embedding.weight)



def forward(self, triples): def forward(self, triples):
idx = [[[t[0], t[2]] for t in batch] for batch in triples] idx = [[[t[0], t[2]] for t in batch] for batch in triples]
idx = torch.LongTensor(idx).to(self.device) idx = torch.LongTensor(idx).to(self.device)
self.embed_size = embed_size self.embed_size = embed_size
self.K = K self.K = K
self.out_size = out_size self.out_size = out_size
self.rel_fc1 = nn.Sequential(OrderedDict([
('fc', nn.Linear(2*embed_size, num_hidden1)),
('bn', nn.BatchNorm1d(K)),
('relu', nn.LeakyReLU()),
('drop', nn.Dropout(p=dropout_p)),
]))
self.rel_fc2 = nn.Sequential(OrderedDict([
('fc', nn.Linear(num_hidden1, num_hidden2)),
('bn', nn.BatchNorm1d(K)),
('relu', nn.LeakyReLU()),
('drop', nn.Dropout(p=dropout_p)),
]))
self.rel_fc3 = nn.Sequential(OrderedDict([
('fc', nn.Linear(num_hidden2, out_size)),
('bn', nn.BatchNorm1d(K)),
]))
nn.init.xavier_normal_(self.rel_fc1.fc.weight)
nn.init.xavier_normal_(self.rel_fc2.fc.weight)
nn.init.xavier_normal_(self.rel_fc3.fc.weight)
self.hidden_size = out_size
self.rnn = nn.LSTM(embed_size,self.hidden_size,1)
# nn.init.xavier_normal_(self.rnn.all_weights)


def forward(self, inputs): def forward(self, inputs):
size = inputs.shape size = inputs.shape
x = inputs.contiguous().view(size[0], size[1], -1)
x = self.rel_fc1(x)
x = self.rel_fc2(x)
x = self.rel_fc3(x)
x = torch.mean(x, 1)
x = inputs.contiguous().view(size[0]*2, size[2], size[3])
x = x.transpose(1,0)


_,(x,c) = self.rnn(x)
x = x[-1]

x = x.squeeze(0)
x = x.reshape(size[0],2,100)
x = torch.mean(x, 1)
return x.view(size[0], 1, 1, self.out_size) return x.view(size[0], 1, 1, self.out_size)





+ 1
- 1
trainer.py View File

print('Finish') print('Finish')


def eval(self, istest=False, epoch=None): def eval(self, istest=False, epoch=None):
self.MetaTL.eval()
# self.MetaTL.eval()
self.MetaTL.rel_q_sharing = dict() self.MetaTL.rel_q_sharing = dict()



Loading…
Cancel
Save