import numpy as np import torch import random from torch.utils.data import Dataset from .utils import read_map, get_index_by_name class FastTensorDataLoader: """ A DataLoader-like object for a set of tensors that can be much faster than TensorDataset + DataLoader because dataloader grabs individual indices of the dataset and calls cat (slow). Source: https://discuss.pytorch.org/t/dataloader-much-slower-than-manual-batching/27014/6 """ def __init__(self, *tensors, batch_size=32, shuffle=False): """ Initialize a FastTensorDataLoader. :param *tensors: tensors to store. Must have the same length @ dim 0. :param batch_size: batch size to load. :param shuffle: if True, shuffle the data *in-place* whenever an iterator is created out of this object. :returns: A FastTensorDataLoader. """ assert all(t.shape[0] == tensors[0].shape[0] for t in tensors) self.tensors = tensors self.dataset_len = self.tensors[0].shape[0] self.batch_size = batch_size self.shuffle = shuffle # Calculate # batches n_batches, remainder = divmod(self.dataset_len, self.batch_size) if remainder > 0: n_batches += 1 self.n_batches = n_batches def __iter__(self): if self.shuffle: r = torch.randperm(self.dataset_len) self.tensors = [t[r] for t in self.tensors] self.i = 0 return self def __next__(self): if self.i >= self.dataset_len: raise StopIteration batch = tuple(t[self.i:self.i + self.batch_size] for t in self.tensors) self.i += self.batch_size return batch def __len__(self): return self.n_batches class FastSynergyDataset(Dataset): def __init__(self, drugname2drugbankid_file, cell2id_file, cell_feat_file, synergy_score_file, use_folds, train=True): self.drug2id = read_map(drugname2drugbankid_file, keep_str = True) self.cell2id = read_map(cell2id_file) self.cell_feat = np.load(cell_feat_file) self.samples = [] self.raw_samples = [] self.train = train valid_drugs = set(self.drug2id.keys()) valid_cells = set(self.cell2id.keys()) with open(synergy_score_file, 'r') as f: f.readline() for line in f: drug1, drug2, cellname, score, fold = line.rstrip().split('\t') if drug1 in valid_drugs and drug2 in valid_drugs and cellname in valid_cells: if int(fold) in use_folds: drug1_id = get_index_by_name(drug1) drug2_id = get_index_by_name(drug2) sample = [ # TODO: specify drug_feat # drug1_feat + drug2_feat + cell_feat + score torch.IntTensor([drug1_id]), torch.IntTensor([drug2_id]), torch.from_numpy(self.cell_feat[self.cell2id[cellname]]).float(), torch.FloatTensor([float(score)]), ] # print(sample) self.samples.append(sample) raw_sample = [drug1_id, drug2_id, self.cell2id[cellname], score] self.raw_samples.append(raw_sample) if train: sample = [ torch.IntTensor([drug2_id]), torch.IntTensor([drug1_id]), torch.from_numpy(self.cell_feat[self.cell2id[cellname]]).float(), torch.FloatTensor([float(score)]), ] self.samples.append(sample) raw_sample = [drug2_id, drug1_id, self.cell2id[cellname], score] self.raw_samples.append(raw_sample) def __len__(self): return len(self.samples) def __getitem__(self, item): return self.samples[item] def cell_feat_len(self): return self.cell_feat.shape[-1] def tensor_samples(self, indices=None): if indices is None: indices = list(range(len(self))) # print('-----------------------------') # print(self.samples) # print('-----------------') print(self.samples[0]) print(self.samples[i][0] and self.samples[i][1] for i in indices) d1 = torch.cat([torch.unsqueeze(self.samples[i][0], 0) for i in indices], dim=0) d2 = torch.cat([torch.unsqueeze(self.samples[i][1], 0) for i in indices], dim=0) c = torch.cat([torch.unsqueeze(self.samples[i][2], 0) for i in indices], dim=0) y = torch.cat([torch.unsqueeze(self.samples[i][3], 0) for i in indices], dim=0) return d1, d2, c, y