Browse Source

codes uploaded

master
Saeed Dastani 1 year ago
parent
commit
fffc3bba39
4 changed files with 9944 additions and 0 deletions
  1. 934
    0
      3d cnn.ipynb
  2. 2988
    0
      CNN-RNN.ipynb
  3. 731
    0
      MFCC.ipynb
  4. 5291
    0
      captum.ipynb

+ 934
- 0
3d cnn.ipynb View File

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "9f087356",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\saeed\\Desktop\\Master\\bci\\lib\\site-packages\\tqdm\\auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.model_selection import KFold, StratifiedKFold\n",
"import librosa\n",
"import librosa.display\n",
"import IPython.display as ipd\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import scipy.io\n",
"from tqdm import tqdm\n",
"import glob\n",
"import os\n",
"import json\n",
"import pickle\n",
"from einops import rearrange\n",
"from captum.attr import DeepLift\n",
"from captum.attr import visualization as viz"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "ba4bf52c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(1913, 62, 20, 11)\n"
]
}
],
"source": [
"with open(\"data/normal_all_data.pkl\", \"rb\") as f:\n",
" all_data = pickle.load(f)\n",
"with open(\"data/all_label.pkl\", \"rb\") as f:\n",
" labels = pickle.load(f)\n",
"with open(\"data/vowel_label.pkl\", \"rb\") as f:\n",
" vowel_label = pickle.load(f)\n",
"with open(\"data/bilab_label.pkl\", \"rb\") as f:\n",
" bilab_label = pickle.load(f)\n",
"with open(\"data/nasal_label.pkl\", \"rb\") as f:\n",
" nasal_label = pickle.load(f)\n",
"with open(\"data/iy_label.pkl\", \"rb\") as f:\n",
" iy_label = pickle.load(f)\n",
"with open(\"data/uw_label.pkl\", \"rb\") as f:\n",
" uw_label = pickle.load(f)\n",
"\n",
"print(all_data.shape)"
]
},
{
"cell_type": "code",
"execution_count": 220,
"id": "17f8364a",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
" 0%| | 0/1913 [00:00<?, ?it/s]C:\\Users\\saeed\\Desktop\\Master\\bci\\lib\\site-packages\\librosa\\util\\decorators.py:88: UserWarning: n_fft=250 is too small for input signal of length=11\n",
" return f(*args, **kwargs)\n",
" 5%|███▉ | 95/1913 [00:23<07:36, 3.98it/s]\n",
"\n",
"KeyboardInterrupt\n",
"\n"
]
}
],
"source": [
"#calculate MFCCs with windowing\n",
"n_mfcc = 20\n",
"framesize = 1 * 250\n",
"hop_size = int(framesize/2)\n",
"\n",
"trials = []\n",
"for i, trial in enumerate(tqdm(data)):\n",
" channels = []\n",
" for j, channel in enumerate(trial):\n",
" mfccs = librosa.feature.mfcc(y=channel, n_mfcc=n_mfcc, n_fft=framesize, hop_length=hop_size, sr=250)\n",
" channels.append(np.array(mfccs))\n",
" trials.append(np.array(channels)) \n",
"mfc_data = np.array(trials)\n",
"\n",
"print(mfc_data.shape)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "2830a90b",
"metadata": {},
"outputs": [],
"source": [
"#save as (windows MFCCs)\n",
"with open('data/11_20mfc.pkl', 'wb') as f:\n",
" pickle.dump(mfc_data, f)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "86e2469b",
"metadata": {},
"outputs": [],
"source": [
"class Dataset():\n",
" def __init__(self, data, label, oversample=True):\n",
" self.data = data\n",
" self.label = label\n",
" self.over = oversample\n",
" self.train = None\n",
" self.val = None\n",
" self.test = None\n",
" \n",
" def picturize(self):\n",
" trials = []\n",
" depth = self.data.shape[2]\n",
" for trial in self.data:\n",
" pic = np.zeros((7,9,depth,11))\n",
" pic[0,2] = trial[3]\n",
" pic[0,3] = trial[0]\n",
" pic[0,4] = trial[1]\n",
" pic[0,5] = trial[2]\n",
" pic[0,6] = trial[4]\n",
" pic[1,:] = trial[5:14]\n",
" pic[2,:] = trial[14:23]\n",
" pic[3,:] = trial[23:32]\n",
" pic[4,:] = trial[32:41]\n",
" pic[5,:] = trial[41:50]\n",
" pic[6,0] = trial[50]\n",
" pic[6,1] = trial[51]\n",
" pic[6,2] = trial[52]\n",
" pic[6,3] = trial[58]\n",
" pic[6,4] = trial[53]\n",
" pic[6,5] = trial[60]\n",
" pic[6,6] = trial[54]\n",
" pic[6,7] = trial[55]\n",
" pic[6,8] = trial[56]\n",
" trials.append(pic)\n",
" self.data = np.array(trials)\n",
" return self.data\n",
" \n",
" def split(self, train_idx, test_idx, val_size=0.1, norm=False):\n",
" train_val_data = np.stack([self.data[index] for index in train_idx])\n",
" train_val_label = [self.label[index] for index in train_idx]\n",
" test_data = np.stack([self.data[index] for index in test_idx])\n",
" test_label = [self.label[index] for index in test_idx]\n",
" \n",
" if norm:\n",
" Max = np.max(train_val_data, axis=(0,1,2,4), keepdims=True)\n",
" Min = np.min(train_val_data, axis=(0,1,2,4), keepdims=True)\n",
" train_val_data = (train_val_data-Min)/(Max-Min)\n",
"\n",
" Max_test = np.max(test_data, axis=(0,1,2,4), keepdims=True)\n",
" Min_test = np.min(test_data, axis=(0,1,2,4), keepdims=True)\n",
" test_data = (test_data-Min)/(Max-Min)\n",
" \n",
" train_val = [[train_val_data[i], train_val_label[i]] for i in range(len(train_val_data))]\n",
" self.test = [[test_data[i], test_label[i]] for i in range(len(test_data))]\n",
" \n",
" num_train_val = len(train_val)\n",
" indices = list(range(num_train_val))\n",
" np.random.shuffle(indices)\n",
" split = int(np.floor(val_size*num_train_val))\n",
" train, val = [train_val[i] for i in indices[split:]] ,[train_val[i] for i in indices[:split]]\n",
" \n",
" if self.over:\n",
" train_labels = [data[1] for data in train]\n",
" _, counts = np.unique(train_labels, return_counts=True)\n",
" print(counts)\n",
" if counts[1]>counts[0]:\n",
" label0 = [data for data in train if data[1]==0]\n",
" coef = int(counts[1]/counts[0])\n",
" for i in range(coef):\n",
" train = train + label0\n",
" elif counts[1]<counts[0]:\n",
" label1 = [data for data in train if data[1]==1]\n",
" coef = int(counts[0]/counts[1])\n",
" for i in range(coef):\n",
" train = train + label1\n",
" self.train = train\n",
" self.val = val\n",
" \n",
" return self.train, self.val, self.test\n",
" \n",
" \n",
" def show(self):\n",
" print('data shape = ', self.data.shape)\n",
" \n",
" if self.train is None:\n",
" print('train not creaeted!')\n",
" else:\n",
" print('train shape = ', len(self.train))\n",
" \n",
" if self.val is None:\n",
" print('validation not creaeted!')\n",
" else:\n",
" print('validation shape = ', len(self.val))\n",
" \n",
" if self.test is None:\n",
" print('test not creaeted!')\n",
" else:\n",
" print('test shape = ', len(self.test))"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "4fb6541e",
"metadata": {},
"outputs": [],
"source": [
"def train_model(train_loader, val_loader, epochs, lr, fold, steps):\n",
" print('creating model...')\n",
" model = cnn3d().float()\n",
" optimizer = torch.optim.Adam(model.parameters(), lr=lr)\n",
" criterion = nn.BCELoss()\n",
"\n",
" scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, total_steps=steps, max_lr=lr*10)\n",
" scheduler1 = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.1)\n",
" l1_lambda = 0.0001\n",
" \n",
" min_val_loss = np.inf\n",
" max_val_acc = 0\n",
" for epoch in range(epochs):\n",
" print('epoch: ', epoch+1)\n",
" train_loss = 0\n",
" train_correct = 0\n",
" model.train()\n",
" for iteration, (data,label) in enumerate(train_loader):\n",
" optimizer.zero_grad()\n",
" output = model(data.float())\n",
" label = torch.reshape(label, (-1,1))\n",
" label = label.float()\n",
" loss = criterion(output, label)\n",
" for W in model.parameters():\n",
" loss = loss + l1_lambda*W.norm(1)\n",
" loss.backward()\n",
" optimizer.step()\n",
" scheduler.step()\n",
" targets = [1 if output[i].round()==label[i] else 0 for i in range(len(label))]\n",
" #print([output[i].round().item() for i in range(len(label))])\n",
" train_correct += sum(targets)\n",
" train_loss += loss.item()*data.shape[0]\n",
" #scheduler1.step() \n",
" train_acc = train_correct/len(train_loader.sampler) \n",
" train_loss = train_loss/len(train_loader.sampler)\n",
" \n",
" val_loss = 0\n",
" val_correct = 0\n",
" model.eval()\n",
" for data, label in val_loader:\n",
" output = model(data.float())\n",
" label = torch.reshape(label, (-1,1))\n",
" label = label.float()\n",
" loss = criterion(output, label) \n",
" val_loss += loss.item()*data.shape[0]\n",
" targets = [1 if output[i].round()==label[i] else 0 for i in range(len(label))]\n",
" val_correct += sum(targets)\n",
" \n",
" val_loss = val_loss/len(val_loader.sampler)\n",
" val_acc = val_correct/len(val_loader.sampler)\n",
" if val_loss <= min_val_loss:\n",
" print(\"validation loss decreased ({:.6f} ---> {:.6f}), val_acc = {}\".format(min_val_loss, val_loss, val_acc))\n",
" torch.save(model.state_dict(), 'train/model'+str(fold)+'.pt')\n",
" min_val_loss = val_loss\n",
" torch.save(model.state_dict(), 'train/last_model'+str(fold)+'.pt') \n",
" print('epoch {}: train loss = {}, train acc = {},\\nval_loss = {}, val_acc = {}\\n'\n",
" .format(epoch+1, train_loss, train_acc, val_loss, val_acc))\n",
" \n",
" if int(train_acc)==1:\n",
" print('!!! overfitted !!!')\n",
" break\n",
" model.train()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "7d61fc48",
"metadata": {},
"outputs": [],
"source": [
"def evaluate_model(test_loader, fold):\n",
" model =cnn3d().float()\n",
" model.load_state_dict(torch.load('train/model'+str(fold)+'.pt'))\n",
" \n",
" n_correct = 0\n",
" model.eval()\n",
" for data, label in test_loader:\n",
" output = model(data.float())\n",
" targets = [1 if output[i].round()==label[i] else 0 for i in range(len(label))]\n",
" print(targets)\n",
" n_correct += sum(targets) \n",
" test_accs = n_correct/len(test_loader.sampler)\n",
" print('early stoping results:\\n\\t', test_accs)\n",
" \n",
" n_correct = 0\n",
" model.eval()\n",
" for data, label in train_loader:\n",
" output = model(data.float())\n",
" targets = [1 if output[i].round()==label[i] else 0 for i in range(len(label))]\n",
" n_correct += sum(targets)\n",
" \n",
" train_accs = n_correct/len(train_loader.sampler)\n",
" print('\\t', train_accs)\n",
" \n",
" model = cnn3d().float()\n",
" model.load_state_dict(torch.load('train/last_model'+str(fold)+'.pt'))\n",
" \n",
" n_correct = 0\n",
" model.eval()\n",
" for data, label in test_loader:\n",
" output = model(data.float())\n",
" targets = [1 if output[i].round()==label[i] else 0 for i in range(len(label))]\n",
" print(targets)\n",
" n_correct += sum(targets)\n",
" test_accs_over = n_correct/len(test_loader.sampler)\n",
" print('full train results:\\n\\t', test_accs_over)\n",
" \n",
" n_correct = 0\n",
" model.eval()\n",
" for data, label in train_loader:\n",
" output = model(data.float())\n",
" targets = [1 if output[i].round()==label[i] else 0 for i in range(len(label))]\n",
" n_correct += sum(targets)\n",
" train_accs_over = n_correct/len(train_loader.sampler)\n",
" print('\\t', train_accs_over)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "402f6c76",
"metadata": {},
"outputs": [],
"source": [
"def calculate_steps(train_loader, epochs):\n",
" steps = 0\n",
" for epoch in range(epochs):\n",
" for data, label in train_loader:\n",
" steps += 1\n",
" return steps"
]
},
{
"cell_type": "code",
"execution_count": 43,
"id": "076c9c78",
"metadata": {},
"outputs": [],
"source": [
"class cnn3d(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.conv1 = nn.Conv3d(20, 16, kernel_size=(3, 3, 3), padding=1)\n",
" self.conv2 = nn.Conv3d(16, 32, kernel_size=(3, 3, 3), padding=0)\n",
" self.pool = nn.MaxPool3d((2, 2, 2), stride=2)\n",
" self.fc1 = nn.Linear(192, 128)\n",
" self.fc2 = nn.Linear(128, 1)\n",
" self.drop = nn.Dropout(0.25)\n",
" self.batch1 = nn.BatchNorm3d(16)\n",
" self.batch2 = nn.BatchNorm3d(32)\n",
" self.batch3 = nn.BatchNorm1d(128)\n",
" \n",
" def forward(self, x):\n",
" x = rearrange(x, 'n h w m t -> n m t h w')\n",
" out = self.pool(F.relu(self.batch1(self.conv1(x))))\n",
" out = F.relu(self.batch2(self.conv2(out)))\n",
" out = out.view(out.size(0), -1)\n",
" out = self.drop(F.relu(self.batch3(self.fc1(out))))\n",
" out = F.sigmoid(self.fc2(out))\n",
" return out"
]
},
{
"cell_type": "code",
"execution_count": 44,
"id": "1b859362",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[0.6052],\n",
" [0.5052],\n",
" [0.2035],\n",
" [0.6347]], grad_fn=<SigmoidBackward0>)"
]
},
"execution_count": 44,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#test model\n",
"model = cnn3d()\n",
"sample = torch.rand((4,7,9,20,11))\n",
"model(sample)"
]
},
{
"cell_type": "code",
"execution_count": 45,
"id": "7c72830d",
"metadata": {},
"outputs": [],
"source": [
"#congig\n",
"\n",
"val_size = 0.25\n",
"n_epochs = 100\n",
"batch_size = 128\n",
"print_every = 10\n",
"lr = 0.00001\n",
"k = 10\n",
"skf=StratifiedKFold(n_splits=k, shuffle=True, random_state=32)"
]
},
{
"cell_type": "code",
"execution_count": 46,
"id": "96df8fcb",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(1913, 62, 17, 11)\n"
]
}
],
"source": [
"print(all_data[:,:,3:,:].shape)"
]
},
{
"cell_type": "code",
"execution_count": 47,
"id": "9eb46d39",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"------------fold 0-----------\n",
"[ 291 1258]\n",
"0.4636933284187247\n",
"data shape = (1913, 7, 9, 20, 11)\n",
"train shape = 2713\n",
"validation shape = 172\n",
"test shape = 192\n",
"calculating total steps...\n",
"creating model...\n",
"epoch: 1\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\saeed\\Desktop\\Master\\bci\\lib\\site-packages\\torch\\nn\\functional.py:1960: UserWarning: nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.\n",
" warnings.warn(\"nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.\")\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"validation loss decreased (inf ---> 0.679123), val_acc = 0.7093023255813954\n",
"epoch 1: train loss = 0.8683970651853318, train acc = 0.5208256542572798,\n",
"val_loss = 0.6791226988614991, val_acc = 0.7093023255813954\n",
"\n",
"epoch: 2\n",
"epoch 2: train loss = 0.8702899437086019, train acc = 0.524511610762993,\n",
"val_loss = 0.7166987934777903, val_acc = 0.4186046511627907\n",
"\n",
"epoch: 3\n",
"epoch 3: train loss = 0.8651134706141103, train acc = 0.5344636933284187,\n",
"val_loss = 0.7987130630848019, val_acc = 0.3430232558139535\n",
"\n",
"epoch: 4\n",
"epoch 4: train loss = 0.8593147574778724, train acc = 0.5407298193881312,\n",
"val_loss = 0.8141745814057284, val_acc = 0.3023255813953488\n",
"\n",
"epoch: 5\n",
"epoch 5: train loss = 0.8594019789780104, train acc = 0.539255436785846,\n",
"val_loss = 0.81581727987112, val_acc = 0.3081395348837209\n",
"\n",
"epoch: 6\n",
"epoch 6: train loss = 0.855004380604371, train acc = 0.5477331367489864,\n",
"val_loss = 0.8083413853201755, val_acc = 0.3372093023255814\n",
"\n",
"epoch: 7\n",
"epoch 7: train loss = 0.8482353424744724, train acc = 0.5547364541098415,\n",
"val_loss = 0.7954940158267354, val_acc = 0.3546511627906977\n",
"\n",
"epoch: 8\n",
"epoch 8: train loss = 0.8403365777464226, train acc = 0.5772207887946922,\n",
"val_loss = 0.7943125145379887, val_acc = 0.3430232558139535\n",
"\n",
"epoch: 9\n",
"epoch 9: train loss = 0.8323755666772862, train acc = 0.5794323626981202,\n",
"val_loss = 0.7893233784409457, val_acc = 0.3546511627906977\n",
"\n",
"epoch: 10\n",
"epoch 10: train loss = 0.8240972711202451, train acc = 0.5927018061186878,\n",
"val_loss = 0.7756274495013925, val_acc = 0.37209302325581395\n",
"\n",
"epoch: 11\n",
"epoch 11: train loss = 0.8123821243341673, train acc = 0.618134906008109,\n",
"val_loss = 0.7595333221346833, val_acc = 0.38953488372093026\n",
"\n",
"epoch: 12\n",
"epoch 12: train loss = 0.8052803598511847, train acc = 0.6376704754883893,\n",
"val_loss = 0.7610093549240468, val_acc = 0.38372093023255816\n",
"\n",
"epoch: 13\n",
"epoch 13: train loss = 0.794128361483847, train acc = 0.6439366015481017,\n",
"val_loss = 0.746018763198409, val_acc = 0.43023255813953487\n",
"\n",
"epoch: 14\n",
"epoch 14: train loss = 0.7812480943869982, train acc = 0.6804275709546628,\n",
"val_loss = 0.7397682084593662, val_acc = 0.42441860465116277\n",
"\n",
"epoch: 15\n",
"epoch 15: train loss = 0.770653328441229, train acc = 0.6859565057132326,\n",
"val_loss = 0.7616525181504183, val_acc = 0.38953488372093026\n",
"\n",
"epoch: 16\n",
"epoch 16: train loss = 0.7537735868897598, train acc = 0.7150755621083671,\n",
"val_loss = 0.692143508168154, val_acc = 0.5348837209302325\n",
"\n",
"epoch: 17\n",
"epoch 17: train loss = 0.7378950803604154, train acc = 0.7412458532989311,\n",
"val_loss = 0.7284444473510565, val_acc = 0.43023255813953487\n",
"\n",
"epoch: 18\n",
"validation loss decreased (0.679123 ---> 0.677328), val_acc = 0.5872093023255814\n",
"epoch 18: train loss = 0.7233544636642718, train acc = 0.7640987836343531,\n",
"val_loss = 0.6773282483566639, val_acc = 0.5872093023255814\n",
"\n",
"epoch: 19\n",
"epoch 19: train loss = 0.699734561411692, train acc = 0.7876889052709178,\n",
"val_loss = 0.6925027619960696, val_acc = 0.5523255813953488\n",
"\n",
"epoch: 20\n",
"validation loss decreased (0.677328 ---> 0.633773), val_acc = 0.6686046511627907\n",
"epoch 20: train loss = 0.6822101758539655, train acc = 0.8112790269074824,\n",
"val_loss = 0.6337725229041521, val_acc = 0.6686046511627907\n",
"\n",
"epoch: 21\n",
"epoch 21: train loss = 0.6548226526905767, train acc = 0.832288978990048,\n",
"val_loss = 0.7115770703138307, val_acc = 0.5174418604651163\n",
"\n",
"epoch: 22\n",
"epoch 22: train loss = 0.6249409928736632, train acc = 0.8698857353483229,\n",
"val_loss = 0.643924496894659, val_acc = 0.5930232558139535\n",
"\n",
"epoch: 23\n",
"epoch 23: train loss = 0.5943754564516227, train acc = 0.8772576483597494,\n",
"val_loss = 0.6660033076308495, val_acc = 0.5697674418604651\n",
"\n",
"epoch: 24\n",
"validation loss decreased (0.633773 ---> 0.579030), val_acc = 0.7151162790697675\n",
"epoch 24: train loss = 0.5627755398817094, train acc = 0.9012163656468853,\n",
"val_loss = 0.5790298899938894, val_acc = 0.7151162790697675\n",
"\n",
"epoch: 25\n",
"epoch 25: train loss = 0.5276576605090961, train acc = 0.92480648728345,\n",
"val_loss = 0.5794833940128947, val_acc = 0.7151162790697675\n",
"\n",
"epoch: 26\n",
"epoch 26: train loss = 0.4913847099425298, train acc = 0.9343899741983045,\n",
"val_loss = 0.6454916582551113, val_acc = 0.622093023255814\n",
"\n",
"epoch: 27\n",
"validation loss decreased (0.579030 ---> 0.547536), val_acc = 0.7441860465116279\n",
"epoch 27: train loss = 0.45694537973254534, train acc = 0.9502395871728714,\n",
"val_loss = 0.547536434129227, val_acc = 0.7441860465116279\n",
"\n",
"epoch: 28\n",
"validation loss decreased (0.547536 ---> 0.514732), val_acc = 0.8023255813953488\n",
"epoch 28: train loss = 0.42175784723546905, train acc = 0.9638776262440103,\n",
"val_loss = 0.5147316435048747, val_acc = 0.8023255813953488\n",
"\n",
"epoch: 29\n",
"epoch 29: train loss = 0.39061159156935704, train acc = 0.9712495392554368,\n",
"val_loss = 0.535475094651067, val_acc = 0.7616279069767442\n",
"\n",
"epoch: 30\n",
"validation loss decreased (0.514732 ---> 0.508574), val_acc = 0.8081395348837209\n",
"epoch 30: train loss = 0.36516626786139894, train acc = 0.9764098783634353,\n",
"val_loss = 0.5085743651833645, val_acc = 0.8081395348837209\n",
"\n",
"epoch: 31\n",
"epoch 31: train loss = 0.33421938774377447, train acc = 0.9826760044231478,\n",
"val_loss = 0.5448793408482574, val_acc = 0.7616279069767442\n",
"\n",
"epoch: 32\n",
"epoch 32: train loss = 0.3104572842646647, train acc = 0.9900479174345743,\n",
"val_loss = 0.5426593824874523, val_acc = 0.7965116279069767\n",
"\n",
"epoch: 33\n",
"epoch 33: train loss = 0.28988734955415046, train acc = 0.9904165130851456,\n",
"val_loss = 0.5494307279586792, val_acc = 0.7848837209302325\n",
"\n",
"epoch: 34\n",
"epoch 34: train loss = 0.26993356656641526, train acc = 0.9941024695908588,\n",
"val_loss = 0.5200380724529887, val_acc = 0.813953488372093\n",
"\n",
"epoch: 35\n",
"epoch 35: train loss = 0.2577803114266401, train acc = 0.9959454478437154,\n",
"val_loss = 0.5192096496737281, val_acc = 0.8197674418604651\n",
"\n",
"epoch: 36\n",
"epoch 36: train loss = 0.24427844815222574, train acc = 0.997788426096572,\n",
"val_loss = 0.5306130952613298, val_acc = 0.8197674418604651\n",
"\n",
"epoch: 37\n",
"validation loss decreased (0.508574 ---> 0.504719), val_acc = 0.8372093023255814\n",
"epoch 37: train loss = 0.23392101167160517, train acc = 0.9981570217471434,\n",
"val_loss = 0.504719233097032, val_acc = 0.8372093023255814\n",
"\n",
"epoch: 38\n",
"epoch 38: train loss = 0.22321982001230103, train acc = 0.9992628086988573,\n",
"val_loss = 0.5188674448534499, val_acc = 0.8313953488372093\n",
"\n",
"epoch: 39\n",
"epoch 39: train loss = 0.2165969933942372, train acc = 0.9996314043494287,\n",
"val_loss = 0.5341297630653825, val_acc = 0.8372093023255814\n",
"\n",
"epoch: 40\n",
"epoch 40: train loss = 0.2110399665799125, train acc = 1.0,\n",
"val_loss = 0.5399310172990311, val_acc = 0.8197674418604651\n",
"\n",
"!!! overfitted !!!\n",
"[0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1]\n",
"[1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0]\n",
"early stoping results:\n",
"\t 0.7552083333333334\n",
"\t 0.9996314043494287\n",
"[1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1]\n",
"[1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1]\n",
"full train results:\n",
"\t 0.7395833333333334\n",
"\t 1.0\n",
"------------fold 1-----------\n",
"[ 283 1266]\n",
"0.47221186124580383\n",
"data shape = (1913, 7, 9, 20, 11)\n",
"train shape = 2681\n",
"validation shape = 172\n",
"test shape = 192\n",
"calculating total steps...\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"creating model...\n",
"epoch: 1\n",
"validation loss decreased (inf ---> 0.646471), val_acc = 0.8255813953488372\n",
"epoch 1: train loss = 0.86992137868295, train acc = 0.5330100708690787,\n",
"val_loss = 0.6464709279149078, val_acc = 0.8255813953488372\n",
"\n",
"epoch: 2\n",
"epoch 2: train loss = 0.8734584440939553, train acc = 0.505408429690414,\n",
"val_loss = 0.6706514441689779, val_acc = 0.6046511627906976\n",
"\n",
"epoch: 3\n",
"epoch 3: train loss = 0.8688026436315193, train acc = 0.5210742260350616,\n",
"val_loss = 0.6759761059006979, val_acc = 0.5697674418604651\n",
"\n",
"epoch: 4\n",
"epoch 4: train loss = 0.8661403424790527, train acc = 0.5203282357329355,\n",
"val_loss = 0.6691561743270519, val_acc = 0.5872093023255814\n",
"\n",
"epoch: 5\n",
"epoch 5: train loss = 0.8571998774227212, train acc = 0.5389779932860873,\n",
"val_loss = 0.6676994409672049, val_acc = 0.5872093023255814\n",
"\n",
"epoch: 6\n",
"epoch 6: train loss = 0.8471142522704463, train acc = 0.5479298769116001,\n",
"val_loss = 0.6665740276491919, val_acc = 0.5872093023255814\n",
"\n",
"epoch: 7\n",
"epoch 7: train loss = 0.841995612025839, train acc = 0.5557627750839239,\n",
"val_loss = 0.6650435148283492, val_acc = 0.6162790697674418\n",
"\n",
"epoch: 8\n",
"epoch 8: train loss = 0.8332126938065056, train acc = 0.5747855277881387,\n",
"val_loss = 0.6654224465059679, val_acc = 0.6162790697674418\n",
"\n",
"epoch: 9\n",
"epoch 9: train loss = 0.828857630791392, train acc = 0.5863483774710929,\n",
"val_loss = 0.6667825177658436, val_acc = 0.6046511627906976\n",
"\n",
"epoch: 10\n",
"epoch 10: train loss = 0.8170930889727953, train acc = 0.6094740768370012,\n",
"val_loss = 0.662314846072086, val_acc = 0.6104651162790697\n",
"\n",
"epoch: 11\n",
"epoch 11: train loss = 0.8065087129401876, train acc = 0.6258858634837747,\n",
"val_loss = 0.6654108665710272, val_acc = 0.5755813953488372\n",
"\n",
"epoch: 12\n",
"epoch 12: train loss = 0.7962307475787942, train acc = 0.6478925773964939,\n",
"val_loss = 0.6616917931756308, val_acc = 0.5813953488372093\n",
"\n",
"epoch: 13\n",
"epoch 13: train loss = 0.776722994525679, train acc = 0.6833271167474823,\n",
"val_loss = 0.6591206082077914, val_acc = 0.5872093023255814\n",
"\n",
"epoch: 14\n",
"epoch 14: train loss = 0.7678748263836085, train acc = 0.701603879149571,\n",
"val_loss = 0.6578961167224618, val_acc = 0.6104651162790697\n",
"\n",
"epoch: 15\n",
"epoch 15: train loss = 0.7510278017295914, train acc = 0.7135397239835882,\n",
"val_loss = 0.6544989874196607, val_acc = 0.5988372093023255\n",
"\n",
"epoch: 16\n",
"validation loss decreased (0.646471 ---> 0.642755), val_acc = 0.6104651162790697\n",
"epoch 16: train loss = 0.7301953062776397, train acc = 0.7601641178664678,\n",
"val_loss = 0.6427546046500983, val_acc = 0.6104651162790697\n",
"\n",
"epoch: 17\n",
"validation loss decreased (0.642755 ---> 0.635279), val_acc = 0.6627906976744186\n",
"epoch 17: train loss = 0.7135588408227911, train acc = 0.7859007832898173,\n",
"val_loss = 0.6352790053500685, val_acc = 0.6627906976744186\n",
"\n",
"epoch: 18\n",
"validation loss decreased (0.635279 ---> 0.631912), val_acc = 0.6395348837209303\n",
"epoch 18: train loss = 0.6874780969493471, train acc = 0.8157403953748601,\n",
"val_loss = 0.6319117740143178, val_acc = 0.6395348837209303\n",
"\n",
"epoch: 19\n",
"validation loss decreased (0.631912 ---> 0.623324), val_acc = 0.6627906976744186\n",
"epoch 19: train loss = 0.6639259783290563, train acc = 0.8489369638194704,\n",
"val_loss = 0.6233235237210296, val_acc = 0.6627906976744186\n",
"\n",
"epoch: 20\n",
"validation loss decreased (0.623324 ---> 0.621211), val_acc = 0.6686046511627907\n",
"epoch 20: train loss = 0.6353608048793674, train acc = 0.8776575904513242,\n",
"val_loss = 0.6212111226347989, val_acc = 0.6686046511627907\n",
"\n",
"epoch: 21\n",
"validation loss decreased (0.621211 ---> 0.620210), val_acc = 0.6627906976744186\n",
"epoch 21: train loss = 0.6062298955778629, train acc = 0.8955613577023499,\n",
"val_loss = 0.6202103772828745, val_acc = 0.6627906976744186\n",
"\n",
"epoch: 22\n",
"validation loss decreased (0.620210 ---> 0.578789), val_acc = 0.7267441860465116\n",
"epoch 22: train loss = 0.5769774110314563, train acc = 0.9089891831406192,\n",
"val_loss = 0.5787890664366788, val_acc = 0.7267441860465116\n",
"\n",
"epoch: 23\n",
"epoch 23: train loss = 0.5446096906151535, train acc = 0.9365908243192839,\n",
"val_loss = 0.58748683818551, val_acc = 0.7267441860465116\n",
"\n",
"epoch: 24\n",
"epoch 24: train loss = 0.5065792173602962, train acc = 0.9470346885490488,\n",
"val_loss = 0.6247755801954935, val_acc = 0.6395348837209303\n",
"\n",
"epoch: 25\n",
"validation loss decreased (0.578789 ---> 0.532145), val_acc = 0.7790697674418605\n",
"epoch 25: train loss = 0.47399600875639997, train acc = 0.9593435285341291,\n",
"val_loss = 0.5321454031522884, val_acc = 0.7790697674418605\n",
"\n",
"epoch: 26\n",
"epoch 26: train loss = 0.4397122574603989, train acc = 0.9668034315553897,\n",
"val_loss = 0.5539705656295599, val_acc = 0.7674418604651163\n",
"\n",
"epoch: 27\n",
"epoch 27: train loss = 0.4094221261689486, train acc = 0.9701603879149571,\n",
"val_loss = 0.6161841694698778, val_acc = 0.686046511627907\n",
"\n",
"epoch: 28\n",
"epoch 28: train loss = 0.3771897173559252, train acc = 0.9802312569936591,\n",
"val_loss = 0.5943776341371758, val_acc = 0.7383720930232558\n",
"\n",
"epoch: 29\n",
"epoch 29: train loss = 0.35380360624186363, train acc = 0.9832152182021634,\n",
"val_loss = 0.5820831060409546, val_acc = 0.7383720930232558\n",
"\n",
"epoch: 30\n",
"epoch 30: train loss = 0.3256017004140764, train acc = 0.9891831406191719,\n",
"val_loss = 0.642474444799645, val_acc = 0.7034883720930233\n",
"\n",
"epoch: 31\n",
"epoch 31: train loss = 0.3058795078238281, train acc = 0.990302126072361,\n",
"val_loss = 0.5841574142145556, val_acc = 0.7732558139534884\n",
"\n",
"epoch: 32\n",
"epoch 32: train loss = 0.28646321554141274, train acc = 0.9962700484893696,\n",
"val_loss = 0.6434207239816355, val_acc = 0.7325581395348837\n",
"\n",
"epoch: 33\n",
"epoch 33: train loss = 0.26851461289178874, train acc = 0.9947780678851175,\n",
"val_loss = 0.6213336531506028, val_acc = 0.7558139534883721\n",
"\n",
"epoch: 34\n",
"epoch 34: train loss = 0.2543404344545732, train acc = 0.9977620290936218,\n",
"val_loss = 0.6432186337404473, val_acc = 0.75\n",
"\n",
"epoch: 35\n",
"epoch 35: train loss = 0.24356636835740866, train acc = 0.9970160387914957,\n",
"val_loss = 0.6490278299464736, val_acc = 0.7732558139534884\n",
"\n",
"epoch: 36\n",
"epoch 36: train loss = 0.2339746813105363, train acc = 0.999627004848937,\n",
"val_loss = 0.6225385125293288, val_acc = 0.7558139534883721\n",
"\n",
"epoch: 37\n",
"epoch 37: train loss = 0.22574142918285262, train acc = 1.0,\n",
"val_loss = 0.6539836512055508, val_acc = 0.7616279069767442\n",
"\n",
"!!! overfitted !!!\n",
"[1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0]\n",
"[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1]\n",
"early stoping results:\n",
"\t 0.7760416666666666\n",
"\t 0.9671764267064528\n",
"[1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 1]\n",
"[1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1]\n",
"full train results:\n",
"\t 0.7604166666666666\n",
"\t 0.999627004848937\n",
"------------fold 2-----------\n",
"[ 283 1266]\n",
"0.47221186124580383\n",
"data shape = (1913, 7, 9, 20, 11)\n",
"train shape = 2681\n",
"validation shape = 172\n",
"test shape = 192\n",
"calculating total steps...\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"Input \u001b[1;32mIn [47]\u001b[0m, in \u001b[0;36m<cell line: 4>\u001b[1;34m()\u001b[0m\n\u001b[0;32m 12\u001b[0m test_loader \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mutils\u001b[38;5;241m.\u001b[39mdata\u001b[38;5;241m.\u001b[39mDataLoader(test, batch_size\u001b[38;5;241m=\u001b[39mbatch_size, shuffle\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n\u001b[0;32m 14\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mcalculating total steps...\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m---> 15\u001b[0m steps \u001b[38;5;241m=\u001b[39m \u001b[43mcalculate_steps\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrain_loader\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn_epochs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 16\u001b[0m train_model(train_loader, val_loader, epochs\u001b[38;5;241m=\u001b[39mn_epochs, lr\u001b[38;5;241m=\u001b[39mlr, fold\u001b[38;5;241m=\u001b[39mfold, steps\u001b[38;5;241m=\u001b[39msteps)\n\u001b[0;32m 17\u001b[0m evaluate_model(test_loader, fold\u001b[38;5;241m=\u001b[39mfold)\n",
"Input \u001b[1;32mIn [11]\u001b[0m, in \u001b[0;36mcalculate_steps\u001b[1;34m(train_loader, epochs)\u001b[0m\n\u001b[0;32m 2\u001b[0m steps \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[0;32m 3\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m epoch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(epochs):\n\u001b[1;32m----> 4\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m data, label \u001b[38;5;129;01min\u001b[39;00m train_loader:\n\u001b[0;32m 5\u001b[0m steps \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[0;32m 6\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m steps\n",
"File \u001b[1;32m~\\Desktop\\Master\\bci\\lib\\site-packages\\torch\\utils\\data\\dataloader.py:677\u001b[0m, in \u001b[0;36m_BaseDataLoaderIter.__next__\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m 676\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__next__\u001b[39m(\u001b[38;5;28mself\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Any:\n\u001b[1;32m--> 677\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautograd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprofiler\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrecord_function\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_profile_name\u001b[49m\u001b[43m)\u001b[49m:\n\u001b[0;32m 678\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_sampler_iter \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m 679\u001b[0m \u001b[38;5;66;03m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[39;00m\n\u001b[0;32m 680\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reset() \u001b[38;5;66;03m# type: ignore[call-arg]\u001b[39;00m\n",
"File \u001b[1;32m~\\Desktop\\Master\\bci\\lib\\site-packages\\torch\\autograd\\profiler.py:443\u001b[0m, in \u001b[0;36mrecord_function.__init__\u001b[1;34m(self, name, args)\u001b[0m\n\u001b[0;32m 440\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrun_callbacks_on_exit: \u001b[38;5;28mbool\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[0;32m 441\u001b[0m \u001b[38;5;66;03m# Stores underlying RecordFunction as a tensor. TODO: move to custom\u001b[39;00m\n\u001b[0;32m 442\u001b[0m \u001b[38;5;66;03m# class (https://github.com/pytorch/pytorch/issues/35026).\u001b[39;00m\n\u001b[1;32m--> 443\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mhandle: torch\u001b[38;5;241m.\u001b[39mTensor \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mzeros\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\n",
"\u001b[1;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"dataset = Dataset(all_data, vowel_label)\n",
"data = dataset.picturize()\n",
"\n",
"for fold, (train_idx, test_idx) in enumerate(skf.split(data, labels)):\n",
" print('------------fold {}-----------'.format(fold))\n",
" train, val, test = dataset.split(train_idx, test_idx)\n",
" train_label = [item[1] for item in train]\n",
" print(sum(train_label)/len(train_label))\n",
" dataset.show()\n",
" train_loader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True)\n",
" val_loader = torch.utils.data.DataLoader(val, batch_size=batch_size, shuffle=True)\n",
" test_loader = torch.utils.data.DataLoader(test, batch_size=batch_size, shuffle=True)\n",
" \n",
" print('calculating total steps...')\n",
" steps = calculate_steps(train_loader, n_epochs)\n",
" train_model(train_loader, val_loader, epochs=n_epochs, lr=lr, fold=fold, steps=steps)\n",
" evaluate_model(test_loader, fold=fold)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4d614e0a",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

+ 2988
- 0
CNN-RNN.ipynb
File diff suppressed because it is too large
View File


+ 731
- 0
MFCC.ipynb
File diff suppressed because it is too large
View File


+ 5291
- 0
captum.ipynb
File diff suppressed because it is too large
View File


Loading…
Cancel
Save