Skip to content
Snippets Groups Projects

feat: add function to calculate mean and variance

Closed Reto Tschannen requested to merge issue12 into main
Files
3
+ 80
0
"""Calculates mean and the variance for a given gene accross all cells."""
import csv
from glob import glob
import io
import os
import matplotlib.pyplot as plt
def mean_variance(filepath: str, output_dir: str = os.getcwd()+'/') -> str:
"""For observed gene counts calcuclate mean and var.
At the moment the function does not check the import files format,
be careful, and only add text files in the format
geneid number_of_transcipts.
Args:
directory with text files of gene expression
counts in individual cells.
Returns:
1. Path to Csv-formatted table with GeneID, Mean, Variance of the
count, and Scatterplot of mean vs variance for all genes.
Raises:
ValueError: If there are no files in directory
"""
# Open each file in the input directory, raises error if no file is found
files = [file for file in glob(filepath)]
if len(files) == 0:
raise ValueError('No files in directory:', filepath)
# Creates all required dictionaries to construct the mean and variance
gene_counts = {}
occurence = {}
mean = {}
variance = {}
# Adds together all gene counts in gene_counts, and occurences in occurence
for file_name in files:
with io.open(file_name, 'r') as fh:
for line in fh:
geneid, copies = str(line.split()[0]), int(line.split()[1])
if geneid not in gene_counts:
gene_counts[geneid] = [copies]
occurence[geneid] = 1
else:
gene_counts[geneid] += [copies]
occurence[geneid] += 1
# Calculates mean of each gene
for i in gene_counts:
mean[i] = sum(gene_counts[i])/occurence[i]
# Calculates the variance
for i in gene_counts:
for j in range(0, len(gene_counts[i])):
variance[i] = (gene_counts[i][j]-mean[i])**2/occurence[i]
# Plots mean against variance
plt.scatter(mean.values(), variance.values())
for value in list(mean.keys()):
plt.text(mean[value], variance[value], value)
plt.xlabel('mean')
plt.ylabel('variance')
plt.title('Mean gene expression vs. variance')
plt.savefig(output_dir+'/meanvarianceplot.png')
plt.show()
# Constructs csv file and saves it in the users directory
path = output_dir+'/results_mean_var_function.csv'
with open(path, 'w') as csv_file:
filewriter = csv.writer(csv_file, delimiter=',', quotechar='|',
quoting=csv.QUOTE_MINIMAL)
filewriter.writerow(['geneid', 'mean', 'variance'])
for id in gene_counts.keys():
filewriter.writerow([id, mean[id], variance[id]])
return output_dir
Loading