diff --git a/modules/mol/alg/pymod/chain_mapping.py b/modules/mol/alg/pymod/chain_mapping.py index 844e0052c7dbbac21c8f07630a1498dae2fd26ad..d4ab0e9db79f26cd48e99a596d043bb50a05f260 100644 --- a/modules/mol/alg/pymod/chain_mapping.py +++ b/modules/mol/alg/pymod/chain_mapping.py @@ -17,6 +17,7 @@ from ost import mol from ost import geom from ost.mol.alg import lddt +from ost.mol.alg import qsscore class MappingResult: @@ -728,7 +729,7 @@ class ChainMapper: are estimated using all possible combinations of target and model chains within the same chem groups and build the basis for further extension. - There are two extension strategies: + There are three extension strategies: * **greedy_single**: Iteratively add the model/target chain pair that adds the most conserved contacts based on the GDT-TS metric @@ -748,8 +749,8 @@ class ChainMapper: :param model: Model to map :type model: :class:`ost.mol.EntityView`/:class:`ost.mol.EntityHandle` :param strategy: Strategy to extend mappings from initial transforms, - see description above. Must be in ["single", - "iterative"] + see description above. Must be in ["greedy_single", + "greedy_iterative", "greedy_iterative_rmsd"] :type strategy: :class:`str` :param single_chain_gdtts_thresh: Minimal GDT-TS score for model/target chain pair to be added to mapping. @@ -871,6 +872,69 @@ class ChainMapper: alns) + def GetQSScoreMapping(self, model, contact_d = 12.0, strategy = "naive"): + """ Identify chain mapping based on QSScore + + Scoring is based on CA/C3' positions which are present in all chains of + a :attr:`chem_groups` as well as the *model* chains which are mapped to + that respective chem group. + + There is currently one sampling strategy: + + * **naive**: Naively iterate all possible mappings and return best based + on QS score. + + :param model: Model to map + :type model: :class:`ost.mol.EntityView`/:class:`ost.mol.EntityHandle` + :param contact_d: Max distance between two residues to be considered as + contact in qs scoring + :type contact_d: :class:`float` + :param strategy: Strategy for sampling, must be in ["naive"] + :type strategy: :class:`str` + :returns: A :class:`MappingResult` + """ + + strategies = ["naive"] + if strategy not in strategies: + raise RuntimeError(f"strategy must be {strategies}") + + chem_mapping, chem_group_alns, mdl = self.GetChemMapping(model) + ref_mdl_alns = _GetRefMdlAlns(self.chem_groups, + self.chem_group_alignments, + chem_mapping, + chem_group_alns) + + # check for the simplest case + one_to_one = _CheckOneToOneMapping(self.chem_groups, chem_mapping) + if one_to_one is not None: + alns = dict() + for ref_group, mdl_group in zip(self.chem_groups, one_to_one): + for ref_ch, mdl_ch in zip(ref_group, mdl_group): + if ref_ch is not None and mdl_ch is not None: + aln = ref_mdl_alns[(ref_ch, mdl_ch)] + aln.AttachView(0, self.target.Select(f"cname={ref_ch}")) + aln.AttachView(1, mdl.Select(f"cname={mdl_ch}")) + alns[(ref_ch, mdl_ch)] = aln + return MappingResult(self.target, mdl, self.chem_groups, one_to_one, + alns) + + if strategy == "naive": + mapping = _QSScoreNaive(self.target, mdl, self.chem_groups, + chem_mapping, ref_mdl_alns, contact_d, + self.n_max_naive) + + alns = dict() + for ref_group, mdl_group in zip(self.chem_groups, mapping): + for ref_ch, mdl_ch in zip(ref_group, mdl_group): + if ref_ch is not None and mdl_ch is not None: + aln = ref_mdl_alns[(ref_ch, mdl_ch)] + aln.AttachView(0, self.target.Select(f"cname={ref_ch}")) + aln.AttachView(1, mdl.Select(f"cname={mdl_ch}")) + alns[(ref_ch, mdl_ch)] = aln + + return MappingResult(self.target, mdl, self.chem_groups, mapping, + alns) + def GetRepr(self, substructure, model, topn=1, inclusion_radius=15.0, thresholds=[0.5, 1.0, 2.0, 4.0], bb_only=False, only_interchain=False): @@ -2339,6 +2403,17 @@ def _IterativeRigidRMSD(initial_transforms, initial_mappings, chem_groups, return best_mapping +def _QSScoreNaive(trg, mdl, chem_groups, chem_mapping, ref_mdl_alns, contact_d, + n_max_naive): + best_mapping = None + best_score = -1.0 + qs_scorer = qsscore.QSScorer(trg, chem_groups, mdl, ref_mdl_alns) + for mapping in _ChainMappings(chem_groups, chem_mapping, n_max_naive): + score = qs_scorer.GetQSScore(mapping, check=False) + if score > best_score: + best_mapping = mapping + best_score = score + return best_mapping def _GetRefPos(trg, mdl, trg_msas, mdl_alns, max_pos = None): """ Extracts reference positions which are present in trg and mdl diff --git a/modules/mol/alg/tests/test_chain_mapping.py b/modules/mol/alg/tests/test_chain_mapping.py index c60a9697cec982c7787ffaf6fb091817631de885..6e41fad79ec8ef34fdf2b43fe7496fe287d8065f 100644 --- a/modules/mol/alg/tests/test_chain_mapping.py +++ b/modules/mol/alg/tests/test_chain_mapping.py @@ -267,6 +267,9 @@ class TestChainMapper(unittest.TestCase): greedy_rigid_res = mapper.GetRigidMapping(mdl, strategy="greedy_iterative_rmsd") self.assertEqual(greedy_rigid_res.mapping, [['X', 'Y'],[None],['Z']]) + naive_qsscore_res = mapper.GetQSScoreMapping(mdl, strategy="naive") + self.assertEqual(naive_qsscore_res.mapping, [['X', 'Y'],[None],['Z']]) + if __name__ == "__main__": from ost import testutils