From 551c2056c7cff2a2ebd4404d0cf5faff2bfbedff Mon Sep 17 00:00:00 2001
From: "hugo.madgeleon" <hugo.madgeleon@stud.unibas.ch>
Date: Sun, 11 Dec 2022 19:52:22 +0100
Subject: [PATCH] tanya's changes

---
 term_frag_sel/cli.py           | 16 +++++------
 term_frag_sel/fragmentation.py | 50 ++++++++++++++++++++++++++++------
 tests/test_frag.py             | 44 ++++++++++++++++++++++++++----
 tests/test_utils.py            |  6 ----
 4 files changed, 89 insertions(+), 27 deletions(-)

diff --git a/term_frag_sel/cli.py b/term_frag_sel/cli.py
index a9b8e95..2495d85 100644
--- a/term_frag_sel/cli.py
+++ b/term_frag_sel/cli.py
@@ -35,8 +35,8 @@ def main(args: argparse.Namespace):
         fasta_dict = fasta_parse[split:splits[i+1]]
         term_frags = fragmentation(fasta_dict, seq_counts,
                                    args.mean, args.std,
-                                   args.A_prob, args.T_prob,
-                                   args.G_prob, args.C_prob)
+                                   args.a_prob, args.t_prob,
+                                   args.g_prob, args.c_prob)
 
         logger.info("Writing batch %s sequences to %s...", i, args.output)
         with open(args.output, 'a', encoding="utf-8") as out_file:
