Skip to content
Snippets Groups Projects
Commit 485bdca3 authored by TheRiPtide's avatar TheRiPtide
Browse files

refactor: small changes to adjust style and format

parent bcd472ba
No related branches found
No related tags found
1 merge request!23feat: deep-leaning poly(A) classifier
Pipeline #13865 failed
This commit is part of merge request !23. Comments created here will be created in the context of that merge request.
"""Module for classifying poly(A) sequences.""" """Package for classifying poly(A) sequences."""
"""Command-line interface for the poly(A) classifier.""" """Command-line interface for the poly(A) classifier."""
from sys import stdout
import sys from argparse import ArgumentParser
import argparse
from src.polyA_classifier.polyA_classifier import PolyAClassifier from src.polyA_classifier.polyA_classifier import PolyAClassifier
parser = argparse.ArgumentParser() parser = ArgumentParser()
parser.add_argument('data', action='store', help='str or list(str) of length 200 chars to classify.') parser.add_argument('data', action='store', help='str or list(str) of length 200 chars to classify.')
parser.add_argument('-p', '--path', action='store', help='Path to state-dict for Net-model.', parser.add_argument('-p', '--path', action='store', help='Path to state-dict for Net-model.',
default='./models/internal_priming.pth') default='./models/internal_priming.pth')
if __name__ == '__main__': def main():
args = parser.parse_args() args = parser.parse_args()
if args.data[0] == '[': if args.data[0] == '[':
...@@ -26,5 +24,9 @@ if __name__ == '__main__': ...@@ -26,5 +24,9 @@ if __name__ == '__main__':
classifier = PolyAClassifier(state_dict_path=args.path) classifier = PolyAClassifier(state_dict_path=args.path)
result = classifier.classify(data) result = classifier.classify(data)
stdout.write(str(result))
if __name__ == '__main__':
sys.stdout.write(str(result)) main()
"""Tests for poly_a module.""" """Tests for poly_a package."""
import pytest import pytest
import os import os
from src.poly_a import generate_poly_a from src.poly_a import generate_poly_a
...@@ -18,7 +17,7 @@ class TestGeneratePolyA(): ...@@ -18,7 +17,7 @@ class TestGeneratePolyA():
def test_passes_set_all_args(self): def test_passes_set_all_args(self):
res = generate_poly_a( res = generate_poly_a(
length=10, length=10,
weights=(1,0,0,0), weights=(1, 0, 0, 0),
) )
assert isinstance(res, str) assert isinstance(res, str)
assert len(res) == 10 assert len(res) == 10
...@@ -40,10 +39,10 @@ class TestGeneratePolyA(): ...@@ -40,10 +39,10 @@ class TestGeneratePolyA():
@pytest.mark.parametrize( @pytest.mark.parametrize(
"weights, expected", "weights, expected",
[ [
((0,0,1), ValueError), ((0, 0, 1), ValueError),
(('a', 0,0,1), ValueError), (('a', 0, 0, 1), ValueError),
((0,0,0,-1), ValueError), ((0, 0, 0, -1), ValueError),
((0,0,0,0), ValueError), ((0, 0, 0, 0), ValueError),
] ]
) )
def test_wrong_weights(self, expected, weights): def test_wrong_weights(self, expected, weights):
...@@ -67,11 +66,7 @@ class TestClassifyPolyA: ...@@ -67,11 +66,7 @@ class TestClassifyPolyA:
] ]
) )
def test_passes_set_all_args(self, sample, expected): def test_passes_set_all_args(self, sample, expected):
print(os.getcwd())
model = PolyAClassifier(Net, './tests/resources/internal_priming_test_model.pth') model = PolyAClassifier(Net, './tests/resources/internal_priming_test_model.pth')
pred = model.classify(sample) pred = model.classify(sample)
assert pred == expected assert pred == expected
...@@ -84,7 +79,6 @@ class TestClassifyPolyA: ...@@ -84,7 +79,6 @@ class TestClassifyPolyA:
] ]
) )
def test_wrong_input_length(self, sample, expected): def test_wrong_input_length(self, sample, expected):
model = PolyAClassifier(state_dict_path='./tests/resources/internal_priming_test_model.pth') model = PolyAClassifier(state_dict_path='./tests/resources/internal_priming_test_model.pth')
with pytest.raises(expected): with pytest.raises(expected):
...@@ -104,5 +98,3 @@ class TestClassifyPolyA: ...@@ -104,5 +98,3 @@ class TestClassifyPolyA:
with pytest.raises(expected): with pytest.raises(expected):
model.classify(sample) model.classify(sample)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment