diff --git a/tsg/main.py b/tsg/main.py index 27970813973d171aa0c0bd903886353642a957e8..25b7455ff21aac6d344994b20093f06a9c509a0f 100644 --- a/tsg/main.py +++ b/tsg/main.py @@ -1,5 +1,6 @@ import logging +import numpy as np import pandas as pd LOG = logging.getLogger(__name__) @@ -162,6 +163,115 @@ class Gtf: def pick_transcript(self, transcript_id: str) -> pd.DataFrame: return self.df.query(f"transcript_id == '{transcript_id}'") + +class TranscriptGenerator: + def __init__( + self, + transcript_id: str, + transcript_count: int, + transcript_df: pd.DataFrame, + prob_inclusion: float, + ): + assert len(transcript_df) > 0 + assert transcript_count > 0 + assert (prob_inclusion >= 0) and (prob_inclusion <= 1) + + self.id = transcript_id + self.count = transcript_count + self.df = transcript_df + self.no_exons = len(transcript_df) + self.strand = self.df["strand"].unique().item() + self.prob_inclusion = prob_inclusion + + def _get_inclusions(self) -> np.array: + """Generate inclusions array where each column corresponds to one sample and the number of columns corresponds to the number of samples. + + Returns: + np.array: inclusions, where True means intron inclusion + """ + inclusion_arr = np.random.rand(self.no_exons, self.count) < self.prob_inclusion + if self.strand == "+": + inclusion_arr[-1, :] = False + elif self.strand == "-": + inclusion_arr[-1, :] = False + + return inclusion_arr + + def _get_unique_inclusions(self) -> (list, np.array, np.array): + inclusion_arr = self._get_inclusions() + # Unique intron inclusion arrays and counts + inclusion_arr_unique, counts = np.unique( + inclusion_arr, axis=1, return_counts=True + ) + # Name for each generated transcript + names = [] + for i in range(inclusion_arr_unique.shape[1]): + if np.all(inclusion_arr_unique[:, i] == False, axis=0): + names.append(self.id) + else: + names.append(f"{self.id}_{i}") + + return names, inclusion_arr_unique, counts + + def _get_df(self, inclusions: np.array, transcript_id: str) -> pd.DataFrame: + """Take as input a dataframe filtered to one transcript and a boolean vector denoting intron inclusions. + + Args: + inclusions (np.array): boolean vector denoting intron inclusion + transcript_id (str): transcript id + + Returns: + pd.DataFrame: Derived dataframe + """ + df_generated = self.df.copy() + if self.strand == "+": + origninal_end = df_generated["end"] + df_generated["end"] = np.where( + inclusions, df_generated["start"].shift(periods=-1, fill_value=-1) - 1, origninal_end + ) + if self.strand == "-": + origninal_start = df_generated["start"] + df_generated["start"] = np.where( + inclusions, df_generated["end"].shift(periods=-1, fill_value=-1) + 1, origninal_start + ) + + original_id = df_generated["exon_id"] + df_generated["exon_id"] = np.where( + inclusions, + df_generated["exon_id"] + "_" + np.arange(len(df_generated)).astype(str), + original_id, + ) + + df_generated["transcript_id"] = transcript_id + return df_generated + + def generate_transcripts(self, filename: str) -> None: + """Write transcripts to file. + + Args: + filename (str): Output csv filename + """ + ids, inclusions, counts = self._get_unique_inclusions() + with open(filename, "a") as fh: + for transcript_id, transcript_count in zip(ids, counts): + fh.write(f"{transcript_id},{transcript_count}\n") + + def generate_annotations(self, filename: str) -> None: + ids, inclusions, counts = self._get_unique_inclusions() + n_unique = len(ids) + + try: + df = pd.concat( + [self._get_df(inclusions[:, i], ids[i]) for i in range(n_unique)] + ) + df = reverse_parse_free_text(df) + + write_gtf(df, filename) + LOG.info(f"Transcript {self.id} sampled") + except ValueError: + LOG.error(f"Transcript {self.id} could not be sampled.") + + def sample_transcripts( input_transcripts_file: str, input_annotations_file: str, @@ -174,3 +284,20 @@ def sample_transcripts( annotations = Gtf() annotations.read_file(input_annotations_file) annotations.parse_free_text() + + # Set up output file, write header once and append data in loop + write_header(output_annotations_file) + + for _, row in transcripts.iterrows(): + transcript_id = row["id"] + transcript_count = row["count"] + + transcript_df = annotations.pick_transcript(transcript_id) + transcripts = TranscriptGenerator( + transcript_id, + transcript_count, + transcript_df, + prob_inclusion=prob_inclusion, + ) + transcripts.generate_annotations(output_annotations_file) + transcripts.generate_transcripts(output_transcripts_file)