"""Tests the transcript structure generator module."""

import pytest
from transcript_structure import Generate_transcript_structure as Gts

TEST_CSV_TITLE = './tests/resources/test_transcript_structure/Rik_5_Rp1_5_title.csv'
TEST_CSV_NO_TITLE = './tests/resources/test_transcript_structure/Rik_5_Rp1_5_no_title.csv'

GENE_COORDS = './tests/resources/test_transcript_structure/RP1_RIK.gtf'
GENE_KEYS = ['Rp1', '1700034P13Rik']

P_INTRON_0: float = 0
P_INTRON_0_2 = 0.2
P_INTRON_1: float = 1


@pytest.mark.parametrize(
        "test_input",
        [(TEST_CSV_TITLE, GENE_COORDS, P_INTRON_0),
         (TEST_CSV_NO_TITLE, GENE_COORDS, P_INTRON_0)
         ],
)
def test_csv_2_dict(test_input):
    builder = Gts.BuildTranscriptStructure(test_input[0], test_input[1], test_input[2])
    builder.__csv_2_dict()
    with open(TEST_CSV_TITLE) as csv:
        csv_lines = csv.readlines()
    first_line = csv_lines[0].split(',')
    if not first_line[1].isnumeric():
        del(csv_lines[0])  # Removes title.
    csv_lines[-1] = ''.join([csv_lines[-1], '\n'])  # Adds \n to last line of csv.

    keys = list(builder.gene_count_dict.keys())
    for index, line in enumerate(csv_lines):
        dic_line = ''.join([keys[index], ',', str(builder.gene_count_dict[keys[index]]), '\n'])
        assert line == dic_line


def test_gtf_2_dict():
    builder = Gts.BuildTranscriptStructure(TEST_CSV_TITLE, GENE_COORDS, P_INTRON_0)
    builder.__gtf_2_dict()
    assert len(builder.gene_sequences_dict) == 2  # Two genes read in the dictionary.
    assert len(builder.gene_sequences_dict[GENE_KEYS[0]]) == 5
    assert len(builder.gene_sequences_dict[GENE_KEYS[1]]) == 5
    gene_line_rik = ('1\thavana\tgene\t9747648\t9791924\t.\t+\t.\tgene_id "ENSMUSG00000097893"; gene_version "8"; '
                     'gene_name "1700034P13Rik"; gene_source "havana"; gene_biotype "lncRNA";\n')
    assert gene_line_rik == builder.gene_sequences_dict[GENE_KEYS[1]]['gene_line']

    with open(GENE_COORDS) as gtf:
        lines = gtf.readlines()
        numb_exons_gtf = len(lines) - 4  # 2x exon + transcript line
    numb_exons_dict = 0
    for gene_key in GENE_KEYS:
        numb_exons_dict += len(builder.gene_sequences_dict[gene_key]['exon_seq'])
    assert numb_exons_gtf == numb_exons_dict


@pytest.mark.parametrize(
        "test_input",
        [(TEST_CSV_TITLE, GENE_COORDS, P_INTRON_0),
         (TEST_CSV_TITLE, GENE_COORDS, P_INTRON_1)
         ],
)
def test_make_new_transcripts(test_input):
    builder = Gts.BuildTranscriptStructure(test_input[0], test_input[1], test_input[2])
    builder.__csv_2_dict()  # Generates dictionary from gene count csv file.
    builder.__gtf_2_dict()  # Generates dictionary from gtf file.
    builder.__make_new_transcripts()  # Generates the differently spliced transcripts.

    numb_trans_dict = 0
    numb_trans_csv = 10
    for gene_key in GENE_KEYS:
        for trans_id in builder.gene_transcript_dict[gene_key]:
            numb_trans_dict += builder.gene_transcript_dict[gene_key][trans_id]
    assert numb_trans_csv == numb_trans_csv

    for gene_key in GENE_KEYS:
        assert len(builder.gene_transcript_dict[gene_key]) == 1  # All have identical transcript IDs.


@pytest.mark.parametrize(
        "test_input",
        [(TEST_CSV_TITLE, GENE_COORDS, P_INTRON_0),
         (TEST_CSV_TITLE, GENE_COORDS, P_INTRON_1)
         ],
)
def test_make_gtf_lines(test_input):
    builder = Gts.BuildTranscriptStructure(test_input[0], test_input[1], test_input[2])
    builder.__csv_2_dict()  # Generates dictionary from gene count csv file.
    builder.__gtf_2_dict()  # Generates dictionary from gtf file.
    builder.__make_new_transcripts()  # Generates the differently spliced transcripts.
    builder.__make_gtf_info()
    for line in builder.gtf_lines:
        columns = line.split('\t')
        assert columns[3] < columns[4]  # Tests that the coordinates are increasing.
    pass


@pytest.mark.parametrize(
        "test_input",
        [(TEST_CSV_TITLE, GENE_COORDS, P_INTRON_0),
         (TEST_CSV_TITLE, GENE_COORDS, P_INTRON_1)
         ],
)
def test_sort_gtf_lines(test_input):
    builder = Gts.BuildTranscriptStructure(test_input[0], test_input[1], test_input[2])
    builder.__csv_2_dict()  # Generates dictionary from gene count csv file.
    builder.__gtf_2_dict()  # Generates dictionary from gtf file.
    builder.__make_new_transcripts()  # Generates the differently spliced transcripts.
    builder.__make_gtf_info()

    starts_before = []  # Verifies that the function actually has to sort.
    for line in builder.gtf_lines:
        columns = line.split('\t')
        if columns[2] == 'gene':
            starts_before.append(columns[3])
    for ii in range(len(starts_before)-1):
        assert starts_before[ii] > starts_before[ii+1]
    builder.__sort_gtf_lines()

    builder.__sort_gtf_lines()
    starts_after = []  # Verifies that the function sorted.
    for line in builder.gtf_lines:
        columns = line.split('\t')
        if columns[2] == 'gene':
            starts_after.append(columns[3])
    for ii in range(len(starts_before)-1):
        assert starts_after[ii] < starts_after[ii+1]
    builder.__sort_gtf_lines()


def test_write_gtf():
    pass


def test_write_csv():
    pass