From bccddafe3c07e6fee92fbdfe67dc161c6344ff0e Mon Sep 17 00:00:00 2001
From: Gabriel Studer <gabriel.studer@unibas.ch>
Date: Mon, 29 Aug 2022 11:55:26 +0200
Subject: [PATCH] lDDT: introduce MappingResult object

Besides raw mapping, this also contains alignments etc for further processing
---
 modules/mol/alg/pymod/chain_mapping.py      | 192 ++++++++++++++++----
 modules/mol/alg/tests/test_chain_mapping.py |  27 +--
 2 files changed, 173 insertions(+), 46 deletions(-)

diff --git a/modules/mol/alg/pymod/chain_mapping.py b/modules/mol/alg/pymod/chain_mapping.py
index e9a026203..06a3b3200 100644
--- a/modules/mol/alg/pymod/chain_mapping.py
+++ b/modules/mol/alg/pymod/chain_mapping.py
@@ -18,12 +18,83 @@ from ost import geom
 
 from ost.mol.alg import lddt
 
+
+class MappingResult:
+    """ Result object for the chain mapping functions in :class:`ChainMapper`
+
+    Constructor is directly called within the functions, no need to construct
+    such objects yourself.
+    """
+    def __init__(self, target, model, chem_groups, mapping, alns):
+        self._target = target
+        self._model = model
+        self._chem_groups = chem_groups
+        self._mapping = mapping
+
+    @property
+    def target(self):
+        """ Target/reference structure, i.e. :attr:`ChainMapper.target`
+
+        :type: :class:`ost.mol.EntityView`
+        """
+        return self._target
+
+    @property
+    def model(self):
+        """ Model structure that gets mapped onto :attr:`~target`
+
+        Underwent same processing as :attr:`ChainMapper.target`, i.e.
+        only contains peptide/nucleotide chains of sufficient size.
+
+        :type: :class:`ost.mol.EntityView`
+        """
+        return self._model
+
+    @property
+    def chem_groups(self):
+        """ Groups of chemically equivalent chains in :attr:`~target`
+
+        Same as :attr:`ChainMapper.chem_group`
+
+        :class:`list` of :class:`list` of :class:`str` (chain names)
+        """
+        return self._chem_groups
+
+
+    @property
+    def mapping(self):
+        """ Mapping of :attr:`model` chains onto :attr:`~target`
+
+        Exact same shape as :attr:`~chem_groups` but containing the names of the
+        mapped chains in :attr:`~model`. May contain None for :attr:`~target`
+        chains that are not covered. No guarantee that all chains in
+        :attr:`~model` are mapped.
+
+        :class:`list` of :class:`list` of :class:`str` (chain names)
+        """
+        return self._mapping
+
+    @property
+    def alns(self):
+        """ Alignments of mapped chains in :attr:`~target` and :attr:`~model`
+
+        Each alignment is accessible with ``alns[(t_chain,m_chain)]``. First
+        sequence is the sequence of :attr:`target` chain, second sequence the
+        one from :attr:`~model`. The respective :class:`ost.mol.EntityView` are
+        attached with :func:`ost.seq.ConstSequenceHandle.AttachView`.
+
+        :type: :class:`dict` with key: :class:`tuple` of :class:`str`, value:
+               :class:`ost.seq.AlignmentHandle`
+        """
+        return self._aln
+
+
 class ReprResult:
 
     """ Result object for :func:`ChainMapper.GetRepr`
 
     Constructor is directly called within the function, no need to construct
-    such object yourself.
+    such objects yourself.
 
     :param lDDT: lDDT for this mapping. Depends on how you call
                  :func:`ChainMapper.GetRepr` whether this is backbone only or
@@ -511,24 +582,30 @@ class ChainMapper:
         :type inclusion_radius: :class:`float`
         :param thresholds: Thresholds for lDDT
         :type thresholds: :class:`list` of :class:`float`
-        :returns: A :class:`list` of :class:`list` that reflects
-                  :attr:`~chem_groups` but is filled with the respective model
-                  chains. Target chains without mapped model chains are set to
-                  None.
+        :returns: A :class:`MappingResult`
         """
         chem_mapping, chem_group_alns, mdl = self.GetChemMapping(model)
 
