You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

Untitled.ipynb 3.6KB

3 months ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "code",
  5. "execution_count": 1,
  6. "id": "93e252d5-c7d2-48bd-9d21-70bb5694a026",
  7. "metadata": {
  8. "tags": []
  9. },
  10. "outputs": [],
  11. "source": [
  12. "from _mydelta.multi_prompt import MultiPrompt"
  13. ]
  14. },
  15. {
  16. "cell_type": "code",
  17. "execution_count": 2,
  18. "id": "c9cd7bc9-cd12-4e77-9176-d71c614a6094",
  19. "metadata": {
  20. "tags": []
  21. },
  22. "outputs": [],
  23. "source": [
  24. "from pathlib import Path\n",
  25. "path = Path('/disks/ssd/trained_final/cont_thesis/cont_thesis_t5_small_glue-cola/10_combine_128_simple')\n",
  26. "best_out = MultiPrompt.get_saved_final_emb(\n",
  27. " config_path=path / 'config.json',\n",
  28. " weights_path=path / 'best.pt'\n",
  29. ")"
  30. ]
  31. },
  32. {
  33. "cell_type": "code",
  34. "execution_count": 3,
  35. "id": "853f0084-5b12-40e0-a6ea-da6cd96bcd88",
  36. "metadata": {
  37. "tags": []
  38. },
  39. "outputs": [
  40. {
  41. "data": {
  42. "text/plain": [
  43. "torch.Size([10, 512])"
  44. ]
  45. },
  46. "execution_count": 3,
  47. "metadata": {},
  48. "output_type": "execute_result"
  49. }
  50. ],
  51. "source": [
  52. "best_out.shape"
  53. ]
  54. },
  55. {
  56. "cell_type": "code",
  57. "execution_count": 4,
  58. "id": "0807f193-4cb5-4d84-9210-3581e2e49c51",
  59. "metadata": {
  60. "tags": []
  61. },
  62. "outputs": [],
  63. "source": [
  64. "import torch\n",
  65. "\n",
  66. "sd = torch.load(path / 'best.pt')"
  67. ]
  68. },
  69. {
  70. "cell_type": "code",
  71. "execution_count": 7,
  72. "id": "73685dcd-d842-4265-b1db-760124840212",
  73. "metadata": {
  74. "tags": []
  75. },
  76. "outputs": [
  77. {
  78. "data": {
  79. "text/plain": [
  80. "tensor([0.3015], device='cuda:0')"
  81. ]
  82. },
  83. "execution_count": 7,
  84. "metadata": {},
  85. "output_type": "execute_result"
  86. }
  87. ],
  88. "source": [
  89. "sd['prompts.2.sadcl_coeff_pretrained']"
  90. ]
  91. },
  92. {
  93. "cell_type": "code",
  94. "execution_count": 16,
  95. "id": "dffe272c-97d5-41de-ac31-fd2702163670",
  96. "metadata": {},
  97. "outputs": [],
  98. "source": [
  99. "from accelerate import Accelerator\n",
  100. "import accelerate.utils.other as auo\n",
  101. "import accelerate.logging as al"
  102. ]
  103. },
  104. {
  105. "cell_type": "code",
  106. "execution_count": 25,
  107. "id": "8d184d14-a9b7-41ae-b5f8-cf977b7009fd",
  108. "metadata": {
  109. "tags": []
  110. },
  111. "outputs": [],
  112. "source": [
  113. "# Accelerator()\n",
  114. "\n",
  115. "al"
  116. ]
  117. },
  118. {
  119. "cell_type": "code",
  120. "execution_count": 28,
  121. "id": "972a0e50-43aa-44eb-8c10-3e86fba0819d",
  122. "metadata": {
  123. "tags": []
  124. },
  125. "outputs": [
  126. {
  127. "data": {
  128. "text/plain": [
  129. "50"
  130. ]
  131. },
  132. "execution_count": 28,
  133. "metadata": {},
  134. "output_type": "execute_result"
  135. }
  136. ],
  137. "source": [
  138. "auo.logger.getEffectiveLevel()"
  139. ]
  140. },
  141. {
  142. "cell_type": "code",
  143. "execution_count": 18,
  144. "id": "7a247b50-57a0-43cd-9a8d-18d58ea1fd27",
  145. "metadata": {
  146. "tags": []
  147. },
  148. "outputs": [
  149. {
  150. "name": "stdout",
  151. "output_type": "stream",
  152. "text": [
  153. "__main__\n"
  154. ]
  155. }
  156. ],
  157. "source": [
  158. "print(__name__)"
  159. ]
  160. },
  161. {
  162. "cell_type": "code",
  163. "execution_count": null,
  164. "id": "6abe432e-bb4b-4610-899d-e7759512181c",
  165. "metadata": {},
  166. "outputs": [],
  167. "source": []
  168. }
  169. ],
  170. "metadata": {
  171. "kernelspec": {
  172. "display_name": "Python [conda env:deep]",
  173. "language": "python",
  174. "name": "conda-env-deep-py"
  175. },
  176. "language_info": {
  177. "codemirror_mode": {
  178. "name": "ipython",
  179. "version": 3
  180. },
  181. "file_extension": ".py",
  182. "mimetype": "text/x-python",
  183. "name": "python",
  184. "nbconvert_exporter": "python",
  185. "pygments_lexer": "ipython3",
  186. "version": "3.10.13"
  187. }
  188. },
  189. "nbformat": 4,
  190. "nbformat_minor": 5
  191. }