From 0510212c576c47c09111ae4076eb2dc261ccd0ec Mon Sep 17 00:00:00 2001
From: garion0000 <michele.garioni@unibas.ch>
Date: Fri, 17 Dec 2021 18:26:11 +0100
Subject: [PATCH] feat: add gene_count module

---
 requirements.txt  |   1 +
 src/gene_count.py | 124 ++++++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 125 insertions(+)
 create mode 100644 src/gene_count.py

diff --git a/requirements.txt b/requirements.txt
index e69de29..e317a24 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -0,0 +1 @@
+pysam 0.18.0
diff --git a/src/gene_count.py b/src/gene_count.py
new file mode 100644
index 0000000..572385f
--- /dev/null
+++ b/src/gene_count.py
@@ -0,0 +1,124 @@
+"""Module containing function to count reads per gene.
+
+Function:
+    count_reads: Take as input an aligned BAM file and an
+    annotated GTF file, and counts the reads belonging to
+    each gene.
+"""
+import logging
+import pysam
+
+from pathlib import Path
+
+
+LOG = logging.getLogger(__name__)
+
+
+def CountReads(bam_file: Path, gtf_file: Path, output_file: Path) -> None:
+    """Reads counts per gene.
+
+    Args:
+        bam_file (path): Path to aligned BAM file.
+        gtf_file (path): Path to annotated GTF file.
+        output_file (path): Path to output csv file.
+    """
+    # Store GTF annotations in a dictionary. '{chr:[(start,end,gene)]}'
+    with open(gtf_file) as f1:
+        LOG.info("Reading annotations...")
+        gtf_dc = {}
+        # Read gtf skipping header and non-gene lines
+        for line in f1:
+            if line.startswith('#!'):
+                continue
+            fields = line.split('\t')
+            if fields[2] != 'gene':
+                continue
+            # extract gene id
+            gene_id = fields[8].split(';')
+            gene_id = gene_id[0]
+            quote_indx = [i for i, c in enumerate(gene_id) if c == '"']
+            gene_id = gene_id[quote_indx[0]+1:quote_indx[1]]
+            # build gtf dictionary
+            if fields[0] not in gtf_dc:
+                gtf_dc[fields[0]] = [(fields[3], fields[4], gene_id)]
+            else:
+                gtf_dc[fields[0]].append((fields[3], fields[4], gene_id))
+    LOG.info("Annotations read.")
+
+    # Initialize empty gene count dictionary {gene:count}
+    count_dc = {}
+    # Initialize empty 'history' dictionary
+    # to detect multiple alignments {id_read:[{gene:count},n_occurence]}
+    history = {}
+    # Read trough BAM file and add counts to genes
+    # we assume that the BAM is formatted as STAR output
+    bam_seq = pysam.AlignmentFile(bam_file, 'rb', ignore_truncation=True, require_index=False)
+    LOG.info("Reading alignments...")
+    for line in bam_seq.fetch(until_eof=True):
+        line_fields = str(line).split('\t')
+        read_start = int(line_fields[3])
+        read_end = read_start+len(line_fields[9])
+        read_range = (read_start, read_end)
+        # extract the data from gtf's matching chromosome
+        genes = gtf_dc[line_fields[2]]
+        # initialize empty list to account for same read overlapping two genes
+        # it should not happen in mRNAseq... discard the read?
+        g = []
+        n = []
+        # run through the gene intervals, find the overlaps with the reads
+        # and memorize the found genes
+        flag = 0
+        for i in genes:
+            gene_range = (int(i[0]), int(i[1]))
+            if read_range[1] < gene_range[0] or read_range[0] > gene_range[1]:
+                continue
+            flag = 1
+            g.append(i[2])
+            # assign to the gene the value of the overlap between gene and read
+            overlap = range(max(read_range[0], gene_range[0]), min(read_range[1], gene_range[1]))
+            n.append(len(overlap)/len(line_fields[9]))
+            # normalize per n genes
+            n = [x/len(g) for x in n]
+        # update the count dictionary
+        if flag == 0:
+            continue
+        for g_iter in range(len(g)):
+            if g[g_iter] not in count_dc:
+                count_dc[g[g_iter]] = n[g_iter]
+            else:
+                count_dc[g[g_iter]] += n[g_iter]
+        # if is the first time we have this read, add it to the history
+        if line_fields[0] not in history:
+            history[line_fields[0]] = [dict(zip(g, n)), 1]
+        # else, update history and correct the count dictionary
+        else:
+            genes_matched = history[line_fields[0]]
+            old_occurrence = genes_matched[1]
+            # update occurrence
+            new_occurence = old_occurrence + 1
+            genes_matched[1] = new_occurence
+            # update mapping history of this read
+            for g_iter in range(len(g)):
+                if g[g_iter] not in genes_matched[0]:
+                    genes_matched[0][g[g_iter]] = n[g_iter]
+                else:
+                    genes_matched[0][g[g_iter]] += n[g_iter]
+            # correct count dictionary by mapping history
+            for (gene_m, score) in genes_matched[0].items():
+                # subtract last added score from this read
+                count_dc[gene_m] = count_dc[gene_m] - (score / old_occurrence)
+                # divide the original score by the new nuber of alignment
+                updt_count = score / new_occurence
+                # re-add the updated score to the count dictionary
+                count_dc[gene_m] = count_dc[gene_m] + updt_count
+    bam_seq.close()
+    LOG.info("Alignments read.")
+    
+    # write dictionary
+    myfile = open(output_file, 'w')
+    LOG.info('writing output...')
+    for (k, v) in count_dc.items():
+        line = ','.join([str(k), str(v)])
+        myfile.write(line + '\n')
+    myfile.close()
+    LOG.info('output written.')
-- 
GitLab