You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

Untitled.ipynb 23KB

3 months ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "code",
  5. "execution_count": 1,
  6. "id": "ab3de9ac-742b-4e14-a59e-4aaac49973de",
  7. "metadata": {
  8. "tags": []
  9. },
  10. "outputs": [],
  11. "source": [
  12. "from transformers import T5Model\n",
  13. "import numpy as np\n",
  14. "from sklearn.decomposition import PCA\n",
  15. "from tqdm.notebook import tqdm"
  16. ]
  17. },
  18. {
  19. "cell_type": "code",
  20. "execution_count": 2,
  21. "id": "6801369d-4153-4674-a558-f4237d52fddc",
  22. "metadata": {
  23. "tags": []
  24. },
  25. "outputs": [],
  26. "source": [
  27. "model = T5Model.from_pretrained('google/t5-large-lm-adapt')"
  28. ]
  29. },
  30. {
  31. "cell_type": "code",
  32. "execution_count": 3,
  33. "id": "bfe07ee9-aed8-47cd-8950-62e2c9f0026a",
  34. "metadata": {
  35. "tags": []
  36. },
  37. "outputs": [],
  38. "source": [
  39. "weitghts = model.get_encoder().get_input_embeddings().weight.detach().clone().numpy()"
  40. ]
  41. },
  42. {
  43. "cell_type": "code",
  44. "execution_count": 4,
  45. "id": "f686d3fe-a427-45d1-83f5-5bbab7251800",
  46. "metadata": {
  47. "tags": []
  48. },
  49. "outputs": [
  50. {
  51. "data": {
  52. "application/vnd.jupyter.widget-view+json": {
  53. "model_id": "1344e2e9d9744feeaee9cbf5ac1c0054",
  54. "version_major": 2,
  55. "version_minor": 0
  56. },
  57. "text/plain": [
  58. " 0%| | 0/63 [00:00<?, ?it/s]"
  59. ]
  60. },
  61. "metadata": {},
  62. "output_type": "display_data"
  63. }
  64. ],
  65. "source": [
  66. "def calc_loss(n_components, X):\n",
  67. " pca = PCA(n_components=n_components)\n",
  68. " pca.fit(X)\n",
  69. " reduced = pca.transform(X)\n",
  70. " reconstruct = np.dot(reduced, pca.components_) + pca.mean_\n",
  71. " loss = ((reconstruct - X) ** 2).mean()\n",
  72. " return loss\n",
  73. "\n",
  74. "x = []\n",
  75. "y = []\n",
  76. "for n_components in tqdm(range(16, 1024, 16)):\n",
  77. " x.append(n_components)\n",
  78. " y.append(calc_loss(n_components, weitghts))"
  79. ]
  80. },
  81. {
  82. "cell_type": "code",
  83. "execution_count": 6,
  84. "id": "d1df28d1-b26c-4064-90c5-c2bcb080484a",
  85. "metadata": {
  86. "tags": []
  87. },
  88. "outputs": [
  89. {
  90. "data": {
  91. "text/plain": [
  92. "[<matplotlib.lines.Line2D at 0x7f1ee440ae60>]"
  93. ]
  94. },
  95. "execution_count": 6,
  96. "metadata": {},
  97. "output_type": "execute_result"
  98. },
  99. {
  100. "data": {
  101. "image/png": "",
  102. "text/plain": [
  103. "<Figure size 640x480 with 1 Axes>"
  104. ]
  105. },
  106. "metadata": {},
  107. "output_type": "display_data"
  108. }
  109. ],
  110. "source": [
  111. "import matplotlib.pyplot as plt\n",
  112. "\n",
  113. "plt.plot(x, np.sqrt(y))"
  114. ]
  115. },
  116. {
  117. "cell_type": "code",
  118. "execution_count": null,
  119. "id": "712f8c9b-ad73-4811-941d-aeee49257e2d",
  120. "metadata": {},
  121. "outputs": [],
  122. "source": []
  123. }
  124. ],
  125. "metadata": {
  126. "kernelspec": {
  127. "display_name": "Python [conda env:deep]",
  128. "language": "python",
  129. "name": "conda-env-deep-py"
  130. },
  131. "language_info": {
  132. "codemirror_mode": {
  133. "name": "ipython",
  134. "version": 3
  135. },
  136. "file_extension": ".py",
  137. "mimetype": "text/x-python",
  138. "name": "python",
  139. "nbconvert_exporter": "python",
  140. "pygments_lexer": "ipython3",
  141. "version": "3.10.11"
  142. }
  143. },
  144. "nbformat": 4,
  145. "nbformat_minor": 5
  146. }