A Persian grapheme-to-phoneme (G2P) model designed for homograph disambiguation, fine-tuned using the HomoRich dataset to improve pronunciation accuracy.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

GE2PE.py 4.7KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. from transformers import AutoTokenizer, T5ForConditionalGeneration
  2. from Parsivar.normalizer import Normalizer
  3. class GE2PE():
  4. def __init__(self, model_path = './content/checkpoint-320', GPU = False, dictionary = None):
  5. """
  6. model_path: path to where the GE2PE transformer is saved.
  7. GPU: boolean indicating use of GPU in generation.
  8. dictionary: a dictionary for self-defined words.
  9. """
  10. self.GPU = GPU
  11. self.model = T5ForConditionalGeneration.from_pretrained(model_path)
  12. if self.GPU:
  13. self.model = self.model.cuda()
  14. self.tokenizer = AutoTokenizer.from_pretrained(model_path)
  15. self.dictionary = dictionary
  16. self.norma = Normalizer(pinglish_conversion_needed=True)
  17. def is_vowel(self, char):
  18. return (char in ['a', '/', 'i', 'e', 'u', 'o'])
  19. def rules(self, grapheme, phoneme):
  20. grapheme = grapheme.replace('آ', 'ءا')
  21. words = grapheme.split(' ')
  22. prons = phoneme.replace('1', '').split(' ')
  23. if len(words) != len(prons):
  24. return phoneme
  25. for i in range(len(words)):
  26. if 'ِ' not in words[i] and 'ُ' not in words[i] and 'َ' not in words[i]:
  27. continue
  28. for j in range(len(words[i])):
  29. if words[i][j] == 'َ':
  30. if j == len(words[i]) - 1 and prons[i][-1] != '/':
  31. prons[i] = prons[i] + '/'
  32. elif self.is_vowel(prons[i][j]):
  33. prons[i] = prons[i][:j] + '/' + prons[i][j+1:]
  34. else:
  35. prons[i] = prons[i][:j] + '/' + prons[i][j:]
  36. if words[i][j] == 'ِ':
  37. if j == len(words[i]) - 1 and prons[i][-1] != 'e':
  38. prons[i] = prons[i] + 'e'
  39. elif self.is_vowel(prons[i][j]):
  40. prons[i] = prons[i][:j] + 'e' + prons[i][j+1:]
  41. else:
  42. prons[i] = prons[i][:j] + 'e' + prons[i][j:]
  43. if words[i][j] == 'ُ':
  44. if j == len(words[i]) - 1 and prons[i][-1] != 'o':
  45. prons[i] = prons[i] + 'o'
  46. elif self.is_vowel(prons[i][j]):
  47. prons[i] = prons[i][:j] + 'o' + prons[i][j+1:]
  48. else:
  49. prons[i] = prons[i][:j] + 'o' + prons[i][j:]
  50. return ' '.join(prons)
  51. def lexicon(self, grapheme, phoneme):
  52. words = grapheme.split(' ')
  53. prons = phoneme.split(' ')
  54. output = prons
  55. for i in range(len(words)):
  56. try:
  57. output[i] = self.dictionary[words[i]]
  58. if prons[i][-1] == '1' and output[i][-1] != 'e':
  59. output[i] = output[i] + 'e1'
  60. elif prons[i][-1] == '1' and output[i][-1] == 'e':
  61. output[i] = output[i] + 'ye1'
  62. except:
  63. pass
  64. return ' '.join(output)
  65. def generate(self, input_list, batch_size = 10, use_rules = False, use_dict = False):
  66. """
  67. input_list: list of sentences to be phonemized.
  68. batch_size: inference batch_size
  69. use_rules: boolean indicating the use of rules to apply short vowels.
  70. use_dict: boolean indicating the use of self-defined dictionary.
  71. returns the list of phonemized sentences.
  72. """
  73. output_list = []
  74. input_list = [self.norma.normalize(text).replace('ك', 'ک') for text in input_list]
  75. input = input_list
  76. input_list = [text.replace('ِ', '').replace('ُ', '').replace('َ', '') for text in input_list]
  77. for i in range(0,len(input_list),batch_size):
  78. in_ids = self.tokenizer(input_list[i:i+batch_size], padding=True,add_special_tokens=False, return_attention_mask=True,return_tensors='pt')
  79. if self.GPU:
  80. out_ids = self.model.generate(in_ids["input_ids"].cuda(), attention_mask=in_ids["attention_mask"].cuda(), num_beams=5,
  81. min_length= 1, max_length=512, early_stopping=True,)
  82. else:
  83. out_ids = self.model.generate(in_ids["input_ids"], attention_mask=in_ids["attention_mask"], num_beams=5,
  84. min_length= 1, max_length=512, early_stopping=True,)
  85. output_list += self.tokenizer.batch_decode(out_ids, skip_special_tokens=True)
  86. if use_dict:
  87. for i in range(len(input_list)):
  88. output_list[i] = self.lexicon(input_list[i], output_list[i])
  89. if use_rules:
  90. for i in range(len(input_list)):
  91. output_list[i] = self.rules(input[i], output_list[i])
  92. output_list = [i.strip() for i in output_list]
  93. return output_list