Skip to content
Snippets Groups Projects
meanvariancefunction.py 2.42 KiB
import csv
from glob import glob
import io
import os

import matplotlib.pyplot as plt


def mean_variance(filepath: str) -> str:
    """Given the counts observed for a given gene across all cells,
    calculate the mean and variance.

    Args:
        directory with files of gene expression counts in individual cells.
    Returns:
        1. Path to Csv-formatted table with GeneID, Mean, Variance of the count.
        2. Scatterplot of mean vs variance for all genes.

    Raises:
        ValueError: If there are no files in directory
    """
    # Open each file in the input directory, raises error if no file is found
    files = [file for file in glob(filepath)]

    if len(files) == 0:
        raise ValueError('No files in directory:', filepath)

    # Creates all required dictionaries to construct the mean and variance
    gene_counts = {}
    occurence = {}
    mean = {}
    variance = {}

    # Adds together all gene counts in gene_counts, and occurences in occurence
    for file_name in files:
        with io.open(file_name, 'r') as fh:
            for line in fh:
                geneid, copies = str(line.split()[0]), int(line.split()[1])
                if geneid not in gene_counts:
                    gene_counts[geneid] = [copies]
                    occurence[geneid] = 1
                else:
                    gene_counts[geneid] += [copies]
                    occurence[geneid] += 1
    # Calculates mean of each gene
    for i in gene_counts:
        mean[i] = sum(gene_counts[i])/occurence[i]

    # Calculates the variance
    for i in gene_counts:
        for j in range(0, len(gene_counts[i])):
            variance[i] = (gene_counts[i][j]-mean[i])**2/occurence[i]

    # Plots mean against variance
    plt.scatter(mean.values(), variance.values())
    for value in list(mean.keys()):
        plt.text(mean[value], variance[value], value)
    plt.xlabel('mean')
    plt.ylabel('variance')
    plt.title('Mean gene expression vs. variance')
    plt.show()
    # Constructs csv file and saves it in the users directory
    path = os.path.expanduser("~")+'/results_mean_var_function.csv'
    with open(path, 'w') as csv_file:
        filewriter = csv.writer(csv_file, delimiter=',', quotechar='|',
                                quoting=csv.QUOTE_MINIMAL)
        filewriter.writerow(['geneid', 'mean', 'variance'])
        for id in gene_counts.keys():
            filewriter.writerow([id, mean[id], variance[id]])
    return path