You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

data.py 13KB


  1. from datasets import load_from_disk
  2. import torch
  3. from torch.utils.data import DataLoader, WeightedRandomSampler
  4. import copy
  5. import sys
  6. import torch
  7. from torch.utils.data.dataset import Dataset
  8. from transformers.tokenization_utils import PreTrainedTokenizer
  9. from dataclasses import dataclass
  10. from typing import Any, Callable, Dict, List, NewType, Tuple, Union
  11. from torch.nn.utils.rnn import pad_sequence
  12. from transformers.tokenization_utils import PreTrainedTokenizer
  13. from transformers.tokenization_utils_base import BatchEncoding
  14. def load_dataset(dataset_name, path, toy_example):
  15. dataset = load_from_disk(f"{path}saved_datasets/{dataset_name}")
  16. # toy example for develop
  17. if toy_example == 1:
  18. dataset["train"] = dataset["train"].select(range(1024))
  19. dataset["validation"] = dataset["validation"].select(range(512))
  20. return dataset
  21. def load_dataloaders(dataset, dataset_name, batch_size, virtual_batch_size, tokenizer, seq_length, dp=1):
  22. data_collator = DataCollatorForData2TextLanguageModeling(tokenizer)
  23. if dataset_name == 'e2e_nlg':
  24. train_dataset = E2ETextDataset(tokenizer,
  25. dataset["train"]["meaning_representation"],
  26. dataset["train"]["human_reference"],
  27. seq_length,
  28. tokenizer.bos_token,
  29. tokenizer.eos_token,
  30. seq_length)
  31. validation_dataset = E2ETextDataset(tokenizer,
  32. dataset["validation"]["meaning_representation"],
  33. dataset["validation"]["human_reference"],
  34. seq_length,
  35. tokenizer.bos_token,
  36. tokenizer.eos_token,
  37. seq_length)
  38. train_data_size = len(dataset["train"])
  39. if dp == 1:
  40. sampler = WeightedRandomSampler([virtual_batch_size/train_data_size for _ in range(train_data_size)], num_samples=train_data_size, replacement=True)
  41. train_loader = DataLoader(train_dataset, batch_size=virtual_batch_size, sampler=sampler, drop_last=True, collate_fn=data_collator)
  42. else:
  43. train_loader = DataLoader(train_dataset, batch_size=virtual_batch_size, collate_fn=data_collator)
  44. validation_loader = DataLoader(validation_dataset, batch_size=batch_size, collate_fn=data_collator)
  45. elif dataset_name == 'dart':
  46. pass
  47. return train_loader, validation_loader
  48. # Copyright (c) Xuechen Li. All Rights Reserved.
  49. #
  50. # Licensed under the Apache License, Version 2.0 (the "License");
  51. # you may not use this file except in compliance with the License.
  52. # You may obtain a copy of the License at
  53. #
  54. # http://www.apache.org/licenses/LICENSE-2.0
  55. #
  56. # Unless required by applicable law or agreed to in writing, software
  57. # distributed under the License is distributed on an "AS IS" BASIS,
  58. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  59. # See the License for the specific language governing permissions and
  60. # limitations under the License.
  61. class E2ETextDataset(Dataset):
  62. def __init__(
  63. self,
  64. tokenizer: PreTrainedTokenizer,
  65. src_lines,
  66. tgt_lines,
  67. block_size: int,
  68. bos_tok: str,
  69. eos_tok: str,
  70. max_seq_len=sys.maxsize,
  71. max_examples=sys.maxsize,
  72. **_,
  73. ):
  74. src_lines = src_lines
  75. tgt_lines = tgt_lines
  76. edited_sents = []
  77. for src, tgt in zip(src_lines, tgt_lines):
  78. sent = ' {} {} '.format(src, bos_tok) + tgt + ' {}'.format(eos_tok)
  79. edited_sents.append(sent)
  80. # --- Filter out super long sentences ---
  81. new_src_lines, new_tgt_lines, new_edited_sents = [], [], []
  82. for src_line, tgt_line, edited_sent in zip(src_lines, tgt_lines, edited_sents):
  83. tokenized_edited_sent = tokenizer.tokenize(edited_sent)
  84. if len(tokenized_edited_sent) <= max_seq_len:
  85. new_src_lines.append(src_line)
  86. new_tgt_lines.append(tgt_line)
  87. new_edited_sents.append(edited_sent)
  88. del src_line, tgt_line, edited_sent
  89. src_lines, tgt_lines, edited_sents = new_src_lines, new_tgt_lines, new_edited_sents
  90. # ---------------------------------------
  91. # --- Truncate the dataset if necessary; this must be after the length filtering. ---
  92. src_lines = src_lines[:max_examples]
  93. tgt_lines = tgt_lines[:max_examples]
  94. edited_sents = edited_sents[:max_examples]
  95. # ---
  96. batch_encoding = tokenizer(
  97. edited_sents,
  98. add_special_tokens=True,
  99. truncation=True,
  100. max_length=block_size,
  101. is_split_into_words=False,
  102. )
  103. self.examples = batch_encoding["input_ids"]
  104. self.labels = copy.deepcopy(self.examples)
  105. # split into category words:
  106. ssl_lst = []
  107. for ss in src_lines:
  108. ssl = [la.split(':')[0].strip() for la in ss.split('|')]
  109. ssl_lst.append(ssl)
  110. self.src_cat = tokenizer(
  111. ssl_lst,
  112. add_special_tokens=True,
  113. truncation=True,
  114. max_length=block_size,
  115. is_split_into_words=True
  116. )['input_ids']
  117. self.src_sent = []
  118. self.tgt_sent = []
  119. # temp_src_len = 0
  120. # temp_tgt_len = 0
  121. # temp_count = 0
  122. separator = tokenizer(bos_tok, add_special_tokens=False)['input_ids'][0]
  123. for i, elem in enumerate(self.labels):
  124. sep_idx = elem.index(separator) + 1
  125. self.src_sent.append(self.examples[i][:sep_idx - 1])
  126. self.tgt_sent.append(self.examples[i][sep_idx - 1:])
  127. self.labels[i][:sep_idx] = [-100] * sep_idx # Doesn't contribute to loss.
  128. # temp_src_len += sep_idx - 1
  129. # temp_tgt_len += len(elem) - (sep_idx - 1)
  130. # temp_count += 1
  131. # print('tgt_avg: ', temp_tgt_len / temp_count)
  132. # print('src_avg: ', temp_src_len / temp_count)
  133. # print('ratios: ', temp_src_len / temp_tgt_len)
  134. # print(self.labels[0])
  135. # print(self.examples[0])
  136. # print(edited_sents[0])
  137. # print(self.src_sent[0])
  138. # print(self.tgt_sent[0])
  139. # print(self.src_cat[0])
  140. assert len(self.src_cat) == len(self.examples)
  141. def __len__(self):
  142. return len(self.examples)
  143. def __getitem__(self, i):
  144. return (
  145. torch.tensor(self.examples[i], dtype=torch.long),
  146. torch.tensor(self.labels[i], dtype=torch.long),
  147. torch.tensor(self.src_sent[i], dtype=torch.long),
  148. torch.tensor(self.tgt_sent[i], dtype=torch.long),
  149. torch.tensor(self.src_cat[i], dtype=torch.long),
  150. )
  151. # InputDataClass = NewType("InputDataClass", Any)
  152. """
  153. A DataCollator is a function that takes a list of samples from a Dataset
  154. and collate them into a batch, as a dictionary of Tensors.
  155. """
  156. # DataCollator = NewType("DataCollator", Callable[[List[InputDataClass]], Dict[str, torch.Tensor]])
  157. @dataclass
  158. class DataCollatorForData2TextLanguageModeling:
  159. """
  160. Data collator used for language modeling.
  161. - collates batches of tensors, honoring their tokenizer's pad_token
  162. - preprocesses batches for masked language modeling
  163. """
  164. tokenizer: PreTrainedTokenizer
  165. mlm: bool = False
  166. format_mode: str = 'cat'
  167. mlm_probability: float = 0.15
  168. def __call__(
  169. self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]]
  170. ) -> Dict[str, torch.Tensor]:
  171. if isinstance(examples[0], (dict, BatchEncoding)):
  172. examples = [e["input_ids"] for e in examples]
  173. input_ids, labels, src, tgt, cate = zip(*examples)
  174. if self.mlm:
  175. inputs, labels = self.mask_tokens(batch)
  176. return {"input_ids": inputs, "labels": labels}
  177. else:
  178. if self.format_mode == 'cat':
  179. mode_input = 3
  180. elif self.format_mode == 'peek':
  181. mode_input = 1
  182. elif self.format_mode == 'nopeek':
  183. mode_input = 2
  184. elif self.format_mode == 'infix':
  185. mode_input = 4
  186. # mode_input = 1 # means that we take the input again.
  187. # mode_input = 2 # means that we do not peek at src again.
  188. # mode_input = 3 # means that we look at the categories, and see the input again.
  189. if mode_input == 1:
  190. # input, batch
  191. batch = self._tensorize_batch(input_ids)
  192. labels = self._tensorize_batch(labels)
  193. src = self._tensorize_batch(src)
  194. cate_batch, cate_attn = None, None
  195. # tgt = self._tensorize_batch(tgt)
  196. elif mode_input == 2:
  197. # nopeek.
  198. batch = self._tensorize_batch(tgt)
  199. labels = batch.clone()
  200. src = self._tensorize_batch(src)
  201. cate_batch, cate_attn = None, None
  202. elif mode_input == 3:
  203. batch = self._tensorize_batch(input_ids)
  204. labels = self._tensorize_batch(labels)
  205. src = self._tensorize_batch(cate)
  206. cate_batch, cate_attn = None, None
  207. elif mode_input == 4:
  208. batch = self._tensorize_batch(tgt)
  209. labels = batch.clone()
  210. src = self._tensorize_batch(src)
  211. cate_batch = self._tensorize_batch(cate)
  212. cate_attn = (cate_batch != self.tokenizer.pad_token_id)
  213. labels[labels == self.tokenizer.pad_token_id] = -100 # tgt
  214. src_attn = (src != self.tokenizer.pad_token_id) # src
  215. tgt_attn = (batch != self.tokenizer.pad_token_id) # tgt
  216. if cate_batch is None:
  217. return {"input_ids": batch, "labels": labels, 'src_attn': src_attn, 'tgt_attn':tgt_attn,
  218. 'src':src}
  219. else:
  220. return {"input_ids": batch, "labels": labels, 'src_attn': src_attn, 'tgt_attn': tgt_attn,
  221. 'src': src, "cate_batch":cate_batch, "cate_attn":cate_attn}
  222. def _tensorize_batch(
  223. self, examples: List[Union[List[int], torch.Tensor, Dict[str, torch.Tensor]]]
  224. ) -> torch.Tensor:
  225. # In order to accept both lists of lists and lists of Tensors
  226. if isinstance(examples[0], (list, tuple)):
  227. examples = [torch.tensor(e, dtype=torch.long) for e in examples]
  228. length_of_first = examples[0].size(0)
  229. are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
  230. if are_tensors_same_length:
  231. return torch.stack(examples, dim=0)
  232. else:
  233. if self.tokenizer._pad_token is None:
  234. raise ValueError(
  235. "You are attempting to pad samples but the tokenizer you are using"
  236. f" ({self.tokenizer.__class__.__name__}) does not have one."
  237. )
  238. return pad_sequence(examples, batch_first=True, padding_value=self.tokenizer.pad_token_id)
  239. def mask_tokens(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
  240. """
  241. Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
  242. """
  243. if self.tokenizer.mask_token is None:
  244. raise ValueError(
  245. "This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer."
  246. )
  247. labels = inputs.clone()
  248. # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
  249. probability_matrix = torch.full(labels.shape, self.mlm_probability)
  250. special_tokens_mask = [
  251. self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
  252. ]
  253. probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
  254. if self.tokenizer._pad_token is not None:
  255. padding_mask = labels.eq(self.tokenizer.pad_token_id)
  256. probability_matrix.masked_fill_(padding_mask, value=0.0)
  257. masked_indices = torch.bernoulli(probability_matrix).bool()
  258. labels[~masked_indices] = -100 # We only compute loss on masked tokens
  259. # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
  260. indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
  261. inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
  262. # 10% of the time, we replace masked input tokens with random word
  263. indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
  264. random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
  265. inputs[indices_random] = random_words[indices_random]
  266. # The rest of the time (10% of the time) we keep the masked input tokens unchanged
  267. return inputs, labels