You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

attempt.ipynb 5.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "code",
  5. "execution_count": 1,
  6. "id": "e6ecf439-a0db-42e0-a6b9-f512198b0e0e",
  7. "metadata": {
  8. "tags": []
  9. },
  10. "outputs": [],
  11. "source": [
  12. "import torch"
  13. ]
  14. },
  15. {
  16. "cell_type": "code",
  17. "execution_count": 4,
  18. "id": "4bcc7c7e-711a-4cd9-b901-d6ff76938a75",
  19. "metadata": {
  20. "tags": []
  21. },
  22. "outputs": [],
  23. "source": [
  24. "best_path = '/home/msadraei/trained_final/iclr_resp_t5_small_glue-cola/10_attempt/best.pt'\n",
  25. "first_path = '/home/msadraei/trained_final/iclr_resp_t5_small_glue-cola/10_attempt/first.pt'"
  26. ]
  27. },
  28. {
  29. "cell_type": "code",
  30. "execution_count": 5,
  31. "id": "eaa4a300-1e6c-46f0-8f0d-16e9c71c2388",
  32. "metadata": {
  33. "tags": []
  34. },
  35. "outputs": [],
  36. "source": [
  37. "best = torch.load(best_path)\n",
  38. "first = torch.load(first_path)"
  39. ]
  40. },
  41. {
  42. "cell_type": "code",
  43. "execution_count": 8,
  44. "id": "c5e0b6bb-3bde-4526-8a6a-5dac0a3b3cc3",
  45. "metadata": {
  46. "tags": []
  47. },
  48. "outputs": [
  49. {
  50. "name": "stdout",
  51. "output_type": "stream",
  52. "text": [
  53. "sadcl_p_target\n",
  54. "tensor(42.7208, device='cuda:0')\n",
  55. "pretrained_tasks\n",
  56. "tensor(0., device='cuda:0')\n",
  57. "sadcl_attention_score.g_network.0.weight\n",
  58. "tensor(157.3032, device='cuda:0')\n",
  59. "sadcl_attention_score.g_network.2.weight\n",
  60. "tensor(154.6590, device='cuda:0')\n",
  61. "sadcl_attention_score.g_network.3.weight\n",
  62. "tensor(18.1127, device='cuda:0')\n",
  63. "sadcl_attention_score.g_network.3.bias\n",
  64. "tensor(19.0149, device='cuda:0')\n"
  65. ]
  66. }
  67. ],
  68. "source": [
  69. "for key in best.keys():\n",
  70. " print(key)\n",
  71. " v1 = first[key]\n",
  72. " v2 = best[key]\n",
  73. " print(torch.norm(v1 - v2))"
  74. ]
  75. },
  76. {
  77. "cell_type": "code",
  78. "execution_count": 13,
  79. "id": "42815cf2-b8bf-4219-a3fd-ebbe92fb5c32",
  80. "metadata": {},
  81. "outputs": [],
  82. "source": [
  83. "base_path = '/home/msadraei/trained_final/forward_transfer_test_t5_base_superglue-rte/10_combine_128_4tasks_new_impl_tie_50/100'\n",
  84. "last_path = f'{base_path}/last.pt'\n",
  85. "best_path = f'{base_path}/best.pt'\n",
  86. "first_path = f'{base_path}/first.pt'"
  87. ]
  88. },
  89. {
  90. "cell_type": "code",
  91. "execution_count": 14,
  92. "id": "880cb651-ddea-4564-93ab-c5f52e1f02dd",
  93. "metadata": {
  94. "tags": []
  95. },
  96. "outputs": [],
  97. "source": [
  98. "import torch\n",
  99. "last = torch.load(last_path)\n",
  100. "best = torch.load(best_path)\n",
  101. "first = torch.load(first_path)"
  102. ]
  103. },
  104. {
  105. "cell_type": "code",
  106. "execution_count": 15,
  107. "id": "ee4b3287-203f-49b0-8b89-6070f9ff4062",
  108. "metadata": {
  109. "tags": []
  110. },
  111. "outputs": [],
  112. "source": [
  113. "import numpy as np\n",
  114. "def pretrained_coeff(state_dict):\n",
  115. " return np.stack([\n",
  116. " val.cpu().numpy()\n",
  117. " for key, val in state_dict.items()\n",
  118. " if 'sadcl_coeff_pretrained' in key\n",
  119. " ])"
  120. ]
  121. },
  122. {
  123. "cell_type": "code",
  124. "execution_count": 16,
  125. "id": "26518ecd-8cc1-4543-acaf-56637295bbe8",
  126. "metadata": {
  127. "tags": []
  128. },
  129. "outputs": [],
  130. "source": [
  131. "last_coeff = pretrained_coeff(best)\n",
  132. "best_coeff = pretrained_coeff(best)\n",
  133. "first_coeff = pretrained_coeff(first)"
  134. ]
  135. },
  136. {
  137. "cell_type": "code",
  138. "execution_count": 17,
  139. "id": "5a850a65-724a-483d-abb3-b7de6118db31",
  140. "metadata": {
  141. "tags": []
  142. },
  143. "outputs": [
  144. {
  145. "data": {
  146. "text/plain": [
  147. "array([[0.43, 0.42, 0.42, 0.42],\n",
  148. " [0.43, 0.42, 0.42, 0.42],\n",
  149. " [0.43, 0.42, 0.42, 0.42],\n",
  150. " [0.43, 0.42, 0.42, 0.42],\n",
  151. " [0.43, 0.42, 0.42, 0.42],\n",
  152. " [0.43, 0.42, 0.42, 0.42],\n",
  153. " [0.43, 0.42, 0.42, 0.42],\n",
  154. " [0.43, 0.42, 0.42, 0.42],\n",
  155. " [0.43, 0.42, 0.42, 0.42],\n",
  156. " [0.43, 0.42, 0.42, 0.42]], dtype=float32)"
  157. ]
  158. },
  159. "execution_count": 17,
  160. "metadata": {},
  161. "output_type": "execute_result"
  162. }
  163. ],
  164. "source": [
  165. "np.round(last_coeff/ 100 , 2)\n"
  166. ]
  167. },
  168. {
  169. "cell_type": "code",
  170. "execution_count": 65,
  171. "id": "7182b595-5bb3-4c06-88dc-1f50ed774500",
  172. "metadata": {},
  173. "outputs": [
  174. {
  175. "data": {
  176. "text/plain": [
  177. "tensor(34.9105)"
  178. ]
  179. },
  180. "execution_count": 65,
  181. "metadata": {},
  182. "output_type": "execute_result"
  183. }
  184. ],
  185. "source": [
  186. "torch.linalg.vector_norm(torch.Tensor(best_coeff[0]), ord=1)"
  187. ]
  188. },
  189. {
  190. "cell_type": "code",
  191. "execution_count": null,
  192. "id": "9e2a2080-9450-4df2-b20e-4619e3f92c1b",
  193. "metadata": {},
  194. "outputs": [],
  195. "source": []
  196. }
  197. ],
  198. "metadata": {
  199. "kernelspec": {
  200. "display_name": "Python [conda env:deep]",
  201. "language": "python",
  202. "name": "conda-env-deep-py"
  203. },
  204. "language_info": {
  205. "codemirror_mode": {
  206. "name": "ipython",
  207. "version": 3
  208. },
  209. "file_extension": ".py",
  210. "mimetype": "text/x-python",
  211. "name": "python",
  212. "nbconvert_exporter": "python",
  213. "pygments_lexer": "ipython3",
  214. "version": "3.10.13"
  215. }
  216. },
  217. "nbformat": 4,
  218. "nbformat_minor": 5
  219. }