diff --git a/term_frag_sel/cli.py b/term_frag_sel/cli.py index 93fafbb197c9a582e0281f33e443163bb0fca26e..67b9f7d554b0f2376955a2c77338d2b726cce1e4 100644 --- a/term_frag_sel/cli.py +++ b/term_frag_sel/cli.py @@ -8,7 +8,7 @@ import pandas as pd # type: ignore from term_frag_sel.fragmentation import fragmentation -from term_frag_sel.utils import check_positive, check_prob +from term_frag_sel.utils import check_positive logging.basicConfig( format='[%(asctime)s: %(levelname)s] %(message)s \ @@ -38,14 +38,10 @@ def main(args: argparse.Namespace): logger.info("Fragmentation of %s...", args.fasta) splits = np.arange(0, len(list(fasta))+args.size, args.size) - nuc_probs = {'A': args.a_prob, 'T': args.t_prob, - 'G': args.g_prob, 'C': args.c_prob} for i, split in enumerate(splits): fasta_dict = fasta[split:splits[i+1]] term_frags = fragmentation(fasta_dict, seq_counts, - nuc_probs, - args.mean, args.std, - ) + args.mean, args.std) logger.info("Writing batch %s sequences to %s...", i, args.output) with open(args.output, 'a', encoding="utf-8") as out_file: @@ -119,18 +115,6 @@ def parse_arguments() -> argparse.Namespace: type=check_positive, help="Standard deviation fragment length \ (defafult: 60)") - parser.add_argument('-a', '--a_prob', required=False, default=0.22, - type=check_prob, - help="Probability cut happens after nucleotide A") - parser.add_argument('-t', '--t_prob', required=False, default=0.25, - type=check_prob, - help="Probability cut happens after nucleotide T") - parser.add_argument('-g', '--g_prob', required=False, default=0.25, - type=check_prob, - help="Probability cut happens after nucleotide G") - parser.add_argument('-c', '--c_prob', required=False, default=0.28, - type=check_prob, - help="Probability cut happens after nucleotide C") parser.add_argument('-s', '--size', required=False, default=10000, type=check_positive, help="Chunk size for batch processing") diff --git a/term_frag_sel/fragmentation.py b/term_frag_sel/fragmentation.py index ef347cab82a652ede9f1ea3d2cd57f27920bb0cd..9cd8a0ff760a81bc17cb6673ec80dcb932a0d290 100644 --- a/term_frag_sel/fragmentation.py +++ b/term_frag_sel/fragmentation.py @@ -1,98 +1,46 @@ """Fragment sequences.""" -import re - -import numpy as np +import random import pandas as pd # type: ignore -def get_cut_number(seq_len: int, mean: float) -> int: - """Get the number of cuts for a particular sequence. - - Args: - seq_len (int): length of sequence/number of nucleotides in the sequence - mean (float): mean fragment length - - Returns: - int: number of cuts - """ - if not isinstance(seq_len, int): - raise ValueError(f"Sequence length must be numeric, \ - not {type(seq_len)}") - - if not isinstance(mean, int): - raise ValueError(f"Mean must be numeric, not {type(mean)}") - - cuts_distribution = [] # distribution of cut numbers (n_cuts) - - # 1000 iterations should not be too computationally inefficient - # given the nature of the code - for _ in range(1000): - n_cuts = 0 - len_sum = 0 - while True: - len_sum += np.random.exponential(scale=mean) - if len_sum < seq_len: - n_cuts += 1 - else: - cuts_distribution.append(n_cuts) - break - - cuts_distribution.sort() - cut_counts = {x: cuts_distribution.count(x) for x in cuts_distribution} - cut_probs = [x/1000 for x in cut_counts.values()] - - # randomly ick no. of cut from cut distribution based on probability - # of cut numbers - n_cuts = np.random.choice(list(cut_counts.keys()), p=cut_probs) - - return n_cuts - - def fragmentation(fasta: dict, seq_counts: pd.DataFrame, - nuc_probs: dict, - mu_length: int = 100, std: int = 10, - ) -> list: + mean: int = 100, std: int = 10) -> list: """Fragment cDNA sequences and select terminal fragment. Args: fasta_file (dict): dictionary of {transcript IDs: sequences} counts_file (pd.DataFrame): dataframe with sequence counts and IDs - nuc_probs (dict): probability of cut occuring a certain nucleotide. \ - Ordered as A, T, G, C. E.g: {'A': 0.22, 'T': 0.25, \ - 'G': 0.25, 'C': 0.28}. - mu_length (int): mean length of desired fragments + mean (int): mean length of desired fragments std (int): standard deviation of desired fragment lengths Returns: list: list of selected terminal fragments """ - # calculated using https://www.nature.com/articles/srep04532#MOESM1 + if not isinstance(mean, int): + raise ValueError(f"Mean must be numeric, not {type(mean)}") + + if not isinstance(std, int): + raise ValueError(f"Std must be numeric, not {type(mean)}") + term_frags = [] for seq_id, seq in fasta.items(): counts = seq_counts[seq_counts["seqID"] == seq_id]["count"] for _ in range(counts.iloc[0]): - # pick no. of cuts from gauss fragment length distribution - # non-uniformly random DNA fragmentation implementation based on - # https://www.nature.com/articles/srep04532#Sec1 - # assume fragmentation by sonication for NGS workflow cuts = [] - cut_nucs = np.random.choice(list(nuc_probs.keys()), - size=get_cut_number(len(seq), - mu_length), - p=list(nuc_probs.values())) - for nuc in cut_nucs: - nuc_pos = [x.start() for x in re.finditer(nuc, seq)] - pos = np.random.choice(nuc_pos) - while pos in cuts: - pos = np.random.choice(nuc_pos) - cuts.append(pos) + seq_len = len(seq) + prob_cut_per_base = 1/mean + + for i in range(seq_len): + rand_prob = random.uniform(0, 1) + if rand_prob < prob_cut_per_base: + cuts.append(i) cuts.sort() cuts.insert(0, 0) term_frag = "" # check if 3' fragment is in the correct size range - if mu_length-std <= len(seq[cuts[-1]:len(seq)]) <= mu_length+std: + if mean - std <= len(seq[cuts[-1]:len(seq)]) <= mean + std: term_frag = seq[cuts[-1]:len(seq)] if term_frag != "": diff --git a/term_frag_sel/utils.py b/term_frag_sel/utils.py index d77cc1497952b11b89f9754aad33e93dd5d78965..a565dcefe272521ce141f5925f047cf7c761d51b 100644 --- a/term_frag_sel/utils.py +++ b/term_frag_sel/utils.py @@ -25,22 +25,3 @@ def check_positive(value: str) -> int: got: {value}""") from exc else: return ivalue - - -def check_prob(value: str) -> float: - """Check probability value is within ]0,1] range. - - Args: - value (str): command line parameter - - Raises: - argparse.ArgumentTypeError: received a value outside valid range - - Returns: - float: float version of input value - """ - pvalue = float(value) - if pvalue <= 0 or pvalue > 1: - raise argparse.ArgumentTypeError("""Expected a positive float between - 0 and 1, but got {value}""") - return pvalue diff --git a/tests/test_frag.py b/tests/test_frag.py index d6eaccd78626a6129fe4f1c4a05a9e66da43328e..da2d5bb0be820bfb92250c6fa581da98956eec60 100644 --- a/tests/test_frag.py +++ b/tests/test_frag.py @@ -3,7 +3,7 @@ import pandas as pd import pytest from Bio import SeqIO -from term_frag_sel.fragmentation import fragmentation, get_cut_number +from term_frag_sel.fragmentation import fragmentation FASTA_FILE = "tests/test_files/test.fasta" CSV_FILE = "tests/test_files/test.csv" @@ -17,38 +17,13 @@ with open(FASTA_FILE, "r", encoding="utf-8") as handle: seq_counts = pd.read_csv(CSV_FILE, names=["seqID", "count"]) seq_counts = seq_counts.astype({"seqID": str}) -NUC = {'A': 0.22, 'T': 0.25, 'G': 0.25, 'C': 0.28} - - -def test_get_cut_number(): - """Test get_cut_number function.""" - assert get_cut_number(100, 20) <= 15 and get_cut_number(100, 20) >= 0 - - with pytest.raises(ValueError): - get_cut_number(seq_len='a', mean=4) - with pytest.raises(ValueError): - get_cut_number(seq_len=10, mean='a') - def test_fragmentation(): """Test fragmentation function.""" - assert isinstance(fragmentation(fasta_dict, seq_counts, - NUC), list) + assert isinstance(fragmentation(fasta_dict, seq_counts), list) - # no need to check string mean or std since it's checked at CLI - with pytest.raises(ValueError): - nuc_probs = {'A': 'a', 'T': 0.25, - 'G': 0.25, 'C': 0.28} - fragmentation(fasta_dict, seq_counts, nuc_probs) with pytest.raises(ValueError): - nuc_probs = {'A': 0.22, 'T': 'a', - 'G': 0.25, 'C': 0.28} - fragmentation(fasta_dict, seq_counts, nuc_probs) - with pytest.raises(ValueError): - nuc_probs = {'A': 0.22, 'T': 0.25, - 'G': 'a', 'C': 0.28} - fragmentation(fasta_dict, seq_counts, nuc_probs) + fragmentation(fasta_dict, seq_counts, mean='a') + with pytest.raises(ValueError): - nuc_probs = {'A': 0.22, 'T': 0.25, - 'G': 0.25, 'C': 'a'} - fragmentation(fasta_dict, seq_counts, nuc_probs) + fragmentation(fasta_dict, seq_counts, std='a') diff --git a/tests/test_utils.py b/tests/test_utils.py index c7d7be2d3fb43f5511322d1bdfe2d4107ecc38ae..32cf347c7924c5b881622bcdab509d34d97f414b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,7 +2,7 @@ import argparse import pytest -from term_frag_sel.utils import check_positive, check_prob # type: ignore +from term_frag_sel.utils import check_positive # type: ignore def test_positive(): @@ -19,19 +19,3 @@ def test_positive(): check_positive("string") with pytest.raises(argparse.ArgumentTypeError): check_positive("") - - -def test_prob(): - """Test check_prob function.""" - assert check_prob("0.1") == 0.1 - assert check_prob("1") == 1.0 - with pytest.raises(argparse.ArgumentTypeError): - check_prob("0") - with pytest.raises(argparse.ArgumentTypeError): - check_prob("10") - with pytest.raises(argparse.ArgumentTypeError): - check_prob("-1") - with pytest.raises(ValueError): - check_prob("string") - with pytest.raises(ValueError): - check_prob("")