diff --git a/src/polyA_classifier/__init__.py b/src/polyA_classifier/__init__.py index 0f5ce032e98c6174c7674e9a49fb1373faa53533..75219a0067cd6f611222dab87c8b479d5830b853 100644 --- a/src/polyA_classifier/__init__.py +++ b/src/polyA_classifier/__init__.py @@ -1 +1 @@ -"""Module for classifying poly(A) sequences.""" +"""Package for classifying poly(A) sequences.""" diff --git a/src/polyA_classifier/cli.py b/src/polyA_classifier/cli.py index 397ccf83519f4762711b5670f10c22667a8cc969..5d0e3208b7eea483fc2419868943856519acebf1 100644 --- a/src/polyA_classifier/cli.py +++ b/src/polyA_classifier/cli.py @@ -1,18 +1,16 @@ """Command-line interface for the poly(A) classifier.""" - -import sys -import argparse +from sys import stdout +from argparse import ArgumentParser from src.polyA_classifier.polyA_classifier import PolyAClassifier -parser = argparse.ArgumentParser() +parser = ArgumentParser() parser.add_argument('data', action='store', help='str or list(str) of length 200 chars to classify.') parser.add_argument('-p', '--path', action='store', help='Path to state-dict for Net-model.', default='./models/internal_priming.pth') -if __name__ == '__main__': - +def main(): args = parser.parse_args() if args.data[0] == '[': @@ -26,5 +24,9 @@ if __name__ == '__main__': classifier = PolyAClassifier(state_dict_path=args.path) result = classifier.classify(data) + stdout.write(str(result)) + + +if __name__ == '__main__': - sys.stdout.write(str(result)) + main() diff --git a/tests/test_poly_a.py b/tests/test_poly_a.py index 8c95e90a166916eee199eaf9e2b1acba18cf9ae1..0c77a64895ba4c6b61ce8373e930abbe13f552cc 100644 --- a/tests/test_poly_a.py +++ b/tests/test_poly_a.py @@ -1,5 +1,4 @@ -"""Tests for poly_a module.""" - +"""Tests for poly_a package.""" import pytest import os from src.poly_a import generate_poly_a @@ -18,7 +17,7 @@ class TestGeneratePolyA(): def test_passes_set_all_args(self): res = generate_poly_a( length=10, - weights=(1,0,0,0), + weights=(1, 0, 0, 0), ) assert isinstance(res, str) assert len(res) == 10 @@ -40,10 +39,10 @@ class TestGeneratePolyA(): @pytest.mark.parametrize( "weights, expected", [ - ((0,0,1), ValueError), - (('a', 0,0,1), ValueError), - ((0,0,0,-1), ValueError), - ((0,0,0,0), ValueError), + ((0, 0, 1), ValueError), + (('a', 0, 0, 1), ValueError), + ((0, 0, 0, -1), ValueError), + ((0, 0, 0, 0), ValueError), ] ) def test_wrong_weights(self, expected, weights): @@ -67,11 +66,7 @@ class TestClassifyPolyA: ] ) def test_passes_set_all_args(self, sample, expected): - - print(os.getcwd()) - model = PolyAClassifier(Net, './tests/resources/internal_priming_test_model.pth') - pred = model.classify(sample) assert pred == expected @@ -84,7 +79,6 @@ class TestClassifyPolyA: ] ) def test_wrong_input_length(self, sample, expected): - model = PolyAClassifier(state_dict_path='./tests/resources/internal_priming_test_model.pth') with pytest.raises(expected): @@ -104,5 +98,3 @@ class TestClassifyPolyA: with pytest.raises(expected): model.classify(sample) - -