Skip to content
Snippets Groups Projects
Commit c70be5be authored by TheRiPtide's avatar TheRiPtide
Browse files

chore: flake8

parent 6d6f4480
No related branches found
No related tags found
1 merge request!23feat: deep-leaning poly(A) classifier
Pipeline #13780 failed
<<<<<<< HEAD
"""Module for classifying polyA tails as internal or real."""
=======
>>>>>>> fb8e822ed92fba85e584305fcb18bdf45ad601df
import torch
from torch.nn import Linear, ReLU, Sequential, MaxPool1d, Module, BatchNorm1d, Conv1d
import numpy as np
......@@ -10,13 +7,10 @@ from typing import Union
class Net(Module):
<<<<<<< HEAD
"""Two layer 1D convolutional neural net"""
=======
>>>>>>> fb8e822ed92fba85e584305fcb18bdf45ad601df
"""Two layer 1D convolutional neural net."""
def __init__(self):
"""Returns Net object."""
super(Net, self).__init__()
self.cnn_layers = Sequential(
......@@ -36,14 +30,8 @@ class Net(Module):
Linear(4 * 50, 10)
)
<<<<<<< HEAD
def forward(self, x):
"""Forward pass function."""
=======
# Defining the forward pass
def forward(self, x):
>>>>>>> fb8e822ed92fba85e584305fcb18bdf45ad601df
x = self.cnn_layers(x)
x = x.view(x.size(0), -1)
x = self.linear_layers(x)
......@@ -51,23 +39,16 @@ class Net(Module):
class PolyAClassifier:
<<<<<<< HEAD
"""Classifier object using the state-dict of a pretrained pytorch model"""
=======
>>>>>>> fb8e822ed92fba85e584305fcb18bdf45ad601df
"""Classifier object using the state-dict of a pretrained pytorch model."""
enum = {
'A': 0.0,
'U': 1 / 3,
<<<<<<< HEAD
'T': 1 / 3,
=======
>>>>>>> fb8e822ed92fba85e584305fcb18bdf45ad601df
'G': 2 / 3,
'C': 1.0
}
<<<<<<< HEAD
def __init__(self, model: Module = Net, state_dict_path: str = './models/internal_priming.pth'):
"""Returns a stateless classifier with the model loaded.
......@@ -75,10 +56,6 @@ class PolyAClassifier:
model: An object subclassing the pytorch Module
state_dict_path: A path to a saved state-dict of said object at a trained state.
"""
=======
def __init__(self, model=Net, state_dict_path='./models/internal_priming.pth'):
>>>>>>> fb8e822ed92fba85e584305fcb18bdf45ad601df
self.model = model()
self.model.load_state_dict(torch.load(state_dict_path))
......@@ -94,13 +71,9 @@ class PolyAClassifier:
Raises:
TypeError: If sequence is not str or list(str)
ValueError: If some or all sequences are not of length 200
<<<<<<< HEAD
ValueError: If non-allowed letters in string
=======
>>>>>>> fb8e822ed92fba85e584305fcb18bdf45ad601df
"""
if type(sequence) is list:
sequences = [list(seq) for seq in sequence]
......@@ -115,7 +88,6 @@ class PolyAClassifier:
enum_seqs = []
<<<<<<< HEAD
try:
for s in sequences:
enum_sequence = [self.enum[key.upper()] for key in s]
......@@ -136,27 +108,6 @@ class PolyAClassifier:
if test_shape[1] != 200:
raise ValueError('Sequences not of length 200')
=======
for s in sequences:
enum_sequence = [self.enum[key] for key in s]
enum_seqs.append(enum_sequence)
# convert to ndarray and reshape for pytorch
test = np.array(enum_seqs, dtype=np.float32)
try:
test_shape = test.shape
test = test.reshape(test_shape[0], 1, test_shape[1])
if test_shape[1] != 200:
raise ValueError('Sequences not of length 200')
except IndexError:
raise ValueError('Not all sequences of length 200')
>>>>>>> fb8e822ed92fba85e584305fcb18bdf45ad601df
test = torch.from_numpy(test)
# make prediction
......@@ -173,29 +124,4 @@ class PolyAClassifier:
else:
<<<<<<< HEAD
return predictions.tolist()
=======
return predictions
if __name__ == '__main__':
mod = PolyAClassifier(state_dict_path='../models/internal_priming.pth')
real_str = 'CGCCGGAAGAACGAAUCUCCCACUGCCCGGGCAUCCAAUGGACUUCAUAGGAAUGGCAGCUGAUAACACCGCCCCCUGUGGCGCGCCAGAGGGCGCGCUUCGUGUAGGCUUCGAUGUCGCGGUAAAAUUCUUGGAUUAAAGAAGGGGCCCUGUGGUAGCAAGUUUUUUAUUCUGUGGGCGCUCUUACGCGUGUAUUGUCU'
fake_str = 'GUUUGAGGCGCAUGACGCGUUUCGGGGGCCUUGCGUCGCCCACGCCGGCGUUCUCUUUAAAAGGAGCAACGACACCACGCCCCAUGGACCAUGCCGCAGGGUGAACGUCGUCCCGCAACUGCCGUGCACCCGUCAAAAGGAGGCGUCUUCAAAAAAAAAACAAAAUAAAAACACAUACCGCGGCGCGUAUUAGAGCGGCG'
list_test = [real_str, fake_str]
pred = mod.classify(real_str)
print(pred)
pred = mod.classify(fake_str)
print(pred)
pred = mod.classify(list_test)
print(pred)
>>>>>>> fb8e822ed92fba85e584305fcb18bdf45ad601df
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment