From bbca50d7de855d2a0e884f8fb6d29a9bcd8073ca Mon Sep 17 00:00:00 2001
From: "hugo.madgeleon" <hugo.madgeleon@stud.unibas.ch>
Date: Sun, 11 Dec 2022 22:59:42 +0100
Subject: [PATCH] final linter check

---
 .gitignore                     |   2 +-
 term_frag_sel/cli.py           |  42 +++++++++-----
 term_frag_sel/fragmentation.py |  54 +++++++++---------
 tests/test_cli.py              |   8 +--
 tests/test_files/test.csv      | 100 ++++++++++++++++-----------------
 tests/test_files/test_tab.csv  |  96 +++++++++++++++----------------
 tests/test_frag.py             |  59 +++++++++----------
 7 files changed, 187 insertions(+), 174 deletions(-)

diff --git a/.gitignore b/.gitignore
index 3816760..20d1246 100644
--- a/.gitignore
+++ b/.gitignore
@@ -61,4 +61,4 @@ __pycache__/
 *egg-info/
 .coverage
 build/
-*play.py
\ No newline at end of file
+*/play.py
\ No newline at end of file
diff --git a/term_frag_sel/cli.py b/term_frag_sel/cli.py
index 2ca6cf2..30c2535 100644
--- a/term_frag_sel/cli.py
+++ b/term_frag_sel/cli.py
@@ -10,6 +10,13 @@ import pandas as pd  # type: ignore
 from term_frag_sel.fragmentation import fragmentation
 from term_frag_sel.utils import check_positive, check_prob
 
+logging.basicConfig(
+    format='[%(asctime)s: %(levelname)s] %(message)s \
+        (module "%(module)s")',
+    level=logging.INFO,
+)
+logger = logging.getLogger("main")
+
 
 def main(args: argparse.Namespace):
     """Use CLI arguments to fragment sequences and output text file \
@@ -31,12 +38,14 @@ def main(args: argparse.Namespace):
     logger.info("Fragmentation of %s...", args.fasta)
     splits = np.arange(0, len(list(fasta))+args.size, args.size)
 
+    nuc_probs = {'A': args.a_prob, 'T': args.t_prob,
+                 'G': args.g_prob, 'C': args.c_prob}
     for i, split in enumerate(splits):
         fasta_dict = fasta[split:splits[i+1]]
         term_frags = fragmentation(fasta_dict, seq_counts,
+                                   nuc_probs,
                                    args.mean, args.std,
-                                   args.a_prob, args.t_prob,
-                                   args.g_prob, args.c_prob)
+                                   )
 
         logger.info("Writing batch %s sequences to %s...", i, args.output)
         with open(args.output, 'a', encoding="utf-8") as out_file:
@@ -57,23 +66,28 @@ def file_validation(fasta_file: str,
     Returns:
         tuple: fasta dict and sequence counts pd.DataFrame
     """
-    fasta_sequences = SeqIO.parse(open(fasta_file, "r", encoding="utf-8"),'fasta')
-    if not any(fasta_sequences):
-        raise ValueError("Input FASTA file is either empty or \
-            incorrect file type.")
-
     fasta_dict = {}
-    for record in fasta_sequences:
-        fasta_dict[record.id] = str(record.seq).upper()
+    with open(fasta_file, "r", encoding="utf-8") as handle:
+        fasta_sequences = SeqIO.parse(handle, "fasta")
+
+        if not any(fasta_sequences):
+            raise ValueError("Input FASTA file is either empty or \
+                incorrect file type.")
+
+        for record in fasta_sequences:
+            fasta_dict[record.id] = str(record.seq).upper()
 
     count_path = Path(counts_file)
     if not count_path.is_file():
-        logger.exception("Input counts file does not exist or isn't a file.")
+        raise FileNotFoundError("Input counts file does not exist or \
+            isn't a file.")
+
+    if sep == ",":
+        seq_counts = pd.read_csv(counts_file, names=["seqID", "count"])
+        seq_counts = seq_counts.astype({"seqID": str})
     else:
