Browse Source

some model bugs fixed

main
rezav 2 years ago
parent
commit
519a0efb36
3 changed files with 5 additions and 2 deletions
  1. 1
    0
      drug/models.py
  2. 2
    1
      predictor/cross_validation.py
  3. 2
    1
      predictor/model/models.py

+ 1
- 0
drug/models.py View File



def forward(self, x, edge_index): def forward(self, x, edge_index):
# First Message Passing Layer (Transformation) # First Message Passing Layer (Transformation)
x = x.to(torch.float32)
x = self.conv1(x, edge_index) x = self.conv1(x, edge_index)
x = x.relu() x = x.relu()
x = F.dropout(x, p=0.5, training=self.training) x = F.dropout(x, p=0.5, training=self.training)

+ 2
- 1
predictor/cross_validation.py View File

trn_loss /= train_loader.dataset_len trn_loss /= train_loader.dataset_len
val_loss = eval_epoch(model, valid_loader, loss_func, gpu_id) val_loss = eval_epoch(model, valid_loader, loss_func, gpu_id)
val_loss /= valid_loader.dataset_len val_loss /= valid_loader.dataset_len
print("loss epoch " + str(epoch) + ": " + str(trn_loss))
if val_loss < min_loss: if val_loss < min_loss:
angry = 0 angry = 0
min_loss = val_loss min_loss = val_loss
def cv(args, out_dir): def cv(args, out_dir):
save_args(args, os.path.join(out_dir, 'args.json')) save_args(args, os.path.join(out_dir, 'args.json'))
test_loss_file = os.path.join(out_dir, 'test_loss.pkl') test_loss_file = os.path.join(out_dir, 'test_loss.pkl')
print("cuda available: " + str(torch.cuda.is_available()))
if torch.cuda.is_available() and (args.gpu is not None): if torch.cuda.is_available() and (args.gpu is not None):
gpu_id = args.gpu gpu_id = args.gpu
else: else:

+ 2
- 1
predictor/model/models.py View File

x = self.ddiDataset.get().x x = self.ddiDataset.get().x
edge_index = self.ddiDataset.get().edge_index edge_index = self.ddiDataset.get().edge_index
x = self.gcn(x, edge_index) x = self.gcn(x, edge_index)
drug1_idx = torch.flatten(drug1_idx)
drug2_idx = torch.flatten(drug2_idx)
drug1_feat = x[drug1_idx] drug1_feat = x[drug1_idx]
drug2_feat = x[drug2_idx] drug2_feat = x[drug2_idx]
feat = torch.cat([drug1_feat, drug2_feat, cell_feat], 1) feat = torch.cat([drug1_feat, drug2_feat, cell_feat], 1)

return feat return feat





Loading…
Cancel
Save