You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

02_gpt_custom.ipynb 53KB

3 months ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "id": "19c25879-e13a-4f5e-8b5a-67d6bb77c3f6",
  6. "metadata": {
  7. "tags": []
  8. },
  9. "source": [
  10. "# Intro"
  11. ]
  12. },
  13. {
  14. "cell_type": "code",
  15. "execution_count": 1,
  16. "id": "ca485005-54c1-4126-8c1e-53ca633b7f26",
  17. "metadata": {
  18. "tags": []
  19. },
  20. "outputs": [
  21. {
  22. "name": "stdout",
  23. "output_type": "stream",
  24. "text": [
  25. "Python version is: 3.10.11\n",
  26. "Torch version is: 1.13.1+cu117\n",
  27. "Nvidia device is: NVIDIA GeForce RTX 4090\n",
  28. "Transformers version is: 4.32.1\n",
  29. "Adapterhub not found!!!\n"
  30. ]
  31. }
  32. ],
  33. "source": [
  34. "from transformers import GPT2TokenizerFast, GPT2Model, DataCollatorWithPadding\n",
  35. "from transformers.modeling_outputs import SequenceClassifierOutputWithPast\n",
  36. "import torch\n",
  37. "import torch.nn as nn\n",
  38. "from utils import print_system_info\n",
  39. "from typing import Literal, Optional, List, Dict, Callable\n",
  40. "from types import SimpleNamespace\n",
  41. "from dataclasses import dataclass\n",
  42. "\n",
  43. "print_system_info()"
  44. ]
  45. },
  46. {
  47. "cell_type": "code",
  48. "execution_count": 2,
  49. "id": "931ebd25-5e5a-4fdf-b2db-92d4ccf7f88e",
  50. "metadata": {
  51. "tags": []
  52. },
  53. "outputs": [],
  54. "source": [
  55. "DEVICE = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
  56. "MODEL_NAME = 'gpt2'\n",
  57. "NAMESPACE = 'sadcl'\n",
  58. "\n",
  59. "INIT_TEXT = \"sentiment or value or relation of the previous text is\"\n",
  60. "N_LAST_LAYERS = 10\n",
  61. "\n",
  62. "N_TOKENS = 5"
  63. ]
  64. },
  65. {
  66. "cell_type": "markdown",
  67. "id": "e879d47e-0b67-452a-91c2-f36383efbed8",
  68. "metadata": {},
  69. "source": [
  70. "# Class"
  71. ]
  72. },
  73. {
  74. "cell_type": "code",
  75. "execution_count": 3,
  76. "id": "7c61fcde-f9e7-4d30-b989-511010e6298b",
  77. "metadata": {
  78. "tags": []
  79. },
  80. "outputs": [],
  81. "source": [
  82. "def initialize_embedding(\n",
  83. " emb_dim: int,\n",
  84. " n_tokens: int, \n",
  85. " random_range: float,\n",
  86. " initialize_from: Optional[torch.Tensor]\n",
  87. "):\n",
  88. " if initialize_from is None:\n",
  89. " return torch.FloatTensor(n_tokens, emb_dim).uniform_(-random_range, random_range)\n",
  90. "\n",
  91. " assert initialize_from.shape[0] >= n_tokens\n",
  92. " assert initialize_from.shape[1] == emb_dim\n",
  93. " return initialize_from[:n_tokens, :].detach().clone()\n",
  94. "\n",
  95. "class SoftEmbedding(nn.Module):\n",
  96. " def __init__(\n",
  97. " self,\n",
  98. " emb_dim: int,\n",
  99. " n_tokens: int,\n",
  100. " first_layer_flag: bool = False,\n",
  101. " random_range: float = 0.1,\n",
  102. " initialize_from: Optional[torch.Tensor] = None\n",
  103. " ):\n",
  104. " super().__init__()\n",
  105. " \n",
  106. " self.emb_dim = emb_dim\n",
  107. " self.n_tokens = n_tokens\n",
  108. " self.first_layer_flag = first_layer_flag\n",
  109. " \n",
  110. " self.sadcl_learned_embedding = nn.parameter.Parameter(\n",
  111. " initialize_embedding(\n",
  112. " emb_dim,\n",
  113. " n_tokens,\n",
  114. " random_range,\n",
  115. " initialize_from\n",
  116. " )\n",
  117. " )\n",
  118. " # self.sadcl_mlp = nn.Sequential(\n",
  119. " # nn.Linear(emb_dim, 24, bias=False),\n",
  120. " # nn.ReLU(),\n",
  121. " # nn.Linear(24, 768, bias=False)\n",
  122. " # )\n",
  123. "\n",
  124. " assert self.sadcl_learned_embedding.shape == (n_tokens, emb_dim)\n",
  125. " \n",
  126. " def forward(self, input_embedding, attention_mask, sequnce_lengths):\n",
  127. " # input_embedding.shape = (batch_size, num_of_input_tokens+n_tokens, emb_dim)\n",
  128. " # output_embedding = []\n",
  129. " \n",
  130. " learned_embedding = self.sadcl_learned_embedding# + self.sadcl_mlp(self.sadcl_learned_embedding)\n",
  131. " \n",
  132. " batch_size = input_embedding.size(0)\n",
  133. " learned_embedding = learned_embedding.repeat(batch_size, 1, 1) # (batch_size, n_tokens, emb_dim)\n",
  134. " \n",
  135. " attention_mask_shift = torch.zeros((batch_size, 1, 1, self.n_tokens), device=attention_mask.device)\n",
  136. " attention_mask = torch.cat([attention_mask_shift, attention_mask[:, :, :, :-self.n_tokens]], dim=-1)\n",
  137. " if self.first_layer_flag:\n",
  138. " output_embedding = torch.cat([learned_embedding, input_embedding[:, :-self.n_tokens]], dim=1)\n",
  139. " else:\n",
  140. " output_embedding = torch.cat([learned_embedding, input_embedding[:, self.n_tokens:]], dim=1)\n",
  141. " # print(attention_mask == 0)\n",
  142. " return output_embedding, attention_mask\n",
  143. " \n",
  144. " def get_weights(self):\n",
  145. " return self.sadcl_learned_embedding.detach().clone()\n",
  146. "\n",
  147. "\n",
  148. "class GPT2ModuleWrapper(nn.Module):\n",
  149. " def __init__(\n",
  150. " self,\n",
  151. " module,\n",
  152. " emb_dim:int,\n",
  153. " n_tokens:int,\n",
  154. " get_sequnce_lengths:int,\n",
  155. " first_layer_flag:bool,\n",
  156. " initialize_from:Optional[torch.Tensor] = None\n",
  157. " ):\n",
  158. " super().__init__()\n",
  159. " self.original_module = module\n",
  160. " self.soft_prompt = SoftEmbedding(\n",
  161. " emb_dim=emb_dim,\n",
  162. " n_tokens=n_tokens,\n",
  163. " first_layer_flag=first_layer_flag,\n",
  164. " initialize_from=initialize_from\n",
  165. " )\n",
  166. " self.get_sequnce_lengths = get_sequnce_lengths\n",
  167. " \n",
  168. " \n",
  169. " def forward(self, hidden_states, *args, **kwargs):\n",
  170. " output_embedding, attention_mask = self.soft_prompt(\n",
  171. " hidden_states,\n",
  172. " kwargs['attention_mask'],\n",
  173. " self.get_sequnce_lengths()\n",
  174. " )\n",
  175. " kwargs['attention_mask'] = attention_mask\n",
  176. " return self.original_module(output_embedding, *args, **kwargs)\n",
  177. "\n",
  178. "class GPT2Injector:\n",
  179. " def __init__(self):\n",
  180. " self.sequnce_lengths = None\n",
  181. " \n",
  182. " def get_sequnce_lengths(self):\n",
  183. " return self.sequnce_lengths\n",
  184. " \n",
  185. " def _mutate_model_forward(self, model):\n",
  186. " old_forward = model.forward\n",
  187. " pad_token_id = model.config.pad_token_id\n",
  188. " def new_forward(*args, **kwargs):\n",
  189. " input_ids = kwargs['input_ids']\n",
  190. " self.sequnce_lengths = (\n",
  191. " torch.eq(input_ids, pad_token_id).long().argmax(-1) - 1\n",
  192. " ).detach().cpu().tolist()\n",
  193. " return old_forward(*args, **kwargs)\n",
  194. " model.forward = new_forward\n",
  195. " \n",
  196. " def _reverse_mutate_model_forward(self, model):\n",
  197. " orig_class = type(model)\n",
  198. " model.forward = orig_class.forward.__get__(model, orig_class)\n",
  199. " \n",
  200. " def mutate(self, model, n_layers, n_tokens, init_prompts):\n",
  201. " self._mutate_model_forward(model)\n",
  202. " module_list = manager.model.h\n",
  203. " start = len(module_list) - n_layers\n",
  204. " for idx in range(start, len(module_list)):\n",
  205. " module_list[idx] = GPT2ModuleWrapper(\n",
  206. " module=module_list[idx],\n",
  207. " emb_dim=model.embed_dim,\n",
  208. " n_tokens=n_tokens,\n",
  209. " get_sequnce_lengths=self.get_sequnce_lengths,\n",
  210. " first_layer_flag=(idx == start),\n",
  211. " initialize_from=init_prompts[idx][0]\n",
  212. " )\n",
  213. " return module_list[start:]\n",
  214. " \n",
  215. " def reverse_mutate(self, model):\n",
  216. " self._reverse_mutate_model_forward(model)\n",
  217. " module_list = model.h\n",
  218. " for idx in range(len(module_list)):\n",
  219. " if type(module_list[idx]) is GPT2ModuleWrapper:\n",
  220. " module_list[idx] = module_list[idx].original_module\n"
  221. ]
  222. },
  223. {
  224. "cell_type": "code",
  225. "execution_count": 4,
  226. "id": "f215af71-8f06-4466-a1cc-bf27b1193627",
  227. "metadata": {
  228. "tags": []
  229. },
  230. "outputs": [],
  231. "source": [
  232. "class MixHeadModel(nn.Module):\n",
  233. " def __init__(self, model, head):\n",
  234. " super().__init__()\n",
  235. " self.model = model\n",
  236. " self.sadcl_head = head\n",
  237. " \n",
  238. " def forward(self, *args, **kwargs):\n",
  239. " labels = kwargs.pop('labels', None)\n",
  240. " transformer_outputs = self.model(*args, **kwargs)\n",
  241. " out = self.sadcl_head(\n",
  242. " transformer_outputs=transformer_outputs,\n",
  243. " labels=labels\n",
  244. " )\n",
  245. " return out"
  246. ]
  247. },
  248. {
  249. "cell_type": "code",
  250. "execution_count": 5,
  251. "id": "cea800ea-d538-4aab-8aca-41feaba49b7d",
  252. "metadata": {
  253. "tags": []
  254. },
  255. "outputs": [],
  256. "source": [
  257. "class GPT2ClassificationHead(nn.Module):\n",
  258. " def __init__(\n",
  259. " self,\n",
  260. " emb_dim: int,\n",
  261. " n_labels: int,\n",
  262. " get_sequnce_lengths: Callable[[], List[int]],\n",
  263. " n_tokens: int,\n",
  264. " init_range: float,\n",
  265. " bias=True\n",
  266. " ):\n",
  267. " super().__init__()\n",
  268. " \n",
  269. " self.get_sequnce_lengths = get_sequnce_lengths\n",
  270. " self.n_labels = n_labels\n",
  271. " self.n_tokens = n_tokens\n",
  272. " self.loss_func = nn.CrossEntropyLoss()\n",
  273. " \n",
  274. " self.score = nn.Linear(emb_dim, n_labels, bias) # Bias is false in huggingface implementation\n",
  275. " \n",
  276. " self._init_weights(init_range)\n",
  277. " \n",
  278. " def _init_weights(self, init_range):\n",
  279. " self.score.weight.data.normal_(mean=0.0, std=init_range)\n",
  280. " if self.score.bias is not None:\n",
  281. " self.score.bias.data.zero_()\n",
  282. " \n",
  283. " def forward(self, transformer_outputs, labels=None):\n",
  284. " last_text_token_per_batch = self.get_sequnce_lengths()\n",
  285. " last_prompt_token_per_batch = [\n",
  286. " seqlen + self.n_tokens for seqlen in last_text_token_per_batch\n",
  287. " ]\n",
  288. " last_hidden_state = transformer_outputs.last_hidden_state\n",
  289. " batch_size = last_hidden_state.size(0)\n",
  290. " \n",
  291. " # last_text_token = last_hidden_state[range(batch_size), last_text_token_per_batch]\n",
  292. " last_prompt_token = last_hidden_state[range(batch_size), last_prompt_token_per_batch]\n",
  293. " logits = self.score(last_prompt_token)\n",
  294. " \n",
  295. " loss = None\n",
  296. " if labels is not None:\n",
  297. " loss = self.loss_func(logits.view(-1, self.n_labels), labels.view(-1))\n",
  298. " \n",
  299. " return SequenceClassifierOutputWithPast(\n",
  300. " loss=loss,\n",
  301. " logits=logits,\n",
  302. " past_key_values=transformer_outputs.past_key_values,\n",
  303. " hidden_states=transformer_outputs.hidden_states,\n",
  304. " attentions=transformer_outputs.attentions,\n",
  305. " )"
  306. ]
  307. },
  308. {
  309. "cell_type": "code",
  310. "execution_count": 6,
  311. "id": "577784eb-ab61-424d-a633-7b030d6d06d3",
  312. "metadata": {
  313. "tags": []
  314. },
  315. "outputs": [],
  316. "source": [
  317. "@dataclass\n",
  318. "class PEFTConfig:\n",
  319. " name: str\n",
  320. " kind: Literal['regression', 'classification', 'generation']\n",
  321. " n_labels: Optional[int] # only for classification\n",
  322. " @classmethod\n",
  323. " def classification(cls, name: str, n_labels: int):\n",
  324. " return cls(name=name, n_labels=n_labels, kind='classification')\n",
  325. "\n",
  326. "class GPT2LLL:\n",
  327. " def __init__(\n",
  328. " self,\n",
  329. " n_tokens=N_TOKENS,\n",
  330. " n_last_layers=N_LAST_LAYERS,\n",
  331. " model_name=MODEL_NAME,\n",
  332. " device=DEVICE,\n",
  333. " init_text=INIT_TEXT\n",
  334. " ):\n",
  335. " self.n_tokens = n_tokens\n",
  336. " self.n_last_layers = n_last_layers\n",
  337. " self.model_name = model_name\n",
  338. " self.device = device\n",
  339. " \n",
  340. " self.pefts = {}\n",
  341. " \n",
  342. " self.tokenizer = GPT2TokenizerFast.from_pretrained(model_name, add_prefix_space=True)\n",
  343. " self.tokenizer.pad_token = self.tokenizer.eos_token\n",
  344. " \n",
  345. " self.model = GPT2Model.from_pretrained(model_name, pad_token_id=self.tokenizer.pad_token_id)\n",
  346. " self.model.to(device);\n",
  347. " \n",
  348. " init_tokens = self.tokenizer(init_text, return_tensors='pt').to(device)\n",
  349. " with torch.no_grad():\n",
  350. " self.init_prompts = self.model(**init_tokens, output_hidden_states=True).hidden_states\n",
  351. " \n",
  352. " self.current_peft_name = None\n",
  353. " self.current_mix_model = None\n",
  354. " \n",
  355. " @property\n",
  356. " def current_peft(self):\n",
  357. " if self.current_peft_name is None:\n",
  358. " return None\n",
  359. " return self.pefts[self.current_peft_name]\n",
  360. " \n",
  361. " def generate_tokenizer_map(self):\n",
  362. " n_tokens = self.n_tokens\n",
  363. " tokenizer = self.tokenizer\n",
  364. " def return_function(rows):\n",
  365. " outputs_dict = tokenizer(rows)\n",
  366. " for row in outputs_dict['input_ids']:\n",
  367. " row.extend([tokenizer.pad_token_id] * n_tokens)\n",
  368. " for row in outputs_dict['attention_mask']:\n",
  369. " row.extend([0] * n_tokens)\n",
  370. " return outputs_dict\n",
  371. " return return_function\n",
  372. " \n",
  373. " def activate_peft(self, name):\n",
  374. " self.current_peft_name = name\n",
  375. " \n",
  376. " self.current_peft.injector.mutate(\n",
  377. " model=self.model,\n",
  378. " n_layers=self.n_last_layers,\n",
  379. " n_tokens=self.n_tokens,\n",
  380. " init_prompts=self.init_prompts\n",
  381. " )\n",
  382. " self.current_mix_model = MixHeadModel(\n",
  383. " head=self.current_peft.head,\n",
  384. " model=self.model\n",
  385. " )\n",
  386. " \n",
  387. " def auto_freeze(self):\n",
  388. " print(\"Unfreezed params are:\")\n",
  389. " for param_name, weights in self.current_mix_model.named_parameters():\n",
  390. " if NAMESPACE in param_name:\n",
  391. " weights.requires_grad = True\n",
  392. " print(\"- \" + param_name)\n",
  393. " else:\n",
  394. " weights.requires_grad = False\n",
  395. " \n",
  396. " def add_peft(self, config: PEFTConfig):\n",
  397. " assert config.name not in self.pefts\n",
  398. " injector = GPT2Injector()\n",
  399. " head = GPT2ClassificationHead(\n",
  400. " emb_dim=self.model.embed_dim,\n",
  401. " n_labels=config.n_labels,\n",
  402. " get_sequnce_lengths=injector.get_sequnce_lengths,\n",
  403. " n_tokens=self.n_tokens,\n",
  404. " init_range=self.model.config.initializer_range,\n",
  405. " bias=False\n",
  406. " )\n",
  407. " head.to(self.device)\n",
  408. " self.pefts[config.name] = SimpleNamespace(\n",
  409. " head=head,\n",
  410. " injector=injector\n",
  411. " )"
  412. ]
  413. },
  414. {
  415. "cell_type": "markdown",
  416. "id": "8fcfcb04-6513-4321-917f-d13c2dba886e",
  417. "metadata": {
  418. "tags": []
  419. },
  420. "source": [
  421. "# Train"
  422. ]
  423. },
  424. {
  425. "cell_type": "markdown",
  426. "id": "003fb992-fb75-4655-b60a-284ef0dcf4eb",
  427. "metadata": {
  428. "tags": []
  429. },
  430. "source": [
  431. "## Prepare"
  432. ]
  433. },
  434. {
  435. "cell_type": "code",
  436. "execution_count": 7,
  437. "id": "cfe42619-bb12-430e-9359-5ee2d2e40bdc",
  438. "metadata": {
  439. "tags": []
  440. },
  441. "outputs": [
  442. {
  443. "name": "stdout",
  444. "output_type": "stream",
  445. "text": [
  446. "Unfreezed params are:\n",
  447. "- model.h.2.soft_prompt.sadcl_learned_embedding\n",
  448. "- model.h.3.soft_prompt.sadcl_learned_embedding\n",
  449. "- model.h.4.soft_prompt.sadcl_learned_embedding\n",
  450. "- model.h.5.soft_prompt.sadcl_learned_embedding\n",
  451. "- model.h.6.soft_prompt.sadcl_learned_embedding\n",
  452. "- model.h.7.soft_prompt.sadcl_learned_embedding\n",
  453. "- model.h.8.soft_prompt.sadcl_learned_embedding\n",
  454. "- model.h.9.soft_prompt.sadcl_learned_embedding\n",
  455. "- model.h.10.soft_prompt.sadcl_learned_embedding\n",
  456. "- model.h.11.soft_prompt.sadcl_learned_embedding\n",
  457. "- sadcl_head.score.weight\n"
  458. ]
  459. }
  460. ],
  461. "source": [
  462. "peft_name = 'peft1'\n",
  463. "\n",
  464. "manager = GPT2LLL()\n",
  465. "manager.add_peft(PEFTConfig.classification(name=peft_name, n_labels=2))\n",
  466. "manager.activate_peft(peft_name)\n",
  467. "manager.auto_freeze()"
  468. ]
  469. },
  470. {
  471. "cell_type": "code",
  472. "execution_count": 8,
  473. "id": "072bf63b-de2f-4c05-a6a2-6fde3bb5aa6d",
  474. "metadata": {
  475. "tags": []
  476. },
  477. "outputs": [],
  478. "source": [
  479. "from config import load_config\n",
  480. "config = load_config('config.yaml')"
  481. ]
  482. },
  483. {
  484. "cell_type": "code",
  485. "execution_count": 9,
  486. "id": "3b3827aa-a61c-4e34-9e67-86768bd8b446",
  487. "metadata": {
  488. "tags": []
  489. },
  490. "outputs": [
  491. {
  492. "name": "stderr",
  493. "output_type": "stream",
  494. "text": [
  495. "Found cached dataset glue (/home/mohalisad/.cache/huggingface/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n"
  496. ]
  497. },
  498. {
  499. "data": {
  500. "application/vnd.jupyter.widget-view+json": {
  501. "model_id": "ac4726d36f6241c59be6dbeee759fce2",
  502. "version_major": 2,
  503. "version_minor": 0
  504. },
  505. "text/plain": [
  506. " 0%| | 0/3 [00:00<?, ?it/s]"
  507. ]
  508. },
  509. "metadata": {},
  510. "output_type": "display_data"
  511. },
  512. {
  513. "name": "stderr",
  514. "output_type": "stream",
  515. "text": [
  516. "Loading cached processed dataset at /home/mohalisad/.cache/huggingface/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-f7a02c6d65621ecd.arrow\n",
  517. "Loading cached processed dataset at /home/mohalisad/.cache/huggingface/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-c36341ab82d2d37d.arrow\n",
  518. "Loading cached processed dataset at /home/mohalisad/.cache/huggingface/datasets/glue/cola/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-9f7663dac81ea13b.arrow\n"
  519. ]
  520. }
  521. ],
  522. "source": [
  523. "from datasets import load_dataset\n",
  524. "dataset = load_dataset('glue', 'cola')\n",
  525. "tokenizer_map = manager.generate_tokenizer_map()\n",
  526. "dataset = dataset.map(lambda x: tokenizer_map(x['sentence']), batched=True)\n",
  527. "dataset.set_format(type='torch', columns=[\n",
  528. " 'input_ids', 'attention_mask', 'label' # 'token_type_ids',\n",
  529. "])"
  530. ]
  531. },
  532. {
  533. "cell_type": "markdown",
  534. "id": "5331fedd-e6ec-4387-a1ec-55488d144f45",
  535. "metadata": {},
  536. "source": [
  537. "## Training"
  538. ]
  539. },
  540. {
  541. "cell_type": "code",
  542. "execution_count": 15,
  543. "id": "e13c9012-089f-45c1-baea-4f0850ccfbaa",
  544. "metadata": {
  545. "tags": []
  546. },
  547. "outputs": [
  548. {
  549. "data": {
  550. "text/html": [
  551. "\n",
  552. " <div>\n",
  553. " \n",
  554. " <progress value='14609' max='42880' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
  555. " [14609/42880 04:58 < 09:37, 48.96 it/s, Epoch 54.51/160]\n",
  556. " </div>\n",
  557. " <table border=\"1\" class=\"dataframe\">\n",
  558. " <thead>\n",
  559. " <tr style=\"text-align: left;\">\n",
  560. " <th>Epoch</th>\n",
  561. " <th>Training Loss</th>\n",
  562. " <th>Validation Loss</th>\n",
  563. " <th>Accuracy</th>\n",
  564. " <th>F1-score-1</th>\n",
  565. " <th>F1-score-ma</th>\n",
  566. " </tr>\n",
  567. " </thead>\n",
  568. " <tbody>\n",
  569. " <tr>\n",
  570. " <td>1</td>\n",
  571. " <td>No log</td>\n",
  572. " <td>0.617917</td>\n",
  573. " <td>0.691275</td>\n",
  574. " <td>0.817253</td>\n",
  575. " <td>0.411713</td>\n",
  576. " </tr>\n",
  577. " <tr>\n",
  578. " <td>2</td>\n",
  579. " <td>0.618200</td>\n",
  580. " <td>0.620259</td>\n",
  581. " <td>0.691275</td>\n",
  582. " <td>0.817253</td>\n",
  583. " <td>0.411713</td>\n",
  584. " </tr>\n",
  585. " <tr>\n",
  586. " <td>3</td>\n",
  587. " <td>0.618200</td>\n",
  588. " <td>0.612236</td>\n",
  589. " <td>0.691275</td>\n",
  590. " <td>0.817253</td>\n",
  591. " <td>0.411713</td>\n",
  592. " </tr>\n",
  593. " <tr>\n",
  594. " <td>4</td>\n",
  595. " <td>0.616500</td>\n",
  596. " <td>0.613789</td>\n",
  597. " <td>0.691275</td>\n",
  598. " <td>0.817253</td>\n",
  599. " <td>0.411713</td>\n",
  600. " </tr>\n",
  601. " <tr>\n",
  602. " <td>5</td>\n",
  603. " <td>0.616500</td>\n",
  604. " <td>0.615989</td>\n",
  605. " <td>0.691275</td>\n",
  606. " <td>0.817253</td>\n",
  607. " <td>0.411713</td>\n",
  608. " </tr>\n",
  609. " <tr>\n",
  610. " <td>6</td>\n",
  611. " <td>0.612800</td>\n",
  612. " <td>0.614961</td>\n",
  613. " <td>0.691275</td>\n",
  614. " <td>0.817253</td>\n",
  615. " <td>0.411713</td>\n",
  616. " </tr>\n",
  617. " <tr>\n",
  618. " <td>7</td>\n",
  619. " <td>0.612800</td>\n",
  620. " <td>0.612622</td>\n",
  621. " <td>0.691275</td>\n",
  622. " <td>0.817253</td>\n",
  623. " <td>0.411713</td>\n",
  624. " </tr>\n",
  625. " <tr>\n",
  626. " <td>8</td>\n",
  627. " <td>0.611300</td>\n",
  628. " <td>0.613691</td>\n",
  629. " <td>0.691275</td>\n",
  630. " <td>0.817253</td>\n",
  631. " <td>0.411713</td>\n",
  632. " </tr>\n",
  633. " <tr>\n",
  634. " <td>9</td>\n",
  635. " <td>0.611300</td>\n",
  636. " <td>0.613889</td>\n",
  637. " <td>0.691275</td>\n",
  638. " <td>0.817253</td>\n",
  639. " <td>0.411713</td>\n",
  640. " </tr>\n",
  641. " <tr>\n",
  642. " <td>10</td>\n",
  643. " <td>0.609400</td>\n",
  644. " <td>0.616157</td>\n",
  645. " <td>0.691275</td>\n",
  646. " <td>0.817253</td>\n",
  647. " <td>0.411713</td>\n",
  648. " </tr>\n",
  649. " <tr>\n",
  650. " <td>11</td>\n",
  651. " <td>0.609400</td>\n",
  652. " <td>0.614404</td>\n",
  653. " <td>0.691275</td>\n",
  654. " <td>0.817253</td>\n",
  655. " <td>0.411713</td>\n",
  656. " </tr>\n",
  657. " <tr>\n",
  658. " <td>12</td>\n",
  659. " <td>0.609700</td>\n",
  660. " <td>0.614005</td>\n",
  661. " <td>0.691275</td>\n",
  662. " <td>0.817253</td>\n",
  663. " <td>0.411713</td>\n",
  664. " </tr>\n",
  665. " <tr>\n",
  666. " <td>13</td>\n",
  667. " <td>0.609700</td>\n",
  668. " <td>0.611722</td>\n",
  669. " <td>0.691275</td>\n",
  670. " <td>0.817253</td>\n",
  671. " <td>0.411713</td>\n",
  672. " </tr>\n",
  673. " <tr>\n",
  674. " <td>14</td>\n",
  675. " <td>0.607100</td>\n",
  676. " <td>0.609891</td>\n",
  677. " <td>0.692234</td>\n",
  678. " <td>0.817717</td>\n",
  679. " <td>0.415012</td>\n",
  680. " </tr>\n",
  681. " <tr>\n",
  682. " <td>15</td>\n",
  683. " <td>0.606600</td>\n",
  684. " <td>0.612338</td>\n",
  685. " <td>0.691275</td>\n",
  686. " <td>0.817253</td>\n",
  687. " <td>0.411713</td>\n",
  688. " </tr>\n",
  689. " <tr>\n",
  690. " <td>16</td>\n",
  691. " <td>0.606600</td>\n",
  692. " <td>0.614802</td>\n",
  693. " <td>0.691275</td>\n",
  694. " <td>0.817253</td>\n",
  695. " <td>0.411713</td>\n",
  696. " </tr>\n",
  697. " <tr>\n",
  698. " <td>17</td>\n",
  699. " <td>0.604600</td>\n",
  700. " <td>0.614289</td>\n",
  701. " <td>0.691275</td>\n",
  702. " <td>0.817253</td>\n",
  703. " <td>0.411713</td>\n",
  704. " </tr>\n",
  705. " <tr>\n",
  706. " <td>18</td>\n",
  707. " <td>0.604600</td>\n",
  708. " <td>0.610662</td>\n",
  709. " <td>0.692234</td>\n",
  710. " <td>0.817717</td>\n",
  711. " <td>0.415012</td>\n",
  712. " </tr>\n",
  713. " <tr>\n",
  714. " <td>19</td>\n",
  715. " <td>0.603600</td>\n",
  716. " <td>0.610867</td>\n",
  717. " <td>0.692234</td>\n",
  718. " <td>0.817717</td>\n",
  719. " <td>0.415012</td>\n",
  720. " </tr>\n",
  721. " <tr>\n",
  722. " <td>20</td>\n",
  723. " <td>0.603600</td>\n",
  724. " <td>0.615460</td>\n",
  725. " <td>0.691275</td>\n",
  726. " <td>0.817253</td>\n",
  727. " <td>0.411713</td>\n",
  728. " </tr>\n",
  729. " <tr>\n",
  730. " <td>21</td>\n",
  731. " <td>0.602600</td>\n",
  732. " <td>0.612030</td>\n",
  733. " <td>0.692234</td>\n",
  734. " <td>0.817717</td>\n",
  735. " <td>0.415012</td>\n",
  736. " </tr>\n",
  737. " <tr>\n",
  738. " <td>22</td>\n",
  739. " <td>0.602600</td>\n",
  740. " <td>0.611254</td>\n",
  741. " <td>0.692234</td>\n",
  742. " <td>0.817717</td>\n",
  743. " <td>0.415012</td>\n",
  744. " </tr>\n",
  745. " <tr>\n",
  746. " <td>23</td>\n",
  747. " <td>0.601900</td>\n",
  748. " <td>0.612736</td>\n",
  749. " <td>0.691275</td>\n",
  750. " <td>0.817253</td>\n",
  751. " <td>0.411713</td>\n",
  752. " </tr>\n",
  753. " <tr>\n",
  754. " <td>24</td>\n",
  755. " <td>0.601900</td>\n",
  756. " <td>0.613839</td>\n",
  757. " <td>0.691275</td>\n",
  758. " <td>0.817253</td>\n",
  759. " <td>0.411713</td>\n",
  760. " </tr>\n",
  761. " <tr>\n",
  762. " <td>25</td>\n",
  763. " <td>0.604800</td>\n",
  764. " <td>0.612303</td>\n",
  765. " <td>0.691275</td>\n",
  766. " <td>0.817253</td>\n",
  767. " <td>0.411713</td>\n",
  768. " </tr>\n",
  769. " <tr>\n",
  770. " <td>26</td>\n",
  771. " <td>0.604800</td>\n",
  772. " <td>0.612139</td>\n",
  773. " <td>0.691275</td>\n",
  774. " <td>0.817253</td>\n",
  775. " <td>0.411713</td>\n",
  776. " </tr>\n",
  777. " <tr>\n",
  778. " <td>27</td>\n",
  779. " <td>0.603400</td>\n",
  780. " <td>0.612106</td>\n",
  781. " <td>0.691275</td>\n",
  782. " <td>0.817253</td>\n",
  783. " <td>0.411713</td>\n",
  784. " </tr>\n",
  785. " <tr>\n",
  786. " <td>28</td>\n",
  787. " <td>0.602300</td>\n",
  788. " <td>0.614560</td>\n",
  789. " <td>0.691275</td>\n",
  790. " <td>0.817253</td>\n",
  791. " <td>0.411713</td>\n",
  792. " </tr>\n",
  793. " <tr>\n",
  794. " <td>29</td>\n",
  795. " <td>0.602300</td>\n",
  796. " <td>0.613581</td>\n",
  797. " <td>0.691275</td>\n",
  798. " <td>0.817253</td>\n",
  799. " <td>0.411713</td>\n",
  800. " </tr>\n",
  801. " <tr>\n",
  802. " <td>30</td>\n",
  803. " <td>0.602800</td>\n",
  804. " <td>0.615965</td>\n",
  805. " <td>0.691275</td>\n",
  806. " <td>0.817253</td>\n",
  807. " <td>0.411713</td>\n",
  808. " </tr>\n",
  809. " <tr>\n",
  810. " <td>31</td>\n",
  811. " <td>0.602800</td>\n",
  812. " <td>0.613715</td>\n",
  813. " <td>0.692234</td>\n",
  814. " <td>0.817717</td>\n",
  815. " <td>0.415012</td>\n",
  816. " </tr>\n",
  817. " <tr>\n",
  818. " <td>32</td>\n",
  819. " <td>0.601400</td>\n",
  820. " <td>0.613545</td>\n",
  821. " <td>0.692234</td>\n",
  822. " <td>0.817717</td>\n",
  823. " <td>0.415012</td>\n",
  824. " </tr>\n",
  825. " <tr>\n",
  826. " <td>33</td>\n",
  827. " <td>0.601400</td>\n",
  828. " <td>0.612631</td>\n",
  829. " <td>0.692234</td>\n",
  830. " <td>0.817717</td>\n",
  831. " <td>0.415012</td>\n",
  832. " </tr>\n",
  833. " <tr>\n",
  834. " <td>34</td>\n",
  835. " <td>0.601400</td>\n",
  836. " <td>0.611881</td>\n",
  837. " <td>0.692234</td>\n",
  838. " <td>0.817717</td>\n",
  839. " <td>0.415012</td>\n",
  840. " </tr>\n",
  841. " <tr>\n",
  842. " <td>35</td>\n",
  843. " <td>0.601400</td>\n",
  844. " <td>0.614503</td>\n",
  845. " <td>0.691275</td>\n",
  846. " <td>0.817253</td>\n",
  847. " <td>0.411713</td>\n",
  848. " </tr>\n",
  849. " <tr>\n",
  850. " <td>36</td>\n",
  851. " <td>0.600700</td>\n",
  852. " <td>0.610912</td>\n",
  853. " <td>0.692234</td>\n",
  854. " <td>0.817717</td>\n",
  855. " <td>0.415012</td>\n",
  856. " </tr>\n",
  857. " <tr>\n",
  858. " <td>37</td>\n",
  859. " <td>0.600700</td>\n",
  860. " <td>0.611916</td>\n",
  861. " <td>0.692234</td>\n",
  862. " <td>0.817717</td>\n",
  863. " <td>0.415012</td>\n",
  864. " </tr>\n",
  865. " <tr>\n",
  866. " <td>38</td>\n",
  867. " <td>0.600800</td>\n",
  868. " <td>0.611409</td>\n",
  869. " <td>0.692234</td>\n",
  870. " <td>0.817717</td>\n",
  871. " <td>0.415012</td>\n",
  872. " </tr>\n",
  873. " <tr>\n",
  874. " <td>39</td>\n",
  875. " <td>0.600800</td>\n",
  876. " <td>0.613652</td>\n",
  877. " <td>0.692234</td>\n",
  878. " <td>0.817717</td>\n",
  879. " <td>0.415012</td>\n",
  880. " </tr>\n",
  881. " <tr>\n",
  882. " <td>40</td>\n",
  883. " <td>0.600600</td>\n",
  884. " <td>0.612413</td>\n",
  885. " <td>0.692234</td>\n",
  886. " <td>0.817717</td>\n",
  887. " <td>0.415012</td>\n",
  888. " </tr>\n",
  889. " <tr>\n",
  890. " <td>41</td>\n",
  891. " <td>0.600600</td>\n",
  892. " <td>0.613673</td>\n",
  893. " <td>0.691275</td>\n",
  894. " <td>0.817253</td>\n",
  895. " <td>0.411713</td>\n",
  896. " </tr>\n",
  897. " <tr>\n",
  898. " <td>42</td>\n",
  899. " <td>0.600400</td>\n",
  900. " <td>0.611154</td>\n",
  901. " <td>0.692234</td>\n",
  902. " <td>0.817717</td>\n",
  903. " <td>0.415012</td>\n",
  904. " </tr>\n",
  905. " <tr>\n",
  906. " <td>43</td>\n",
  907. " <td>0.600000</td>\n",
  908. " <td>0.611216</td>\n",
  909. " <td>0.692234</td>\n",
  910. " <td>0.817717</td>\n",
  911. " <td>0.415012</td>\n",
  912. " </tr>\n",
  913. " <tr>\n",
  914. " <td>44</td>\n",
  915. " <td>0.600000</td>\n",
  916. " <td>0.610118</td>\n",
  917. " <td>0.692234</td>\n",
  918. " <td>0.817717</td>\n",
  919. " <td>0.415012</td>\n",
  920. " </tr>\n",
  921. " <tr>\n",
  922. " <td>45</td>\n",
  923. " <td>0.601900</td>\n",
  924. " <td>0.611573</td>\n",
  925. " <td>0.692234</td>\n",
  926. " <td>0.817717</td>\n",
  927. " <td>0.415012</td>\n",
  928. " </tr>\n",
  929. " <tr>\n",
  930. " <td>46</td>\n",
  931. " <td>0.601900</td>\n",
  932. " <td>0.613571</td>\n",
  933. " <td>0.692234</td>\n",
  934. " <td>0.817924</td>\n",
  935. " <td>0.412058</td>\n",
  936. " </tr>\n",
  937. " <tr>\n",
  938. " <td>47</td>\n",
  939. " <td>0.598700</td>\n",
  940. " <td>0.611853</td>\n",
  941. " <td>0.691275</td>\n",
  942. " <td>0.817253</td>\n",
  943. " <td>0.411713</td>\n",
  944. " </tr>\n",
  945. " <tr>\n",
  946. " <td>48</td>\n",
  947. " <td>0.598700</td>\n",
  948. " <td>0.611213</td>\n",
  949. " <td>0.691275</td>\n",
  950. " <td>0.817253</td>\n",
  951. " <td>0.411713</td>\n",
  952. " </tr>\n",
  953. " <tr>\n",
  954. " <td>49</td>\n",
  955. " <td>0.597600</td>\n",
  956. " <td>0.611855</td>\n",
  957. " <td>0.692234</td>\n",
  958. " <td>0.817924</td>\n",
  959. " <td>0.412058</td>\n",
  960. " </tr>\n",
  961. " <tr>\n",
  962. " <td>50</td>\n",
  963. " <td>0.597600</td>\n",
  964. " <td>0.611871</td>\n",
  965. " <td>0.692234</td>\n",
  966. " <td>0.817924</td>\n",
  967. " <td>0.412058</td>\n",
  968. " </tr>\n",
  969. " <tr>\n",
  970. " <td>51</td>\n",
  971. " <td>0.600100</td>\n",
  972. " <td>0.612086</td>\n",
  973. " <td>0.692234</td>\n",
  974. " <td>0.817924</td>\n",
  975. " <td>0.412058</td>\n",
  976. " </tr>\n",
  977. " <tr>\n",
  978. " <td>52</td>\n",
  979. " <td>0.600100</td>\n",
  980. " <td>0.610666</td>\n",
  981. " <td>0.692234</td>\n",
  982. " <td>0.817924</td>\n",
  983. " <td>0.412058</td>\n",
  984. " </tr>\n",
  985. " <tr>\n",
  986. " <td>53</td>\n",
  987. " <td>0.599600</td>\n",
  988. " <td>0.613406</td>\n",
  989. " <td>0.692234</td>\n",
  990. " <td>0.817924</td>\n",
  991. " <td>0.412058</td>\n",
  992. " </tr>\n",
  993. " <tr>\n",
  994. " <td>54</td>\n",
  995. " <td>0.599600</td>\n",
  996. " <td>0.617041</td>\n",
  997. " <td>0.692234</td>\n",
  998. " <td>0.817924</td>\n",
  999. " <td>0.412058</td>\n",
  1000. " </tr>\n",
  1001. " </tbody>\n",
  1002. "</table><p>"
  1003. ],
  1004. "text/plain": [
  1005. "<IPython.core.display.HTML object>"
  1006. ]
  1007. },
  1008. "metadata": {},
  1009. "output_type": "display_data"
  1010. },
  1011. {
  1012. "ename": "KeyboardInterrupt",
  1013. "evalue": "",
  1014. "output_type": "error",
  1015. "traceback": [
  1016. "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
  1017. "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
  1018. "Cell \u001b[0;32mIn[15], line 43\u001b[0m\n\u001b[1;32m 34\u001b[0m trainer \u001b[38;5;241m=\u001b[39m Trainer(\n\u001b[1;32m 35\u001b[0m model\u001b[38;5;241m=\u001b[39mmanager\u001b[38;5;241m.\u001b[39mcurrent_mix_model, \u001b[38;5;66;03m# manager.current_mix_model\u001b[39;00m\n\u001b[1;32m 36\u001b[0m args\u001b[38;5;241m=\u001b[39mtraining_args,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 40\u001b[0m compute_metrics\u001b[38;5;241m=\u001b[39mcompute_metrics\n\u001b[1;32m 41\u001b[0m )\n\u001b[1;32m 42\u001b[0m \u001b[38;5;66;03m# trainer.label_names = ['labels']\u001b[39;00m\n\u001b[0;32m---> 43\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mpast_key_values\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n",
  1019. "File \u001b[0;32m~/anaconda3/envs/deep/lib/python3.10/site-packages/transformers/trainer.py:1555\u001b[0m, in \u001b[0;36mTrainer.train\u001b[0;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[0m\n\u001b[1;32m 1553\u001b[0m hf_hub_utils\u001b[38;5;241m.\u001b[39menable_progress_bars()\n\u001b[1;32m 1554\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1555\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minner_training_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1556\u001b[0m \u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1557\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1558\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrial\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1559\u001b[0m \u001b[43m \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1560\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n",
  1020. "File \u001b[0;32m~/anaconda3/envs/deep/lib/python3.10/site-packages/transformers/trainer.py:1837\u001b[0m, in \u001b[0;36mTrainer._inner_training_loop\u001b[0;34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m 1834\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_handler\u001b[38;5;241m.\u001b[39mon_step_begin(args, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol)\n\u001b[1;32m 1836\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39maccumulate(model):\n\u001b[0;32m-> 1837\u001b[0m tr_loss_step \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1839\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[1;32m 1840\u001b[0m args\u001b[38;5;241m.\u001b[39mlogging_nan_inf_filter\n\u001b[1;32m 1841\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torch_tpu_available()\n\u001b[1;32m 1842\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m (torch\u001b[38;5;241m.\u001b[39misnan(tr_loss_step) \u001b[38;5;129;01mor\u001b[39;00m torch\u001b[38;5;241m.\u001b[39misinf(tr_loss_step))\n\u001b[1;32m 1843\u001b[0m ):\n\u001b[1;32m 1844\u001b[0m \u001b[38;5;66;03m# if loss is nan or inf simply add the average of previous logged losses\u001b[39;00m\n\u001b[1;32m 1845\u001b[0m tr_loss \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m tr_loss \u001b[38;5;241m/\u001b[39m (\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mglobal_step \u001b[38;5;241m-\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_globalstep_last_logged)\n",
  1021. "File \u001b[0;32m~/anaconda3/envs/deep/lib/python3.10/site-packages/transformers/trainer.py:2693\u001b[0m, in \u001b[0;36mTrainer.training_step\u001b[0;34m(self, model, inputs)\u001b[0m\n\u001b[1;32m 2691\u001b[0m scaled_loss\u001b[38;5;241m.\u001b[39mbackward()\n\u001b[1;32m 2692\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 2693\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43maccelerator\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[43mloss\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2695\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m loss\u001b[38;5;241m.\u001b[39mdetach() \u001b[38;5;241m/\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mgradient_accumulation_steps\n",
  1022. "File \u001b[0;32m~/anaconda3/envs/deep/lib/python3.10/site-packages/accelerate/accelerator.py:1923\u001b[0m, in \u001b[0;36mAccelerator.backward\u001b[0;34m(self, loss, **kwargs)\u001b[0m\n\u001b[1;32m 1921\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mscaler\u001b[38;5;241m.\u001b[39mscale(loss)\u001b[38;5;241m.\u001b[39mbackward(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 1922\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1923\u001b[0m \u001b[43mloss\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
  1023. "File \u001b[0;32m~/anaconda3/envs/deep/lib/python3.10/site-packages/torch/_tensor.py:488\u001b[0m, in \u001b[0;36mTensor.backward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 478\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function_unary(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 479\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m 480\u001b[0m Tensor\u001b[38;5;241m.\u001b[39mbackward,\n\u001b[1;32m 481\u001b[0m (\u001b[38;5;28mself\u001b[39m,),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 486\u001b[0m inputs\u001b[38;5;241m=\u001b[39minputs,\n\u001b[1;32m 487\u001b[0m )\n\u001b[0;32m--> 488\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautograd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 489\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgradient\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\n\u001b[1;32m 490\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
  1024. "File \u001b[0;32m~/anaconda3/envs/deep/lib/python3.10/site-packages/torch/autograd/__init__.py:197\u001b[0m, in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 192\u001b[0m retain_graph \u001b[38;5;241m=\u001b[39m create_graph\n\u001b[1;32m 194\u001b[0m \u001b[38;5;66;03m# The reason we repeat same the comment below is that\u001b[39;00m\n\u001b[1;32m 195\u001b[0m \u001b[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[1;32m 196\u001b[0m \u001b[38;5;66;03m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[0;32m--> 197\u001b[0m \u001b[43mVariable\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_execution_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_backward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[1;32m 198\u001b[0m \u001b[43m \u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgrad_tensors_\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 199\u001b[0m \u001b[43m \u001b[49m\u001b[43mallow_unreachable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maccumulate_grad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n",
  1025. "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
  1026. ]
  1027. }
  1028. ],
  1029. "source": [
  1030. "from transformers import TrainingArguments, Trainer, DataCollatorWithPadding\n",
  1031. "from sklearn.metrics import classification_report\n",
  1032. "\n",
  1033. "\n",
  1034. "def compute_metrics(pred):\n",
  1035. " true_labels = pred.label_ids.ravel()\n",
  1036. " pred_labels = pred.predictions.argmax(-1).ravel()\n",
  1037. " report = classification_report(true_labels, pred_labels, output_dict=True)\n",
  1038. " return {\n",
  1039. " 'accuracy': report['accuracy'],\n",
  1040. " 'f1-score-1': report['1']['f1-score'],\n",
  1041. " 'f1-score-ma': report['macro avg']['f1-score']\n",
  1042. " }\n",
  1043. "\n",
  1044. "col_fn = DataCollatorWithPadding(\n",
  1045. " manager.tokenizer, return_tensors='pt', padding='longest'\n",
  1046. ")\n",
  1047. "\n",
  1048. "training_args = TrainingArguments(\n",
  1049. " evaluation_strategy=\"epoch\",\n",
  1050. " save_strategy=\"epoch\",\n",
  1051. " # The next 2 lines are important to ensure the dataset labels are properly passed to the model\n",
  1052. " remove_unused_columns=False,\n",
  1053. " label_names=['labels'],\n",
  1054. " **{\n",
  1055. " 'output_dir': '/disks/part4/trash',\n",
  1056. " 'num_train_epochs': 160,\n",
  1057. " 'learning_rate': 0.00001,\n",
  1058. " 'per_device_train_batch_size': 32,\n",
  1059. " 'per_device_eval_batch_size': 32\n",
  1060. " }\n",
  1061. ")\n",
  1062. "\n",
  1063. "trainer = Trainer(\n",
  1064. " model=manager.current_mix_model, # manager.current_mix_model\n",
  1065. " args=training_args,\n",
  1066. " train_dataset=dataset['train'],\n",
  1067. " eval_dataset=dataset['validation'],\n",
  1068. " data_collator=col_fn,\n",
  1069. " compute_metrics=compute_metrics\n",
  1070. ")\n",
  1071. "# trainer.label_names = ['labels']\n",
  1072. "trainer.train(ignore_keys_for_eval=[\"past_key_values\"])"
  1073. ]
  1074. },
  1075. {
  1076. "cell_type": "code",
  1077. "execution_count": 14,
  1078. "id": "60f9a209-1a05-4f6c-b450-773f788b93a0",
  1079. "metadata": {
  1080. "tags": []
  1081. },
  1082. "outputs": [
  1083. {
  1084. "data": {
  1085. "text/plain": [
  1086. "0.6912751677852349"
  1087. ]
  1088. },
  1089. "execution_count": 14,
  1090. "metadata": {},
  1091. "output_type": "execute_result"
  1092. }
  1093. ],
  1094. "source": [
  1095. "import numpy as np\n",
  1096. "np.mean(dataset['validation']['label'].numpy())"
  1097. ]
  1098. },
  1099. {
  1100. "cell_type": "markdown",
  1101. "id": "8806b4a8-2f34-4e42-b0cd-d6b53d2b465b",
  1102. "metadata": {},
  1103. "source": [
  1104. "# debug"
  1105. ]
  1106. },
  1107. {
  1108. "cell_type": "code",
  1109. "execution_count": null,
  1110. "id": "559e0faa-f179-4b0a-b4b3-246260fc9056",
  1111. "metadata": {
  1112. "tags": []
  1113. },
  1114. "outputs": [],
  1115. "source": [
  1116. "inputs = col_fn(dataset['validation'][0:50]).to(DEVICE)\n",
  1117. "outputs = manager.current_mix_model(**inputs)\n",
  1118. "outputs.loss.backward()"
  1119. ]
  1120. },
  1121. {
  1122. "cell_type": "code",
  1123. "execution_count": null,
  1124. "id": "43cf9e69-b9a8-4059-a8e5-cab21635388c",
  1125. "metadata": {
  1126. "tags": []
  1127. },
  1128. "outputs": [],
  1129. "source": [
  1130. "for i in range(6, 12):\n",
  1131. " o = manager.current_mix_model.model.h[i].soft_prompt.sadcl_learned_embedding.grad.abs().sum().item()\n",
  1132. " print(i, o)"
  1133. ]
  1134. },
  1135. {
  1136. "cell_type": "code",
  1137. "execution_count": null,
  1138. "id": "ff6980bd-52c0-497b-b272-0be58044ee2f",
  1139. "metadata": {
  1140. "tags": []
  1141. },
  1142. "outputs": [],
  1143. "source": [
  1144. "manager.current_mix_model.sadcl_head.score.weight.grad"
  1145. ]
  1146. },
  1147. {
  1148. "cell_type": "code",
  1149. "execution_count": null,
  1150. "id": "dd30fcbb-8385-4cf0-b2ce-26fa961e26c9",
  1151. "metadata": {
  1152. "tags": []
  1153. },
  1154. "outputs": [],
  1155. "source": [
  1156. "raise Exception()"
  1157. ]
  1158. },
  1159. {
  1160. "cell_type": "code",
  1161. "execution_count": null,
  1162. "id": "0a788c20-5f98-447e-970a-4e96a4694976",
  1163. "metadata": {
  1164. "tags": []
  1165. },
  1166. "outputs": [],
  1167. "source": [
  1168. "from transformers import GPT2ForSequenceClassification\n",
  1169. "\n",
  1170. "mtest = GPT2ForSequenceClassification.from_pretrained('gpt2', pad_token_id=manager.tokenizer.pad_token_id)\n",
  1171. "mtest.to(DEVICE)\n",
  1172. "\n",
  1173. "training_args = TrainingArguments(\n",
  1174. " evaluation_strategy=\"epoch\",\n",
  1175. " save_strategy=\"epoch\",\n",
  1176. " # The next 2 lines are important to ensure the dataset labels are properly passed to the model\n",
  1177. " remove_unused_columns=False,\n",
  1178. " label_names=['labels'],\n",
  1179. " **\n",
  1180. " {\n",
  1181. " 'output_dir': '/home/mohalisad/Developer/Thesis/cp3',\n",
  1182. " 'num_train_epochs': 80,\n",
  1183. " 'learning_rate': 0.00001,\n",
  1184. " 'per_device_train_batch_size': 32,\n",
  1185. " 'per_device_eval_batch_size': 32\n",
  1186. " }\n",
  1187. ")\n",
  1188. "\n",
  1189. "trainer = Trainer(\n",
  1190. " model=mtest, # manager.current_mix_model\n",
  1191. " args=training_args,\n",
  1192. " train_dataset=dataset['train'],\n",
  1193. " eval_dataset=dataset['validation'],\n",
  1194. " data_collator=col_fn,\n",
  1195. " compute_metrics=compute_metrics\n",
  1196. ")\n",
  1197. "# trainer.label_names = ['labels']\n",
  1198. "trainer.train()"
  1199. ]
  1200. },
  1201. {
  1202. "cell_type": "markdown",
  1203. "id": "ebb6c1f3-104d-4185-a6db-1aade4a4c9c9",
  1204. "metadata": {},
  1205. "source": [
  1206. "# Trash"
  1207. ]
  1208. },
  1209. {
  1210. "cell_type": "code",
  1211. "execution_count": null,
  1212. "id": "f135d55d-b0cf-4b2b-bc4a-209ee70ca88b",
  1213. "metadata": {
  1214. "tags": []
  1215. },
  1216. "outputs": [],
  1217. "source": [
  1218. "def map_inputs(str_list):\n",
  1219. " tokens = manager.generate_tokenizer_map()(str_list)\n",
  1220. " col_fn = DataCollatorWithPadding(manager.tokenizer)\n",
  1221. " return col_fn(tokens).to(DEVICE)\n",
  1222. " \n",
  1223. "inputs = map_inputs([\"Hello, my dog is cute\", \"bye\", \"why are\"])\n",
  1224. "label = torch.tensor([0, 1, 1], device=DEVICE)\n",
  1225. "outputs = manager.current_mix_model(label=label, **inputs)"
  1226. ]
  1227. },
  1228. {
  1229. "cell_type": "code",
  1230. "execution_count": null,
  1231. "id": "1a6d396c-59b4-43f8-9961-620ae96df172",
  1232. "metadata": {
  1233. "tags": []
  1234. },
  1235. "outputs": [],
  1236. "source": [
  1237. "token_ids = manager.tokenizer(INIT_TEXT, return_tensors='pt')['input_ids'].to(DEVICE)"
  1238. ]
  1239. },
  1240. {
  1241. "cell_type": "code",
  1242. "execution_count": null,
  1243. "id": "abb508ae-8fa0-47bf-a988-41147b739685",
  1244. "metadata": {
  1245. "tags": []
  1246. },
  1247. "outputs": [],
  1248. "source": [
  1249. "token_ids"
  1250. ]
  1251. },
  1252. {
  1253. "cell_type": "code",
  1254. "execution_count": null,
  1255. "id": "e29dee57-e30e-4248-b579-e0a77391339c",
  1256. "metadata": {
  1257. "tags": []
  1258. },
  1259. "outputs": [],
  1260. "source": [
  1261. "manager.model.wte(token_ids).shape"
  1262. ]
  1263. },
  1264. {
  1265. "cell_type": "code",
  1266. "execution_count": null,
  1267. "id": "5ef27833-e9a5-44eb-a235-7a63c1d273d9",
  1268. "metadata": {
  1269. "tags": []
  1270. },
  1271. "outputs": [],
  1272. "source": [
  1273. "outputs.loss"
  1274. ]
  1275. },
  1276. {
  1277. "cell_type": "code",
  1278. "execution_count": null,
  1279. "id": "6a6f1ffd-ca5c-4381-94f5-5d09285a4c93",
  1280. "metadata": {
  1281. "tags": []
  1282. },
  1283. "outputs": [],
  1284. "source": [
  1285. "manager.model.h[9].original_module.attn.c_attn.weight.grad"
  1286. ]
  1287. },
  1288. {
  1289. "cell_type": "code",
  1290. "execution_count": null,
  1291. "id": "09e588e0-f7bd-4b55-9201-207e6065da06",
  1292. "metadata": {
  1293. "tags": []
  1294. },
  1295. "outputs": [],
  1296. "source": [
  1297. "(torch.tensor([0, 1, 0]) == 0).any()\n"
  1298. ]
  1299. },
  1300. {
  1301. "cell_type": "code",
  1302. "execution_count": null,
  1303. "id": "eb556578-7ec2-4d08-a215-029c677e4878",
  1304. "metadata": {
  1305. "tags": []
  1306. },
  1307. "outputs": [],
  1308. "source": [
  1309. "manager.model.h[9].soft_prompt.sadcl_learned_embedding.grad"
  1310. ]
  1311. },
  1312. {
  1313. "cell_type": "code",
  1314. "execution_count": null,
  1315. "id": "8c30d534-2edb-4ab9-bde9-44e4b83de259",
  1316. "metadata": {
  1317. "tags": []
  1318. },
  1319. "outputs": [],
  1320. "source": [
  1321. "outputs.last_hidden_state.sum().backward()"
  1322. ]
  1323. },
  1324. {
  1325. "cell_type": "code",
  1326. "execution_count": null,
  1327. "id": "c13de4da-54f1-4611-93cd-c2d821112d0c",
  1328. "metadata": {},
  1329. "outputs": [],
  1330. "source": [
  1331. "\n",
  1332. "\n",
  1333. "last_hidden_states = outputs.last_hidden_state\n",
  1334. "inputs = tokenizer([\"Hello, my dog is cute\", \"bye\"])\n",
  1335. "outputs = model(**inputs)\n",
  1336. "\n",
  1337. "last_hidden_states = outputs.last_hidden_state"
  1338. ]
  1339. },
  1340. {
  1341. "cell_type": "code",
  1342. "execution_count": null,
  1343. "id": "057a79bc-efe1-4b73-b990-d4d6d445ff3e",
  1344. "metadata": {
  1345. "tags": []
  1346. },
  1347. "outputs": [],
  1348. "source": [
  1349. "inputs"
  1350. ]
  1351. },
  1352. {
  1353. "cell_type": "code",
  1354. "execution_count": null,
  1355. "id": "10235ef7-ed2c-47dd-ab5e-096072bc6cd0",
  1356. "metadata": {
  1357. "tags": []
  1358. },
  1359. "outputs": [],
  1360. "source": [
  1361. "\n",
  1362. "inputs = tokenize_dataset([\"Hello, my dog is cute\", \"bye\"])\n",
  1363. "inputs"
  1364. ]
  1365. },
  1366. {
  1367. "cell_type": "code",
  1368. "execution_count": null,
  1369. "id": "a8ca1b81-f0a6-40f5-853d-a154585b61b3",
  1370. "metadata": {
  1371. "tags": []
  1372. },
  1373. "outputs": [],
  1374. "source": [
  1375. "tokenizer.eos_token"
  1376. ]
  1377. },
  1378. {
  1379. "cell_type": "code",
  1380. "execution_count": null,
  1381. "id": "7d88682f-9e3c-4e80-84d0-448f7ed93bc4",
  1382. "metadata": {},
  1383. "outputs": [],
  1384. "source": [
  1385. "x = nn.Parameter(torch.arange(27).reshape(3, 3, 3).float())\n",
  1386. "x"
  1387. ]
  1388. },
  1389. {
  1390. "cell_type": "code",
  1391. "execution_count": null,
  1392. "id": "3115c7f4-94a3-4526-8373-ec524fe46c1c",
  1393. "metadata": {
  1394. "tags": []
  1395. },
  1396. "outputs": [],
  1397. "source": [
  1398. "b = nn.Parameter(torch.tensor([7, 7, 7]).float())\n",
  1399. "b"
  1400. ]
  1401. },
  1402. {
  1403. "cell_type": "code",
  1404. "execution_count": null,
  1405. "id": "30ec47b9-1636-491b-85c6-24415933fb3d",
  1406. "metadata": {
  1407. "tags": []
  1408. },
  1409. "outputs": [],
  1410. "source": [
  1411. "x[0, 0, :] = torch.tensor([7, 7, 7])"
  1412. ]
  1413. },
  1414. {
  1415. "cell_type": "code",
  1416. "execution_count": null,
  1417. "id": "b0a87c7c-9740-412e-9ad3-a7f3f06e90bc",
  1418. "metadata": {},
  1419. "outputs": [],
  1420. "source": []
  1421. }
  1422. ],
  1423. "metadata": {
  1424. "kernelspec": {
  1425. "display_name": "Python [conda env:deep]",
  1426. "language": "python",
  1427. "name": "conda-env-deep-py"
  1428. },
  1429. "language_info": {
  1430. "codemirror_mode": {
  1431. "name": "ipython",
  1432. "version": 3
  1433. },
  1434. "file_extension": ".py",
  1435. "mimetype": "text/x-python",
  1436. "name": "python",
  1437. "nbconvert_exporter": "python",
  1438. "pygments_lexer": "ipython3",
  1439. "version": "3.10.11"
  1440. }
  1441. },
  1442. "nbformat": 4,
  1443. "nbformat_minor": 5
  1444. }