@@ -97,21 +97,21 @@ def parse_arguments() -> argparse.Namespace:
                         help="output file path")
     parser.add_argument('--mean', required=False, default=300,
                         type=check_positive,
-                        help="Mean fragment length (default: 10)")
+                        help="Mean fragment length (default: 300)")
     parser.add_argument('--std', required=False, default=60,
                         type=check_positive,
                         help="Standard deviation fragment length \
-                            (defafult: 1)")
-    parser.add_argument('-a', '--A_prob', required=False, default=0.22,
+                            (defafult: 60)")
+    parser.add_argument('-a', '--a_prob', required=False, default=0.22,
                         type=check_prob,
                         help="Probability cut happens after nucleotide A")
-    parser.add_argument('-t', '--T_prob', required=False, default=0.25,
+    parser.add_argument('-t', '--t_prob', required=False, default=0.25,
                         type=check_prob,
                         help="Probability cut happens after nucleotide T")
-    parser.add_argument('-g', '--G_prob', required=False, default=0.25,
+    parser.add_argument('-g', '--g_prob', required=False, default=0.25,
                         type=check_prob,
                         help="Probability cut happens after nucleotide G")
-    parser.add_argument('-c', '--C_prob', required=False, default=0.28,
+    parser.add_argument('-c', '--c_prob', required=False, default=0.28,
                         type=check_prob,
                         help="Probability cut happens after nucleotide C")
     parser.add_argument('-s', '--size', required=False, default=10000,
diff --git a/term_frag_sel/fragmentation.py b/term_frag_sel/fragmentation.py
index 302fe6f..7444ff8 100644
--- a/term_frag_sel/fragmentation.py
+++ b/term_frag_sel/fragmentation.py
@@ -4,6 +4,38 @@ import re
 import numpy as np
 import pandas as pd  # type: ignore
 
+def get_cut_number(seq_len: int, mean: float) -> int:
+    """Get the number of cuts for a particular sequence.
+
+    Args:
+        seq_len (int): length of sequence/number of nucleotides in the sequence
+        mean (float): mean fragment length
+
+    Returns:
+        int: number of cuts
+    """
+    cuts_distribution = [] # distribution of cut numbers (n_cuts)
+
+    # 1000 iterations should not be too computationally inefficient given the nature of the code
+    for _ in range(1000):
+        n_cuts = 0
+        len_sum = 0
+        while True:
+            len_sum += np.random.exponential(scale = mean)
+            if len_sum < seq_len:
+                n_cuts += 1
+            else:
+                cuts_distribution.append(n_cuts)
+                break
+
+    cuts_distribution.sort()
+    cut_counts = {x:cuts_distribution.count(x) for x in cuts_distribution}
+    cut_probs = [x/1000 for x in cut_counts.values()]
+
+    # randomly ick no. of cut from cut distribution based on probability of cut numbers
+    n_cuts =  np.random.choice(list(cut_counts.keys()), p = cut_probs)
+
+    return n_cuts
 
 def fragmentation(fasta: dict, seq_counts: pd.DataFrame,
                   mean_length: int = 100, std: int = 10,
@@ -32,7 +64,10 @@ def fragmentation(fasta: dict, seq_counts: pd.DataFrame,
     for seq_id, seq in fasta.items():
         counts = seq_counts[seq_counts["seqID"] == seq_id]["count"]
         for _ in range(counts):
-            n_cuts = int(len(seq)/mean_length)
+            seq_len = len(seq)
+
+            # pick no. of cuts from gauss fragment length distribution
+            n_cuts = get_cut_number(seq_len, mean_length)
 
             # non-uniformly random DNA fragmentation implementation based on
             # https://www.nature.com/articles/srep04532#Sec1
@@ -50,13 +85,12 @@ def fragmentation(fasta: dict, seq_counts: pd.DataFrame,
             cuts.sort()
             cuts.insert(0, 0)
             term_frag = ""
-            for i, val in enumerate(cuts):
-                if i == len(cuts)-1:
-                    fragment = seq[val+1:cuts[-1]]
-                else:
-                    fragment = seq[val:cuts[i+1]]
-                if mean_length-std <= len(fragment) <= mean_length+std:
-                    term_frag = fragment
+
+            # check if 3' fragment is in the correct size range
+            fragment = seq[cuts[-1]:len(seq)]
+            if mean_length-std <= len(fragment) <= mean_length+std:
+                term_frag = fragment
+
             if term_frag != "":
                 term_frags.append(term_frag)
 
diff --git a/tests/test_frag.py b/tests/test_frag.py
index d97b176..a90eaf6 100644
--- a/tests/test_frag.py
+++ b/tests/test_frag.py
@@ -1,13 +1,17 @@
 """Test utils.py functions."""
+import random
 import pytest
 from Bio import SeqIO
 
-from term_frag_sel.fragmentation import fragmentation  # type: ignore
+from term_frag_sel.fragmentation import fragmentation, get_cut_number  # type: ignore
 
 with open("tests/test_files/test.fasta", "r", encoding="utf-8") as handle:
     fasta = SeqIO.parse(handle, "fasta")
 
-# have to create the counts file
+seq_counts = {}
+for entry in fasta:
+    name, sequence = entry.id, str(entry.seq)
+    seq_counts[name] = random.randint(50,150)
 
 MU = 100
 STD = 10
@@ -16,8 +20,38 @@ C_PROB = 0.28
 T_PROB = 0.25
 G_PROB = 0.25
 
+def test_get_cut_number():
+    """Test get_cut_number function."""
+    assert get_cut_number(100,20,4) <= 15 and get_cut_number(100,20,4) >= 0
 
-def test_frag():
-    """Test fragmentation function."""
+    with pytest.raises(ValueError):
+        get_cut_number(seq_len = 'a', mean = 4, std = 2.3) # error if seq_len not int
+    with pytest.raises(ValueError):
+        get_cut_number(seq_len = 10, mean = 'a', std = 2.3) # error if mean not int
+    with pytest.raises(ValueError):
+        get_cut_number(seq_len = 10, mean = 3, std = 'a') # error if std not float
     with pytest.raises(TypeError):
-        fragmentation()
+        get_cut_number(seq_len = 10, mean = 3) # error if there is a missing argument
+
+def test_fragmentation():
+    """Test fragmentation function."""
+    assert isinstance(fragmentation(fasta, seq_counts), list)
+
+    with pytest.raises(ValueError):
+        fragmentation(fasta, seq_counts, mean_length = 'a', std = 10,
+                      a_prob = 0.22, t_prob = 0.25,g_prob = 0.25, c_prob = 0.28)
+    with pytest.raises(ValueError):
+        fragmentation(fasta, seq_counts, mean_length = 100, std = 'a',
+                      a_prob = 0.22, t_prob = 0.25,g_prob = 0.25, c_prob = 0.28)
+    with pytest.raises(ValueError):
+        fragmentation(fasta, seq_counts, mean_length = 100, std = 10,
+                      a_prob = 'a', t_prob = 0.25,g_prob = 0.25, c_prob = 0.28)
+    with pytest.raises(ValueError):
+        fragmentation(fasta, seq_counts, mean_length = 100, std = 10,
+                      a_prob = 0.22, t_prob = 'a',g_prob = 0.25, c_prob = 0.28)
+    with pytest.raises(ValueError):
+        fragmentation(fasta, seq_counts, mean_length = 100, std = 10,
+                      a_prob = 0.22, t_prob = 0.25,g_prob = 'a', c_prob = 0.28)
+    with pytest.raises(ValueError):
+        fragmentation(fasta, seq_counts, mean_length = 100, std = 10,
+                      a_prob = 0.22, t_prob = 0.25,g_prob = 0.25, c_prob = 'a')
diff --git a/tests/test_utils.py b/tests/test_utils.py
index 36084a6..c7d7be2 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -1,9 +1,4 @@
 """Test utils.py functions."""
-<<<<<<< HEAD
-import pytest
-
-from 
-=======
 import argparse
 import pytest
 
@@ -40,4 +35,3 @@ def test_prob():
         check_prob("string")
     with pytest.raises(ValueError):
         check_prob("")
->>>>>>> hugo_new
-- 
GitLab