Skip to content
Snippets Groups Projects
Commit 5eef04d8 authored by TheRiPtide's avatar TheRiPtide
Browse files

fix: mypy now passes

parent 728d941b
No related branches found
No related tags found
1 merge request!23feat: deep-leaning poly(A) classifier
Pipeline #13792 failed
......@@ -49,7 +49,7 @@ class PolyAClassifier:
'C': 1.0
}
def __init__(self, model: Module = Net, state_dict_path: str = './models/internal_priming.pth'):
def __init__(self, model = Net, state_dict_path: str = './models/internal_priming.pth'):
"""Returns a stateless classifier with the model loaded.
Args:
......@@ -108,11 +108,11 @@ class PolyAClassifier:
if test_shape[1] != 200:
raise ValueError('Sequences not of length 200')
test = torch.from_numpy(test)
tens = torch.from_numpy(test)
# make prediction
with torch.no_grad():
output = self.model(test.cpu())
output = self.model(tens.cpu())
softmax = torch.exp(output).cpu()
prob = list(softmax.numpy())
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment