def forward(self, x, edge_index): | def forward(self, x, edge_index): | ||||
# First Message Passing Layer (Transformation) | # First Message Passing Layer (Transformation) | ||||
x = x.to(torch.float32) | |||||
x = self.conv1(x, edge_index) | x = self.conv1(x, edge_index) | ||||
x = x.relu() | x = x.relu() | ||||
x = F.dropout(x, p=0.5, training=self.training) | x = F.dropout(x, p=0.5, training=self.training) |
trn_loss /= train_loader.dataset_len | trn_loss /= train_loader.dataset_len | ||||
val_loss = eval_epoch(model, valid_loader, loss_func, gpu_id) | val_loss = eval_epoch(model, valid_loader, loss_func, gpu_id) | ||||
val_loss /= valid_loader.dataset_len | val_loss /= valid_loader.dataset_len | ||||
print("loss epoch " + str(epoch) + ": " + str(trn_loss)) | |||||
if val_loss < min_loss: | if val_loss < min_loss: | ||||
angry = 0 | angry = 0 | ||||
min_loss = val_loss | min_loss = val_loss | ||||
def cv(args, out_dir): | def cv(args, out_dir): | ||||
save_args(args, os.path.join(out_dir, 'args.json')) | save_args(args, os.path.join(out_dir, 'args.json')) | ||||
test_loss_file = os.path.join(out_dir, 'test_loss.pkl') | test_loss_file = os.path.join(out_dir, 'test_loss.pkl') | ||||
print("cuda available: " + str(torch.cuda.is_available())) | |||||
if torch.cuda.is_available() and (args.gpu is not None): | if torch.cuda.is_available() and (args.gpu is not None): | ||||
gpu_id = args.gpu | gpu_id = args.gpu | ||||
else: | else: |
x = self.ddiDataset.get().x | x = self.ddiDataset.get().x | ||||
edge_index = self.ddiDataset.get().edge_index | edge_index = self.ddiDataset.get().edge_index | ||||
x = self.gcn(x, edge_index) | x = self.gcn(x, edge_index) | ||||
drug1_idx = torch.flatten(drug1_idx) | |||||
drug2_idx = torch.flatten(drug2_idx) | |||||
drug1_feat = x[drug1_idx] | drug1_feat = x[drug1_idx] | ||||
drug2_feat = x[drug2_idx] | drug2_feat = x[drug2_idx] | ||||
feat = torch.cat([drug1_feat, drug2_feat, cell_feat], 1) | feat = torch.cat([drug1_feat, drug2_feat, cell_feat], 1) | ||||
return feat | return feat | ||||