-        # check for the simplest case
-        one_to_one = _CheckOneToOneMapping(self.chem_groups, chem_mapping)
-        if one_to_one is not None:
-            return one_to_one
-
         # all possible ref/mdl alns given chem mapping
         ref_mdl_alns =  _GetRefMdlAlns(self.chem_groups,
                                        self.chem_group_alignments,
                                        chem_mapping,
                                        chem_group_alns)
 
+        # check for the simplest case
+        one_to_one = _CheckOneToOneMapping(self.chem_groups, chem_mapping)
+        if one_to_one is not None:
+            alns = dict()
+            for ref_group, mdl_group in zip(self.chem_groups, one_to_one):
+                for ref_ch, mdl_ch in zip(ref_group, mdl_group):
+                    if ref_ch is not None and mdl_ch is not None:
+                        aln = ref_mdl_alns[(ref_ch, mdl_ch)]
+                        aln.AttachView(0, self.target.Select(f"cname={ref_ch}"))
+                        aln.AttachView(1, mdl.Select(f"cname={mdl_ch}"))
+                        alns[(ref_ch, mdl_ch)] = aln
+            return MappingResult(self.target, mdl, self.chem_groups, one_to_one,
+                                 alns)
+
         best_mapping = None
         best_lddt = -1.0
 
@@ -570,8 +647,17 @@ class ChainMapper:
                     best_mapping = mapping
                     best_lddt = lDDT
 
-        return best_mapping
+        alns = dict()
+        for ref_group, mdl_group in zip(self.chem_groups, best_mapping):
+            for ref_ch, mdl_ch in zip(ref_group, mdl_group):
+                if ref_ch is not None and mdl_ch is not None:
+                    aln = ref_mdl_alns[(ref_ch, mdl_ch)]
+                    aln.AttachView(0, self.target.Select(f"cname={ref_ch}"))
+                    aln.AttachView(1, mdl.Select(f"cname={mdl_ch}"))
+                    alns[(ref_ch, mdl_ch)] = aln
 
