diff --git a/modules/mol/alg/pymod/ligand_scoring.py b/modules/mol/alg/pymod/ligand_scoring.py index c44ec74c0071cd633b2763b6f5d95d2ec89aef9c..75738e34f90450f88612eaf2e3d7e40b3b40aae5 100644 --- a/modules/mol/alg/pymod/ligand_scoring.py +++ b/modules/mol/alg/pymod/ligand_scoring.py @@ -87,6 +87,9 @@ class LigandScorer: based on residue numbers. This can be assumed in benchmarking setups such as CAMEO/CASP. :type resnum_alignments: :class:`bool` + :param check_resnames: On by default. Enforces residue name matches + between mapped model and target residues. + :type check_resnames: :class:`bool` :param chain_mapper: a chain mapper initialized for the target structure. If None (default), a chain mapper will be initialized lazily as required. @@ -103,9 +106,9 @@ class LigandScorer: :param lddt_bs_radius: :class:`float` """ def __init__(self, model, target, model_ligands=None, target_ligands=None, - resnum_alignments=False, chain_mapper=None, - substructure_match=False, radius=4.0, lddt_pli_radius=6.0, - lddt_bs_radius=10.0): + resnum_alignments=False, check_resnames=True, + chain_mapper=None, substructure_match=False, + radius=4.0, lddt_pli_radius=6.0, lddt_bs_radius=10.0): if isinstance(model, mol.EntityView): self.model = mol.CreateEntityFromView(model, False) @@ -135,6 +138,7 @@ class LigandScorer: self._chain_mapper = chain_mapper self.resnum_alignments = resnum_alignments + self.check_resnames = check_resnames self.substructure_match = substructure_match self.radius = radius self.lddt_pli_radius = lddt_pli_radius @@ -453,7 +457,8 @@ class LigandScorer: n_cont, n_cons = lddt_scorer.lDDT( mdl_bs_ent, chain_mapping={"A": "A", "_": "_"}, no_intrachain=True, - return_dist_test=True) + return_dist_test=True, + check_resnames = self.check_resnames) # Save results? best_lddt = self._lddt_pli_matrix[target_i, model_i]["lddt_pli"] diff --git a/modules/mol/alg/tests/test_ligand_scoring.py b/modules/mol/alg/tests/test_ligand_scoring.py index 8e3eb94a7f3258c63cf066f9e9c4c4a539327fa6..a09d350fb7a5713119bb91fef05f970717b6c198 100644 --- a/modules/mol/alg/tests/test_ligand_scoring.py +++ b/modules/mol/alg/tests/test_ligand_scoring.py @@ -263,6 +263,20 @@ class TestLigandScoring(unittest.TestCase): assert sc._lddt_pli_matrix[1, 0]["bs_num_overlap_res"] == 15 assert sc._lddt_pli_matrix[5, 0]["bs_num_overlap_res"] == 15 + def test_check_resnames(self): + """Test check_resname argument works + """ + # 4C0A has mismatching sequence and fails with check_resnames=True + trg, trg_seqres = io.LoadMMCIF(os.path.join('testfiles', "1r8q.cif.gz"), seqres=True) + trg_4c0a, _ = io.LoadMMCIF(os.path.join('testfiles', "4c0a.cif.gz"), seqres=True) + + with self.assertRaises(RuntimeError): + sc = LigandScorer(trg, trg_4c0a, None, None, check_resnames=True) + sc._compute_scores() + + sc = LigandScorer(trg, trg_4c0a, None, None, check_resnames=False) + sc._compute_scores() + if __name__ == "__main__": from ost import testutils