-        if sep == ",":
-            seq_counts = pd.read_csv(counts_file, names=["seqID", "count"])
-        else:
-            seq_counts = pd.read_table(counts_file, names=["seqID", "count"])
+        seq_counts = pd.read_table(counts_file, names=["seqID", "count"])
+        seq_counts = seq_counts.astype({"seqID": str})
 
     return fasta_dict, seq_counts
 
diff --git a/term_frag_sel/fragmentation.py b/term_frag_sel/fragmentation.py
index 7444ff8..ef347ca 100644
--- a/term_frag_sel/fragmentation.py
+++ b/term_frag_sel/fragmentation.py
@@ -4,6 +4,7 @@ import re
 import numpy as np
 import pandas as pd  # type: ignore
 
+
 def get_cut_number(seq_len: int, mean: float) -> int:
     """Get the number of cuts for a particular sequence.
 
@@ -14,14 +15,22 @@ def get_cut_number(seq_len: int, mean: float) -> int:
     Returns:
         int: number of cuts
     """
-    cuts_distribution = [] # distribution of cut numbers (n_cuts)
+    if not isinstance(seq_len, int):
+        raise ValueError(f"Sequence length must be numeric, \
+            not {type(seq_len)}")
+
+    if not isinstance(mean, int):
+        raise ValueError(f"Mean must be numeric, not {type(mean)}")
 
-    # 1000 iterations should not be too computationally inefficient given the nature of the code
+    cuts_distribution = []  # distribution of cut numbers (n_cuts)
+
+    # 1000 iterations should not be too computationally inefficient
+    # given the nature of the code
     for _ in range(1000):
         n_cuts = 0
         len_sum = 0
         while True:
-            len_sum += np.random.exponential(scale = mean)
+            len_sum += np.random.exponential(scale=mean)
             if len_sum < seq_len:
                 n_cuts += 1
             else:
@@ -29,52 +38,48 @@ def get_cut_number(seq_len: int, mean: float) -> int:
                 break
 
     cuts_distribution.sort()
-    cut_counts = {x:cuts_distribution.count(x) for x in cuts_distribution}
+    cut_counts = {x: cuts_distribution.count(x) for x in cuts_distribution}
     cut_probs = [x/1000 for x in cut_counts.values()]
 
-    # randomly ick no. of cut from cut distribution based on probability of cut numbers
-    n_cuts =  np.random.choice(list(cut_counts.keys()), p = cut_probs)
+    # randomly ick no. of cut from cut distribution based on probability
+    # of cut numbers
+    n_cuts = np.random.choice(list(cut_counts.keys()), p=cut_probs)
 
     return n_cuts
 
+
 def fragmentation(fasta: dict, seq_counts: pd.DataFrame,
-                  mean_length: int = 100, std: int = 10,
-                  a_prob: float = 0.22, t_prob: float = 0.25,
-                  g_prob: float = 0.25, c_prob: float = 0.28
+                  nuc_probs: dict,
+                  mu_length: int = 100, std: int = 10,
                   ) -> list:
     """Fragment cDNA sequences and select terminal fragment.
 
     Args:
         fasta_file (dict): dictionary of {transcript IDs: sequences}
         counts_file (pd.DataFrame): dataframe with sequence counts and IDs
-        mean_length (int): mean length of desired fragments
+        nuc_probs (dict): probability of cut occuring a certain nucleotide. \
+            Ordered as A, T, G, C. E.g: {'A': 0.22, 'T': 0.25, \
+                'G': 0.25, 'C': 0.28}.
+        mu_length (int): mean length of desired fragments
         std (int): standard deviation of desired fragment lengths
-        a_prob (float): probability of nucleotide A
-        t_prob (float): probability of nucleotide T
-        g_prob (float): probability of nucleotide G
-        c_prob (float): probability of nucleotide C
 
     Returns:
         list: list of selected terminal fragments
     """
     # calculated using https://www.nature.com/articles/srep04532#MOESM1
