Browse Source

using LSTM instead of MLP

RNN
mohamad maheri 3 years ago
parent
commit
1c015e3b35
3 changed files with 14 additions and 27 deletions
  1. 2
    1
      main.py
  2. 11
    25
      models.py
  3. 1
    1
      trainer.py

+ 2
- 1
main.py View File

@@ -24,7 +24,7 @@ def get_params():
args.add_argument("-m", "--margin", default=1, type=float)
args.add_argument("-p", "--dropout_p", default=0.5, type=float)

args.add_argument("-gpu", "--device", default=0, type=int)
args.add_argument("-gpu", "--device", default=1, type=int)


args = args.parse_args()
@@ -33,6 +33,7 @@ def get_params():
params[k] = v

params['device'] = torch.device('cuda:'+str(args.device))
# params['device'] = torch.device('cpu')

return params, args


+ 11
- 25
models.py View File

@@ -12,7 +12,6 @@ class Embedding(nn.Module):
self.embedding = nn.Embedding(num_ent + 1, self.es)
nn.init.xavier_uniform_(self.embedding.weight)


def forward(self, triples):
idx = [[[t[0], t[2]] for t in batch] for batch in triples]
idx = torch.LongTensor(idx).to(self.device)
@@ -24,34 +23,21 @@ class MetaLearner(nn.Module):
self.embed_size = embed_size
self.K = K
self.out_size = out_size
self.rel_fc1 = nn.Sequential(OrderedDict([
('fc', nn.Linear(2*embed_size, num_hidden1)),
('bn', nn.BatchNorm1d(K)),
('relu', nn.LeakyReLU()),
('drop', nn.Dropout(p=dropout_p)),
]))
self.rel_fc2 = nn.Sequential(OrderedDict([
('fc', nn.Linear(num_hidden1, num_hidden2)),
('bn', nn.BatchNorm1d(K)),
('relu', nn.LeakyReLU()),
('drop', nn.Dropout(p=dropout_p)),
]))
self.rel_fc3 = nn.Sequential(OrderedDict([
('fc', nn.Linear(num_hidden2, out_size)),
('bn', nn.BatchNorm1d(K)),
]))
nn.init.xavier_normal_(self.rel_fc1.fc.weight)
nn.init.xavier_normal_(self.rel_fc2.fc.weight)
nn.init.xavier_normal_(self.rel_fc3.fc.weight)
self.hidden_size = out_size
self.rnn = nn.LSTM(embed_size,self.hidden_size,1)
# nn.init.xavier_normal_(self.rnn.all_weights)

def forward(self, inputs):
size = inputs.shape
x = inputs.contiguous().view(size[0], size[1], -1)
x = self.rel_fc1(x)
x = self.rel_fc2(x)
x = self.rel_fc3(x)
x = torch.mean(x, 1)
x = inputs.contiguous().view(size[0]*2, size[2], size[3])
x = x.transpose(1,0)

_,(x,c) = self.rnn(x)
x = x[-1]

x = x.squeeze(0)
x = x.reshape(size[0],2,100)
x = torch.mean(x, 1)
return x.view(size[0], 1, 1, self.out_size)



+ 1
- 1
trainer.py View File

@@ -88,7 +88,7 @@ class Trainer:
print('Finish')

def eval(self, istest=False, epoch=None):
self.MetaTL.eval()
# self.MetaTL.eval()
self.MetaTL.rel_q_sharing = dict()


Loading…
Cancel
Save