diff --git a/models/internal_priming.pth b/models/internal_priming.pth index 7eb6ff99f2a924ff4bc1a60a12fc26199d0b0e09..eeab5fda317b000d341c6d4f6badf410ba6e96f0 100644 Binary files a/models/internal_priming.pth and b/models/internal_priming.pth differ diff --git a/notebooks/internal_priming.ipynb b/notebooks/internal_priming.ipynb index 55fd43adb98a8c02168d59027944348527e9e485..d1c274f997859f2959cf0e04a2284ea5ced0a835 100644 --- a/notebooks/internal_priming.ipynb +++ b/notebooks/internal_priming.ipynb @@ -22,32 +22,13 @@ }, { "cell_type": "code", -<<<<<<< HEAD -<<<<<<< HEAD - "execution_count": 80, -======= - "execution_count": null, ->>>>>>> d2ef840 (chore: started cnn notebook) -======= - "execution_count": 80, ->>>>>>> fb8e822ed92fba85e584305fcb18bdf45ad601df + "execution_count": 14, "outputs": [], "source": [ "# importing the libraries\n", "import pandas as pd\n", "import numpy as np\n", -<<<<<<< HEAD -<<<<<<< HEAD -<<<<<<< HEAD - "import matplotlib.pyplot as plt\n", -======= ->>>>>>> d2ef840 (chore: started cnn notebook) -======= "import matplotlib.pyplot as plt\n", ->>>>>>> 93ea318 (chore: added training function for cnn) -======= - "import matplotlib.pyplot as plt\n", ->>>>>>> fb8e822ed92fba85e584305fcb18bdf45ad601df "\n", "# for creating validation set\n", "from sklearn.model_selection import train_test_split\n", @@ -59,18 +40,9 @@ "# PyTorch libraries and modules\n", "import torch\n", "from torch.autograd import Variable\n", -<<<<<<< HEAD -<<<<<<< HEAD "from torch.nn import Linear, ReLU, CrossEntropyLoss, Sequential, MaxPool1d, Module, Softmax, BatchNorm1d, Dropout, Conv1d\n", - "from torch.optim import Adam\n", -======= - "from torch.nn import Linear, ReLU, CrossEntropyLoss, Sequential, Conv2d, MaxPool2d, Module, Softmax, BatchNorm2d, Dropout\n", "from torch.optim import Adam, SGD\n", ->>>>>>> d2ef840 (chore: started cnn notebook) -======= - "from torch.nn import Linear, ReLU, CrossEntropyLoss, Sequential, MaxPool1d, Module, Softmax, BatchNorm1d, Dropout, Conv1d\n", - "from torch.optim import Adam\n", ->>>>>>> fb8e822ed92fba85e584305fcb18bdf45ad601df + "from torchsummary import summary\n", "\n", "\n", "# adding the nn\n", @@ -79,10 +51,6 @@ " super(Net, self).__init__()\n", "\n", " self.cnn_layers = Sequential(\n", -<<<<<<< HEAD -<<<<<<< HEAD -======= ->>>>>>> fb8e822ed92fba85e584305fcb18bdf45ad601df " # Defining a 1D convolution layer\n", " Conv1d(1, 4, kernel_size=3, stride=1, padding=1),\n", " BatchNorm1d(4),\n", @@ -97,25 +65,6 @@ "\n", " self.linear_layers = Sequential(\n", " Linear(4 * 50, 10)\n", -<<<<<<< HEAD -======= - " # Defining a 2D convolution layer\n", - " Conv2d(1, 4, kernel_size=3, stride=1, padding=1),\n", - " BatchNorm2d(4),\n", - " ReLU(inplace=True),\n", - " MaxPool2d(kernel_size=2, stride=2),\n", - " # Defining another 2D convolution layer\n", - " Conv2d(4, 4, kernel_size=3, stride=1, padding=1),\n", - " BatchNorm2d(4),\n", - " ReLU(inplace=True),\n", - " MaxPool2d(kernel_size=2, stride=2),\n", - " )\n", - "\n", - " self.linear_layers = Sequential(\n", - " Linear(4 * 7 * 7, 10)\n", ->>>>>>> d2ef840 (chore: started cnn notebook) -======= ->>>>>>> fb8e822ed92fba85e584305fcb18bdf45ad601df " )\n", "\n", " # Defining the forward pass\n", @@ -123,13 +72,6 @@ " x = self.cnn_layers(x)\n", " x = x.view(x.size(0), -1)\n", " x = self.linear_layers(x)\n", -<<<<<<< HEAD -<<<<<<< HEAD -<<<<<<< HEAD -======= ->>>>>>> 93ea318 (chore: added training function for cnn) -======= ->>>>>>> fb8e822ed92fba85e584305fcb18bdf45ad601df " return x\n", "\n", "# defining training function\n", @@ -164,15 +106,6 @@ " tr_loss = loss_train.item()\n", "\n", " return loss_train, loss_val" -<<<<<<< HEAD -<<<<<<< HEAD -======= - " return x" ->>>>>>> d2ef840 (chore: started cnn notebook) -======= ->>>>>>> 93ea318 (chore: added training function for cnn) -======= ->>>>>>> fb8e822ed92fba85e584305fcb18bdf45ad601df ], "metadata": { "collapsed": false, @@ -195,26 +128,22 @@ }, { "cell_type": "code", -<<<<<<< HEAD -<<<<<<< HEAD -======= ->>>>>>> fb8e822ed92fba85e584305fcb18bdf45ad601df - "execution_count": 81, + "execution_count": 15, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 20000/20000 [00:00<00:00, 97752.58it/s]\n" + "100%|██████████| 20000/20000 [00:00<00:00, 27099.07it/s]\n" ] } ], "source": [ "enum = {\n", - " 'A': 0.0,\n", - " 'U': 1/3,\n", - " 'G': 2/3,\n", - " 'C': 1.0\n", + " 'A': [1, 0, 0, 0],\n", + " 'U': [0, 1, 0, 0],\n", + " 'G': [0, 0, 1, 0],\n", + " 'C': [0, 0, 0, 1]\n", "}\n", "\n", "# TODO: Get test data from issues 25 and 26\n", @@ -256,7 +185,7 @@ }, { "cell_type": "code", - "execution_count": 82, + "execution_count": 17, "outputs": [], "source": [ "# TODO: reshape shape from [n, l] to [n, 1, l]\n", @@ -270,31 +199,14 @@ "train_shape = train_x.shape\n", "val_shape = val_x.shape\n", "\n", - "train_x = train_x.reshape(train_shape[0], 1, train_shape[1])\n", - "val_x = val_x.reshape(val_shape[0], 1, val_shape[1])\n", + "train_x = train_x.reshape(train_shape[0], 1, train_shape[1], 4)\n", + "val_x = val_x.reshape(val_shape[0], 1, val_shape[1], 4)\n", "\n", "train_x = torch.from_numpy(train_x)\n", "train_y = torch.from_numpy(train_y)\n", "\n", "val_x = torch.from_numpy(val_x)\n", "val_y = torch.from_numpy(val_y)" -<<<<<<< HEAD -======= - "execution_count": null, - "outputs": [], - "source": [ - "# TODO: Get test data from issues 25 and 26\n", - "train_x = []\n", - "train_y = []\n", - "test_x = []\n", - "test_y = []\n", - "\n", - "train_x, val_x, train_y, val_y = train_test_split(train_x, train_y, test_size = 0.1)\n", - "\n", - "# TODO: reshape shape from [n, l] to [n, 1, l]\n" ->>>>>>> d2ef840 (chore: started cnn notebook) -======= ->>>>>>> fb8e822ed92fba85e584305fcb18bdf45ad601df ], "metadata": { "collapsed": false, @@ -317,17 +229,32 @@ }, { "cell_type": "code", -<<<<<<< HEAD -<<<<<<< HEAD -======= ->>>>>>> fb8e822ed92fba85e584305fcb18bdf45ad601df - "execution_count": 83, + "execution_count": 18, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 25/25 [00:18<00:00, 1.34it/s]\n" + " 0%| | 0/25 [00:00<?, ?it/s]\n" + ] + }, + { + "ename": "RuntimeError", + "evalue": "Expected 3-dimensional input for 3-dimensional weight [4, 1, 3], but got 4-dimensional input of size [18000, 1, 200, 4] instead", + "output_type": "error", + "traceback": [ + "\u001B[1;31m---------------------------------------------------------------------------\u001B[0m", + "\u001B[1;31mRuntimeError\u001B[0m Traceback (most recent call last)", + "\u001B[1;32m~\\AppData\\Local\\Temp/ipykernel_14744/999922600.py\u001B[0m in \u001B[0;36m<module>\u001B[1;34m\u001B[0m\n\u001B[0;32m 24\u001B[0m \u001B[1;31m# training the model\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 25\u001B[0m \u001B[1;32mfor\u001B[0m \u001B[0mepoch\u001B[0m \u001B[1;32min\u001B[0m \u001B[0mtqdm\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mrange\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mn_epochs\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m---> 26\u001B[1;33m \u001B[0mtrain_loss\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mval_loss\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mtrain\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 27\u001B[0m \u001B[0mtrain_losses\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mappend\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mtrain_loss\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 28\u001B[0m \u001B[0mval_losses\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mappend\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mval_loss\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n", + "\u001B[1;32m~\\AppData\\Local\\Temp/ipykernel_14744/2669949571.py\u001B[0m in \u001B[0;36mtrain\u001B[1;34m()\u001B[0m\n\u001B[0;32m 67\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 68\u001B[0m \u001B[1;31m# prediction for training and validation set\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m---> 69\u001B[1;33m \u001B[0moutput_train\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mmodel\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mx_train\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 70\u001B[0m \u001B[0moutput_val\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mmodel\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mx_val\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 71\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n", + "\u001B[1;32mc:\\users\\gzaug\\onedrive\\dokumente\\uni\\programming in life sciences\\scrna-seq-simulation\\venv\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001B[0m in \u001B[0;36m_call_impl\u001B[1;34m(self, *input, **kwargs)\u001B[0m\n\u001B[0;32m 1100\u001B[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001B[0;32m 1101\u001B[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001B[1;32m-> 1102\u001B[1;33m \u001B[1;32mreturn\u001B[0m \u001B[0mforward_call\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m*\u001B[0m\u001B[0minput\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;33m**\u001B[0m\u001B[0mkwargs\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 1103\u001B[0m \u001B[1;31m# Do not call functions when jit is used\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 1104\u001B[0m \u001B[0mfull_backward_hooks\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mnon_full_backward_hooks\u001B[0m \u001B[1;33m=\u001B[0m \u001B[1;33m[\u001B[0m\u001B[1;33m]\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;33m[\u001B[0m\u001B[1;33m]\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n", + "\u001B[1;32m~\\AppData\\Local\\Temp/ipykernel_14744/2669949571.py\u001B[0m in \u001B[0;36mforward\u001B[1;34m(self, x)\u001B[0m\n\u001B[0;32m 43\u001B[0m \u001B[1;31m# Defining the forward pass\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 44\u001B[0m \u001B[1;32mdef\u001B[0m \u001B[0mforward\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mself\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mx\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m---> 45\u001B[1;33m \u001B[0mx\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mcnn_layers\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mx\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 46\u001B[0m \u001B[0mx\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mx\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mview\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mx\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0msize\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;36m0\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;33m-\u001B[0m\u001B[1;36m1\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 47\u001B[0m \u001B[0mx\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mlinear_layers\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mx\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n", + "\u001B[1;32mc:\\users\\gzaug\\onedrive\\dokumente\\uni\\programming in life sciences\\scrna-seq-simulation\\venv\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001B[0m in \u001B[0;36m_call_impl\u001B[1;34m(self, *input, **kwargs)\u001B[0m\n\u001B[0;32m 1100\u001B[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001B[0;32m 1101\u001B[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001B[1;32m-> 1102\u001B[1;33m \u001B[1;32mreturn\u001B[0m \u001B[0mforward_call\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m*\u001B[0m\u001B[0minput\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;33m**\u001B[0m\u001B[0mkwargs\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 1103\u001B[0m \u001B[1;31m# Do not call functions when jit is used\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 1104\u001B[0m \u001B[0mfull_backward_hooks\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mnon_full_backward_hooks\u001B[0m \u001B[1;33m=\u001B[0m \u001B[1;33m[\u001B[0m\u001B[1;33m]\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;33m[\u001B[0m\u001B[1;33m]\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n", + "\u001B[1;32mc:\\users\\gzaug\\onedrive\\dokumente\\uni\\programming in life sciences\\scrna-seq-simulation\\venv\\lib\\site-packages\\torch\\nn\\modules\\container.py\u001B[0m in \u001B[0;36mforward\u001B[1;34m(self, input)\u001B[0m\n\u001B[0;32m 139\u001B[0m \u001B[1;32mdef\u001B[0m \u001B[0mforward\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mself\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0minput\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 140\u001B[0m \u001B[1;32mfor\u001B[0m \u001B[0mmodule\u001B[0m \u001B[1;32min\u001B[0m \u001B[0mself\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 141\u001B[1;33m \u001B[0minput\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mmodule\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0minput\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 142\u001B[0m \u001B[1;32mreturn\u001B[0m \u001B[0minput\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 143\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n", + "\u001B[1;32mc:\\users\\gzaug\\onedrive\\dokumente\\uni\\programming in life sciences\\scrna-seq-simulation\\venv\\lib\\site-packages\\torch\\nn\\modules\\module.py\u001B[0m in \u001B[0;36m_call_impl\u001B[1;34m(self, *input, **kwargs)\u001B[0m\n\u001B[0;32m 1100\u001B[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001B[0;32m 1101\u001B[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001B[1;32m-> 1102\u001B[1;33m \u001B[1;32mreturn\u001B[0m \u001B[0mforward_call\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m*\u001B[0m\u001B[0minput\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;33m**\u001B[0m\u001B[0mkwargs\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 1103\u001B[0m \u001B[1;31m# Do not call functions when jit is used\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 1104\u001B[0m \u001B[0mfull_backward_hooks\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mnon_full_backward_hooks\u001B[0m \u001B[1;33m=\u001B[0m \u001B[1;33m[\u001B[0m\u001B[1;33m]\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;33m[\u001B[0m\u001B[1;33m]\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n", + "\u001B[1;32mc:\\users\\gzaug\\onedrive\\dokumente\\uni\\programming in life sciences\\scrna-seq-simulation\\venv\\lib\\site-packages\\torch\\nn\\modules\\conv.py\u001B[0m in \u001B[0;36mforward\u001B[1;34m(self, input)\u001B[0m\n\u001B[0;32m 299\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 300\u001B[0m \u001B[1;32mdef\u001B[0m \u001B[0mforward\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mself\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0minput\u001B[0m\u001B[1;33m:\u001B[0m \u001B[0mTensor\u001B[0m\u001B[1;33m)\u001B[0m \u001B[1;33m->\u001B[0m \u001B[0mTensor\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 301\u001B[1;33m \u001B[1;32mreturn\u001B[0m \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0m_conv_forward\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0minput\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mweight\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mbias\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m 302\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 303\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n", + "\u001B[1;32mc:\\users\\gzaug\\onedrive\\dokumente\\uni\\programming in life sciences\\scrna-seq-simulation\\venv\\lib\\site-packages\\torch\\nn\\modules\\conv.py\u001B[0m in \u001B[0;36m_conv_forward\u001B[1;34m(self, input, weight, bias)\u001B[0m\n\u001B[0;32m 295\u001B[0m \u001B[0mweight\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mbias\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mstride\u001B[0m\u001B[1;33m,\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m 296\u001B[0m _single(0), self.dilation, self.groups)\n\u001B[1;32m--> 297\u001B[1;33m return F.conv1d(input, weight, bias, self.stride,\n\u001B[0m\u001B[0;32m 298\u001B[0m self.padding, self.dilation, self.groups)\n\u001B[0;32m 299\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n", + "\u001B[1;31mRuntimeError\u001B[0m: Expected 3-dimensional input for 3-dimensional weight [4, 1, 3], but got 4-dimensional input of size [18000, 1, 200, 4] instead" ] } ], @@ -341,37 +268,10 @@ "# defining the loss function\n", "criterion = CrossEntropyLoss()\n", "\n", -<<<<<<< HEAD -======= - "execution_count": null, - "outputs": [], - "source": [ - "# defining the model\n", - "model = Net()\n", - "\n", - "# defining the optimizer\n", - "optimizer = Adam(model.parameters(), lr=0.07)\n", - "\n", - "# defining the loss function\n", - "criterion = CrossEntropyLoss()\n", -<<<<<<< HEAD ->>>>>>> d2ef840 (chore: started cnn notebook) -======= - "\n", ->>>>>>> 93ea318 (chore: added training function for cnn) -======= ->>>>>>> fb8e822ed92fba85e584305fcb18bdf45ad601df "# checking if GPU is available\n", "if torch.cuda.is_available():\n", " model = model.cuda()\n", " criterion = criterion.cuda()\n", -<<<<<<< HEAD -<<<<<<< HEAD -<<<<<<< HEAD -======= ->>>>>>> 93ea318 (chore: added training function for cnn) -======= ->>>>>>> fb8e822ed92fba85e584305fcb18bdf45ad601df "\n", "# defining the number of epochs\n", "n_epochs = 25\n", @@ -383,10 +283,6 @@ "val_losses = []\n", "\n", "# training the model\n", -<<<<<<< HEAD -<<<<<<< HEAD -======= ->>>>>>> fb8e822ed92fba85e584305fcb18bdf45ad601df "for epoch in tqdm(range(n_epochs)):\n", " train_loss, val_loss = train()\n", " train_losses.append(train_loss)\n", @@ -413,19 +309,8 @@ }, { "cell_type": "code", - "execution_count": 84, - "outputs": [ - { - "data": { - "text/plain": "<Figure size 432x288 with 1 Axes>", - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXgAAAD4CAYAAADmWv3KAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8/fFQqAAAACXBIWXMAAAsTAAALEwEAmpwYAAArwklEQVR4nO3deXQc5Znv8e/Tu9QttXbbWAbbhNXgVcaAgdhkwSzDagg+uQEPJ2xDQiAzk5Blgie5zOQm3DkMMwm5zgIhh4zJkIkDFxMybDHgG8B2HMBgggEZS9609aKlpV7e+0d3C9mWrG6p1eUuPZ9z+lhdXV31FJ381Kp663nFGINSSin7cVhdgFJKqYmhAa+UUjalAa+UUjalAa+UUjalAa+UUjblsmrHdXV1ZubMmVbtXimlStKWLVvajTH1uaxrWcDPnDmTzZs3W7V7pZQqSSKyK9d19RSNUkrZlAa8UkrZlAa8UkrZlGXn4JVSxRePx2lpaSEWi1ldihqFz+ejsbERt9s95m1owCs1ibS0tFBRUcHMmTMREavLUSMwxtDR0UFLSwuzZs0a83b0FI1Sk0gsFqO2tlbD/SgnItTW1o77Ly0NeKUmGQ330lCIz6nkAv6dfVG+//QOQr0DVpeilFJHtZIL+N7Xf8vNm5ZzoPltq0tRSuWpo6OD+fPnM3/+fKZOncr06dMHnw8MHPlL2+bNm7n99ttH3cfZZ59dkFpfeOEFLrnkkoJsyyold5E1EAhQKb3s7NwDzLO6HKVUHmpra9m2bRsAa9asIRAI8Hd/93eDrycSCVyu4WOpqamJpqamUfexadOmgtRqByX3DT5QcwwAsa69FleilCqE1atXc8stt7BkyRK+8pWv8Oqrr3LWWWexYMECzj77bN555x3g4G/Ua9as4YYbbmDZsmXMnj2b+++/f3B7gUBgcP1ly5axcuVKTj75ZD772c+SncFuw4YNnHzyySxatIjbb7991G/qnZ2dXH755cydO5czzzyT119/HYA//OEPg3+BLFiwgGg0yt69eznvvPOYP38+p512Gi+++GLB/5vlquS+wQcb0gGfiBywuBKlSts/PrGdt/ZECrrNU4+p5O6/mpP3+1paWti0aRNOp5NIJMKLL76Iy+XimWee4etf/zq//vWvD3vPjh07eP7554lGo5x00knceuuth40Z/9Of/sT27ds55phjWLp0KS+//DJNTU3cfPPNbNy4kVmzZrFq1apR67v77rtZsGAB69ev57nnnuO6665j27Zt3HvvvfzgBz9g6dKldHd34/P5WLt2LRdccAHf+MY3SCaT9Pb25v3fo1BKLuDLg1NJGYHu/VaXopQqkKuvvhqn0wlAOBzm+uuv591330VEiMfjw77n4osvxuv14vV6aWhoYP/+/TQ2Nh60zhlnnDG4bP78+TQ3NxMIBJg9e/bg+PJVq1axdu3aI9b30ksvDf6SOf/88+no6CASibB06VK+/OUv89nPfpYrr7ySxsZGFi9ezA033EA8Hufyyy9n/vz54/lPMy4lF/A4XUSkAmdfm9WVKFXSxvJNe6L4/f7Bn//hH/6B5cuX85vf/Ibm5maWLVs27Hu8Xu/gz06nk0QiMaZ1xuOuu+7i4osvZsOGDSxdupSnn36a8847j40bN/Lkk0+yevVqvvzlL3PdddcVdL+5Krlz8ABhZzWeWIfVZSilJkA4HGb69OkAPPTQQwXf/kknncT7779Pc3MzAI8++uio7zn33HN55JFHgPS5/bq6OiorK3nvvfc4/fTT+epXv8rixYvZsWMHu3btYsqUKdx44418/vOfZ+vWrQU/hlyVZMD3emrxxzutLkMpNQG+8pWv8LWvfY0FCxYU/Bs3QFlZGT/84Q9ZsWIFixYtoqKigmAweMT3rFmzhi1btjB37lzuuusufv7znwNw3333cdpppzF37lzcbjcXXnghL7zwAvPmzWPBggU8+uijfOlLXyr4MeRKsleVi62pqcmMdcKPbfetpC70Oo1r/lLgqpSyt7fffptTTjnF6jIs193dTSAQwBjDbbfdxgknnMCdd95pdVmHGe7zEpEtxpjRx4tSot/gk+X1VJsQiWTK6lKUUiXoxz/+MfPnz2fOnDmEw2Fuvvlmq0uaEKV3kRUQfz1+6edAuIuGmlqry1FKlZg777zzqPzGXmgl+Q3eVTkFgNCBVosrUUqpo1dJBryvehoAPR16N6tSSo2kJAM+UJMO+FjXPosrUUqpo1dJBnywIT1GNh7RgFdKqZGUZMCXV00FwHRrPxqlSsny5ct5+umnD1p23333ceutt474nmXLlpEdUn3RRRcRCoUOW2fNmjXce++9R9z3+vXreeuttwaff+tb3+KZZ57Jo/rhHc1thUsy4MXlIUwAZ1+71aUopfKwatUq1q1bd9CydevW5dTwC9JdIKuqqsa070MD/tvf/jaf/OQnx7StUlGSAQ8QcVbjiWnAK1VKVq5cyZNPPjk4uUdzczN79uzh3HPP5dZbb6WpqYk5c+Zw9913D/v+mTNn0t6e/v/9Pffcw4knnsg555wz2FIY0mPcFy9ezLx587jqqqvo7e1l06ZNPP744/z93/898+fP57333mP16tU89thjADz77LMsWLCA008/nRtuuIH+/v7B/d19990sXLiQ008/nR07dhzx+I62tsIlOQ4eoMddQ/mAtitQasyeugv2vVHYbU49HS787ogv19TUcMYZZ/DUU09x2WWXsW7dOq655hpEhHvuuYeamhqSySSf+MQneP3115k7d+6w29myZQvr1q1j27ZtJBIJFi5cyKJFiwC48sorufHGGwH45je/yU9/+lO++MUvcumll3LJJZewcuXKg7YVi8VYvXo1zz77LCeeeCLXXXcdDzzwAHfccQcAdXV1bN26lR/+8Ifce++9/OQnPxnx+I62tsIl+w2+31tLZVIDXqlSM/Q0zdDTM7/61a9YuHAhCxYsYPv27QedTjnUiy++yBVXXEF5eTmVlZVceumlg6+9+eabnHvuuZx++uk88sgjbN++/Yj1vPPOO8yaNYsTTzwRgOuvv56NGzcOvn7llVcCsGjRosEGZSN56aWX+NznPgcM31b4/vvvJxQK4XK5WLx4MQ8++CBr1qzhjTfeoKKi4ojbHouS/QafLK+nOhQmmTI4HTpLvFJ5O8I37Yl02WWXceedd7J161Z6e3tZtGgRH3zwAffeey+vvfYa1dXVrF69mlgsNqbtr169mvXr1zNv3jweeughXnjhhXHVm205PJ52w1a1FS7Zb/D4G6iQPrrCYasrUUrlIRAIsHz5cm644YbBb++RSAS/308wGGT//v089dRTR9zGeeedx/r16+nr6yMajfLEE08MvhaNRpk2bRrxeHywxS9ARUUF0Wj0sG2ddNJJNDc3s3PnTgB+8Ytf8PGPf3xMx3a0tRUu2W/w2XYF4bZW6qqrrC1GKZWXVatWccUVVwyeqsm21z355JOZMWMGS5cuPeL7Fy5cyGc+8xnmzZtHQ0MDixcvHnztO9/5DkuWLKG+vp4lS5YMhvq1117LjTfeyP333z94cRXA5/Px4IMPcvXVV5NIJFi8eDG33HLLmI4rO1fs3LlzKS8vP6it8PPPP4/D4WDOnDlceOGFrFu3ju9///u43W4CgQAPP/zwmPZ5JCXZLhjgnY3/yUnPfZ5tFzzG/LM+VcDKlLIvbRdcWiZlu2AAf2168u2+Lu1Ho5RSwynZgA/WZ9oVhPVuVqWUGk7JBnygOtuuYL/FlShVWqw6LavyU4jPqWQDXtw+Ivhx9LZZXYpSJcPn89HR0aEhf5QzxtDR0YHP5xvXdkp2FA1k2hX0d1hdhlIlo7GxkZaWFtra9IvR0c7n89HY2DiubYwa8CIyA3gYmAIYYK0x5l8PWUeAfwUuAnqB1caYwg/qPESPS9sVKJUPt9vNrFmzrC5DFUkup2gSwN8aY04FzgRuE5FTD1nnQuCEzOMm4IGCVjmCfm8tFQkNeKWUGs6oAW+M2Zv9Nm6MiQJvA9MPWe0y4GGT9kegSkSmFbzaQyTK66g2YVIpPZ+olFKHyusiq4jMBBYArxzy0nRg95DnLRz+SwARuUlENovI5oKcA/Q3EJQeQtHu8W9LKaVsJueAF5EA8GvgDmNMZCw7M8asNcY0GWOa6uvrx7KJgzgrGgAIte8Z97aUUspucgp4EXGTDvdHjDH/NcwqrcCMIc8bM8smlLcqfRaou23Cd6WUUiVn1IDPjJD5KfC2MeZfRljtceA6STsTCBtjJryHgL82fbNTb0jbFSil1KFyGQe/FPgc8IaIbMss+zpwLIAx5kfABtJDJHeSHib51wWvdBjBuvRp/oHwvmLsTimlSsqoAW+MeQk44owaJn1b3G2FKipXFZmGYyaqN20opdShSrZVAYB4yumhDEevNhxTSqlDlXTAA4Qc1Xhi1rcrCPfFuefJt4jFk1aXopRSgA0CvsddQ9mA9QG/8a3dDGz6EVub260uRSmlABsEfMxbS0Wiy+oycL/3e/7R/XNSH/7R6lKUUgqwQcAnyuqoMiHL25+azmYABnTIplLqKFHyAW/8DVRLN+HuXkvrcEfTnRqSUb3gq5Q6OpR8wGfbFXRZfDdrRSzdLsHRrQGvlDo6lHzA+6rSd7NG2607NZJKGeoS6Zut3DEdk6+UOjqUfMCX16Rvdurrsi7g27tjTCcd7D6dYUopdZQo+YCvrEsH/EDEunYF+/Z8iE/iAAR0AhKl1FGi9AM+064gGbHu3Hd473sARB1BqlNdlo/oUUopsEHAO3wBevHhtLBdQd+B9wFor5pLLWHCvQOW1aKUUlklH/AAYUcV7ph1d5Carg8BiE1ZgEeSdLTvt6wWpZTKskXAd7tqKBuw7ty3u7uFkARx1qZnq4906AxTSinr2SLg0+0KrAv4yr5WQp6plFenrwf0durdrEop69ki4ONldVSlrGlXYIyhNrGP3vJGKrITkIR0AhKllPVsEfD466mim0hvrOi7bov2cQztJIMzBicgSUU14JVS1rNFwDsrpuAQQ1db8c9972/dhVcSuGtn4SivJo4L6dG7WZVS1rNFwHuC0wCIthc/4LNj4P1TZoPDQdgRxN2nAa+Usp4tAr68Jt2PpteCdgWxzBj4msaPARC1eESPUkpl2SLgK7LtCsLFH39uQrsA8Nenh0jGPLVUJLQfjVLKerYI+GBm9EoqWvyA90Rb6JRqcJcBEC+rJ2jRiB6llBrKFgHv9FUQw2PJxc2K2B66vNMGn6f89dQRJtzbX/RalFJqKFsEPCKEHFW4+4rbrsCYdB/4vvLpg8ucFVNwSYpObVeglLKYPQKedLsC30Bxz323R/qYRgep4LGDyzzB9AXfiMUzTCmllG0Cvs9T/HYF+1vfxy1J3LUzB5f5Mzc7WTGiRymlhrJNwCcsaFcQ2bsTgMCU4weXVdRm2xVowCulrGWbgDf+eqqJ0N1XvIubsbZmAGoaTxhclh2ymdLJt5VSFrNNwDsqGnCKobO9eH1gTFczKQR//XEf1VFWxQAuHD0a8Eopa9km4LMXN6Ptxbu46eluocNRCy7vRwtFCDmqiz6iRymlDmWbgC+vTo9FL2Yv9orYXkKeaYctT09AonezKqWsZZuA/6hdQXFO0RhjqE/sIzZkDHxWn7eWgIUTkCilFOQQ8CLyMxE5ICJvjvD6MhEJi8i2zONbhS9zdNl2Bclocc59t4d7mEoHySFj4LPivnqqU13arkApZalcvsE/BKwYZZ0XjTHzM49vj7+s/LnK0xc3i9Wu4EDLezjF4BkyBj7LBBqoIUK4p/gTkCilVNaoAW+M2Qgc/ecbROiSalxF6sUe3pfuAx+YdvxhrzkzI3q62nUsvFLKOoU6B3+WiPxZRJ4SkTkjrSQiN4nIZhHZ3NZW+CDudlXjK1Iv9v62DwComf6xw17zZiYgiVgwAYlSSmUVIuC3AscZY+YB/wasH2lFY8xaY0yTMaapvr6+ALs+WJ+nlkC8SH9shHaRRAgMGQOfVV6TDvi+Io7oUUqpQ4074I0xEWNMd+bnDYBbROrGXdkYxDPtCorB291Cu6MenO7DXquoT1/w7S/SiB6llBrOuANeRKaKiGR+PiOzTUsGgafK66khTE9sYML3VRnbQ3iYMfAAlZmGYyaqAa+Uso5rtBVE5D+AZUCdiLQAdwNuAGPMj4CVwK0ikgD6gGuNReMDHRUNuCTF3vZ9+BsPH75YKOk+8PvZX3XW8HX4Ki2bgEQppbJGDXhjzKpRXv934N8LVtE4DPZib2+FCQz4jnCUBrrYG5wx/ArZdgUxvZtVKWUd29zJCh+1K+jpmNhTIwda3sMhBk/dzBHXSbcr0H40Sinr2CrgP2pXMLGjVwb7wE89fAx8VsxbS0Wia0LrUEqpI7FVwAfri9OuoL89PQa+dvoJI64TL9N2BUopa9kq4N3+GuI4kYnuxd71IXGcBOpHPs+fnoAkSkTbFSilLGKrgEeEkFThmuBe7J7sGHiHc8R1nJVTcYihUyffVkpZxF4BD0Rd1fj6J3b0SrB/D2Hv8GPgs7yDI3r0blallDVsF/B9nlr8E9iLPTsGvs/feMT1yjM3O/V1aT8apZQ1bBfwA75agsnQhG2/IxSmQUKYYfrAD5W9mzWu7QqUUhaxXcBn2xX09ScmZPttLekhku66WUdcrzIzZDMZ2T8hdSil1GhsF/DOwBQ8kqSzY2KCNbr3fQAqps4+4noOXwW9+HD0arsCpZQ1bBfw7qr0xc3wBPVi729PB3xt48hj4LNCjmo8MQ14pZQ1bBfwvup0wPd0TNDFzdCHDOCiovbIF1kBelzVlBVpAhKllDqU7QI+e3GzPzwxp2i83btpczSAY/T/dH3eWiomcESPUkodie0CPlif/mY9URc3K2N7CXmPyWndRFkD1amQtitQSlnCdgHvCdSSwIH0FD7gjTHUJ/cT80/PbX1/PdUSJdLdV/BalFJqNLYLeBwOQhKckHYFnV2d1EoEU5Vbr3ln5ZT0+7RdgVLKAvYLeCDqrME7Ae0K2lreA8Azyhj4LG9mRE+0XQNeKVV8tgz4Pk8N/njhL25GM33gK6aM3Ad+qPKabLsC7UejlCo+Wwb8gK+OYLLwk20M9oGfMfoYeIBgXfpcvbYrUEpZwZYBnyqvo4YwsYECtysI7SaGe3Ao5miy7QpS3dquQClVfLYMeAlMwSsJOjoLexepr2c3BxxTQCSn9R1eP92U4ejRu1mVUsVny4B3B9OjVyJthb2btTK2l3COY+Czwo5q3BM8AYlSSg3HlgHvq05PxtHTWbiLm8YYGpL7iAVGb1EwVLe7hvKBiZ2ARCmlhmPLgK+oSQd8f6hwAR/q6qBKekbtA3+omKeOiqS2K1BKFZ8tAz5Ynx69kihgu4K23X8BwFs3M6/3JcrrtF2BUsoStgx4X7CBJAIFvLgZ2ZftA5/bGPgs428gKD1Eoj0Fq0UppXJhy4DH4SQsQZx9hQv4gcwY+LoZJ+b1PlemXUGXtitQShWZPQMeiDir8cYKd3FTQrvoxUtlzZS83uepSl8PiHZowCulisu2Ad/nqSEQL1zAe7tbOeCYmvMY+Cx/ZkSPtitQShWbbQO+31tHZSpUsO0F+/cQ9k3L+32VmQu+AyFtV6CUKi7bBnyqvJ5aE6I/Pv52BSaVoiG5n35/fmPgAYKZdgVG2xUopYrMtgEvgQZ8Eqezc/xj0ENd7VRIH6bquLzf6/CUEcGv7QqUUkVn24DPjl4JF6AXe/vudwDw5DkGPivsqMYb03YFSqniGjXgReRnInJARN4c4XURkftFZKeIvC4iCwtfZv581enJNno6x9+PJjsGvnJafmPgs7rdNZRpuwKlVJHl8g3+IWDFEV6/EDgh87gJeGD8ZY1fRaalb6xr/Oe+xzoGPivmrdV2BUqpohs14I0xG4EjpdNlwMMm7Y9AlYjkP9ykwD5qVzD+0SsS+pAI5QSr68f0/kRZvbYrUEoVXSHOwU8Hdg953pJZdhgRuUlENovI5ra2ib3oWBacQsoUpl2Br6eVNmd+NzgdJNBAhfQRiUbHXYtSSuWqqBdZjTFrjTFNxpim+vqxfRvOmdNFRCpw9o4/4Kv69xDxjv2PEmdFpl3BAb2bVSlVPIUI+FZgxpDnjZlllos4q/H2j+/i5uAY+MCM0VcegTfTrqBb2xUopYqoEAH/OHBdZjTNmUDYGHNU3Jff66mhPD6+i5vhjn2USz9U5dcHfih/5oKvtitQShWTa7QVROQ/gGVAnYi0AHcDbgBjzI+ADcBFwE6gF/jriSo2X/3eWmr63hjXNtpbdlIFeOpnjXkb2btZ42FtV6CUKp5RA94Ys2qU1w1wW8EqKqBkeT01XSHiyRRu59j+WInuew+Ayjz7wA+VDfhUVNsVKKWKx7Z3skK6XYFf+unq6hrzNrJj4OvHOAYewOH2EiJQkAu+SimVK1sHvKsyfTdrqK1lzNuQ8IeECBCsqhlXLWFHDR5tV6CUKiJbB3y2XUG0Y+wXN8t6WsY3Bj6jx11DubYrUEoVka0DPlCTHp4YG8folWD/XiK+Y8ZdS8xbS0Vi7KeKlFIqX7YO+I/aFRwY0/vH0wf+UImyeqqMtitQShWPrQO+PDNdHj1jC/hwWys+iY9rDPygQAMBiRGJhMe/LaWUyoGtAx6nm/A4Rq+0t+4EwFs/e9ylZPvTh9r0blalVHHYO+BJtyvw9I9t9Er3vnTAV04df8B7Mu0KogWYgEQppXJh+4BPj14ZW7uCgfZmABrGMQY+y187/gu+SimVD9sHfL+3jsrk2EavOMK76TCVVAaD464jWJe+UKvtCpRSxWL7gE+W11FtwiSSqbzfW9bTQrtrCiIy7jqq6qaRMoLpHtsFX6WUypftA14yk210hvMfvRIc2EvEN+zcJXlzuNyEpAJHASYgUUqpXNg+4F2ZyTbCbflNvm1SSRqSB+gPFCbgAcKOsV/wVUqpfNk+4L2D7QryC/hIWwseSUDVcQWrRdsVKKWKyfYB769JtxnId/RKR8u7APjG0Qf+UP2+OirGeMFXKaXyZfuAH5xsI5JfL/bBPvDTxt4H/lCJsnpqUl2YVP4XfJVSKl+2D3h/TfoUDd35BXy8oxmAKY0nFLCYBspkgEgkVLhtKqXUCGwf8OIuI4IfR29+Fzcd4Q9pM1VUVlYUrBZnMHvBd+z96ZVSKle2D3hItyvw5jnZRllPK22uqQUZA5/lG2xXkN8FX6WUGotR52S1gx53DeW9rXz13x4i6HFQ6RUqPEKlBwIeIeCGgNvgd4PfBeUuQ0N/MzvL5xe0Dn9t5oJvSNsVKKUm3qQI+Kqps5jS/Fv+V8eX8nrf1uApBa2jMnvBdzztCrb9EhpOhWPmF6YopZRtTYqAn7LyXth9NThcBz0SOOiJQzQO0QFDJC509xvC/YZoHM4/Y2FB66iqnUbSCCY6xnYF0f2k1t9GT/AEKu54BQp4+kgpZT+TIuAJNMApf3XYYhcQzDyKwely0SFBHGPtT//ar6gjRUX4HeJvb8B96sUFrlApZSeT4iLr0STkrMY7xnYFA9se493UdD5M1RN++p9Bp/9TSh2BBnyRjbVdgQm3cExkG1sqz+eJymupC79BYufzE1ChUsouNOCLrN9bR2Ui/3YFB/74KAC++Ss5ZcXN7DU1dD51T6HLU0rZiAZ8kSXL6qg2obzbFSTf+DVvpmZyzplnsXxOI4/7r6KhczOJD16eoEqVUqVOA77YAlPwSpxoJPdpBE3nBxzTvZ03qz9BXcCLiDD7gttoN5W0P/VPE1isUqqUacAXmbMy3a4gdCD3dgV7N60DoGLR1YPLPnH6TB4vu5ypB14i2bK1sEUqpWxBA77IfNXpdgXdebQrkLf+iz+njuecxU2DyxwOYfqnbidsyjnwpJ6LV0odTgO+yPw16YCPhXK7mzXV9i7Tev/CjrpPEyxzH/TapxZ8jN96L2Xa3mdI7dte8FqVUqVNA77IgvWNAMTDufWjaX35EQCqz7jmsNccDqHuk7fTbXzs02/xSqlD5BTwIrJCRN4RkZ0ictcwr68WkTYR2ZZ5fL7wpdpDVe0UEsaB6c6tXYFnx3o2m5M5Z+HcYV+/oOlUnvBcxJTdT2HadxayVKVUiRs14EXECfwAuBA4FVglIqcOs+qjxpj5mcdPClynbTidTrokiDOHdgXxvduZEvuA9xs+Tbln+K4STocQWP4l4sbJnif/udDlKqVKWC7f4M8Adhpj3jfGDADrgMsmtix7Czur8eTQn37Py78kaYT6Mw8/PTPUiiVz+b/uTzPlg99gunYVqkylVInLJeCnA7uHPG/JLDvUVSLyuog8JiIzhtuQiNwkIptFZHNb29gabtlBj7sGf3yUdgXGUPaX3/KazOHsecP9wfQRt9OB59w7SBnYs+F7BaxUKVXKCnWR9QlgpjFmLvDfwM+HW8kYs9YY02SMaaqvry/QrktPv7eOilHaFfS3bKNhYDe7p63A63KOus0VS5t4ynU+9e8+ionohCJKqdwCvhUY+o28MbNskDGmwxjTn3n6E2BRYcqzp0R5/ajtClpfeoSEcXDMWUc+PZPlcTlILb0Dp0mw56nvF6pUpVQJyyXgXwNOEJFZIuIBrgUeH7qCiEwb8vRS4O3ClWhDgQY8kiQaHuE8vDFUvvcEr8pclsw5IefNXnTeWfzeeR61Ox6Bnvw7Viql7GXUgDfGJIAvAE+TDu5fGWO2i8i3ReTSzGq3i8h2EfkzcDuweqIKtgN3pl1BeIR2BX3Nr1KX2Me+GRficuZ+Fs3rctK35HZ8JkbL7/6lILUqpUpXTulhjNlgjDnRGHO8MeaezLJvGWMez/z8NWPMHGPMPGPMcmPMjoksutR5qzLtCjqGb1fQ+tIjDBgnxy79TN7bvuj85TwrZ1L95oMQC4+rTqVUadM7WS3gr01Pvh3rGuZiaCpFTfMGXnEsYOGJM/Pets/tJNz0Jfymh9bf3z/OSpVSpUwD3gLBuky7gsj+w17r3vkyNck2OmZegsMxtkm1V3zq07zIQiq3rYWBnnHVqpQqXRrwFqiubWDAOKH78IDf+/IviRk3x59z9TDvzE25x8WBBV+kIhWh9dkHxlOqUqqEacBbwOl00ClVOA5tV5BKUr/7d7ziWsRps4e7lyx3F6y4lFc4Df/mH0I8Nq5tKaVKkwa8RSLOaryHtCsI73iBqlQn4eMvRWRsp2eyAl4XraffRlWygz0vaGsgpSYjDXiL9LhrKY8fPG3f/k2/pMd4OfncsZ+eGeoTF65kGyfifeV+SMYLsk2lVOnQgLfIgLeWYGJIwCfjTG39Pa96zuDEGQ0F2Uew3MMHp/wNtYn97H1+bUG2qZQqHRrwFkmU11NlwphUEoCON/6bShOh92OFbdR5/iWfZQunUL7pe5hYpKDbVkod3TTgrRJowCUpukPpiT/aX1lHxJRx2sevLOhugn4PLYu/TjAVYtcT3y3otpVSRzcNeIu4KqcCEDrQCokBpu97hs2+szluam3B93XhBZfwnPMcpm7/MYmu4dsjKKXsRwPeIr6qdMD3dO5l3582EDA9DJw0MfOoeFwOnJ9eg5gUzY99c0L2oZQ6+mjAWyRQlx7n3te5l9Cr6wgZP3M/fvmE7e+8M5r4vf9SZreup+fDP0/YfpRSRw8NeIsEMwFvQrs4tu0FtpSfwzG1wQnbn4gw+8q7iZoy9v/6qxO2H6XU0UMD3iLVNXX0GzfTd/2Gcvowp14x4fuc87GZPD9lNbPD/4+2bU9N+P6UUtbSgLdIul1BkCnxVjpMJfM/funobyqAM675KrtNA/1PfQMyQzSVUvakAW+hsLMGgD8FzqOu0l+UfR5TV8W2E2+nsf89dr3wYFH2qZSyhga8hXrd6YB3zr2qqPtdftUtbOdjBF76J4y2E1bKtjTgLdTln80uM4WF51xU1P0GfG72LPkGtakO/vLb7xV130qp4tGAt9CxK/+J96/8HUG/r+j7Xv7py9nkWkLj9v9Df3hf0fevlJp4GvAWOmFaNcvnzbZk3y6nA9eK7+A1/ez8z29ZUoNSamJpwE9iZzQt4Q8VF3NSy38S3v2W1eUopQpMA36SO+6q79BnvOzRm5+Ush0N+EnuY7NmsWnq/+CU0Eb2/PlZq8tRShWQBrxi4We+wT5TQ/+Gr4MxVpejlCoQDXhFfU01b570RWb17+Dd5x62uhylVIFowCsAzrnqC7wrx1Hx8j2kBmJWl6OUKgANeAWAz+vhwJnfYGpqP2/+9l6ry1FKFYAGvBp01qeuYat7ITO3P0BfuMPqcpRS46QBrwY5HIJrxXcImB7+8su/JRXrtrokpdQ4uKwuQB1d5i46h41/uJjz9v+G+Hef4L2yOcSO/TjTFl5I3QlLwOG0ukSlVI7EWDQsrqmpyWzevNmSfasj64v18+rz60m8+xyNna9wEh8AEKGC3VVNyPHnc+ziiwlMPd7iSpWafERkizGmKad1NeDVkRhjePf992nd+jtczS9wQvdmpkonAHudx7C//mwCp3yK45ouwO2vzr4JknFIxCDRTyreR3wgxkB/H4n+PpIDfSQGYqTi/YBBRHA4nDhEwCE4HA4c4kBEEIcDpzjSy0Vwebx4AnVQVg2+KnDqH6FqctGAVxOmP57g7dc30/HG76hsfYlTB17HL/0kjIM+KcNDHDdxHBTnf1e9Uk6fs4KYK0jcEyTprYKyKqS8Bqe/Bm9FLb7KOvxVDbgr6qG8Fsqq9FSTKln5BHxOX39EZAXwr4AT+Ikx5ruHvO4FHgYWAR3AZ4wxzfkUrUqD1+1i/qIzYdGZAISi3bzy2rP0vvMcjv4IKaeXlNMLLi/G6QGXF5w+cPvA5cXh8uLw+BCXD3F5QQRjUqRSBpNKpR/GkDLpf9PP069jDMlEjFRvF6a3C4mFcA+EcMcj+Poj+HvDBGklKN1U0YNbhp+SMIXQ4wjQ66yi3xMk7qkmVVYN5bU4/LV4Kupw+6vwZh4+fxCHrwK8mYf+clAlYtRv8CLiBP4CfApoAV4DVhlj3hqyzt8Ac40xt4jItcAVxpjPHGm7+g1eFZoxhp6BJKHeAUI9A0QjIXpCbfRH24lH20n2dCC9HThjnbj6Q5TFQ5Qnw1SkIlRLlBqi+CQ+6n768NHnKKff4WfAVU7cFSDpKifl9GKcXozTh3F5weVLP9xexFWGuH043D6cbh8ObxkOpxeH04XD5cLh8uBwuhCnG6fLjcOV+dfhxulO/+x0edKnspwORJwgjvQvG3GAOEEk/VC2Vuhv8GcAO40x72c2vg64DBjaX/YyYE3m58eAfxcRMVad/1GTkogQ8LoIeF00VpcDVcDMUd+XTBnCfXFaewcIR8L0dB4g3hcm3hsm2RchGYtCLIIMRJGBbpwDUdyJHlzJHrzxHnyxHnymDS9xPMTxMoCXePohiQk+6oOljJDEgREhhWBwZP4VQDBAavAEWnr5R49Dlg3+rkj/YAb/TS8xg9v9aJ3s60MZhv7SGbJenr+MJN/TfkeIn5G2NXJFuW9r+G0cvM6uWdey5Lr/OeI2CyWXgJ8O7B7yvAVYMtI6xpiEiISBWqB96EoichNwE8Cxxx47xpKVKiynQ6jxe6jxe6A+AMdPH/O2UinDQDJFfyJFJJFiIJEg3t+XfsR6SQz0kezvIxHvwyTi6UcqTiqZwCTjmGQCknFMKoFJffSzZF4X0qesMCkklQRSkMo8NynEJDPBlkJSqfTrmeeYTIwbkMHl6YeYVPoAMsuy630UTJn4N2Yw4A99bejrg88PejZ0vfz/2w7dc86O8EvEjLitXCL60LeMfkvR0P05a2eNun4hFHUIgjFmLbAW0qdoirlvpYrB4RB8Dic+d/Y8vRfwW1mSmsRyuZO1FZgx5HljZtmw64iICwiSvtiqlFLKIrkE/GvACSIyS0Q8wLXA44es8zhwfebnlcBzev5dKaWsNeopmsw59S8AT5MeJvkzY8x2Efk2sNkY8zjwU+AXIrIT6CT9S0AppZSFcjoHb4zZAGw4ZNm3hvwcA64ubGlKKaXGQ7tJKqWUTWnAK6WUTWnAK6WUTWnAK6WUTVnWTVJE2oBdY3x7HYfcJTvJTObjn8zHDpP7+PXY044zxtTn8ibLAn48RGRzrs127GgyH/9kPnaY3Mevx57/sespGqWUsikNeKWUsqlSDfi1Vhdgscl8/JP52GFyH78ee55K8hy8Ukqp0ZXqN3illFKj0IBXSimbKrmAF5EVIvKOiOwUkbusrqeYRKRZRN4QkW0iYvsJbUXkZyJyQETeHLKsRkT+W0TezfxbbWWNE2WEY18jIq2Zz3+biFxkZY0TRURmiMjzIvKWiGwXkS9llk+Wz36k48/78y+pc/C5TABuZyLSDDQZYybFzR4ich7QDTxsjDkts+x7QKcx5ruZX/DVxpivWlnnRBjh2NcA3caYe62sbaKJyDRgmjFmq4hUAFuAy4HVTI7PfqTjv4Y8P/9S+wY/OAG4MWYAyE4ArmzIGLOR9PwCQ10G/Dzz889J/w/fdkY49knBGLPXGLM183MUeJv0vM+T5bMf6fjzVmoBP9wE4GOfIbn0GOD3IrIlM4H5ZDTFGLM38/M+YIqVxVjgCyLyeuYUji1PUQwlIjOBBcArTMLP/pDjhzw//1IL+MnuHGPMQuBC4LbMn/GTVmZayNI5xzh+DwDHA/OBvcD/trSaCSYiAeDXwB3GmMjQ1ybDZz/M8ef9+ZdawOcyAbhtGWNaM/8eAH5D+pTVZLM/c44ye67ygMX1FI0xZr8xJmmMSQE/xsafv4i4SYfbI8aY/8osnjSf/XDHP5bPv9QCPpcJwG1JRPyZCy6IiB/4NPDmkd9lS0MneL8e+K2FtRRVNtwyrsCmn7+ICOl5nt82xvzLkJcmxWc/0vGP5fMvqVE0AJmhQffx0QTg91hbUXGIyGzS39ohPZfuL+1+7CLyH8Ay0q1S9wN3A+uBXwHHkm43fY0xxnYXI0c49mWk/zw3QDNw85Bz0rYhIucALwJvAKnM4q+TPg89GT77kY5/FXl+/iUX8EoppXJTaqdolFJK5UgDXimlbEoDXimlbEoDXimlbEoDXimlbEoDXimlbEoDXimlbOr/AxgLcM3x+pyyAAAAAElFTkSuQmCC\n" - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], + "execution_count": null, + "outputs": [], "source": [ "train_losses_list = [train_loss.item() for train_loss in train_losses]\n", "val_losses_list = [val_loss.item() for val_loss in val_losses]\n", @@ -445,17 +330,8 @@ }, { "cell_type": "code", - "execution_count": 85, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "0.9995\n", - "0.9995\n" - ] - } - ], + "execution_count": null, + "outputs": [], "source": [ "# prediction for training set\n", "with torch.no_grad():\n", @@ -499,28 +375,14 @@ }, { "cell_type": "code", - "execution_count": 86, + "execution_count": null, "outputs": [], "source": [ - "torch.save(model.state_dict(), '../models/internal_priming.pth')" -<<<<<<< HEAD -======= - "\n" ->>>>>>> d2ef840 (chore: started cnn notebook) -======= - "for epoch in range(n_epochs):\n", - " train_loss, val_loss = train()\n", - " train_losses.append(train_loss)\n", - " val_losses.append(val_loss)\n", + "torch.save(model.state_dict(), '../models/internal_priming.pth')\n", "\n", - "# plotting the training and validation loss\n", - "plt.plot(train_losses, label='Training loss')\n", - "plt.plot(val_losses, label='Validation loss')\n", - "plt.legend()\n", - "plt.show()" ->>>>>>> 93ea318 (chore: added training function for cnn) -======= ->>>>>>> fb8e822ed92fba85e584305fcb18bdf45ad601df + "for name, param in model.named_parameters():\n", + " if param.requires_grad:\n", + " print(name, param.data)" ], "metadata": { "collapsed": false,