import numpy as np
import torch
import random
from torch.utils.data import Dataset
from .utils import read_map, get_index_by_name
class FastTensorDataLoader:
"""
A DataLoader-like object for a set of tensors that can be much faster than
TensorDataset + DataLoader because dataloader grabs individual indices of
the dataset and calls cat (slow).
Source: https://discuss.pytorch.org/t/dataloader-much-slower-than-manual-batching/27014/6
"""
def __init__(self, *tensors, batch_size=32, shuffle=False):
"""
Initialize a FastTensorDataLoader.
:param *tensors: tensors to store. Must have the same length @ dim 0.
:param batch_size: batch size to load.
:param shuffle: if True, shuffle the data *in-place* whenever an
iterator is created out of this object.
:returns: A FastTensorDataLoader.
"""
assert all(t.shape[0] == tensors[0].shape[0] for t in tensors)
self.tensors = tensors
self.dataset_len = self.tensors[0].shape[0]
self.batch_size = batch_size
self.shuffle = shuffle
# Calculate # batches
n_batches, remainder = divmod(self.dataset_len, self.batch_size)
if remainder > 0:
n_batches += 1
self.n_batches = n_batches
def __iter__(self):
if self.shuffle:
r = torch.randperm(self.dataset_len)
self.tensors = [t[r] for t in self.tensors]
self.i = 0
return self
def __next__(self):
if self.i >= self.dataset_len:
raise StopIteration
batch = tuple(t[self.i:self.i + self.batch_size] for t in self.tensors)
self.i += self.batch_size
return batch
def __len__(self):
return self.n_batches
class FastSynergyDataset(Dataset):
def __init__(self, drugname2drugbankid_file, cell2id_file, cell_feat_file, synergy_score_file, use_folds,
train=True):
self.drug2id = read_map(drugname2drugbankid_file, keep_str = True)
self.cell2id = read_map(cell2id_file)
self.cell_feat = np.load(cell_feat_file)
self.samples = []
self.raw_samples = []
self.train = train
valid_drugs = set(self.drug2id.keys())
valid_cells = set(self.cell2id.keys())
with open(synergy_score_file, 'r') as f:
f.readline()
for line in f:
drug1, drug2, cellname, score, fold = line.rstrip().split('\t')
if drug1 in valid_drugs and drug2 in valid_drugs and cellname in valid_cells:
if int(fold) in use_folds:
drug1_id = get_index_by_name(drug1)
drug2_id = get_index_by_name(drug2)
sample = [
# TODO: specify drug_feat
# drug1_feat + drug2_feat + cell_feat + score
torch.IntTensor([drug1_id]),
torch.IntTensor([drug2_id]),
torch.from_numpy(self.cell_feat[self.cell2id[cellname]]).float(),
torch.FloatTensor([float(score)]),
]
# print(sample)
self.samples.append(sample)
raw_sample = [drug1_id, drug2_id, self.cell2id[cellname], score]
self.raw_samples.append(raw_sample)
if train:
sample = [
torch.IntTensor([drug2_id]),
torch.IntTensor([drug1_id]),
torch.from_numpy(self.cell_feat[self.cell2id[cellname]]).float(),
torch.FloatTensor([float(score)]),
]
self.samples.append(sample)
raw_sample = [drug2_id, drug1_id, self.cell2id[cellname], score]
self.raw_samples.append(raw_sample)
def __len__(self):
return len(self.samples)
def __getitem__(self, item):
return self.samples[item]
def cell_feat_len(self):
return self.cell_feat.shape[-1]
def tensor_samples(self, indices=None):
if indices is None:
indices = list(range(len(self)))
# print('-----------------------------')
# print(self.samples)
# print('-----------------')
print(self.samples[0])
print(self.samples[i][0] and self.samples[i][1] for i in indices)
d1 = torch.cat([torch.unsqueeze(self.samples[i][0], 0) for i in indices], dim=0)
d2 = torch.cat([torch.unsqueeze(self.samples[i][1], 0) for i in indices], dim=0)
c = torch.cat([torch.unsqueeze(self.samples[i][2], 0) for i in indices], dim=0)
y = torch.cat([torch.unsqueeze(self.samples[i][3], 0) for i in indices], dim=0)
return d1, d2, c, y