You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

MFCC.ipynb 71KB


  1. {
  2. "cells": [
  3. {
  4. "cell_type": "code",
  5. "execution_count": 1,
  6. "id": "9be4a5bf",
  7. "metadata": {},
  8. "outputs": [],
  9. "source": [
  10. "import librosa\n",
  11. "import librosa.display\n",
  12. "import IPython.display as ipd\n",
  13. "import matplotlib.pyplot as plt\n",
  14. "import numpy as np\n",
  15. "import scipy.io\n",
  16. "from tqdm import tqdm\n",
  17. "import glob\n",
  18. "import os\n",
  19. "import json\n",
  20. "import pickle\n",
  21. "from einops import rearrange "
  22. ]
  23. },
  24. {
  25. "cell_type": "code",
  26. "execution_count": 3,
  27. "id": "41a0eba0",
  28. "metadata": {
  29. "scrolled": true
  30. },
  31. "outputs": [
  32. {
  33. "name": "stdout",
  34. "output_type": "stream",
  35. "text": [
  36. "['P09', 'P10', 'P11']\n"
  37. ]
  38. }
  39. ],
  40. "source": [
  41. "#import subject names and data\n",
  42. "sub_names = []\n",
  43. "sub_dict_data = []\n",
  44. "mat_files = glob.glob('Preprocessed_data/*.mat')\n",
  45. "for i, file in enumerate(mat_files):\n",
  46. " if i<3:\n",
  47. " sub_names.append(file.split('.')[0].split('\\\\')[1])\n",
  48. " sub_dict_data.append(scipy.io.loadmat(file))\n",
  49. "print(sub_names)"
  50. ]
  51. },
  52. {
  53. "cell_type": "code",
  54. "execution_count": 52,
  55. "id": "c58a6bf8",
  56. "metadata": {},
  57. "outputs": [
  58. {
  59. "name": "stdout",
  60. "output_type": "stream",
  61. "text": [
  62. "131\n",
  63. "[6, 5, 2, 1, 4, 6, 1, 2, 5, 3, 2, 6, 3, 1, 6, 2, 5, 1, 3, 0, 0, 6, 2, 3, 5, 4, 2, 0, 5, 4, 2, 5, 1, 3, 6, 4, 1, 4, 5, 1, 5, 0, 0, 5, 6, 1, 1, 6, 2, 2, 3, 3, 0, 6, 3, 2, 1, 3, 4, 3, 4, 4, 2, 3, 5, 0, 4, 1, 6, 0, 5, 4, 4, 0, 6, 2, 6, 5, 0, 0, 4, 1, 3, 9, 10, 8, 7, 8, 9, 7, 10, 7, 9, 9, 7, 7, 8, 10, 8, 10, 10, 10, 10, 7, 8, 7, 10, 8, 7, 8, 9, 9, 7, 7, 8, 9, 10, 9, 8, 7, 7, 9, 9, 9, 10, 10, 9, 8, 10, 8, 8]\n"
  64. ]
  65. }
  66. ],
  67. "source": [
  68. "P8_labels = []\n",
  69. "labels = []\n",
  70. "with open('C:/Users/saeed/Desktop/Master/KARA ONE Data/MM08/kinect_data/labels.txt') as f:\n",
  71. " lines = f.readlines()\n",
  72. "print(len(lines))\n",
  73. "for line in lines:\n",
  74. " line = line.split('\\n')[0]\n",
  75. " if line == '/iy/' or line == 'iy':\n",
  76. " labels.append(0)\n",
  77. " elif line == '/uw/' or line == 'uw':\n",
  78. " labels.append(1)\n",
  79. " elif line == '/piy/' or line == 'piy':\n",
  80. " labels.append(2)\n",
  81. " elif line == '/tiy/' or line == 'tiy':\n",
  82. " labels.append(3)\n",
  83. " elif line == '/diy/' or line == 'diy':\n",
  84. " labels.append(4)\n",
  85. " elif line == '/m/' or line == 'm':\n",
  86. " labels.append(5)\n",
  87. " elif line == '/n/' or line == 'n':\n",
  88. " labels.append(6)\n",
  89. " elif line == '/pat/' or line == 'pat':\n",
  90. " labels.append(7)\n",
  91. " elif line == '/pot/' or line == 'pot':\n",
  92. " labels.append(8)\n",
  93. " elif line == '/knew/' or line == 'knew':\n",
  94. " labels.append(9)\n",
  95. " elif line == '/gnaw/' or line == 'gnaw':\n",
  96. " labels.append(10)\n",
  97. "P8_labels = labels\n",
  98. "print(P8_labels) "
  99. ]
  100. },
  101. {
  102. "cell_type": "code",
  103. "execution_count": 53,
  104. "id": "8679e86e",
  105. "metadata": {},
  106. "outputs": [],
  107. "source": [
  108. "labels = [P9_labels, P10_labels, P11_labels, P12_labels, P14_labels, P15_labels, P16_labels, P18_labels, P19_labels,\n",
  109. " P2_labels, P20_labels, P21_labels, P5_labels, P8_labels]"
  110. ]
  111. },
  112. {
  113. "cell_type": "code",
  114. "execution_count": 54,
  115. "id": "623771dc",
  116. "metadata": {},
  117. "outputs": [
  118. {
  119. "name": "stdout",
  120. "output_type": "stream",
  121. "text": [
  122. "[5, 0, 5, 6, 3, 1, 5, 3, 5, 5, 3, 6, 0, 4, 3, 4, 3, 6, 5, 4, 5, 3, 2, 6, 1, 0, 6, 4, 2, 3, 0, 0, 5, 2, 4, 2, 1, 0, 3, 6, 2, 0, 0, 4, 2, 1, 4, 2, 3, 6, 0, 3, 1, 5, 0, 3, 2, 4, 1, 0, 3, 0, 2, 1, 5, 2, 2, 4, 1, 4, 5, 5, 1, 1, 6, 1, 4, 4, 6, 1, 6, 6, 6, 2, 7, 8, 10, 7, 10, 10, 8, 8, 9, 10, 8, 7, 7, 9, 10, 9, 9, 8, 7, 8, 7, 9, 10, 9, 9, 8, 8, 8, 10, 7, 9, 7, 7, 7, 10, 8, 9, 8, 10, 10, 10, 10, 7, 9, 9, 7, 8, 9]\n"
  123. ]
  124. }
  125. ],
  126. "source": [
  127. "print(labels[1])"
  128. ]
  129. },
  130. {
  131. "cell_type": "code",
  132. "execution_count": 53,
  133. "id": "c4ef23a6",
  134. "metadata": {},
  135. "outputs": [],
  136. "source": [
  137. "with open(\"labels.pkl\", \"wb\") as f:\n",
  138. " pickle.dump(labels, f)"
  139. ]
  140. },
  141. {
  142. "cell_type": "code",
  143. "execution_count": 4,
  144. "id": "bb8b91e2",
  145. "metadata": {},
  146. "outputs": [],
  147. "source": [
  148. "#extract numpy data from dict data\n",
  149. "sub_data = []\n",
  150. "for dict_data, name in zip(sub_dict_data, sub_names):\n",
  151. " if name == 'P2':\n",
  152. " sub_data.append(dict_data['P2'])\n",
  153. " else:\n",
  154. " sub_data.append(dict_data['A'])"
  155. ]
  156. },
  157. {
  158. "cell_type": "code",
  159. "execution_count": 5,
  160. "id": "545e70be",
  161. "metadata": {
  162. "scrolled": true
  163. },
  164. "outputs": [
  165. {
  166. "name": "stdout",
  167. "output_type": "stream",
  168. "text": [
  169. "number of subjects = 3\n",
  170. "data shapes:\n",
  171. "(62, 1250, 132)\n",
  172. "(62, 1250, 132)\n",
  173. "(62, 1250, 132)\n",
  174. "number of all trials = 396\n"
  175. ]
  176. }
  177. ],
  178. "source": [
  179. "#printing number of subjects, all data shapes and number of all trials\n",
  180. "print('number of subjects = ', len(sub_data))\n",
  181. "print('data shapes:')\n",
  182. "count = 0\n",
  183. "for data in sub_data:\n",
  184. " count += data.shape[2]\n",
  185. " print(data.shape)\n",
  186. "print('number of all trials = ', count) "
  187. ]
  188. },
  189. {
  190. "cell_type": "code",
  191. "execution_count": 6,
  192. "id": "10d50aad",
  193. "metadata": {
  194. "scrolled": true
  195. },
  196. "outputs": [
  197. {
  198. "name": "stdout",
  199. "output_type": "stream",
  200. "text": [
  201. "(62, 1250, 132)\n",
  202. "(62, 1250, 132)\n",
  203. "(62, 1250, 132)\n"
  204. ]
  205. }
  206. ],
  207. "source": [
  208. "for data in sub_data:\n",
  209. " print(data.shape)"
  210. ]
  211. },
  212. {
  213. "cell_type": "code",
  214. "execution_count": 20,
  215. "id": "9fd1a38c",
  216. "metadata": {
  217. "scrolled": false
  218. },
  219. "outputs": [
  220. {
  221. "data": {
  222. "image/png": "\n",
  223. "text/plain": [
  224. "<Figure size 576x216 with 1 Axes>"
  225. ]
  226. },
  227. "metadata": {
  228. "needs_background": "light"
  229. },
  230. "output_type": "display_data"
  231. },
  232. {
  233. "data": {
  234. "image/png": "\n",
  235. "text/plain": [
  236. "<Figure size 576x216 with 1 Axes>"
  237. ]
  238. },
  239. "metadata": {
  240. "needs_background": "light"
  241. },
  242. "output_type": "display_data"
  243. },
  244. {
  245. "data": {
  246. "image/png": "\n",
  247. "text/plain": [
  248. "<Figure size 576x216 with 1 Axes>"
  249. ]
  250. },
  251. "metadata": {
  252. "needs_background": "light"
  253. },
  254. "output_type": "display_data"
  255. }
  256. ],
  257. "source": [
  258. "for i,data in enumerate(sub_mfc):\n",
  259. " hist_data = rearrange(data, 'n c m t -> m (n c t)')\n",
  260. " plt.figure(figsize=(8, 3))\n",
  261. " hist = np.histogram(hist_data[0], bins=1000)\n",
  262. " ax = plt.gca()\n",
  263. " ax.set_facecolor((0.98,0.98,0.98))\n",
  264. " plt.grid()\n",
  265. " plt.xticks(fontsize=16)\n",
  266. " plt.yticks(fontsize=16)\n",
  267. " plt.xlabel('Time', fontsize=16)\n",
  268. " plt.ylabel('Vlue', fontsize=16)\n",
  269. " plt.plot(hist[1][1:],hist[0])\n",
  270. " plt.savefig('p'+str(i)+'.png', dpi=300)"
  271. ]
  272. },
  273. {
  274. "cell_type": "code",
  275. "execution_count": 7,
  276. "id": "8c89f9a5",
  277. "metadata": {
  278. "scrolled": true
  279. },
  280. "outputs": [
  281. {
  282. "name": "stderr",
  283. "output_type": "stream",
  284. "text": [
  285. "100%|████████████████████████████████████████████████████████████████████████████████| 132/132 [00:37<00:00, 3.48it/s]\n"
  286. ]
  287. },
  288. {
  289. "name": "stdout",
  290. "output_type": "stream",
  291. "text": [
  292. "(132, 62, 20, 11)\n"
  293. ]
  294. },
  295. {
  296. "name": "stderr",
  297. "output_type": "stream",
  298. "text": [
  299. "100%|████████████████████████████████████████████████████████████████████████████████| 132/132 [00:39<00:00, 3.33it/s]\n"
  300. ]
  301. },
  302. {
  303. "name": "stdout",
  304. "output_type": "stream",
  305. "text": [
  306. "(132, 62, 20, 11)\n"
  307. ]
  308. },
  309. {
  310. "name": "stderr",
  311. "output_type": "stream",
  312. "text": [
  313. "100%|████████████████████████████████████████████████████████████████████████████████| 132/132 [00:39<00:00, 3.36it/s]"
  314. ]
  315. },
  316. {
  317. "name": "stdout",
  318. "output_type": "stream",
  319. "text": [
  320. "(132, 62, 20, 11)\n",
  321. "(396, 62, 20, 11)\n"
  322. ]
  323. },
  324. {
  325. "name": "stderr",
  326. "output_type": "stream",
  327. "text": [
  328. "\n"
  329. ]
  330. }
  331. ],
  332. "source": [
  333. "#reshaping the data to (trial, channel, sample) \n",
  334. "n_mfcc = 20\n",
  335. "framesize = 1 * 250\n",
  336. "hop_size = int(framesize/2)\n",
  337. "\n",
  338. "sub_mfc = []\n",
  339. "\n",
  340. "for i, data in enumerate(sub_data):\n",
  341. " data = rearrange(data, 'c s t -> t c s')\n",
  342. " trials = []\n",
  343. " for j, trial in enumerate(tqdm(data)):\n",
  344. " channels = []\n",
  345. " for k, channel in enumerate(trial):\n",
  346. " mfccs = librosa.feature.mfcc(y=channel, n_mfcc=n_mfcc, n_fft=framesize, hop_length=hop_size, sr=250)\n",
  347. " channels.append(np.array(mfccs))\n",
  348. " trials.append(np.array(channels)) \n",
  349. " data = np.array(trials)\n",
  350. " print(data.shape)\n",
  351. " sub_mfc.append(data)\n",
  352. " Max = np.max(data, axis=(0,1,3), keepdims=True)\n",
  353. " Min = np.min(data, axis=(0,1,3), keepdims=True)\n",
  354. " data = (data-Min)/(Max-Min)\n",
  355. " if i == 0:\n",
  356. " all_data = data\n",
  357. " else:\n",
  358. " all_data = np.vstack((all_data, data))\n",
  359. "print(all_data.shape)"
  360. ]
  361. },
  362. {
  363. "cell_type": "code",
  364. "execution_count": 24,
  365. "id": "6aedd8d3",
  366. "metadata": {},
  367. "outputs": [],
  368. "source": [
  369. "with open('normal_all_data.pkl', 'wb') as f:\n",
  370. " pickle.dump(all_data, f)"
  371. ]
  372. },
  373. {
  374. "cell_type": "code",
  375. "execution_count": 6,
  376. "id": "fcc23565",
  377. "metadata": {},
  378. "outputs": [],
  379. "source": [
  380. "#set parameters for MFCC extraction\n",
  381. "n_mfcc = 20\n",
  382. "framesize = 1 * 250\n",
  383. "hop_size = int(framesize/2)"
  384. ]
  385. },
  386. {
  387. "cell_type": "code",
  388. "execution_count": 45,
  389. "id": "78c6ffa2",
  390. "metadata": {},
  391. "outputs": [
  392. {
  393. "name": "stderr",
  394. "output_type": "stream",
  395. "text": [
  396. "100%|██████████████████████████████████████████████████████████████████████████████████| 14/14 [05:12<00:00, 22.29s/it]\n"
  397. ]
  398. }
  399. ],
  400. "source": [
  401. "#calculate MFCCs and put them in a matrix with shape (#trial, #channels) for each sub. put all new subs in MFCC_data \n",
  402. "MFCC_data = []\n",
  403. "for sub in tqdm(reshaped_data):\n",
  404. " trials = []\n",
  405. " for i, trial in enumerate(sub):\n",
  406. " channels = []\n",
  407. " for j, channel in enumerate(trial):\n",
  408. " mfccs = librosa.feature.mfcc(y=channel, n_mfcc=n_mfcc, n_fft=framesize, hop_length=hop_size, sr=250)\n",
  409. " channels.append(np.array(mfccs))\n",
  410. " trials.append(np.array(channels))\n",
  411. " MFCC_data.append(np.array(trials)) "
  412. ]
  413. },
  414. {
  415. "cell_type": "code",
  416. "execution_count": 51,
  417. "id": "e56998d7",
  418. "metadata": {},
  419. "outputs": [],
  420. "source": [
  421. "with open(\"MFCCs.pkl\", \"wb\") as f:\n",
  422. " pickle.dump(MFCC_data, f)"
  423. ]
  424. },
  425. {
  426. "cell_type": "code",
  427. "execution_count": 4,
  428. "id": "86bc95b2",
  429. "metadata": {
  430. "scrolled": true
  431. },
  432. "outputs": [],
  433. "source": [
  434. "with open(\"MFCCs.pkl\", \"rb\") as f:\n",
  435. " mfccs = pickle.load(f)"
  436. ]
  437. },
  438. {
  439. "cell_type": "code",
  440. "execution_count": 56,
  441. "id": "2f243666",
  442. "metadata": {},
  443. "outputs": [
  444. {
  445. "data": {
  446. "text/plain": [
  447. "(132, 62, 20, 11)"
  448. ]
  449. },
  450. "execution_count": 56,
  451. "metadata": {},
  452. "output_type": "execute_result"
  453. }
  454. ],
  455. "source": [
  456. "mfccs[0].shape"
  457. ]
  458. },
  459. {
  460. "cell_type": "code",
  461. "execution_count": 77,
  462. "id": "20783211",
  463. "metadata": {},
  464. "outputs": [],
  465. "source": [
  466. "all_labels = []\n",
  467. "for P_label in labels:\n",
  468. " for label in P_label:\n",
  469. " all_labels.append(label)"
  470. ]
  471. },
  472. {
  473. "cell_type": "code",
  474. "execution_count": 79,
  475. "id": "88491149",
  476. "metadata": {},
  477. "outputs": [],
  478. "source": [
  479. "with open(\"all_labels.pkl\", \"wb\") as f:\n",
  480. " pickle.dump(all_labels, f)"
  481. ]
  482. },
  483. {
  484. "cell_type": "code",
  485. "execution_count": 81,
  486. "id": "451b0af3",
  487. "metadata": {},
  488. "outputs": [],
  489. "source": [
  490. "for i, mfcc in enumerate(mfccs):\n",
  491. " if i==0:\n",
  492. " all_mfccs = mfcc\n",
  493. " else:\n",
  494. " all_mfccs = np.vstack((all_mfccs, mfcc))"
  495. ]
  496. },
  497. {
  498. "cell_type": "code",
  499. "execution_count": 82,
  500. "id": "e596f778",
  501. "metadata": {},
  502. "outputs": [
  503. {
  504. "data": {
  505. "text/plain": [
  506. "(1913, 62, 20, 11)"
  507. ]
  508. },
  509. "execution_count": 82,
  510. "metadata": {},
  511. "output_type": "execute_result"
  512. }
  513. ],
  514. "source": [
  515. "all_mfccs.shape"
  516. ]
  517. },
  518. {
  519. "cell_type": "code",
  520. "execution_count": 83,
  521. "id": "fb164790",
  522. "metadata": {},
  523. "outputs": [],
  524. "source": [
  525. "with open(\"all_data.pkl\", \"wb\") as f:\n",
  526. " pickle.dump(all_mfccs, f)"
  527. ]
  528. },
  529. {
  530. "cell_type": "code",
  531. "execution_count": 74,
  532. "id": "71bdd5bf",
  533. "metadata": {},
  534. "outputs": [],
  535. "source": [
  536. "a = np.random.randint(0, 20, size=(2,3,4,5))\n",
  537. "b = np.random.randint(0, 20, size=(3,3,4,5))\n",
  538. "c = np.random.randint(0, 20, size=(4,3,4,5))"
  539. ]
  540. },
  541. {
  542. "cell_type": "code",
  543. "execution_count": 56,
  544. "id": "1efdbb66",
  545. "metadata": {},
  546. "outputs": [],
  547. "source": [
  548. "with open('all_label.pkl', 'wb') as f:\n",
  549. " pickle.dump(all_label, f)"
  550. ]
  551. },
  552. {
  553. "cell_type": "code",
  554. "execution_count": 55,
  555. "id": "9c938cac",
  556. "metadata": {},
  557. "outputs": [
  558. {
  559. "data": {
  560. "text/plain": [
  561. "1913"
  562. ]
  563. },
  564. "execution_count": 55,
  565. "metadata": {},
  566. "output_type": "execute_result"
  567. }
  568. ],
  569. "source": [
  570. "all_label = [label for p in labels for label in p]\n",
  571. "len(all_label)"
  572. ]
  573. },
  574. {
  575. "cell_type": "code",
  576. "execution_count": null,
  577. "id": "48bfe619",
  578. "metadata": {},
  579. "outputs": [],
  580. "source": [
  581. "with open(\"all_label.pkl\", \"rb\") as f:\n",
  582. " labels = pickle.load(f)\n",
  583. "vowel_labels, nasal_labels, bilabial_labels, iy_labels, uw_labels = [], [], [], [], [] \n",
  584. "for label in labels:\n",
  585. " if label==0:\n",
  586. " vowel_labels.append(0)\n",
  587. " nasal_labels.append(0)\n",
  588. " bilabial_labels.append(0)\n",
  589. " iy_labels.append(1)\n",
  590. " uw_labels.append(0)\n",
  591. " elif label==1:\n",
  592. " vowel_labels.append(0)\n",
  593. " nasal_labels.append(0)\n",
  594. " bilabial_labels.append(0)\n",
  595. " iy_labels.append(0)\n",
  596. " uw_labels.append(1)\n",
  597. " elif label==2:\n",
  598. " vowel_labels.append(1)\n",
  599. " nasal_labels.append(0)\n",
  600. " bilabial_labels.append(1)\n",
  601. " iy_labels.append(1)\n",
  602. " uw_labels.append(0)\n",
  603. " elif label==3:\n",
  604. " vowel_labels.append(1)\n",
  605. " nasal_labels.append(0)\n",
  606. " bilabial_labels.append(0)\n",
  607. " iy_labels.append(1)\n",
  608. " uw_labels.append(0)\n",
  609. " elif label==4:\n",
  610. " vowel_labels.append(1)\n",
  611. " nasal_labels.append(0)\n",
  612. " bilabial_labels.append(0)\n",
  613. " iy_labels.append(1)\n",
  614. " uw_labels.append(0)\n",
  615. " elif label==5:\n",
  616. " vowel_labels.append(1)\n",
  617. " nasal_labels.append(1)\n",
  618. " bilabial_labels.append(1)\n",
  619. " iy_labels.append(0)\n",
  620. " uw_labels.append(0)\n",
  621. " elif label==6:\n",
  622. " vowel_labels.append(1)\n",
  623. " nasal_labels.append(1)\n",
  624. " bilabial_labels.append(0)\n",
  625. " iy_labels.append(0)\n",
  626. " uw_labels.append(0)\n",
  627. " elif label==7:\n",
  628. " vowel_labels.append(1)\n",
  629. " nasal_labels.append(0)\n",
  630. " bilabial_labels.append(1)\n",
  631. " iy_labels.append(0)\n",
  632. " uw_labels.append(0)\n",
  633. " elif label==8:\n",
  634. " vowel_labels.append(1)\n",
  635. " nasal_labels.append(0)\n",
  636. " bilabial_labels.append(1)\n",
  637. " iy_labels.append(0)\n",
  638. " uw_labels.append(0)\n",
  639. " elif label==9:\n",
  640. " vowel_labels.append(1)\n",
  641. " nasal_labels.append(1)\n",
  642. " bilabial_labels.append(0)\n",
  643. " iy_labels.append(0)\n",
  644. " uw_labels.append(0)\n",
  645. " elif label==10:\n",
  646. " vowel_labels.append(1)\n",
  647. " nasal_labels.append(1)\n",
  648. " bilabial_labels.append(0)\n",
  649. " iy_labels.append(0)\n",
  650. " uw_labels.append(0)\n",
  651. " \n",
  652. "with open('vowel_label.pkl', 'wb') as f:\n",
  653. " pickle.dump(vowel_labels, f)\n",
  654. "with open('bilab_label.pkl', 'wb') as f:\n",
  655. " pickle.dump(bilabi_labels, f)\n",
  656. "with open('nasal_label.pkl', 'wb') as f:\n",
  657. " pickle.dump(vowel_labels, f)\n",
  658. "with open('iy_label.pkl', 'wb') as f:\n",
  659. " pickle.dump(vowel_labels, f)\n",
  660. "with open('uw_label.pkl', 'wb') as f:\n",
  661. " pickle.dump(vowel_labels, f)"
  662. ]
  663. },
  664. {
  665. "cell_type": "code",
  666. "execution_count": null,
  667. "id": "c0edd8b2",
  668. "metadata": {},
  669. "outputs": [],
  670. "source": [
  671. "#train-test split for tasks\n",
  672. "X_vowel_train, X_vowel_test, y_vowel_train, y_vowel_test = train_test_split(all_data, vowel_labels, \n",
  673. " stratify=vowel_labels,\n",
  674. " test_size=0.1)\n",
  675. "X_nasal_train, X_nasal_test, y_nasal_train, y_nasal_test = train_test_split(all_data, nasal_labels, \n",
  676. " stratify=nasal_labels, \n",
  677. " test_size=0.1)\n",
  678. "X_bilabial_train, X_bilabial_test, y_bilabial_train, y_bilabial_test = train_test_split(all_data, bilabial_labels, \n",
  679. " stratify=bilabial_labels, \n",
  680. " test_size=0.1)\n",
  681. "X_iy_train, X_iy_test, y_iy_train, y_iy_test = train_test_split(all_data, iy_labels, \n",
  682. " stratify=iy_labels, \n",
  683. " test_size=0.1)\n",
  684. "X_uw_train, X_uw_test, y_uw_train, y_uw_test = train_test_split(all_data, uw_labels, \n",
  685. " stratify=uw_labels, \n",
  686. " test_size=0.1)\n",
  687. "X_train, X_test, y_train, y_test = train_test_split(all_data, all_labels, \n",
  688. " stratify=all_labels, \n",
  689. " test_size=0.1)"
  690. ]
  691. },
  692. {
  693. "cell_type": "code",
  694. "execution_count": null,
  695. "id": "05c2e96b",
  696. "metadata": {},
  697. "outputs": [],
  698. "source": [
  699. "with open(\"data/uw/X_uw_train.pkl\", \"wb\") as f:\n",
  700. " pickle.dump(X_uw_train, f)\n",
  701. "with open(\"data/uw/X_uw_test.pkl\", \"wb\") as f:\n",
  702. " pickle.dump(X_uw_test, f)\n",
  703. "with open(\"data/uw/y_uw_train.pkl\", \"wb\") as f:\n",
  704. " pickle.dump(y_uw_train, f)\n",
  705. "with open(\"data/uw/y_uw_test.pkl\", \"wb\") as f:\n",
  706. " pickle.dump(y_uw_test, f)"
  707. ]
  708. }
  709. ],
  710. "metadata": {
  711. "kernelspec": {
  712. "display_name": "Python 3 (ipykernel)",
  713. "language": "python",
  714. "name": "python3"
  715. },
  716. "language_info": {
  717. "codemirror_mode": {
  718. "name": "ipython",
  719. "version": 3
  720. },
  721. "file_extension": ".py",
  722. "mimetype": "text/x-python",
  723. "name": "python",
  724. "nbconvert_exporter": "python",
  725. "pygments_lexer": "ipython3",
  726. "version": "3.9.7"
  727. }
  728. },
  729. "nbformat": 4,
  730. "nbformat_minor": 5
  731. }