diff --git a/term_frag_sel/cli.py b/term_frag_sel/cli.py index a9b8e957a60f3a64c33e2e15cd10de951b887850..2495d85e64b8a03076ba8423bff2595658088975 100644 --- a/term_frag_sel/cli.py +++ b/term_frag_sel/cli.py @@ -35,8 +35,8 @@ def main(args: argparse.Namespace): fasta_dict = fasta_parse[split:splits[i+1]] term_frags = fragmentation(fasta_dict, seq_counts, args.mean, args.std, - args.A_prob, args.T_prob, - args.G_prob, args.C_prob) + args.a_prob, args.t_prob, + args.g_prob, args.c_prob) logger.info("Writing batch %s sequences to %s...", i, args.output) with open(args.output, 'a', encoding="utf-8") as out_file: @@ -97,21 +97,21 @@ def parse_arguments() -> argparse.Namespace: help="output file path") parser.add_argument('--mean', required=False, default=300, type=check_positive, - help="Mean fragment length (default: 10)") + help="Mean fragment length (default: 300)") parser.add_argument('--std', required=False, default=60, type=check_positive, help="Standard deviation fragment length \ - (defafult: 1)") - parser.add_argument('-a', '--A_prob', required=False, default=0.22, + (defafult: 60)") + parser.add_argument('-a', '--a_prob', required=False, default=0.22, type=check_prob, help="Probability cut happens after nucleotide A") - parser.add_argument('-t', '--T_prob', required=False, default=0.25, + parser.add_argument('-t', '--t_prob', required=False, default=0.25, type=check_prob, help="Probability cut happens after nucleotide T") - parser.add_argument('-g', '--G_prob', required=False, default=0.25, + parser.add_argument('-g', '--g_prob', required=False, default=0.25, type=check_prob, help="Probability cut happens after nucleotide G") - parser.add_argument('-c', '--C_prob', required=False, default=0.28, + parser.add_argument('-c', '--c_prob', required=False, default=0.28, type=check_prob, help="Probability cut happens after nucleotide C") parser.add_argument('-s', '--size', required=False, default=10000, diff --git a/term_frag_sel/fragmentation.py b/term_frag_sel/fragmentation.py index 302fe6f41a8602688af4abcf8681ef830e91ba48..7444ff8b78c1f2cc51410030b52c61aa7298f5c0 100644 --- a/term_frag_sel/fragmentation.py +++ b/term_frag_sel/fragmentation.py @@ -4,6 +4,38 @@ import re import numpy as np import pandas as pd # type: ignore +def get_cut_number(seq_len: int, mean: float) -> int: + """Get the number of cuts for a particular sequence. + + Args: + seq_len (int): length of sequence/number of nucleotides in the sequence + mean (float): mean fragment length + + Returns: + int: number of cuts + """ + cuts_distribution = [] # distribution of cut numbers (n_cuts) + + # 1000 iterations should not be too computationally inefficient given the nature of the code + for _ in range(1000): + n_cuts = 0 + len_sum = 0 + while True: + len_sum += np.random.exponential(scale = mean) + if len_sum < seq_len: + n_cuts += 1 + else: + cuts_distribution.append(n_cuts) + break + + cuts_distribution.sort() + cut_counts = {x:cuts_distribution.count(x) for x in cuts_distribution} + cut_probs = [x/1000 for x in cut_counts.values()] + + # randomly ick no. of cut from cut distribution based on probability of cut numbers + n_cuts = np.random.choice(list(cut_counts.keys()), p = cut_probs) + + return n_cuts def fragmentation(fasta: dict, seq_counts: pd.DataFrame, mean_length: int = 100, std: int = 10, @@ -32,7 +64,10 @@ def fragmentation(fasta: dict, seq_counts: pd.DataFrame, for seq_id, seq in fasta.items(): counts = seq_counts[seq_counts["seqID"] == seq_id]["count"] for _ in range(counts): - n_cuts = int(len(seq)/mean_length) + seq_len = len(seq) + + # pick no. of cuts from gauss fragment length distribution + n_cuts = get_cut_number(seq_len, mean_length) # non-uniformly random DNA fragmentation implementation based on # https://www.nature.com/articles/srep04532#Sec1 @@ -50,13 +85,12 @@ def fragmentation(fasta: dict, seq_counts: pd.DataFrame, cuts.sort() cuts.insert(0, 0) term_frag = "" - for i, val in enumerate(cuts): - if i == len(cuts)-1: - fragment = seq[val+1:cuts[-1]] - else: - fragment = seq[val:cuts[i+1]] - if mean_length-std <= len(fragment) <= mean_length+std: - term_frag = fragment + + # check if 3' fragment is in the correct size range + fragment = seq[cuts[-1]:len(seq)] + if mean_length-std <= len(fragment) <= mean_length+std: + term_frag = fragment + if term_frag != "": term_frags.append(term_frag) diff --git a/tests/test_frag.py b/tests/test_frag.py index d97b176054f801a93796832ffc54c324b4076046..a90eaf620f739c062d9b4c5557790dceb77e80ba 100644 --- a/tests/test_frag.py +++ b/tests/test_frag.py @@ -1,13 +1,17 @@ """Test utils.py functions.""" +import random import pytest from Bio import SeqIO -from term_frag_sel.fragmentation import fragmentation # type: ignore +from term_frag_sel.fragmentation import fragmentation, get_cut_number # type: ignore with open("tests/test_files/test.fasta", "r", encoding="utf-8") as handle: fasta = SeqIO.parse(handle, "fasta") -# have to create the counts file +seq_counts = {} +for entry in fasta: + name, sequence = entry.id, str(entry.seq) + seq_counts[name] = random.randint(50,150) MU = 100 STD = 10 @@ -16,8 +20,38 @@ C_PROB = 0.28 T_PROB = 0.25 G_PROB = 0.25 +def test_get_cut_number(): + """Test get_cut_number function.""" + assert get_cut_number(100,20,4) <= 15 and get_cut_number(100,20,4) >= 0 -def test_frag(): - """Test fragmentation function.""" + with pytest.raises(ValueError): + get_cut_number(seq_len = 'a', mean = 4, std = 2.3) # error if seq_len not int + with pytest.raises(ValueError): + get_cut_number(seq_len = 10, mean = 'a', std = 2.3) # error if mean not int + with pytest.raises(ValueError): + get_cut_number(seq_len = 10, mean = 3, std = 'a') # error if std not float with pytest.raises(TypeError): - fragmentation() + get_cut_number(seq_len = 10, mean = 3) # error if there is a missing argument + +def test_fragmentation(): + """Test fragmentation function.""" + assert isinstance(fragmentation(fasta, seq_counts), list) + + with pytest.raises(ValueError): + fragmentation(fasta, seq_counts, mean_length = 'a', std = 10, + a_prob = 0.22, t_prob = 0.25,g_prob = 0.25, c_prob = 0.28) + with pytest.raises(ValueError): + fragmentation(fasta, seq_counts, mean_length = 100, std = 'a', + a_prob = 0.22, t_prob = 0.25,g_prob = 0.25, c_prob = 0.28) + with pytest.raises(ValueError): + fragmentation(fasta, seq_counts, mean_length = 100, std = 10, + a_prob = 'a', t_prob = 0.25,g_prob = 0.25, c_prob = 0.28) + with pytest.raises(ValueError): + fragmentation(fasta, seq_counts, mean_length = 100, std = 10, + a_prob = 0.22, t_prob = 'a',g_prob = 0.25, c_prob = 0.28) + with pytest.raises(ValueError): + fragmentation(fasta, seq_counts, mean_length = 100, std = 10, + a_prob = 0.22, t_prob = 0.25,g_prob = 'a', c_prob = 0.28) + with pytest.raises(ValueError): + fragmentation(fasta, seq_counts, mean_length = 100, std = 10, + a_prob = 0.22, t_prob = 0.25,g_prob = 0.25, c_prob = 'a') diff --git a/tests/test_utils.py b/tests/test_utils.py index 36084a66eb91d452b0bfd525b3e8b91c786f2b81..c7d7be2d3fb43f5511322d1bdfe2d4107ecc38ae 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,9 +1,4 @@ """Test utils.py functions.""" -<<<<<<< HEAD -import pytest - -from -======= import argparse import pytest @@ -40,4 +35,3 @@ def test_prob(): check_prob("string") with pytest.raises(ValueError): check_prob("") ->>>>>>> hugo_new