From 485bdca3683f70a560965b0e641e13044a1873e4 Mon Sep 17 00:00:00 2001
From: TheRiPtide <g.zaugg97@gmail.com>
Date: Wed, 22 Dec 2021 10:32:24 +0100
Subject: [PATCH] refactor: small changes to adjust style and format

---
 src/polyA_classifier/__init__.py |  2 +-
 src/polyA_classifier/cli.py      | 16 +++++++++-------
 tests/test_poly_a.py             | 20 ++++++--------------
 3 files changed, 16 insertions(+), 22 deletions(-)

diff --git a/src/polyA_classifier/__init__.py b/src/polyA_classifier/__init__.py
index 0f5ce03..75219a0 100644
--- a/src/polyA_classifier/__init__.py
+++ b/src/polyA_classifier/__init__.py
@@ -1 +1 @@
-"""Module for classifying poly(A) sequences."""
+"""Package for classifying poly(A) sequences."""
diff --git a/src/polyA_classifier/cli.py b/src/polyA_classifier/cli.py
index 397ccf8..5d0e320 100644
--- a/src/polyA_classifier/cli.py
+++ b/src/polyA_classifier/cli.py
@@ -1,18 +1,16 @@
 """Command-line interface for the poly(A) classifier."""
-
-import sys
-import argparse
+from sys import stdout
+from argparse import ArgumentParser
 from src.polyA_classifier.polyA_classifier import PolyAClassifier
 
-parser = argparse.ArgumentParser()
+parser = ArgumentParser()
 
 parser.add_argument('data', action='store', help='str or list(str) of length 200 chars to classify.')
 parser.add_argument('-p', '--path', action='store', help='Path to state-dict for Net-model.',
                     default='./models/internal_priming.pth')
 
 
-if __name__ == '__main__':
-
+def main():
     args = parser.parse_args()
 
     if args.data[0] == '[':
@@ -26,5 +24,9 @@ if __name__ == '__main__':
     classifier = PolyAClassifier(state_dict_path=args.path)
 
     result = classifier.classify(data)
+    stdout.write(str(result))
+
+
+if __name__ == '__main__':
 
-    sys.stdout.write(str(result))
+    main()
diff --git a/tests/test_poly_a.py b/tests/test_poly_a.py
index 8c95e90..0c77a64 100644
--- a/tests/test_poly_a.py
+++ b/tests/test_poly_a.py
@@ -1,5 +1,4 @@
-"""Tests for poly_a module."""
-
+"""Tests for poly_a package."""
 import pytest
 import os
 from src.poly_a import generate_poly_a
@@ -18,7 +17,7 @@ class TestGeneratePolyA():
     def test_passes_set_all_args(self):
         res = generate_poly_a(
             length=10,
-            weights=(1,0,0,0),
+            weights=(1, 0, 0, 0),
         )
         assert isinstance(res, str)
         assert len(res) == 10
@@ -40,10 +39,10 @@ class TestGeneratePolyA():
     @pytest.mark.parametrize(
         "weights, expected",
         [
-            ((0,0,1), ValueError),
-            (('a', 0,0,1), ValueError),
-            ((0,0,0,-1), ValueError),
-            ((0,0,0,0), ValueError),
+            ((0, 0, 1), ValueError),
+            (('a', 0, 0, 1), ValueError),
+            ((0, 0, 0, -1), ValueError),
+            ((0, 0, 0, 0), ValueError),
         ]
     )
     def test_wrong_weights(self, expected, weights):
@@ -67,11 +66,7 @@ class TestClassifyPolyA:
         ]
     )
     def test_passes_set_all_args(self, sample, expected):
-
-        print(os.getcwd())
-
         model = PolyAClassifier(Net, './tests/resources/internal_priming_test_model.pth')
-
         pred = model.classify(sample)
 
         assert pred == expected
@@ -84,7 +79,6 @@ class TestClassifyPolyA:
         ]
     )
     def test_wrong_input_length(self, sample, expected):
-
         model = PolyAClassifier(state_dict_path='./tests/resources/internal_priming_test_model.pth')
 
         with pytest.raises(expected):
@@ -104,5 +98,3 @@ class TestClassifyPolyA:
 
         with pytest.raises(expected):
             model.classify(sample)
-
-
-- 
GitLab