From 1628ccb38cf8c92650b3cf4a638a95f34b3d2636 Mon Sep 17 00:00:00 2001
From: Lorenzo Pantolini <lorenzo.pantolini@unibas.ch>
Date: Thu, 8 Aug 2024 14:21:25 +0200
Subject: [PATCH] fix cosine_sim_matrix input

---
 eba/alignments.py     | 4 ++--
 eba/score_matrices.py | 5 +++--
 eba_example.py        | 2 +-
 3 files changed, 6 insertions(+), 5 deletions(-)

diff --git a/eba/alignments.py b/eba/alignments.py
index beb6676..1f7ba41 100644
--- a/eba/alignments.py
+++ b/eba/alignments.py
@@ -3,7 +3,7 @@ import numba as nb
 
 MIN_FLOAT64 = np.finfo(np.float64).min
 
-@nb.njit(cache=False)
+@nb.njit
 def _make_dtw_matrix(
     score_matrix: np.ndarray,
     gap_open_penalty: float = 0.0,
@@ -78,7 +78,7 @@ def _make_dtw_matrix(
     return matrix, backtrack
 
 
-@nb.njit(cache=False)
+@nb.njit
 def _get_dtw_alignment(start_direction, backtrack: np.ndarray, n1, m1):
     """
     Finds optimal warping path from a backtrack matrix
diff --git a/eba/score_matrices.py b/eba/score_matrices.py
index 6534eec..6cacfdb 100644
--- a/eba/score_matrices.py
+++ b/eba/score_matrices.py
@@ -1,4 +1,5 @@
 import torch
+import numpy as np
 from scipy import spatial
 
 def compute_similarity_matrix(embedding1, embedding2, l=1, p=2):
@@ -48,7 +49,7 @@ def compute_similarity_matrix_plain(embedding1, embedding2, l=1, p=2):
 
 
 
-def compute_cosine_similarity_matrix(embedding1, embedding2, l=1, p=2):
+def compute_cosine_similarity_matrix(embedding1, embedding2):
     """ Take as input 2 sequence embeddings (at a residue level) and returns the cosine similarity matrix
         with the signal enhancement based on Z-scores. The signal enhancement seems to be redundant 
         when used with the cosine similarity score, therefore we don't recommend this version.
@@ -84,5 +85,5 @@ def compute_cosine_similarity_matrix_plain(embedding1, embedding2):
         :type embedding2: pytorch tensor
     """
     
-    return torch.tensor(1-spatial.distance.cdist(embedding1, embedding2, 'cosine'))
+    return torch.tensor(1-spatial.distance.cdist(embedding1.cpu().numpy(), embedding2.cpu().numpy(), 'cosine'))
 
diff --git a/eba_example.py b/eba_example.py
index f1fde89..83d35f5 100644
--- a/eba_example.py
+++ b/eba_example.py
@@ -20,7 +20,7 @@ print(emb1.shape)
 similarity_matrix = sm.compute_similarity_matrix(emb1, emb2)
 eba_results = methods.compute_eba(similarity_matrix)
 ### to return the alignment itself use:
-#eba_results = eba.EBA(similarity_matrix, extensive_output=True)
+#eba_results = methods.compute_eba(similarity_matrix, extensive_output=True)
 
 ### show results
 print('EBA raw: ', eba_results['EBA_raw'])
-- 
GitLab