| def eval_model(model, optimizer, loss_func, train_data, test_data, | def eval_model(model, optimizer, loss_func, train_data, test_data, | ||||
| batch_size, n_epoch, patience, gpu_id, mdl_dir): | |||||
| batch_size, n_epoch, patience, ddi_graph, gpu_id, mdl_dir): | |||||
| tr_indices, es_indices = random_split_indices(len(train_data), test_rate=0.1) | tr_indices, es_indices = random_split_indices(len(train_data), test_rate=0.1) | ||||
| train_loader = FastTensorDataLoader(*train_data.tensor_samples(tr_indices), batch_size=batch_size, shuffle=True) | train_loader = FastTensorDataLoader(*train_data.tensor_samples(tr_indices), batch_size=batch_size, shuffle=True) | ||||
| valid_loader = FastTensorDataLoader(*train_data.tensor_samples(es_indices), batch_size=len(es_indices) // 4) | valid_loader = FastTensorDataLoader(*train_data.tensor_samples(es_indices), batch_size=len(es_indices) // 4) | ||||
| test_loader = FastTensorDataLoader(*test_data.tensor_samples(), batch_size=len(test_data) // 4) | test_loader = FastTensorDataLoader(*test_data.tensor_samples(), batch_size=len(test_data) // 4) | ||||
| train_model(model, optimizer, loss_func, train_loader, valid_loader, n_epoch, patience, gpu_id, | |||||
| train_model(model, optimizer, loss_func, train_loader, valid_loader, n_epoch, patience, gpu_id, ddi_graph, | |||||
| sl=True, mdl_dir=mdl_dir) | sl=True, mdl_dir=mdl_dir) | ||||
| test_loss = eval_epoch(model, test_loader, loss_func, gpu_id) | |||||
| test_loss = eval_epoch(model, test_loader, loss_func, ddi_graph, gpu_id) | |||||
| test_loss /= len(test_data) | test_loss /= len(test_data) | ||||
| return test_loss | return test_loss | ||||
| if not os.path.exists(test_mdl_dir): | if not os.path.exists(test_mdl_dir): | ||||
| os.makedirs(test_mdl_dir) | os.makedirs(test_mdl_dir) | ||||
| test_loss = eval_model(model, optimizer, loss_func, train_data, test_data, | test_loss = eval_model(model, optimizer, loss_func, train_data, test_data, | ||||
| args.batch, args.epoch, args.patience, gpu_id, test_mdl_dir) | |||||
| args.batch, args.epoch, args.patience, ddi_graph, gpu_id, test_mdl_dir) | |||||
| test_losses.append(test_loss) | test_losses.append(test_loss) | ||||
| logging.info("Test loss: {:.4f}".format(test_loss)) | logging.info("Test loss: {:.4f}".format(test_loss)) | ||||
| logging.info("*" * n_delimiter + '\n') | logging.info("*" * n_delimiter + '\n') |