Skip to content
Snippets Groups Projects

feat: add function to calculate mean and variance

Closed Reto Tschannen requested to merge issue12 into main
Files
3
+ 77
0
 
"""Calculates mean and the variance for a given gene accross all cells."""
 
 
import csv
 
from glob import glob
 
import io
 
import os
 
 
import matplotlib.pyplot as plt
 
 
 
def mean_variance(filepath: str, output_dir: str = os.getcwd()+'/') -> str:
 
"""Given the counts observed for a given gene across all cells,
 
calculate the mean and variance.
 
 
! At the moment the function does not check the import files format,
 
be careful, and only add text files in the format geneid number_of_transcipts !
 
 
Args:
 
directory with text files of gene expression counts in individual cells.
 
Returns:
 
1. Path to Csv-formatted table with GeneID, Mean, Variance of the count, and
 
Scatterplot of mean vs variance for all genes.
 
 
Raises:
 
ValueError: If there are no files in directory
 
 
"""
 
# Open each file in the input directory, raises error if no file is found
 
files = [file for file in glob(filepath)]
 
 
if len(files) == 0:
 
raise ValueError('No files in directory:', filepath)
 
 
# Creates all required dictionaries to construct the mean and variance
 
gene_counts = {}
 
occurence = {}
 
mean = {}
 
variance = {}
 
 
# Adds together all gene counts in gene_counts, and occurences in occurence
 
for file_name in files:
 
with io.open(file_name, 'r') as fh:
 
for line in fh:
 
geneid, copies = str(line.split()[0]), int(line.split()[1])
 
if geneid not in gene_counts:
 
gene_counts[geneid] = [copies]
 
occurence[geneid] = 1
 
else:
 
gene_counts[geneid] += [copies]
 
occurence[geneid] += 1
 
# Calculates mean of each gene
 
for i in gene_counts:
 
mean[i] = sum(gene_counts[i])/occurence[i]
 
 
# Calculates the variance
 
for i in gene_counts:
 
for j in range(0, len(gene_counts[i])):
 
variance[i] = (gene_counts[i][j]-mean[i])**2/occurence[i]
 
 
# Plots mean against variance
 
plt.scatter(mean.values(), variance.values())
 
for value in list(mean.keys()):
 
plt.text(mean[value], variance[value], value)
 
plt.xlabel('mean')
 
plt.ylabel('variance')
 
plt.title('Mean gene expression vs. variance')
 
plt.savefig(output_dir+'/meanvarianceplot.png')
 
plt.show()
 
# Constructs csv file and saves it in the users directory
 
path = output_dir+'/results_mean_var_function.csv'
 
with open(path, 'w') as csv_file:
 
filewriter = csv.writer(csv_file, delimiter=',', quotechar='|',
 
quoting=csv.QUOTE_MINIMAL)
 
filewriter.writerow(['geneid', 'mean', 'variance'])
 
for id in gene_counts.keys():
 
filewriter.writerow([id, mean[id], variance[id]])
 
return output_dir
Loading