From 485bdca3683f70a560965b0e641e13044a1873e4 Mon Sep 17 00:00:00 2001 From: TheRiPtide <g.zaugg97@gmail.com> Date: Wed, 22 Dec 2021 10:32:24 +0100 Subject: [PATCH] refactor: small changes to adjust style and format --- src/polyA_classifier/__init__.py | 2 +- src/polyA_classifier/cli.py | 16 +++++++++------- tests/test_poly_a.py | 20 ++++++-------------- 3 files changed, 16 insertions(+), 22 deletions(-) diff --git a/src/polyA_classifier/__init__.py b/src/polyA_classifier/__init__.py index 0f5ce03..75219a0 100644 --- a/src/polyA_classifier/__init__.py +++ b/src/polyA_classifier/__init__.py @@ -1 +1 @@ -"""Module for classifying poly(A) sequences.""" +"""Package for classifying poly(A) sequences.""" diff --git a/src/polyA_classifier/cli.py b/src/polyA_classifier/cli.py index 397ccf8..5d0e320 100644 --- a/src/polyA_classifier/cli.py +++ b/src/polyA_classifier/cli.py @@ -1,18 +1,16 @@ """Command-line interface for the poly(A) classifier.""" - -import sys -import argparse +from sys import stdout +from argparse import ArgumentParser from src.polyA_classifier.polyA_classifier import PolyAClassifier -parser = argparse.ArgumentParser() +parser = ArgumentParser() parser.add_argument('data', action='store', help='str or list(str) of length 200 chars to classify.') parser.add_argument('-p', '--path', action='store', help='Path to state-dict for Net-model.', default='./models/internal_priming.pth') -if __name__ == '__main__': - +def main(): args = parser.parse_args() if args.data[0] == '[': @@ -26,5 +24,9 @@ if __name__ == '__main__': classifier = PolyAClassifier(state_dict_path=args.path) result = classifier.classify(data) + stdout.write(str(result)) + + +if __name__ == '__main__': - sys.stdout.write(str(result)) + main() diff --git a/tests/test_poly_a.py b/tests/test_poly_a.py index 8c95e90..0c77a64 100644 --- a/tests/test_poly_a.py +++ b/tests/test_poly_a.py @@ -1,5 +1,4 @@ -"""Tests for poly_a module.""" - +"""Tests for poly_a package.""" import pytest import os from src.poly_a import generate_poly_a @@ -18,7 +17,7 @@ class TestGeneratePolyA(): def test_passes_set_all_args(self): res = generate_poly_a( length=10, - weights=(1,0,0,0), + weights=(1, 0, 0, 0), ) assert isinstance(res, str) assert len(res) == 10 @@ -40,10 +39,10 @@ class TestGeneratePolyA(): @pytest.mark.parametrize( "weights, expected", [ - ((0,0,1), ValueError), - (('a', 0,0,1), ValueError), - ((0,0,0,-1), ValueError), - ((0,0,0,0), ValueError), + ((0, 0, 1), ValueError), + (('a', 0, 0, 1), ValueError), + ((0, 0, 0, -1), ValueError), + ((0, 0, 0, 0), ValueError), ] ) def test_wrong_weights(self, expected, weights): @@ -67,11 +66,7 @@ class TestClassifyPolyA: ] ) def test_passes_set_all_args(self, sample, expected): - - print(os.getcwd()) - model = PolyAClassifier(Net, './tests/resources/internal_priming_test_model.pth') - pred = model.classify(sample) assert pred == expected @@ -84,7 +79,6 @@ class TestClassifyPolyA: ] ) def test_wrong_input_length(self, sample, expected): - model = PolyAClassifier(state_dict_path='./tests/resources/internal_priming_test_model.pth') with pytest.raises(expected): @@ -104,5 +98,3 @@ class TestClassifyPolyA: with pytest.raises(expected): model.classify(sample) - - -- GitLab