You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

model.py 11KB

2 weeks ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. from transformers import GPT2Tokenizer, GPT2Config, GPT2LMHeadModel
  2. import torch
  3. import transformers
  4. from torch import nn
  5. from transformers.models.gpt2.modeling_gpt2 import GPT2MLP
  6. import os
  7. # Loads model and its tokenizer
  8. def load_model(model_name, cache_dir="."):
  9. tokenizer = GPT2Tokenizer.from_pretrained(f"{cache_dir}gpt2/{model_name}-tokenizer")
  10. model = GPT2LMHeadModel.from_pretrained(f"{cache_dir}gpt2/{model_name}-model")
  11. add_pad_token(model, tokenizer)
  12. model.requires_grad_(False)
  13. return model, tokenizer
  14. # Adds padding token to the tokenizer and model embedding layer
  15. def add_pad_token(model, tokenizer):
  16. tokenizer.add_special_tokens({'pad_token': '[PAD]'})
  17. model.resize_token_embeddings(len(tokenizer))
  18. a = model.get_input_embeddings().weight
  19. a.data[-1] = a.data[:-1].mean(dim=0)
  20. # Returns number of trainbale parameters of the model
  21. def get_number_of_trainable_parameters(model):
  22. return sum(p.numel() for p in model.parameters() if p.requires_grad)
  23. # Returns number of parameters of the model
  24. def get_number_of_parameters(model):
  25. return sum(p.numel() for p in model.parameters())
  26. # Mutates model structure and adjusts trainable parameters
  27. def prepare_model(model, cfg):
  28. if cfg.peft_mode == 'bitfit':
  29. for a, b in model.named_parameters():
  30. if 'bias' in a:
  31. b.requires_grad = True
  32. elif cfg.peft_mode == 'lora':
  33. model.requires_grad_(True)
  34. model = convert_gpt2_attention_to_lora(model, cfg.rank, cfg.alpha, cfg.drop_out)
  35. mark_only_lora_as_trainable(model)
  36. elif cfg.peft_mode == 'lorabitfit':
  37. model.requires_grad_(True)
  38. model = convert_gpt2_attention_to_lora(model, cfg.rank, cfg.alpha, cfg.drop_out)
  39. mark_only_lora_as_trainable(model)
  40. if cfg.two_step_training == 0:
  41. for a, b in model.named_parameters():
  42. if 'bias' in a:
  43. b.requires_grad = True
  44. elif cfg.peft_mode == 'full':
  45. model.requires_grad_(True)
  46. elif cfg.peft_mode == 'adapter':
  47. model.requires_grad_(False)
  48. bottleneck_size = model.config.n_embd // cfg.reduction_factor
  49. mutate_model_adapter(model, bottleneck_size, model.config.n_embd)
  50. for a, b in model.named_parameters():
  51. if 'adapter' in a:
  52. b.requires_grad = True
  53. elif cfg.peft_mode == 'adapterbitfit':
  54. model.requires_grad_(False)
  55. bottleneck_size = model.config.n_embd // cfg.reduction_factor
  56. mutate_model_adapter(model, bottleneck_size, model.config.n_embd)
  57. if cfg.two_step_training == 0:
  58. for a, b in model.named_parameters():
  59. if 'adapter' in a or 'bias' in a:
  60. b.requires_grad = True
  61. else:
  62. for a, b in model.named_parameters():
  63. if 'adapter' in a:
  64. b.requires_grad = True
  65. model.to(cfg.device)
  66. return model
  67. def save_model(model, peft_mode, save_path, model_name):
  68. if not os.path.exists(save_path):
  69. os.makedirs(save_path)
  70. if peft_mode == "bitfit":
  71. bias_params = {}
  72. for name, param in model.named_parameters():
  73. if 'bias' in name:
  74. bias_params[name] = param.data.clone()
  75. torch.save(bias_params, f'{save_path}/{model_name}.pth')
  76. elif peft_mode == 'lora':
  77. lora_params = {}
  78. for name, param in model.named_parameters():
  79. if 'lora' in name:
  80. lora_params[name] = param.data.clone()
  81. torch.save(lora_params, f'{save_path}/{model_name}.pth')
  82. elif peft_mode == 'lorabitfit':
  83. lorabitfit_params = {}
  84. for name, param in model.named_parameters():
  85. if 'lora' in name or 'bias' in name:
  86. lorabitfit_params[name] = param.data.clone()
  87. torch.save(lorabitfit_params, f'{save_path}/{model_name}.pth')
  88. elif peft_mode == 'full':
  89. pass
  90. elif peft_mode == 'adapter':
  91. adapter_params = {}
  92. for name, param in model.named_parameters():
  93. if 'adapter' in name:
  94. adapter_params[name] = param.data.clone()
  95. torch.save(adapter_params, f'{save_path}/{model_name}.pth')
  96. elif peft_mode == 'adapterbitfit':
  97. adapterbitfit_params = {}
  98. for name, param in model.named_parameters():
  99. if 'adapter' in name or 'bias' in name:
  100. adapterbitfit_params[name] = param.data.clone()
  101. torch.save(adapterbitfit_params, f'{save_path}/{model_name}.pth')
  102. def load_model_weights(model, peft_mode, path):
  103. if peft_mode == 'full':
  104. pass
  105. else:
  106. model_weights = torch.load(path)
  107. with torch.no_grad():
  108. for name, param in model.named_parameters():
  109. if name in model_weights:
  110. param.copy_(model_weights[name])
  111. return model
  112. # Copyright (c) Xuechen Li. All Rights Reserved.
  113. #
  114. # Licensed under the Apache License, Version 2.0 (the "License");
  115. # you may not use this file except in compliance with the License.
  116. # You may obtain a copy of the License at
  117. #
  118. # http://www.apache.org/licenses/LICENSE-2.0
  119. #
  120. # Unless required by applicable law or agreed to in writing, software
  121. # distributed under the License is distributed on an "AS IS" BASIS,
  122. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  123. # See the License for the specific language governing permissions and
  124. # limitations under the License.
  125. """
  126. LoRA layers.
  127. This version does not have merged weights for zero latency inference. It makes the code easier to read and maintain.
  128. Adapted from
  129. https://github.com/microsoft/LoRA
  130. https://www.microsoft.com/en-us/research/project/dp-transformers/
  131. """
  132. class MYDPMergedLinear(nn.Module):
  133. def __init__(
  134. self,
  135. in_features: int,
  136. out_features: int,
  137. pretrained_module,
  138. lora_r=0,
  139. lora_alpha=1.,
  140. lora_dropout=0.,
  141. ):
  142. super(MYDPMergedLinear, self).__init__()
  143. self.pretrained_module = pretrained_module
  144. self.lora_r = lora_r
  145. self.lora_alpha = lora_alpha
  146. self.lora_dropout = nn.Dropout(p=lora_dropout)
  147. if self.lora_r > 0:
  148. self.lora_A = nn.Linear(in_features=in_features, out_features=lora_r, bias=False)
  149. self.lora_B = nn.Linear(in_features=lora_r, out_features=out_features, bias=False)
  150. self.scaling = self.lora_alpha / lora_r
  151. self.reset_parameters()
  152. def forward(self, x: torch.Tensor):
  153. result = self.pretrained_module(x)
  154. if self.lora_r > 0:
  155. after_dropout = self.lora_dropout(x)
  156. after_A = self.lora_A(after_dropout)
  157. after_B = self.lora_B(after_A)
  158. result += after_B * self.scaling
  159. return result
  160. def reset_parameters(self):
  161. # self.linear.reset_parameters()
  162. if self.lora_r > 0:
  163. self.lora_A.reset_parameters()
  164. self.lora_B.weight.data.zero_()
  165. @staticmethod
  166. def from_transformers_conv1d(
  167. original_layer,
  168. lora_r=0,
  169. lora_alpha=1.,
  170. lora_dropout=0.,
  171. ) -> "MYDPMergedLinear":
  172. lora_layer = MYDPMergedLinear(
  173. in_features=original_layer.weight.shape[0],
  174. out_features=original_layer.weight.shape[1],
  175. pretrained_module = original_layer,
  176. lora_r=lora_r,
  177. lora_alpha=lora_alpha,
  178. lora_dropout=lora_dropout,
  179. ).to(original_layer.weight.device)
  180. return lora_layer
  181. def convert_gpt2_attention_to_lora(
  182. model: transformers.GPT2PreTrainedModel,
  183. lora_r=0,
  184. lora_alpha=1.,
  185. lora_dropout=0.,
  186. ) -> transformers.GPT2PreTrainedModel:
  187. if not isinstance(model, transformers.GPT2PreTrainedModel):
  188. raise TypeError("Requires a GPT2 model")
  189. if not hasattr(model, "h") and hasattr(model, "transformer"):
  190. transformer = model.transformer
  191. else:
  192. transformer = model
  193. for h_i in transformer.h:
  194. new_layer = MYDPMergedLinear.from_transformers_conv1d(
  195. original_layer=h_i.attn.c_attn,
  196. lora_r=lora_r,
  197. lora_alpha=lora_alpha,
  198. lora_dropout=lora_dropout,
  199. )
  200. h_i.attn.c_attn = new_layer
  201. return model
  202. def mutate_model(model: torch.nn.Module, lora_r=0, lora_alpha=1., lora_dropout=0.):
  203. for name, module in model.named_children():
  204. if name == "c_attn":
  205. new_layer = MYDPMergedLinear.from_transformers_conv1d(
  206. original_layer=module,
  207. lora_r=lora_r,
  208. lora_alpha=lora_alpha,
  209. lora_dropout=lora_dropout,
  210. )
  211. setattr(model, name, new_layer)
  212. else:
  213. mutate_model(module, lora_r, lora_alpha, lora_dropout) # recursively call the function on the module
  214. def mark_only_lora_as_trainable(model: torch.nn.Module) -> None:
  215. model.requires_grad_(True)
  216. for n, p in model.named_parameters():
  217. if 'lora_' not in n:
  218. p.requires_grad = False
  219. class AdapterLayer(nn.Module):
  220. def __init__(
  221. self,
  222. emb_dim: int,
  223. bottleneck_size: int,
  224. bias = True
  225. ):
  226. super().__init__()
  227. self.sharif_llm_adapter = nn.Sequential(
  228. nn.Linear(emb_dim, bottleneck_size, bias=bias),
  229. nn.ReLU(),
  230. nn.Linear(bottleneck_size, emb_dim, bias=bias)
  231. )
  232. def forward(self, x: torch.Tensor):
  233. output = x + self.sharif_llm_adapter(x)
  234. return output
  235. class FeedForwardAdapterWrapper(nn.Module):
  236. def __init__(
  237. self,
  238. original_module: GPT2MLP,
  239. bottleneck_size: int,
  240. emb_dim,
  241. bias = True
  242. ):
  243. super().__init__()
  244. assert isinstance(original_module, GPT2MLP)
  245. self.original_module = original_module
  246. self.adapter = AdapterLayer(emb_dim, bottleneck_size, bias=bias)
  247. def forward(self, x: torch.Tensor):
  248. output = self.original_module(x)
  249. output = self.adapter(output)
  250. return output
  251. def mutate_model_recursive_adapter(model: nn.Module, bottleneck_size: int, emb_dim, bias=True):
  252. for name, module in model.named_children():
  253. if isinstance(module, GPT2MLP):
  254. feed_forward_with_adapter = FeedForwardAdapterWrapper(module, bottleneck_size, emb_dim, bias)
  255. setattr(model, name, feed_forward_with_adapter)
  256. else:
  257. mutate_model_recursive_adapter(module, bottleneck_size, emb_dim, bias) # recursively call the function on the module
  258. def mutate_model_adapter(model: nn.Module, bottleneck_size: int, emb_dim, bias=True):
  259. if hasattr(model, '_mutated'):
  260. print("Model already contains adapter layers! \n Try reloading the model.")
  261. return
  262. mutate_model_recursive_adapter(model, bottleneck_size, emb_dim, bias)
  263. model._mutated = True