A Persian grapheme-to-phoneme (G2P) model designed for homograph disambiguation, fine-tuned using the HomoRich dataset to improve pronunciation accuracy.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

finetune-ge2pe.py 12KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381
  1. # %%
  2. import os
  3. import pandas as pd
  4. import numpy as np
  5. import evaluate
  6. from transformers import AutoTokenizer, T5ForConditionalGeneration, Seq2SeqTrainer, Seq2SeqTrainingArguments
  7. from dataclasses import dataclass
  8. from typing import Union, Dict, List
  9. import pandas as pd
  10. import numpy as np
  11. from datasets import Dataset
  12. import argparse
  13. import torch
  14. import evaluate
  15. import os
  16. from dataclasses import dataclass
  17. from typing import Union, Dict, List, Optional
  18. from transformers import AdamW, AutoTokenizer, T5ForConditionalGeneration, T5Config
  19. from transformers import (
  20. DataCollator,
  21. Seq2SeqTrainer,
  22. Seq2SeqTrainingArguments,
  23. set_seed,
  24. )
  25. os.environ["WANDB_DISABLED"] = "true"
  26. # %%
  27. set_seed(41)
  28. # %%
  29. def prepare_dataset(batch):
  30. batch['input_ids'] = batch['Grapheme']
  31. batch['labels'] = batch['Mapped Phoneme']
  32. return batch
  33. # %%
  34. # Data collator for padding
  35. @dataclass
  36. class DataCollatorWithPadding:
  37. tokenizer: AutoTokenizer
  38. padding: Union[bool, str] = True
  39. def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
  40. words = [feature["input_ids"] for feature in features]
  41. prons = [feature["labels"] for feature in features]
  42. batch = self.tokenizer(words, padding=self.padding, add_special_tokens=False, return_attention_mask=True, return_tensors='pt')
  43. pron_batch = self.tokenizer(prons, padding=self.padding, add_special_tokens=True, return_attention_mask=True, return_tensors='pt')
  44. batch['labels'] = pron_batch['input_ids'].masked_fill(pron_batch.attention_mask.ne(1), -100)
  45. return batch
  46. # %%
  47. # Compute metrics (CER and WER)
  48. def compute_metrics(pred):
  49. labels_ids = pred.label_ids
  50. pred_ids = pred.predictions
  51. pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
  52. labels_ids[labels_ids == -100] = tokenizer.pad_token_id
  53. label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
  54. cer = cer_metric.compute(predictions=pred_str, references=label_str)
  55. wer = wer_metric.compute(predictions=pred_str, references=label_str)
  56. return {"cer": cer, 'wer': wer}
  57. # setting the evaluation metrics
  58. cer_metric = evaluate.load("cer")
  59. wer_metric = evaluate.load('wer')
  60. # %% [markdown]
  61. # # Phase 1
  62. # %%
  63. def load_pronuncation_dictionary(path, train=True, homograph_only=False, human=False) -> Dataset:
  64. # path = '/media/external_10TB/mahta_fetrat/PersianG2P_final.csv'
  65. # Read the CSV file
  66. df = pd.read_csv(path, index_col=[0])
  67. if homograph_only:
  68. if human:
  69. df = df[df['Source'] == 'human']
  70. if not human:
  71. df = df[df['Source'] != 'human']
  72. # Drop unnecessary columns
  73. df = df.drop(['Source', 'Source ID'], axis=1)
  74. # Drop rows where 'Phoneme' is NaN
  75. df = df.dropna(subset=['Mapped Phoneme'])
  76. # Filter rows based on phoneme length
  77. Plen = np.array([len(i) for i in df['Mapped Phoneme']])
  78. df = df.iloc[Plen < 512, :]
  79. # Filter rows based on 'Homograph Grapheme' column
  80. if homograph_only:
  81. df = df[df['Homograph Grapheme'].notna() & (df['Homograph Grapheme'] != '')]
  82. else:
  83. df = df[df['Homograph Grapheme'].isna() | (df['Homograph Grapheme'] == '')]
  84. # Shuffle the DataFrame
  85. df = df.sample(frac=1)
  86. # Split into train and test sets
  87. if train:
  88. return Dataset.from_pandas(df.iloc[:len(df)-90, :])
  89. else:
  90. return Dataset.from_pandas(df.iloc[len(df)-90:, :])
  91. # %%
  92. # Load datasets (only rows with 'Homograph Grapheme')
  93. train_data = load_pronuncation_dictionary('PersianG2P_final.csv', train=True)
  94. train_data = train_data.map(prepare_dataset)
  95. train_dataset = train_data
  96. dev_data = load_pronuncation_dictionary('PersianG2P_final.csv', train=False)
  97. dev_data = dev_data.map(prepare_dataset)
  98. dev_dataset = dev_data
  99. # Load tokenizer and model from checkpoint
  100. checkpoint_path = "checkpoint-320" # Path to your checkpoint
  101. tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
  102. model = T5ForConditionalGeneration.from_pretrained(checkpoint_path)
  103. # Data collator
  104. data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
  105. # Training arguments (default values)
  106. training_args = Seq2SeqTrainingArguments(
  107. output_dir="./phase1-30-ep", # Directory to save the fine-tuned model
  108. predict_with_generate=True,
  109. generation_num_beams=5,
  110. generation_max_length=512,
  111. evaluation_strategy="steps",
  112. per_device_train_batch_size=32, # Default batch size
  113. per_device_eval_batch_size=100, # Default batch size
  114. num_train_epochs=5, # Fewer epochs for this step
  115. learning_rate=5e-4, # Default learning rate
  116. warmup_steps=1000, # Default warmup steps
  117. logging_steps=1000, # Default logging steps
  118. save_steps=4000, # Default save steps
  119. eval_steps=1000, # Default evaluation steps
  120. save_total_limit=2, # Keep only the last 2 checkpoints
  121. load_best_model_at_end=True, # Load the best model at the end of training
  122. fp16=False, # Disable FP16 by default
  123. )
  124. # Trainer
  125. trainer = Seq2SeqTrainer(
  126. model=model,
  127. tokenizer=tokenizer,
  128. args=training_args,
  129. compute_metrics=compute_metrics,
  130. train_dataset=train_dataset,
  131. eval_dataset=dev_dataset,
  132. data_collator=data_collator,
  133. )
  134. # Fine-tune the model
  135. trainer.train()
  136. # Save the fine-tuned model
  137. trainer.save_model("./phase1-30-ep")
  138. # %%
  139. import matplotlib.pyplot as plt
  140. # Extract training and validation loss from the log history
  141. train_loss = []
  142. val_loss = []
  143. for log in trainer.state.log_history:
  144. if "loss" in log:
  145. train_loss.append(log["loss"])
  146. if "eval_loss" in log:
  147. val_loss.append(log["eval_loss"])
  148. # Plot the training and validation loss
  149. plt.figure(figsize=(10, 6))
  150. plt.plot(train_loss, label="Training Loss", marker="o")
  151. plt.plot(val_loss, label="Validation Loss", marker="o")
  152. plt.xlabel("Steps")
  153. plt.ylabel("Loss")
  154. plt.title("Training and Validation Loss")
  155. plt.legend()
  156. plt.grid()
  157. # Save the plot to disk
  158. plt.savefig("phase1-30-ep.png")
  159. # Optionally, close the plot to free up memory
  160. plt.close()
  161. # %% [markdown]
  162. # Phase 2
  163. # %%
  164. # Load datasets (only rows with 'Homograph Grapheme')
  165. train_data = load_pronuncation_dictionary('PersianG2P_final.csv',
  166. train=True,
  167. homograph_only=True)
  168. train_data = train_data.map(prepare_dataset)
  169. train_dataset = train_data
  170. dev_data = load_pronuncation_dictionary('PersianG2P_final.csv',
  171. train=False,
  172. homograph_only=True)
  173. dev_data = dev_data.map(prepare_dataset)
  174. dev_dataset = dev_data
  175. # Load tokenizer and model from the previous fine-tuning step
  176. checkpoint_path = "./phase1-30-ep" # Path to the model from Step 1
  177. tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
  178. model = T5ForConditionalGeneration.from_pretrained(checkpoint_path)
  179. # Data collator
  180. data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
  181. # Training arguments (default values)
  182. training_args = Seq2SeqTrainingArguments(
  183. output_dir="./phase2-30-ep", # Directory to save the final fine-tuned model
  184. predict_with_generate=True,
  185. generation_num_beams=5,
  186. generation_max_length=512,
  187. evaluation_strategy="steps",
  188. per_device_train_batch_size=32, # Default batch size
  189. per_device_eval_batch_size=100, # Default batch size
  190. num_train_epochs=30, # More epochs for this step
  191. learning_rate=5e-4, # Lower learning rate for fine-tuning
  192. warmup_steps=1000, # Default warmup steps
  193. logging_steps=1000, # Default logging steps
  194. save_steps=4000, # Default save steps
  195. eval_steps=1000, # Default evaluation steps
  196. save_total_limit=2, # Keep only the last 2 checkpoints
  197. load_best_model_at_end=True, # Load the best model at the end of training
  198. fp16=False, # Disable FP16 by default
  199. )
  200. # Trainer
  201. trainer = Seq2SeqTrainer(
  202. model=model,
  203. tokenizer=tokenizer,
  204. args=training_args,
  205. compute_metrics=compute_metrics,
  206. train_dataset=train_dataset,
  207. eval_dataset=dev_dataset,
  208. data_collator=data_collator,
  209. )
  210. # Fine-tune the model
  211. trainer.train()
  212. # Save the fine-tuned model
  213. trainer.save_model("./phase2-30-ep")
  214. # %%
  215. import matplotlib.pyplot as plt
  216. # Extract training and validation loss from the log history
  217. train_loss = []
  218. val_loss = []
  219. for log in trainer.state.log_history:
  220. if "loss" in log:
  221. train_loss.append(log["loss"])
  222. if "eval_loss" in log:
  223. val_loss.append(log["eval_loss"])
  224. # Plot the training and validation loss
  225. plt.figure(figsize=(10, 6))
  226. plt.plot(train_loss, label="Training Loss", marker="o")
  227. plt.plot(val_loss, label="Validation Loss", marker="o")
  228. plt.xlabel("Steps")
  229. plt.ylabel("Loss")
  230. plt.title("Training and Validation Loss")
  231. plt.legend()
  232. plt.grid()
  233. # Save the plot to disk
  234. plt.savefig("phase2-30-ep.png")
  235. # Optionally, close the plot to free up memory
  236. plt.close()
  237. # %% [markdown]
  238. # # Phase 3
  239. # %%
  240. # Load datasets (only rows with 'Homograph Grapheme')
  241. train_data = load_pronuncation_dictionary('PersianG2P_final_augmented_final.csv',
  242. train=True,
  243. homograph_only=True,
  244. human=True)
  245. train_data = train_data.map(prepare_dataset)
  246. train_dataset = train_data
  247. dev_data = load_pronuncation_dictionary('PersianG2P_final_augmented_final.csv',
  248. train=False,
  249. homograph_only=True,
  250. human=True)
  251. dev_data = dev_data.map(prepare_dataset)
  252. dev_dataset = dev_data
  253. # Load tokenizer and model from the previous fine-tuning step
  254. checkpoint_path = "./phase2-30-ep" # Path to the model from Step 1
  255. tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
  256. model = T5ForConditionalGeneration.from_pretrained(checkpoint_path)
  257. # Data collator
  258. data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
  259. # Training arguments (default values)
  260. training_args = Seq2SeqTrainingArguments(
  261. output_dir="./phase3-30-ep", # Directory to save the final fine-tuned model
  262. predict_with_generate=True,
  263. generation_num_beams=5,
  264. generation_max_length=512,
  265. evaluation_strategy="steps",
  266. per_device_train_batch_size=32, # Default batch size
  267. per_device_eval_batch_size=100, # Default batch size
  268. num_train_epochs=50, # More epochs for this step
  269. learning_rate=5e-4, # Lower learning rate for fine-tuning
  270. warmup_steps=1000, # Default warmup steps
  271. logging_steps=1000, # Default logging steps
  272. save_steps=4000, # Default save steps
  273. eval_steps=1000, # Default evaluation steps
  274. save_total_limit=2, # Keep only the last 2 checkpoints
  275. load_best_model_at_end=True, # Load the best model at the end of training
  276. fp16=False, # Disable FP16 by default
  277. )
  278. # Trainer
  279. trainer = Seq2SeqTrainer(
  280. model=model,
  281. tokenizer=tokenizer,
  282. args=training_args,
  283. compute_metrics=compute_metrics,
  284. train_dataset=train_dataset,
  285. eval_dataset=dev_dataset,
  286. data_collator=data_collator,
  287. )
  288. # Fine-tune the model
  289. trainer.train()
  290. # Save the fine-tuned model
  291. trainer.save_model("./phase3-30-ep")
  292. # %%
  293. import matplotlib.pyplot as plt
  294. # Extract training and validation loss from the log history
  295. train_loss = []
  296. val_loss = []
  297. for log in trainer.state.log_history:
  298. if "loss" in log:
  299. train_loss.append(log["loss"])
  300. if "eval_loss" in log:
  301. val_loss.append(log["eval_loss"])
  302. # Plot the training and validation loss
  303. plt.figure(figsize=(10, 6))
  304. plt.plot(train_loss, label="Training Loss", marker="o")
  305. plt.plot(val_loss, label="Validation Loss", marker="o")
  306. plt.xlabel("Steps")
  307. plt.ylabel("Loss")
  308. plt.title("Training and Validation Loss")
  309. plt.legend()
  310. plt.grid()
  311. # Save the plot to disk
  312. plt.savefig("phase3-30-ep.png")
  313. # Optionally, close the plot to free up memory
  314. plt.close()