From e2c3585a9595ce82bae4bf9249c74355f43217b4 Mon Sep 17 00:00:00 2001
From: Gabriel Studer <gabriel.studer@unibas.ch>
Date: Tue, 6 Sep 2022 14:56:50 +0200
Subject: [PATCH] implement QS score based greedy chain mapping strategies

---
 modules/mol/alg/pymod/chain_mapping.py      | 687 ++++++++++++++++----
 modules/mol/alg/pymod/qsscore.py            |  32 +-
 modules/mol/alg/tests/test_chain_mapping.py |  25 +-
 3 files changed, 618 insertions(+), 126 deletions(-)

diff --git a/modules/mol/alg/pymod/chain_mapping.py b/modules/mol/alg/pymod/chain_mapping.py
index d4ab0e9db..7573286c8 100644
--- a/modules/mol/alg/pymod/chain_mapping.py
+++ b/modules/mol/alg/pymod/chain_mapping.py
@@ -690,11 +690,11 @@ class ChainMapper:
                                  self.n_max_naive)
         else:
             # its one of the greedy strategies - setup greedy searcher
-            the_greed = _GreedySearcher(self.target, mdl, self.chem_groups,
-                                        chem_mapping, ref_mdl_alns,
-                                        inclusion_radius=inclusion_radius,
-                                        thresholds=thresholds,
-                                        steep_opt_rate=steep_opt_rate)
+            the_greed = _lDDTGreedySearcher(self.target, mdl, self.chem_groups,
+                                            chem_mapping, ref_mdl_alns,
+                                            inclusion_radius=inclusion_radius,
+                                            thresholds=thresholds,
+                                            steep_opt_rate=steep_opt_rate)
             if strategy == "greedy_fast":
                 mapping = _lDDTGreedyFast(the_greed)
             elif strategy == "greedy_full":
@@ -716,7 +716,106 @@ class ChainMapper:
                              alns)
 
 