+        return MappingResult(self.target, mdl, self.chem_groups, best_mapping,
+                             alns)
 
     def GetGreedylDDTMapping(self, model, inclusion_radius=15.0,
                              thresholds=[0.5, 1.0, 2.0, 4.0],
@@ -635,10 +721,7 @@ class ChainMapper:
                                             are extended in an initial search
                                             for high scoring local solutions.
         :type block_blocks_per_chem_group: :class:`int`
-        :returns: A :class:`list` of :class:`list` that reflects
-                  :attr:`~chem_groups` but is filled with the respective model
-                  chains. Target chains without mapped model chains are set to
-                  None.
+        :returns: A :class:`MappingResult`
         """
 
         seed_strategies = ["fast", "full", "block"]
@@ -647,21 +730,25 @@ class ChainMapper:
 
         chem_mapping, chem_group_alns, mdl = self.GetChemMapping(model)
 
-        # check for the simplest case
-        only_one_to_one = True
-        for ref_chains, mdl_chains in zip(self.chem_groups, chem_mapping):
-            if len(ref_chains) != 1 or len(mdl_chains) not in [0, 1]:
-                only_one_to_one = False
-                break
-        if only_one_to_one:
-            # skip ref chem groups with no mapped mdl chain
-            return [(a,b) for a,b in zip(self.chem_groups, chem_mapping) if len(b) == 1]
-
         ref_mdl_alns =  _GetRefMdlAlns(self.chem_groups,
                                        self.chem_group_alignments,
                                        chem_mapping,
                                        chem_group_alns)
 
+        # check for the simplest case
+        one_to_one = _CheckOneToOneMapping(self.chem_groups, chem_mapping)
+        if one_to_one is not None:
+            alns = dict()
+            for ref_group, mdl_group in zip(self.chem_groups, one_to_one):
+                for ref_ch, mdl_ch in zip(ref_group, mdl_group):
+                    if ref_ch is not None and mdl_ch is not None:
+                        aln = ref_mdl_alns[(ref_ch, mdl_ch)]
+                        aln.AttachView(0, self.target.Select(f"cname={ref_ch}"))
+                        aln.AttachView(1, mdl.Select(f"cname={mdl_ch}"))
+                        alns[(ref_ch, mdl_ch)] = aln
+            return MappingResult(self.target, mdl, self.chem_groups, one_to_one,
+                                 alns)
+
         # setup greedy searcher
         the_greed = _GreedySearcher(self.target, mdl, self.chem_groups,
                                     chem_mapping, ref_mdl_alns,
@@ -670,11 +757,23 @@ class ChainMapper:
                                     steep_opt_rate=steep_opt_rate)
 
         if seed_strategy == "fast":
-            return _FastGreedy(the_greed)
+            mapping = _FastGreedy(the_greed)
         elif seed_strategy == "full":
-            return _FullGreedy(the_greed, full_n_mdl_chains)
+            mapping = _FullGreedy(the_greed, full_n_mdl_chains)
         elif seed_strategy == "block":
-            return _BlockGreedy(the_greed, block_seed_size, block_blocks_per_chem_group)
+            mapping = _BlockGreedy(the_greed, block_seed_size, block_blocks_per_chem_group)
+
+        alns = dict()
+        for ref_group, mdl_group in zip(self.chem_groups, mapping):
+            for ref_ch, mdl_ch in zip(ref_group, mdl_group):
+                if ref_ch is not None and mdl_ch is not None:
+                    aln = ref_mdl_alns[(ref_ch, mdl_ch)]
+                    aln.AttachView(0, self.target.Select(f"cname={ref_ch}"))
+                    aln.AttachView(1, mdl.Select(f"cname={mdl_ch}"))
+                    alns[(ref_ch, mdl_ch)] = aln
+
+        return MappingResult(self.target, mdl, self.chem_groups, mapping,
+                             alns)
 
 
     def GetGreedyRigidMapping(self, model, strategy = "single",
@@ -734,16 +833,31 @@ class ChainMapper:
                                         as oposed to
                                         :func:`ost.mol.alg.SuperposeSVD`
         :type iterative_superposition: :class:`bool`
-        :returns: A :class:`list` of :class:`list` that reflects
-                  :attr:`~chem_groups` but is filled with the respective model
-                  chains. Target chains without mapped model chains are set to
-                  None.
+        :returns: A :class:`MappingResult`
         """
 
         if strategy not in ["single", "iterative", "iterative_rmsd"]:
             raise RuntimeError("strategy must be \"single\" or \"iterative\"")
 
         chem_mapping, chem_group_alns, mdl = self.GetChemMapping(model)
+        ref_mdl_alns =  _GetRefMdlAlns(self.chem_groups,
+                                       self.chem_group_alignments,
+                                       chem_mapping,
+                                       chem_group_alns)
+
+        # check for the simplest case
+        one_to_one = _CheckOneToOneMapping(self.chem_groups, chem_mapping)
+        if one_to_one is not None:
+            alns = dict()
+            for ref_group, mdl_group in zip(self.chem_groups, one_to_one):
+                for ref_ch, mdl_ch in zip(ref_group, mdl_group):
+                    if ref_ch is not None and mdl_ch is not None:
+                        aln = ref_mdl_alns[(ref_ch, mdl_ch)]
+                        aln.AttachView(0, self.target.Select(f"cname={ref_ch}"))
+                        aln.AttachView(1, mdl.Select(f"cname={mdl_ch}"))
+                        alns[(ref_ch, mdl_ch)] = aln
+            return MappingResult(self.target, mdl, self.chem_groups, one_to_one,
+                                 alns)
 
         trg_group_pos, mdl_group_pos = _GetRefPos(self.target, mdl,
                                                   self.chem_group_alignments,
@@ -804,7 +918,17 @@ class ChainMapper:
                     mapped_mdl_chains.append(None)
             final_mapping.append(mapped_mdl_chains)
 
-        return final_mapping
+        alns = dict()
+        for ref_group, mdl_group in zip(self.chem_groups, final_mapping):
+            for ref_ch, mdl_ch in zip(ref_group, mdl_group):
+                if ref_ch is not None and mdl_ch is not None:
+                    aln = ref_mdl_alns[(ref_ch, mdl_ch)]
+                    aln.AttachView(0, self.target.Select(f"cname={ref_ch}"))
+                    aln.AttachView(1, mdl.Select(f"cname={mdl_ch}"))
+                    alns[(ref_ch, mdl_ch)] = aln
+
+        return MappingResult(self.target, mdl, self.chem_groups, final_mapping,
+                             alns)
 
 
     def GetRepr(self, substructure, model, topn=1, inclusion_radius=15.0,
@@ -2508,4 +2632,4 @@ def _GetTransform(pos_one, pos_two, iterative):
     return res.transformation
 
 # specify public interface
-__all__ = ('ChainMapper', 'ReprResult')
+__all__ = ('ChainMapper', 'ReprResult', 'MappingResult')
diff --git a/modules/mol/alg/tests/test_chain_mapping.py b/modules/mol/alg/tests/test_chain_mapping.py
index 8f3279ef8..655266d03 100644
--- a/modules/mol/alg/tests/test_chain_mapping.py
+++ b/modules/mol/alg/tests/test_chain_mapping.py
@@ -245,24 +245,27 @@ class TestChainMapper(unittest.TestCase):
     # This is not supposed to be in depth algorithm testing, we just check
     # whether the various algorithms return sensible chain mappings
 
-    naive_lddt_mapping = mapper.GetNaivelDDTMapping(mdl)
-    self.assertEqual(naive_lddt_mapping, [['X', 'Y'],[None],['Z']])
+    naive_lddt_res = mapper.GetNaivelDDTMapping(mdl)
+    self.assertEqual(naive_lddt_res.mapping, [['X', 'Y'],[None],['Z']])
 
     # the "fast" strategy produces actually a suboptimal mapping in this case...
-    greedy_lddt_mapping = mapper.GetGreedylDDTMapping(mdl, seed_strategy="fast")
-    self.assertEqual(greedy_lddt_mapping, [['Y', 'X'],[None],['Z']])
+    greedy_lddt_res = mapper.GetGreedylDDTMapping(mdl, seed_strategy="fast")
+    self.assertEqual(greedy_lddt_res.mapping, [['Y', 'X'],[None],['Z']])
 
-    greedy_lddt_mapping = mapper.GetGreedylDDTMapping(mdl, seed_strategy="full")
-    self.assertEqual(greedy_lddt_mapping, [['X', 'Y'],[None],['Z']])
+    greedy_lddt_res = mapper.GetGreedylDDTMapping(mdl, seed_strategy="full")
+    self.assertEqual(greedy_lddt_res.mapping, [['X', 'Y'],[None],['Z']])
 
-    greedy_lddt_mapping = mapper.GetGreedylDDTMapping(mdl, seed_strategy="block")
-    self.assertEqual(greedy_lddt_mapping, [['X', 'Y'],[None],['Z']])
+    greedy_lddt_res = mapper.GetGreedylDDTMapping(mdl, seed_strategy="block")
+    self.assertEqual(greedy_lddt_res.mapping, [['X', 'Y'],[None],['Z']])
 
-    greedy_rigid_mapping = mapper.GetGreedyRigidMapping(mdl, strategy="single")
-    self.assertEqual(greedy_rigid_mapping, [['X', 'Y'],[None],['Z']])
+    greedy_rigid_res = mapper.GetGreedyRigidMapping(mdl, strategy="single")
+    self.assertEqual(greedy_rigid_res.mapping, [['X', 'Y'],[None],['Z']])
 
-    greedy_rigid_mapping = mapper.GetGreedyRigidMapping(mdl, strategy="iterative")
-    self.assertEqual(greedy_rigid_mapping, [['X', 'Y'],[None],['Z']])
+    greedy_rigid_res = mapper.GetGreedyRigidMapping(mdl, strategy="iterative")
+    self.assertEqual(greedy_rigid_res.mapping, [['X', 'Y'],[None],['Z']])
+
+    greedy_rigid_res = mapper.GetGreedyRigidMapping(mdl, strategy="iterative_rmsd")
+    self.assertEqual(greedy_rigid_res.mapping, [['X', 'Y'],[None],['Z']])
 
 
 if __name__ == "__main__":
-- 
GitLab