123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120 |
- import numpy as np
- import torch
-
- import random
-
- from torch.utils.data import Dataset
- from .utils import read_map, get_index_by_name
-
- class FastTensorDataLoader:
- """
- A DataLoader-like object for a set of tensors that can be much faster than
- TensorDataset + DataLoader because dataloader grabs individual indices of
- the dataset and calls cat (slow).
- Source: https://discuss.pytorch.org/t/dataloader-much-slower-than-manual-batching/27014/6
- """
-
- def __init__(self, *tensors, batch_size=32, shuffle=False):
- """
- Initialize a FastTensorDataLoader.
- :param *tensors: tensors to store. Must have the same length @ dim 0.
- :param batch_size: batch size to load.
- :param shuffle: if True, shuffle the data *in-place* whenever an
- iterator is created out of this object.
- :returns: A FastTensorDataLoader.
- """
- assert all(t.shape[0] == tensors[0].shape[0] for t in tensors)
- self.tensors = tensors
-
- self.dataset_len = self.tensors[0].shape[0]
- self.batch_size = batch_size
- self.shuffle = shuffle
-
- # Calculate # batches
- n_batches, remainder = divmod(self.dataset_len, self.batch_size)
- if remainder > 0:
- n_batches += 1
- self.n_batches = n_batches
-
- def __iter__(self):
- if self.shuffle:
- r = torch.randperm(self.dataset_len)
- self.tensors = [t[r] for t in self.tensors]
- self.i = 0
- return self
-
- def __next__(self):
- if self.i >= self.dataset_len:
- raise StopIteration
- batch = tuple(t[self.i:self.i + self.batch_size] for t in self.tensors)
- self.i += self.batch_size
- return batch
-
- def __len__(self):
- return self.n_batches
-
- class FastSynergyDataset(Dataset):
-
- def __init__(self, drugname2drugbankid_file, cell2id_file, cell_feat_file, synergy_score_file, use_folds,
- train=True):
- self.drug2id = read_map(drugname2drugbankid_file, keep_str = True)
- self.cell2id = read_map(cell2id_file)
- self.cell_feat = np.load(cell_feat_file)
- self.samples = []
- self.raw_samples = []
- self.train = train
- valid_drugs = set(self.drug2id.keys())
- valid_cells = set(self.cell2id.keys())
- with open(synergy_score_file, 'r') as f:
- f.readline()
- for line in f:
- drug1, drug2, cellname, score, fold = line.rstrip().split('\t')
- if drug1 in valid_drugs and drug2 in valid_drugs and cellname in valid_cells:
- if int(fold) in use_folds:
- drug1_id = get_index_by_name(drug1)
- drug2_id = get_index_by_name(drug2)
- sample = [
- # TODO: specify drug_feat
- # drug1_feat + drug2_feat + cell_feat + score
- torch.IntTensor([drug1_id]),
- torch.IntTensor([drug2_id]),
- torch.from_numpy(self.cell_feat[self.cell2id[cellname]]).float(),
- torch.FloatTensor([float(score)]),
- ]
- # print(sample)
- self.samples.append(sample)
- raw_sample = [drug1_id, drug2_id, self.cell2id[cellname], score]
- self.raw_samples.append(raw_sample)
- if train:
- sample = [
- torch.IntTensor([drug2_id]),
- torch.IntTensor([drug1_id]),
- torch.from_numpy(self.cell_feat[self.cell2id[cellname]]).float(),
- torch.FloatTensor([float(score)]),
- ]
- self.samples.append(sample)
- raw_sample = [drug2_id, drug1_id, self.cell2id[cellname], score]
- self.raw_samples.append(raw_sample)
-
- def __len__(self):
- return len(self.samples)
-
- def __getitem__(self, item):
- return self.samples[item]
-
- def cell_feat_len(self):
- return self.cell_feat.shape[-1]
-
- def tensor_samples(self, indices=None):
- if indices is None:
- indices = list(range(len(self)))
- # print('-----------------------------')
- # print(self.samples)
- # print('-----------------')
- print(self.samples[0])
- print(self.samples[i][0] and self.samples[i][1] for i in indices)
- d1 = torch.cat([torch.unsqueeze(self.samples[i][0], 0) for i in indices], dim=0)
- d2 = torch.cat([torch.unsqueeze(self.samples[i][1], 0) for i in indices], dim=0)
- c = torch.cat([torch.unsqueeze(self.samples[i][2], 0) for i in indices], dim=0)
- y = torch.cat([torch.unsqueeze(self.samples[i][3], 0) for i in indices], dim=0)
- return d1, d2, c, y
|