Skip to content
Snippets Groups Projects

feat: add function to calculate mean and variance

Closed Reto Tschannen requested to merge issue12 into main
Files
3
+ 74
0
import csv
from glob import glob
import io
import os
import matplotlib.pyplot as plt
def mean_variance(filepath: str, output_dir: str = os.getcwd()+'/') -> str:
"""Given the counts observed for a given gene across all cells,
calculate the mean and variance.
! At the moment the function does not check the import files format,
be careful, and only add text files in the format geneid number_of_transcipts !
Args:
directory with text files of gene expression counts in individual cells.
Returns:
1. Path to Csv-formatted table with GeneID, Mean, Variance of the count, and
Scatterplot of mean vs variance for all genes.
Raises:
ValueError: If there are no files in directory
"""
# Open each file in the input directory, raises error if no file is found
files = [file for file in glob(filepath)]
if len(files) == 0:
raise ValueError('No files in directory:', filepath)
# Creates all required dictionaries to construct the mean and variance
gene_counts = {}
occurence = {}
mean = {}
variance = {}
# Adds together all gene counts in gene_counts, and occurences in occurence
for file_name in files:
with io.open(file_name, 'r') as fh:
for line in fh:
geneid, copies = str(line.split()[0]), int(line.split()[1])
if geneid not in gene_counts:
gene_counts[geneid] = [copies]
occurence[geneid] = 1
else:
gene_counts[geneid] += [copies]
occurence[geneid] += 1
# Calculates mean of each gene
for i in gene_counts:
mean[i] = sum(gene_counts[i])/occurence[i]
# Calculates the variance
for i in gene_counts:
for j in range(0, len(gene_counts[i])):
variance[i] = (gene_counts[i][j]-mean[i])**2/occurence[i]
# Plots mean against variance
plt.scatter(mean.values(), variance.values())
for value in list(mean.keys()):
plt.text(mean[value], variance[value], value)
plt.xlabel('mean')
plt.ylabel('variance')
plt.title('Mean gene expression vs. variance')
plt.savefig(output_dir+'/meanvarianceplot.png')
plt.show()
# Constructs csv file and saves it in the users directory
path = output_dir+'/results_mean_var_function.csv'
with open(path, 'w') as csv_file:
filewriter = csv.writer(csv_file, delimiter=',', quotechar='|',
quoting=csv.QUOTE_MINIMAL)
filewriter.writerow(['geneid', 'mean', 'variance'])
for id in gene_counts.keys():
filewriter.writerow([id, mean[id], variance[id]])
return output_dir
Loading