From bdd0ee5d8e80fe0093ffa4b43a198225ec6f3098 Mon Sep 17 00:00:00 2001
From: "hugo.madgeleon" <hugo.madgeleon@stud.unibas.ch>
Date: Sun, 11 Dec 2022 20:36:45 +0100
Subject: [PATCH] fixes to imports

---
 .gitignore                |  3 ++-
 term_frag_sel/cli.py      | 20 +++++++--------
 tests/test_cli.py         |  6 +++--
 tests/test_files/test.csv | 51 +++++++++++++++++++++++++++++++++++++++
 tests/test_frag.py        | 38 +++++++++++++++--------------
 5 files changed, 87 insertions(+), 31 deletions(-)
 create mode 100644 tests/test_files/test.csv

diff --git a/.gitignore b/.gitignore
index 67cda8e..3816760 100644
--- a/.gitignore
+++ b/.gitignore
@@ -60,4 +60,5 @@ __pycache__/
 *_cache
 *egg-info/
 .coverage
-build/
\ No newline at end of file
+build/
+*play.py
\ No newline at end of file
diff --git a/term_frag_sel/cli.py b/term_frag_sel/cli.py
index 2495d85..843d17b 100644
--- a/term_frag_sel/cli.py
+++ b/term_frag_sel/cli.py
@@ -26,13 +26,10 @@ def main(args: argparse.Namespace):
     fasta, seq_counts = file_validation(args.fasta, args.counts, args.sep)
 
     logger.info("Fragmentation of %s...", args.fasta)
-    fasta_parse = {}
-    for record in fasta:
-        fasta_parse[record.id] = record.seq
-    splits = np.arange(0, len(list(fasta_parse))+args.size, args.size)
+    splits = np.arange(0, len(list(fasta))+args.size, args.size)
 
     for i, split in enumerate(splits):
-        fasta_dict = fasta_parse[split:splits[i+1]]
+        fasta_dict = fasta[split:splits[i+1]]
         term_frags = fragmentation(fasta_dict, seq_counts,
                                    args.mean, args.std,
                                    args.a_prob, args.t_prob,
@@ -55,14 +52,17 @@ def file_validation(fasta_file: str,
         sep (str): Separator for counts file.
 
     Returns:
-        tuple: fasta and sequence counts variables
+        tuple: fasta dict and sequence counts pd.DataFrame
     """
-    with open(fasta_file, "r", encoding="utf-8") as handle:
-        fasta = SeqIO.parse(handle, "fasta")
-    if not any(fasta):
+    fasta_sequences = SeqIO.parse(open(fasta_file, "r", encoding="utf-8"),'fasta')
+    if not any(fasta_sequences):
         raise ValueError("Input FASTA file is either empty or \
             incorrect file type.")
 
+    fasta_dict = {}
+    for record in fasta_sequences:
+        fasta_dict[record.id] = str(record.seq).upper()
+
     count_path = Path(counts_file)
     if not count_path.is_file():
         logger.exception("Input counts file does not exist or isn't a file.")
@@ -72,7 +72,7 @@ def file_validation(fasta_file: str,
         else:
             seq_counts = pd.read_table(counts_file, names=["seqID", "count"])
 
-    return fasta, seq_counts
+    return fasta_dict, seq_counts
 
 
 def parse_arguments() -> argparse.Namespace:
diff --git a/tests/test_cli.py b/tests/test_cli.py
index 53a94c3..7865c82 100644
--- a/tests/test_cli.py
+++ b/tests/test_cli.py
@@ -1,13 +1,15 @@
 """Test cli.py functions."""
 import pytest
+import pandas as pd
 # from Bio import SeqIO
 
 from term_frag_sel.cli import file_validation  # type: ignore
 
 FASTA_FILE = "tests/test_files/test.fasta"
-
+CSV_FILE = "tests/test_files/test.csv"
 
 def test_file():
-    """Test check_positive function."""
+    """Test file_validation function."""
+    assert isinstance(file_validation(FASTA_FILE, CSV_FILE, ","), tuple[dict, pd.DataFrame])
     with pytest.raises(FileNotFoundError):
         file_validation("", "", ",")
diff --git a/tests/test_files/test.csv b/tests/test_files/test.csv
new file mode 100644
index 0000000..011ccf1
--- /dev/null
+++ b/tests/test_files/test.csv
@@ -0,0 +1,51 @@
+,seqID,count
+0,1,120
+1,2,60
+2,3,76
+3,4,141
+4,5,107
+5,6,59
+6,7,121
+7,8,72
+8,9,114
+9,10,126
+10,11,77
+11,12,130
+12,13,146
+13,14,69
+14,15,62
+15,16,93
+16,17,103
+17,18,150
+18,19,140
+19,20,56
+20,21,103
+21,22,115
+22,23,113
+23,24,53
+24,25,95
+25,26,112
+26,27,102
+27,28,53
+28,29,139
+29,30,143
+30,31,150
+31,32,126
+32,33,50
+33,34,94
+34,35,81
+35,36,116
+36,37,51
+37,38,110
+38,39,98
+39,40,60
+40,41,57
+41,42,73
+42,43,143
+43,44,116
+44,45,98
+45,46,139
+46,47,89
+47,48,63
+48,49,68
+49,50,84
diff --git a/tests/test_frag.py b/tests/test_frag.py
index a90eaf6..ba3a0b1 100644
--- a/tests/test_frag.py
+++ b/tests/test_frag.py
@@ -1,17 +1,19 @@
 """Test utils.py functions."""
-import random
+import pandas as pd
 import pytest
 from Bio import SeqIO
 
 from term_frag_sel.fragmentation import fragmentation, get_cut_number  # type: ignore
 
-with open("tests/test_files/test.fasta", "r", encoding="utf-8") as handle:
-    fasta = SeqIO.parse(handle, "fasta")
+FASTA_FILE = "tests/test_files/test.fasta"
+CSV_FILE = "tests/test_files/test.csv"
 
-seq_counts = {}
-for entry in fasta:
-    name, sequence = entry.id, str(entry.seq)
-    seq_counts[name] = random.randint(50,150)
+fasta_sequences = SeqIO.parse(open(FASTA_FILE, "r", encoding="utf-8"),'fasta')
+fasta_dict = {}
+for record in fasta_sequences:
+    fasta_dict[record.id] = str(record.seq).upper()
+
+seq_counts = pd.read_csv(CSV_FILE, names=["seqID", "count"])
 
 MU = 100
 STD = 10
@@ -35,23 +37,23 @@ def test_get_cut_number():
 
 def test_fragmentation():
     """Test fragmentation function."""
-    assert isinstance(fragmentation(fasta, seq_counts), list)
+    assert isinstance(fragmentation(fasta_dict, seq_counts), list)
 
     with pytest.raises(ValueError):
-        fragmentation(fasta, seq_counts, mean_length = 'a', std = 10,
+        fragmentation(fasta_dict, seq_counts, mean_length = 'a', std = 10,
                       a_prob = 0.22, t_prob = 0.25,g_prob = 0.25, c_prob = 0.28)
     with pytest.raises(ValueError):
-        fragmentation(fasta, seq_counts, mean_length = 100, std = 'a',
+        fragmentation(fasta_dict, seq_counts, mean_length = 100, std = 'a',
                       a_prob = 0.22, t_prob = 0.25,g_prob = 0.25, c_prob = 0.28)
     with pytest.raises(ValueError):
-        fragmentation(fasta, seq_counts, mean_length = 100, std = 10,
-                      a_prob = 'a', t_prob = 0.25,g_prob = 0.25, c_prob = 0.28)
+        fragmentation(fasta_dict, seq_counts, mean_length = 100, std = 10,
+                      a_prob = 'a', t_prob = 0.25, g_prob = 0.25, c_prob = 0.28)
     with pytest.raises(ValueError):
-        fragmentation(fasta, seq_counts, mean_length = 100, std = 10,
-                      a_prob = 0.22, t_prob = 'a',g_prob = 0.25, c_prob = 0.28)
+        fragmentation(fasta_dict, seq_counts, mean_length = 100, std = 10,
+                      a_prob = 0.22, t_prob = 'a', g_prob = 0.25, c_prob = 0.28)
     with pytest.raises(ValueError):
-        fragmentation(fasta, seq_counts, mean_length = 100, std = 10,
-                      a_prob = 0.22, t_prob = 0.25,g_prob = 'a', c_prob = 0.28)
+        fragmentation(fasta_dict, seq_counts, mean_length = 100, std = 10,
+                      a_prob = 0.22, t_prob = 0.25, g_prob = 'a', c_prob = 0.28)
     with pytest.raises(ValueError):
-        fragmentation(fasta, seq_counts, mean_length = 100, std = 10,
-                      a_prob = 0.22, t_prob = 0.25,g_prob = 0.25, c_prob = 'a')
+        fragmentation(fasta_dict, seq_counts, mean_length = 100, std = 10,
+                      a_prob = 0.22, t_prob = 0.25, g_prob = 0.25, c_prob = 'a')
-- 
GitLab