Skip to content
Snippets Groups Projects
Commit c1c0b313 authored by TheRiPtide's avatar TheRiPtide
Browse files

feat: Better visualizability of first convolutional layer with onehot encoding

parent 465fcb01
No related branches found
No related tags found
1 merge request!23feat: deep-leaning poly(A) classifier
No preview for this file type
This diff is collapsed.
......@@ -15,7 +15,7 @@ class Net(Module):
self.cnn_layers = Sequential(
# Defining a 1D convolution layer
Conv1d(1, 4, kernel_size=3, stride=1, padding=1),
Conv1d(4, 4, kernel_size=3, stride=1, padding=1),
BatchNorm1d(4),
ReLU(inplace=True),
MaxPool1d(kernel_size=2, stride=2),
......@@ -42,11 +42,11 @@ class PolyAClassifier:
"""Classifier object using the state-dict of a pretrained pytorch model."""
enum = {
'A': 0.0,
'U': 1 / 3,
'T': 1 / 3,
'G': 2 / 3,
'C': 1.0
'A': [1, 0, 0, 0],
'U': [0, 1, 0, 0],
'T': [0, 1, 0, 0],
'G': [0, 0, 1, 0],
'C': [0, 0, 0, 1]
}
def __init__(self, model=Net, state_dict_path: str = './models/internal_priming.pth'):
......@@ -103,7 +103,7 @@ class PolyAClassifier:
raise ValueError('Not all sequences of length 200')
test_shape = test.shape
test = test.reshape(test_shape[0], 1, test_shape[1])
test = test.reshape(test_shape[0], 4, test_shape[1])
if test_shape[1] != 200:
raise ValueError('Sequences not of length 200')
......
No preview for this file type
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment