From 93cfe50c43dfb1bcfe92ecfb88026d9fbc9c497a Mon Sep 17 00:00:00 2001
From: "t.nandan" <t.nandan@stud.unibas.ch>
Date: Sat, 10 Dec 2022 20:47:31 +0000
Subject: [PATCH] updated frag files and test

---
 term_frag_sel/cli.py           |  6 ++--
 term_frag_sel/fragmentation.py | 62 ++++++++++++++++++++++++++--------
 term_frag_sel/utils.py         |  2 +-
 tests/test_cli.py              | 10 ++++--
 tests/test_frag.py             | 52 +++++++++++++++++++++++-----
 tests/test_utils.py            |  9 ++---
 6 files changed, 105 insertions(+), 36 deletions(-)

diff --git a/term_frag_sel/cli.py b/term_frag_sel/cli.py
index a9b8e95..f51f42c 100644
--- a/term_frag_sel/cli.py
+++ b/term_frag_sel/cli.py
@@ -7,7 +7,7 @@ import numpy as np
 import pandas as pd  # type: ignore
 
 
-from term_frag_sel.fragmentation import fragmentation
+from term_frag_sel.fragmentation import fragmentation 
 from term_frag_sel.utils import check_positive, check_prob
 
 
@@ -97,11 +97,11 @@ def parse_arguments() -> argparse.Namespace:
                         help="output file path")
     parser.add_argument('--mean', required=False, default=300,
                         type=check_positive,
-                        help="Mean fragment length (default: 10)")
+                        help="Mean fragment length (default: 300)")
     parser.add_argument('--std', required=False, default=60,
                         type=check_positive,
                         help="Standard deviation fragment length \
-                            (defafult: 1)")
+                            (defafult: 60)")
     parser.add_argument('-a', '--A_prob', required=False, default=0.22,
                         type=check_prob,
                         help="Probability cut happens after nucleotide A")
diff --git a/term_frag_sel/fragmentation.py b/term_frag_sel/fragmentation.py
index 302fe6f..06a5506 100644
--- a/term_frag_sel/fragmentation.py
+++ b/term_frag_sel/fragmentation.py
@@ -4,11 +4,43 @@ import re
 import numpy as np
 import pandas as pd  # type: ignore
 
+def get_cut_number(seq_len: int, mean: float, std: float) -> int:
+    """Get the number of cuts for a particular sequence.
+    
+    Args:
+        seq_len: length of sequence/number of nucleotides in the sequence
+        mean: mean fragment length
+        std: standard deviation of fragment lengths
+    
+    Returns: 
+        int: number of cuts 
+    """
+    cuts_distribution = [] # distribution of cut numbers (n_cuts)
+    
+    for _ in range(1000): # 1000 iterations should not be too computationally inefficient given the nature of the code
+        n_cuts = 0
+        len_sum = 0
+        while True:
+            len_sum += np.random.exponential(scale = mean)
+            if len_sum < seq_len: 
+                n_cuts += 1
+            else:
+                cuts_distribution.append(n_cuts)
+                break
+            
+    cuts_distribution.sort() 
+    cut_counts = {x:cuts_distribution.count(x) for x in cuts_distribution}  
+    cut_probs = [x/1000 for x in cut_counts.values()] 
+
+    # randomly ick no. of cut from cut distribution based on probability of cut numbers
+    n_cuts =  np.random.choice(list(cut_counts.keys()), p = cut_probs)
+    
+    return n_cuts
 
 def fragmentation(fasta: dict, seq_counts: pd.DataFrame,
                   mean_length: int = 100, std: int = 10,
-                  a_prob: float = 0.22, t_prob: float = 0.25,
-                  g_prob: float = 0.25, c_prob: float = 0.28
+                  A_prob: float = 0.22, T_prob: float = 0.25,
+                  G_prob: float = 0.25, C_prob: float = 0.28
                   ) -> list:
     """Fragment cDNA sequences and select terminal fragment.
 
@@ -17,22 +49,23 @@ def fragmentation(fasta: dict, seq_counts: pd.DataFrame,
         counts_file (pd.DataFrame): dataframe with sequence counts and IDs
         mean_length (int): mean length of desired fragments
         std (int): standard deviation of desired fragment lengths
-        a_prob (float): probability of nucleotide A
-        t_prob (float): probability of nucleotide T
-        g_prob (float): probability of nucleotide G
-        c_prob (float): probability of nucleotide C
+        A_prob (float): probability of nucleotide A
+        T_prob (float): probability of nucleotide T
+        G_prob (float): probability of nucleotide G
+        C_prob (float): probability of nucleotide C
 
     Returns:
         list: list of selected terminal fragments
     """
     # calculated using https://www.nature.com/articles/srep04532#MOESM1
-    nuc_probs = {'A': a_prob, 'T': t_prob, 'G': g_prob, 'C': c_prob}
+    nuc_probs = {'A': A_prob, 'T': T_prob, 'G': G_prob, 'C': C_prob}
 
     term_frags = []
     for seq_id, seq in fasta.items():
         counts = seq_counts[seq_counts["seqID"] == seq_id]["count"]
         for _ in range(counts):
-            n_cuts = int(len(seq)/mean_length)
+            seq_len = len(seq)
+            n_cuts = get_cut_number(seq_len, mean_length, std) # pick no. of cuts from gauss fragment length distribution
 
             # non-uniformly random DNA fragmentation implementation based on
             # https://www.nature.com/articles/srep04532#Sec1
@@ -50,14 +83,13 @@ def fragmentation(fasta: dict, seq_counts: pd.DataFrame,
             cuts.sort()
             cuts.insert(0, 0)
             term_frag = ""
-            for i, val in enumerate(cuts):
-                if i == len(cuts)-1:
-                    fragment = seq[val+1:cuts[-1]]
-                else:
-                    fragment = seq[val:cuts[i+1]]
-                if mean_length-std <= len(fragment) <= mean_length+std:
+            
+            # check if 3' fragment is in the correct size range
+            fragment = seq[cuts[-1]:len(seq)] 
+            if mean_length-std <= len(fragment) <= mean_length+std:
                     term_frag = fragment
+            
             if term_frag != "":
                 term_frags.append(term_frag)
 
-    return term_frags
+    return term_frags 
\ No newline at end of file
diff --git a/term_frag_sel/utils.py b/term_frag_sel/utils.py
index d77cc14..125e9eb 100644
--- a/term_frag_sel/utils.py
+++ b/term_frag_sel/utils.py
@@ -28,7 +28,7 @@ def check_positive(value: str) -> int:
 
 
 def check_prob(value: str) -> float:
-    """Check probability value is within ]0,1] range.
+    """Check probability value is within [0,1] range.
 
     Args:
         value (str): command line parameter
diff --git a/tests/test_cli.py b/tests/test_cli.py
index 53a94c3..1cba49d 100644
--- a/tests/test_cli.py
+++ b/tests/test_cli.py
@@ -1,13 +1,19 @@
 """Test cli.py functions."""
 import pytest
+import argparse
 # from Bio import SeqIO
 
-from term_frag_sel.cli import file_validation  # type: ignore
+from term_frag_sel.cli import file_validation, parse_arguments  # type: ignore
 
 FASTA_FILE = "tests/test_files/test.fasta"
 
-
 def test_file():
     """Test check_positive function."""
     with pytest.raises(FileNotFoundError):
         file_validation("", "", ",")
+
+def test_parser():
+    """Test parse_arguments function."""
+    with pytest.raises(argparse.ArgumentError):
+        parse_arguments(fasta='')
+    
\ No newline at end of file
diff --git a/tests/test_frag.py b/tests/test_frag.py
index d97b176..ddbfe98 100644
--- a/tests/test_frag.py
+++ b/tests/test_frag.py
@@ -1,13 +1,19 @@
 """Test utils.py functions."""
 import pytest
 from Bio import SeqIO
+import pandas as pd 
+import random
 
-from term_frag_sel.fragmentation import fragmentation  # type: ignore
+from term_frag_sel.fragmentation import fragmentation, get_cut_number  # type: ignore
 
-with open("tests/test_files/test.fasta", "r", encoding="utf-8") as handle:
+with open("test_files/test.fasta", "r", encoding="utf-8") as handle:
     fasta = SeqIO.parse(handle, "fasta")
-
-# have to create the counts file
+    
+seq_counts = {}
+fasta_sequences = SeqIO.parse(open("test_files/test.fasta"),'fasta')
+for fasta in fasta_sequences:
+        name, sequence = fasta.id, str(fasta.seq)
+        seq_counts[name] = random.randint(50,150)
 
 MU = 100
 STD = 10
@@ -16,8 +22,38 @@ C_PROB = 0.28
 T_PROB = 0.25
 G_PROB = 0.25
 
-
-def test_frag():
-    """Test fragmentation function."""
+def test_get_cut_number(): 
+    """Test get_cut_number function."""
+    assert get_cut_number(100,20,4) <= 15 and get_cut_number(100,20,4) >= 0
+    
+    with pytest.raises(ValueError):
+        get_cut_number(seq_len = 'a', mean = 4, std = 2.3) # error if seq_len not int
+    with pytest.raises(ValueError):
+        get_cut_number(seq_len = 10, mean = 'a', std = 2.3) # error if mean not int  
+    with pytest.raises(ValueError):
+        get_cut_number(seq_len = 10, mean = 3, std = 'a') # error if std not float 
     with pytest.raises(TypeError):
-        fragmentation()
+        get_cut_number(seq_len = 10, mean = 3) # error if there is a missing argument 
+    
+def test_fragmentation():
+    """Test fragmentation function."""
+    assert type(fragmentation(fasta, seq_counts)) == list 
+    
+    with pytest.raises(ValueError):
+        fragmentation(fasta, seq_counts, mean_length = 'a', std = 10, 
+                      A_prob = 0.22, T_prob = 0.25,G_prob = 0.25, C_prob = 0.28) # error if mean not int/float
+    with pytest.raises(ValueError):
+        fragmentation(fasta, seq_counts, mean_length = 100, std = 'a', 
+                      A_prob = 0.22, T_prob = 0.25,G_prob = 0.25, C_prob = 0.28) # error if stf not int/float
+    with pytest.raises(ValueError):
+        fragmentation(fasta, seq_counts, mean_length = 100, std = 10, 
+                      A_prob = 'a', T_prob = 0.25,G_prob = 0.25, C_prob = 0.28) # error if prob not float
+    with pytest.raises(ValueError):
+        fragmentation(fasta, seq_counts, mean_length = 100, std = 10, 
+                      A_prob = 0.22, T_prob = 'a',G_prob = 0.25, C_prob = 0.28) # error if prob not int/float
+    with pytest.raises(ValueError):
+        fragmentation(fasta, seq_counts, mean_length = 100, std = 10, 
+                      A_prob = 0.22, T_prob = 0.25,G_prob = 'a', C_prob = 0.28) # error if prob not int/float
+    with pytest.raises(ValueError):
+        fragmentation(fasta, seq_counts, mean_length = 100, std = 10, 
+                      A_prob = 0.22, T_prob = 0.25,G_prob = 0.25, C_prob = 'a') # error if prob not int/float
diff --git a/tests/test_utils.py b/tests/test_utils.py
index 36084a6..00c8ed7 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -1,11 +1,7 @@
 """Test utils.py functions."""
-<<<<<<< HEAD
-import pytest
 
-from 
-=======
-import argparse
 import pytest
+import argparse
 
 from term_frag_sel.utils import check_positive, check_prob  # type: ignore
 
@@ -23,7 +19,7 @@ def test_positive():
     with pytest.raises(argparse.ArgumentTypeError):
         check_positive("string")
     with pytest.raises(argparse.ArgumentTypeError):
-        check_positive("")
+        check_positive("") 
 
 
 def test_prob():
@@ -40,4 +36,3 @@ def test_prob():
         check_prob("string")
     with pytest.raises(ValueError):
         check_prob("")
->>>>>>> hugo_new
-- 
GitLab