Skip to content
Snippets Groups Projects
Commit 96fbed82 authored by TheRiPtide's avatar TheRiPtide
Browse files

chore: added training function for cnn

parent 1f63980c
Branches
No related tags found
1 merge request!23feat: deep-leaning poly(A) classifier
%% Cell type:markdown id: tags:
# Issue 21: Inferring the code of internal priming by deep learning
In real data sets we would like to distinguish poly(A) sites from internal priming sites. To do this, we want to construct a classifier that uses the sequence flanking the sites. As a deep learning architecture we can use a convolutional neural network, for e.g. from a numpy implementation, https://pypi.org/project/numpycnn/)
Reference: https://www.analyticsvidhya.com/blog/2019/10/building-image-classification-models-cnn-pytorch/
Input: sequences of bona fide and internally-primed poly(A) sites (#16)
Output: classifier based on the nucleotide sequence around the sites
%% Cell type:code id: tags:
``` python
# importing the libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
# for creating validation set
from sklearn.model_selection import train_test_split
# for evaluating the model
from sklearn.metrics import accuracy_score
from tqdm import tqdm
# PyTorch libraries and modules
import torch
from torch.autograd import Variable
from torch.nn import Linear, ReLU, CrossEntropyLoss, Sequential, Conv2d, MaxPool2d, Module, Softmax, BatchNorm2d, Dropout
from torch.optim import Adam, SGD
# adding the nn
class Net(Module):
def __init__(self):
super(Net, self).__init__()
self.cnn_layers = Sequential(
# Defining a 2D convolution layer
Conv2d(1, 4, kernel_size=3, stride=1, padding=1),
BatchNorm2d(4),
ReLU(inplace=True),
MaxPool2d(kernel_size=2, stride=2),
# Defining another 2D convolution layer
Conv2d(4, 4, kernel_size=3, stride=1, padding=1),
BatchNorm2d(4),
ReLU(inplace=True),
MaxPool2d(kernel_size=2, stride=2),
)
self.linear_layers = Sequential(
Linear(4 * 7 * 7, 10)
)
# Defining the forward pass
def forward(self, x):
x = self.cnn_layers(x)
x = x.view(x.size(0), -1)
x = self.linear_layers(x)
return x
# defining training function
def train():
model.train()
tr_loss = 0
# getting the training set
x_train, y_train = Variable(train_x), Variable(train_y)
# getting the validation set
x_val, y_val = Variable(val_x), Variable(val_y)
# converting the data into GPU format
if torch.cuda.is_available():
x_train = x_train.cuda()
y_train = y_train.cuda()
x_val = x_val.cuda()
y_val = y_val.cuda()
# clearing the Gradients of the model parameters
optimizer.zero_grad()
# prediction for training and validation set
output_train = model(x_train)
output_val = model(x_val)
# computing the training and validation loss
loss_train = criterion(output_train, y_train)
loss_val = criterion(output_val, y_val)
# computing the updated weights of all the model parameters
loss_train.backward()
optimizer.step()
tr_loss = loss_train.item()
return loss_train, loss_val
```
%% Cell type:markdown id: tags:
## Load data
%% Cell type:code id: tags:
``` python
# TODO: Get test data from issues 25 and 26
train_x = []
train_y = []
test_x = []
test_y = []
train_x, val_x, train_y, val_y = train_test_split(train_x, train_y, test_size = 0.1)
# TODO: reshape shape from [n, l] to [n, 1, l]
```
%% Cell type:markdown id: tags:
# Model call and loss function definition
%% Cell type:code id: tags:
``` python
# defining the model
model = Net()
# defining the optimizer
optimizer = Adam(model.parameters(), lr=0.07)
# defining the loss function
criterion = CrossEntropyLoss()
# checking if GPU is available
if torch.cuda.is_available():
model = model.cuda()
criterion = criterion.cuda()
# defining the number of epochs
n_epochs = 25
# empty list to store training losses
train_losses = []
# empty list to store validation losses
val_losses = []
# training the model
for epoch in range(n_epochs):
train_loss, val_loss = train()
train_losses.append(train_loss)
val_losses.append(val_loss)
# plotting the training and validation loss
plt.plot(train_losses, label='Training loss')
plt.plot(val_losses, label='Validation loss')
plt.legend()
plt.show()
```
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment