Skip to content
Snippets Groups Projects
meanvariancefunction.py 2.69 KiB
"""Calculates mean and the variance for a given gene accross all cells."""

import csv
from glob import glob
import io
import os

import matplotlib.pyplot as plt


def mean_variance(filepath: str, output_dir: str = os.getcwd()+'/') -> str:
    """For observed gene counts calcuclate mean and var.

    At the moment the function does not check the import files format,
    be careful, and only add text files in the format
    geneid number_of_transcipts.

    Args:
        directory with text files of gene expression
        counts in individual cells.

    Returns:
        1. Path to Csv-formatted table with GeneID, Mean, Variance of the
        count, and Scatterplot of mean vs variance for all genes.

    Raises:
        ValueError: If there are no files in directory

    """
    # Open each file in the input directory, raises error if no file is found
    files = [file for file in glob(filepath)]

    if len(files) == 0:
        raise ValueError('No files in directory:', filepath)

    # Creates all required dictionaries to construct the mean and variance
    gene_counts = {}
    occurence = {}
    mean = {}
    variance = {}

    # Adds together all gene counts in gene_counts, and occurences in occurence
    for file_name in files:
        with io.open(file_name, 'r') as fh:
            for line in fh:
                geneid, copies = str(line.split()[0]), int(line.split()[1])
                if geneid not in gene_counts:
                    gene_counts[geneid] = [copies]
                    occurence[geneid] = 1
                else:
                    gene_counts[geneid] += [copies]
                    occurence[geneid] += 1
    # Calculates mean of each gene
    for i in gene_counts:
        mean[i] = sum(gene_counts[i])/occurence[i]

    # Calculates the variance
    for i in gene_counts:
        for j in range(0, len(gene_counts[i])):
            variance[i] = (gene_counts[i][j]-mean[i])**2/occurence[i]

    # Plots mean against variance
    plt.scatter(mean.values(), variance.values())
    for value in list(mean.keys()):
        plt.text(mean[value], variance[value], value)
    plt.xlabel('mean')
    plt.ylabel('variance')
    plt.title('Mean gene expression vs. variance')
    plt.savefig(output_dir+'/meanvarianceplot.png')
    plt.show()
    # Constructs csv file and saves it in the users directory
    path = output_dir+'/results_mean_var_function.csv'
    with open(path, 'w') as csv_file:
        filewriter = csv.writer(csv_file, delimiter=',', quotechar='|',
                                quoting=csv.QUOTE_MINIMAL)
        filewriter.writerow(['geneid', 'mean', 'variance'])
        for id in gene_counts.keys():
            filewriter.writerow([id, mean[id], variance[id]])
    return output_dir