import re

import numpy as np
import pandas as pd


def fasta_process(fasta_file):
    with open(fasta_file, "r") as f:
        lines = f.readlines()

        ident_pattern = re.compile('>(\S+)')
        seq_pattern = re.compile('^(\S+)$')

        genes = {}
        for line in lines:
            if ident_pattern.search(line):
                seq_id = (ident_pattern.search(line)).group(1)
            elif seq_id in genes.keys():
                genes[seq_id] += (seq_pattern.search(line)).group(1)
            else:
                genes[seq_id] = (seq_pattern.search(line)).group(1)
    return genes

def fragmentation(fasta_file, counts_file, mean_length, std):
    fasta = fasta_process(fasta_file)
    seq_counts = pd.read_csv(counts_file, names = ["seqID", "count"])

    nucs = ['A','T','G','C']
    mononuc_freqs = [0.22, 0.25, 0.23, 0.30]

    term_frags = [] 
    for seq_id, seq in fasta.items():
        counts = seq_counts[seq_counts["seqID"] == seq_id]["count"]
        for _ in range(counts): 
            n_cuts = int(len(seq)/mean_length)
            
            # non-uniformly random DNA fragmentation implementation based on https://www.nature.com/articles/srep04532#Sec1
            # assume fragmentation by sonication for NGS workflow
            cuts = []
            cut_nucs = np.random.choice(nucs, n_cuts, p=mononuc_freqs) 
            for nuc in cut_nucs:
                nuc_pos = [x.start() for x in re.finditer(nuc, seq)]
                pos = np.random.choice(nuc_pos)
                while pos in cuts:
                    pos = np.random.choice(nuc_pos)
                cuts.append(pos)

            cuts.sort() 
            cuts.insert(0,0)
            term_frag = ""
            for i, val in enumerate(cuts):
                if i == len(cuts)-1:
                    fragment = seq[val+1:cuts[-1]]
                else:
                    fragment = seq[val:cuts[i+1]]
                if mean_length-std <= len(fragment) <= mean_length+std:
                    term_frag = fragment
            if term_frag == "":
                continue
            else:
                term_frags.append(term_frag)
    return term_frags