-    def GetRigidMapping(self, model, strategy = "greedy_single",
+    def GetQSScoreMapping(self, model, contact_d = 12.0, strategy = "naive",
+                          full_n_mdl_chains = None, block_seed_size = 5,
+                          block_blocks_per_chem_group = 5,
+                          steep_opt_rate = None):
+        """ Identify chain mapping based on QSScore
+
+        Scoring is based on CA/C3' positions which are present in all chains of
+        a :attr:`chem_groups` as well as the *model* chains which are mapped to
+        that respective chem group. QS score is not defined for single chains.
+        The greedy strategies that require to identify starting seeds thus
+        often rely on single chain lDDTs.
+
+        The following strategies are available:
+
+        * **naive**: Naively iterate all possible mappings and return best based
+                     on QS score.
+
+        * **greedy_fast**: perform all vs. all single chain lDDTs within the
+          respective ref/mdl chem groups. The mapping with highest number of
+          conserved contacts is selected as seed for greedy extension.
+          Extension is based on QS score.
+
+        * **greedy_full**: try multiple seeds for greedy extension, i.e. try
+          all ref/mdl chain combinations within the respective chem groups and
+          retain the mapping leading to the best QS score. Optionally, you can
+          reduce the number of mdl chains per ref chain to the
+          *full_n_mdl_chains* best scoring with respect to single chain lDDT.
+
+        * **greedy_block**: try multiple seeds for greedy extension, i.e. try
+          all ref/mdl chain combinations within the respective chem groups and
+          compute single chain lDDTs. The *block_blocks_per_chem_group* best
+          scoring ones are extend by *block_seed_size* chains and the block with
+          with best QS score is exhaustively extended.
+
+        :param model: Model to map
+        :type model: :class:`ost.mol.EntityView`/:class:`ost.mol.EntityHandle`
+        :param contact_d: Max distance between two residues to be considered as 
+                          contact in qs scoring
+        :type contact_d: :class:`float` 
+        :param strategy: Strategy for sampling, must be in ["naive"]
+        :type strategy: :class:`str`
+        :returns: A :class:`MappingResult`
+        """
+
+        strategies = ["naive", "greedy_fast", "greedy_full", "greedy_block"]
+        if strategy not in strategies:
+            raise RuntimeError(f"strategy must be {strategies}")
+
+        chem_mapping, chem_group_alns, mdl = self.GetChemMapping(model)
+        ref_mdl_alns =  _GetRefMdlAlns(self.chem_groups,
+                                       self.chem_group_alignments,
+                                       chem_mapping,
+                                       chem_group_alns)
+
+        # check for the simplest case
+        one_to_one = _CheckOneToOneMapping(self.chem_groups, chem_mapping)
+        if one_to_one is not None:
+            alns = dict()
+            for ref_group, mdl_group in zip(self.chem_groups, one_to_one):
+                for ref_ch, mdl_ch in zip(ref_group, mdl_group):
+                    if ref_ch is not None and mdl_ch is not None:
+                        aln = ref_mdl_alns[(ref_ch, mdl_ch)]
+                        aln.AttachView(0, self.target.Select(f"cname={ref_ch}"))
+                        aln.AttachView(1, mdl.Select(f"cname={mdl_ch}"))
+                        alns[(ref_ch, mdl_ch)] = aln
+            return MappingResult(self.target, mdl, self.chem_groups, one_to_one,
+                                 alns)
+
+        if strategy == "naive":
+            mapping = _QSScoreNaive(self.target, mdl, self.chem_groups,
+                                    chem_mapping, ref_mdl_alns, contact_d,
+                                    self.n_max_naive)
+        else:
+            # its one of the greedy strategies - setup greedy searcher
+
+            the_greed = _QSScoreGreedySearcher(self.target, mdl, self.chem_groups,
+                                            chem_mapping, ref_mdl_alns,
+                                            contact_d = contact_d,
+                                            steep_opt_rate=steep_opt_rate)
+            if strategy == "greedy_fast":
+                mapping = _QSScoreGreedyFast(the_greed)
+            elif strategy == "greedy_full":
+                mapping = _QSScoreGreedyFull(the_greed, full_n_mdl_chains)
+            elif strategy == "greedy_block":
+                mapping = _QSScoreGreedyBlock(the_greed, block_seed_size,
+                                              block_blocks_per_chem_group)
+
+        alns = dict()
+        for ref_group, mdl_group in zip(self.chem_groups, mapping):
+            for ref_ch, mdl_ch in zip(ref_group, mdl_group):
+                if ref_ch is not None and mdl_ch is not None:
+                    aln = ref_mdl_alns[(ref_ch, mdl_ch)]
+                    aln.AttachView(0, self.target.Select(f"cname={ref_ch}"))
+                    aln.AttachView(1, mdl.Select(f"cname={mdl_ch}"))
+                    alns[(ref_ch, mdl_ch)] = aln
+
+        return MappingResult(self.target, mdl, self.chem_groups, mapping,
+                             alns)
+
+    def GetRigidMapping(self, model, strategy = "greedy_single_gdtts",
                         single_chain_gdtts_thresh=0.4, subsampling=None,
                         first_complete=False, iterative_superposition=False):
         """Identify chain mapping based on rigid superposition
@@ -731,14 +830,14 @@ class ChainMapper:
 
         There are three extension strategies:
 
-        * **greedy_single**: Iteratively add the model/target chain pair that
-          adds the most conserved contacts based on the GDT-TS metric
+        * **greedy_single_gdtts**: Iteratively add the model/target chain pair
+          that adds the most conserved contacts based on the GDT-TS metric
           (Number of CA/C3' atoms within [8, 4, 2, 1] Angstrom). The mapping
           with highest GDT-TS score is returned. However, that mapping is not
           guaranteed to be complete (see *single_chain_gdtts_thresh*).
 
-        * **greedy_iterative**: Same as single except that the transformation
-          gets updated with each added chain pair.
+        * **greedy_iterative_gdtts**: Same as single except that the
+          transformation gets updated with each added chain pair.
 
         * **greedy_iterative_rmsd**: Same as iterative, i.e. the transformation
           gets updated with each added chain pair. However,
@@ -776,7 +875,8 @@ class ChainMapper:
         :returns: A :class:`MappingResult`
         """
 
-        strategies = ["greedy_single", "greedy_iterative", "greedy_iterative_rmsd"]
+        strategies = ["greedy_single_gdtts", "greedy_iterative_gdtts",
+                      "greedy_iterative_rmsd"]
         if strategy not in strategies:
             raise RuntimeError(f"strategy must be {strategies}")
 
@@ -825,22 +925,24 @@ class ChainMapper:
                             initial_transforms.append(transform)
                             initial_mappings.append((t,m))
 
-
-        if strategy == "greedy_single":
-            mapping = _SingleRigid(initial_transforms, initial_mappings,
-                                   self.chem_groups, chem_mapping,
-                                   trg_group_pos, mdl_group_pos,
-                                   single_chain_gdtts_thresh,
-                                   iterative_superposition, first_complete,
-                                   len(self.target.chains), len(mdl.chains))
-
-        elif strategy == "greedy_iterative":
-            mapping = _IterativeRigid(initial_transforms, initial_mappings,
-                                      self.chem_groups, chem_mapping,
-                                      trg_group_pos, mdl_group_pos,
-                                      single_chain_gdtts_thresh,
-                                      iterative_superposition, first_complete,
-                                      len(self.target.chains), len(mdl.chains))
+        if strategy == "greedy_single_gdtts":
+            mapping = _SingleRigidGDTTS(initial_transforms, initial_mappings,
+                                        self.chem_groups, chem_mapping,
+                                        trg_group_pos, mdl_group_pos,
+                                        single_chain_gdtts_thresh,
+                                        iterative_superposition, first_complete,
+                                        len(self.target.chains),
+                                        len(mdl.chains))
+
+        elif strategy == "greedy_iterative_gdtts":
+            mapping = _IterativeRigidGDTTS(initial_transforms, initial_mappings,
+                                           self.chem_groups, chem_mapping,
+                                           trg_group_pos, mdl_group_pos,
+                                           single_chain_gdtts_thresh,
+                                           iterative_superposition,
+                                           first_complete,
+                                           len(self.target.chains),
+                                           len(mdl.chains))
 
         elif strategy == "greedy_iterative_rmsd":
             mapping = _IterativeRigidRMSD(initial_transforms, initial_mappings,
@@ -872,69 +974,6 @@ class ChainMapper:
                              alns)
 
 
-    def GetQSScoreMapping(self, model, contact_d = 12.0, strategy = "naive"):
-        """ Identify chain mapping based on QSScore
-
-        Scoring is based on CA/C3' positions which are present in all chains of
-        a :attr:`chem_groups` as well as the *model* chains which are mapped to
-        that respective chem group.
-
-        There is currently one sampling strategy:
-
-        * **naive**: Naively iterate all possible mappings and return best based
-                     on QS score.
-
-        :param model: Model to map
-        :type model: :class:`ost.mol.EntityView`/:class:`ost.mol.EntityHandle`
-        :param contact_d: Max distance between two residues to be considered as 
-                          contact in qs scoring
-        :type contact_d: :class:`float` 
-        :param strategy: Strategy for sampling, must be in ["naive"]
-        :type strategy: :class:`str`
-        :returns: A :class:`MappingResult`
-        """
-
-        strategies = ["naive"]
-        if strategy not in strategies:
-            raise RuntimeError(f"strategy must be {strategies}")
-
-        chem_mapping, chem_group_alns, mdl = self.GetChemMapping(model)
-        ref_mdl_alns =  _GetRefMdlAlns(self.chem_groups,
-                                       self.chem_group_alignments,
-                                       chem_mapping,
-                                       chem_group_alns)
-
-        # check for the simplest case
-        one_to_one = _CheckOneToOneMapping(self.chem_groups, chem_mapping)
-        if one_to_one is not None:
-            alns = dict()
-            for ref_group, mdl_group in zip(self.chem_groups, one_to_one):
-                for ref_ch, mdl_ch in zip(ref_group, mdl_group):
-                    if ref_ch is not None and mdl_ch is not None:
-                        aln = ref_mdl_alns[(ref_ch, mdl_ch)]
-                        aln.AttachView(0, self.target.Select(f"cname={ref_ch}"))
-                        aln.AttachView(1, mdl.Select(f"cname={mdl_ch}"))
-                        alns[(ref_ch, mdl_ch)] = aln
-            return MappingResult(self.target, mdl, self.chem_groups, one_to_one,
-                                 alns)
-
-        if strategy == "naive":
-            mapping = _QSScoreNaive(self.target, mdl, self.chem_groups,
-                                    chem_mapping, ref_mdl_alns, contact_d,
-                                    self.n_max_naive)
-
-        alns = dict()
-        for ref_group, mdl_group in zip(self.chem_groups, mapping):
-            for ref_ch, mdl_ch in zip(ref_group, mdl_group):
-                if ref_ch is not None and mdl_ch is not None:
-                    aln = ref_mdl_alns[(ref_ch, mdl_ch)]
-                    aln.AttachView(0, self.target.Select(f"cname={ref_ch}"))
-                    aln.AttachView(1, mdl.Select(f"cname={mdl_ch}"))
-                    alns[(ref_ch, mdl_ch)] = aln
-
-        return MappingResult(self.target, mdl, self.chem_groups, mapping,
-                             alns)
-
     def GetRepr(self, substructure, model, topn=1, inclusion_radius=15.0,
                 thresholds=[0.5, 1.0, 2.0, 4.0], bb_only=False,
                 only_interchain=False):
@@ -1742,7 +1781,7 @@ class _lDDTDecomposer:
             self.interface_cache[k2] = conserved
         return self.interface_cache[k1]
 
-class _GreedySearcher(_lDDTDecomposer):
+class _lDDTGreedySearcher(_lDDTDecomposer):
     def __init__(self, ref, mdl, ref_chem_groups, mdl_chem_groups,
                  ref_mdl_alns, inclusion_radius = 15.0,
                  thresholds = [0.5, 1.0, 2.0, 4.0],
@@ -2157,10 +2196,449 @@ def _lDDTGreedyBlock(the_greed, seed_size, blocks_per_chem_group):
     return final_mapping
 
 
-def _SingleRigid(initial_transforms, initial_mappings, chem_groups,
-                 chem_mapping, trg_group_pos, mdl_group_pos,
-                 single_chain_gdtts_thresh, iterative_superposition,
-                 first_complete, n_trg_chains, n_mdl_chains):
+class _QSScoreGreedySearcher(qsscore.QSScorer):
+    def __init__(self, ref, mdl, ref_chem_groups, mdl_chem_groups,
+                 ref_mdl_alns, contact_d = 12.0,
+                 steep_opt_rate = None):
+        """ Greedy extension of already existing but incomplete chain mappings
+        """
+        super().__init__(ref, ref_chem_groups, mdl, ref_mdl_alns,
+                         contact_d = contact_d)
+        self.ref = ref
+        self.mdl = mdl
+        self.ref_mdl_alns = ref_mdl_alns
+        self.steep_opt_rate = steep_opt_rate
+
+        self.neighbors = {k: set() for k in self.qsent1.chain_names}
+        for p in self.qsent1.interacting_chains:
+            self.neighbors[p[0]].add(p[1])
+            self.neighbors[p[1]].add(p[0])
+
+        self.mdl_neighbors = {k: set() for k in self.qsent2.chain_names}
+        for p in self.qsent2.interacting_chains:
+            self.mdl_neighbors[p[0]].add(p[1])
+            self.mdl_neighbors[p[1]].add(p[0])
+
+        assert(len(ref_chem_groups) == len(mdl_chem_groups))
+        self.ref_chem_groups = ref_chem_groups
+        self.mdl_chem_groups = mdl_chem_groups
+        self.ref_ch_group_mapper = dict()
+        self.mdl_ch_group_mapper = dict()
+        for g_idx, (ref_g, mdl_g) in enumerate(zip(ref_chem_groups,
+                                                   mdl_chem_groups)):
+            for ch in ref_g:
+                self.ref_ch_group_mapper[ch] = g_idx
+            for ch in mdl_g:
+                self.mdl_ch_group_mapper[ch] = g_idx
+
+        # cache for lDDT based single chain conserved contacts
+        # used to identify starting points for further extension by QS score
+        # key: tuple (ref_ch, mdl_ch) value: number of conserved contacts
+        self.single_chain_scorer = dict()
+        self.single_chain_cache = dict()
+        for ch in self.ref.chains:
+            single_chain_ref = self.ref.Select(f"cname={ch.GetName()}")
+            self.single_chain_scorer[ch.GetName()] = \
+            lddt.lDDTScorer(single_chain_ref, bb_only = True)
+
+    def SCCounts(self, ref_ch, mdl_ch):
+        if not (ref_ch, mdl_ch) in self.single_chain_cache:
+            alns = dict()
+            alns[mdl_ch] = self.ref_mdl_alns[(ref_ch, mdl_ch)]
+            mdl_sel = self.mdl.Select(f"cname={mdl_ch}")
+            s = self.single_chain_scorer[ref_ch]
+            _,_,_,conserved,_,_,_ = s.lDDT(mdl_sel,
+                                           residue_mapping=alns,
+                                           return_dist_test=True,
+                                           no_interchain=True,
+                                           chain_mapping={mdl_ch: ref_ch},
+                                           check_resnames=False)
+            self.single_chain_cache[(ref_ch, mdl_ch)] = conserved
+        return self.single_chain_cache[(ref_ch, mdl_ch)]
+
+    def ExtendMapping(self, mapping, max_ext = None):
+
+        if len(mapping) == 0:
+            raise RuntimError("Mapping must contain a starting point")
+
+        for ref_ch, mdl_ch in mapping.items():
+            assert(ref_ch in self.ref_ch_group_mapper)
+            assert(mdl_ch in self.mdl_ch_group_mapper)
+            assert(self.ref_ch_group_mapper[ref_ch] == \
+                   self.mdl_ch_group_mapper[mdl_ch])
+
+        # Ref chains onto which we can map. The algorithm starts with a mapping
+        # on ref_ch. From there we can start to expand to connected neighbors.
+        # All neighbors that we can reach from the already mapped chains are
+        # stored in this set which will be updated during runtime
+        map_targets = set()
+        for ref_ch in mapping.keys():
+            map_targets.update(self.neighbors[ref_ch])
+
+        # remove the already mapped chains
+        for ref_ch in mapping.keys():
+            map_targets.discard(ref_ch)
+
+        if len(map_targets) == 0:
+            return mapping # nothing to extend
+
+        # keep track of what model chains are not yet mapped for each chem group
+        free_mdl_chains = list()
+        for chem_group in self.mdl_chem_groups:
+            tmp = [x for x in chem_group if x not in mapping.values()]
+            free_mdl_chains.append(set(tmp))
+
+        # keep track of what ref chains got a mapping
+        newly_mapped_ref_chains = list()
+
+        something_happened = True
+        while something_happened:
+            something_happened=False
+
+            if self.steep_opt_rate is not None:
+                n_chains = len(newly_mapped_ref_chains)
+                if n_chains > 0 and n_chains % self.steep_opt_rate == 0:
+                    mapping = self._SteepOpt(mapping, newly_mapped_ref_chains)
+
+            if max_ext is not None and len(newly_mapped_ref_chains) >= max_ext:
+                break
+
+            # nominator and denominator to determine current QS score
+            nominator, denominator = self._FromFlatMapping(mapping)
+            old_score = 0.0
+            if denominator != 0.0:
+                old_score = nominator/denominator
+
+            max_diff = 0.0
+            max_mapping = None
+            for ref_ch in map_targets:
+                chem_group_idx = self.ref_ch_group_mapper[ref_ch]
+                for mdl_ch in free_mdl_chains[chem_group_idx]:
+                    nominator_diff = 0.0
+                    denominator_diff = 0.0
+                    for neighbor in self.neighbors[ref_ch]:
+                        if neighbor in mapping and mapping[neighbor] in \
+                        self.mdl_neighbors[mdl_ch]:
+                            # it's a newly added interface if (ref_ch, mdl_ch)
+                            # are added to mapping
+                            int1 = (ref_ch, neighbor)
+                            int2 = (mdl_ch, mapping[neighbor])
+                            a, b = self._MappedInterfaceScores(int1, int2)
+                            nominator_diff += a
+                            denominator_diff += b
+                            # the respective interface penalties are subtracted
+                            # from denominator
+                            denominator_diff -= self._InterfacePenalty1(int1)
+                            denominator_diff -= self._InterfacePenalty2(int2)
+
+                    if nominator_diff > 0:
+                        # Only accept a new solution if its actually connected
+                        # i.e. nominator_diff > 0.
+                        new_nominator = nominator + nominator_diff
+                        new_denominator = denominator + denominator_diff
+                        new_score = 0.0
+                        if new_denominator != 0.0:
+                            new_score = new_nominator/new_denominator
+                        diff = new_score - old_score
+                        if diff > max_diff:
+                            max_diff = diff
+                            max_mapping = (ref_ch, mdl_ch)
+     
+            if max_mapping is not None:
+                something_happened = True
+                # assign new found mapping
+                mapping[max_mapping[0]] = max_mapping[1]
+
+                # add all neighboring chains to map targets as they are now
+                # reachable
+                for neighbor in self.neighbors[max_mapping[0]]:
+                    if neighbor not in mapping:
+                        map_targets.add(neighbor)
+
+                # remove the ref chain from map targets
+                map_targets.remove(max_mapping[0])
+
+                # remove the mdl chain from free_mdl_chains - its taken...
+                chem_group_idx = self.ref_ch_group_mapper[max_mapping[0]]
+                free_mdl_chains[chem_group_idx].remove(max_mapping[1])
+
+                # keep track of what ref chains got a mapping
+                newly_mapped_ref_chains.append(max_mapping[0])
+
+        return mapping
+
+    def _SteepOpt(self, mapping, chains_to_optimize=None):
+
+        # just optimize ALL ref chains if nothing specified
+        if chains_to_optimize is None:
+            chains_to_optimize = mapping.keys()
+
+        # make sure that we only have ref chains which are actually mapped
+        ref_chains = [x for x in chains_to_optimize if mapping[x] is not None]
+
+        # group ref chains to be optimized into chem groups
+        tmp = dict()
+        for ch in ref_chains:
+            chem_group_idx = self.ref_ch_group_mapper[ch] 
+            if chem_group_idx in tmp:
+                tmp[chem_group_idx].append(ch)
+            else:
+                tmp[chem_group_idx] = [ch]
+        chem_groups = list(tmp.values())
+
+        # try all possible mapping swaps. Swaps that improve the score are
+        # immediately accepted and we start all over again
+        nominator, denominator = self._fromFlatMapping(mapping)
+        current_score = 0.0
+        if denominator != 0.0:
+            current_score = nominator / denominator
+        something_happened = True
+        while something_happened:
+            something_happened = False
+            for chem_group in chem_groups:
+                if something_happened:
+                    break
+                for ch1, ch2 in itertools.combinations(chem_group, 2):
+                    swapped_mapping = dict(mapping)
+                    swapped_mapping[ch1] = mapping[ch2]
+                    swapped_mapping[ch2] = mapping[ch1]
+                    a, b = self._FromFlatMapping(swapped_mapping)
+                    score = 0.0
+                    if b != 0.0:
+                        score = a/b
+                    if score > current_score:
+                        something_happened = True
+                        mapping = swapped_mapping
+                        current_score = score
+                        break        
+
+        return mapping
+
+
+def _QSScoreNaive(trg, mdl, chem_groups, chem_mapping, ref_mdl_alns, contact_d,
+                  n_max_naive):
+    best_mapping = None
+    best_score = -1.0
+    # qs_scorer implements caching, score calculation is thus as fast as it gets
+    # you'll just hit a wall when the number of possible mappings becomes large
+    qs_scorer = qsscore.QSScorer(trg, chem_groups, mdl, ref_mdl_alns)
+    for mapping in _ChainMappings(chem_groups, chem_mapping, n_max_naive):
+        score = qs_scorer.GetQSScore(mapping, check=False)
+        if score > best_score:
+            best_mapping = mapping
+            best_score = score
+    return best_mapping
+
+
+def _QSScoreGreedyFast(the_greed):
+
+    something_happened = True
+    mapping = dict()
+
+    while something_happened:
+        something_happened = False
+        # search for best scoring starting point, we're using lDDT here
+        n_best = 0
+        best_seed = None
+        mapped_ref_chains = set(mapping.keys())
+        mapped_mdl_chains = set(mapping.values())
+        for ref_chains, mdl_chains in zip(the_greed.ref_chem_groups,
+                                          the_greed.mdl_chem_groups):
+            for ref_ch in ref_chains:
+                if ref_ch not in mapped_ref_chains:
+                    for mdl_ch in mdl_chains:
+                        if mdl_ch not in mapped_mdl_chains:
+                            n = the_greed.SCCounts(ref_ch, mdl_ch)
+                            if n > n_best:
+                                n_best = n
+                                best_seed = (ref_ch, mdl_ch)
+        if n_best == 0:
+            break # no proper seed found anymore...
+        # add seed to mapping and start the greed
+        mapping[best_seed[0]] = best_seed[1]
+        mapping = the_greed.ExtendMapping(mapping)
+        something_happened = True
+
+
+    # translate mapping format and return
+    final_mapping = list()
+    for ref_chains in the_greed.ref_chem_groups:
+        mapped_mdl_chains = list()
+        for ref_ch in ref_chains:
+            if ref_ch in mapping:
+                mapped_mdl_chains.append(mapping[ref_ch])
+            else:
+                mapped_mdl_chains.append(None)
+        final_mapping.append(mapped_mdl_chains)
+
+    return final_mapping
+
+
+def _QSScoreGreedyFull(the_greed, n_mdl_chains):
+    """ Uses each reference chain as starting point for expansion
+
+    However, not all mdl chain are mapped onto these reference chains,
+    that's controlled by *n_mdl_chains*
+    """
+
+    if n_mdl_chains is not None and n_mdl_chains < 1:
+        raise RuntimeError("n_mdl_chains must be None or >= 1")
+
+    something_happened = True
+    mapping = dict()
+
+    while something_happened:
+        something_happened = False
+        # Try all possible starting points and keep the one giving the best QS score
+        best_score = 0.0
+        best_mapping = None
+        mapped_ref_chains = set(mapping.keys())
+        mapped_mdl_chains = set(mapping.values())
+        for ref_chains, mdl_chains in zip(the_greed.ref_chem_groups,
+                                          the_greed.mdl_chem_groups):
+            for ref_ch in ref_chains:
+                if ref_ch not in mapped_ref_chains:
+                    seeds = list()
+                    for mdl_ch in mdl_chains:
+                        if mdl_ch not in mapped_mdl_chains:
+                            seeds.append((ref_ch, mdl_ch))
+                    if n_mdl_chains is not None and n_mdl_chains < len(seeds):
+                        counts = [the_greed.SCCounts(s[0], s[1]) for s in seeds]
+                        tmp = [(a,b) for a,b in zip(counts, seeds)]
+                        tmp.sort(reverse=True)
+                        seeds = [item[1] for item in tmp[:n_mdl_chains]]
+                    for seed in seeds:
+                        tmp_mapping = dict(mapping)
+                        tmp_mapping[seed[0]] = seed[1]
+                        tmp_mapping = the_greed.ExtendMapping(tmp_mapping)
+                        a, b = the_greed._FromFlatMapping(tmp_mapping)
+                        tmp_score = 0.0
+                        if b != 0.0:
+                            tmp_score = a/b
+                        if tmp_score > best_score:
+                            best_score = tmp_score
+                            best_mapping = tmp_mapping
+
+        if best_score == 0.0:
+            break # no proper mapping found anymore...
+
+        something_happened = True
+        mapping = best_mapping
+
+    # translate mapping format and return
+    final_mapping = list()
+    for ref_chains in the_greed.ref_chem_groups:
+        mapped_mdl_chains = list()
+        for ref_ch in ref_chains:
+            if ref_ch in mapping:
+                mapped_mdl_chains.append(mapping[ref_ch])
+            else:
+                mapped_mdl_chains.append(None)
+        final_mapping.append(mapped_mdl_chains)
+
+    return final_mapping
+
+
+def _QSScoreGreedyBlock(the_greed, seed_size, blocks_per_chem_group):
+    """ try multiple seeds, i.e. try all ref/mdl chain combinations within the
+    respective chem groups and compute single chain lDDTs. The
+    *blocks_per_chem_group* best scoring ones are extend by *seed_size* chains
+    and the best scoring one with respect to QS score is exhaustively extended.
+    """
+
+    if seed_size is None or seed_size < 1:
+        raise RuntimeError(f"seed_size must be an int >= 1 (got {seed_size})")
+
+    if blocks_per_chem_group is None or blocks_per_chem_group < 1:
+        raise RuntimeError(f"blocks_per_chem_group must be an int >= 1 "
+                           f"(got {blocks_per_chem_group})")
+
+    max_ext = seed_size - 1 #  -1 => start seed already has size 1
+
+    ref_chem_groups = copy.deepcopy(the_greed.ref_chem_groups)
+    mdl_chem_groups = copy.deepcopy(the_greed.mdl_chem_groups)
+
+    mapping = dict()
+
+    something_happened = True
+    while something_happened:
+        something_happened = False
+        starting_blocks = list()
+        for ref_chains, mdl_chains in zip(ref_chem_groups, mdl_chem_groups):
+            if len(mdl_chains) == 0:
+                continue # nothing to map
+
+            # Identify starting seeds for *blocks_per_chem_group* blocks
+            # thats done with lDDT
+            seeds = list()
+            for ref_ch in ref_chains:
+                seeds += [(ref_ch, mdl_ch) for mdl_ch in mdl_chains]
+            counts = [the_greed.SCCounts(s[0], s[1]) for s in seeds]
+            tmp = [(a,b) for a,b in zip(counts, seeds)]
+            tmp.sort(reverse=True)
+            seeds = [item[1] for item in tmp[:blocks_per_chem_group]]
+
+            # extend starting seeds to *seed_size* and retain best scoring block
+            # for further extension
+            best_score = 0.0
+            best_mapping = None
+            for s in seeds:
+                seed = dict(mapping)
+                seed.update({s[0]: s[1]})  
+                seed = the_greed.ExtendMapping(seed, max_ext = max_ext)
+                a, b = the_greed._FromFlatMapping(seed)
+                seed_score = 0.0
+                if b != 0.0:
+                    seed_score = a/b
+                if seed_score > best_score:
+                    best_score = seed_score
+                    best_mapping = seed
+            if best_mapping != None:
+                starting_blocks.append(best_mapping)
+
+        # fully expand initial starting blocks
+        best_score = 0.0
+        best_mapping = None
+        for seed in starting_blocks:
+            seed = the_greed.ExtendMapping(seed)
+            a, b = the_greed._FromFlatMapping(seed)
+            seed_score = 0.0
+            if b != 0.0:
+                seed_score = a/b
+            if seed_score > best_score:
+                best_score = seed_score
+                best_mapping = seed
+
+        if best_score == 0.0:
+            break # no proper mapping found anymore
+
+        something_happened = True
+        mapping.update(best_mapping)
+        for ref_ch, mdl_ch in best_mapping.items():
+            for group_idx in range(len(ref_chem_groups)):
+                if ref_ch in ref_chem_groups[group_idx]:
+                    ref_chem_groups[group_idx].remove(ref_ch)
+                if mdl_ch in mdl_chem_groups[group_idx]:
+                    mdl_chem_groups[group_idx].remove(mdl_ch)
+
+    # translate mapping format and return
+    final_mapping = list()
+    for ref_chains in the_greed.ref_chem_groups:
+        mapped_mdl_chains = list()
+        for ref_ch in ref_chains:
+            if ref_ch in mapping:
+                mapped_mdl_chains.append(mapping[ref_ch])
+            else:
+                mapped_mdl_chains.append(None)
+        final_mapping.append(mapped_mdl_chains)
+
+    return final_mapping
+
+
+def _SingleRigidGDTTS(initial_transforms, initial_mappings, chem_groups,
+                      chem_mapping, trg_group_pos, mdl_group_pos,
+                      single_chain_gdtts_thresh, iterative_superposition,
+                      first_complete, n_trg_chains, n_mdl_chains):
     """ Takes initial transforms and sequentially adds chain pairs with
     best scoring gdtts that fulfill single_chain_gdtts_thresh. The mapping
     from the transform that leads to best overall gdtts score is returned.
@@ -2221,10 +2699,10 @@ def _SingleRigid(initial_transforms, initial_mappings, chem_groups,
     return best_mapping
 
 
-def _IterativeRigid(initial_transforms, initial_mappings, chem_groups,
-                    chem_mapping, trg_group_pos, mdl_group_pos,
-                    single_chain_gdtts_thresh, iterative_superposition,
-                    first_complete, n_trg_chains, n_mdl_chains):
+def _IterativeRigidGDTTS(initial_transforms, initial_mappings, chem_groups,
+                         chem_mapping, trg_group_pos, mdl_group_pos,
+                         single_chain_gdtts_thresh, iterative_superposition,
+                         first_complete, n_trg_chains, n_mdl_chains):
     """ Takes initial transforms and sequentially adds chain pairs with
     best scoring gdtts that fulfill single_chain_gdtts_thresh. With each
     added chain pair, the transform gets updated. Thus the naming iterative.
@@ -2403,17 +2881,6 @@ def _IterativeRigidRMSD(initial_transforms, initial_mappings, chem_groups,
 
     return best_mapping
 
-def _QSScoreNaive(trg, mdl, chem_groups, chem_mapping, ref_mdl_alns, contact_d,
-                  n_max_naive):
-    best_mapping = None
-    best_score = -1.0
-    qs_scorer = qsscore.QSScorer(trg, chem_groups, mdl, ref_mdl_alns)
-    for mapping in _ChainMappings(chem_groups, chem_mapping, n_max_naive):
-        score = qs_scorer.GetQSScore(mapping, check=False)
-        if score > best_score:
-            best_mapping = mapping
-            best_score = score
-    return best_mapping
 
 def _GetRefPos(trg, mdl, trg_msas, mdl_alns, max_pos = None):
     """ Extracts reference positions which are present in trg and mdl
diff --git a/modules/mol/alg/pymod/qsscore.py b/modules/mol/alg/pymod/qsscore.py
index 54c4a3d5a..391499ae5 100644
--- a/modules/mol/alg/pymod/qsscore.py
+++ b/modules/mol/alg/pymod/qsscore.py
@@ -326,6 +326,16 @@ class QSScorer:
         for a, b in zip(self.chem_groups, mapping):
             flat_mapping.update({x: y for x, y in zip(a, b) if y is not None})
 
+        # refers to equation 6 in Bertoni et al., 2017
+        nominator, denominator = self._FromFlatMapping(flat_mapping)
+
+        if denominator > 0.0:
+            return nominator / denominator
+        else:
+            return 0.0
+
+    def _FromFlatMapping(self, flat_mapping):
+
         # refers to equation 6 in Bertoni et al., 2017
         nominator = 0.0
         denominator= 0.0
@@ -349,15 +359,18 @@ class QSScorer:
             if int2 not in processed_qsent2_interfaces:
                 denominator += self._InterfacePenalty2(int2)
 
-        if denominator > 0.0:
-            return nominator / denominator
-        else:
-            return 0.0
+        return (nominator, denominator)
 
     def _MappedInterfaceScores(self, int1, int2):
-        if (int1, int2) not in self._mapped_cache:
-            self._mapped_cache[(int1, int2)] = self._InterfaceScores(int1, int2)
-        return self._mapped_cache[(int1, int2)] 
+        key_one = (int1, int2)
+        if key_one in self._mapped_cache:
+            return self._mapped_cache[key_one]
+        key_two = ((int1[1], int1[0]), (int2[1], int2[0]))
+        if key_two in self._mapped_cache:
+            return self._mapped_cache[key_two]
+        nominator, denominator = self._InterfaceScores(int1, int2)
+        self._mapped_cache[key_one] = (nominator, denominator)
+        return (nominator, denominator) 
 
     def _InterfaceScores(self, int1, int2):
 
@@ -366,12 +379,11 @@ class QSScorer:
 
         # given two chain names a and b: if a < b, shape of pairwise distances is
         # (len(a), len(b)). However, if b > a, its (len(b), len(a)) => transpose
+        if int1[0] > int1[1]:
+            d1 = d1.transpose()
         if int2[0] > int2[1]:
             d2 = d2.transpose()
 
-        # should be given by design => no need to transpose d1
-        assert(int1[0] < int1[1])
-
         # indices of the first chain in the two interfaces
         mapped_indices_1_1, mapped_indices_1_2 = \
         self._IndexMapping(int1[0], int2[0])
diff --git a/modules/mol/alg/tests/test_chain_mapping.py b/modules/mol/alg/tests/test_chain_mapping.py
index 6e41fad79..d9a9362eb 100644
--- a/modules/mol/alg/tests/test_chain_mapping.py
+++ b/modules/mol/alg/tests/test_chain_mapping.py
@@ -245,6 +245,7 @@ class TestChainMapper(unittest.TestCase):
     # This is not supposed to be in depth algorithm testing, we just check
     # whether the various algorithms return sensible chain mappings
 
+    # lDDT based chain mappings
     naive_lddt_res = mapper.GetlDDTMapping(mdl, strategy="naive")
     self.assertEqual(naive_lddt_res.mapping, [['X', 'Y'],[None],['Z']])
 
@@ -258,19 +259,31 @@ class TestChainMapper(unittest.TestCase):
     greedy_lddt_res = mapper.GetlDDTMapping(mdl, strategy="greedy_block")
     self.assertEqual(greedy_lddt_res.mapping, [['X', 'Y'],[None],['Z']])
 
-    greedy_rigid_res = mapper.GetRigidMapping(mdl, strategy="greedy_single")
+
+    # QS score based chain mappings
+    naive_qsscore_res = mapper.GetQSScoreMapping(mdl, strategy="naive")
+    self.assertEqual(naive_qsscore_res.mapping, [['X', 'Y'],[None],['Z']])
+
+    greedy_qsscore_res = mapper.GetQSScoreMapping(mdl, strategy="greedy_fast")
+    self.assertEqual(naive_qsscore_res.mapping, [['X', 'Y'],[None],['Z']])
+
+    greedy_qsscore_res = mapper.GetQSScoreMapping(mdl, strategy="greedy_full")
+    self.assertEqual(naive_qsscore_res.mapping, [['X', 'Y'],[None],['Z']])
+
+    greedy_qsscore_res = mapper.GetQSScoreMapping(mdl, strategy="greedy_block")
+    self.assertEqual(naive_qsscore_res.mapping, [['X', 'Y'],[None],['Z']])
+
+
+    # rigid chain mappings
+    greedy_rigid_res = mapper.GetRigidMapping(mdl, strategy="greedy_single_gdtts")
     self.assertEqual(greedy_rigid_res.mapping, [['X', 'Y'],[None],['Z']])
 
-    greedy_rigid_res = mapper.GetRigidMapping(mdl, strategy="greedy_iterative")
+    greedy_rigid_res = mapper.GetRigidMapping(mdl, strategy="greedy_iterative_gdtts")
     self.assertEqual(greedy_rigid_res.mapping, [['X', 'Y'],[None],['Z']])
 
     greedy_rigid_res = mapper.GetRigidMapping(mdl, strategy="greedy_iterative_rmsd")
     self.assertEqual(greedy_rigid_res.mapping, [['X', 'Y'],[None],['Z']])
 
-    naive_qsscore_res = mapper.GetQSScoreMapping(mdl, strategy="naive")
-    self.assertEqual(naive_qsscore_res.mapping, [['X', 'Y'],[None],['Z']])
-
-
 if __name__ == "__main__":
   from ost import testutils
   if testutils.SetDefaultCompoundLib():
-- 
GitLab