Browse Source

improve speed (middle of the work)

main
MahsaYazdani 1 year ago
parent
commit
428d5f8d91
3 changed files with 58 additions and 14 deletions
  1. 6
    2
      predictor/cross_validation.py
  2. 37
    12
      predictor/model/models.py
  3. 15
    0
      predictor/model/utils.py

+ 6
- 2
predictor/cross_validation.py View File

@@ -69,7 +69,7 @@ def train_epoch(model, loader, loss_func, optimizer, ddi_graph, gpu_id=None):
epoch_loss = 0

ddi_graph.n_id = torch.arange(ddi_graph.num_nodes)
start = time.time()
# print("len of batch loader: ",len(loader))
for i, batch in enumerate(loader):
# print(f"this is batch {i} ------------------")
@@ -93,6 +93,8 @@ def train_epoch(model, loader, loss_func, optimizer, ddi_graph, gpu_id=None):
# print(type(sampled_data))
# print("Sampled_data: ")
# print(sampled_data.batch_size)

start_inner = time.time()
for subgraph in neighbor_loader:
# print("this is subgraph in cross_validation:")
@@ -102,8 +104,10 @@ def train_epoch(model, loader, loss_func, optimizer, ddi_graph, gpu_id=None):
loss.backward()
optimizer.step()
epoch_loss += loss.item()
print("for on subgraphs time: ", time.time() - start_inner)
# print("epoch_loss: ", epoch_loss)
print("train epoch time: ", time.time() - start)
return epoch_loss


@@ -155,7 +159,7 @@ def train_model(model, optimizer, loss_func, train_loader, valid_loader, n_epoch
if angry >= patience:
break
end_time = time.time()
print("epoch time (min): ", (end_time - start_time)/60)
print("epoch time: ", (end_time - start_time))
if sl:
model.load_state_dict(torch.load(find_best_model(mdl_dir)))
return min_loss

+ 37
- 12
predictor/model/models.py View File

@@ -4,13 +4,13 @@ import torch.nn.functional as F
import os
import sys
import pandas as pd
import time
PROJ_DIR = os.path.dirname(os.path.abspath(os.path.join(os.path.dirname( __file__ ), '..')))

sys.path.insert(0, PROJ_DIR)
from drug.models import GCN
from drug.datasets import DDInteractionDataset
from model.utils import get_FP_by_negative_index
from model.utils import get_FP_by_negative_index, get_FP_by_negative_indices
from const import Drug2FP_FILE


@@ -55,25 +55,48 @@ class Connector(nn.Module):
#drug2_feat = x[drug2_idx]
drug1_feat = torch.empty((len(drug1_idx), len(x[0])))
drug2_feat = torch.empty((len(drug2_idx), len(x[0])))
for index, element in enumerate(drug1_idx):
x_element = element
if element >= 0:
x_element = (node_indices == element).nonzero().squeeze()
drug1_feat[index] = (x[x_element])
for index, element in enumerate(drug2_idx):
x_element = element
if element >= 0:
x_element = (node_indices == element).nonzero().squeeze()
drug2_feat[index] = (x[x_element])

print("x shape: ", x.size())
print("node_indices: ", node_indices.size())

start_time = time.time()
# for index, element in enumerate(drug1_idx):
# x_element = element
# if element >= 0:
# x_element = (node_indices == element).nonzero().squeeze()
# drug1_feat[index] = (x[x_element])
# for index, element in enumerate(drug2_idx):
# x_element = element
# if element >= 0:
# x_element = (node_indices == element).nonzero().squeeze()
# drug2_feat[index] = (x[x_element])

mask_positive = (drug1_idx >= 0)
x_elements_positive = (node_indices.unsqueeze(-1) == drug1_idx[mask_positive]).nonzero(as_tuple=True)[0]
drug1_feat[mask_positive] = x[x_elements_positive]

mask_negative = ~mask_positive
drug1_feat[mask_negative] = get_FP_by_negative_indices(drug1_idx[mask_negative], self.drug2FP_df)

mask_positive = (drug2_idx >= 0)
x_elements_positive = (node_indices.unsqueeze(-1) == drug2_idx[mask_positive]).nonzero(as_tuple=True)[0]
drug2_feat[mask_positive] = x[x_elements_positive]
if self.gpu_id is not None:
drug1_feat = drug1_feat.cuda(self.gpu_id)
drug2_feat = drug2_feat.cuda(self.gpu_id)

print("first: ", time.time() - start_time)
start_time = time.time()

for i, x in enumerate(drug1_idx):
if x < 0:
drug1_feat[i] = get_FP_by_negative_index(x,self.drug2FP_df)
for i, x in enumerate(drug2_idx):
if x < 0:
drug2_feat[i] = get_FP_by_negative_index(x,self.drug2FP_df)
print("second: ", time.time() - start_time)
feat = torch.cat([drug1_feat, drug2_feat, cell_feat], 1)
return feat

@@ -95,7 +118,9 @@ class MLP(nn.Module):
# prev input: self, drug1_feat: torch.Tensor, drug2_feat: torch.Tensor, cell_feat: torch.Tensor, subgraph: related subgraph for the batch
def forward(self, drug1_idx, drug2_idx, cell_feat, subgraph):
start = time.time()
feat = self.connector(drug1_idx, drug2_idx, cell_feat, subgraph)
print("Connector forward time: ", time.time() - start)
out = self.layers(feat)
return out


+ 15
- 0
predictor/model/utils.py View File

@@ -117,3 +117,18 @@ def get_FP_by_negative_index(index, drug2FP_df = None):

array = np.array(list(drug2FP_df.iloc[-index])[1:])
return torch.tensor(array, dtype=torch.float32)


def get_FP_by_negative_indices(indices, drug2FP_df = None):
# converting negative indices to real indices
indices *= -1
print(indices)
print(type(drug2FP_df))

if drug2FP_df is None or drug2FP_df.empty:
print("load Drug2FP_FILE")
drug2FP_df = pd.read_csv(Drug2FP_FILE)

# array = np.array(list(drug2FP_df.iloc[[indices]])[1:])
print(drug2FP_df.iloc[[indices]])
# return torch.tensor(array, dtype=torch.float32)

Loading…
Cancel
Save