Skip to content
Snippets Groups Projects
Commit f28b078c authored by TheRiPtide's avatar TheRiPtide
Browse files

patch: polyA classifier finished

parent bb566328
No related branches found
No related tags found
1 merge request!23feat: deep-leaning poly(A) classifier
"""Module for classifying polyA tails as internal or real."""
import torch
from torch.nn import Linear, ReLU, Sequential, MaxPool1d, Module, BatchNorm1d, Conv1d
import numpy as np
......@@ -5,6 +7,7 @@ from typing import Union
class Net(Module):
"""Two layer 1D convolutional neural net"""
def __init__(self):
......@@ -27,8 +30,8 @@ class Net(Module):
Linear(4 * 50, 10)
)
# Defining the forward pass
def forward(self, x):
"""Forward pass function."""
x = self.cnn_layers(x)
x = x.view(x.size(0), -1)
......@@ -37,15 +40,23 @@ class Net(Module):
class PolyAClassifier:
"""Classifier object using the state-dict of a pretrained pytorch model"""
enum = {
'A': 0.0,
'U': 1 / 3,
'T': 1 / 3,
'G': 2 / 3,
'C': 1.0
}
def __init__(self, model=Net, state_dict_path='./models/internal_priming.pth'):
def __init__(self, model: Module = Net, state_dict_path: str = './models/internal_priming.pth'):
"""Returns a stateless classifier with the model loaded.
Args:
model: An object subclassing the pytorch Module
state_dict_path: A path to a saved state-dict of said object at a trained state.
"""
self.model = model()
self.model.load_state_dict(torch.load(state_dict_path))
......@@ -62,6 +73,7 @@ class PolyAClassifier:
Raises:
TypeError: If sequence is not str or list(str)
ValueError: If some or all sequences are not of length 200
ValueError: If non-allowed letters in string
"""
......@@ -79,23 +91,27 @@ class PolyAClassifier:
enum_seqs = []
for s in sequences:
enum_sequence = [self.enum[key] for key in s]
enum_seqs.append(enum_sequence)
try:
for s in sequences:
enum_sequence = [self.enum[key.upper()] for key in s]
enum_seqs.append(enum_sequence)
except KeyError:
raise ValueError('String contains non-allowed Letters: only A, T, U, G, C allowed')
# convert to ndarray and reshape for pytorch
test = np.array(enum_seqs, dtype=np.float32)
try:
test_shape = test.shape
test = test.reshape(test_shape[0], 1, test_shape[1])
test = np.array(enum_seqs, dtype=np.float32)
except ValueError:
raise ValueError('Not all sequences of length 200')
if test_shape[1] != 200:
raise ValueError('Sequences not of length 200')
test_shape = test.shape
test = test.reshape(test_shape[0], 1, test_shape[1])
if test_shape[1] != 200:
raise ValueError('Sequences not of length 200')
except IndexError:
raise ValueError('Not all sequences of length 200')
test = torch.from_numpy(test)
......@@ -113,25 +129,4 @@ class PolyAClassifier:
else:
return predictions
if __name__ == '__main__':
mod = PolyAClassifier(state_dict_path='../models/internal_priming.pth')
real_str = 'CGCCGGAAGAACGAAUCUCCCACUGCCCGGGCAUCCAAUGGACUUCAUAGGAAUGGCAGCUGAUAACACCGCCCCCUGUGGCGCGCCAGAGGGCGCGCUUCGUGUAGGCUUCGAUGUCGCGGUAAAAUUCUUGGAUUAAAGAAGGGGCCCUGUGGUAGCAAGUUUUUUAUUCUGUGGGCGCUCUUACGCGUGUAUUGUCU'
fake_str = 'GUUUGAGGCGCAUGACGCGUUUCGGGGGCCUUGCGUCGCCCACGCCGGCGUUCUCUUUAAAAGGAGCAACGACACCACGCCCCAUGGACCAUGCCGCAGGGUGAACGUCGUCCCGCAACUGCCGUGCACCCGUCAAAAGGAGGCGUCUUCAAAAAAAAAACAAAAUAAAAACACAUACCGCGGCGCGUAUUAGAGCGGCG'
list_test = [real_str, fake_str]
pred = mod.classify(real_str)
print(pred)
pred = mod.classify(fake_str)
print(pred)
pred = mod.classify(list_test)
print(pred)
return predictions.tolist()
CGCCGGAAGAACGAAUCUCCCACUGCCCGGGCAUCCAAUGGACUUCAUAGGAAUGGCAGCUGAUAACACCGCCCCCUGUGGCGCGCCAGAGGGCGCGCUUCGUGUAGGCUUCGAUGUCGCGGUAAAAUUCUUGGAUUAAAGAAGGGGCCCUGUGGUAGCAAGUUUUUUAUUCUGUGGGCGCUCUUACGCGUGUAUUGUCU
GUUUGAGGCGCAUGACGCGUUUCGGGGGCCUUGCGUCGCCCACGCCGGCGUUCUCUUUAAAAGGAGCAACGACACCACGCCCCAUGGACCAUGCCGCAGGGUGAACGUCGUCCCGCAACUGCCGUGCACCCGUCAAAAGGAGGCGUCUUCAAAAAAAAAACAAAAUAAAAACACAUACCGCGGCGCGUAUUAGAGCGGCG
GUUUGAGGCGCAUGACGCGUUUCGGGGGCCUUGCGUCGCCCACGCCGGCGUUCUCUUUAAAAGGAGCAACGACACCACGCCCCAUGGACCAUGCCGCAGGGUGAACGUCGUCCCGCAACUGCCGUGCACCCGUCAAAAGGAGGCGUCUUCAAAAAAAAAACAAAAUAAAAACACAUACCGCGGCGCGUAUUAGAGCGGCF
\ No newline at end of file
File added
"""Tests for poly_a module."""
import pytest
import os
from src.poly_a import generate_poly_a
from src.polyA_classifier import PolyAClassifier, Net
import linecache
class TestGeneratePolyA():
......@@ -47,3 +49,60 @@ class TestGeneratePolyA():
def test_wrong_weights(self, expected, weights):
with pytest.raises(expected):
generate_poly_a(weights=weights)
class TestClassifyPolyA:
"""Tests for poly(A) tail classification."""
real = linecache.getline('./tests/resources/internal_priming_examples.txt', 1).strip('\n')
fake = linecache.getline('./tests/resources/internal_priming_examples.txt', 2).strip('\n')
bad = linecache.getline('./tests/resources/internal_priming_examples.txt', 3).strip('\n')
@pytest.mark.parametrize(
'sample, expected',
[
(real, 1),
(fake, 0),
([real, fake], [1, 0])
]
)
def test_passes_set_all_args(self, sample, expected):
print(os.getcwd())
model = PolyAClassifier(Net, './tests/resources/internal_priming_test_model.pth')
pred = model.classify(sample)
assert pred == expected
@pytest.mark.parametrize(
'sample, expected',
[
(real[:-1], ValueError),
([real[:-1], fake], ValueError)
]
)
def test_wrong_input_length(self, sample, expected):
model = PolyAClassifier(state_dict_path='./tests/resources/internal_priming_test_model.pth')
with pytest.raises(expected):
model.classify(sample)
@pytest.mark.parametrize(
'sample, expected',
[
(0, TypeError),
(True, TypeError),
([0, 0.0], TypeError),
(bad, ValueError)
]
)
def test_wrong_input_type(self, sample, expected):
model = PolyAClassifier(state_dict_path='./tests/resources/internal_priming_test_model.pth')
with pytest.raises(expected):
model.classify(sample)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment