diff --git a/modules/mol/alg/pymod/chain_mapping.py b/modules/mol/alg/pymod/chain_mapping.py index 85983ed88d08ffe4ceab26136da0c90928562103..692788bb8938aad5e8335a9475cd1585d888ccef 100644 --- a/modules/mol/alg/pymod/chain_mapping.py +++ b/modules/mol/alg/pymod/chain_mapping.py @@ -1231,7 +1231,8 @@ class ChainMapper: def GetRepr(self, substructure, model, topn=1, inclusion_radius=15.0, thresholds=[0.5, 1.0, 2.0, 4.0], bb_only=False, - only_interchain=False, chem_mapping_result = None): + only_interchain=False, chem_mapping_result = None, + global_mapping = None): """ Identify *topn* representations of *substructure* in *model* *substructure* defines a subset of :attr:`~target` for which one @@ -1267,12 +1268,32 @@ class ChainMapper: *model*. If set, *model* parameter is not used. :type chem_mapping_result: :class:`tuple` + :param global_mapping: Pro param. Specify a global mapping result. This + fully defines the desired representation in the + model but extracts it and enriches it with all + the nice attributes of :class:`ReprResult`. + The target attribute in *global_mapping* must be + of the same entity as self.target and the model + attribute of *global_mapping* must be of the same + entity as *model*. + :type global_mapping: :class:`MappingResult` :returns: :class:`list` of :class:`ReprResult` """ if topn < 1: raise RuntimeError("topn must be >= 1") + if global_mapping is not None: + # ensure that this mapping is derived from the same structures + if global_mapping.target.handle.GetHashCode() != \ + self.target.handle.GetHashCode(): + raise RuntimeError("global_mapping.target must be the same " + "entity as self.target") + if global_mapping.model.handle.GetHashCode() != \ + model.handle.GetHashCode(): + raise RuntimeError("global_mapping.model must be the same " + "entity as model param") + # check whether substructure really is a subset of self.target for r in substructure.residues: ch_name = r.GetChain().GetName() @@ -1377,9 +1398,31 @@ class ChainMapper: inclusion_radius = inclusion_radius, bb_only = bb_only) scored_mappings = list() - for mapping in _ChainMappings(substructure_chem_groups, - substructure_chem_mapping, - self.n_max_naive): + + if global_mapping: + # construct mapping of substructure from global mapping + flat_mapping = global_mapping.GetFlatMapping() + mapping = list() + for chem_group, chem_mapping in zip(substructure_chem_groups, + substructure_chem_mapping): + chem_group_mapping = list() + for ch in chem_group: + if ch in flat_mapping: + mdl_ch = flat_mapping[ch] + if mdl_ch in chem_mapping: + chem_group_mapping.append(mdl_ch) + else: + chem_group_mapping.append(None) + else: + chem_group_mapping.append(None) + mapping.append(chem_group_mapping) + mappings = [mapping] + else: + mappings = list(_ChainMappings(substructure_chem_groups, + substructure_chem_mapping, + self.n_max_naive)) + + for mapping in mappings: # chain_mapping and alns as input for lDDT computation lddt_chain_mapping = dict() lddt_alns = dict() diff --git a/modules/mol/alg/tests/test_chain_mapping.py b/modules/mol/alg/tests/test_chain_mapping.py index 16fd363e596e42a180cacf049ef444390a05816c..3d4ccbde250a3f9e2f4b39e7b4b9a5010a18645f 100644 --- a/modules/mol/alg/tests/test_chain_mapping.py +++ b/modules/mol/alg/tests/test_chain_mapping.py @@ -352,12 +352,61 @@ class TestChainMapper(unittest.TestCase): self.assertEqual(ref_aln.GetSequence(1).GetString(), aln.GetSequence(1).GetString()) + def test_get_repr(self): + + ref, ref_seqres = io.LoadMMCIF(os.path.join("testfiles", "1r8q.cif.gz"), + seqres=True) + mdl, mdl_seqres = io.LoadMMCIF(os.path.join("testfiles", "4c0a.cif.gz"), + seqres=True) + + pep_ref = ref.Select("peptide=true") + lig_ref = ref.Select("cname=K") + + # create view of reference binding site + ref_residues_hashes = set() # helper to keep track of added residues + for ligand_at in lig_ref.atoms: + close_atoms = pep_ref.FindWithin(ligand_at.GetPos(), 10.0) + for close_at in close_atoms: + ref_res = close_at.GetResidue() + h = ref_res.handle.GetHashCode() + if h not in ref_residues_hashes: + ref_residues_hashes.add(h) + + ref_bs = ref.CreateEmptyView() + for ch in ref.chains: + for r in ch.residues: + if r.handle.GetHashCode() in ref_residues_hashes: + ref_bs.AddResidue(r, mol.ViewAddFlag.INCLUDE_ALL) + + chain_mapper = ChainMapper(ref) + global_mapping = chain_mapper.GetQSScoreMapping(mdl) + flat_mapping = global_mapping.GetFlatMapping() + + # find optimal representation of binding site + optimal_repr_result = chain_mapper.GetRepr(ref_bs, mdl)[0] + self.assertTrue(optimal_repr_result.lDDT > 0.6) # exp result: 0.6047 + # mapping should be different than overall mapping + repr_mapping = optimal_repr_result.GetFlatChainMapping() + for ref_ch, mdl_ch in repr_mapping.items(): + self.assertNotEqual(mdl_ch, flat_mapping[ref_ch]) + + # enforce usage of global mapping, which gives a different pocket + # with slightly lower lDDT + global_repr_result = \ + chain_mapper.GetRepr(ref_bs, mdl, global_mapping=global_mapping)[0] + self.assertTrue(global_repr_result.lDDT < 0.6) # exp result: 0.5914 + + # ensure that mapping from global_repr_result corresponds to global + # mapping + repr_mapping = global_repr_result.GetFlatChainMapping() + for ref_ch, mdl_ch in repr_mapping.items(): + self.assertEqual(mdl_ch, flat_mapping[ref_ch]) def test_misc(self): - # check for triggered error when no chain fulfills length threshold - ref = _LoadFile("3l1p.1.pdb").Select("cname=A and rnum<8") - self.assertRaises(Exception, ChainMapper, ref) + # check for triggered error when no chain fulfills length threshold + ref = _LoadFile("3l1p.1.pdb").Select("cname=A and rnum<8") + self.assertRaises(Exception, ChainMapper, ref) if __name__ == "__main__":