Browse Source

Merge pull request #1 from mahsayazdani/gpu-specific

Merge gpu-specific to main
main
Mahsa Yazdani 1 year ago
parent
commit
19efc41fde
No account linked to committer's email address
4 changed files with 17 additions and 14 deletions
  1. 10
    5
      drug/datasets.py
  2. 1
    5
      drug/models.py
  3. 4
    1
      predictor/cross_validation.py
  4. 2
    3
      predictor/model/models.py

+ 10
- 5
drug/datasets.py View File

@@ -9,9 +9,11 @@ import os


class DDInteractionDataset(Dataset):
def __init__(self, root = "\\drug/data/", transform=None, pre_transform=None, pre_filter=None):
def __init__(self, root = "\\drug/data/", transform=None, pre_transform=None, pre_filter=None, gpu_id=None):
self.gpu_id = gpu_id
super(DDInteractionDataset, self).__init__(os.path.dirname(os.path.abspath(os.path.dirname( __file__ ))) + "/drug/data/", transform, pre_transform, pre_filter)


@property
def num_features(self):
return self._num_features
@@ -88,6 +90,9 @@ class DDInteractionDataset(Dataset):

# ---------------------------------------------------------------
data = Data(x = node_features, edge_index = edge_index)
if self.gpu_id is not None:
data = data.cuda(self.gpu_id)

if self.pre_filter is not None and not self.pre_filter(data):
pass
@@ -105,8 +110,8 @@ class DDInteractionDataset(Dataset):
data = torch.load(osp.join(self.processed_dir, 'ddi_graph_dataset.pt'))
return data

ddiDataset = DDInteractionDataset(root = "drug/data/")
print(ddiDataset.get().edge_index.t())
# run for checking
# ddiDataset = DDInteractionDataset(root = "drug/data/")
# print(ddiDataset.get().edge_index.t())
# print(ddiDataset.get().x)
print(ddiDataset.num_features)
# print(ddiDataset.num_features)

+ 1
- 5
drug/models.py View File

@@ -19,21 +19,17 @@ from torch_geometric.nn import GCNConv

# base from this notebook: https://colab.research.google.com/drive/1LJir3T6M6Omc2Vn2GV2cDW_GV2YfI53_?usp=sharing#scrollTo=jNsToorfSgS0
class GCN(torch.nn.Module):
def __init__(self, num_features, hidden_channels, gpu_id=None): # num_features = dataset.num_features
def __init__(self, num_features, hidden_channels): # num_features = dataset.num_features
super(GCN, self).__init__()
torch.manual_seed(42)

# Initialize the layers
self.conv1 = GCNConv(num_features, hidden_channels)
self.conv2 = GCNConv(hidden_channels, num_features)
self.gpu_id = gpu_id

def forward(self, x, edge_index):
# First Message Passing Layer (Transformation)
x = x.to(torch.float32)
if self.gpu_id is not None:
x = x.cuda(self.gpu_id)
edge_index = edge_index.cuda(self.gpu_id)
x = self.conv1(x, edge_index)
x = x.relu()
x = F.dropout(x, p=0.5, training=self.training)

+ 4
- 1
predictor/cross_validation.py View File

@@ -36,7 +36,8 @@ def step_batch(model, batch, loss_func, gpu_id=None, train=True):
drug1_id = drug1_id.cuda(gpu_id)
drug2_id = drug2_id.cuda(gpu_id)
cell_feat = cell_feat.cuda(gpu_id)
pass
y_true = y_true.cuda(gpu_id)

if train:
y_pred = model(drug1_id, drug2_id, cell_feat)
else:
@@ -155,6 +156,8 @@ def cv(args, out_dir):
logging.info("Inner loop completed. Mean valid loss: {:.4f}".format(inner_loss))
logging.info("-" * n_delimiter)
losses.append(inner_loss)

torch.cuda.memory_summary(device=None, abbreviated=False)
gc.collect()
torch.cuda.empty_cache()
time.sleep(10)

+ 2
- 3
predictor/model/models.py View File

@@ -16,9 +16,8 @@ from model.utils import get_FP_by_negative_index
class Connector(nn.Module):
def __init__(self, gpu_id=None):
super(Connector, self).__init__()

self.ddiDataset = DDInteractionDataset()
self.gcn = GCN(self.ddiDataset.num_features, self.ddiDataset.num_features // 2, gpu_id)
self.ddiDataset = DDInteractionDataset(gpu_id = gpu_id)
self.gcn = GCN(self.ddiDataset.num_features, self.ddiDataset.num_features // 2)
#Cell line features
# np.load('cell_feat.npy')

Loading…
Cancel
Save