You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

Untitled.ipynb 16KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "code",
  5. "execution_count": 1,
  6. "id": "55d641c5-ae0e-42af-afba-65dab055734e",
  7. "metadata": {
  8. "tags": []
  9. },
  10. "outputs": [],
  11. "source": [
  12. "OPENAI_TOKEN = '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'"
  13. ]
  14. },
  15. {
  16. "cell_type": "code",
  17. "execution_count": 2,
  18. "id": "86ec3895-06b0-4601-a08f-756d286653b3",
  19. "metadata": {
  20. "tags": []
  21. },
  22. "outputs": [],
  23. "source": [
  24. "from langchain.chat_models import ChatOpenAI\n",
  25. "from langchain.schema.messages import HumanMessage, SystemMessage\n",
  26. "\n",
  27. "chat = ChatOpenAI(openai_api_key=OPENAI_TOKEN, temperature=0)"
  28. ]
  29. },
  30. {
  31. "cell_type": "code",
  32. "execution_count": 3,
  33. "id": "2e75b407-27a6-4651-b240-0b370424d837",
  34. "metadata": {
  35. "tags": []
  36. },
  37. "outputs": [],
  38. "source": [
  39. "import sys\n",
  40. "sys.path.append('/home/msadraei/developer/Thesis')"
  41. ]
  42. },
  43. {
  44. "cell_type": "code",
  45. "execution_count": 5,
  46. "id": "79a19f7f-0c9d-44a5-8089-d89f3e8ac43a",
  47. "metadata": {
  48. "tags": []
  49. },
  50. "outputs": [],
  51. "source": [
  52. "from _datasets.glue_helper import SuperGLUEHelper, GLUEHelper"
  53. ]
  54. },
  55. {
  56. "cell_type": "code",
  57. "execution_count": 6,
  58. "id": "f57eace5-57d2-4d0c-908d-20c0f5844f8e",
  59. "metadata": {
  60. "tags": []
  61. },
  62. "outputs": [],
  63. "source": [
  64. "glue_helper = GLUEHelper()\n",
  65. "superglue_helper = SuperGLUEHelper()"
  66. ]
  67. },
  68. {
  69. "cell_type": "code",
  70. "execution_count": 9,
  71. "id": "80bc73c9-c8f5-42cb-a024-2b825c0b1bea",
  72. "metadata": {
  73. "tags": []
  74. },
  75. "outputs": [
  76. {
  77. "data": {
  78. "text/plain": [
  79. "{'paragraph': 'While this process moved along, diplomacy continued its rounds. Direct pressure on the Taliban had proved unsuccessful. As one NSC staff note put it, \"Under the Taliban, Afghanistan is not so much a state sponsor of terrorism as it is a state sponsored by terrorists.\" In early 2000, the United States began a high-level effort to persuade Pakistan to use its influence over the Taliban. In January 2000, Assistant Secretary of State Karl Inderfurth and the State Department\\'s counterterrorism coordinator, Michael Sheehan, met with General Musharraf in Islamabad, dangling before him the possibility of a presidential visit in March as a reward for Pakistani cooperation. Such a visit was coveted by Musharraf, partly as a sign of his government\\'s legitimacy. He told the two envoys that he would meet with Mullah Omar and press him on Bin Laden. They left, however, reporting to Washington that Pakistan was unlikely in fact to do anything,\" given what it sees as the benefits of Taliban control of Afghanistan.\" President Clinton was scheduled to travel to India. The State Department felt that he should not visit India without also visiting Pakistan. The Secret Service and the CIA, however, warned in the strongest terms that visiting Pakistan would risk the President\\'s life. Counterterrorism officials also argued that Pakistan had not done enough to merit a presidential visit. But President Clinton insisted on including Pakistan in the itinerary for his trip to South Asia. His one-day stopover on March 25, 2000, was the first time a U.S. president had been there since 1969. At his meeting with Musharraf and others, President Clinton concentrated on tensions between Pakistan and India and the dangers of nuclear proliferation, but also discussed Bin Laden. President Clinton told us that when he pulled Musharraf aside for a brief, one-on-one meeting, he pleaded with the general for help regarding Bin Laden.\" I offered him the moon when I went to see him, in terms of better relations with the United States, if he\\'d help us get Bin Laden and deal with another issue or two.\" The U.S. effort continued. ',\n",
  80. " 'question': 'What did the high-level effort to persuade Pakistan include?',\n",
  81. " 'answer': 'Children, Gerd, or Dorian Popa',\n",
  82. " 'idx': {'paragraph': 0, 'question': 0, 'answer': 0},\n",
  83. " 'label': 0}"
  84. ]
  85. },
  86. "execution_count": 9,
  87. "metadata": {},
  88. "output_type": "execute_result"
  89. }
  90. ],
  91. "source": [
  92. "superglue_helper.datasets['multirc']['train'][0]"
  93. ]
  94. },
  95. {
  96. "cell_type": "code",
  97. "execution_count": 13,
  98. "id": "392f5304-00e8-41ec-aab5-0bd34e6bb3e7",
  99. "metadata": {
  100. "tags": []
  101. },
  102. "outputs": [],
  103. "source": [
  104. "import json\n",
  105. "import numpy as np\n",
  106. "from evaluate import load\n",
  107. "\n",
  108. "prompt_template = 'input = {input}\\noutput = {output}'\n",
  109. "\n",
  110. "def prepare_wic(input_dict_row):\n",
  111. " word = input_dict_row['word']\n",
  112. " sent1 = input_dict_row['sentence1']\n",
  113. " sent2 = input_dict_row['sentence2']\n",
  114. " slice1 = slice(input_dict_row['start1'], input_dict_row['end1'])\n",
  115. " slice2 = slice(input_dict_row['start2'], input_dict_row['end2'])\n",
  116. "\n",
  117. " anotate_word = lambda _sent, _slice: _sent[:_slice.start] + \" ** \" + _sent[_slice] + \" ** \" + _sent[_slice.stop:]\n",
  118. " input_dict_row['sentence1'] = anotate_word(sent1, slice1)\n",
  119. " input_dict_row['sentence2'] = anotate_word(sent2, slice2)\n",
  120. "\n",
  121. " return {\n",
  122. " 'sentence1': input_dict_row['sentence1'],\n",
  123. " 'sentence2': input_dict_row['sentence2']\n",
  124. " }\n",
  125. "\n",
  126. "def make_chatgpt_ready(ds_helper, task_name):\n",
  127. " ds = ds_helper.datasets[task_name]\n",
  128. " if task_name == 'wic':\n",
  129. " ds = {\n",
  130. " split: [\n",
  131. " {\n",
  132. " **prepare_wic(row),\n",
  133. " 'label': row['label'],\n",
  134. " 'idx': 0\n",
  135. " } for row in ds[split]\n",
  136. " ]\n",
  137. " for split in ['train', 'validation']\n",
  138. " }\n",
  139. " if task_name not in ['wic', 'boolq', 'cb', 'copa', 'cola', 'mrpc', 'rte', 'sst2', 'multirc']:\n",
  140. " np.random.seed(42)\n",
  141. " validation_samples = np.random.choice(range(len(ds['validation'])), replace=False, size=2000).tolist()\n",
  142. " ds = {\n",
  143. " 'train': ds['train'],\n",
  144. " 'validation': [ds['validation'][idx] for idx in validation_samples]\n",
  145. " }\n",
  146. " task_out = ds_helper.get_task_output(task_name)\n",
  147. " \n",
  148. " all_labels = [row['label'] for row in ds['validation']]\n",
  149. " if task_name == 'multirc':\n",
  150. " all_idx = ds['validation']['idx']\n",
  151. " def compute_metric(y_pred):\n",
  152. " glue_metric = load(ds_helper.base_name, task_name)\n",
  153. " y_pred = [\n",
  154. " task_out.str2int(json.loads(item)['label'])\n",
  155. " for item in y_pred\n",
  156. " ]\n",
  157. " assert len(all_idx) == len(y_pred)\n",
  158. " y_pred = [\n",
  159. " {\n",
  160. " 'prediction': y_pred_item,\n",
  161. " 'idx': idx\n",
  162. " } for (y_pred_item, idx) in zip(y_pred, all_idx)\n",
  163. " ]\n",
  164. " return glue_metric.compute(predictions=y_pred, references=all_labels)\n",
  165. " else:\n",
  166. " def compute_metric(y_pred):\n",
  167. " glue_metric = load(ds_helper.base_name, task_name)\n",
  168. " all_preds = [\n",
  169. " task_out.str2int(json.loads(item)['label'])\n",
  170. " for item in y_pred\n",
  171. " ]\n",
  172. " return glue_metric.compute(predictions=all_preds, references=all_labels)\n",
  173. " \n",
  174. " few_exmples = {}\n",
  175. " for row in ds['train']:\n",
  176. " if row['label'] not in few_exmples:\n",
  177. " label = row.pop('label')\n",
  178. " row.pop('idx')\n",
  179. " few_exmples[label] = row\n",
  180. " \n",
  181. " class_names = json.dumps(task_out.names)\n",
  182. " pre_prompt_parts = [f'class_names = {class_names}']\n",
  183. " for label_id, example in few_exmples.items():\n",
  184. " pre_prompt_parts.append(\n",
  185. " prompt_template.format(\n",
  186. " input = json.dumps(example),\n",
  187. " output = json.dumps({'label': task_out.int2str(label_id)})\n",
  188. " )\n",
  189. " )\n",
  190. " \n",
  191. " prompt_str = []\n",
  192. " for row in ds['validation']:\n",
  193. " row.pop('label')\n",
  194. " row.pop('idx')\n",
  195. " prompt_parts = pre_prompt_parts + [\n",
  196. " prompt_template.format(\n",
  197. " input = json.dumps(row),\n",
  198. " output = ''\n",
  199. " )\n",
  200. " ]\n",
  201. " prompt_str.append('\\n'.join(prompt_parts))\n",
  202. " \n",
  203. " return prompt_str, compute_metric"
  204. ]
  205. },
  206. {
  207. "cell_type": "code",
  208. "execution_count": 14,
  209. "id": "9304b06b-1c8c-4654-b074-c442f3aa3ed4",
  210. "metadata": {
  211. "tags": []
  212. },
  213. "outputs": [],
  214. "source": [
  215. "def make_chatgpt_ready_stsb(ds_helper, task_name):\n",
  216. " ds = ds_helper.datasets[task_name]\n",
  217. " task_out = ds_helper.get_task_output(task_name)\n",
  218. " \n",
  219. " all_labels = [row['label'] for row in ds['validation']]\n",
  220. " def compute_metric(y_pred):\n",
  221. " glue_metric = load(ds_helper.base_name, task_name)\n",
  222. " all_preds = [\n",
  223. " task_out.str2int(json.loads(item)['label'])\n",
  224. " for item in y_pred\n",
  225. " ]\n",
  226. " return glue_metric.compute(predictions=all_preds, references=all_labels)\n",
  227. " \n",
  228. " few_exmples = {}\n",
  229. " for row in ds['train']:\n",
  230. " row['label'] = task_out.int2str(row['label'])\n",
  231. " if row['label'] not in few_exmples:\n",
  232. " label = row.pop('label')\n",
  233. " row.pop('idx')\n",
  234. " few_exmples[label] = row\n",
  235. " \n",
  236. " class_names = list(sorted(few_exmples.keys()))\n",
  237. " pre_prompt_parts = [f'class_names = {class_names}']\n",
  238. " for label_id, example in few_exmples.items():\n",
  239. " pre_prompt_parts.append(\n",
  240. " prompt_template.format(\n",
  241. " input = json.dumps(example),\n",
  242. " output = json.dumps({'label': label_id})\n",
  243. " )\n",
  244. " )\n",
  245. " \n",
  246. " prompt_str = []\n",
  247. " for row in ds['validation']:\n",
  248. " row.pop('label')\n",
  249. " row.pop('idx')\n",
  250. " prompt_parts = pre_prompt_parts + [\n",
  251. " prompt_template.format(\n",
  252. " input = json.dumps(row),\n",
  253. " output = ''\n",
  254. " )\n",
  255. " ]\n",
  256. " prompt_str.append('\\n'.join(prompt_parts))\n",
  257. " \n",
  258. " return prompt_str, compute_metric"
  259. ]
  260. },
  261. {
  262. "cell_type": "code",
  263. "execution_count": 19,
  264. "id": "afe4b96f-2948-4544-9397-121a10319bf6",
  265. "metadata": {
  266. "tags": []
  267. },
  268. "outputs": [],
  269. "source": [
  270. "task_name = 'multirc'\n",
  271. "prompts, compute_metric = make_chatgpt_ready(superglue_helper, task_name)"
  272. ]
  273. },
  274. {
  275. "cell_type": "code",
  276. "execution_count": null,
  277. "id": "6cec4a27-bcfc-4699-9555-9d2cefcdfcaa",
  278. "metadata": {
  279. "tags": []
  280. },
  281. "outputs": [],
  282. "source": [
  283. "from tqdm import tqdm\n",
  284. "\n",
  285. "# all_results = []\n",
  286. "for prompt in tqdm(prompts):\n",
  287. " messages = [\n",
  288. " SystemMessage(content=\"You are going to be used as a model for natural language understanding task. Read the json input and output carefully and according to the few-shot examples, classify the input. Your output label must be a member of 'class_names'. Your task is according to the paragraph the answer of question is True of False.\"),\n",
  289. " HumanMessage(content=prompt)\n",
  290. " ]\n",
  291. " all_results.append(chat.invoke(messages).content)"
  292. ]
  293. },
  294. {
  295. "cell_type": "code",
  296. "execution_count": 30,
  297. "id": "57acf17a-8aa1-4f7a-90b3-dd69460d81df",
  298. "metadata": {
  299. "tags": []
  300. },
  301. "outputs": [
  302. {
  303. "name": "stderr",
  304. "output_type": "stream",
  305. "text": [
  306. "100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 504/504 [08:28<00:00, 1.01s/it]\n"
  307. ]
  308. }
  309. ],
  310. "source": [
  311. "for prompt in tqdm(prompts[len(all_results):]):\n",
  312. " messages = [\n",
  313. " SystemMessage(content=\"You are going to be used as a model for natural language understanding task. Read the json input and output carefully and according to the few-shot examples, classify the input. Your output label must be a member of 'class_names'. Your task is according to the paragraph the answer of question is True of False.\"),\n",
  314. " HumanMessage(content=prompt)\n",
  315. " ]\n",
  316. " all_results.append(chat.invoke(messages).content)"
  317. ]
  318. },
  319. {
  320. "cell_type": "code",
  321. "execution_count": 118,
  322. "id": "8e2ea4da-4710-42fa-befc-0c93fd8e5df0",
  323. "metadata": {
  324. "tags": []
  325. },
  326. "outputs": [],
  327. "source": [
  328. "# def conv_res(inp):\n",
  329. "# if 'label' in inp:\n",
  330. "# return inp\n",
  331. "# return json.dumps({'label': inp})\n",
  332. "\n",
  333. "# all_results_conv = [conv_res(x) for x in all_results]"
  334. ]
  335. },
  336. {
  337. "cell_type": "code",
  338. "execution_count": 31,
  339. "id": "15f18e92-80ca-4b7c-87e6-20d694e8cca1",
  340. "metadata": {
  341. "tags": []
  342. },
  343. "outputs": [
  344. {
  345. "data": {
  346. "text/plain": [
  347. "{'exact_match': 0.3410283315844701,\n",
  348. " 'f1_m': 0.728404774590195,\n",
  349. " 'f1_a': 0.7791361043194783}"
  350. ]
  351. },
  352. "execution_count": 31,
  353. "metadata": {},
  354. "output_type": "execute_result"
  355. }
  356. ],
  357. "source": [
  358. "result = compute_metric(all_results)\n",
  359. "result"
  360. ]
  361. },
  362. {
  363. "cell_type": "code",
  364. "execution_count": 33,
  365. "id": "1041840c-4590-4034-8e64-cbdc215a11a8",
  366. "metadata": {
  367. "tags": []
  368. },
  369. "outputs": [
  370. {
  371. "data": {
  372. "text/plain": [
  373. "0.555"
  374. ]
  375. },
  376. "execution_count": 33,
  377. "metadata": {},
  378. "output_type": "execute_result"
  379. }
  380. ],
  381. "source": [
  382. "(0.77 + 0.34) / 2"
  383. ]
  384. },
  385. {
  386. "cell_type": "code",
  387. "execution_count": 32,
  388. "id": "6171134d-45ba-4bc8-991c-8fbd1cb7d370",
  389. "metadata": {
  390. "tags": []
  391. },
  392. "outputs": [],
  393. "source": [
  394. "with open(f'./{task_name}.json', 'w') as f:\n",
  395. " json.dump(result, f)"
  396. ]
  397. },
  398. {
  399. "cell_type": "code",
  400. "execution_count": 54,
  401. "id": "2fca5a91-dbba-4768-9b9f-82f56619f2fb",
  402. "metadata": {
  403. "tags": []
  404. },
  405. "outputs": [
  406. {
  407. "data": {
  408. "text/plain": [
  409. "'class_names = [\"False\", \"True\"]\\ninput = {\"sentence1\": \"Do you want to come over to my ** place ** later?\", \"sentence2\": \"A political system with no ** place ** for the less prominent groups.\"}\\noutput = {\"label\": \"False\"}\\ninput = {\"sentence1\": \"The general ordered the colonel to ** hold ** his position at all costs.\", \"sentence2\": \" ** Hold ** the taxi.\"}\\noutput = {\"label\": \"True\"}\\ninput = {\"sentence1\": \"An emerging professional ** class ** .\", \"sentence2\": \"Apologizing for losing your temper, even though you were badly provoked, showed real ** class ** .\"}\\noutput = '"
  410. ]
  411. },
  412. "execution_count": 54,
  413. "metadata": {},
  414. "output_type": "execute_result"
  415. }
  416. ],
  417. "source": [
  418. "prompts[0]"
  419. ]
  420. },
  421. {
  422. "cell_type": "code",
  423. "execution_count": null,
  424. "id": "229572a2-20ac-43d6-b370-7812deef23cd",
  425. "metadata": {},
  426. "outputs": [],
  427. "source": []
  428. }
  429. ],
  430. "metadata": {
  431. "kernelspec": {
  432. "display_name": "Python [conda env:openai]",
  433. "language": "python",
  434. "name": "conda-env-openai-py"
  435. },
  436. "language_info": {
  437. "codemirror_mode": {
  438. "name": "ipython",
  439. "version": 3
  440. },
  441. "file_extension": ".py",
  442. "mimetype": "text/x-python",
  443. "name": "python",
  444. "nbconvert_exporter": "python",
  445. "pygments_lexer": "ipython3",
  446. "version": "3.10.13"
  447. }
  448. },
  449. "nbformat": 4,
  450. "nbformat_minor": 5
  451. }