Skip to content
Snippets Groups Projects

fix: add cli to sample_from_input

Closed Michele Garioni requested to merge feature into main
4 files
+ 190
4
Compare changes
  • Side-by-side
  • Inline
Files
4
src/gene_count.py 0 → 100644
+ 124
0
 
"""Module containing function to count reads per gene.
 
 
Function:
 
count_reads: Take as input an aligned BAM file and an
 
annotated GTF file, and counts the reads belonging to
 
each gene.
 
"""
 
import logging
 
import pysam
 
 
from pathlib import Path
 
 
 
LOG = logging.getLogger(__name__)
 
 
 
def CountReads(bam_file: Path, gtf_file: Path, output_file: Path) -> None:
 
"""Reads counts per gene.
 
 
Args:
 
bam_file (path): Path to aligned BAM file.
 
gtf_file (path): Path to annotated GTF file.
 
output_file (path): Path to output csv file.
 
"""
 
# Store GTF annotations in a dictionary. '{chr:[(start,end,gene)]}'
 
with open(gtf_file) as f1:
 
LOG.info("Reading annotations...")
 
gtf_dc = {}
 
# Read gtf skipping header and non-gene lines
 
for line in f1:
 
if line.startswith('#!'):
 
continue
 
fields = line.split('\t')
 
if fields[2] != 'gene':
 
continue
 
# extract gene id
 
gene_id = fields[8].split(';')
 
gene_id = gene_id[0]
 
quote_indx = [i for i, c in enumerate(gene_id) if c == '"']
 
gene_id = gene_id[quote_indx[0]+1:quote_indx[1]]
 
# build gtf dictionary
 
if fields[0] not in gtf_dc:
 
gtf_dc[fields[0]] = [(fields[3], fields[4], gene_id)]
 
else:
 
gtf_dc[fields[0]].append((fields[3], fields[4], gene_id))
 
LOG.info("Annotations read.")
 
 
# Initialize empty gene count dictionary {gene:count}
 
count_dc = {}
 
# Initialize empty 'history' dictionary
 
# to detect multiple alignments {id_read:[{gene:count},n_occurence]}
 
history = {}
 
# Read trough BAM file and add counts to genes
 
# we assume that the BAM is formatted as STAR output
 
bam_seq = pysam.AlignmentFile(bam_file, 'rb', ignore_truncation=True, require_index=False)
 
LOG.info("Reading alignments...")
 
for line in bam_seq.fetch(until_eof=True):
 
line_fields = str(line).split('\t')
 
read_start = int(line_fields[3])
 
read_end = read_start+len(line_fields[9])
 
read_range = (read_start, read_end)
 
# extract the data from gtf's matching chromosome
 
genes = gtf_dc[line_fields[2]]
 
# initialize empty list to account for same read overlapping two genes
 
# it should not happen in mRNAseq... discard the read?
 
g = []
 
n = []
 
# run through the gene intervals, find the overlaps with the reads
 
# and memorize the found genes
 
flag = 0
 
for i in genes:
 
gene_range = (int(i[0]), int(i[1]))
 
if read_range[1] < gene_range[0] or read_range[0] > gene_range[1]:
 
continue
 
flag = 1
 
g.append(i[2])
 
# assign to the gene the value of the overlap between gene and read
 
overlap = range(max(read_range[0], gene_range[0]), min(read_range[1], gene_range[1]))
 
n.append(len(overlap)/len(line_fields[9]))
 
# normalize per n genes
 
n = [x/len(g) for x in n]
 
# update the count dictionary
 
if flag == 0:
 
continue
 
for g_iter in range(len(g)):
 
if g[g_iter] not in count_dc:
 
count_dc[g[g_iter]] = n[g_iter]
 
else:
 
count_dc[g[g_iter]] += n[g_iter]
 
# if is the first time we have this read, add it to the history
 
if line_fields[0] not in history:
 
history[line_fields[0]] = [dict(zip(g, n)), 1]
 
# else, update history and correct the count dictionary
 
else:
 
genes_matched = history[line_fields[0]]
 
old_occurrence = genes_matched[1]
 
# update occurrence
 
new_occurence = old_occurrence + 1
 
genes_matched[1] = new_occurence
 
# update mapping history of this read
 
for g_iter in range(len(g)):
 
if g[g_iter] not in genes_matched[0]:
 
genes_matched[0][g[g_iter]] = n[g_iter]
 
else:
 
genes_matched[0][g[g_iter]] += n[g_iter]
 
# correct count dictionary by mapping history
 
for (gene_m, score) in genes_matched[0].items():
 
# subtract last added score from this read
 
count_dc[gene_m] = count_dc[gene_m] - (score / old_occurrence)
 
# divide the original score by the new nuber of alignment
 
updt_count = score / new_occurence
 
# re-add the updated score to the count dictionary
 
count_dc[gene_m] = count_dc[gene_m] + updt_count
 
bam_seq.close()
 
LOG.info("Alignments read.")
 
 
# write dictionary
 
myfile = open(output_file, 'w')
 
LOG.info('writing output...')
 
for (k, v) in count_dc.items():
 
line = ','.join([str(k), str(v)])
 
myfile.write(line + '\n')
 
myfile.close()
 
LOG.info('output written.')
Loading