-    nuc_probs = {'A': a_prob, 'T': t_prob, 'G': g_prob, 'C': c_prob}
-
     term_frags = []
     for seq_id, seq in fasta.items():
         counts = seq_counts[seq_counts["seqID"] == seq_id]["count"]
-        for _ in range(counts):
-            seq_len = len(seq)
-
+        for _ in range(counts.iloc[0]):
             # pick no. of cuts from gauss fragment length distribution
-            n_cuts = get_cut_number(seq_len, mean_length)
-
             # non-uniformly random DNA fragmentation implementation based on
             # https://www.nature.com/articles/srep04532#Sec1
             # assume fragmentation by sonication for NGS workflow
             cuts = []
             cut_nucs = np.random.choice(list(nuc_probs.keys()),
-                                        n_cuts, p=list(nuc_probs.values()))
+                                        size=get_cut_number(len(seq),
+                                                            mu_length),
+                                        p=list(nuc_probs.values()))
             for nuc in cut_nucs:
                 nuc_pos = [x.start() for x in re.finditer(nuc, seq)]
                 pos = np.random.choice(nuc_pos)
@@ -87,9 +92,8 @@ def fragmentation(fasta: dict, seq_counts: pd.DataFrame,
             term_frag = ""
 
             # check if 3' fragment is in the correct size range
-            fragment = seq[cuts[-1]:len(seq)]
-            if mean_length-std <= len(fragment) <= mean_length+std:
-                term_frag = fragment
+            if mu_length-std <= len(seq[cuts[-1]:len(seq)]) <= mu_length+std:
+                term_frag = seq[cuts[-1]:len(seq)]
 
             if term_frag != "":
                 term_frags.append(term_frag)
diff --git a/tests/test_cli.py b/tests/test_cli.py
index 3b4b121..4ffe56a 100644
--- a/tests/test_cli.py
+++ b/tests/test_cli.py
@@ -1,7 +1,5 @@
 """Test cli.py functions."""
 import pytest
-import pandas as pd
-# from Bio import SeqIO
 
 from term_frag_sel.cli import file_validation, main  # type: ignore
 
@@ -12,10 +10,11 @@ TAB_FILE = "tests/test_files/test_tab.csv"
 _, csv_counts = file_validation(FASTA_FILE, CSV_FILE, ",")
 _, tab_counts = file_validation(FASTA_FILE, TAB_FILE, "\t")
 
+
 def test_file():
     """Test file_validation function."""
-    assert isinstance(file_validation(FASTA_FILE, CSV_FILE, ","), tuple[dict, pd.DataFrame])
-    assert isinstance(file_validation(FASTA_FILE, TAB_FILE, "\t"), tuple[dict, pd.DataFrame])
+    assert isinstance(file_validation(FASTA_FILE, CSV_FILE, ","), tuple)
+    assert isinstance(file_validation(FASTA_FILE, TAB_FILE, "\t"), tuple)
     assert csv_counts.equals(tab_counts)
 
     with pytest.raises(FileNotFoundError):
@@ -25,6 +24,7 @@ def test_file():
     with pytest.raises(ValueError):
         file_validation(CSV_FILE, CSV_FILE, ",")
 
+
 def test_main():
     """Test main() function."""
     with pytest.raises(TypeError):
diff --git a/tests/test_files/test.csv b/tests/test_files/test.csv
index 82dfa39..d6ffada 100644
--- a/tests/test_files/test.csv
+++ b/tests/test_files/test.csv
@@ -1,50 +1,50 @@
-0,1,86
-1,2,60
-2,3,56
-3,4,129
-4,5,138
-5,6,89
-6,7,107
-7,8,52
-8,9,81
-9,10,83
-10,11,139
-11,12,66
-12,13,105
-13,14,59
-14,15,93
-15,16,134
-16,17,50
-17,18,64
-18,19,110
-19,20,119
-20,21,54
-21,22,105
-22,23,124
-23,24,109
-24,25,125
-25,26,98
-26,27,88
-27,28,55
-28,29,70
-29,30,147
-30,31,123
-31,32,66
-32,33,54
-33,34,60
-34,35,79
-35,36,121
-36,37,69
-37,38,63
-38,39,121
-39,40,149
-40,41,52
-41,42,100
-42,43,54
-43,44,81
-44,45,150
-45,46,116
-46,47,128
-47,48,73
-48,49,144
-49,50,89
+0,1,99
+1,2,130
+2,3,144
+3,4,52
+4,5,98
+5,6,85
+6,7,119
+7,8,143
+8,9,112
+9,10,87
+10,11,53
+11,12,50
+12,13,129
+13,14,112
+14,15,56
+15,16,82
+16,17,130
+17,18,98
+18,19,87
+19,20,75
+20,21,140
+21,22,141
+22,23,74
+23,24,54
+24,25,56
+25,26,56
+26,27,56
+27,28,99
+28,29,101
+29,30,101
+30,31,62
+31,32,96
+32,33,131
+33,34,117
+34,35,53
+35,36,81
+36,37,114
+37,38,106
+38,39,67
+39,40,121
+40,41,134
+41,42,105
+42,43,91
+43,44,90
+44,45,145
+45,46,59
+46,47,84
+47,48,62
+48,49,50
+49,50,86
diff --git a/tests/test_files/test_tab.csv b/tests/test_files/test_tab.csv
index 8b7c982..ece6e9e 100644
--- a/tests/test_files/test_tab.csv
+++ b/tests/test_files/test_tab.csv
@@ -1,50 +1,50 @@
-0	1	86
-1	2	102
-2	3	81
-3	4	91
-4	5	121
-5	6	82
-6	7	103
-7	8	125
-8	9	135
-9	10	133
-10	11	110
-11	12	108
-12	13	89
-13	14	125
-14	15	76
-15	16	85
+0	1	99
+1	2	130
+2	3	144
+3	4	52
+4	5	98
+5	6	85
+6	7	119
+7	8	143
+8	9	112
+9	10	87
+10	11	53
+11	12	50
+12	13	129
+13	14	112
+14	15	56
+15	16	82
 16	17	130
-17	18	63
-18	19	137
-19	20	55
-20	21	148
-21	22	101
-22	23	145
-23	24	99
-24	25	50
-25	26	101
-26	27	134
-27	28	60
-28	29	107
-29	30	134
-30	31	76
-31	32	118
-32	33	99
-33	34	64
-34	35	97
-35	36	118
-36	37	131
-37	38	142
-38	39	119
-39	40	50
-40	41	90
-41	42	65
-42	43	140
-43	44	145
-44	45	84
-45	46	144
-46	47	103
-47	48	96
+17	18	98
+18	19	87
+19	20	75
+20	21	140
+21	22	141
+22	23	74
+23	24	54
+24	25	56
+25	26	56
+26	27	56
+27	28	99
+28	29	101
+29	30	101
+30	31	62
+31	32	96
+32	33	131
+33	34	117
+34	35	53
+35	36	81
+36	37	114
+37	38	106
+38	39	67
+39	40	121
+40	41	134
+41	42	105
+42	43	91
+43	44	90
+44	45	145
+45	46	59
+46	47	84
+47	48	62
 48	49	50
-49	50	112
+49	50	86
diff --git a/tests/test_frag.py b/tests/test_frag.py
index ba3a0b1..d6eaccd 100644
--- a/tests/test_frag.py
+++ b/tests/test_frag.py
@@ -3,57 +3,52 @@ import pandas as pd
 import pytest
 from Bio import SeqIO
 
-from term_frag_sel.fragmentation import fragmentation, get_cut_number  # type: ignore
+from term_frag_sel.fragmentation import fragmentation, get_cut_number
 
 FASTA_FILE = "tests/test_files/test.fasta"
 CSV_FILE = "tests/test_files/test.csv"
 
-fasta_sequences = SeqIO.parse(open(FASTA_FILE, "r", encoding="utf-8"),'fasta')
 fasta_dict = {}
-for record in fasta_sequences:
-    fasta_dict[record.id] = str(record.seq).upper()
+with open(FASTA_FILE, "r", encoding="utf-8") as handle:
+    fasta_sequences = SeqIO.parse(handle, "fasta")
+    for record in fasta_sequences:
+        fasta_dict[record.id] = str(record.seq).upper()
 
 seq_counts = pd.read_csv(CSV_FILE, names=["seqID", "count"])
+seq_counts = seq_counts.astype({"seqID": str})
+
+NUC = {'A': 0.22, 'T': 0.25, 'G': 0.25, 'C': 0.28}
 
-MU = 100
-STD = 10
-A_PROB = 0.22
-C_PROB = 0.28
-T_PROB = 0.25
-G_PROB = 0.25
 
 def test_get_cut_number():
     """Test get_cut_number function."""
-    assert get_cut_number(100,20,4) <= 15 and get_cut_number(100,20,4) >= 0
+    assert get_cut_number(100, 20) <= 15 and get_cut_number(100, 20) >= 0
 
     with pytest.raises(ValueError):
-        get_cut_number(seq_len = 'a', mean = 4, std = 2.3) # error if seq_len not int
-    with pytest.raises(ValueError):
-        get_cut_number(seq_len = 10, mean = 'a', std = 2.3) # error if mean not int
+        get_cut_number(seq_len='a', mean=4)
     with pytest.raises(ValueError):
-        get_cut_number(seq_len = 10, mean = 3, std = 'a') # error if std not float
-    with pytest.raises(TypeError):
-        get_cut_number(seq_len = 10, mean = 3) # error if there is a missing argument
+        get_cut_number(seq_len=10, mean='a')
+
 
 def test_fragmentation():
     """Test fragmentation function."""
-    assert isinstance(fragmentation(fasta_dict, seq_counts), list)
+    assert isinstance(fragmentation(fasta_dict, seq_counts,
+                                    NUC), list)
 
+    # no need to check string mean or std since it's checked at CLI
     with pytest.raises(ValueError):
-        fragmentation(fasta_dict, seq_counts, mean_length = 'a', std = 10,
-                      a_prob = 0.22, t_prob = 0.25,g_prob = 0.25, c_prob = 0.28)
-    with pytest.raises(ValueError):
-        fragmentation(fasta_dict, seq_counts, mean_length = 100, std = 'a',
-                      a_prob = 0.22, t_prob = 0.25,g_prob = 0.25, c_prob = 0.28)
-    with pytest.raises(ValueError):
-        fragmentation(fasta_dict, seq_counts, mean_length = 100, std = 10,
-                      a_prob = 'a', t_prob = 0.25, g_prob = 0.25, c_prob = 0.28)
+        nuc_probs = {'A': 'a', 'T': 0.25,
+                     'G': 0.25, 'C': 0.28}
+        fragmentation(fasta_dict, seq_counts, nuc_probs)
     with pytest.raises(ValueError):
-        fragmentation(fasta_dict, seq_counts, mean_length = 100, std = 10,
-                      a_prob = 0.22, t_prob = 'a', g_prob = 0.25, c_prob = 0.28)
+        nuc_probs = {'A': 0.22, 'T': 'a',
+                     'G': 0.25, 'C': 0.28}
+        fragmentation(fasta_dict, seq_counts, nuc_probs)
     with pytest.raises(ValueError):
-        fragmentation(fasta_dict, seq_counts, mean_length = 100, std = 10,
-                      a_prob = 0.22, t_prob = 0.25, g_prob = 'a', c_prob = 0.28)
+        nuc_probs = {'A': 0.22, 'T': 0.25,
+                     'G': 'a', 'C': 0.28}
+        fragmentation(fasta_dict, seq_counts, nuc_probs)
     with pytest.raises(ValueError):
-        fragmentation(fasta_dict, seq_counts, mean_length = 100, std = 10,
-                      a_prob = 0.22, t_prob = 0.25, g_prob = 0.25, c_prob = 'a')
+        nuc_probs = {'A': 0.22, 'T': 0.25,
+                     'G': 0.25, 'C': 'a'}
+        fragmentation(fasta_dict, seq_counts, nuc_probs)
-- 
GitLab