diff --git a/modules/mol/alg/pymod/ligand_scoring.py b/modules/mol/alg/pymod/ligand_scoring.py index 008719d187af5161095fd75effb6d3ce5aedf825..5e5f4742a61cdf5830891233b3e067bcc9efc8d4 100644 --- a/modules/mol/alg/pymod/ligand_scoring.py +++ b/modules/mol/alg/pymod/ligand_scoring.py @@ -540,49 +540,73 @@ class LigandScorer: min_mat1 = mat1.min() return assignments - def _assign_ligands_rmsd(self): - """Assign (map) ligands between model and target + @staticmethod + def _reverse_lddt(lddt): + """Reverse lDDT means turning it from a number between 0 and 1 to a + number between infinity and 0 (0 being better). + + In practice, this is 1/lDDT. If lDDT is 0, the result is infinity. """ - # Transform lddt_pli to be on the scale of RMSD with warnings.catch_warnings(): # RuntimeWarning: divide by zero warnings.simplefilter("ignore") - mat2 = np.float64(1) / self.lddt_pli_matrix + return np.float64(1) / lddt + + def _assign_ligands_rmsd(self): + """Assign (map) ligands between model and target. - assignments = self._find_ligand_assignment(self.rmsd_matrix, mat2) - self._rmsd = {} - self._rmsd_assignment = {} - self._rmsd_details = {} + Sets self._rmsd, self._rmsd_assignment and self._rmsd_details. + """ + mat2 = self._reverse_lddt(self.lddt_pli_matrix) + + mat_tuple = self._assign_matrices(self.rmsd_matrix, + mat2, + self._rmsd_full_matrix, + "rmsd") + self._rmsd = mat_tuple[0] + self._rmsd_assignment = mat_tuple[1] + self._rmsd_details = mat_tuple[2] + + def _assign_matrices(self, mat1, mat2, data, main_key): + """ + :param mat1: the main ligand assignment criteria (RMSD or lDDT-PLI) + :param mat2: the secondary ligand assignment criteria (lDDT-PLI or RMSD) + :param data: the data (either self._rmsd_full_matrix or self._lddt_pli_matrix + :return: a tuple with 3 dictionaries of matrices containing the main + data, assignement, and details, respectively. + """ + assignments = self._find_ligand_assignment(mat1, mat2) + out_main = {} + out_assignment = {} + out_details = {} for assignment in assignments: trg_idx, mdl_idx = assignment - mdl_lig_qname = self.model_ligands[mdl_idx].qualified_name - self._rmsd[mdl_lig_qname] = self._rmsd_full_matrix[ - trg_idx, mdl_idx]["rmsd"] - self._rmsd_assignment[mdl_lig_qname] = self._rmsd_full_matrix[ + mdl_lig = self.model_ligands[mdl_idx] + mdl_cname = mdl_lig.chain.name + mdl_restuple = (mdl_lig.number.num, mdl_lig.number.ins_code) + if mdl_cname not in out_main: + out_main[mdl_cname] = {} + out_assignment[mdl_cname] = {} + out_details[mdl_cname] = {} + out_main[mdl_cname][mdl_restuple] = data[ + trg_idx, mdl_idx][main_key] + out_assignment[mdl_cname][mdl_restuple] = data[ trg_idx, mdl_idx]["target_ligand"].qualified_name - self._rmsd_details[mdl_lig_qname] = self._rmsd_full_matrix[ + out_details[mdl_cname][mdl_restuple] = data[ trg_idx, mdl_idx] + return out_main, out_assignment, out_details def _assign_ligands_lddt_pli(self): """ Assign ligands based on lDDT-PLI. """ - # Transform lddt_pli to be on the scale of RMSD - with warnings.catch_warnings(): # RuntimeWarning: divide by zero - warnings.simplefilter("ignore") - mat1 = np.float64(1) / self.lddt_pli_matrix - - assignments = self._find_ligand_assignment(mat1, self.rmsd_matrix) - self._lddt_pli = {} - self._lddt_pli_assignment = {} - self._lddt_pli_details = {} - for assignment in assignments: - trg_idx, mdl_idx = assignment - mdl_lig_qname = self.model_ligands[mdl_idx].qualified_name - self._lddt_pli[mdl_lig_qname] = self._lddt_pli_full_matrix[ - trg_idx, mdl_idx]["lddt_pli"] - self._lddt_pli_assignment[mdl_lig_qname] = self._lddt_pli_full_matrix[ - trg_idx, mdl_idx]["target_ligand"].qualified_name - self._lddt_pli_details[mdl_lig_qname] = self._lddt_pli_full_matrix[ - trg_idx, mdl_idx] + mat1 = self._reverse_lddt(self.lddt_pli_matrix) + + mat_tuple = self._assign_matrices(mat1, + self.rmsd_matrix, + self._lddt_pli_full_matrix, + "lddt_pli") + self._lddt_pli = mat_tuple[0] + self._lddt_pli_assignment = mat_tuple[1] + self._lddt_pli_details = mat_tuple[2] @property def rmsd_matrix(self): diff --git a/modules/mol/alg/tests/test_ligand_scoring.py b/modules/mol/alg/tests/test_ligand_scoring.py index 4bff9e812cbbf3d3c191aa335923446ea1db4395..0116a948dbe26cc061351bb0901cd8d0db92e11a 100644 --- a/modules/mol/alg/tests/test_ligand_scoring.py +++ b/modules/mol/alg/tests/test_ligand_scoring.py @@ -24,6 +24,11 @@ class TestLigandScoring(unittest.TestCase): mdl, mdl_seqres = io.LoadMMCIF(os.path.join('testfiles', "P84080_model_02.cif.gz"), seqres=True) sc = LigandScorer(mdl, trg, None, None) + # import ipdb; ipdb.set_trace() + # import ost.mol.alg.scoring + # scr = ost.mol.alg.scoring.Scorer(sc.model, sc.target) + # scr.lddt + # scr.local_lddt assert len(sc.target_ligands) == 7 assert len(sc.model_ligands) == 1 @@ -284,7 +289,7 @@ class TestLigandScoring(unittest.TestCase): trg_4c0a, _ = io.LoadMMCIF(os.path.join('testfiles', "4c0a.cif.gz"), seqres=True) sc = LigandScorer(trg, trg_4c0a, None, None, check_resnames=False) - expected_keys = {"J.G3D1", "F.G3D1"} + expected_keys = {"J", "F"} self.assertFalse(expected_keys.symmetric_difference(sc.rmsd.keys())) self.assertFalse(expected_keys.symmetric_difference(sc.rmsd_assignment.keys())) self.assertFalse(expected_keys.symmetric_difference(sc.rmsd_details.keys())) @@ -293,42 +298,44 @@ class TestLigandScoring(unittest.TestCase): self.assertFalse(expected_keys.symmetric_difference(sc.lddt_pli_details.keys())) # rmsd - self.assertAlmostEqual(sc.rmsd["J.G3D1"], 3.8498, 3) - self.assertAlmostEqual(sc.rmsd["F.G3D1"], 57.6295, 3) + self.assertAlmostEqual(sc.rmsd["J"][(1, "\x00")], 3.8498, 3) + self.assertAlmostEqual(sc.rmsd["F"][(1, "\x00")], 57.6295, 3) # rmsd_assignment - self.assertEqual(sc.rmsd_assignment, {'J.G3D1': 'L.G3D1', 'F.G3D1': 'I.G3D1'}) + self.assertEqual(sc.rmsd_assignment, {'J': {(1, "\x00"): 'L.G3D1'}, + 'F': {(1, "\x00"): 'I.G3D1'}}) # rmsd_details - self.assertEqual(sc.rmsd_details["J.G3D1"]["chain_mapping"], {'A': 'B', 'H': 'C'}) - self.assertEqual(sc.rmsd_details["J.G3D1"]["bs_num_res"], 16) - self.assertEqual(sc.rmsd_details["J.G3D1"]["bs_num_overlap_res"], 16) - self.assertEqual(sc.rmsd_details["J.G3D1"]["target_ligand"].qualified_name, 'L.G3D1') - self.assertEqual(sc.rmsd_details["J.G3D1"]["model_ligand"].qualified_name, 'J.G3D1') - self.assertEqual(sc.rmsd_details["F.G3D1"]["chain_mapping"], {'F': 'B', 'C': 'C'}) - self.assertEqual(sc.rmsd_details["F.G3D1"]["bs_num_res"], 15) - self.assertEqual(sc.rmsd_details["F.G3D1"]["bs_num_overlap_res"], 15) - self.assertEqual(sc.rmsd_details["F.G3D1"]["target_ligand"].qualified_name, 'I.G3D1') - self.assertEqual(sc.rmsd_details["F.G3D1"]["model_ligand"].qualified_name, 'F.G3D1') + self.assertEqual(sc.rmsd_details["J"][(1, "\x00")]["chain_mapping"], {'A': 'B', 'H': 'C'}) + self.assertEqual(sc.rmsd_details["J"][(1, "\x00")]["bs_num_res"], 16) + self.assertEqual(sc.rmsd_details["J"][(1, "\x00")]["bs_num_overlap_res"], 16) + self.assertEqual(sc.rmsd_details["J"][(1, "\x00")]["target_ligand"].qualified_name, 'L.G3D1') + self.assertEqual(sc.rmsd_details["J"][(1, "\x00")]["model_ligand"].qualified_name, 'J.G3D1') + self.assertEqual(sc.rmsd_details["F"][(1, "\x00")]["chain_mapping"], {'F': 'B', 'C': 'C'}) + self.assertEqual(sc.rmsd_details["F"][(1, "\x00")]["bs_num_res"], 15) + self.assertEqual(sc.rmsd_details["F"][(1, "\x00")]["bs_num_overlap_res"], 15) + self.assertEqual(sc.rmsd_details["F"][(1, "\x00")]["target_ligand"].qualified_name, 'I.G3D1') + self.assertEqual(sc.rmsd_details["F"][(1, "\x00")]["model_ligand"].qualified_name, 'F.G3D1') # lddt_pli - self.assertAlmostEqual(sc.lddt_pli["J.G3D1"], 0.91194, 5) - self.assertAlmostEqual(sc.lddt_pli["F.G3D1"], 0.0014598, 6) + self.assertAlmostEqual(sc.lddt_pli["J"][(1, "\x00")], 0.91194, 5) + self.assertAlmostEqual(sc.lddt_pli["F"][(1, "\x00")], 0.0014598, 6) # lddt_pli_assignment - self.assertEqual(sc.lddt_pli_assignment, {'J.G3D1': 'I.G3D1', 'F.G3D1': 'J.G3D1'}) + self.assertEqual(sc.lddt_pli_assignment, {'J': {(1, "\x00"): 'I.G3D1'}, + 'F': {(1, "\x00"): 'J.G3D1'}}) # lddt_pli_details - self.assertAlmostEqual(sc.lddt_pli_details["J.G3D1"]["rmsd"], 4.1008, 4) - self.assertEqual(sc.lddt_pli_details["J.G3D1"]["lddt_pli_n_contacts"], 5224) - self.assertEqual(sc.lddt_pli_details["J.G3D1"]["chain_mapping"], {'F': 'B', 'C': 'C'}) - self.assertEqual(sc.lddt_pli_details["J.G3D1"]["bs_num_res"], 15) - self.assertEqual(sc.lddt_pli_details["J.G3D1"]["bs_num_overlap_res"], 15) - self.assertEqual(sc.lddt_pli_details["J.G3D1"]["target_ligand"].qualified_name, 'I.G3D1') - self.assertEqual(sc.lddt_pli_details["J.G3D1"]["model_ligand"].qualified_name, 'J.G3D1') - self.assertAlmostEqual(sc.lddt_pli_details["F.G3D1"]["rmsd"], 57.7868, 4) - self.assertEqual(sc.lddt_pli_details["F.G3D1"]["lddt_pli_n_contacts"], 5480) - self.assertEqual(sc.lddt_pli_details["F.G3D1"]["chain_mapping"], {'E': 'B', 'D': 'C'}) - self.assertEqual(sc.lddt_pli_details["F.G3D1"]["bs_num_res"], 16) - self.assertEqual(sc.lddt_pli_details["F.G3D1"]["bs_num_overlap_res"], 16) - self.assertEqual(sc.lddt_pli_details["F.G3D1"]["target_ligand"].qualified_name, 'J.G3D1') - self.assertEqual(sc.lddt_pli_details["F.G3D1"]["model_ligand"].qualified_name, 'F.G3D1') + self.assertAlmostEqual(sc.lddt_pli_details["J"][(1, "\x00")]["rmsd"], 4.1008, 4) + self.assertEqual(sc.lddt_pli_details["J"][(1, "\x00")]["lddt_pli_n_contacts"], 5224) + self.assertEqual(sc.lddt_pli_details["J"][(1, "\x00")]["chain_mapping"], {'F': 'B', 'C': 'C'}) + self.assertEqual(sc.lddt_pli_details["J"][(1, "\x00")]["bs_num_res"], 15) + self.assertEqual(sc.lddt_pli_details["J"][(1, "\x00")]["bs_num_overlap_res"], 15) + self.assertEqual(sc.lddt_pli_details["J"][(1, "\x00")]["target_ligand"].qualified_name, 'I.G3D1') + self.assertEqual(sc.lddt_pli_details["J"][(1, "\x00")]["model_ligand"].qualified_name, 'J.G3D1') + self.assertAlmostEqual(sc.lddt_pli_details["F"][(1, "\x00")]["rmsd"], 57.7868, 4) + self.assertEqual(sc.lddt_pli_details["F"][(1, "\x00")]["lddt_pli_n_contacts"], 5480) + self.assertEqual(sc.lddt_pli_details["F"][(1, "\x00")]["chain_mapping"], {'E': 'B', 'D': 'C'}) + self.assertEqual(sc.lddt_pli_details["F"][(1, "\x00")]["bs_num_res"], 16) + self.assertEqual(sc.lddt_pli_details["F"][(1, "\x00")]["bs_num_overlap_res"], 16) + self.assertEqual(sc.lddt_pli_details["F"][(1, "\x00")]["target_ligand"].qualified_name, 'J.G3D1') + self.assertEqual(sc.lddt_pli_details["F"][(1, "\x00")]["model_ligand"].qualified_name, 'F.G3D1') if __name__ == "__main__":