diff --git a/modules/mol/alg/pymod/ligand_scoring.py b/modules/mol/alg/pymod/ligand_scoring.py index 6d1a25ad90c2dd65d3cde449110869b861a34027..7302741a4d7f9e423dd5d1ad0d1edb69648535c4 100644 --- a/modules/mol/alg/pymod/ligand_scoring.py +++ b/modules/mol/alg/pymod/ligand_scoring.py @@ -460,7 +460,7 @@ class LigandScorer: # " - setting to Infinity" % str(err)) # bb_rmsd = float("inf") lddt_pli_full_matrix[target_i, model_i] = { - "lddt_pli": 0, + "lddt_pli": np.nan, "lddt_pli_n_contacts": None, "rmsd": rmsd, # "symmetry_number": i, @@ -496,7 +496,7 @@ class LigandScorer: # Save results? best_lddt = lddt_pli_full_matrix[ target_i, model_i]["lddt_pli"] - if global_lddt > best_lddt: + if global_lddt > best_lddt or np.isnan(best_lddt): lddt_pli_full_matrix[target_i, model_i].update({ "lddt_pli": global_lddt, "lddt_pli_n_contacts": lddt_tot, @@ -515,8 +515,8 @@ class LigandScorer: mat1 = np.copy(mat1) mat2 = np.copy(mat2) assignments = [] - min_mat1 = mat1.min() - while min_mat1 < np.inf: + min_mat1 = LigandScorer._nanmin_nowarn(mat1) + while not np.isnan(min_mat1): best_mat1 = np.argwhere(mat1 == min_mat1) # Multiple "best" - use mat2 to disambiguate if len(best_mat1) > 1: @@ -530,18 +530,25 @@ class LigandScorer: max_i_trg, max_i_mdl = best_mat1[0] # Disable row and column - mat1[max_i_trg, :] = np.inf - mat1[:, max_i_mdl] = np.inf - mat2[max_i_trg, :] = np.inf - mat2[:, max_i_mdl] = np.inf + mat1[max_i_trg, :] = np.nan + mat1[:, max_i_mdl] = np.nan + mat2[max_i_trg, :] = np.nan + mat2[:, max_i_mdl] = np.nan # Save assignments.append((max_i_trg, max_i_mdl)) # Recompute min - min_mat1 = mat1.min() + min_mat1 = LigandScorer._nanmin_nowarn(mat1) return assignments + @staticmethod + def _nanmin_nowarn(array): + """Compute np.nanmin but ignore the RuntimeWarning.""" + with warnings.catch_warnings(): # RuntimeWarning: All-NaN slice encountered + warnings.simplefilter("ignore") + return np.nanmin(array) + @staticmethod def _reverse_lddt(lddt): """Reverse lDDT means turning it from a number between 0 and 1 to a @@ -610,7 +617,7 @@ class LigandScorer: Target ligands are in rows, model ligands in columns. - Infinite values indicate that no RMSD could be computed (i.e. different + NaN values indicate that no RMSD could be computed (i.e. different ligands). :rtype: :class:`~numpy.ndarray` @@ -620,7 +627,7 @@ class LigandScorer: if self._rmsd_matrix is None: # convert shape = self._rmsd_full_matrix.shape - self._rmsd_matrix = np.full(shape, np.inf) + self._rmsd_matrix = np.full(shape, np.nan) for i, j in np.ndindex(shape): if self._rmsd_full_matrix[i, j] is not None: self._rmsd_matrix[i, j] = self._rmsd_full_matrix[ @@ -633,7 +640,7 @@ class LigandScorer: Target ligands are in rows, model ligands in columns. - A value of 0 indicate that no lDDT-PLI could be computed (i.e. different + NaN values indicate that no lDDT-PLI could be computed (i.e. different ligands). :rtype: :class:`~numpy.ndarray` @@ -643,7 +650,7 @@ class LigandScorer: if self._lddt_pli_matrix is None: # convert shape = self._lddt_pli_full_matrix.shape - self._lddt_pli_matrix = np.zeros(shape) + self._lddt_pli_matrix = np.full(shape, np.nan) for i, j in np.ndindex(shape): if self._lddt_pli_full_matrix[i, j] is not None: self._lddt_pli_matrix[i, j] = self._lddt_pli_full_matrix[ diff --git a/modules/mol/alg/tests/test_ligand_scoring.py b/modules/mol/alg/tests/test_ligand_scoring.py index 37aa4c51423d3c9ea4fb2d7b029e7e18c862c018..7586a67fab5f0c0655df0972b1fdc6d62e2e6481 100644 --- a/modules/mol/alg/tests/test_ligand_scoring.py +++ b/modules/mol/alg/tests/test_ligand_scoring.py @@ -251,23 +251,23 @@ class TestLigandScoring(unittest.TestCase): # Check RMSD assert sc.rmsd_matrix.shape == (7, 1) np.testing.assert_almost_equal(sc.rmsd_matrix, np.array( - [[np.inf], + [[np.nan], [0.04244993], - [np.inf], - [np.inf], - [np.inf], + [np.nan], + [np.nan], + [np.nan], [0.29399303], - [np.inf]]), decimal=5) + [np.nan]]), decimal=5) # Check lDDT-PLI self.assertEqual(sc.lddt_pli_matrix.shape, (7, 1)) - self.assertEqual(sc.lddt_pli_matrix[0, 0], 0) + self.assertTrue(np.isnan(sc.lddt_pli_matrix[0, 0])) self.assertAlmostEqual(sc.lddt_pli_matrix[1, 0], 0.99843, 5) - self.assertEqual(sc.lddt_pli_matrix[2, 0], 0) - self.assertEqual(sc.lddt_pli_matrix[3, 0], 0) - self.assertEqual(sc.lddt_pli_matrix[4, 0], 0) + self.assertTrue(np.isnan(sc.lddt_pli_matrix[2, 0])) + self.assertTrue(np.isnan(sc.lddt_pli_matrix[3, 0])) + self.assertTrue(np.isnan(sc.lddt_pli_matrix[4, 0])) self.assertAlmostEqual(sc.lddt_pli_matrix[5, 0], 1.0) - self.assertEqual(sc.lddt_pli_matrix[6, 0], 0) + self.assertTrue(np.isnan(sc.lddt_pli_matrix[6, 0])) def test_check_resnames(self): """Test check_resname argument works