Browse Source

improve speed (middle of the work)

main
MahsaYazdani 1 year ago
parent
commit
428d5f8d91
3 changed files with 58 additions and 14 deletions
  1. 6
    2
      predictor/cross_validation.py
  2. 37
    12
      predictor/model/models.py
  3. 15
    0
      predictor/model/utils.py

+ 6
- 2
predictor/cross_validation.py View File

epoch_loss = 0 epoch_loss = 0


ddi_graph.n_id = torch.arange(ddi_graph.num_nodes) ddi_graph.n_id = torch.arange(ddi_graph.num_nodes)
start = time.time()
# print("len of batch loader: ",len(loader)) # print("len of batch loader: ",len(loader))
for i, batch in enumerate(loader): for i, batch in enumerate(loader):
# print(f"this is batch {i} ------------------") # print(f"this is batch {i} ------------------")
# print(type(sampled_data)) # print(type(sampled_data))
# print("Sampled_data: ") # print("Sampled_data: ")
# print(sampled_data.batch_size) # print(sampled_data.batch_size)

start_inner = time.time()
for subgraph in neighbor_loader: for subgraph in neighbor_loader:
# print("this is subgraph in cross_validation:") # print("this is subgraph in cross_validation:")
loss.backward() loss.backward()
optimizer.step() optimizer.step()
epoch_loss += loss.item() epoch_loss += loss.item()
print("for on subgraphs time: ", time.time() - start_inner)
# print("epoch_loss: ", epoch_loss) # print("epoch_loss: ", epoch_loss)
print("train epoch time: ", time.time() - start)
return epoch_loss return epoch_loss




if angry >= patience: if angry >= patience:
break break
end_time = time.time() end_time = time.time()
print("epoch time (min): ", (end_time - start_time)/60)
print("epoch time: ", (end_time - start_time))
if sl: if sl:
model.load_state_dict(torch.load(find_best_model(mdl_dir))) model.load_state_dict(torch.load(find_best_model(mdl_dir)))
return min_loss return min_loss

+ 37
- 12
predictor/model/models.py View File

import os import os
import sys import sys
import pandas as pd import pandas as pd
import time
PROJ_DIR = os.path.dirname(os.path.abspath(os.path.join(os.path.dirname( __file__ ), '..'))) PROJ_DIR = os.path.dirname(os.path.abspath(os.path.join(os.path.dirname( __file__ ), '..')))


sys.path.insert(0, PROJ_DIR) sys.path.insert(0, PROJ_DIR)
from drug.models import GCN from drug.models import GCN
from drug.datasets import DDInteractionDataset from drug.datasets import DDInteractionDataset
from model.utils import get_FP_by_negative_index
from model.utils import get_FP_by_negative_index, get_FP_by_negative_indices
from const import Drug2FP_FILE from const import Drug2FP_FILE




#drug2_feat = x[drug2_idx] #drug2_feat = x[drug2_idx]
drug1_feat = torch.empty((len(drug1_idx), len(x[0]))) drug1_feat = torch.empty((len(drug1_idx), len(x[0])))
drug2_feat = torch.empty((len(drug2_idx), len(x[0]))) drug2_feat = torch.empty((len(drug2_idx), len(x[0])))
for index, element in enumerate(drug1_idx):
x_element = element
if element >= 0:
x_element = (node_indices == element).nonzero().squeeze()
drug1_feat[index] = (x[x_element])
for index, element in enumerate(drug2_idx):
x_element = element
if element >= 0:
x_element = (node_indices == element).nonzero().squeeze()
drug2_feat[index] = (x[x_element])

print("x shape: ", x.size())
print("node_indices: ", node_indices.size())

start_time = time.time()
# for index, element in enumerate(drug1_idx):
# x_element = element
# if element >= 0:
# x_element = (node_indices == element).nonzero().squeeze()
# drug1_feat[index] = (x[x_element])
# for index, element in enumerate(drug2_idx):
# x_element = element
# if element >= 0:
# x_element = (node_indices == element).nonzero().squeeze()
# drug2_feat[index] = (x[x_element])

mask_positive = (drug1_idx >= 0)
x_elements_positive = (node_indices.unsqueeze(-1) == drug1_idx[mask_positive]).nonzero(as_tuple=True)[0]
drug1_feat[mask_positive] = x[x_elements_positive]

mask_negative = ~mask_positive
drug1_feat[mask_negative] = get_FP_by_negative_indices(drug1_idx[mask_negative], self.drug2FP_df)

mask_positive = (drug2_idx >= 0)
x_elements_positive = (node_indices.unsqueeze(-1) == drug2_idx[mask_positive]).nonzero(as_tuple=True)[0]
drug2_feat[mask_positive] = x[x_elements_positive]
if self.gpu_id is not None: if self.gpu_id is not None:
drug1_feat = drug1_feat.cuda(self.gpu_id) drug1_feat = drug1_feat.cuda(self.gpu_id)
drug2_feat = drug2_feat.cuda(self.gpu_id) drug2_feat = drug2_feat.cuda(self.gpu_id)

print("first: ", time.time() - start_time)
start_time = time.time()

for i, x in enumerate(drug1_idx): for i, x in enumerate(drug1_idx):
if x < 0: if x < 0:
drug1_feat[i] = get_FP_by_negative_index(x,self.drug2FP_df) drug1_feat[i] = get_FP_by_negative_index(x,self.drug2FP_df)
for i, x in enumerate(drug2_idx): for i, x in enumerate(drug2_idx):
if x < 0: if x < 0:
drug2_feat[i] = get_FP_by_negative_index(x,self.drug2FP_df) drug2_feat[i] = get_FP_by_negative_index(x,self.drug2FP_df)
print("second: ", time.time() - start_time)
feat = torch.cat([drug1_feat, drug2_feat, cell_feat], 1) feat = torch.cat([drug1_feat, drug2_feat, cell_feat], 1)
return feat return feat


# prev input: self, drug1_feat: torch.Tensor, drug2_feat: torch.Tensor, cell_feat: torch.Tensor, subgraph: related subgraph for the batch # prev input: self, drug1_feat: torch.Tensor, drug2_feat: torch.Tensor, cell_feat: torch.Tensor, subgraph: related subgraph for the batch
def forward(self, drug1_idx, drug2_idx, cell_feat, subgraph): def forward(self, drug1_idx, drug2_idx, cell_feat, subgraph):
start = time.time()
feat = self.connector(drug1_idx, drug2_idx, cell_feat, subgraph) feat = self.connector(drug1_idx, drug2_idx, cell_feat, subgraph)
print("Connector forward time: ", time.time() - start)
out = self.layers(feat) out = self.layers(feat)
return out return out



+ 15
- 0
predictor/model/utils.py View File



array = np.array(list(drug2FP_df.iloc[-index])[1:]) array = np.array(list(drug2FP_df.iloc[-index])[1:])
return torch.tensor(array, dtype=torch.float32) return torch.tensor(array, dtype=torch.float32)


def get_FP_by_negative_indices(indices, drug2FP_df = None):
# converting negative indices to real indices
indices *= -1
print(indices)
print(type(drug2FP_df))

if drug2FP_df is None or drug2FP_df.empty:
print("load Drug2FP_FILE")
drug2FP_df = pd.read_csv(Drug2FP_FILE)

# array = np.array(list(drug2FP_df.iloc[[indices]])[1:])
print(drug2FP_df.iloc[[indices]])
# return torch.tensor(array, dtype=torch.float32)

Loading…
Cancel
Save