Browse Source

some model bugs fixed

main
rezav 2 years ago
parent
commit
519a0efb36
3 changed files with 5 additions and 2 deletions
  1. 1
    0
      drug/models.py
  2. 2
    1
      predictor/cross_validation.py
  3. 2
    1
      predictor/model/models.py

+ 1
- 0
drug/models.py View File

@@ -29,6 +29,7 @@ class GCN(torch.nn.Module):

def forward(self, x, edge_index):
# First Message Passing Layer (Transformation)
x = x.to(torch.float32)
x = self.conv1(x, edge_index)
x = x.relu()
x = F.dropout(x, p=0.5, training=self.training)

+ 2
- 1
predictor/cross_validation.py View File

@@ -76,6 +76,7 @@ def train_model(model, optimizer, loss_func, train_loader, valid_loader, n_epoch
trn_loss /= train_loader.dataset_len
val_loss = eval_epoch(model, valid_loader, loss_func, gpu_id)
val_loss /= valid_loader.dataset_len
print("loss epoch " + str(epoch) + ": " + str(trn_loss))
if val_loss < min_loss:
angry = 0
min_loss = val_loss
@@ -104,7 +105,7 @@ def create_model(data, hidden_size, gpu_id=None):
def cv(args, out_dir):
save_args(args, os.path.join(out_dir, 'args.json'))
test_loss_file = os.path.join(out_dir, 'test_loss.pkl')
print("cuda available: " + str(torch.cuda.is_available()))
if torch.cuda.is_available() and (args.gpu is not None):
gpu_id = args.gpu
else:

+ 2
- 1
predictor/model/models.py View File

@@ -26,10 +26,11 @@ class Connector(nn.Module):
x = self.ddiDataset.get().x
edge_index = self.ddiDataset.get().edge_index
x = self.gcn(x, edge_index)
drug1_idx = torch.flatten(drug1_idx)
drug2_idx = torch.flatten(drug2_idx)
drug1_feat = x[drug1_idx]
drug2_feat = x[drug2_idx]
feat = torch.cat([drug1_feat, drug2_feat, cell_feat], 1)

return feat



Loading…
Cancel
Save