Skip to content
Snippets Groups Projects
polyA_classifier.py 3.72 KiB
"""Module for classifying polyA tails as internal or real."""

import torch
from torch.nn import Linear, ReLU, Sequential, MaxPool1d, Module, BatchNorm1d, Conv1d
import numpy as np
from typing import Union


class Net(Module):
    """Two layer 1D convolutional neural net."""

    def __init__(self):
        """Returns Net object."""
        super(Net, self).__init__()

        self.cnn_layers = Sequential(
            # Defining a 1D convolution layer
            Conv1d(4, 4, kernel_size=3, stride=1, padding=1),
            BatchNorm1d(4),
            ReLU(inplace=True),
            MaxPool1d(kernel_size=2, stride=2),
            # Defining another 1D convolution layer
            Conv1d(4, 4, kernel_size=3, stride=1, padding=1),
            BatchNorm1d(4),
            ReLU(inplace=True),
            MaxPool1d(kernel_size=2, stride=2),
        )

        self.linear_layers = Sequential(
            Linear(4 * 50, 10)
        )

    def forward(self, x):
        """Forward pass function."""
        x = self.cnn_layers(x)
        x = x.view(x.size(0), -1)
        x = self.linear_layers(x)
        return x


class PolyAClassifier:
    """Classifier object using the state-dict of a pretrained pytorch model."""

    enum = {
        'A': [1, 0, 0, 0],
        'U': [0, 1, 0, 0],
        'T': [0, 1, 0, 0],
        'G': [0, 0, 1, 0],
        'C': [0, 0, 0, 1]
    }

    def __init__(self, model=Net, state_dict_path: str = './models/internal_priming.pth'):
        """Returns a stateless classifier with the model loaded.

        Args:
            model: An object subclassing the pytorch Module
            state_dict_path: A path to a saved state-dict of said object at a trained state.
        """
        self.model = model()
        self.model.load_state_dict(torch.load(state_dict_path))

    def classify(self, sequence: Union[str, list[str]]) -> Union[int, list[int]]:
        """Classify a sequence of bases to be either an internal read of a real polyA tail.

        Args:
            sequence : input sequence or list of input sequences with length of 200 characters.

        Returns:
            predictions for input sequence(s) with 1 for real PolyA and 0 for internal priming.

        Raises:
            TypeError: If sequence is not str or list(str)
            ValueError: If some or all sequences are not of length 200
            ValueError: If non-allowed letters in string

        """
        if type(sequence) is list:

            sequences = [list(seq) for seq in sequence]

        elif type(sequence) is str:

            sequences = [list(sequence)]

        else:

            raise TypeError('Type of sequence input wrong, should be str or list(str).')

        enum_seqs = []

        try:
            for s in sequences:
                enum_sequence = [self.enum[key.upper()] for key in s]
                enum_seqs.append(enum_sequence)
        except KeyError:
            raise ValueError('String contains non-allowed Letters: only A, T, U, G, C allowed')

        # convert to ndarray and reshape for pytorch

        try:
            test = np.array(enum_seqs, dtype=np.float32)
        except ValueError:
            raise ValueError('Not all sequences of length 200')

        test_shape = test.shape
        test = test.reshape(test_shape[0], 4, test_shape[1])

        if test_shape[1] != 200:
            raise ValueError('Sequences not of length 200')

        tens = torch.from_numpy(test)

        # make prediction
        with torch.no_grad():
            output = self.model(tens.cpu())

        softmax = torch.exp(output).cpu()
        prob = list(softmax.numpy())
        predictions = np.argmax(prob, axis=1)

        if len(predictions) == 1:

            return predictions[0]

        else:

            return predictions.tolist()