Skip to content
Snippets Groups Projects
Commit 485bdca3 authored by TheRiPtide's avatar TheRiPtide
Browse files

refactor: small changes to adjust style and format

parent bcd472ba
No related branches found
No related tags found
1 merge request!23feat: deep-leaning poly(A) classifier
Pipeline #13865 failed
This commit is part of merge request !23. Comments created here will be created in the context of that merge request.
"""Module for classifying poly(A) sequences."""
"""Package for classifying poly(A) sequences."""
"""Command-line interface for the poly(A) classifier."""
import sys
import argparse
from sys import stdout
from argparse import ArgumentParser
from src.polyA_classifier.polyA_classifier import PolyAClassifier
parser = argparse.ArgumentParser()
parser = ArgumentParser()
parser.add_argument('data', action='store', help='str or list(str) of length 200 chars to classify.')
parser.add_argument('-p', '--path', action='store', help='Path to state-dict for Net-model.',
default='./models/internal_priming.pth')
if __name__ == '__main__':
def main():
args = parser.parse_args()
if args.data[0] == '[':
......@@ -26,5 +24,9 @@ if __name__ == '__main__':
classifier = PolyAClassifier(state_dict_path=args.path)
result = classifier.classify(data)
stdout.write(str(result))
if __name__ == '__main__':
sys.stdout.write(str(result))
main()
"""Tests for poly_a module."""
"""Tests for poly_a package."""
import pytest
import os
from src.poly_a import generate_poly_a
......@@ -18,7 +17,7 @@ class TestGeneratePolyA():
def test_passes_set_all_args(self):
res = generate_poly_a(
length=10,
weights=(1,0,0,0),
weights=(1, 0, 0, 0),
)
assert isinstance(res, str)
assert len(res) == 10
......@@ -40,10 +39,10 @@ class TestGeneratePolyA():
@pytest.mark.parametrize(
"weights, expected",
[
((0,0,1), ValueError),
(('a', 0,0,1), ValueError),
((0,0,0,-1), ValueError),
((0,0,0,0), ValueError),
((0, 0, 1), ValueError),
(('a', 0, 0, 1), ValueError),
((0, 0, 0, -1), ValueError),
((0, 0, 0, 0), ValueError),
]
)
def test_wrong_weights(self, expected, weights):
......@@ -67,11 +66,7 @@ class TestClassifyPolyA:
]
)
def test_passes_set_all_args(self, sample, expected):
print(os.getcwd())
model = PolyAClassifier(Net, './tests/resources/internal_priming_test_model.pth')
pred = model.classify(sample)
assert pred == expected
......@@ -84,7 +79,6 @@ class TestClassifyPolyA:
]
)
def test_wrong_input_length(self, sample, expected):
model = PolyAClassifier(state_dict_path='./tests/resources/internal_priming_test_model.pth')
with pytest.raises(expected):
......@@ -104,5 +98,3 @@ class TestClassifyPolyA:
with pytest.raises(expected):
model.classify(sample)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment