Skip to content
Snippets Groups Projects
Commit 2200b3f5 authored by TheRiPtide's avatar TheRiPtide
Browse files

Merge branch 'dl-onehot' into deep-learning

parents 485bdca3 c1c0b313
Branches
No related tags found
1 merge request!23feat: deep-leaning poly(A) classifier
Pipeline #13868 failed
No preview for this file type
This diff is collapsed.
...@@ -15,7 +15,7 @@ class Net(Module): ...@@ -15,7 +15,7 @@ class Net(Module):
self.cnn_layers = Sequential( self.cnn_layers = Sequential(
# Defining a 1D convolution layer # Defining a 1D convolution layer
Conv1d(1, 4, kernel_size=3, stride=1, padding=1), Conv1d(4, 4, kernel_size=3, stride=1, padding=1),
BatchNorm1d(4), BatchNorm1d(4),
ReLU(inplace=True), ReLU(inplace=True),
MaxPool1d(kernel_size=2, stride=2), MaxPool1d(kernel_size=2, stride=2),
...@@ -42,11 +42,11 @@ class PolyAClassifier: ...@@ -42,11 +42,11 @@ class PolyAClassifier:
"""Classifier object using the state-dict of a pretrained pytorch model.""" """Classifier object using the state-dict of a pretrained pytorch model."""
enum = { enum = {
'A': 0.0, 'A': [1, 0, 0, 0],
'U': 1 / 3, 'U': [0, 1, 0, 0],
'T': 1 / 3, 'T': [0, 1, 0, 0],
'G': 2 / 3, 'G': [0, 0, 1, 0],
'C': 1.0 'C': [0, 0, 0, 1]
} }
def __init__(self, model=Net, state_dict_path: str = './models/internal_priming.pth'): def __init__(self, model=Net, state_dict_path: str = './models/internal_priming.pth'):
...@@ -103,7 +103,7 @@ class PolyAClassifier: ...@@ -103,7 +103,7 @@ class PolyAClassifier:
raise ValueError('Not all sequences of length 200') raise ValueError('Not all sequences of length 200')
test_shape = test.shape test_shape = test.shape
test = test.reshape(test_shape[0], 1, test_shape[1]) test = test.reshape(test_shape[0], 4, test_shape[1])
if test_shape[1] != 200: if test_shape[1] != 200:
raise ValueError('Sequences not of length 200') raise ValueError('Sequences not of length 200')
......
No preview for this file type
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment