You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

datasets.py 5.8KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. import numpy as np
  2. import torch
  3. import random
  4. from torch.utils.data import Dataset
  5. from .utils import read_map, get_index_by_name, read_files, get_DTI_data, get_FP_data
  6. class FastTensorDataLoader:
  7. """
  8. A DataLoader-like object for a set of tensors that can be much faster than
  9. TensorDataset + DataLoader because dataloader grabs individual indices of
  10. the dataset and calls cat (slow).
  11. Source: https://discuss.pytorch.org/t/dataloader-much-slower-than-manual-batching/27014/6
  12. """
  13. def __init__(self, *tensors, batch_size=32, shuffle=False):
  14. """
  15. Initialize a FastTensorDataLoader.
  16. :param *tensors: tensors to store. Must have the same length @ dim 0.
  17. :param batch_size: batch size to load.
  18. :param shuffle: if True, shuffle the data *in-place* whenever an
  19. iterator is created out of this object.
  20. :returns: A FastTensorDataLoader.
  21. """
  22. assert all(t.shape[0] == tensors[0].shape[0] for t in tensors)
  23. self.tensors = tensors
  24. self.dataset_len = self.tensors[0].shape[0]
  25. self.batch_size = batch_size
  26. self.shuffle = shuffle
  27. # Calculate # batches
  28. n_batches, remainder = divmod(self.dataset_len, self.batch_size)
  29. if remainder > 0:
  30. n_batches += 1
  31. self.n_batches = n_batches
  32. def __iter__(self):
  33. if self.shuffle:
  34. r = torch.randperm(self.dataset_len)
  35. self.tensors = [t[r] for t in self.tensors]
  36. self.i = 0
  37. return self
  38. def __next__(self):
  39. if self.i >= self.dataset_len:
  40. raise StopIteration
  41. batch = tuple(t[self.i:self.i + self.batch_size] for t in self.tensors)
  42. self.i += self.batch_size
  43. return batch
  44. def __len__(self):
  45. return self.n_batches
  46. class FastSynergyDataset(Dataset):
  47. def __init__(self, drug2id_file, cell2id_file, cell_feat_file, synergy_score_file, dti_feat_file, fp_feat_file, use_folds,
  48. train=True):
  49. self.drug2id = read_map(drug2id_file, keep_str = True)
  50. self.cell2id = read_map(cell2id_file)
  51. self.cell_feat = np.load(cell_feat_file)
  52. self.samples = []
  53. self.raw_samples = []
  54. self.train = train
  55. dti = get_DTI_data(dti_feat_file)
  56. fp = get_FP_data(fp_feat_file)
  57. valid_drugs = set(self.drug2id.keys())
  58. valid_cells = set(self.cell2id.keys())
  59. files_dict = read_files()
  60. with open(synergy_score_file, 'r') as f:
  61. f.readline()
  62. for line in f:
  63. drug1, drug2, cellname, score, fold = line.rstrip().split('\t')
  64. #fold = random.randint(0, 4)
  65. if drug1 in valid_drugs and drug2 in valid_drugs and cellname in valid_cells:
  66. if int(fold) in use_folds:
  67. drug1_id = get_index_by_name(drug1,files_dict)
  68. drug2_id = get_index_by_name(drug2,files_dict)
  69. sample = [
  70. # TODO: specify drug_feat
  71. # drug1_feat + drug2_feat + cell_feat + score
  72. torch.IntTensor([drug1_id]),
  73. torch.IntTensor([drug2_id]),
  74. fp[drug1],
  75. fp[drug2],
  76. dti[drug1],
  77. dti[drug2],
  78. torch.from_numpy(self.cell_feat[self.cell2id[cellname]]).float(),
  79. torch.FloatTensor([float(score)]),
  80. ]
  81. # print(sample)
  82. self.samples.append(sample)
  83. raw_sample = [drug1, drug2, self.cell2id[cellname], score]
  84. self.raw_samples.append(raw_sample)
  85. if train:
  86. sample = [
  87. torch.IntTensor([drug2_id]),
  88. torch.IntTensor([drug1_id]),
  89. fp[drug2],
  90. fp[drug1],
  91. dti[drug2],
  92. dti[drug1],
  93. torch.from_numpy(self.cell_feat[self.cell2id[cellname]]).float(),
  94. torch.FloatTensor([float(score)]),
  95. ]
  96. self.samples.append(sample)
  97. raw_sample = [drug2, drug1, self.cell2id[cellname], score]
  98. self.raw_samples.append(raw_sample)
  99. def __len__(self):
  100. return len(self.samples)
  101. def __getitem__(self, item):
  102. return self.samples[item]
  103. def cell_feat_len(self):
  104. return self.cell_feat.shape[-1]
  105. def tensor_samples(self, indices=None):
  106. if indices is None:
  107. indices = list(range(len(self)))
  108. # print('-----------------------------')
  109. # print(self.samples)
  110. # print('-----------------')
  111. # print(self.samples[0])
  112. # print(self.samples[i][0] and self.samples[i][1] for i in indices)
  113. i1 = torch.cat([torch.unsqueeze(self.samples[i][0], 0) for i in indices], dim=0)
  114. i2 = torch.cat([torch.unsqueeze(self.samples[i][1], 0) for i in indices], dim=0)
  115. d1 = torch.cat([torch.unsqueeze(self.samples[i][2], 0) for i in indices], dim=0)
  116. d2 = torch.cat([torch.unsqueeze(self.samples[i][3], 0) for i in indices], dim=0)
  117. t1 = torch.cat([torch.unsqueeze(self.samples[i][4], 0) for i in indices], dim=0)
  118. t2 = torch.cat([torch.unsqueeze(self.samples[i][5], 0) for i in indices], dim=0)
  119. c = torch.cat([torch.unsqueeze(self.samples[i][6], 0) for i in indices], dim=0)
  120. y = torch.cat([torch.unsqueeze(self.samples[i][7], 0) for i in indices], dim=0)
  121. return i1, i2, d1, d2, t1, t2, c, y