Skip to content
Snippets Groups Projects
Commit 05728f04 authored by TheRiPtide's avatar TheRiPtide
Browse files

chore: rebase 2

parent 3fa94f50
No related branches found
No related tags found
1 merge request!23feat: deep-leaning poly(A) classifier
......@@ -32,10 +32,14 @@
"# importing the libraries\n",
"import pandas as pd\n",
"import numpy as np\n",
<<<<<<< HEAD
<<<<<<< HEAD
"import matplotlib.pyplot as plt\n",
=======
>>>>>>> d2ef840 (chore: started cnn notebook)
=======
"import matplotlib.pyplot as plt\n",
>>>>>>> 93ea318 (chore: added training function for cnn)
"\n",
"# for creating validation set\n",
"from sklearn.model_selection import train_test_split\n",
......@@ -100,6 +104,9 @@
" x = x.view(x.size(0), -1)\n",
" x = self.linear_layers(x)\n",
<<<<<<< HEAD
<<<<<<< HEAD
=======
>>>>>>> 93ea318 (chore: added training function for cnn)
" return x\n",
"\n",
"# defining training function\n",
......@@ -134,9 +141,12 @@
" tr_loss = loss_train.item()\n",
"\n",
" return loss_train, loss_val"
<<<<<<< HEAD
=======
" return x"
>>>>>>> d2ef840 (chore: started cnn notebook)
=======
>>>>>>> 93ea318 (chore: added training function for cnn)
],
"metadata": {
"collapsed": false,
......@@ -325,16 +335,25 @@
"source": [
"# defining the model\n",
"model = Net()\n",
"\n",
"# defining the optimizer\n",
"optimizer = Adam(model.parameters(), lr=0.07)\n",
"\n",
"# defining the loss function\n",
"criterion = CrossEntropyLoss()\n",
<<<<<<< HEAD
>>>>>>> d2ef840 (chore: started cnn notebook)
=======
"\n",
>>>>>>> 93ea318 (chore: added training function for cnn)
"# checking if GPU is available\n",
"if torch.cuda.is_available():\n",
" model = model.cuda()\n",
" criterion = criterion.cuda()\n",
<<<<<<< HEAD
<<<<<<< HEAD
=======
>>>>>>> 93ea318 (chore: added training function for cnn)
"\n",
"# defining the number of epochs\n",
"n_epochs = 25\n",
......@@ -346,6 +365,7 @@
"val_losses = []\n",
"\n",
"# training the model\n",
<<<<<<< HEAD
"for epoch in tqdm(range(n_epochs)):\n",
" train_loss, val_loss = train()\n",
" train_losses.append(train_loss)\n",
......@@ -465,6 +485,18 @@
=======
"\n"
>>>>>>> d2ef840 (chore: started cnn notebook)
=======
"for epoch in range(n_epochs):\n",
" train_loss, val_loss = train()\n",
" train_losses.append(train_loss)\n",
" val_losses.append(val_loss)\n",
"\n",
"# plotting the training and validation loss\n",
"plt.plot(train_losses, label='Training loss')\n",
"plt.plot(val_losses, label='Validation loss')\n",
"plt.legend()\n",
"plt.show()"
>>>>>>> 93ea318 (chore: added training function for cnn)
],
"metadata": {
"collapsed": false,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment