A Persian grapheme-to-phoneme (G2P) model designed for homograph disambiguation, fine-tuned using the HomoRich dataset to improve pronunciation accuracy.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

finetune-t5.py 12KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397
  1. # %%
  2. import os
  3. import pandas as pd
  4. import numpy as np
  5. import evaluate
  6. from transformers import AutoTokenizer, T5ForConditionalGeneration, Seq2SeqTrainer, Seq2SeqTrainingArguments
  7. from dataclasses import dataclass
  8. from typing import Union, Dict, List
  9. import pandas as pd
  10. import numpy as np
  11. from datasets import Dataset
  12. import argparse
  13. import torch
  14. import evaluate
  15. import os
  16. from dataclasses import dataclass
  17. from typing import Union, Dict, List, Optional
  18. from transformers import AdamW, AutoTokenizer, T5ForConditionalGeneration, T5Config
  19. from transformers import (
  20. DataCollator,
  21. Seq2SeqTrainer,
  22. Seq2SeqTrainingArguments,
  23. set_seed,
  24. )
  25. os.environ["WANDB_DISABLED"] = "true"
  26. # %%
  27. set_seed(41)
  28. # %%
  29. def prepare_dataset(batch):
  30. batch['input_ids'] = batch['Grapheme']
  31. batch['labels'] = batch['Mapped Phoneme']
  32. return batch
  33. # %%
  34. # Data collator for padding
  35. @dataclass
  36. class DataCollatorWithPadding:
  37. tokenizer: AutoTokenizer
  38. padding: Union[bool, str] = True
  39. def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
  40. words = [feature["input_ids"] for feature in features]
  41. prons = [feature["labels"] for feature in features]
  42. batch = self.tokenizer(words, padding=self.padding, add_special_tokens=False, return_attention_mask=True, return_tensors='pt')
  43. pron_batch = self.tokenizer(prons, padding=self.padding, add_special_tokens=True, return_attention_mask=True, return_tensors='pt')
  44. batch['labels'] = pron_batch['input_ids'].masked_fill(pron_batch.attention_mask.ne(1), -100)
  45. return batch
  46. # %%
  47. # Compute metrics (CER and WER)
  48. def compute_metrics(pred):
  49. labels_ids = pred.label_ids
  50. pred_ids = pred.predictions
  51. pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
  52. labels_ids[labels_ids == -100] = tokenizer.pad_token_id
  53. label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
  54. cer = cer_metric.compute(predictions=pred_str, references=label_str)
  55. wer = wer_metric.compute(predictions=pred_str, references=label_str)
  56. return {"cer": cer, 'wer': wer}
  57. # setting the evaluation metrics
  58. cer_metric = evaluate.load("cer")
  59. wer_metric = evaluate.load('wer')
  60. # %% [markdown]
  61. # # Phase 1
  62. # %%
  63. def load_pronuncation_dictionary(path, train=True, homograph_only=False, human=False) -> Dataset:
  64. # path = 'PersianG2P_final.csv'
  65. # Read the CSV file
  66. df = pd.read_csv(path, index_col=[0])
  67. if homograph_only:
  68. if human:
  69. df = df[df['Source'] == 'human']
  70. if not human:
  71. df = df[df['Source'] != 'human']
  72. # Drop unnecessary columns
  73. df = df.drop(['Source', 'Source ID'], axis=1)
  74. # Drop rows where 'Phoneme' is NaN
  75. df = df.dropna(subset=['Mapped Phoneme'])
  76. # Filter rows based on phoneme length
  77. Plen = np.array([len(i) for i in df['Mapped Phoneme']])
  78. df = df.iloc[Plen < 512, :]
  79. # Filter rows based on 'Homograph Grapheme' column
  80. if homograph_only:
  81. df = df[df['Homograph Grapheme'].notna() & (df['Homograph Grapheme'] != '')]
  82. else:
  83. df = df[df['Homograph Grapheme'].isna() | (df['Homograph Grapheme'] == '')]
  84. # Shuffle the DataFrame
  85. df = df.sample(frac=1)
  86. # Split into train and test sets
  87. if train:
  88. return Dataset.from_pandas(df.iloc[:len(df)-90, :])
  89. else:
  90. return Dataset.from_pandas(df.iloc[len(df)-90:, :])
  91. # %%
  92. # Load datasets (only rows with 'Homograph Grapheme')
  93. train_data = load_pronuncation_dictionary('PersianG2P_final.csv', train=True)
  94. train_data = train_data.map(prepare_dataset)
  95. train_dataset = train_data
  96. dev_data = load_pronuncation_dictionary('PersianG2P_final.csv', train=False)
  97. dev_data = dev_data.map(prepare_dataset)
  98. dev_dataset = dev_data
  99. # # Load tokenizer and model from checkpoint
  100. # checkpoint_path = "checkpoint-320" # Path to your checkpoint
  101. # tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
  102. # model = T5ForConditionalGeneration.from_pretrained(checkpoint_path)
  103. # # Load tokenizer and model from checkpoint
  104. # checkpoint_path = "checkpoint-320" # Path to your checkpoint
  105. tokenizer = AutoTokenizer.from_pretrained('google/byt5-small')
  106. # model = T5ForConditionalGeneration.from_pretrained(checkpoint_path)
  107. config = T5Config.from_pretrained('google/byt5-small')
  108. config.num_decoder_layers = 2
  109. config.num_layers = 2
  110. config.d_kv = 64
  111. config.d_model = 512
  112. config.d_ff = 512
  113. print('Initializing a ByT5 model...')
  114. model = T5ForConditionalGeneration(config)
  115. # Data collator
  116. data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
  117. # Training arguments (default values)
  118. training_args = Seq2SeqTrainingArguments(
  119. output_dir="./phase1-t5", # Directory to save the fine-tuned model
  120. predict_with_generate=True,
  121. generation_num_beams=5,
  122. generation_max_length=512,
  123. evaluation_strategy="steps",
  124. per_device_train_batch_size=32, # Default batch size
  125. per_device_eval_batch_size=100, # Default batch size
  126. num_train_epochs=5, # Fewer epochs for this step
  127. learning_rate=5e-4, # Default learning rate
  128. warmup_steps=1000, # Default warmup steps
  129. logging_steps=1000, # Default logging steps
  130. save_steps=4000, # Default save steps
  131. eval_steps=1000, # Default evaluation steps
  132. save_total_limit=2, # Keep only the last 2 checkpoints
  133. load_best_model_at_end=True, # Load the best model at the end of training
  134. fp16=False, # Disable FP16 by default
  135. remove_unused_columns=False,
  136. )
  137. # Trainer
  138. trainer = Seq2SeqTrainer(
  139. model=model,
  140. tokenizer=tokenizer,
  141. args=training_args,
  142. compute_metrics=compute_metrics,
  143. train_dataset=train_dataset,
  144. eval_dataset=dev_dataset,
  145. data_collator=data_collator,
  146. )
  147. # Fine-tune the model
  148. trainer.train()
  149. # Save the fine-tuned model
  150. trainer.save_model("./phase1-t5")
  151. # %%
  152. import matplotlib.pyplot as plt
  153. # Extract training and validation loss from the log history
  154. train_loss = []
  155. val_loss = []
  156. for log in trainer.state.log_history:
  157. if "loss" in log:
  158. train_loss.append(log["loss"])
  159. if "eval_loss" in log:
  160. val_loss.append(log["eval_loss"])
  161. # Plot the training and validation loss
  162. plt.figure(figsize=(10, 6))
  163. plt.plot(train_loss, label="Training Loss", marker="o")
  164. plt.plot(val_loss, label="Validation Loss", marker="o")
  165. plt.xlabel("Steps")
  166. plt.ylabel("Loss")
  167. plt.title("Training and Validation Loss")
  168. plt.legend()
  169. plt.grid()
  170. # Save the plot to disk
  171. plt.savefig("phase1-t5.png")
  172. # Optionally, close the plot to free up memory
  173. plt.close()
  174. # %% [markdown]
  175. # # Phase 2
  176. # %%
  177. # Load datasets (only rows with 'Homograph Grapheme')
  178. train_data = load_pronuncation_dictionary('PersianG2P_final.csv',
  179. train=True,
  180. homograph_only=True)
  181. train_data = train_data.map(prepare_dataset)
  182. train_dataset = train_data
  183. dev_data = load_pronuncation_dictionary('PersianG2P_final.csv',
  184. train=False,
  185. homograph_only=True)
  186. dev_data = dev_data.map(prepare_dataset)
  187. dev_dataset = dev_data
  188. # Load tokenizer and model from the previous fine-tuning step
  189. checkpoint_path = "./phase1-t5" # Path to the model from Step 1
  190. tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
  191. model = T5ForConditionalGeneration.from_pretrained(checkpoint_path)
  192. # Data collator
  193. data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
  194. # Training arguments (default values)
  195. training_args = Seq2SeqTrainingArguments(
  196. output_dir="./phase2-t5", # Directory to save the final fine-tuned model
  197. predict_with_generate=True,
  198. generation_num_beams=5,
  199. generation_max_length=512,
  200. evaluation_strategy="steps",
  201. per_device_train_batch_size=32, # Default batch size
  202. per_device_eval_batch_size=100, # Default batch size
  203. num_train_epochs=30, # More epochs for this step
  204. learning_rate=5e-4, # Lower learning rate for fine-tuning
  205. warmup_steps=1000, # Default warmup steps
  206. logging_steps=1000, # Default logging steps
  207. save_steps=4000, # Default save steps
  208. eval_steps=1000, # Default evaluation steps
  209. save_total_limit=2, # Keep only the last 2 checkpoints
  210. load_best_model_at_end=True, # Load the best model at the end of training
  211. fp16=False, # Disable FP16 by default
  212. )
  213. # Trainer
  214. trainer = Seq2SeqTrainer(
  215. model=model,
  216. tokenizer=tokenizer,
  217. args=training_args,
  218. compute_metrics=compute_metrics,
  219. train_dataset=train_dataset,
  220. eval_dataset=dev_dataset,
  221. data_collator=data_collator,
  222. )
  223. # Fine-tune the model
  224. trainer.train()
  225. # Save the fine-tuned model
  226. trainer.save_model("./phase2-t5")
  227. # %%
  228. import matplotlib.pyplot as plt
  229. # Extract training and validation loss from the log history
  230. train_loss = []
  231. val_loss = []
  232. for log in trainer.state.log_history:
  233. if "loss" in log:
  234. train_loss.append(log["loss"])
  235. if "eval_loss" in log:
  236. val_loss.append(log["eval_loss"])
  237. # Plot the training and validation loss
  238. plt.figure(figsize=(10, 6))
  239. plt.plot(train_loss, label="Training Loss", marker="o")
  240. plt.plot(val_loss, label="Validation Loss", marker="o")
  241. plt.xlabel("Steps")
  242. plt.ylabel("Loss")
  243. plt.title("Training and Validation Loss")
  244. plt.legend()
  245. plt.grid()
  246. # Save the plot to disk
  247. plt.savefig("phase2-t5.png")
  248. # Optionally, close the plot to free up memory
  249. plt.close()
  250. # %% [markdown]
  251. # # Phase 3
  252. # %%
  253. # Load datasets (only rows with 'Homograph Grapheme')
  254. train_data = load_pronuncation_dictionary('PersianG2P_final_augmented_final.csv',
  255. train=True,
  256. homograph_only=True,
  257. human=True)
  258. train_data = train_data.map(prepare_dataset)
  259. train_dataset = train_data
  260. dev_data = load_pronuncation_dictionary('PersianG2P_final_augmented_final.csv',
  261. train=False,
  262. homograph_only=True,
  263. human=True)
  264. dev_data = dev_data.map(prepare_dataset)
  265. dev_dataset = dev_data
  266. # Load tokenizer and model from the previous fine-tuning step
  267. checkpoint_path = "./phase2-t5" # Path to the model from Step 1
  268. tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
  269. model = T5ForConditionalGeneration.from_pretrained(checkpoint_path)
  270. # Data collator
  271. data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
  272. # Training arguments (default values)
  273. training_args = Seq2SeqTrainingArguments(
  274. output_dir="./phase3-t5", # Directory to save the final fine-tuned model
  275. predict_with_generate=True,
  276. generation_num_beams=5,
  277. generation_max_length=512,
  278. evaluation_strategy="steps",
  279. per_device_train_batch_size=32, # Default batch size
  280. per_device_eval_batch_size=100, # Default batch size
  281. num_train_epochs=50, # More epochs for this step
  282. learning_rate=5e-4, # Lower learning rate for fine-tuning
  283. warmup_steps=1000, # Default warmup steps
  284. logging_steps=1000, # Default logging steps
  285. save_steps=4000, # Default save steps
  286. eval_steps=1000, # Default evaluation steps
  287. save_total_limit=2, # Keep only the last 2 checkpoints
  288. load_best_model_at_end=True, # Load the best model at the end of training
  289. fp16=False, # Disable FP16 by default
  290. )
  291. # Trainer
  292. trainer = Seq2SeqTrainer(
  293. model=model,
  294. tokenizer=tokenizer,
  295. args=training_args,
  296. compute_metrics=compute_metrics,
  297. train_dataset=train_dataset,
  298. eval_dataset=dev_dataset,
  299. data_collator=data_collator,
  300. )
  301. # Fine-tune the model
  302. trainer.train()
  303. # Save the fine-tuned model
  304. trainer.save_model("./phase3-t5")
  305. # %%
  306. import matplotlib.pyplot as plt
  307. # Extract training and validation loss from the log history
  308. train_loss = []
  309. val_loss = []
  310. for log in trainer.state.log_history:
  311. if "loss" in log:
  312. train_loss.append(log["loss"])
  313. if "eval_loss" in log:
  314. val_loss.append(log["eval_loss"])
  315. # Plot the training and validation loss
  316. plt.figure(figsize=(10, 6))
  317. plt.plot(train_loss, label="Training Loss", marker="o")
  318. plt.plot(val_loss, label="Validation Loss", marker="o")
  319. plt.xlabel("Steps")
  320. plt.ylabel("Loss")
  321. plt.title("Training and Validation Loss")
  322. plt.legend()
  323. plt.grid()
  324. # Save the plot to disk
  325. plt.savefig("phase3-t5.png")
  326. # Optionally, close the plot to free up memory
  327. plt.close()