Skip to content
Snippets Groups Projects

refactor: update main and tests for CI workflow

Merged Mate Balajti requested to merge andri.fraenkl-main-patch-95359 into main
7 files
+ 194
164
Compare changes
  • Side-by-side
  • Inline
Files
7
+ 27
22
"""Tests for main module"""
"""Tests for main module."""
import numpy as np
import pandas as pd
import pandas as pd
import pytest
from tsg.main import Gtf, TranscriptGenerator, dict_to_str, str_to_dict
from tsg.main import Gtf, TranscriptGenerator, dict_to_str, str_to_dict
@@ -10,8 +8,10 @@ class TestFreeTextParsing:
@@ -10,8 +8,10 @@ class TestFreeTextParsing:
"""Test if free text dictionary is correctly parsed."""
"""Test if free text dictionary is correctly parsed."""
def test_str2dict(self):
def test_str2dict(self):
 
"""Test for str2dict function."""
res = str_to_dict(
res = str_to_dict(
'gene_id "GENE2"; transcript_id "TRANSCRIPT2"; exon_number "1"; exon_id "EXON1";'
'gene_id "GENE2"; transcript_id "TRANSCRIPT2"; \
 
exon_number "1"; exon_id "EXON1";'
)
)
assert res == {
assert res == {
@@ -22,6 +22,7 @@ class TestFreeTextParsing:
@@ -22,6 +22,7 @@ class TestFreeTextParsing:
}
}
def test_dict2str(self):
def test_dict2str(self):
 
"""Test for dict2str function."""
res = dict_to_str(
res = dict_to_str(
{
{
"gene_id": "GENE2",
"gene_id": "GENE2",
@@ -31,14 +32,17 @@ class TestFreeTextParsing:
@@ -31,14 +32,17 @@ class TestFreeTextParsing:
}
}
)
)
print(res)
print(res)
assert (
assert res == (
res
'gene_id "GENE2"; '
== 'gene_id "GENE2"; transcript_id "TRANSCRIPT2"; exon_number "1"; exon_id "EXON1";'
'transcript_id "TRANSCRIPT2"; '
 
'exon_number "1"; '
 
'exon_id "EXON1";'
)
)
class TestGtf:
class TestGtf:
"Test if Gtf class works correctly."
"""Test if Gtf class works correctly."""
 
cols = [
cols = [
"seqname",
"seqname",
"source",
"source",
@@ -52,19 +56,21 @@ class TestGtf:
@@ -52,19 +56,21 @@ class TestGtf:
]
]
def test_init(self):
def test_init(self):
 
"""Test for init function."""
annotations = Gtf()
annotations = Gtf()
annotations.read_file("tests/resources/Annotation1.gtf")
annotations.read_file("tests/resources/Annotation1.gtf")
assert annotations.parsed == False
assert annotations.parsed is False
assert annotations.original_columns == self.cols
assert annotations.original_columns == self.cols
assert annotations.free_text_columns == []
assert annotations.free_text_columns == []
def test_parsed(self):
def test_parsed(self):
 
"""Test for parsed function."""
annotations = Gtf()
annotations = Gtf()
annotations.read_file("tests/resources/Annotation1.gtf")
annotations.read_file("tests/resources/Annotation1.gtf")
annotations.parse_key_value()
annotations.parse_key_value()
assert annotations.parsed == True
assert annotations.parsed is True
assert set(annotations.free_text_columns) == set(
assert set(annotations.free_text_columns) == set(
[
[
"gene_id",
"gene_id",
@@ -75,11 +81,14 @@ class TestGtf:
@@ -75,11 +81,14 @@ class TestGtf:
]
]
)
)
assert set(annotations.original_columns) == set(
assert set(annotations.original_columns) == set(
["seqname", "source", "feature", "start", "end", "score", "strand", "frame"]
["seqname", "source", "feature", "start",
 
"end", "score", "strand", "frame"]
)
)
class TestTranscriptGenerator:
class TestTranscriptGenerator:
 
"""Test for TranscriptGenerator class."""
 
cols = [
cols = [
"start",
"start",
"end",
"end",
@@ -98,35 +107,31 @@ class TestTranscriptGenerator:
@@ -98,35 +107,31 @@ class TestTranscriptGenerator:
df2 = pd.DataFrame(columns=["start", "end", "strand"])
df2 = pd.DataFrame(columns=["start", "end", "strand"])
def test_init(self):
def test_init(self):
 
"""Test for init."""
transcripts = TranscriptGenerator("TRANSCRIPT1", 3, self.df1, 0.05)
transcripts = TranscriptGenerator("TRANSCRIPT1", 3, self.df1, 0.05)
assert transcripts.strand == "+"
assert transcripts.strand == "+"
def test_init_2(self):
with pytest.raises(AssertionError):
transcripts = TranscriptGenerator("TRANSCRIPT2", 3, self.df2, 0.05)
def test_init_3(self):
with pytest.raises(AssertionError):
transcripts = TranscriptGenerator("TRANSCRIPT1", 0, self.df1, 0.05)
def test_inclusions(self):
def test_inclusions(self):
 
"""Test for inclusions."""
transcripts = TranscriptGenerator("TRANSCRIPT1", 3, self.df1, 0.5)
transcripts = TranscriptGenerator("TRANSCRIPT1", 3, self.df1, 0.5)
res = transcripts._get_inclusions()
res = transcripts.get_inclusions()
assert res.shape == (3, 3)
assert res.shape == (3, 3)
def test_unique_inclusions(self):
def test_unique_inclusions(self):
 
"""Test for unique inclusions."""
transcripts = TranscriptGenerator("TRANSCRIPT1", 3, self.df1, 0.5)
transcripts = TranscriptGenerator("TRANSCRIPT1", 3, self.df1, 0.5)
res1, res2, res3 = transcripts._get_unique_inclusions()
transcripts.get_unique_inclusions()
def test_get_df(self):
def test_get_df(self):
 
"""Test for get_df function."""
inclusions = [False, True, False]
inclusions = [False, True, False]
expected_end = pd.Series([20, 79, 100], name="end")
expected_end = pd.Series([20, 79, 100], name="end")
transcript_id = "TRANSCRIPT1_1"
transcript_id = "TRANSCRIPT1_1"
transcripts = TranscriptGenerator("TRANSCRIPT1", 3, self.df1, 0.5)
transcripts = TranscriptGenerator("TRANSCRIPT1", 3, self.df1, 0.5)
res = transcripts._get_df(inclusions, transcript_id)
res = transcripts.get_df(inclusions, transcript_id)
assert res["transcript_id"].unique().item() == "TRANSCRIPT1_1"
assert res["transcript_id"].unique().item() == "TRANSCRIPT1_1"
assert res["strand"].unique().item() == "+"
assert res["strand"].unique().item() == "+"
Loading