"""Tests the transcript structure generator module.""" import pytest from transcript_structure import Generate_transcript_structure as Gts TEST_CSV_TITLE = './tests/resources/test_transcript_structure/Rik_5_Rp1_5_title.csv' TEST_CSV_NO_TITLE = './tests/resources/test_transcript_structure/Rik_5_Rp1_5_no_title.csv' GENE_COORDS = './tests/resources/test_transcript_structure/RP1_RIK.gtf' GENE_KEYS = ['Rp1', '1700034P13Rik'] P_INTRON_0: float = 0 P_INTRON_0_2 = 0.2 P_INTRON_1: float = 1 @pytest.mark.parametrize( "test_input", [(TEST_CSV_TITLE, GENE_COORDS, P_INTRON_0), (TEST_CSV_NO_TITLE, GENE_COORDS, P_INTRON_0) ], ) def test_csv_2_dict(test_input): builder = Gts.BuildTranscriptStructure(test_input[0], test_input[1], test_input[2]) builder.__csv_2_dict() with open(TEST_CSV_TITLE) as csv: csv_lines = csv.readlines() first_line = csv_lines[0].split(',') if not first_line[1].isnumeric(): del(csv_lines[0]) # Removes title. csv_lines[-1] = ''.join([csv_lines[-1], '\n']) # Adds \n to last line of csv. keys = list(builder.gene_count_dict.keys()) for index, line in enumerate(csv_lines): dic_line = ''.join([keys[index], ',', str(builder.gene_count_dict[keys[index]]), '\n']) assert line == dic_line def test_gtf_2_dict(): builder = Gts.BuildTranscriptStructure(TEST_CSV_TITLE, GENE_COORDS, P_INTRON_0) builder.__gtf_2_dict() assert len(builder.gene_sequences_dict) == 2 # Two genes read in the dictionary. assert len(builder.gene_sequences_dict[GENE_KEYS[0]]) == 5 assert len(builder.gene_sequences_dict[GENE_KEYS[1]]) == 5 gene_line_rik = ('1\thavana\tgene\t9747648\t9791924\t.\t+\t.\tgene_id "ENSMUSG00000097893"; gene_version "8"; ' 'gene_name "1700034P13Rik"; gene_source "havana"; gene_biotype "lncRNA";\n') assert gene_line_rik == builder.gene_sequences_dict[GENE_KEYS[1]]['gene_line'] with open(GENE_COORDS) as gtf: lines = gtf.readlines() numb_exons_gtf = len(lines) - 4 # 2x exon + transcript line numb_exons_dict = 0 for gene_key in GENE_KEYS: numb_exons_dict += len(builder.gene_sequences_dict[gene_key]['exon_seq']) assert numb_exons_gtf == numb_exons_dict @pytest.mark.parametrize( "test_input", [(TEST_CSV_TITLE, GENE_COORDS, P_INTRON_0), (TEST_CSV_TITLE, GENE_COORDS, P_INTRON_1) ], ) def test_make_new_transcripts(test_input): builder = Gts.BuildTranscriptStructure(test_input[0], test_input[1], test_input[2]) builder.__csv_2_dict() # Generates dictionary from gene count csv file. builder.__gtf_2_dict() # Generates dictionary from gtf file. builder.__make_new_transcripts() # Generates the differently spliced transcripts. numb_trans_dict = 0 numb_trans_csv = 10 for gene_key in GENE_KEYS: for trans_id in builder.gene_transcript_dict[gene_key]: numb_trans_dict += builder.gene_transcript_dict[gene_key][trans_id] assert numb_trans_csv == numb_trans_csv for gene_key in GENE_KEYS: assert len(builder.gene_transcript_dict[gene_key]) == 1 # All have identical transcript IDs. @pytest.mark.parametrize( "test_input", [(TEST_CSV_TITLE, GENE_COORDS, P_INTRON_0), (TEST_CSV_TITLE, GENE_COORDS, P_INTRON_1) ], ) def test_make_gtf_lines(test_input): builder = Gts.BuildTranscriptStructure(test_input[0], test_input[1], test_input[2]) builder.__csv_2_dict() # Generates dictionary from gene count csv file. builder.__gtf_2_dict() # Generates dictionary from gtf file. builder.__make_new_transcripts() # Generates the differently spliced transcripts. builder.__make_gtf_info() for line in builder.gtf_lines: columns = line.split('\t') assert columns[3] < columns[4] # Tests that the coordinates are increasing. pass @pytest.mark.parametrize( "test_input", [(TEST_CSV_TITLE, GENE_COORDS, P_INTRON_0), (TEST_CSV_TITLE, GENE_COORDS, P_INTRON_1) ], ) def test_sort_gtf_lines(test_input): builder = Gts.BuildTranscriptStructure(test_input[0], test_input[1], test_input[2]) builder.__csv_2_dict() # Generates dictionary from gene count csv file. builder.__gtf_2_dict() # Generates dictionary from gtf file. builder.__make_new_transcripts() # Generates the differently spliced transcripts. builder.__make_gtf_info() starts_before = [] # Verifies that the function actually has to sort. for line in builder.gtf_lines: columns = line.split('\t') if columns[2] == 'gene': starts_before.append(columns[3]) for ii in range(len(starts_before)-1): assert starts_before[ii] > starts_before[ii+1] builder.__sort_gtf_lines() builder.__sort_gtf_lines() starts_after = [] # Verifies that the function sorted. for line in builder.gtf_lines: columns = line.split('\t') if columns[2] == 'gene': starts_after.append(columns[3]) for ii in range(len(starts_before)-1): assert starts_after[ii] < starts_after[ii+1] builder.__sort_gtf_lines() def test_write_gtf(): pass def test_write_csv(): pass