Skip to content
Snippets Groups Projects
Commit 485bdca3 authored by TheRiPtide's avatar TheRiPtide
Browse files

refactor: small changes to adjust style and format

parent bcd472ba
No related branches found
No related tags found
1 merge request!23feat: deep-leaning poly(A) classifier
Pipeline #13865 failed
"""Module for classifying poly(A) sequences."""
"""Package for classifying poly(A) sequences."""
"""Command-line interface for the poly(A) classifier."""
import sys
import argparse
from sys import stdout
from argparse import ArgumentParser
from src.polyA_classifier.polyA_classifier import PolyAClassifier
parser = argparse.ArgumentParser()
parser = ArgumentParser()
parser.add_argument('data', action='store', help='str or list(str) of length 200 chars to classify.')
parser.add_argument('-p', '--path', action='store', help='Path to state-dict for Net-model.',
default='./models/internal_priming.pth')
if __name__ == '__main__':
def main():
args = parser.parse_args()
if args.data[0] == '[':
......@@ -26,5 +24,9 @@ if __name__ == '__main__':
classifier = PolyAClassifier(state_dict_path=args.path)
result = classifier.classify(data)
stdout.write(str(result))
if __name__ == '__main__':
sys.stdout.write(str(result))
main()
"""Tests for poly_a module."""
"""Tests for poly_a package."""
import pytest
import os
from src.poly_a import generate_poly_a
......@@ -18,7 +17,7 @@ class TestGeneratePolyA():
def test_passes_set_all_args(self):
res = generate_poly_a(
length=10,
weights=(1,0,0,0),
weights=(1, 0, 0, 0),
)
assert isinstance(res, str)
assert len(res) == 10
......@@ -40,10 +39,10 @@ class TestGeneratePolyA():
@pytest.mark.parametrize(
"weights, expected",
[
((0,0,1), ValueError),
(('a', 0,0,1), ValueError),
((0,0,0,-1), ValueError),
((0,0,0,0), ValueError),
((0, 0, 1), ValueError),
(('a', 0, 0, 1), ValueError),
((0, 0, 0, -1), ValueError),
((0, 0, 0, 0), ValueError),
]
)
def test_wrong_weights(self, expected, weights):
......@@ -67,11 +66,7 @@ class TestClassifyPolyA:
]
)
def test_passes_set_all_args(self, sample, expected):
print(os.getcwd())
model = PolyAClassifier(Net, './tests/resources/internal_priming_test_model.pth')
pred = model.classify(sample)
assert pred == expected
......@@ -84,7 +79,6 @@ class TestClassifyPolyA:
]
)
def test_wrong_input_length(self, sample, expected):
model = PolyAClassifier(state_dict_path='./tests/resources/internal_priming_test_model.pth')
with pytest.raises(expected):
......@@ -104,5 +98,3 @@ class TestClassifyPolyA:
with pytest.raises(expected):
model.classify(sample)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment