You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

CNN-RNN.ipynb 353KB

1 year ago

  1. {
  2. "cells": [
  3. {
  4. "cell_type": "code",
  5. "execution_count": 1,
  6. "id": "902d927c",
  7. "metadata": {},
  8. "outputs": [
  9. {
  10. "name": "stderr",
  11. "output_type": "stream",
  12. "text": [
  13. "C:\\Users\\saeed\\Desktop\\Master\\bci\\lib\\site-packages\\tqdm\\auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
  14. " from .autonotebook import tqdm as notebook_tqdm\n"
  15. ]
  16. }
  17. ],
  18. "source": [
  19. "import torch\n",
  20. "import torch.nn as nn\n",
  21. "import torch.nn.functional as F\n",
  22. "from sklearn.model_selection import train_test_split\n",
  23. "from sklearn.model_selection import KFold, StratifiedKFold\n",
  24. "from sklearn.manifold import TSNE\n",
  25. "import librosa\n",
  26. "import librosa.display\n",
  27. "import IPython.display as ipd\n",
  28. "import matplotlib.pyplot as plt\n",
  29. "import numpy as np\n",
  30. "import scipy.io\n",
  31. "from tqdm import tqdm\n",
  32. "import glob\n",
  33. "import os\n",
  34. "import json\n",
  35. "import pickle\n",
  36. "from einops import rearrange"
  37. ]
  38. },
  39. {
  40. "cell_type": "code",
  41. "execution_count": null,
  42. "id": "eb0bbf34",
  43. "metadata": {},
  44. "outputs": [],
  45. "source": [
  46. "'''all:\n",
  47. "with open(\"data/all/X_train.pkl\", \"rb\") as f:\n",
  48. " X__train = pickle.load(f)\n",
  49. "with open(\"data/all/X_test.pkl\", \"rb\") as f:\n",
  50. " X__test = pickle.load(f)\n",
  51. "with open(\"data/all/y_train.pkl\", \"rb\") as f:\n",
  52. " y__train = pickle.load(f)\n",
  53. "with open(\"data/all/y_test.pkl\", \"rb\") as f:\n",
  54. " y__test = pickle.load(f)\n",
  55. " \n",
  56. " vowel:\n",
  57. "with open(\"data/vowel/X_vowel_train.pkl\", \"rb\") as f:\n",
  58. " X__train = pickle.load(f)\n",
  59. "with open(\"data/vowel/X_vowel_test.pkl\", \"rb\") as f:\n",
  60. " X__test = pickle.load(f)\n",
  61. "with open(\"data/vowel/y_vowel_train.pkl\", \"rb\") as f:\n",
  62. " y__train = pickle.load(f)\n",
  63. "with open(\"data/vowel/y_vowel_test.pkl\", \"rb\") as f:\n",
  64. " y__test = pickle.load(f)\n",
  65. " \n",
  66. " nasal:\n",
  67. "with open(\"data/nasal/X_nasal_train.pkl\", \"rb\") as f:\n",
  68. " X__train = pickle.load(f)\n",
  69. "with open(\"data/nasal/X_nasal_test.pkl\", \"rb\") as f:\n",
  70. " X__test = pickle.load(f)\n",
  71. "with open(\"data/nasal/y_nasal_train.pkl\", \"rb\") as f:\n",
  72. " y__train = pickle.load(f)\n",
  73. "with open(\"data/nasal/y_nasal_test.pkl\", \"rb\") as f:\n",
  74. " y__test = pickle.load(f)\n",
  75. " \n",
  76. " bilabial:\n",
  77. "with open(\"data/bilabial/X_bilabial_train.pkl\", \"rb\") as f:\n",
  78. " X__train = pickle.load(f)\n",
  79. "with open(\"data/bilabial/X_bilabial_test.pkl\", \"rb\") as f:\n",
  80. " X__test = pickle.load(f)\n",
  81. "with open(\"data/bilabial/y_bilabial_train.pkl\", \"rb\") as f:\n",
  82. " y__train = pickle.load(f)\n",
  83. "with open(\"data/bilabial/y_bilabial_test.pkl\", \"rb\") as f:\n",
  84. " y__test = pickle.load(f)\n",
  85. " \n",
  86. " iy:\n",
  87. "with open(\"data/iy/X_iy_train.pkl\", \"rb\") as f:\n",
  88. " X__train = pickle.load(f)\n",
  89. "with open(\"data/iy/X_iy_test.pkl\", \"rb\") as f:\n",
  90. " X__test = pickle.load(f)\n",
  91. "with open(\"data/iy/y_iy_train.pkl\", \"rb\") as f:\n",
  92. " y__train = pickle.load(f)\n",
  93. "with open(\"data/iy/y_iy_test.pkl\", \"rb\") as f:\n",
  94. " y__test = pickle.load(f)\n",
  95. " \n",
  96. " uw:\n",
  97. "with open(\"data/uw/X_uw_train.pkl\", \"rb\") as f:\n",
  98. " X__train = pickle.load(f)\n",
  99. "with open(\"data/uw/X_uw_test.pkl\", \"rb\") as f:\n",
  100. " X__test = pickle.load(f)\n",
  101. "with open(\"data/uw/y_uw_train.pkl\", \"rb\") as f:\n",
  102. " y__train = pickle.load(f)\n",
  103. "with open(\"data/uw/y_uw_test.pkl\", \"rb\") as f:\n",
  104. " y__test = pickle.load(f)"
  105. ]
  106. },
  107. {
  108. "cell_type": "code",
  109. "execution_count": 6,
  110. "id": "5eed7b79",
  111. "metadata": {},
  112. "outputs": [
  113. {
  114. "name": "stdout",
  115. "output_type": "stream",
  116. "text": [
  117. "(1721, 62, 1250) (192, 62, 1250) 1721 192\n"
  118. ]
  119. }
  120. ],
  121. "source": [
  122. "#load train and test data and labels\n",
  123. "with open(\"data/uw/X_uw_train.pkl\", \"rb\") as f:\n",
  124. " X__train = pickle.load(f)\n",
  125. "with open(\"data/uw/X_uw_test.pkl\", \"rb\") as f:\n",
  126. " X__test = pickle.load(f)\n",
  127. "with open(\"data/uw/y_uw_train.pkl\", \"rb\") as f:\n",
  128. " y__train = pickle.load(f)\n",
  129. "with open(\"data/uw/y_uw_test.pkl\", \"rb\") as f:\n",
  130. " y__test = pickle.load(f)\n",
  131. "print(X__train.shape, X__test.shape, len(y__train), len(y__test))"
  132. ]
  133. },
  134. {
  135. "cell_type": "code",
  136. "execution_count": 7,
  137. "id": "4901d867",
  138. "metadata": {},
  139. "outputs": [
  140. {
  141. "name": "stderr",
  142. "output_type": "stream",
  143. "text": [
  144. "100%|██████████████████████████████████████████████████████████████████████████████| 1721/1721 [05:15<00:00, 5.45it/s]\n",
  145. "100%|████████████████████████████████████████████████████████████████████████████████| 192/192 [00:46<00:00, 4.17it/s]"
  146. ]
  147. },
  148. {
  149. "name": "stdout",
  150. "output_type": "stream",
  151. "text": [
  152. "(1721, 62, 20, 11) (192, 62, 20, 11)\n"
  153. ]
  154. },
  155. {
  156. "name": "stderr",
  157. "output_type": "stream",
  158. "text": [
  159. "\n"
  160. ]
  161. }
  162. ],
  163. "source": [
  164. "#set parameters for MFCC extraction\n",
  165. "n_mfcc = 20\n",
  166. "framesize = 1 * 250\n",
  167. "hop_size = int(framesize/2)\n",
  168. "\n",
  169. "#calculate MFCCs and put them in a matrix with shape (#trial, #channels) for each sub. put all new subs in MFCC_data \n",
  170. "trials = []\n",
  171. "for i, trial in enumerate(tqdm(X__train)):\n",
  172. " channels = []\n",
  173. " for j, channel in enumerate(trial):\n",
  174. " mfccs = librosa.feature.mfcc(y=channel, n_mfcc=n_mfcc, n_fft=framesize, hop_length=hop_size, sr=250)\n",
  175. " channels.append(np.array(mfccs))\n",
  176. " trials.append(np.array(channels)) \n",
  177. "mfc_train = np.array(trials)\n",
  178. "trials = []\n",
  179. "for i, trial in enumerate(tqdm(X__test)):\n",
  180. " channels = []\n",
  181. " for j, channel in enumerate(trial):\n",
  182. " mfccs = librosa.feature.mfcc(y=channel, n_mfcc=n_mfcc, n_fft=framesize, hop_length=hop_size, sr=250)\n",
  183. " channels.append(np.array(mfccs))\n",
  184. " trials.append(np.array(channels)) \n",
  185. "mfc_test = np.array(trials)\n",
  186. "print(mfc_train.shape, mfc_test.shape)"
  187. ]
  188. },
  189. {
  190. "cell_type": "code",
  191. "execution_count": 8,
  192. "id": "dec5d37e",
  193. "metadata": {},
  194. "outputs": [
  195. {
  196. "name": "stdout",
  197. "output_type": "stream",
  198. "text": [
  199. "(1721, 62, 10, 11) (192, 62, 10, 11)\n"
  200. ]
  201. }
  202. ],
  203. "source": [
  204. "mfc_train1 = mfc_train[:,:,0:10,:]\n",
  205. "mfc_test1 = mfc_test[:,:,0:10,:]\n",
  206. "print(mfc_train1.shape, mfc_test1.shape)"
  207. ]
  208. },
  209. {
  210. "cell_type": "code",
  211. "execution_count": 9,
  212. "id": "93b33cdb",
  213. "metadata": {},
  214. "outputs": [],
  215. "source": [
  216. "with open('uw_20mfc_train.pkl', 'wb') as f:\n",
  217. " pickle.dump(mfc_train1, f)\n",
  218. "with open('uw_20mfc_test.pkl', 'wb') as f:\n",
  219. " pickle.dump(mfc_test1, f)"
  220. ]
  221. },
  222. {
  223. "cell_type": "code",
  224. "execution_count": 188,
  225. "id": "3fb31100",
  226. "metadata": {},
  227. "outputs": [],
  228. "source": [
  229. "with open('bilab_20mfc_train.pkl', 'rb') as f:\n",
  230. " mfc_train = pickle.load(f)\n",
  231. "with open('bilab_20mfc_test.pkl', 'rb') as f:\n",
  232. " mfc_test = pickle.load(f)\n",
  233. "with open(\"data/bilabial/y_bilabial_train.pkl\", \"rb\") as f:\n",
  234. " y__train = pickle.load(f)\n",
  235. "with open(\"data/bilabial/y_bilabial_test.pkl\", \"rb\") as f:\n",
  236. " y__test = pickle.load(f)"
  237. ]
  238. },
  239. {
  240. "cell_type": "code",
  241. "execution_count": 2,
  242. "id": "7cbb9002",
  243. "metadata": {},
  244. "outputs": [
  245. {
  246. "name": "stdout",
  247. "output_type": "stream",
  248. "text": [
  249. "(1913, 62, 20, 11)\n"
  250. ]
  251. }
  252. ],
  253. "source": [
  254. "with open(\"data/normal_all_data.pkl\", \"rb\") as f:\n",
  255. " all_data = pickle.load(f)\n",
  256. "with open(\"data/11_20mfc.pkl\", \"rb\") as f:\n",
  257. " mfc_data = pickle.load(f)\n",
  258. "with open(\"data/all_label.pkl\", \"rb\") as f:\n",
  259. " labels = pickle.load(f)\n",
  260. "with open(\"data/vowel_label.pkl\", \"rb\") as f:\n",
  261. " vowel_label = pickle.load(f)\n",
  262. "with open(\"data/bilab_label.pkl\", \"rb\") as f:\n",
  263. " bilab_label = pickle.load(f)\n",
  264. "with open(\"data/nasal_label.pkl\", \"rb\") as f:\n",
  265. " nasal_label = pickle.load(f)\n",
  266. "with open(\"data/iy_label.pkl\", \"rb\") as f:\n",
  267. " iy_label = pickle.load(f)\n",
  268. "with open(\"data/uw_label.pkl\", \"rb\") as f:\n",
  269. " uw_label = pickle.load(f)\n",
  270. "\n",
  271. "print(mfc_data.shape)"
  272. ]
  273. },
  274. {
  275. "cell_type": "code",
  276. "execution_count": 3,
  277. "id": "dc4650d9",
  278. "metadata": {},
  279. "outputs": [
  280. {
  281. "data": {
  282. "text/plain": [
  283. "0.36382645060115004"
  284. ]
  285. },
  286. "execution_count": 3,
  287. "metadata": {},
  288. "output_type": "execute_result"
  289. }
  290. ],
  291. "source": [
  292. "sum(nasal_label)/len(nasal_label)"
  293. ]
  294. },
  295. {
  296. "cell_type": "code",
  297. "execution_count": 189,
  298. "id": "56e00da9",
  299. "metadata": {},
  300. "outputs": [
  301. {
  302. "data": {
  303. "text/plain": [
  304. "(1721, 62, 10, 11)"
  305. ]
  306. },
  307. "execution_count": 189,
  308. "metadata": {},
  309. "output_type": "execute_result"
  310. }
  311. ],
  312. "source": [
  313. "mfc_train = mfc_train[:,:,2:12,:]\n",
  314. "mfc_test = mfc_test[:,:,2:12,:]\n",
  315. "mfc_train.shape"
  316. ]
  317. },
  318. {
  319. "cell_type": "code",
  320. "execution_count": 197,
  321. "id": "198c529c",
  322. "metadata": {},
  323. "outputs": [
  324. {
  325. "name": "stderr",
  326. "output_type": "stream",
  327. "text": [
  328. "C:\\Users\\saeed\\Desktop\\Master\\bci\\lib\\site-packages\\sklearn\\manifold\\_t_sne.py:795: FutureWarning: The default initialization in TSNE will change from 'random' to 'pca' in 1.2.\n",
  329. " warnings.warn(\n",
  330. "C:\\Users\\saeed\\Desktop\\Master\\bci\\lib\\site-packages\\sklearn\\manifold\\_t_sne.py:805: FutureWarning: The default learning rate in TSNE will change from 200.0 to 'auto' in 1.2.\n",
  331. " warnings.warn(\n"
  332. ]
  333. }
  334. ],
  335. "source": [
  336. "data = mfc_train[:,:,:10,:]\n",
  337. "data = rearrange(data, 'n c m t -> n (c m t)')\n",
  338. "tsne = TSNE(n_components=2, random_state=0)\n",
  339. "X = tsne.fit_transform(data)"
  340. ]
  341. },
  342. {
  343. "cell_type": "code",
  344. "execution_count": 198,
  345. "id": "f33647be",
  346. "metadata": {},
  347. "outputs": [
  348. {
  349. "data": {
  350. "image/png": "\n",
  351. "text/plain": [
  352. "<Figure size 576x504 with 1 Axes>"
  353. ]
  354. },
  355. "metadata": {
  356. "needs_background": "light"
  357. },
  358. "output_type": "display_data"
  359. }
  360. ],
  361. "source": [
  362. "l1, l2, l3, l4, l5, l6, l7, l8, l9, l10, l11 = [], [], [], [], [], [], [], [], [], [], []\n",
  363. "for i, label in enumerate(y__train):\n",
  364. " if label==0:\n",
  365. " l1.append(X[i])\n",
  366. " elif label==1:\n",
  367. " l2.append(X[i])\n",
  368. " elif label==2:\n",
  369. " l3.append(X[i])\n",
  370. " elif label==3:\n",
  371. " l4.append(X[i])\n",
  372. " elif label==4:\n",
  373. " l5.append(X[i])\n",
  374. " elif label==5:\n",
  375. " l6.append(X[i])\n",
  376. " elif label==6:\n",
  377. " l7.append(X[i])\n",
  378. " elif label==7:\n",
  379. " l8.append(X[i])\n",
  380. " elif label==8:\n",
  381. " l9.append(X[i])\n",
  382. " elif label==9:\n",
  383. " l10.append(X[i])\n",
  384. " elif label==10:\n",
  385. " l11.append(X[i])\n",
  386. "plt.figure(figsize=(8, 7))\n",
  387. "for x in l1:\n",
  388. " plt.scatter(x[0], x[1], c='r')\n",
  389. "for x in l2:\n",
  390. " plt.scatter(x[0], x[1], c='g')\n",
  391. "for x in l3:\n",
  392. " plt.scatter(x[0], x[1], c='b')\n",
  393. "for x in l4:\n",
  394. " plt.scatter(x[0], x[1], c='c')\n",
  395. "for x in l5:\n",
  396. " plt.scatter(x[0], x[1], c='y')\n",
  397. "for x in l6:\n",
  398. " plt.scatter(x[0], x[1], c='m')\n",
  399. "for x in l7:\n",
  400. " plt.scatter(x[0], x[1], c='k')\n",
  401. "for x in l8:\n",
  402. " plt.scatter(x[0], x[1], c='w')\n",
  403. "for x in l9:\n",
  404. " plt.scatter(x[0], x[1], c='orange')\n",
  405. "for x in l10:\n",
  406. " plt.scatter(x[0], x[1], c='purple')\n",
  407. "for x in l11:\n",
  408. " plt.scatter(x[0], x[1], c='bisque')\n",
  409. "plt.show()\n"
  410. ]
  411. },
  412. {
  413. "cell_type": "code",
  414. "execution_count": 20,
  415. "id": "c8a54e81",
  416. "metadata": {},
  417. "outputs": [
  418. {
  419. "data": {
  420. "text/plain": [
  421. "(array([4, 3, 0, 1, 0, 3, 0, 0, 0, 1], dtype=int64),\n",
  422. " array([1. , 1.8, 2.6, 3.4, 4.2, 5. , 5.8, 6.6, 7.4, 8.2, 9. ]))"
  423. ]
  424. },
  425. "execution_count": 20,
  426. "metadata": {},
  427. "output_type": "execute_result"
  428. }
  429. ],
  430. "source": [
  431. "a = np.array([[1,1,1,1],[2,2,2,4],[5,5,5,9]])\n",
  432. "np.histogram(a)"
  433. ]
  434. },
  435. {
  436. "cell_type": "code",
  437. "execution_count": 190,
  438. "id": "f440d5eb",
  439. "metadata": {
  440. "scrolled": true
  441. },
  442. "outputs": [
  443. {
  444. "data": {
  445. "image/png": "\n",
  446. "text/plain": [
  447. "<Figure size 720x288 with 1 Axes>"
  448. ]
  449. },
  450. "metadata": {
  451. "needs_background": "light"
  452. },
  453. "output_type": "display_data"
  454. },
  455. {
  456. "data": {
  457. "image/png": "\n",
  458. "text/plain": [
  459. "<Figure size 720x288 with 1 Axes>"
  460. ]
  461. },
  462. "metadata": {
  463. "needs_background": "light"
  464. },
  465. "output_type": "display_data"
  466. },
  467. {
  468. "data": {
  469. "image/png": "\n",
  470. "text/plain": [
  471. "<Figure size 720x288 with 1 Axes>"
  472. ]
  473. },
  474. "metadata": {
  475. "needs_background": "light"
  476. },
  477. "output_type": "display_data"
  478. },
  479. {
  480. "data": {
  481. "image/png": "\n",
  482. "text/plain": [
  483. "<Figure size 720x288 with 1 Axes>"
  484. ]
  485. },
  486. "metadata": {
  487. "needs_background": "light"
  488. },
  489. "output_type": "display_data"
  490. },
  491. {
  492. "data": {
  493. "image/png": "\n",
  494. "text/plain": [
  495. "<Figure size 720x288 with 1 Axes>"
  496. ]
  497. },
  498. "metadata": {
  499. "needs_background": "light"
  500. },
  501. "output_type": "display_data"
  502. },
  503. {
  504. "data": {
  505. "image/png": "\n",
  506. "text/plain": [
  507. "<Figure size 720x288 with 1 Axes>"
  508. ]
  509. },
  510. "metadata": {
  511. "needs_background": "light"
  512. },
  513. "output_type": "display_data"
  514. },
  515. {
  516. "data": {
  517. "image/png": "\n",
  518. "text/plain": [
  519. "<Figure size 720x288 with 1 Axes>"
  520. ]
  521. },
  522. "metadata": {
  523. "needs_background": "light"
  524. },
  525. "output_type": "display_data"
  526. },
  527. {
  528. "data": {
  529. "image/png": "\n",
  530. "text/plain": [
  531. "<Figure size 720x288 with 1 Axes>"
  532. ]
  533. },
  534. "metadata": {
  535. "needs_background": "light"
  536. },
  537. "output_type": "display_data"
  538. },
  539. {
  540. "data": {
  541. "image/png": "\n",
  542. "text/plain": [
  543. "<Figure size 720x288 with 1 Axes>"
  544. ]
  545. },
  546. "metadata": {
  547. "needs_background": "light"
  548. },
  549. "output_type": "display_data"
  550. },
  551. {
  552. "data": {
  553. "image/png": "\n",
  554. "text/plain": [
  555. "<Figure size 720x288 with 1 Axes>"
  556. ]
  557. },
  558. "metadata": {
  559. "needs_background": "light"
  560. },
  561. "output_type": "display_data"
  562. }
  563. ],
  564. "source": [
  565. "data = rearrange(mfc_train, 'n c m t -> m (n c t)')\n",
  566. "for m in data:\n",
  567. " plt.figure(figsize=(10, 4))\n",
  568. " hist = np.histogram(m, bins=1000)\n",
  569. " plt.plot(hist[0])"
  570. ]
  571. },
  572. {
  573. "cell_type": "code",
  574. "execution_count": 5,
  575. "id": "2a4fc02f",
  576. "metadata": {},
  577. "outputs": [
  578. {
  579. "data": {
  580. "text/plain": [
  581. "(1913, 62, 20, 11)"
  582. ]
  583. },
  584. "execution_count": 5,
  585. "metadata": {},
  586. "output_type": "execute_result"
  587. }
  588. ],
  589. "source": [
  590. "all_data.shape"
  591. ]
  592. },
  593. {
  594. "cell_type": "code",
  595. "execution_count": 5,
  596. "id": "57c5fc40",
  597. "metadata": {
  598. "scrolled": false
  599. },
  600. "outputs": [
  601. {
  602. "data": {
  603. "text/plain": [
  604. "'trials = []\\nfor trial in mfc_test:\\n pic = np.zeros((7,9,10,11))\\n pic[0,2] = trial[3]\\n pic[0,3] = trial[0]\\n pic[0,4] = trial[1]\\n pic[0,5] = trial[2]\\n pic[0,6] = trial[4]\\n pic[1,:] = trial[5:14]\\n pic[2,:] = trial[14:23]\\n pic[3,:] = trial[23:32]\\n pic[4,:] = trial[32:41]\\n pic[5,:] = trial[41:50]\\n pic[6,0] = trial[50]\\n pic[6,1] = trial[51]\\n pic[6,2] = trial[52]\\n pic[6,3] = trial[58]\\n pic[6,4] = trial[53]\\n pic[6,5] = trial[60]\\n pic[6,6] = trial[54]\\n pic[6,7] = trial[55]\\n pic[6,8] = trial[56]\\n trials.append(pic) \\npicture_data_test = np.array(trials)'"
  605. ]
  606. },
  607. "execution_count": 5,
  608. "metadata": {},
  609. "output_type": "execute_result"
  610. }
  611. ],
  612. "source": [
  613. "\n",
  614. "trials = []\n",
  615. "for trial in all_data:\n",
  616. " pic = np.zeros((7,9,20,11))\n",
  617. " pic[0,2] = trial[3]\n",
  618. " pic[0,3] = trial[0]\n",
  619. " pic[0,4] = trial[1]\n",
  620. " pic[0,5] = trial[2]\n",
  621. " pic[0,6] = trial[4]\n",
  622. " pic[1,:] = trial[5:14]\n",
  623. " pic[2,:] = trial[14:23]\n",
  624. " pic[3,:] = trial[23:32]\n",
  625. " pic[4,:] = trial[32:41]\n",
  626. " pic[5,:] = trial[41:50]\n",
  627. " pic[6,0] = trial[50]\n",
  628. " pic[6,1] = trial[51]\n",
  629. " pic[6,2] = trial[52]\n",
  630. " pic[6,3] = trial[58]\n",
  631. " pic[6,4] = trial[53]\n",
  632. " pic[6,5] = trial[60]\n",
  633. " pic[6,6] = trial[54]\n",
  634. " pic[6,7] = trial[55]\n",
  635. " pic[6,8] = trial[56]\n",
  636. " trials.append(pic)\n",
  637. "picture_data_train = np.array(trials)\n",
  638. "'''trials = []\n",
  639. "for trial in mfc_test:\n",
  640. " pic = np.zeros((7,9,10,11))\n",
  641. " pic[0,2] = trial[3]\n",
  642. " pic[0,3] = trial[0]\n",
  643. " pic[0,4] = trial[1]\n",
  644. " pic[0,5] = trial[2]\n",
  645. " pic[0,6] = trial[4]\n",
  646. " pic[1,:] = trial[5:14]\n",
  647. " pic[2,:] = trial[14:23]\n",
  648. " pic[3,:] = trial[23:32]\n",
  649. " pic[4,:] = trial[32:41]\n",
  650. " pic[5,:] = trial[41:50]\n",
  651. " pic[6,0] = trial[50]\n",
  652. " pic[6,1] = trial[51]\n",
  653. " pic[6,2] = trial[52]\n",
  654. " pic[6,3] = trial[58]\n",
  655. " pic[6,4] = trial[53]\n",
  656. " pic[6,5] = trial[60]\n",
  657. " pic[6,6] = trial[54]\n",
  658. " pic[6,7] = trial[55]\n",
  659. " pic[6,8] = trial[56]\n",
  660. " trials.append(pic) \n",
  661. "picture_data_test = np.array(trials)'''"
  662. ]
  663. },
  664. {
  665. "cell_type": "code",
  666. "execution_count": 4,
  667. "id": "650f9466",
  668. "metadata": {},
  669. "outputs": [
  670. {
  671. "data": {
  672. "text/plain": [
  673. "(1913, 62, 20, 11)"
  674. ]
  675. },
  676. "execution_count": 4,
  677. "metadata": {},
  678. "output_type": "execute_result"
  679. }
  680. ],
  681. "source": [
  682. "all_data.shape"
  683. ]
  684. },
  685. {
  686. "cell_type": "code",
  687. "execution_count": 6,
  688. "id": "49d1f044",
  689. "metadata": {},
  690. "outputs": [
  691. {
  692. "data": {
  693. "text/plain": [
  694. "(1913, 7, 9, 20, 11)"
  695. ]
  696. },
  697. "execution_count": 6,
  698. "metadata": {},
  699. "output_type": "execute_result"
  700. }
  701. ],
  702. "source": [
  703. "picture_data_train.shape"
  704. ]
  705. },
  706. {
  707. "cell_type": "code",
  708. "execution_count": 7,
  709. "id": "60b82988",
  710. "metadata": {},
  711. "outputs": [
  712. {
  713. "ename": "NameError",
  714. "evalue": "name 'picture_data_test' is not defined",
  715. "output_type": "error",
  716. "traceback": [
  717. "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
  718. "\u001b[1;31mNameError\u001b[0m Traceback (most recent call last)",
  719. "Input \u001b[1;32mIn [7]\u001b[0m, in \u001b[0;36m<cell line: 1>\u001b[1;34m()\u001b[0m\n\u001b[1;32m----> 1\u001b[0m \u001b[38;5;28mprint\u001b[39m(picture_data_train\u001b[38;5;241m.\u001b[39mshape, \u001b[43mpicture_data_test\u001b[49m\u001b[38;5;241m.\u001b[39mshape)\n",
  720. "\u001b[1;31mNameError\u001b[0m: name 'picture_data_test' is not defined"
  721. ]
  722. }
  723. ],
  724. "source": [
  725. "print(picture_data_train.shape, picture_data_test.shape)"
  726. ]
  727. },
  728. {
  729. "cell_type": "code",
  730. "execution_count": 8,
  731. "id": "6cad2f67",
  732. "metadata": {},
  733. "outputs": [],
  734. "source": [
  735. "dataset = picture_data_train\n",
  736. "labels = nasal_label"
  737. ]
  738. },
  739. {
  740. "cell_type": "code",
  741. "execution_count": 193,
  742. "id": "142b725f",
  743. "metadata": {},
  744. "outputs": [
  745. {
  746. "name": "stdout",
  747. "output_type": "stream",
  748. "text": [
  749. "(1721, 62, 10, 11) 1913 (7, 9, 10, 11) 1913\n"
  750. ]
  751. }
  752. ],
  753. "source": [
  754. "#concat train and test data. concat train and test labels\n",
  755. "\n",
  756. "dataset = []\n",
  757. "for data in picture_data_train:\n",
  758. " dataset.append(data)\n",
  759. "for data in picture_data_test:\n",
  760. " dataset.append(data)\n",
  761. "\n",
  762. "labels = y__train + y__test\n",
  763. " \n",
  764. "print(mfc_train.shape, len(dataset), dataset[0].shape, len(labels))"
  765. ]
  766. },
  767. {
  768. "cell_type": "code",
  769. "execution_count": 9,
  770. "id": "f287c055",
  771. "metadata": {},
  772. "outputs": [],
  773. "source": [
  774. "#config\n",
  775. "val_size = 0.15\n",
  776. "n_epochs = 100\n",
  777. "batch_size = 128\n",
  778. "print_every = 10\n",
  779. "k = 10\n",
  780. "skf=StratifiedKFold(n_splits=k, shuffle=True, random_state=42)"
  781. ]
  782. },
  783. {
  784. "cell_type": "code",
  785. "execution_count": 10,
  786. "id": "959308f6",
  787. "metadata": {},
  788. "outputs": [],
  789. "source": [
  790. "#model\n",
  791. "\n",
  792. "class CNN_RNN(nn.Module):\n",
  793. " def __init__(self):\n",
  794. " super().__init__()\n",
  795. " self.conv1 = nn.Conv2d(20, 16, 3)\n",
  796. " #torch.nn.init.xavier_normal_(self.conv1.weight)\n",
  797. " self.pool = nn.MaxPool2d(2, 1)\n",
  798. " self.conv2 = nn.Conv2d(16, 32, 3)\n",
  799. " #torch.nn.init.xavier_normal_(self.conv2.weight)\n",
  800. " self.lstm = nn.LSTM(input_size=256, hidden_size=128, num_layers=2, batch_first=True)\n",
  801. " self.fc = nn.Linear(128, 2)\n",
  802. " #torch.nn.init.xavier_normal_(self.fc.weight)\n",
  803. " self.batch1 = nn.BatchNorm2d(16)\n",
  804. " self.batch2 = nn.BatchNorm2d(32)\n",
  805. " self.relu1 = nn.ReLU()\n",
  806. " self.relu2 = nn.ReLU()\n",
  807. " \n",
  808. " \n",
  809. " def forward(self, x):\n",
  810. " hidden = torch.zeros(2, x.shape[0], 128), torch.zeros(2, x.shape[0], 128)\n",
  811. " # (batch, heigth, width, feature, time)\n",
  812. " #print(x.shape)\n",
  813. " x = rearrange(x, 'batch heigth width feature time -> (batch time) feature heigth width')\n",
  814. " #print(x.shape)\n",
  815. " out = self.pool(self.relu1(self.batch1(self.conv1(x))))\n",
  816. " #print(out.shape)\n",
  817. " out = self.relu2(self.batch2(self.conv2(out)))\n",
  818. " #print(out.shape)\n",
  819. " out = rearrange(out, '(batch time) channel heigth width -> batch time (channel heigth width)', time=11)\n",
  820. " #print(out.shape)\n",
  821. " out, hidden = self.lstm(out, hidden) \n",
  822. " out = out[:,-1,:]\n",
  823. " out = self.fc(out)\n",
  824. " return out\n",
  825. " \n",
  826. "class FC(nn.Module):\n",
  827. " def __init__(self, hidden1=500):\n",
  828. " super(FC, self).__init__()\n",
  829. " self.fc1 = nn.Linear(6820, hidden1)\n",
  830. " torch.nn.init.xavier_normal(self.fc1.weight)\n",
  831. " self.fc2 = nn.Linear(hidden1, 1)\n",
  832. " torch.nn.init.xavier_normal(self.fc2.weight)\n",
  833. " self.dropout = nn.Dropout(0.3)\n",
  834. " \n",
  835. " def forward(self, x):\n",
  836. " x = x.view(-1, 6820)\n",
  837. " x = F.relu(self.fc1(x))\n",
  838. " #x = self.dropout(x)\n",
  839. " x = F.sigmoid(self.fc2(x))\n",
  840. " return x"
  841. ]
  842. },
  843. {
  844. "cell_type": "code",
  845. "execution_count": 11,
  846. "id": "6c240889",
  847. "metadata": {
  848. "scrolled": false
  849. },
  850. "outputs": [
  851. {
  852. "name": "stdout",
  853. "output_type": "stream",
  854. "text": [
  855. "-----------------------------Fold 1---------------\n",
  856. "preparing dataloaders...\n",
  857. "coef when 0 > 1 1\n",
  858. "creating model...\n",
  859. "calculating total steps...\n",
  860. "epoch: 1\n",
  861. "validation loss decreased (inf ---> 0.703975), val_acc = 0.39534884691238403\n",
  862. "validation acc increased (0.000000 ---> 0.395349)\n",
  863. "validation acc increased (0.395349 ---> 0.395349)\n",
  864. "epoch 1: train loss = 0.6921196075978445, l1loss = 1.4813192892626534, train acc = 0.5274282693862915,\n",
  865. "val_loss = 0.7102676463681598, val_acc = 0.39534884691238403\n",
  866. "\n",
  867. "epoch: 2\n",
  868. "validation acc increased (0.395349 ---> 0.395349)\n",
  869. "validation acc increased (0.395349 ---> 0.395349)\n",
  870. "epoch 2: train loss = 0.6915809742517668, l1loss = 1.4611903595816387, train acc = 0.5274282693862915,\n",
  871. "val_loss = 0.7109477325927379, val_acc = 0.39534884691238403\n",
  872. "\n",
  873. "epoch: 3\n",
  874. "validation acc increased (0.395349 ---> 0.395349)\n",
  875. "validation acc increased (0.395349 ---> 0.395349)\n",
  876. "epoch 3: train loss = 0.6911844878743867, l1loss = 1.4379649762238578, train acc = 0.5274282693862915,\n",
  877. "val_loss = 0.7080174167026845, val_acc = 0.39534884691238403\n",
  878. "\n",
  879. "epoch: 4\n",
  880. "validation acc increased (0.395349 ---> 0.395349)\n",
  881. "validation acc increased (0.395349 ---> 0.395349)\n",
  882. "epoch 4: train loss = 0.6908970355807565, l1loss = 1.4091885588429474, train acc = 0.5274282693862915,\n",
  883. "val_loss = 0.7074278080186178, val_acc = 0.39534884691238403\n",
  884. "\n",
  885. "epoch: 5\n",
  886. "validation acc increased (0.395349 ---> 0.395349)\n",
  887. "validation acc increased (0.395349 ---> 0.395349)\n",
  888. "epoch 5: train loss = 0.6903622140522032, l1loss = 1.3728010045504173, train acc = 0.5274282693862915,\n",
  889. "val_loss = 0.7060032583946405, val_acc = 0.39534884691238403\n",
  890. "\n",
  891. "epoch: 6\n",
  892. "validation acc increased (0.395349 ---> 0.395349)\n",
  893. "validation acc increased (0.395349 ---> 0.395349)\n",
  894. "epoch 6: train loss = 0.6897485629961487, l1loss = 1.3268929230968858, train acc = 0.5279315710067749,\n",
  895. "val_loss = 0.7049394624177799, val_acc = 0.40310078859329224\n",
  896. "\n",
  897. "epoch: 7\n",
  898. "validation acc increased (0.395349 ---> 0.403101)\n",
  899. "validation acc increased (0.403101 ---> 0.406977)\n",
  900. "epoch 7: train loss = 0.6890018312068338, l1loss = 1.2698341833027271, train acc = 0.5294413566589355,\n",
  901. "val_loss = 0.7049807642781457, val_acc = 0.40310078859329224\n",
  902. "\n",
  903. "epoch: 8\n",
  904. "validation loss decreased (0.703975 ---> 0.700499), val_acc = 0.44573643803596497\n",
  905. "validation acc increased (0.406977 ---> 0.445736)\n",
  906. "epoch 8: train loss = 0.6881676978001837, l1loss = 1.2007527305783492, train acc = 0.5430296659469604,\n",
  907. "val_loss = 0.700941942458929, val_acc = 0.44186046719551086\n",
  908. "\n",
  909. "epoch: 9\n",
  910. "epoch 9: train loss = 0.6866489719312636, l1loss = 1.119547625252753, train acc = 0.534977376461029,\n",
  911. "val_loss = 0.7109170130981032, val_acc = 0.41860464215278625\n",
  912. "\n",
  913. "epoch: 10\n",
  914. "validation acc increased (0.445736 ---> 0.453488)\n",
  915. "epoch 10: train loss = 0.6828876420308071, l1loss = 1.026746424707145, train acc = 0.5812783241271973,\n",
  916. "val_loss = 0.7096735031105751, val_acc = 0.44961240887641907\n",
  917. "\n",
  918. "epoch: 11\n",
  919. "validation acc increased (0.453488 ---> 0.457364)\n",
  920. "validation loss decreased (0.700499 ---> 0.691370), val_acc = 0.5155038833618164\n",
  921. "validation acc increased (0.457364 ---> 0.515504)\n",
  922. "epoch 11: train loss = 0.6756821570530811, l1loss = 0.9239571331491648, train acc = 0.606441855430603,\n",
  923. "val_loss = 0.7130429273427918, val_acc = 0.4651162922382355\n",
  924. "\n",
  925. "epoch: 12\n",
  926. "validation loss decreased (0.691370 ---> 0.687802), val_acc = 0.5426356792449951\n",
  927. "validation acc increased (0.515504 ---> 0.542636)\n",
  928. "epoch 12: train loss = 0.6597185275275083, l1loss = 0.8148023310653156, train acc = 0.6170105934143066,\n",
  929. "val_loss = 0.7542387963265411, val_acc = 0.45736435055732727\n",
  930. "\n",
  931. "epoch: 13\n",
  932. "epoch 13: train loss = 0.6367104136589848, l1loss = 0.7034740930054736, train acc = 0.6371414065361023,\n",
  933. "val_loss = 0.7009645912998407, val_acc = 0.5348837375640869\n",
  934. "\n",
  935. "epoch: 14\n",
  936. "validation acc increased (0.542636 ---> 0.542636)\n",
  937. "validation acc increased (0.542636 ---> 0.589147)\n",
  938. "epoch 14: train loss = 0.5822035336860413, l1loss = 0.594905492913861, train acc = 0.7121288180351257,\n",
  939. "val_loss = 0.7594922891882963, val_acc = 0.5581395626068115\n",
  940. "\n",
  941. "epoch: 15\n",
  942. "epoch 15: train loss = 0.5122891185148073, l1loss = 0.4949571643124049, train acc = 0.7488676309585571,\n",
  943. "val_loss = 0.8891759768936985, val_acc = 0.4883720874786377\n",
  944. "\n",
  945. "epoch: 16\n",
  946. "epoch 16: train loss = 0.39873673111035346, l1loss = 0.40662467002748665, train acc = 0.8349270224571228,\n",
  947. "val_loss = 1.1463301422059997, val_acc = 0.5581395626068115\n",
  948. "\n",
  949. "epoch: 17\n",
  950. "epoch 17: train loss = 0.2904471275670839, l1loss = 0.3356658784540442, train acc = 0.8917967081069946,\n",
  951. "val_loss = 1.2913649451825047, val_acc = 0.5\n",
  952. "\n",
  953. "epoch: 18\n",
  954. "epoch 18: train loss = 0.22612523326906175, l1loss = 0.28161148739689007, train acc = 0.9068948030471802,\n",
  955. "val_loss = 1.6142784735953162, val_acc = 0.5813953280448914\n",
  956. "\n",
  957. "epoch: 19\n",
  958. "epoch 19: train loss = 0.1450976847132427, l1loss = 0.247697628168347, train acc = 0.9531957507133484,\n",
  959. "val_loss = 1.5399292397868725, val_acc = 0.5620155334472656\n",
  960. "\n",
  961. "epoch: 20\n",
  962. "epoch 20: train loss = 0.11934498384422433, l1loss = 0.23167835120496044, train acc = 0.9557121396064758,\n",
  963. "val_loss = 1.7687409589456957, val_acc = 0.4844961166381836\n",
  964. "\n",
  965. "epoch: 21\n",
  966. "epoch 21: train loss = 0.06861750959270134, l1loss = 0.22078036307298662, train acc = 0.9773527979850769,\n",
  967. "val_loss = 2.424526694393175, val_acc = 0.5852712988853455\n",
  968. "\n",
  969. "epoch: 22\n",
  970. "epoch 22: train loss = 0.04279914844610551, l1loss = 0.21008030950993103, train acc = 0.9899345636367798,\n",
  971. "val_loss = 1.830416483472484, val_acc = 0.5310077667236328\n",
  972. "\n",
  973. "epoch: 23\n",
  974. "epoch 23: train loss = 0.032723448187240295, l1loss = 0.20147271192607774, train acc = 0.9924509525299072,\n",
  975. "val_loss = 2.2050810784332513, val_acc = 0.4922480583190918\n",
  976. "\n",
  977. "epoch: 24\n",
  978. "epoch 24: train loss = 0.02353123065494807, l1loss = 0.19340241685787157, train acc = 0.9964771270751953,\n",
  979. "val_loss = 2.1999229627062182, val_acc = 0.5736433863639832\n",
  980. "\n",
  981. "epoch: 25\n",
  982. "epoch 25: train loss = 0.032386840027243756, l1loss = 0.18697129591931994, train acc = 0.9904378652572632,\n",
  983. "val_loss = 2.24283122277075, val_acc = 0.5465116500854492\n",
  984. "\n",
  985. "epoch: 26\n",
  986. "epoch 26: train loss = 0.05402209995587267, l1loss = 0.18239764653982443, train acc = 0.9798691272735596,\n",
  987. "val_loss = 2.1508019247720407, val_acc = 0.5891472697257996\n",
  988. "\n",
  989. "epoch: 27\n",
  990. "validation acc increased (0.589147 ---> 0.596899)\n",
  991. "epoch 27: train loss = 0.08302161952074773, l1loss = 0.17984621981219784, train acc = 0.9667841196060181,\n",
  992. "val_loss = 2.134554007256678, val_acc = 0.5155038833618164\n",
  993. "\n",
  994. "epoch: 28\n",
  995. "validation acc increased (0.596899 ---> 0.608527)\n",
  996. "epoch 28: train loss = 0.08317013363863393, l1loss = 0.17741221547966696, train acc = 0.9672873616218567,\n",
  997. "val_loss = 1.8020918643058732, val_acc = 0.5813953280448914\n",
  998. "\n",
  999. "epoch: 29\n",
  1000. "epoch 29: train loss = 0.09245095192021201, l1loss = 0.1736254430201165, train acc = 0.9642677307128906,\n",
  1001. "val_loss = 2.263783588426101, val_acc = 0.6356589198112488\n",
  1002. "\n",
  1003. "epoch: 30\n",
  1004. "validation acc increased (0.608527 ---> 0.612403)\n",
  1005. "epoch 30: train loss = 0.03266293139847968, l1loss = 0.167058823288057, train acc = 0.9939607381820679,\n",
  1006. "val_loss = 1.8842465008876121, val_acc = 0.5775193572044373\n",
  1007. "\n",
  1008. "epoch: 31\n",
  1009. "epoch 31: train loss = 0.009013392774094783, l1loss = 0.16012636431510208, train acc = 1.0,\n",
  1010. "val_loss = 1.957195477892262, val_acc = 0.565891444683075\n",
  1011. "\n",
  1012. "!!! overfitted !!!\n",
  1013. "tensor(74)\n",
  1014. "tensor(38)\n",
  1015. "early stoping results:\n",
  1016. "\t [tensor(0.5833)]\n",
  1017. "\t [tensor(0.5737)]\n",
  1018. "tensor(67)\n",
  1019. "tensor(35)\n",
  1020. "full train results:\n",
  1021. "\t [tensor(0.5312)]\n",
  1022. "\t [tensor(0.9995)]\n",
  1023. "tensor(77)\n",
  1024. "tensor(39)\n",
  1025. "best accs results:\n",
  1026. "\t [tensor(0.6042)]\n",
  1027. "\t [tensor(0.6064)]\n",
  1028. "[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1]\n",
  1029. "-----------------------------Fold 2---------------\n",
  1030. "preparing dataloaders...\n"
  1031. ]
  1032. },
  1033. {
  1034. "name": "stdout",
  1035. "output_type": "stream",
  1036. "text": [
  1037. "coef when 0 > 1 1\n",
  1038. "creating model...\n",
  1039. "calculating total steps...\n",
  1040. "epoch: 1\n",
  1041. "validation loss decreased (inf ---> 0.685728), val_acc = 0.643410861492157\n",
  1042. "validation acc increased (0.000000 ---> 0.643411)\n",
  1043. "validation acc increased (0.643411 ---> 0.643411)\n",
  1044. "epoch 1: train loss = 0.6923627504481037, l1loss = 1.486846898458096, train acc = 0.5132699012756348,\n",
  1045. "val_loss = 0.6893179130184558, val_acc = 0.643410861492157\n",
  1046. "\n",
  1047. "epoch: 2\n",
  1048. "validation acc increased (0.643411 ---> 0.643411)\n",
  1049. "epoch 2: train loss = 0.690640912243409, l1loss = 1.4669305257458178, train acc = 0.5353029370307922,\n",
  1050. "val_loss = 0.6956856370896332, val_acc = 0.3759689927101135\n",
  1051. "\n",
  1052. "epoch: 3\n",
  1053. "epoch 3: train loss = 0.6898900556098716, l1loss = 1.4436278268582952, train acc = 0.5348021984100342,\n",
  1054. "val_loss = 0.7077792864437251, val_acc = 0.356589138507843\n",
  1055. "\n",
  1056. "epoch: 4\n",
  1057. "epoch 4: train loss = 0.6896268491692464, l1loss = 1.41467766739096, train acc = 0.5348021984100342,\n",
  1058. "val_loss = 0.716421481712844, val_acc = 0.356589138507843\n",
  1059. "\n",
  1060. "epoch: 5\n",
  1061. "epoch 5: train loss = 0.6888132073250065, l1loss = 1.3779625838914393, train acc = 0.5348021984100342,\n",
  1062. "val_loss = 0.7130461672479792, val_acc = 0.356589138507843\n",
  1063. "\n",
  1064. "epoch: 6\n",
  1065. "epoch 6: train loss = 0.6880481983042981, l1loss = 1.3316323196643223, train acc = 0.5353029370307922,\n",
  1066. "val_loss = 0.7145648002624512, val_acc = 0.356589138507843\n",
  1067. "\n",
  1068. "epoch: 7\n",
  1069. "epoch 7: train loss = 0.6870320929129481, l1loss = 1.2742788348725633, train acc = 0.5373059511184692,\n",
  1070. "val_loss = 0.7159286654272745, val_acc = 0.3527131676673889\n",
  1071. "\n",
  1072. "epoch: 8\n",
  1073. "epoch 8: train loss = 0.6852890617201551, l1loss = 1.205061872063963, train acc = 0.5448172092437744,\n",
  1074. "val_loss = 0.7084857664367025, val_acc = 0.39147287607192993\n",
  1075. "\n",
  1076. "epoch: 9\n",
  1077. "epoch 9: train loss = 0.6824781329678367, l1loss = 1.1237744811659285, train acc = 0.5553330183029175,\n",
  1078. "val_loss = 0.7270350784294365, val_acc = 0.3488371968269348\n",
  1079. "\n",
  1080. "epoch: 10\n",
  1081. "epoch 10: train loss = 0.6778379736629557, l1loss = 1.0311813556557008, train acc = 0.5823735594749451,\n",
  1082. "val_loss = 0.721730862477029, val_acc = 0.38759690523147583\n",
  1083. "\n",
  1084. "epoch: 11\n",
  1085. "epoch 11: train loss = 0.6686287961068247, l1loss = 0.9295802932712037, train acc = 0.6119178533554077,\n",
  1086. "val_loss = 0.7199302791625031, val_acc = 0.4844961166381836\n",
  1087. "\n",
  1088. "epoch: 12\n",
  1089. "epoch 12: train loss = 0.6473726140121608, l1loss = 0.8220474930124756, train acc = 0.6479719877243042,\n",
  1090. "val_loss = 0.7618678161340167, val_acc = 0.44186046719551086\n",
  1091. "\n",
  1092. "epoch: 13\n",
  1093. "epoch 13: train loss = 0.6235634726050142, l1loss = 0.7134110764616659, train acc = 0.6604907512664795,\n",
  1094. "val_loss = 0.7779512396154478, val_acc = 0.44961240887641907\n",
  1095. "\n",
  1096. "epoch: 14\n",
  1097. "epoch 14: train loss = 0.5916976472348885, l1loss = 0.6070883066939543, train acc = 0.6935403347015381,\n",
  1098. "val_loss = 0.7408287622207819, val_acc = 0.5348837375640869\n",
  1099. "\n",
  1100. "epoch: 15\n",
  1101. "epoch 15: train loss = 0.5074248785487878, l1loss = 0.5080034436287735, train acc = 0.7611417174339294,\n",
  1102. "val_loss = 0.9050711670587229, val_acc = 0.44186046719551086\n",
  1103. "\n",
  1104. "epoch: 16\n",
  1105. "epoch 16: train loss = 0.43145891812125864, l1loss = 0.41937342439822456, train acc = 0.8137205839157104,\n",
  1106. "val_loss = 0.8636050733715989, val_acc = 0.569767415523529\n",
  1107. "\n",
  1108. "epoch: 17\n",
  1109. "epoch 17: train loss = 0.33662977910185077, l1loss = 0.3446814462908161, train acc = 0.8642964363098145,\n",
  1110. "val_loss = 0.9003204245899998, val_acc = 0.569767415523529\n",
  1111. "\n",
  1112. "epoch: 18\n",
  1113. "epoch 18: train loss = 0.24278796418314882, l1loss = 0.28942148506074294, train acc = 0.9108663201332092,\n",
  1114. "val_loss = 1.8521458714507346, val_acc = 0.6356589198112488\n",
  1115. "\n",
  1116. "epoch: 19\n",
  1117. "epoch 19: train loss = 0.30993230177512093, l1loss = 0.25730218698397, train acc = 0.8688032031059265,\n",
  1118. "val_loss = 1.1883431146311205, val_acc = 0.5775193572044373\n",
  1119. "\n",
  1120. "epoch: 20\n",
  1121. "epoch 20: train loss = 0.1617210699160456, l1loss = 0.2381301120736925, train acc = 0.947421133518219,\n",
  1122. "val_loss = 1.535721439261769, val_acc = 0.45348837971687317\n",
  1123. "\n",
  1124. "epoch: 21\n",
  1125. "epoch 21: train loss = 0.12891582561900272, l1loss = 0.2264571423433874, train acc = 0.9594391584396362,\n",
  1126. "val_loss = 2.21431942595992, val_acc = 0.42248061299324036\n",
  1127. "\n",
  1128. "epoch: 22\n",
  1129. "epoch 22: train loss = 0.10533707820226028, l1loss = 0.22040308403801667, train acc = 0.9609414339065552,\n",
  1130. "val_loss = 1.6402112753816354, val_acc = 0.5968992114067078\n",
  1131. "\n",
  1132. "epoch: 23\n",
  1133. "epoch 23: train loss = 0.05758423222024619, l1loss = 0.21273435418702508, train acc = 0.9834752082824707,\n",
  1134. "val_loss = 1.7986319601073746, val_acc = 0.569767415523529\n",
  1135. "\n",
  1136. "epoch: 24\n",
  1137. "epoch 24: train loss = 0.027758270035729626, l1loss = 0.20449841585109158, train acc = 0.993990957736969,\n",
  1138. "val_loss = 1.8702633103658988, val_acc = 0.5620155334472656\n",
  1139. "\n",
  1140. "epoch: 25\n",
  1141. "epoch 25: train loss = 0.017600334982089717, l1loss = 0.19607467192766603, train acc = 0.9969955086708069,\n",
  1142. "val_loss = 1.982220847015233, val_acc = 0.5155038833618164\n",
  1143. "\n",
  1144. "epoch: 26\n",
  1145. "epoch 26: train loss = 0.012600567348536457, l1loss = 0.1878898390006828, train acc = 0.998497724533081,\n",
  1146. "val_loss = 2.2955358172153204, val_acc = 0.4961240291595459\n",
  1147. "\n",
  1148. "epoch: 27\n",
  1149. "epoch 27: train loss = 0.009947275661042126, l1loss = 0.18061071683540306, train acc = 1.0,\n",
  1150. "val_loss = 1.9011228786882504, val_acc = 0.5193798542022705\n",
  1151. "\n",
  1152. "!!! overfitted !!!\n",
  1153. "tensor(81)\n",
  1154. "tensor(41)\n",
  1155. "early stoping results:\n",
  1156. "\t [tensor(0.5833), tensor(0.6354)]\n",
  1157. "\t [tensor(0.5737), tensor(0.4652)]\n",
  1158. "tensor(75)\n",
  1159. "tensor(28)\n",
  1160. "full train results:\n",
  1161. "\t [tensor(0.5312), tensor(0.5365)]\n",
  1162. "\t [tensor(0.9995), tensor(1.)]\n",
  1163. "tensor(83)\n",
  1164. "tensor(39)\n",
  1165. "best accs results:\n",
  1166. "\t [tensor(0.6042), tensor(0.6354)]\n",
  1167. "\t [tensor(0.6064), tensor(0.4652)]\n",
  1168. "[0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1]\n",
  1169. "-----------------------------Fold 3---------------\n",
  1170. "preparing dataloaders...\n",
  1171. "coef when 0 > 1 1\n",
  1172. "creating model...\n",
  1173. "calculating total steps...\n",
  1174. "epoch: 1\n",
  1175. "validation loss decreased (inf ---> 0.695385), val_acc = 0.3217054307460785\n",
  1176. "validation acc increased (0.000000 ---> 0.321705)\n",
  1177. "validation loss decreased (0.695385 ---> 0.694880), val_acc = 0.3217054307460785\n",
  1178. "validation acc increased (0.321705 ---> 0.321705)\n",
  1179. "epoch 1: train loss = 0.6919325688781434, l1loss = 1.4844660898029864, train acc = 0.5353938341140747,\n",
  1180. "val_loss = 0.696869459263114, val_acc = 0.3217054307460785\n",
  1181. "\n",
  1182. "epoch: 2\n",
  1183. "validation acc increased (0.321705 ---> 0.321705)\n",
  1184. "validation acc increased (0.321705 ---> 0.321705)\n",
  1185. "epoch 2: train loss = 0.6903208250538302, l1loss = 1.4645608618394923, train acc = 0.5408773422241211,\n",
  1186. "val_loss = 0.7094094993532166, val_acc = 0.3217054307460785\n",
  1187. "\n",
  1188. "epoch: 3\n",
  1189. "validation acc increased (0.321705 ---> 0.321705)\n",
  1190. "validation acc increased (0.321705 ---> 0.321705)\n",
  1191. "epoch 3: train loss = 0.6888462312794398, l1loss = 1.4412288264049253, train acc = 0.5413758754730225,\n",
  1192. "val_loss = 0.7203627150188121, val_acc = 0.3217054307460785\n",
  1193. "\n",
  1194. "epoch: 4\n",
  1195. "validation acc increased (0.321705 ---> 0.321705)\n",
  1196. "validation acc increased (0.321705 ---> 0.321705)\n",
  1197. "epoch 4: train loss = 0.6883660529926791, l1loss = 1.4121827829632898, train acc = 0.5413758754730225,\n",
  1198. "val_loss = 0.7268370288287023, val_acc = 0.3217054307460785\n",
  1199. "\n",
  1200. "epoch: 5\n",
  1201. "validation acc increased (0.321705 ---> 0.321705)\n",
  1202. "validation acc increased (0.321705 ---> 0.321705)\n",
  1203. "epoch 5: train loss = 0.6879507299078069, l1loss = 1.375367067509135, train acc = 0.5413758754730225,\n",
  1204. "val_loss = 0.7243463373923487, val_acc = 0.3217054307460785\n",
  1205. "\n",
  1206. "epoch: 6\n",
  1207. "validation acc increased (0.321705 ---> 0.321705)\n",
  1208. "validation acc increased (0.321705 ---> 0.321705)\n",
  1209. "epoch 6: train loss = 0.6870695925424486, l1loss = 1.3291421180232572, train acc = 0.5413758754730225,\n",
  1210. "val_loss = 0.7238615367763727, val_acc = 0.3217054307460785\n",
  1211. "\n",
  1212. "epoch: 7\n",
  1213. "validation acc increased (0.321705 ---> 0.321705)\n"
  1214. ]
  1215. },
  1216. {
  1217. "name": "stdout",
  1218. "output_type": "stream",
  1219. "text": [
  1220. "validation acc increased (0.321705 ---> 0.321705)\n",
  1221. "epoch 7: train loss = 0.6860224290835417, l1loss = 1.271995098854704, train acc = 0.5413758754730225,\n",
  1222. "val_loss = 0.7287569512692533, val_acc = 0.3217054307460785\n",
  1223. "\n",
  1224. "epoch: 8\n",
  1225. "validation acc increased (0.321705 ---> 0.321705)\n",
  1226. "validation acc increased (0.321705 ---> 0.321705)\n",
  1227. "epoch 8: train loss = 0.6840933574993138, l1loss = 1.2031428917336202, train acc = 0.5413758754730225,\n",
  1228. "val_loss = 0.7273619147234185, val_acc = 0.3217054307460785\n",
  1229. "\n",
  1230. "epoch: 9\n",
  1231. "validation acc increased (0.321705 ---> 0.321705)\n",
  1232. "validation acc increased (0.321705 ---> 0.321705)\n",
  1233. "epoch 9: train loss = 0.6811107120746868, l1loss = 1.1223840780058507, train acc = 0.5488534569740295,\n",
  1234. "val_loss = 0.7214035119197165, val_acc = 0.3682170510292053\n",
  1235. "\n",
  1236. "epoch: 10\n",
  1237. "validation acc increased (0.321705 ---> 0.368217)\n",
  1238. "validation acc increased (0.368217 ---> 0.379845)\n",
  1239. "epoch 10: train loss = 0.6758952574620575, l1loss = 1.0303770931268619, train acc = 0.5947158336639404,\n",
  1240. "val_loss = 0.7813697923985563, val_acc = 0.3178294599056244\n",
  1241. "\n",
  1242. "epoch: 11\n",
  1243. "epoch 11: train loss = 0.6630049520213963, l1loss = 0.9292449082952673, train acc = 0.6096709966659546,\n",
  1244. "val_loss = 0.7966540109279544, val_acc = 0.3372093141078949\n",
  1245. "\n",
  1246. "epoch: 12\n",
  1247. "epoch 12: train loss = 0.6346044056080393, l1loss = 0.8222741206526637, train acc = 0.6495513319969177,\n",
  1248. "val_loss = 1.1256857214047926, val_acc = 0.3178294599056244\n",
  1249. "\n",
  1250. "epoch: 13\n",
  1251. "epoch 13: train loss = 0.5827583085266448, l1loss = 0.7148247218559889, train acc = 0.7008973360061646,\n",
  1252. "val_loss = 0.9412916704665782, val_acc = 0.38759690523147583\n",
  1253. "\n",
  1254. "epoch: 14\n",
  1255. "validation acc increased (0.379845 ---> 0.395349)\n",
  1256. "epoch 14: train loss = 0.521890408674242, l1loss = 0.6116738263536189, train acc = 0.7442672252655029,\n",
  1257. "val_loss = 1.0655778175176576, val_acc = 0.43798449635505676\n",
  1258. "\n",
  1259. "epoch: 15\n",
  1260. "validation acc increased (0.395349 ---> 0.492248)\n",
  1261. "epoch 15: train loss = 0.43515452614809913, l1loss = 0.5178841115648701, train acc = 0.8125622868537903,\n",
  1262. "val_loss = 1.1385327799375666, val_acc = 0.46124032139778137\n",
  1263. "\n",
  1264. "epoch: 16\n",
  1265. "validation acc increased (0.492248 ---> 0.534884)\n",
  1266. "epoch 16: train loss = 0.36243219033314483, l1loss = 0.4364416953454821, train acc = 0.8459621071815491,\n",
  1267. "val_loss = 1.0747009594311085, val_acc = 0.5426356792449951\n",
  1268. "\n",
  1269. "epoch: 17\n",
  1270. "validation acc increased (0.534884 ---> 0.534884)\n",
  1271. "validation acc increased (0.534884 ---> 0.577519)\n",
  1272. "epoch 17: train loss = 0.26884116884599535, l1loss = 0.37061460038362926, train acc = 0.8918245434761047,\n",
  1273. "val_loss = 1.05600640709086, val_acc = 0.5503876209259033\n",
  1274. "\n",
  1275. "epoch: 18\n",
  1276. "epoch 18: train loss = 0.22932734512022937, l1loss = 0.320850499219457, train acc = 0.9082751870155334,\n",
  1277. "val_loss = 1.3345337708791096, val_acc = 0.6085271239280701\n",
  1278. "\n",
  1279. "epoch: 19\n",
  1280. "validation acc increased (0.577519 ---> 0.600775)\n",
  1281. "validation acc increased (0.600775 ---> 0.643411)\n",
  1282. "epoch 19: train loss = 0.19543913260830245, l1loss = 0.28918363494150423, train acc = 0.9252243041992188,\n",
  1283. "val_loss = 1.4304973793887468, val_acc = 0.6589147448539734\n",
  1284. "\n",
  1285. "epoch: 20\n",
  1286. "validation acc increased (0.643411 ---> 0.662791)\n",
  1287. "epoch 20: train loss = 0.12316853676366425, l1loss = 0.27048087601409715, train acc = 0.9526420831680298,\n",
  1288. "val_loss = 1.3577455007290655, val_acc = 0.569767415523529\n",
  1289. "\n",
  1290. "epoch: 21\n",
  1291. "epoch 21: train loss = 0.04757456736990129, l1loss = 0.25799264823451473, train acc = 0.989531397819519,\n",
  1292. "val_loss = 1.6723881137463474, val_acc = 0.565891444683075\n",
  1293. "\n",
  1294. "epoch: 22\n",
  1295. "epoch 22: train loss = 0.02772095015389494, l1loss = 0.24743169342592017, train acc = 0.9950149655342102,\n",
  1296. "val_loss = 1.8008276674564379, val_acc = 0.5930232405662537\n",
  1297. "\n",
  1298. "epoch: 23\n",
  1299. "epoch 23: train loss = 0.03146858667120976, l1loss = 0.2371698537561021, train acc = 0.9940179586410522,\n",
  1300. "val_loss = 1.9554185228234575, val_acc = 0.6124030947685242\n",
  1301. "\n",
  1302. "epoch: 24\n",
  1303. "epoch 24: train loss = 0.10344059998587384, l1loss = 0.23161495476218305, train acc = 0.9561315774917603,\n",
  1304. "val_loss = 1.9649058589495199, val_acc = 0.538759708404541\n",
  1305. "\n",
  1306. "epoch: 25\n",
  1307. "epoch 25: train loss = 0.1441766292168754, l1loss = 0.22718537122694588, train acc = 0.9411764740943909,\n",
  1308. "val_loss = 1.2929859771284946, val_acc = 0.5891472697257996\n",
  1309. "\n",
  1310. "epoch: 26\n",
  1311. "epoch 26: train loss = 0.048959363251924515, l1loss = 0.21805687045171515, train acc = 0.9880359172821045,\n",
  1312. "val_loss = 1.5774617324503817, val_acc = 0.6124030947685242\n",
  1313. "\n",
  1314. "epoch: 27\n",
  1315. "epoch 27: train loss = 0.012611901554244673, l1loss = 0.20935222295559533, train acc = 0.9995014667510986,\n",
  1316. "val_loss = 1.885564662227335, val_acc = 0.6201550364494324\n",
  1317. "\n",
  1318. "epoch: 28\n",
  1319. "epoch 28: train loss = 0.006157807443116384, l1loss = 0.19971955642624128, train acc = 1.0,\n",
  1320. "val_loss = 1.8642971059993314, val_acc = 0.5852712988853455\n",
  1321. "\n",
  1322. "!!! overfitted !!!\n",
  1323. "tensor(51)\n",
  1324. "tensor(19)\n",
  1325. "early stoping results:\n",
  1326. "\t [tensor(0.5833), tensor(0.6354), tensor(0.3646)]\n",
  1327. "\t [tensor(0.5737), tensor(0.4652), tensor(0.5414)]\n",
  1328. "tensor(80)\n",
  1329. "tensor(34)\n",
  1330. "full train results:\n",
  1331. "\t [tensor(0.5312), tensor(0.5365), tensor(0.5938)]\n",
  1332. "\t [tensor(0.9995), tensor(1.), tensor(1.)]\n",
  1333. "tensor(76)\n",
  1334. "tensor(36)\n",
  1335. "best accs results:\n",
  1336. "\t [tensor(0.6042), tensor(0.6354), tensor(0.5833)]\n",
  1337. "\t [tensor(0.6064), tensor(0.4652), tensor(0.6461)]\n",
  1338. "[1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1]\n",
  1339. "-----------------------------Fold 4---------------\n",
  1340. "preparing dataloaders...\n",
  1341. "coef when 0 > 1 1\n",
  1342. "creating model...\n",
  1343. "calculating total steps...\n",
  1344. "epoch: 1\n",
  1345. "validation loss decreased (inf ---> 0.699053), val_acc = 0.3255814015865326\n",
  1346. "validation acc increased (0.000000 ---> 0.325581)\n",
  1347. "validation acc increased (0.325581 ---> 0.325581)\n",
  1348. "epoch 1: train loss = 0.6915451809986757, l1loss = 1.482684857348502, train acc = 0.5393818616867065,\n",
  1349. "val_loss = 0.7069418707559275, val_acc = 0.3255814015865326\n",
  1350. "\n",
  1351. "epoch: 2\n",
  1352. "validation acc increased (0.325581 ---> 0.325581)\n",
  1353. "validation acc increased (0.325581 ---> 0.325581)\n",
  1354. "epoch 2: train loss = 0.6905436350602809, l1loss = 1.4623768342694161, train acc = 0.5403788685798645,\n",
  1355. "val_loss = 0.7145971189173617, val_acc = 0.3255814015865326\n",
  1356. "\n",
  1357. "epoch: 3\n",
  1358. "validation acc increased (0.325581 ---> 0.325581)\n",
  1359. "validation acc increased (0.325581 ---> 0.325581)\n",
  1360. "epoch 3: train loss = 0.6898230171631483, l1loss = 1.438654139533951, train acc = 0.5403788685798645,\n",
  1361. "val_loss = 0.7176525223162747, val_acc = 0.3255814015865326\n",
  1362. "\n",
  1363. "epoch: 4\n",
  1364. "validation acc increased (0.325581 ---> 0.325581)\n",
  1365. "validation acc increased (0.325581 ---> 0.325581)\n",
  1366. "epoch 4: train loss = 0.6894210968036595, l1loss = 1.4091600701435731, train acc = 0.5403788685798645,\n",
  1367. "val_loss = 0.7221627840700076, val_acc = 0.3255814015865326\n",
  1368. "\n",
  1369. "epoch: 5\n",
  1370. "validation acc increased (0.325581 ---> 0.325581)\n",
  1371. "validation acc increased (0.325581 ---> 0.325581)\n",
  1372. "epoch 5: train loss = 0.688976728904283, l1loss = 1.3717759866657429, train acc = 0.5403788685798645,\n",
  1373. "val_loss = 0.7216660745384157, val_acc = 0.3255814015865326\n",
  1374. "\n",
  1375. "epoch: 6\n",
  1376. "validation acc increased (0.325581 ---> 0.325581)\n",
  1377. "validation acc increased (0.325581 ---> 0.325581)\n",
  1378. "epoch 6: train loss = 0.6885826580545839, l1loss = 1.324531928254505, train acc = 0.5403788685798645,\n",
  1379. "val_loss = 0.7219718865645949, val_acc = 0.3255814015865326\n",
  1380. "\n",
  1381. "epoch: 7\n",
  1382. "validation acc increased (0.325581 ---> 0.325581)\n",
  1383. "validation acc increased (0.325581 ---> 0.325581)\n",
  1384. "epoch 7: train loss = 0.6881202621569306, l1loss = 1.26606709008678, train acc = 0.5403788685798645,\n",
  1385. "val_loss = 0.7203695473744888, val_acc = 0.3255814015865326\n",
  1386. "\n",
  1387. "epoch: 8\n",
  1388. "validation acc increased (0.325581 ---> 0.325581)\n",
  1389. "validation acc increased (0.325581 ---> 0.325581)\n",
  1390. "epoch 8: train loss = 0.6871786932883448, l1loss = 1.1954490566776614, train acc = 0.5403788685798645,\n",
  1391. "val_loss = 0.7235223895819612, val_acc = 0.3255814015865326\n",
  1392. "\n",
  1393. "epoch: 9\n"
  1394. ]
  1395. },
  1396. {
  1397. "name": "stdout",
  1398. "output_type": "stream",
  1399. "text": [
  1400. "validation acc increased (0.325581 ---> 0.325581)\n",
  1401. "validation acc increased (0.325581 ---> 0.325581)\n",
  1402. "epoch 9: train loss = 0.6859321622168673, l1loss = 1.1123363999046334, train acc = 0.5403788685798645,\n",
  1403. "val_loss = 0.7276398294655851, val_acc = 0.3255814015865326\n",
  1404. "\n",
  1405. "epoch: 10\n",
  1406. "validation acc increased (0.325581 ---> 0.325581)\n",
  1407. "validation acc increased (0.325581 ---> 0.325581)\n",
  1408. "epoch 10: train loss = 0.6838702996849182, l1loss = 1.0172734350887158, train acc = 0.5403788685798645,\n",
  1409. "val_loss = 0.728483495324157, val_acc = 0.3294573724269867\n",
  1410. "\n",
  1411. "epoch: 11\n",
  1412. "validation acc increased (0.325581 ---> 0.329457)\n",
  1413. "validation loss decreased (0.699053 ---> 0.696120), val_acc = 0.4883720874786377\n",
  1414. "validation acc increased (0.329457 ---> 0.488372)\n",
  1415. "epoch 11: train loss = 0.6797237210235709, l1loss = 0.9119024760464015, train acc = 0.5677965879440308,\n",
  1416. "val_loss = 0.6912261903748032, val_acc = 0.5077519416809082\n",
  1417. "\n",
  1418. "epoch: 12\n",
  1419. "validation loss decreased (0.696120 ---> 0.692533), val_acc = 0.5155038833618164\n",
  1420. "validation acc increased (0.488372 ---> 0.515504)\n",
  1421. "epoch 12: train loss = 0.6665826211660238, l1loss = 0.7991669664710969, train acc = 0.6091724634170532,\n",
  1422. "val_loss = 0.7750974052636198, val_acc = 0.3527131676673889\n",
  1423. "\n",
  1424. "epoch: 13\n",
  1425. "epoch 13: train loss = 0.6396277128521015, l1loss = 0.6842512581664567, train acc = 0.6535393595695496,\n",
  1426. "val_loss = 0.637429671694142, val_acc = 0.6472868323326111\n",
  1427. "\n",
  1428. "epoch: 14\n",
  1429. "validation loss decreased (0.692533 ---> 0.639330), val_acc = 0.6666666865348816\n",
  1430. "validation acc increased (0.515504 ---> 0.666667)\n",
  1431. "epoch 14: train loss = 0.5922036307640114, l1loss = 0.5737616670571438, train acc = 0.6919242143630981,\n",
  1432. "val_loss = 0.7422484321187633, val_acc = 0.5465116500854492\n",
  1433. "\n",
  1434. "epoch: 15\n",
  1435. "epoch 15: train loss = 0.4942521263927905, l1loss = 0.4735739782586293, train acc = 0.776669979095459,\n",
  1436. "val_loss = 0.8247309513101282, val_acc = 0.6472868323326111\n",
  1437. "\n",
  1438. "epoch: 16\n",
  1439. "epoch 16: train loss = 0.4356083417402307, l1loss = 0.38790225777645054, train acc = 0.7911266088485718,\n",
  1440. "val_loss = 0.8231048981348673, val_acc = 0.6317829489707947\n",
  1441. "\n",
  1442. "epoch: 17\n",
  1443. "epoch 17: train loss = 0.3488304465563442, l1loss = 0.3178104176181382, train acc = 0.8534396886825562,\n",
  1444. "val_loss = 0.9343741550001987, val_acc = 0.5852712988853455\n",
  1445. "\n",
  1446. "epoch: 18\n",
  1447. "epoch 18: train loss = 0.2602038436670484, l1loss = 0.2649042825934181, train acc = 0.8983050584793091,\n",
  1448. "val_loss = 1.105872887511586, val_acc = 0.5968992114067078\n",
  1449. "\n",
  1450. "epoch: 19\n",
  1451. "epoch 19: train loss = 0.16879668974448533, l1loss = 0.2346440425898475, train acc = 0.9421734809875488,\n",
  1452. "val_loss = 1.0993869045908138, val_acc = 0.5852712988853455\n",
  1453. "\n",
  1454. "epoch: 20\n",
  1455. "epoch 20: train loss = 0.12419395824729029, l1loss = 0.22091067336492262, train acc = 0.9596211314201355,\n",
  1456. "val_loss = 1.4720232468242793, val_acc = 0.6317829489707947\n",
  1457. "\n",
  1458. "epoch: 21\n",
  1459. "epoch 21: train loss = 0.1183394953786792, l1loss = 0.20997104164610356, train acc = 0.9586241245269775,\n",
  1460. "val_loss = 1.61035810514938, val_acc = 0.6124030947685242\n",
  1461. "\n",
  1462. "epoch: 22\n",
  1463. "epoch 22: train loss = 0.048374136068588, l1loss = 0.20068507425390475, train acc = 0.9885343909263611,\n",
  1464. "val_loss = 1.7387195844749213, val_acc = 0.6395348906517029\n",
  1465. "\n",
  1466. "epoch: 23\n",
  1467. "epoch 23: train loss = 0.02681651372362229, l1loss = 0.19265636617497933, train acc = 0.9950149655342102,\n",
  1468. "val_loss = 1.9591693790384042, val_acc = 0.6356589198112488\n",
  1469. "\n",
  1470. "epoch: 24\n",
  1471. "epoch 24: train loss = 0.12268258048851018, l1loss = 0.19024286480808067, train acc = 0.9506480693817139,\n",
  1472. "val_loss = 1.5421607235605403, val_acc = 0.6007751822471619\n",
  1473. "\n",
  1474. "epoch: 25\n",
  1475. "epoch 25: train loss = 0.12608384469511028, l1loss = 0.1905509570422225, train acc = 0.9526420831680298,\n",
  1476. "val_loss = 1.6125494640089513, val_acc = 0.5775193572044373\n",
  1477. "\n",
  1478. "epoch: 26\n",
  1479. "epoch 26: train loss = 0.04950928593338546, l1loss = 0.1808391643366809, train acc = 0.9845463633537292,\n",
  1480. "val_loss = 1.5583655085674553, val_acc = 0.5968992114067078\n",
  1481. "\n",
  1482. "epoch: 27\n",
  1483. "epoch 27: train loss = 0.023618471229324492, l1loss = 0.1748058880611526, train acc = 0.9955134391784668,\n",
  1484. "val_loss = 1.6373267991598262, val_acc = 0.6317829489707947\n",
  1485. "\n",
  1486. "epoch: 28\n",
  1487. "epoch 28: train loss = 0.009286384627497861, l1loss = 0.16702499122526923, train acc = 0.9995014667510986,\n",
  1488. "val_loss = 1.7560179769530777, val_acc = 0.5736433863639832\n",
  1489. "\n",
  1490. "epoch: 29\n",
  1491. "epoch 29: train loss = 0.004661366911762852, l1loss = 0.1595693781808509, train acc = 1.0,\n",
  1492. "val_loss = 1.8593065498411194, val_acc = 0.5736433863639832\n",
  1493. "\n",
  1494. "!!! overfitted !!!\n",
  1495. "tensor(77)\n",
  1496. "tensor(40)\n",
  1497. "early stoping results:\n",
  1498. "\t [tensor(0.5833), tensor(0.6354), tensor(0.3646), tensor(0.6126)]\n",
  1499. "\t [tensor(0.5737), tensor(0.4652), tensor(0.5414), tensor(0.5115)]\n",
  1500. "tensor(67)\n",
  1501. "tensor(39)\n",
  1502. "full train results:\n",
  1503. "\t [tensor(0.5312), tensor(0.5365), tensor(0.5938), tensor(0.5550)]\n",
  1504. "\t [tensor(0.9995), tensor(1.), tensor(1.), tensor(1.)]\n",
  1505. "tensor(72)\n",
  1506. "tensor(45)\n",
  1507. "best accs results:\n",
  1508. "\t [tensor(0.6042), tensor(0.6354), tensor(0.5833), tensor(0.6126)]\n",
  1509. "\t [tensor(0.6064), tensor(0.4652), tensor(0.6461), tensor(0.5115)]\n",
  1510. "[0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1]\n",
  1511. "-----------------------------Fold 5---------------\n",
  1512. "preparing dataloaders...\n",
  1513. "coef when 0 > 1 1\n",
  1514. "creating model...\n",
  1515. "calculating total steps...\n",
  1516. "epoch: 1\n",
  1517. "validation loss decreased (inf ---> 0.680743), val_acc = 0.5930232405662537\n",
  1518. "validation acc increased (0.000000 ---> 0.593023)\n",
  1519. "validation acc increased (0.593023 ---> 0.593023)\n",
  1520. "epoch 1: train loss = 0.6970352851473715, l1loss = 1.4848045129920133, train acc = 0.4750629663467407,\n",
  1521. "val_loss = 0.6833919952082079, val_acc = 0.5930232405662537\n",
  1522. "\n",
  1523. "epoch: 2\n",
  1524. "validation acc increased (0.593023 ---> 0.593023)\n",
  1525. "validation acc increased (0.593023 ---> 0.593023)\n",
  1526. "epoch 2: train loss = 0.6939090573517441, l1loss = 1.4651734015202942, train acc = 0.49068009853363037,\n",
  1527. "val_loss = 0.6874888668688692, val_acc = 0.5930232405662537\n",
  1528. "\n",
  1529. "epoch: 3\n",
  1530. "validation acc increased (0.593023 ---> 0.593023)\n",
  1531. "epoch 3: train loss = 0.6920081238902787, l1loss = 1.4423043520084255, train acc = 0.5219143629074097,\n",
  1532. "val_loss = 0.6943495569303054, val_acc = 0.46124032139778137\n",
  1533. "\n",
  1534. "epoch: 4\n",
  1535. "epoch 4: train loss = 0.6908973218511875, l1loss = 1.4135601722323323, train acc = 0.5249370336532593,\n",
  1536. "val_loss = 0.7015382374903952, val_acc = 0.40697672963142395\n",
  1537. "\n",
  1538. "epoch: 5\n",
  1539. "epoch 5: train loss = 0.6905066842095978, l1loss = 1.37683726215843, train acc = 0.5249370336532593,\n",
  1540. "val_loss = 0.7027538501939108, val_acc = 0.40697672963142395\n",
  1541. "\n",
  1542. "epoch: 6\n",
  1543. "epoch 6: train loss = 0.689899053081157, l1loss = 1.33054918318011, train acc = 0.5254408121109009,\n",
  1544. "val_loss = 0.7039248813954435, val_acc = 0.40697672963142395\n",
  1545. "\n",
  1546. "epoch: 7\n",
  1547. "epoch 7: train loss = 0.6889980239291635, l1loss = 1.2732717847343655, train acc = 0.5254408121109009,\n",
  1548. "val_loss = 0.7034929233003956, val_acc = 0.40697672963142395\n",
  1549. "\n",
  1550. "epoch: 8\n",
  1551. "epoch 8: train loss = 0.6880159931158839, l1loss = 1.2041282832172115, train acc = 0.5471032857894897,\n",
  1552. "val_loss = 0.7000225662261017, val_acc = 0.43410852551460266\n",
  1553. "\n",
  1554. "epoch: 9\n",
  1555. "epoch 9: train loss = 0.6862827549953605, l1loss = 1.1230295440712263, train acc = 0.5496221780776978,\n",
  1556. "val_loss = 0.7079056478285974, val_acc = 0.41860464215278625\n",
  1557. "\n",
  1558. "epoch: 10\n",
  1559. "epoch 10: train loss = 0.6821412534197272, l1loss = 1.0302875974916992, train acc = 0.575818657875061,\n",
  1560. "val_loss = 0.7016733090082804, val_acc = 0.4728682041168213\n",
  1561. "\n",
  1562. "epoch: 11\n",
  1563. "epoch 11: train loss = 0.6741724207058962, l1loss = 0.9277493146564858, train acc = 0.6120907068252563,\n",
  1564. "val_loss = 0.7347481019737184, val_acc = 0.43798449635505676\n",
  1565. "\n",
  1566. "epoch: 12\n",
  1567. "epoch 12: train loss = 0.6619942795419573, l1loss = 0.8190199558020239, train acc = 0.6171284914016724,\n",
  1568. "val_loss = 0.7454949183057445, val_acc = 0.4728682041168213\n",
  1569. "\n",
  1570. "epoch: 13\n"
  1571. ]
  1572. },
  1573. {
  1574. "name": "stdout",
  1575. "output_type": "stream",
  1576. "text": [
  1577. "epoch 13: train loss = 0.637912023097502, l1loss = 0.7075121821023955, train acc = 0.6392946839332581,\n",
  1578. "val_loss = 0.7158694054729254, val_acc = 0.4922480583190918\n",
  1579. "\n",
  1580. "epoch: 14\n",
  1581. "epoch 14: train loss = 0.606015481966869, l1loss = 0.5979608490425034, train acc = 0.6806045174598694,\n",
  1582. "val_loss = 0.7537827140601107, val_acc = 0.5852712988853455\n",
  1583. "\n",
  1584. "epoch: 15\n",
  1585. "epoch 15: train loss = 0.5577827427189056, l1loss = 0.49493083063541193, train acc = 0.7193954586982727,\n",
  1586. "val_loss = 0.7263471793758777, val_acc = 0.565891444683075\n",
  1587. "\n",
  1588. "epoch: 16\n",
  1589. "epoch 16: train loss = 0.4718546643965791, l1loss = 0.40398407203124215, train acc = 0.7874055504798889,\n",
  1590. "val_loss = 1.424316991669263, val_acc = 0.41860464215278625\n",
  1591. "\n",
  1592. "epoch: 17\n",
  1593. "epoch 17: train loss = 0.3839408636843828, l1loss = 0.3292035478038211, train acc = 0.8372796177864075,\n",
  1594. "val_loss = 1.0002119476481002, val_acc = 0.5193798542022705\n",
  1595. "\n",
  1596. "epoch: 18\n",
  1597. "epoch 18: train loss = 0.3173008106217276, l1loss = 0.27490125777439145, train acc = 0.8755667209625244,\n",
  1598. "val_loss = 1.1687451869018317, val_acc = 0.5775193572044373\n",
  1599. "\n",
  1600. "epoch: 19\n",
  1601. "epoch 19: train loss = 0.16879188284765864, l1loss = 0.2434915311315498, train acc = 0.9465994834899902,\n",
  1602. "val_loss = 1.2804822773896447, val_acc = 0.5930232405662537\n",
  1603. "\n",
  1604. "epoch: 20\n",
  1605. "epoch 20: train loss = 0.13567791892689482, l1loss = 0.2295223989249477, train acc = 0.9491183757781982,\n",
  1606. "val_loss = 1.385039623393569, val_acc = 0.5620155334472656\n",
  1607. "\n",
  1608. "epoch: 21\n",
  1609. "validation acc increased (0.593023 ---> 0.624031)\n",
  1610. "epoch 21: train loss = 0.13815371947789973, l1loss = 0.22059971798277023, train acc = 0.9465994834899902,\n",
  1611. "val_loss = 2.150006371875142, val_acc = 0.5968992114067078\n",
  1612. "\n",
  1613. "epoch: 22\n",
  1614. "epoch 22: train loss = 0.10025845674858885, l1loss = 0.21158002366497175, train acc = 0.9622166156768799,\n",
  1615. "val_loss = 1.5413802104402883, val_acc = 0.5426356792449951\n",
  1616. "\n",
  1617. "epoch: 23\n",
  1618. "epoch 23: train loss = 0.030761046739534406, l1loss = 0.2031533922371396, train acc = 0.9939546585083008,\n",
  1619. "val_loss = 1.933851578438929, val_acc = 0.5891472697257996\n",
  1620. "\n",
  1621. "epoch: 24\n",
  1622. "epoch 24: train loss = 0.012612140763748052, l1loss = 0.19433939178434367, train acc = 0.9994962215423584,\n",
  1623. "val_loss = 1.86677174789961, val_acc = 0.5852712988853455\n",
  1624. "\n",
  1625. "epoch: 25\n",
  1626. "epoch 25: train loss = 0.007508058874616845, l1loss = 0.1846848609315358, train acc = 1.0,\n",
  1627. "val_loss = 1.7752613688624181, val_acc = 0.5930232405662537\n",
  1628. "\n",
  1629. "!!! overfitted !!!\n",
  1630. "tensor(80)\n",
  1631. "tensor(41)\n",
  1632. "early stoping results:\n",
  1633. "\t [tensor(0.5833), tensor(0.6354), tensor(0.3646), tensor(0.6126), tensor(0.6335)]\n",
  1634. "\t [tensor(0.5737), tensor(0.4652), tensor(0.5414), tensor(0.5115), tensor(0.4751)]\n",
  1635. "tensor(64)\n",
  1636. "tensor(36)\n",
  1637. "full train results:\n",
  1638. "\t [tensor(0.5312), tensor(0.5365), tensor(0.5938), tensor(0.5550), tensor(0.5236)]\n",
  1639. "\t [tensor(0.9995), tensor(1.), tensor(1.), tensor(1.), tensor(1.)]\n",
  1640. "tensor(75)\n",
  1641. "tensor(32)\n",
  1642. "best accs results:\n",
  1643. "\t [tensor(0.6042), tensor(0.6354), tensor(0.5833), tensor(0.6126), tensor(0.5602)]\n",
  1644. "\t [tensor(0.6064), tensor(0.4652), tensor(0.6461), tensor(0.5115), tensor(0.9521)]\n",
  1645. "[0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 1]\n",
  1646. "-----------------------------Fold 6---------------\n",
  1647. "preparing dataloaders...\n",
  1648. "coef when 0 > 1 1\n",
  1649. "creating model...\n",
  1650. "calculating total steps...\n",
  1651. "epoch: 1\n",
  1652. "validation loss decreased (inf ---> 0.720092), val_acc = 0.3720930218696594\n",
  1653. "validation acc increased (0.000000 ---> 0.372093)\n",
  1654. "validation loss decreased (0.720092 ---> 0.713030), val_acc = 0.3720930218696594\n",
  1655. "validation acc increased (0.372093 ---> 0.372093)\n",
  1656. "epoch 1: train loss = 0.6914953364651564, l1loss = 1.4853989268497096, train acc = 0.5315948128700256,\n",
  1657. "val_loss = 0.7110218992528989, val_acc = 0.3720930218696594\n",
  1658. "\n",
  1659. "epoch: 2\n",
  1660. "validation loss decreased (0.713030 ---> 0.710726), val_acc = 0.3720930218696594\n",
  1661. "validation acc increased (0.372093 ---> 0.372093)\n",
  1662. "validation loss decreased (0.710726 ---> 0.709240), val_acc = 0.3720930218696594\n",
  1663. "validation acc increased (0.372093 ---> 0.372093)\n",
  1664. "epoch 2: train loss = 0.6908083540745221, l1loss = 1.4647457792144363, train acc = 0.5315948128700256,\n",
  1665. "val_loss = 0.7090809151183727, val_acc = 0.3720930218696594\n",
  1666. "\n",
  1667. "epoch: 3\n",
  1668. "validation loss decreased (0.709240 ---> 0.709072), val_acc = 0.3720930218696594\n",
  1669. "validation acc increased (0.372093 ---> 0.372093)\n",
  1670. "validation loss decreased (0.709072 ---> 0.709056), val_acc = 0.3720930218696594\n",
  1671. "validation acc increased (0.372093 ---> 0.372093)\n",
  1672. "epoch 3: train loss = 0.6904735120990451, l1loss = 1.4407883356423412, train acc = 0.5315948128700256,\n",
  1673. "val_loss = 0.7095573918763981, val_acc = 0.3720930218696594\n",
  1674. "\n",
  1675. "epoch: 4\n",
  1676. "validation acc increased (0.372093 ---> 0.372093)\n",
  1677. "validation acc increased (0.372093 ---> 0.372093)\n",
  1678. "epoch 4: train loss = 0.6901337791947927, l1loss = 1.4112883624485288, train acc = 0.5315948128700256,\n",
  1679. "val_loss = 0.7101693338201952, val_acc = 0.3720930218696594\n",
  1680. "\n",
  1681. "epoch: 5\n",
  1682. "validation acc increased (0.372093 ---> 0.372093)\n",
  1683. "validation acc increased (0.372093 ---> 0.372093)\n",
  1684. "epoch 5: train loss = 0.6895435992310733, l1loss = 1.3740417406097458, train acc = 0.5315948128700256,\n",
  1685. "val_loss = 0.7115494068278823, val_acc = 0.3720930218696594\n",
  1686. "\n",
  1687. "epoch: 6\n",
  1688. "validation acc increased (0.372093 ---> 0.372093)\n",
  1689. "validation acc increased (0.372093 ---> 0.372093)\n",
  1690. "epoch 6: train loss = 0.6889353504510916, l1loss = 1.3271373138504259, train acc = 0.5315948128700256,\n",
  1691. "val_loss = 0.7145411603210509, val_acc = 0.3720930218696594\n",
  1692. "\n",
  1693. "epoch: 7\n",
  1694. "validation acc increased (0.372093 ---> 0.372093)\n",
  1695. "validation acc increased (0.372093 ---> 0.372093)\n",
  1696. "epoch 7: train loss = 0.6881078454771395, l1loss = 1.2690683422739073, train acc = 0.5315948128700256,\n",
  1697. "val_loss = 0.7155752542407013, val_acc = 0.3720930218696594\n",
  1698. "\n",
  1699. "epoch: 8\n",
  1700. "validation acc increased (0.372093 ---> 0.372093)\n",
  1701. "validation loss decreased (0.709056 ---> 0.707930), val_acc = 0.3682170510292053\n",
  1702. "epoch 8: train loss = 0.6867945939152984, l1loss = 1.198984569627997, train acc = 0.5461384057998657,\n",
  1703. "val_loss = 0.7080177447592565, val_acc = 0.3720930218696594\n",
  1704. "\n",
  1705. "epoch: 9\n",
  1706. "validation acc increased (0.372093 ---> 0.372093)\n",
  1707. "epoch 9: train loss = 0.6847166879963851, l1loss = 1.1167998019766545, train acc = 0.5511534810066223,\n",
  1708. "val_loss = 0.715700043726337, val_acc = 0.41472867131233215\n",
  1709. "\n",
  1710. "epoch: 10\n",
  1711. "validation acc increased (0.372093 ---> 0.414729)\n",
  1712. "epoch 10: train loss = 0.6799348440906827, l1loss = 1.0231370177407682, train acc = 0.5797392129898071,\n",
  1713. "val_loss = 0.7157923727072486, val_acc = 0.41860464215278625\n",
  1714. "\n",
  1715. "epoch: 11\n",
  1716. "validation acc increased (0.414729 ---> 0.449612)\n",
  1717. "validation loss decreased (0.707930 ---> 0.696232), val_acc = 0.5116279125213623\n",
  1718. "validation acc increased (0.449612 ---> 0.511628)\n",
  1719. "epoch 11: train loss = 0.672254242533547, l1loss = 0.9196460943523357, train acc = 0.6078234910964966,\n",
  1720. "val_loss = 0.7157423847405485, val_acc = 0.45348837971687317\n",
  1721. "\n",
  1722. "epoch: 12\n",
  1723. "validation acc increased (0.511628 ---> 0.515504)\n",
  1724. "epoch 12: train loss = 0.6543741083312536, l1loss = 0.8098590250240046, train acc = 0.6424272656440735,\n",
  1725. "val_loss = 0.8028264553971993, val_acc = 0.43410852551460266\n",
  1726. "\n",
  1727. "epoch: 13\n",
  1728. "validation acc increased (0.515504 ---> 0.554264)\n",
  1729. "epoch 13: train loss = 0.6407015342769795, l1loss = 0.6978255600006196, train acc = 0.6399197578430176,\n",
  1730. "val_loss = 0.7011320503183114, val_acc = 0.6085271239280701\n",
  1731. "\n",
  1732. "epoch: 14\n",
  1733. "validation acc increased (0.554264 ---> 0.558140)\n",
  1734. "validation acc increased (0.558140 ---> 0.600775)\n",
  1735. "epoch 14: train loss = 0.5929560238045221, l1loss = 0.5874572954660912, train acc = 0.6965897679328918,\n",
  1736. "val_loss = 0.7719989308091098, val_acc = 0.5\n",
  1737. "\n",
  1738. "epoch: 15\n",
  1739. "validation acc increased (0.600775 ---> 0.616279)\n",
  1740. "epoch 15: train loss = 0.5356296086263513, l1loss = 0.4843397445700233, train acc = 0.7311936020851135,\n",
  1741. "val_loss = 1.2483898142511531, val_acc = 0.3798449635505676\n",
  1742. "\n",
  1743. "epoch: 16\n"
  1744. ]
  1745. },
  1746. {
  1747. "name": "stdout",
  1748. "output_type": "stream",
  1749. "text": [
  1750. "epoch 16: train loss = 0.4664823236831332, l1loss = 0.39368396988484183, train acc = 0.7918756008148193,\n",
  1751. "val_loss = 0.9508252638254979, val_acc = 0.5968992114067078\n",
  1752. "\n",
  1753. "epoch: 17\n",
  1754. "epoch 17: train loss = 0.3485482805889612, l1loss = 0.32071299114102947, train acc = 0.8515546917915344,\n",
  1755. "val_loss = 1.4121340199034342, val_acc = 0.43023255467414856\n",
  1756. "\n",
  1757. "epoch: 18\n",
  1758. "epoch 18: train loss = 0.2641293948428442, l1loss = 0.26571550107408787, train acc = 0.8966900706291199,\n",
  1759. "val_loss = 2.8454202820164287, val_acc = 0.356589138507843\n",
  1760. "\n",
  1761. "epoch: 19\n",
  1762. "epoch 19: train loss = 0.2165779782872023, l1loss = 0.23378948680496026, train acc = 0.9132397174835205,\n",
  1763. "val_loss = 1.407352695169375, val_acc = 0.5193798542022705\n",
  1764. "\n",
  1765. "epoch: 20\n",
  1766. "epoch 20: train loss = 0.14256087192890038, l1loss = 0.21813757950766993, train acc = 0.9538615942001343,\n",
  1767. "val_loss = 1.7007370909979178, val_acc = 0.4883720874786377\n",
  1768. "\n",
  1769. "epoch: 21\n",
  1770. "epoch 21: train loss = 0.0572467178868554, l1loss = 0.20835738851050795, train acc = 0.9869608879089355,\n",
  1771. "val_loss = 1.6843582021086947, val_acc = 0.5581395626068115\n",
  1772. "\n",
  1773. "epoch: 22\n",
  1774. "epoch 22: train loss = 0.03993170134117874, l1loss = 0.20134628987838415, train acc = 0.9904714226722717,\n",
  1775. "val_loss = 1.7536737527770474, val_acc = 0.565891444683075\n",
  1776. "\n",
  1777. "epoch: 23\n",
  1778. "epoch 23: train loss = 0.037888755216641555, l1loss = 0.19423958542593264, train acc = 0.9899699091911316,\n",
  1779. "val_loss = 1.9527307499286741, val_acc = 0.5581395626068115\n",
  1780. "\n",
  1781. "epoch: 24\n",
  1782. "epoch 24: train loss = 0.05171581631976599, l1loss = 0.18909807017478444, train acc = 0.9849548935890198,\n",
  1783. "val_loss = 2.5161986521972244, val_acc = 0.5852712988853455\n",
  1784. "\n",
  1785. "epoch: 25\n",
  1786. "epoch 25: train loss = 0.07914294568875967, l1loss = 0.18467072341422977, train acc = 0.9719157218933105,\n",
  1787. "val_loss = 3.2461071727190944, val_acc = 0.45348837971687317\n",
  1788. "\n",
  1789. "epoch: 26\n",
  1790. "epoch 26: train loss = 0.057620100116538425, l1loss = 0.18046283077751263, train acc = 0.9789367914199829,\n",
  1791. "val_loss = 2.2363041315891947, val_acc = 0.6279069781303406\n",
  1792. "\n",
  1793. "epoch: 27\n",
  1794. "validation acc increased (0.616279 ---> 0.624031)\n",
  1795. "epoch 27: train loss = 0.035453362097083746, l1loss = 0.17314500579145273, train acc = 0.9899699091911316,\n",
  1796. "val_loss = 2.3932818331459695, val_acc = 0.5116279125213623\n",
  1797. "\n",
  1798. "epoch: 28\n",
  1799. "epoch 28: train loss = 0.017172363360665566, l1loss = 0.1665714439621182, train acc = 0.9964894652366638,\n",
  1800. "val_loss = 2.209194153778313, val_acc = 0.5736433863639832\n",
  1801. "\n",
  1802. "epoch: 29\n",
  1803. "epoch 29: train loss = 0.010634527754677514, l1loss = 0.16049564522929272, train acc = 0.9989969730377197,\n",
  1804. "val_loss = 2.0558686441229295, val_acc = 0.5116279125213623\n",
  1805. "\n",
  1806. "epoch: 30\n",
  1807. "epoch 30: train loss = 0.005607646728614523, l1loss = 0.15422261991316719, train acc = 1.0,\n",
  1808. "val_loss = 2.0884801685348036, val_acc = 0.5581395626068115\n",
  1809. "\n",
  1810. "!!! overfitted !!!\n",
  1811. "tensor(65)\n",
  1812. "tensor(33)\n",
  1813. "early stoping results:\n",
  1814. "\t [tensor(0.5833), tensor(0.6354), tensor(0.3646), tensor(0.6126), tensor(0.6335), tensor(0.5131)]\n",
  1815. "\t [tensor(0.5737), tensor(0.4652), tensor(0.5414), tensor(0.5115), tensor(0.4751), tensor(0.5817)]\n",
  1816. "tensor(73)\n",
  1817. "tensor(35)\n",
  1818. "full train results:\n",
  1819. "\t [tensor(0.5312), tensor(0.5365), tensor(0.5938), tensor(0.5550), tensor(0.5236), tensor(0.5654)]\n",
  1820. "\t [tensor(0.9995), tensor(1.), tensor(1.), tensor(1.), tensor(1.), tensor(1.)]\n",
  1821. "tensor(77)\n",
  1822. "tensor(43)\n",
  1823. "best accs results:\n",
  1824. "\t [tensor(0.6042), tensor(0.6354), tensor(0.5833), tensor(0.6126), tensor(0.5602), tensor(0.6283)]\n",
  1825. "\t [tensor(0.6064), tensor(0.4652), tensor(0.6461), tensor(0.5115), tensor(0.9521), tensor(0.7914)]\n",
  1826. "[0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0]\n",
  1827. "-----------------------------Fold 7---------------\n",
  1828. "preparing dataloaders...\n",
  1829. "coef when 0 > 1 1\n",
  1830. "creating model...\n",
  1831. "calculating total steps...\n",
  1832. "epoch: 1\n",
  1833. "validation loss decreased (inf ---> 0.680365), val_acc = 0.643410861492157\n",
  1834. "validation acc increased (0.000000 ---> 0.643411)\n",
  1835. "validation acc increased (0.643411 ---> 0.643411)\n",
  1836. "epoch 1: train loss = 0.6937421182204987, l1loss = 1.4836181387297804, train acc = 0.48824411630630493,\n",
  1837. "val_loss = 0.6849656885908556, val_acc = 0.643410861492157\n",
  1838. "\n",
  1839. "epoch: 2\n",
  1840. "validation acc increased (0.643411 ---> 0.643411)\n",
  1841. "epoch 2: train loss = 0.6914882965538727, l1loss = 1.46388185632533, train acc = 0.5357679128646851,\n",
  1842. "val_loss = 0.6942343249801517, val_acc = 0.44573643803596497\n",
  1843. "\n",
  1844. "epoch: 3\n",
  1845. "epoch 3: train loss = 0.6897589440581917, l1loss = 1.4407587112218754, train acc = 0.5352676510810852,\n",
  1846. "val_loss = 0.7051206457522489, val_acc = 0.356589138507843\n",
  1847. "\n",
  1848. "epoch: 4\n",
  1849. "epoch 4: train loss = 0.689329623609498, l1loss = 1.4119728003340164, train acc = 0.5352676510810852,\n",
  1850. "val_loss = 0.7143772563268972, val_acc = 0.356589138507843\n",
  1851. "\n",
  1852. "epoch: 5\n",
  1853. "epoch 5: train loss = 0.6887832490905277, l1loss = 1.375555465255993, train acc = 0.5352676510810852,\n",
  1854. "val_loss = 0.7188024733417718, val_acc = 0.356589138507843\n",
  1855. "\n",
  1856. "epoch: 6\n",
  1857. "epoch 6: train loss = 0.6878896750409106, l1loss = 1.3297643308463007, train acc = 0.5352676510810852,\n",
  1858. "val_loss = 0.7172646388527035, val_acc = 0.356589138507843\n",
  1859. "\n",
  1860. "epoch: 7\n",
  1861. "epoch 7: train loss = 0.6867665288864583, l1loss = 1.273135135625827, train acc = 0.5357679128646851,\n",
  1862. "val_loss = 0.7201978106831395, val_acc = 0.356589138507843\n",
  1863. "\n",
  1864. "epoch: 8\n",
  1865. "epoch 8: train loss = 0.685014221297317, l1loss = 1.2046409373047233, train acc = 0.5402701497077942,\n",
  1866. "val_loss = 0.7213430737340173, val_acc = 0.356589138507843\n",
  1867. "\n",
  1868. "epoch: 9\n",
  1869. "epoch 9: train loss = 0.681378725679473, l1loss = 1.1243868157409203, train acc = 0.5477738976478577,\n",
  1870. "val_loss = 0.73843981591306, val_acc = 0.356589138507843\n",
  1871. "\n",
  1872. "epoch: 10\n",
  1873. "epoch 10: train loss = 0.6753599447808067, l1loss = 1.033160796071244, train acc = 0.5892946720123291,\n",
  1874. "val_loss = 0.7419206292130226, val_acc = 0.3682170510292053\n",
  1875. "\n",
  1876. "epoch: 11\n",
  1877. "epoch 11: train loss = 0.6624713613129426, l1loss = 0.9328325188654909, train acc = 0.614307165145874,\n",
  1878. "val_loss = 0.688416297583617, val_acc = 0.5348837375640869\n",
  1879. "\n",
  1880. "epoch: 12\n",
  1881. "epoch 12: train loss = 0.6369260627070089, l1loss = 0.8273587237005534, train acc = 0.6518259048461914,\n",
  1882. "val_loss = 0.7440652865772099, val_acc = 0.4689922332763672\n",
  1883. "\n",
  1884. "epoch: 13\n",
  1885. "epoch 13: train loss = 0.593051583365001, l1loss = 0.7211097616622184, train acc = 0.6943472027778625,\n",
  1886. "val_loss = 0.7471253086430157, val_acc = 0.6356589198112488\n",
  1887. "\n",
  1888. "epoch: 14\n",
  1889. "epoch 14: train loss = 0.5171202323387837, l1loss = 0.619771498659362, train acc = 0.751375675201416,\n",
  1890. "val_loss = 0.8451099765393161, val_acc = 0.5930232405662537\n",
  1891. "\n",
  1892. "epoch: 15\n",
  1893. "epoch 15: train loss = 0.4243705461715805, l1loss = 0.5292422390538731, train acc = 0.8104051947593689,\n",
  1894. "val_loss = 0.9671652009782865, val_acc = 0.5348837375640869\n",
  1895. "\n",
  1896. "epoch: 16\n",
  1897. "epoch 16: train loss = 0.3493039304164125, l1loss = 0.45268860447400805, train acc = 0.8559279441833496,\n",
  1898. "val_loss = 1.1312974515811418, val_acc = 0.43798449635505676\n",
  1899. "\n",
  1900. "epoch: 17\n",
  1901. "epoch 17: train loss = 0.2992922716763331, l1loss = 0.3898296642148417, train acc = 0.8874437212944031,\n",
  1902. "val_loss = 1.2580841289934261, val_acc = 0.5775193572044373\n",
  1903. "\n",
  1904. "epoch: 18\n",
  1905. "epoch 18: train loss = 0.17743351521821188, l1loss = 0.3432618670310898, train acc = 0.9384692311286926,\n",
  1906. "val_loss = 1.4614967828573182, val_acc = 0.45348837971687317\n",
  1907. "\n",
  1908. "epoch: 19\n",
  1909. "epoch 19: train loss = 0.12819814724123674, l1loss = 0.3091015477309291, train acc = 0.9579789638519287,\n",
  1910. "val_loss = 2.0725105788356575, val_acc = 0.6395348906517029\n",
  1911. "\n",
  1912. "epoch: 20\n",
  1913. "epoch 20: train loss = 0.14354813504794528, l1loss = 0.2900429670067654, train acc = 0.9454727172851562,\n",
  1914. "val_loss = 2.0894346117049225, val_acc = 0.43798449635505676\n",
  1915. "\n",
  1916. "epoch: 21\n",
  1917. "epoch 21: train loss = 0.09740429143400386, l1loss = 0.276254555369807, train acc = 0.9649825096130371,\n",
  1918. "val_loss = 1.6611481997393822, val_acc = 0.5155038833618164\n",
  1919. "\n",
  1920. "epoch: 22\n"
  1921. ]
  1922. },
  1923. {
  1924. "name": "stdout",
  1925. "output_type": "stream",
  1926. "text": [
  1927. "epoch 22: train loss = 0.03162910269686077, l1loss = 0.2640630051963266, train acc = 0.9929965138435364,\n",
  1928. "val_loss = 2.8269113385400106, val_acc = 0.41472867131233215\n",
  1929. "\n",
  1930. "epoch: 23\n",
  1931. "epoch 23: train loss = 0.02073831153961794, l1loss = 0.25253280331606626, train acc = 0.9969984889030457,\n",
  1932. "val_loss = 1.9658752079157866, val_acc = 0.604651153087616\n",
  1933. "\n",
  1934. "epoch: 24\n",
  1935. "epoch 24: train loss = 0.012647862177856524, l1loss = 0.24088280176508123, train acc = 0.9984992742538452,\n",
  1936. "val_loss = 2.1244828877320816, val_acc = 0.5348837375640869\n",
  1937. "\n",
  1938. "epoch: 25\n",
  1939. "epoch 25: train loss = 0.006505623935028098, l1loss = 0.2303655490748819, train acc = 1.0,\n",
  1940. "val_loss = 2.1552340097205582, val_acc = 0.4961240291595459\n",
  1941. "\n",
  1942. "!!! overfitted !!!\n",
  1943. "tensor(80)\n",
  1944. "tensor(42)\n",
  1945. "early stoping results:\n",
  1946. "\t [tensor(0.5833), tensor(0.6354), tensor(0.3646), tensor(0.6126), tensor(0.6335), tensor(0.5131), tensor(0.6387)]\n",
  1947. "\t [tensor(0.5737), tensor(0.4652), tensor(0.5414), tensor(0.5115), tensor(0.4751), tensor(0.5817), tensor(0.4647)]\n",
  1948. "tensor(79)\n",
  1949. "tensor(36)\n",
  1950. "full train results:\n",
  1951. "\t [tensor(0.5312), tensor(0.5365), tensor(0.5938), tensor(0.5550), tensor(0.5236), tensor(0.5654), tensor(0.6021)]\n",
  1952. "\t [tensor(0.9995), tensor(1.), tensor(1.), tensor(1.), tensor(1.), tensor(1.), tensor(0.9990)]\n",
  1953. "tensor(81)\n",
  1954. "tensor(41)\n",
  1955. "best accs results:\n",
  1956. "\t [tensor(0.6042), tensor(0.6354), tensor(0.5833), tensor(0.6126), tensor(0.5602), tensor(0.6283), tensor(0.6387)]\n",
  1957. "\t [tensor(0.6064), tensor(0.4652), tensor(0.6461), tensor(0.5115), tensor(0.9521), tensor(0.7914), tensor(0.4647)]\n",
  1958. "[1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1]\n",
  1959. "-----------------------------Fold 8---------------\n",
  1960. "preparing dataloaders...\n",
  1961. "coef when 0 > 1 1\n",
  1962. "creating model...\n",
  1963. "calculating total steps...\n",
  1964. "epoch: 1\n",
  1965. "validation loss decreased (inf ---> 0.704280), val_acc = 0.356589138507843\n",
  1966. "validation acc increased (0.000000 ---> 0.356589)\n",
  1967. "validation acc increased (0.356589 ---> 0.356589)\n",
  1968. "epoch 1: train loss = 0.6914031615908471, l1loss = 1.4855487954324338, train acc = 0.5352676510810852,\n",
  1969. "val_loss = 0.7122278712516608, val_acc = 0.356589138507843\n",
  1970. "\n",
  1971. "epoch: 2\n",
  1972. "validation acc increased (0.356589 ---> 0.356589)\n",
  1973. "validation acc increased (0.356589 ---> 0.356589)\n",
  1974. "epoch 2: train loss = 0.6906723660370777, l1loss = 1.4651722050476457, train acc = 0.5352676510810852,\n",
  1975. "val_loss = 0.7149748654328576, val_acc = 0.356589138507843\n",
  1976. "\n",
  1977. "epoch: 3\n",
  1978. "validation acc increased (0.356589 ---> 0.356589)\n",
  1979. "validation acc increased (0.356589 ---> 0.356589)\n",
  1980. "epoch 3: train loss = 0.6902823968193661, l1loss = 1.4415272097995486, train acc = 0.5352676510810852,\n",
  1981. "val_loss = 0.7170244894286458, val_acc = 0.356589138507843\n",
  1982. "\n",
  1983. "epoch: 4\n",
  1984. "validation acc increased (0.356589 ---> 0.356589)\n",
  1985. "validation acc increased (0.356589 ---> 0.356589)\n",
  1986. "epoch 4: train loss = 0.6898218136897619, l1loss = 1.412299975327935, train acc = 0.5352676510810852,\n",
  1987. "val_loss = 0.7177939595178117, val_acc = 0.356589138507843\n",
  1988. "\n",
  1989. "epoch: 5\n",
  1990. "validation acc increased (0.356589 ---> 0.356589)\n",
  1991. "validation acc increased (0.356589 ---> 0.356589)\n",
  1992. "epoch 5: train loss = 0.6893795061672015, l1loss = 1.3753384496045744, train acc = 0.5352676510810852,\n",
  1993. "val_loss = 0.715468819751296, val_acc = 0.356589138507843\n",
  1994. "\n",
  1995. "epoch: 6\n",
  1996. "validation acc increased (0.356589 ---> 0.356589)\n",
  1997. "validation acc increased (0.356589 ---> 0.356589)\n",
  1998. "epoch 6: train loss = 0.6886429281876408, l1loss = 1.3287717812415538, train acc = 0.5352676510810852,\n",
  1999. "val_loss = 0.7151580522226733, val_acc = 0.356589138507843\n",
  2000. "\n",
  2001. "epoch: 7\n",
  2002. "validation acc increased (0.356589 ---> 0.356589)\n",
  2003. "validation acc increased (0.356589 ---> 0.356589)\n",
  2004. "epoch 7: train loss = 0.6878388487976631, l1loss = 1.27108275150883, train acc = 0.5352676510810852,\n",
  2005. "val_loss = 0.7185324430465698, val_acc = 0.356589138507843\n",
  2006. "\n",
  2007. "epoch: 8\n",
  2008. "validation acc increased (0.356589 ---> 0.356589)\n",
  2009. "validation acc increased (0.356589 ---> 0.375969)\n",
  2010. "epoch 8: train loss = 0.687089113010771, l1loss = 1.2014408532591567, train acc = 0.5377689003944397,\n",
  2011. "val_loss = 0.7098801186842512, val_acc = 0.3759689927101135\n",
  2012. "\n",
  2013. "epoch: 9\n",
  2014. "epoch 9: train loss = 0.6857657482052756, l1loss = 1.1198362330903764, train acc = 0.5357679128646851,\n",
  2015. "val_loss = 0.7146235921586207, val_acc = 0.3720930218696594\n",
  2016. "\n",
  2017. "epoch: 10\n",
  2018. "validation acc increased (0.375969 ---> 0.391473)\n",
  2019. "validation loss decreased (0.704280 ---> 0.700316), val_acc = 0.43410852551460266\n",
  2020. "validation acc increased (0.391473 ---> 0.434109)\n",
  2021. "epoch 10: train loss = 0.6807256672846311, l1loss = 1.0265806999547653, train acc = 0.5732866525650024,\n",
  2022. "val_loss = 0.6975175391796024, val_acc = 0.44573643803596497\n",
  2023. "\n",
  2024. "epoch: 11\n",
  2025. "validation loss decreased (0.700316 ---> 0.692875), val_acc = 0.4651162922382355\n",
  2026. "validation acc increased (0.434109 ---> 0.465116)\n",
  2027. "validation loss decreased (0.692875 ---> 0.687179), val_acc = 0.5348837375640869\n",
  2028. "validation acc increased (0.465116 ---> 0.534884)\n",
  2029. "epoch 11: train loss = 0.673976903053568, l1loss = 0.9234890513505979, train acc = 0.5992996692657471,\n",
  2030. "val_loss = 0.7138355277305426, val_acc = 0.43798449635505676\n",
  2031. "\n",
  2032. "epoch: 12\n",
  2033. "validation loss decreased (0.687179 ---> 0.655739), val_acc = 0.6472868323326111\n",
  2034. "validation acc increased (0.534884 ---> 0.647287)\n",
  2035. "epoch 12: train loss = 0.6544435339727779, l1loss = 0.8135775458878312, train acc = 0.6203101277351379,\n",
  2036. "val_loss = 0.8063760263513226, val_acc = 0.38759690523147583\n",
  2037. "\n",
  2038. "epoch: 13\n",
  2039. "epoch 13: train loss = 0.6309017556318347, l1loss = 0.7015269837300738, train acc = 0.639319658279419,\n",
  2040. "val_loss = 0.8180261628572331, val_acc = 0.42635658383369446\n",
  2041. "\n",
  2042. "epoch: 14\n",
  2043. "epoch 14: train loss = 0.6293001123939292, l1loss = 0.5926253258913621, train acc = 0.6403201818466187,\n",
  2044. "val_loss = 0.6947672367095947, val_acc = 0.5581395626068115\n",
  2045. "\n",
  2046. "epoch: 15\n",
  2047. "epoch 15: train loss = 0.5548516084934366, l1loss = 0.4896452505329718, train acc = 0.7323662042617798,\n",
  2048. "val_loss = 0.8717553079590317, val_acc = 0.46124032139778137\n",
  2049. "\n",
  2050. "epoch: 16\n",
  2051. "epoch 16: train loss = 0.4710256685430614, l1loss = 0.3975880754745382, train acc = 0.7793896794319153,\n",
  2052. "val_loss = 0.8850018007810726, val_acc = 0.6317829489707947\n",
  2053. "\n",
  2054. "epoch: 17\n",
  2055. "epoch 17: train loss = 0.40750853099365003, l1loss = 0.3214303405956366, train acc = 0.8174086809158325,\n",
  2056. "val_loss = 0.8342532101412152, val_acc = 0.6007751822471619\n",
  2057. "\n",
  2058. "epoch: 18\n",
  2059. "epoch 18: train loss = 0.33526081428222504, l1loss = 0.2646906571411502, train acc = 0.8569284677505493,\n",
  2060. "val_loss = 1.3167366464008656, val_acc = 0.4844961166381836\n",
  2061. "\n",
  2062. "epoch: 19\n",
  2063. "epoch 19: train loss = 0.24304662869237792, l1loss = 0.23158110821318661, train acc = 0.9064532518386841,\n",
  2064. "val_loss = 1.1697414286376895, val_acc = 0.6472868323326111\n",
  2065. "\n",
  2066. "epoch: 20\n",
  2067. "epoch 20: train loss = 0.20508646148034726, l1loss = 0.21585142629602183, train acc = 0.9144572019577026,\n",
  2068. "val_loss = 1.3775318278822788, val_acc = 0.6007751822471619\n",
  2069. "\n",
  2070. "epoch: 21\n",
  2071. "epoch 21: train loss = 0.12821232205468336, l1loss = 0.2060455378366745, train acc = 0.9564782381057739,\n",
  2072. "val_loss = 1.4499094604059708, val_acc = 0.6162790656089783\n",
  2073. "\n",
  2074. "epoch: 22\n",
  2075. "epoch 22: train loss = 0.05814009320071007, l1loss = 0.19943478927873504, train acc = 0.9834917187690735,\n",
  2076. "val_loss = 1.5952766163404597, val_acc = 0.5581395626068115\n",
  2077. "\n",
  2078. "epoch: 23\n",
  2079. "epoch 23: train loss = 0.043333539148497247, l1loss = 0.19244094265288506, train acc = 0.9879940152168274,\n",
  2080. "val_loss = 2.111579608518717, val_acc = 0.5155038833618164\n",
  2081. "\n",
  2082. "epoch: 24\n",
  2083. "epoch 24: train loss = 0.05830567527288434, l1loss = 0.18697950598029509, train acc = 0.980990469455719,\n",
  2084. "val_loss = 1.6956204277600428, val_acc = 0.5620155334472656\n",
  2085. "\n",
  2086. "epoch: 25\n",
  2087. "epoch 25: train loss = 0.05772003486730147, l1loss = 0.18347380734819838, train acc = 0.9799900054931641,\n",
  2088. "val_loss = 2.1974048300306928, val_acc = 0.4961240291595459\n",
  2089. "\n",
  2090. "epoch: 26\n"
  2091. ]
  2092. },
  2093. {
  2094. "name": "stdout",
  2095. "output_type": "stream",
  2096. "text": [
  2097. "epoch 26: train loss = 0.10959840359062836, l1loss = 0.18189016606969916, train acc = 0.9599800109863281,\n",
  2098. "val_loss = 2.2403581733851468, val_acc = 0.5775193572044373\n",
  2099. "\n",
  2100. "epoch: 27\n",
  2101. "epoch 27: train loss = 0.05633044383279409, l1loss = 0.17777645592393726, train acc = 0.9814907312393188,\n",
  2102. "val_loss = 3.4118495135344276, val_acc = 0.41085270047187805\n",
  2103. "\n",
  2104. "epoch: 28\n",
  2105. "epoch 28: train loss = 0.04293046985380631, l1loss = 0.17149413792624243, train acc = 0.9854927659034729,\n",
  2106. "val_loss = 2.1026175188463787, val_acc = 0.5852712988853455\n",
  2107. "\n",
  2108. "epoch: 29\n",
  2109. "epoch 29: train loss = 0.018007467739869203, l1loss = 0.16545488913098594, train acc = 0.994997501373291,\n",
  2110. "val_loss = 1.9395856857299805, val_acc = 0.569767415523529\n",
  2111. "\n",
  2112. "epoch: 30\n",
  2113. "epoch 30: train loss = 0.012381990377992959, l1loss = 0.15917553348264557, train acc = 0.9984992742538452,\n",
  2114. "val_loss = 2.3071911690337945, val_acc = 0.5116279125213623\n",
  2115. "\n",
  2116. "epoch: 31\n",
  2117. "epoch 31: train loss = 0.00495839840150147, l1loss = 0.15346068114772804, train acc = 0.9994997382164001,\n",
  2118. "val_loss = 2.027354578639186, val_acc = 0.5968992114067078\n",
  2119. "\n",
  2120. "epoch: 32\n",
  2121. "epoch 32: train loss = 0.006408099178271391, l1loss = 0.14823033495537336, train acc = 1.0,\n",
  2122. "val_loss = 1.955235477565795, val_acc = 0.569767415523529\n",
  2123. "\n",
  2124. "!!! overfitted !!!\n",
  2125. "tensor(78)\n",
  2126. "tensor(43)\n",
  2127. "early stoping results:\n",
  2128. "\t [tensor(0.5833), tensor(0.6354), tensor(0.3646), tensor(0.6126), tensor(0.6335), tensor(0.5131), tensor(0.6387), tensor(0.6335)]\n",
  2129. "\t [tensor(0.5737), tensor(0.4652), tensor(0.5414), tensor(0.5115), tensor(0.4751), tensor(0.5817), tensor(0.4647), tensor(0.5123)]\n",
  2130. "tensor(64)\n",
  2131. "tensor(37)\n",
  2132. "full train results:\n",
  2133. "\t [tensor(0.5312), tensor(0.5365), tensor(0.5938), tensor(0.5550), tensor(0.5236), tensor(0.5654), tensor(0.6021), tensor(0.5288)]\n",
  2134. "\t [tensor(0.9995), tensor(1.), tensor(1.), tensor(1.), tensor(1.), tensor(1.), tensor(0.9990), tensor(1.)]\n",
  2135. "tensor(74)\n",
  2136. "tensor(47)\n",
  2137. "best accs results:\n",
  2138. "\t [tensor(0.6042), tensor(0.6354), tensor(0.5833), tensor(0.6126), tensor(0.5602), tensor(0.6283), tensor(0.6387), tensor(0.6335)]\n",
  2139. "\t [tensor(0.6064), tensor(0.4652), tensor(0.6461), tensor(0.5115), tensor(0.9521), tensor(0.7914), tensor(0.4647), tensor(0.5123)]\n",
  2140. "[1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0]\n",
  2141. "-----------------------------Fold 9---------------\n",
  2142. "preparing dataloaders...\n",
  2143. "coef when 0 > 1 1\n",
  2144. "creating model...\n",
  2145. "calculating total steps...\n",
  2146. "epoch: 1\n",
  2147. "validation loss decreased (inf ---> 0.688638), val_acc = 0.6627907156944275\n",
  2148. "validation acc increased (0.000000 ---> 0.662791)\n",
  2149. "epoch 1: train loss = 0.6933125124244157, l1loss = 1.483249993381386, train acc = 0.5079840421676636,\n",
  2150. "val_loss = 0.7031067840812742, val_acc = 0.3372093141078949\n",
  2151. "\n",
  2152. "epoch: 2\n",
  2153. "epoch 2: train loss = 0.6910889809003133, l1loss = 1.4634534901011729, train acc = 0.5394211411476135,\n",
  2154. "val_loss = 0.7151631089143975, val_acc = 0.3372093141078949\n",
  2155. "\n",
  2156. "epoch: 3\n",
  2157. "epoch 3: train loss = 0.690039867650487, l1loss = 1.4402860893222862, train acc = 0.538922131061554,\n",
  2158. "val_loss = 0.7211975704791934, val_acc = 0.3372093141078949\n",
  2159. "\n",
  2160. "epoch: 4\n",
  2161. "epoch 4: train loss = 0.6893733679177518, l1loss = 1.411496847451566, train acc = 0.538922131061554,\n",
  2162. "val_loss = 0.721796124942543, val_acc = 0.3372093141078949\n",
  2163. "\n",
  2164. "epoch: 5\n",
  2165. "epoch 5: train loss = 0.6888950928718506, l1loss = 1.3750012099385975, train acc = 0.538922131061554,\n",
  2166. "val_loss = 0.7237857270610425, val_acc = 0.3372093141078949\n",
  2167. "\n",
  2168. "epoch: 6\n",
  2169. "epoch 6: train loss = 0.688418993335998, l1loss = 1.328940457926539, train acc = 0.538922131061554,\n",
  2170. "val_loss = 0.7278173852336499, val_acc = 0.3372093141078949\n",
  2171. "\n",
  2172. "epoch: 7\n",
  2173. "epoch 7: train loss = 0.6875811150212012, l1loss = 1.271831663068897, train acc = 0.538922131061554,\n",
  2174. "val_loss = 0.7228789542072503, val_acc = 0.3372093141078949\n",
  2175. "\n",
  2176. "epoch: 8\n",
  2177. "epoch 8: train loss = 0.6866033931454261, l1loss = 1.2027254354454087, train acc = 0.538922131061554,\n",
  2178. "val_loss = 0.7285417683364809, val_acc = 0.3372093141078949\n",
  2179. "\n",
  2180. "epoch: 9\n",
  2181. "epoch 9: train loss = 0.6847465848494433, l1loss = 1.1213995715577207, train acc = 0.5434131622314453,\n",
  2182. "val_loss = 0.7351138185161029, val_acc = 0.3372093141078949\n",
  2183. "\n",
  2184. "epoch: 10\n",
  2185. "epoch 10: train loss = 0.6824056646543111, l1loss = 1.0283273393046595, train acc = 0.5508981943130493,\n",
  2186. "val_loss = 0.7407340176345766, val_acc = 0.3333333432674408\n",
  2187. "\n",
  2188. "epoch: 11\n",
  2189. "epoch 11: train loss = 0.6756234464055287, l1loss = 0.9253893102000573, train acc = 0.5743513107299805,\n",
  2190. "val_loss = 0.7472086497979571, val_acc = 0.3449612259864807\n",
  2191. "\n",
  2192. "epoch: 12\n",
  2193. "epoch 12: train loss = 0.6602505656059631, l1loss = 0.8152883985561287, train acc = 0.6162674427032471,\n",
  2194. "val_loss = 0.7113825424697048, val_acc = 0.4922480583190918\n",
  2195. "\n",
  2196. "epoch: 13\n",
  2197. "validation loss decreased (0.688638 ---> 0.667874), val_acc = 0.6550387740135193\n",
  2198. "epoch 13: train loss = 0.6188409713927857, l1loss = 0.704593446916211, train acc = 0.6756486892700195,\n",
  2199. "val_loss = 0.7028621758601462, val_acc = 0.5930232405662537\n",
  2200. "\n",
  2201. "epoch: 14\n",
  2202. "epoch 14: train loss = 0.5674543769773609, l1loss = 0.5986516261529066, train acc = 0.7080838084220886,\n",
  2203. "val_loss = 0.7258701370668041, val_acc = 0.538759708404541\n",
  2204. "\n",
  2205. "epoch: 15\n",
  2206. "epoch 15: train loss = 0.5195131490449467, l1loss = 0.5008506833198304, train acc = 0.7554890513420105,\n",
  2207. "val_loss = 0.7875658929810043, val_acc = 0.5310077667236328\n",
  2208. "\n",
  2209. "epoch: 16\n",
  2210. "epoch 16: train loss = 0.4449116040370659, l1loss = 0.4148543224244299, train acc = 0.8003991842269897,\n",
  2211. "val_loss = 0.9337220635525015, val_acc = 0.5968992114067078\n",
  2212. "\n",
  2213. "epoch: 17\n",
  2214. "epoch 17: train loss = 0.3717906239503872, l1loss = 0.34227201402068375, train acc = 0.8303393125534058,\n",
  2215. "val_loss = 0.9147336353627287, val_acc = 0.5968992114067078\n",
  2216. "\n",
  2217. "epoch: 18\n",
  2218. "epoch 18: train loss = 0.2456365298487231, l1loss = 0.2895531968442266, train acc = 0.9016966223716736,\n",
  2219. "val_loss = 1.224231436271076, val_acc = 0.5193798542022705\n",
  2220. "\n",
  2221. "epoch: 19\n",
  2222. "epoch 19: train loss = 0.2006593743365206, l1loss = 0.25701168244946265, train acc = 0.9211576581001282,\n",
  2223. "val_loss = 1.278371164040972, val_acc = 0.5\n",
  2224. "\n",
  2225. "epoch: 20\n",
  2226. "epoch 20: train loss = 0.2456775011416681, l1loss = 0.2422227506747027, train acc = 0.8952096104621887,\n",
  2227. "val_loss = 1.319511427435764, val_acc = 0.6356589198112488\n",
  2228. "\n",
  2229. "epoch: 21\n",
  2230. "epoch 21: train loss = 0.10676692857475814, l1loss = 0.2299664880701168, train acc = 0.9675648808479309,\n",
  2231. "val_loss = 1.6988744661789532, val_acc = 0.5038759708404541\n",
  2232. "\n",
  2233. "epoch: 22\n",
  2234. "epoch 22: train loss = 0.06530988484085677, l1loss = 0.2210731065558816, train acc = 0.9795409440994263,\n",
  2235. "val_loss = 1.772416866102884, val_acc = 0.4806201457977295\n",
  2236. "\n",
  2237. "epoch: 23\n",
  2238. "epoch 23: train loss = 0.16125833081033178, l1loss = 0.21659656228895435, train acc = 0.9346307516098022,\n",
  2239. "val_loss = 2.660450824471407, val_acc = 0.4767441749572754\n",
  2240. "\n",
  2241. "epoch: 24\n",
  2242. "epoch 24: train loss = 0.052061744673761304, l1loss = 0.2090652662063549, train acc = 0.9890219569206238,\n",
  2243. "val_loss = 1.6618810494740803, val_acc = 0.5736433863639832\n",
  2244. "\n",
  2245. "epoch: 25\n",
  2246. "epoch 25: train loss = 0.025476095487852533, l1loss = 0.1998468009000291, train acc = 0.9960079789161682,\n",
  2247. "val_loss = 1.7238089602455036, val_acc = 0.604651153087616\n",
  2248. "\n",
  2249. "epoch: 26\n",
  2250. "epoch 26: train loss = 0.02256299738220112, l1loss = 0.19226344292630215, train acc = 0.9915169477462769,\n",
  2251. "val_loss = 2.2160921032114547, val_acc = 0.5232558250427246\n",
  2252. "\n",
  2253. "epoch: 27\n",
  2254. "epoch 27: train loss = 0.0615687958464889, l1loss = 0.18798830180230017, train acc = 0.9765468835830688,\n",
  2255. "val_loss = 2.2735826359238733, val_acc = 0.6124030947685242\n",
  2256. "\n",
  2257. "epoch: 28\n",
  2258. "epoch 28: train loss = 0.22541590647664136, l1loss = 0.19491120203288492, train acc = 0.9036926031112671,\n",
  2259. "val_loss = 1.4376536790252656, val_acc = 0.4728682041168213\n",
  2260. "\n",
  2261. "epoch: 29\n"
  2262. ]
  2263. },
  2264. {
  2265. "name": "stdout",
  2266. "output_type": "stream",
  2267. "text": [
  2268. "epoch 29: train loss = 0.0631909598698635, l1loss = 0.1892870161050332, train acc = 0.9780439138412476,\n",
  2269. "val_loss = 2.1181586296983466, val_acc = 0.4728682041168213\n",
  2270. "\n",
  2271. "epoch: 30\n",
  2272. "epoch 30: train loss = 0.01965019029905339, l1loss = 0.17764762243229948, train acc = 0.9970059990882874,\n",
  2273. "val_loss = 1.861120948495791, val_acc = 0.5465116500854492\n",
  2274. "\n",
  2275. "epoch: 31\n",
  2276. "epoch 31: train loss = 0.008150098203103282, l1loss = 0.16982665143446057, train acc = 1.0,\n",
  2277. "val_loss = 1.9675745483516722, val_acc = 0.538759708404541\n",
  2278. "\n",
  2279. "!!! overfitted !!!\n",
  2280. "tensor(80)\n",
  2281. "tensor(42)\n",
  2282. "early stoping results:\n",
  2283. "\t [tensor(0.5833), tensor(0.6354), tensor(0.3646), tensor(0.6126), tensor(0.6335), tensor(0.5131), tensor(0.6387), tensor(0.6335), tensor(0.6387)]\n",
  2284. "\t [tensor(0.5737), tensor(0.4652), tensor(0.5414), tensor(0.5115), tensor(0.4751), tensor(0.5817), tensor(0.4647), tensor(0.5123), tensor(0.4775)]\n",
  2285. "tensor(80)\n",
  2286. "tensor(33)\n",
  2287. "full train results:\n",
  2288. "\t [tensor(0.5312), tensor(0.5365), tensor(0.5938), tensor(0.5550), tensor(0.5236), tensor(0.5654), tensor(0.6021), tensor(0.5288), tensor(0.5916)]\n",
  2289. "\t [tensor(0.9995), tensor(1.), tensor(1.), tensor(1.), tensor(1.), tensor(1.), tensor(0.9990), tensor(1.), tensor(0.9990)]\n",
  2290. "tensor(87)\n",
  2291. "tensor(35)\n",
  2292. "best accs results:\n",
  2293. "\t [tensor(0.6042), tensor(0.6354), tensor(0.5833), tensor(0.6126), tensor(0.5602), tensor(0.6283), tensor(0.6387), tensor(0.6335), tensor(0.6387)]\n",
  2294. "\t [tensor(0.6064), tensor(0.4652), tensor(0.6461), tensor(0.5115), tensor(0.9521), tensor(0.7914), tensor(0.4647), tensor(0.5123), tensor(0.4611)]\n",
  2295. "[1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0]\n",
  2296. "-----------------------------Fold 10---------------\n",
  2297. "preparing dataloaders...\n",
  2298. "coef when 0 > 1 1\n",
  2299. "creating model...\n",
  2300. "calculating total steps...\n",
  2301. "epoch: 1\n",
  2302. "validation loss decreased (inf ---> 0.672692), val_acc = 0.6550387740135193\n",
  2303. "validation acc increased (0.000000 ---> 0.655039)\n",
  2304. "validation acc increased (0.655039 ---> 0.655039)\n",
  2305. "epoch 1: train loss = 0.6991829303475646, l1loss = 1.4839057058959335, train acc = 0.4625374674797058,\n",
  2306. "val_loss = 0.6754856922829798, val_acc = 0.6550387740135193\n",
  2307. "\n",
  2308. "epoch: 2\n",
  2309. "validation acc increased (0.655039 ---> 0.655039)\n",
  2310. "validation acc increased (0.655039 ---> 0.655039)\n",
  2311. "epoch 2: train loss = 0.6939829035357876, l1loss = 1.4645228752723107, train acc = 0.49450549483299255,\n",
  2312. "val_loss = 0.6890089049819828, val_acc = 0.6550387740135193\n",
  2313. "\n",
  2314. "epoch: 3\n",
  2315. "validation acc increased (0.655039 ---> 0.658915)\n",
  2316. "epoch 3: train loss = 0.6903928536754269, l1loss = 1.4417118324742808, train acc = 0.5394605398178101,\n",
  2317. "val_loss = 0.7046850910482481, val_acc = 0.3449612259864807\n",
  2318. "\n",
  2319. "epoch: 4\n",
  2320. "epoch 4: train loss = 0.6891815625466071, l1loss = 1.4131007423172226, train acc = 0.5374625325202942,\n",
  2321. "val_loss = 0.719056154406348, val_acc = 0.3449612259864807\n",
  2322. "\n",
  2323. "epoch: 5\n",
  2324. "epoch 5: train loss = 0.6889630373898562, l1loss = 1.376648348885459, train acc = 0.5374625325202942,\n",
  2325. "val_loss = 0.7238496870033501, val_acc = 0.3449612259864807\n",
  2326. "\n",
  2327. "epoch: 6\n",
  2328. "epoch 6: train loss = 0.6881177404543737, l1loss = 1.3303997759576087, train acc = 0.5374625325202942,\n",
  2329. "val_loss = 0.7200313155041185, val_acc = 0.3449612259864807\n",
  2330. "\n",
  2331. "epoch: 7\n",
  2332. "epoch 7: train loss = 0.6869667621401997, l1loss = 1.2731195258332062, train acc = 0.5374625325202942,\n",
  2333. "val_loss = 0.7144513324249623, val_acc = 0.3449612259864807\n",
  2334. "\n",
  2335. "epoch: 8\n",
  2336. "epoch 8: train loss = 0.6860746289347555, l1loss = 1.2041123083659582, train acc = 0.546953022480011,\n",
  2337. "val_loss = 0.7131708472274071, val_acc = 0.3488371968269348\n",
  2338. "\n",
  2339. "epoch: 9\n",
  2340. "epoch 9: train loss = 0.6840823413608791, l1loss = 1.1235749534555486, train acc = 0.5389610528945923,\n",
  2341. "val_loss = 0.7343198874200038, val_acc = 0.3449612259864807\n",
  2342. "\n",
  2343. "epoch: 10\n",
  2344. "epoch 10: train loss = 0.6786974231441776, l1loss = 1.03154473943072, train acc = 0.580919086933136,\n",
  2345. "val_loss = 0.7059320497882459, val_acc = 0.44961240887641907\n",
  2346. "\n",
  2347. "epoch: 11\n",
  2348. "epoch 11: train loss = 0.6691220754629129, l1loss = 0.9298823592546103, train acc = 0.6143856048583984,\n",
  2349. "val_loss = 0.8592184712720472, val_acc = 0.3449612259864807\n",
  2350. "\n",
  2351. "epoch: 12\n",
  2352. "epoch 12: train loss = 0.6548706972515667, l1loss = 0.8217827181716065, train acc = 0.6238760948181152,\n",
  2353. "val_loss = 0.7483981838522031, val_acc = 0.43798449635505676\n",
  2354. "\n",
  2355. "epoch: 13\n",
  2356. "epoch 13: train loss = 0.6237266789544951, l1loss = 0.7114838892882401, train acc = 0.6543456315994263,\n",
  2357. "val_loss = 0.9002822369568108, val_acc = 0.3759689927101135\n",
  2358. "\n",
  2359. "epoch: 14\n",
  2360. "epoch 14: train loss = 0.5636290751851641, l1loss = 0.6049504718342266, train acc = 0.726773202419281,\n",
  2361. "val_loss = 0.7929284221442171, val_acc = 0.569767415523529\n",
  2362. "\n",
  2363. "epoch: 15\n",
  2364. "epoch 15: train loss = 0.500030604811696, l1loss = 0.5065306650472807, train acc = 0.7647352814674377,\n",
  2365. "val_loss = 0.8221829214761424, val_acc = 0.5426356792449951\n",
  2366. "\n",
  2367. "epoch: 16\n",
  2368. "epoch 16: train loss = 0.432999442179839, l1loss = 0.4196841673893886, train acc = 0.8066933155059814,\n",
  2369. "val_loss = 0.9935263626335203, val_acc = 0.6589147448539734\n",
  2370. "\n",
  2371. "epoch: 17\n",
  2372. "epoch 17: train loss = 0.3120971677246151, l1loss = 0.34845078170120897, train acc = 0.8761239051818848,\n",
  2373. "val_loss = 1.0305215822633846, val_acc = 0.643410861492157\n",
  2374. "\n",
  2375. "epoch: 18\n",
  2376. "epoch 18: train loss = 0.21123728623518814, l1loss = 0.2952757461683138, train acc = 0.9240759015083313,\n",
  2377. "val_loss = 1.2388130657432614, val_acc = 0.569767415523529\n",
  2378. "\n",
  2379. "epoch: 19\n",
  2380. "epoch 19: train loss = 0.1134848499021211, l1loss = 0.2615779391475967, train acc = 0.9630369544029236,\n",
  2381. "val_loss = 1.480191317177558, val_acc = 0.5891472697257996\n",
  2382. "\n",
  2383. "epoch: 20\n",
  2384. "epoch 20: train loss = 0.127544645990406, l1loss = 0.24558655981655483, train acc = 0.9500499367713928,\n",
  2385. "val_loss = 1.6672501929299257, val_acc = 0.5813953280448914\n",
  2386. "\n",
  2387. "epoch: 21\n",
  2388. "epoch 21: train loss = 0.07767527621168714, l1loss = 0.23245742854538498, train acc = 0.9760239720344543,\n",
  2389. "val_loss = 1.5230142534241196, val_acc = 0.6007751822471619\n",
  2390. "\n",
  2391. "epoch: 22\n",
  2392. "epoch 22: train loss = 0.07988126125636992, l1loss = 0.22322249095458965, train acc = 0.9710289835929871,\n",
  2393. "val_loss = 1.4321563678194387, val_acc = 0.5968992114067078\n",
  2394. "\n",
  2395. "epoch: 23\n",
  2396. "epoch 23: train loss = 0.04452426001228593, l1loss = 0.213434277267961, train acc = 0.9900099635124207,\n",
  2397. "val_loss = 1.7915307799527465, val_acc = 0.643410861492157\n",
  2398. "\n",
  2399. "epoch: 24\n",
  2400. "epoch 24: train loss = 0.028254755680541414, l1loss = 0.20372207654820573, train acc = 0.9940059781074524,\n",
  2401. "val_loss = 1.940356590951136, val_acc = 0.604651153087616\n",
  2402. "\n",
  2403. "epoch: 25\n",
  2404. "epoch 25: train loss = 0.01316713175264659, l1loss = 0.19553671531624847, train acc = 0.9985014796257019,\n",
  2405. "val_loss = 1.7937852486159451, val_acc = 0.569767415523529\n",
  2406. "\n",
  2407. "epoch: 26\n",
  2408. "epoch 26: train loss = 0.005751309507469703, l1loss = 0.18567656404845842, train acc = 1.0,\n",
  2409. "val_loss = 1.7987251420353734, val_acc = 0.604651153087616\n",
  2410. "\n",
  2411. "!!! overfitted !!!\n",
  2412. "tensor(80)\n",
  2413. "tensor(42)\n",
  2414. "early stoping results:\n",
  2415. "\t [tensor(0.5833), tensor(0.6354), tensor(0.3646), tensor(0.6126), tensor(0.6335), tensor(0.5131), tensor(0.6387), tensor(0.6335), tensor(0.6387), tensor(0.6387)]\n",
  2416. "\t [tensor(0.5737), tensor(0.4652), tensor(0.5414), tensor(0.5115), tensor(0.4751), tensor(0.5817), tensor(0.4647), tensor(0.5123), tensor(0.4775), tensor(0.4625)]\n",
  2417. "tensor(75)\n",
  2418. "tensor(28)\n",
  2419. "full train results:\n",
  2420. "\t [tensor(0.5312), tensor(0.5365), tensor(0.5938), tensor(0.5550), tensor(0.5236), tensor(0.5654), tensor(0.6021), tensor(0.5288), tensor(0.5916), tensor(0.5393)]\n",
  2421. "\t [tensor(0.9995), tensor(1.), tensor(1.), tensor(1.), tensor(1.), tensor(1.), tensor(0.9990), tensor(1.), tensor(0.9990), tensor(1.)]\n",
  2422. "tensor(85)\n",
  2423. "tensor(39)\n",
  2424. "best accs results:\n",
  2425. "\t [tensor(0.6042), tensor(0.6354), tensor(0.5833), tensor(0.6126), tensor(0.5602), tensor(0.6283), tensor(0.6387), tensor(0.6335), tensor(0.6387), tensor(0.6492)]\n"
  2426. ]
  2427. },
  2428. {
  2429. "name": "stdout",
  2430. "output_type": "stream",
  2431. "text": [
  2432. "\t [tensor(0.6064), tensor(0.4652), tensor(0.6461), tensor(0.5115), tensor(0.9521), tensor(0.7914), tensor(0.4647), tensor(0.5123), tensor(0.4611), tensor(0.4670)]\n",
  2433. "[1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0]\n"
  2434. ]
  2435. }
  2436. ],
  2437. "source": [
  2438. "train_accs, test_accs = [], []\n",
  2439. "train_accs_over, test_accs_over = [], []\n",
  2440. "train_accs_acc, test_accs_acc = [], []\n",
  2441. "\n",
  2442. "for fold, (train_val_idx, test_idx) in enumerate(skf.split(dataset, labels)):\n",
  2443. " \n",
  2444. " print('-----------------------------Fold {}---------------'.format(fold + 1))\n",
  2445. "\n",
  2446. " \n",
  2447. " print('preparing dataloaders...')\n",
  2448. " train_val_data = np.stack([dataset[index] for index in train_val_idx])\n",
  2449. " train_val_label = [labels[index] for index in train_val_idx]\n",
  2450. " test_data = np.stack([dataset[index] for index in test_idx])\n",
  2451. " test_label = [labels[index] for index in test_idx]\n",
  2452. " \n",
  2453. " \n",
  2454. " Max = np.max(train_val_data, axis=(0,1,2,4), keepdims=True)\n",
  2455. " Min = np.min(train_val_data, axis=(0,1,2,4), keepdims=True)\n",
  2456. " train_val_data = (train_val_data-Min)/(Max-Min)\n",
  2457. " \n",
  2458. " Max_test = np.max(test_data, axis=(0,1,2,4), keepdims=True)\n",
  2459. " Min_test = np.min(test_data, axis=(0,1,2,4), keepdims=True)\n",
  2460. " test_data = (test_data-Min)/(Max-Min)\n",
  2461. " \n",
  2462. " \n",
  2463. " train_val = [[train_val_data[i], train_val_label[i]] for i in range(len(train_val_data))]\n",
  2464. " test = [[test_data[i], test_label[i]] for i in range(len(test_data))]\n",
  2465. " \n",
  2466. " num_train_val = len(train_val)\n",
  2467. " indices = list(range(num_train_val))\n",
  2468. " np.random.shuffle(indices)\n",
  2469. " split = int(np.floor(val_size*num_train_val))\n",
  2470. " train, val = [train_val[i] for i in indices[split:]] ,[train_val[i] for i in indices[:split]]\n",
  2471. " \n",
  2472. " train_labels = [data[1] for data in train]\n",
  2473. " \n",
  2474. " oversample = 1\n",
  2475. " _, counts = np.unique(train_labels, return_counts=True)\n",
  2476. " if oversample==1:\n",
  2477. " if counts[1]>counts[0]:\n",
  2478. " label0 = [data for data in train if data[1]==0]\n",
  2479. " coef = int(counts[1]/counts[0])\n",
  2480. " print('coef when 1 > 0', coef)\n",
  2481. " for i in range(coef):\n",
  2482. " train = train + label0\n",
  2483. " elif counts[1]<counts[0]:\n",
  2484. " label1 = [data for data in train if data[1]==1]\n",
  2485. " coef = int(counts[0]/counts[1])\n",
  2486. " print('coef when 0 > 1', coef)\n",
  2487. " for i in range(coef):\n",
  2488. " train = train + label1\n",
  2489. " \n",
  2490. "\n",
  2491. " train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True)\n",
  2492. " val_loader = torch.utils.data.DataLoader(val, batch_size=batch_size, shuffle=True)\n",
  2493. " test_loader = torch.utils.data.DataLoader(test, batch_size=batch_size, shuffle=True)\n",
  2494. " \n",
  2495. " print('creating model...')\n",
  2496. " model = CNN_RNN().float()\n",
  2497. " optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)\n",
  2498. " criterion = nn.CrossEntropyLoss()\n",
  2499. " \n",
  2500. " print('calculating total steps...')\n",
  2501. " steps = 0\n",
  2502. " for epoch in range(n_epochs):\n",
  2503. " for data, label in train_loader:\n",
  2504. " steps += 1\n",
  2505. "\n",
  2506. " scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, total_steps=steps, max_lr=0.001)\n",
  2507. " scheduler1 = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)\n",
  2508. " l1_lambda = 0.0001\n",
  2509. " \n",
  2510. " min_val_loss = np.inf\n",
  2511. " max_val_acc = 0\n",
  2512. " \n",
  2513. " for epoch in range(n_epochs):\n",
  2514. " print('epoch: ', epoch+1)\n",
  2515. " train_loss = 0\n",
  2516. " l1_loss = 0\n",
  2517. " train_correct = 0\n",
  2518. " model.train()\n",
  2519. " '''for name, param in model.named_parameters():\n",
  2520. " print(name, param.data)\n",
  2521. " break'''\n",
  2522. " for iteration, (data,label) in enumerate(train_loader):\n",
  2523. " #print('\\ndata = ', torch.amax(data, axis=(0,1,2,4)), torch.amin(data, axis=(0,1,2,4)))\n",
  2524. " optimizer.zero_grad()\n",
  2525. " output = model(data.float())\n",
  2526. " '''label = torch.reshape(label, (-1,1))\n",
  2527. " label = label.float()'''\n",
  2528. " loss = criterion(output, label)\n",
  2529. " add_loss = loss\n",
  2530. " ex_loss = 0\n",
  2531. " for W in model.parameters():\n",
  2532. " ex_loss += l1_lambda*W.norm(1)\n",
  2533. " loss = loss + l1_lambda*W.norm(1) \n",
  2534. " loss.backward()\n",
  2535. " optimizer.step()\n",
  2536. " scheduler.step()\n",
  2537. " #targets = [1 if output[i].round()==label[i] else 0 for i in range(len(label))]\n",
  2538. " targets = sum(torch.argmax(output,dim=1)==label)\n",
  2539. " train_correct += targets\n",
  2540. " train_loss += add_loss.item()*data.shape[0]\n",
  2541. " l1_loss += ex_loss.item()*data.shape[0]\n",
  2542. " \n",
  2543. " if iteration % print_every == 0:\n",
  2544. " is_training = True\n",
  2545. " val_loss = 0\n",
  2546. " val_correct = 0\n",
  2547. " model.eval()\n",
  2548. " for data, label in val_loader:\n",
  2549. " output = model(data.float())\n",
  2550. " '''label = torch.reshape(label, (-1,1))\n",
  2551. " label = label.float()'''\n",
  2552. " loss = criterion(output, label) \n",
  2553. " val_loss += loss.item()*data.shape[0]\n",
  2554. " #targets = [1 if output[i].round()==label[i] else 0 for i in range(len(label))]\n",
  2555. " targets = sum(torch.argmax(output,dim=1)==label)\n",
  2556. " val_correct += targets\n",
  2557. " val_loss = val_loss/len(val_loader.sampler)\n",
  2558. " val_acc = val_correct/len(val_loader.sampler)\n",
  2559. "\n",
  2560. " if val_loss <= min_val_loss:\n",
  2561. " print(\"validation loss decreased ({:.6f} ---> {:.6f}), val_acc = {}\".format(min_val_loss, val_loss, val_acc))\n",
  2562. " torch.save(model.state_dict(), 'captum/nasal/model'+str(fold)+'.pt')\n",
  2563. " min_val_loss = val_loss\n",
  2564. " if val_acc >= max_val_acc:\n",
  2565. " print(\"validation acc increased ({:.6f} ---> {:.6f})\".format(max_val_acc, val_acc))\n",
  2566. " torch.save(model.state_dict(), 'captum/nasal/model'+str(fold)+'_acc.pt')\n",
  2567. " max_val_acc = val_acc\n",
  2568. " torch.save(model.state_dict(), 'captum/nasal/last_model'+str(fold)+'.pt')\n",
  2569. " model.train(mode=is_training)\n",
  2570. " \n",
  2571. " train_acc = train_correct/len(train_loader.sampler) \n",
  2572. " train_loss = train_loss/len(train_loader.sampler)\n",
  2573. " loss1 = l1_loss/len(train_loader.sampler)\n",
  2574. " \n",
  2575. " val_loss = 0\n",
  2576. " val_correct = 0\n",
  2577. " model.eval()\n",
  2578. " for data, label in val_loader:\n",
  2579. " output = model(data.float())\n",
  2580. " '''label = torch.reshape(label, (-1,1))\n",
  2581. " label = label.float()'''\n",
  2582. " loss = criterion(output, label) \n",
  2583. " val_loss += loss.item()*data.shape[0]\n",
  2584. " #targets = [1 if output[i].round()==label[i] else 0 for i in range(len(label))]\n",
  2585. " targets = sum(torch.argmax(output,dim=1)==label)\n",
  2586. " val_correct += targets\n",
  2587. " \n",
  2588. " val_loss = val_loss/len(val_loader.sampler)\n",
  2589. " val_acc = val_correct/len(val_loader.sampler)\n",
  2590. " \n",
  2591. " print('epoch {}: train loss = {}, l1loss = {}, train acc = {},\\nval_loss = {}, val_acc = {}\\n'\n",
  2592. " .format(epoch+1, train_loss, loss1, train_acc, val_loss, val_acc))\n",
  2593. " if int(train_acc)==1:\n",
  2594. " print('!!! overfitted !!!')\n",
  2595. " break\n",
  2596. " model.train()\n",
  2597. " #scheduler1.step(val_loss)\n",
  2598. " \n",
  2599. " model =CNN_RNN().float()\n",
  2600. " model.load_state_dict(torch.load('captum/nasal/model'+str(fold)+'.pt'))\n",
  2601. " \n",
  2602. " n_correct = 0\n",
  2603. " model.eval()\n",
  2604. " for data, label in test_loader:\n",
  2605. " output = model(data.float())\n",
  2606. " #targets = [1 if output[i].round()==label[i] else 0 for i in range(len(label))]\n",
  2607. " targets = sum(torch.argmax(output,dim=1)==label)\n",
  2608. " print(targets)\n",
  2609. " n_correct += targets\n",
  2610. " \n",
  2611. " test_accs.append(n_correct/len(test_loader.sampler))\n",
  2612. " print('early stoping results:\\n\\t', test_accs)\n",
  2613. " \n",
  2614. " n_correct = 0\n",
  2615. " model.eval()\n",
  2616. " for data, label in train_loader:\n",
  2617. " output = model(data.float())\n",
  2618. " #targets = [1 if output[i].round()==label[i] else 0 for i in range(len(label))]\n",
  2619. " targets = sum(torch.argmax(output,dim=1)==label)\n",
  2620. " n_correct += targets\n",
  2621. " \n",
  2622. " train_accs.append(n_correct/len(train_loader.sampler))\n",
  2623. " print('\\t', train_accs)\n",
  2624. " \n",
  2625. " model = CNN_RNN().float()\n",
  2626. " model.load_state_dict(torch.load('captum/nasal/last_model'+str(fold)+'.pt'))\n",
  2627. " \n",
  2628. " n_correct = 0\n",
  2629. " model.eval()\n",
  2630. " for data, label in test_loader:\n",
  2631. " output = model(data.float())\n",
  2632. " #targets = [1 if output[i].round()==label[i] else 0 for i in range(len(label))]\n",
  2633. " targets = sum(torch.argmax(output,dim=1)==label)\n",
  2634. " print(targets)\n",
  2635. " n_correct += targets\n",
  2636. " test_accs_over.append(n_correct/len(test_loader.sampler))\n",
  2637. " print('full train results:\\n\\t', test_accs_over)\n",
  2638. " \n",
  2639. " n_correct = 0\n",
  2640. " model.eval()\n",
  2641. " for data, label in train_loader:\n",
  2642. " output = model(data.float())\n",
  2643. " #targets = [1 if output[i].round()==label[i] else 0 for i in range(len(label))]\n",
  2644. " targets = sum(torch.argmax(output,dim=1)==label)\n",
  2645. " n_correct += targets\n",
  2646. " train_accs_over.append(n_correct/len(train_loader.sampler))\n",
  2647. " print('\\t', train_accs_over)\n",
  2648. " \n",
  2649. " model = CNN_RNN().float()\n",
  2650. " model.load_state_dict(torch.load('captum/nasal/model'+str(fold)+'_acc.pt'))\n",
  2651. " \n",
  2652. " n_correct = 0\n",
  2653. " model.eval()\n",
  2654. " for data, label in test_loader:\n",
  2655. " output = model(data.float())\n",
  2656. " #targets = [1 if output[i].round()==label[i] else 0 for i in range(len(label))]\n",
  2657. " targets = sum(torch.argmax(output,dim=1)==label)\n",
  2658. " print(targets)\n",
  2659. " n_correct += targets\n",
  2660. " test_accs_acc.append(n_correct/len(test_loader.sampler))\n",
  2661. " print('best accs results:\\n\\t', test_accs_acc)\n",
  2662. " \n",
  2663. " n_correct = 0\n",
  2664. " model.eval()\n",
  2665. " for data, label in train_loader:\n",
  2666. " output = model(data.float())\n",
  2667. " #targets = [1 if output[i].round()==label[i] else 0 for i in range(len(label))]\n",
  2668. " targets = sum(torch.argmax(output,dim=1)==label)\n",
  2669. " n_correct += targets\n",
  2670. " train_accs_acc.append(n_correct/len(train_loader.sampler))\n",
  2671. " print('\\t', train_accs_acc)\n",
  2672. " print(test_label)"
  2673. ]
  2674. },
  2675. {
  2676. "cell_type": "code",
  2677. "execution_count": 16,
  2678. "id": "93a5fc77",
  2679. "metadata": {},
  2680. "outputs": [
  2681. {
  2682. "data": {
  2683. "text/plain": [
  2684. "0.8107657068062828"
  2685. ]
  2686. },
  2687. "execution_count": 16,
  2688. "metadata": {},
  2689. "output_type": "execute_result"
  2690. }
  2691. ],
  2692. "source": [
  2693. "(sum(test_accs_acc))/10"
  2694. ]
  2695. },
  2696. {
  2697. "cell_type": "code",
  2698. "execution_count": null,
  2699. "id": "74a75907",
  2700. "metadata": {},
  2701. "outputs": [],
  2702. "source": [
  2703. "train_accs, test_accs = [], []\n",
  2704. "train_accs_over, test_accs_over = [], []\n",
  2705. "train_accs_acc, test_accs_acc = [], []"
  2706. ]
  2707. },
  2708. {
  2709. "cell_type": "code",
  2710. "execution_count": 17,
  2711. "id": "078e125c",
  2712. "metadata": {},
  2713. "outputs": [
  2714. {
  2715. "data": {
  2716. "text/plain": [
  2717. "[0.8020833333333334,\n",
  2718. " 0.8177083333333334,\n",
  2719. " 0.8177083333333334,\n",
  2720. " 0.8272251308900523,\n",
  2721. " 0.7958115183246073,\n",
  2722. " 0.8219895287958116,\n",
  2723. " 0.8167539267015707,\n",
  2724. " 0.8167539267015707,\n",
  2725. " 0.806282722513089,\n",
  2726. " 0.7853403141361257]"
  2727. ]
  2728. },
  2729. "execution_count": 17,
  2730. "metadata": {},
  2731. "output_type": "execute_result"
  2732. }
  2733. ],
  2734. "source": [
  2735. "test_accs_acc"
  2736. ]
  2737. },
  2738. {
  2739. "cell_type": "code",
  2740. "execution_count": null,
  2741. "id": "5abb4609",
  2742. "metadata": {},
  2743. "outputs": [],
  2744. "source": []
  2745. },
  2746. {
  2747. "cell_type": "code",
  2748. "execution_count": null,
  2749. "id": "5518dc3d",
  2750. "metadata": {},
  2751. "outputs": [],
  2752. "source": []
  2753. },
  2754. {
  2755. "cell_type": "code",
  2756. "execution_count": null,
  2757. "id": "f62e3322",
  2758. "metadata": {},
  2759. "outputs": [],
  2760. "source": []
  2761. },
  2762. {
  2763. "cell_type": "code",
  2764. "execution_count": 23,
  2765. "id": "91c032f1",
  2766. "metadata": {},
  2767. "outputs": [
  2768. {
  2769. "data": {
  2770. "text/plain": [
  2771. "0.6354166666666667"
  2772. ]
  2773. },
  2774. "execution_count": 23,
  2775. "metadata": {},
  2776. "output_type": "execute_result"
  2777. }
  2778. ],
  2779. "source": [
  2780. "1 - sum(y__test)/len(y__test)"
  2781. ]
  2782. },
  2783. {
  2784. "cell_type": "code",
  2785. "execution_count": null,
  2786. "id": "d49fc26a",
  2787. "metadata": {},
  2788. "outputs": [],
  2789. "source": [
  2790. "+ l1_lambda*(sum(torch.linalg.norm(p, 1) for p in model.parameters()))"
  2791. ]
  2792. },
  2793. {
  2794. "cell_type": "code",
  2795. "execution_count": null,
  2796. "id": "97c7cbed",
  2797. "metadata": {},
  2798. "outputs": [],
  2799. "source": []
  2800. },
  2801. {
  2802. "cell_type": "code",
  2803. "execution_count": null,
  2804. "id": "d8a75e16",
  2805. "metadata": {},
  2806. "outputs": [],
  2807. "source": []
  2808. },
  2809. {
  2810. "cell_type": "code",
  2811. "execution_count": 218,
  2812. "id": "c02888f1",
  2813. "metadata": {},
  2814. "outputs": [
  2815. {
  2816. "name": "stdout",
  2817. "output_type": "stream",
  2818. "text": [
  2819. "CNN_RNN(\n",
  2820. " (conv1): Conv2d(10, 16, kernel_size=(3, 3), stride=(1, 1))\n",
  2821. " (pool): MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, ceil_mode=False)\n",
  2822. " (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1))\n",
  2823. " (batchnorm): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
  2824. " (lstm): LSTM(256, 128, num_layers=2, batch_first=True)\n",
  2825. " (fc): Linear(in_features=128, out_features=2, bias=True)\n",
  2826. ")\n"
  2827. ]
  2828. },
  2829. {
  2830. "name": "stderr",
  2831. "output_type": "stream",
  2832. "text": [
  2833. "C:\\Users\\saeed\\AppData\\Local\\Temp\\ipykernel_16720\\3694016727.py:7: UserWarning: nn.init.xavier_normal is now deprecated in favor of nn.init.xavier_normal_.\n",
  2834. " torch.nn.init.xavier_normal(self.conv1.weight)\n",
  2835. "C:\\Users\\saeed\\AppData\\Local\\Temp\\ipykernel_16720\\3694016727.py:10: UserWarning: nn.init.xavier_normal is now deprecated in favor of nn.init.xavier_normal_.\n",
  2836. " torch.nn.init.xavier_normal(self.conv2.weight)\n",
  2837. "C:\\Users\\saeed\\AppData\\Local\\Temp\\ipykernel_16720\\3694016727.py:14: UserWarning: nn.init.xavier_normal is now deprecated in favor of nn.init.xavier_normal_.\n",
  2838. " torch.nn.init.xavier_normal(self.fc.weight)\n"
  2839. ]
  2840. }
  2841. ],
  2842. "source": [
  2843. "model = model = CNN_RNN().float()\n",
  2844. "model.load_state_dict(torch.load('train/fc1_10_20_11_vowel/model.pt'))\n",
  2845. "print(model)"
  2846. ]
  2847. },
  2848. {
  2849. "cell_type": "code",
  2850. "execution_count": 219,
  2851. "id": "fd9bd508",
  2852. "metadata": {},
  2853. "outputs": [
  2854. {
  2855. "name": "stdout",
  2856. "output_type": "stream",
  2857. "text": [
  2858. "0.7864583333333334\n"
  2859. ]
  2860. }
  2861. ],
  2862. "source": [
  2863. "test_loss = 0\n",
  2864. "n_correct = 0\n",
  2865. "\n",
  2866. "model.eval()\n",
  2867. "for data, label in test_loader:\n",
  2868. " output = model(data.float())\n",
  2869. " pred = torch.argmax(output, dim=1)\n",
  2870. " n_correct += torch.sum(pred==label).item()\n",
  2871. "test_acc = n_correct/len(test_loader.sampler)\n",
  2872. "print(test_acc)"
  2873. ]
  2874. },
  2875. {
  2876. "cell_type": "code",
  2877. "execution_count": 53,
  2878. "id": "c7fd86ea",
  2879. "metadata": {},
  2880. "outputs": [
  2881. {
  2882. "name": "stdout",
  2883. "output_type": "stream",
  2884. "text": [
  2885. "0.3645833333333333\n"
  2886. ]
  2887. }
  2888. ],
  2889. "source": [
  2890. "print(sum(test_label)/len(test_label))"
  2891. ]
  2892. },
  2893. {
  2894. "cell_type": "code",
  2895. "execution_count": null,
  2896. "id": "20273a2b",
  2897. "metadata": {},
  2898. "outputs": [],
  2899. "source": []
  2900. },
  2901. {
  2902. "cell_type": "code",
  2903. "execution_count": null,
  2904. "id": "037040c9",
  2905. "metadata": {},
  2906. "outputs": [],
  2907. "source": []
  2908. },
  2909. {
  2910. "cell_type": "code",
  2911. "execution_count": 178,
  2912. "id": "99ab9dcd",
  2913. "metadata": {},
  2914. "outputs": [],
  2915. "source": [
  2916. "a = [1,2,3,4,5,6]"
  2917. ]
  2918. },
  2919. {
  2920. "cell_type": "code",
  2921. "execution_count": 179,
  2922. "id": "5189a944",
  2923. "metadata": {},
  2924. "outputs": [],
  2925. "source": [
  2926. "np.random.shuffle(a)"
  2927. ]
  2928. },
  2929. {
  2930. "cell_type": "code",
  2931. "execution_count": 180,
  2932. "id": "354e419a",
  2933. "metadata": {},
  2934. "outputs": [
  2935. {
  2936. "data": {
  2937. "text/plain": [
  2938. "[4, 2, 6, 3, 5, 1]"
  2939. ]
  2940. },
  2941. "execution_count": 180,
  2942. "metadata": {},
  2943. "output_type": "execute_result"
  2944. }
  2945. ],
  2946. "source": [
  2947. "a"
  2948. ]
  2949. },
  2950. {
  2951. "cell_type": "code",
  2952. "execution_count": null,
  2953. "id": "55070b74",
  2954. "metadata": {},
  2955. "outputs": [],
  2956. "source": []
  2957. },
  2958. {
  2959. "cell_type": "code",
  2960. "execution_count": null,
  2961. "id": "1d9174a3",
  2962. "metadata": {},
  2963. "outputs": [],
  2964. "source": []
  2965. }
  2966. ],
  2967. "metadata": {
  2968. "kernelspec": {
  2969. "display_name": "Python 3 (ipykernel)",
  2970. "language": "python",
  2971. "name": "python3"
  2972. },
  2973. "language_info": {
  2974. "codemirror_mode": {
  2975. "name": "ipython",
  2976. "version": 3
  2977. },
  2978. "file_extension": ".py",
  2979. "mimetype": "text/x-python",
  2980. "name": "python",
  2981. "nbconvert_exporter": "python",
  2982. "pygments_lexer": "ipython3",
  2983. "version": "3.9.7"
  2984. }
  2985. },
  2986. "nbformat": 4,
  2987. "nbformat_minor": 5
  2988. }