From 29b2180b79546414448246cde1731ac30bf5ac50 Mon Sep 17 00:00:00 2001
From: Gabriel Studer <gabriel.studer@unibas.ch>
Date: Fri, 3 Mar 2023 09:27:46 +0100
Subject: [PATCH] enable custom chain mappings in Scorer object

---
 modules/mol/alg/pymod/chain_mapping.py |   7 +-
 modules/mol/alg/pymod/scoring.py       | 102 ++++++++++++++++++++++++-
 2 files changed, 107 insertions(+), 2 deletions(-)

diff --git a/modules/mol/alg/pymod/chain_mapping.py b/modules/mol/alg/pymod/chain_mapping.py
index 874f83d12..7ee8ff5fc 100644
--- a/modules/mol/alg/pymod/chain_mapping.py
+++ b/modules/mol/alg/pymod/chain_mapping.py
@@ -1814,7 +1814,7 @@ def _MapSequence(ref_seqs, ref_types, s, s_type, aligner):
     return (scored_alns[0][1], scored_alns[0][2])
 
 def _GetRefMdlAlns(ref_chem_groups, ref_chem_group_msas, mdl_chem_groups,
-                   mdl_chem_group_alns):
+                   mdl_chem_group_alns, pairs=None):
     """ Get all possible ref/mdl chain alignments given chem group mapping
 
     :param ref_chem_groups: :attr:`ChainMapper.chem_groups`
@@ -1831,6 +1831,9 @@ def _GetRefMdlAlns(ref_chem_groups, ref_chem_group_msas, mdl_chem_groups,
                                 Return values of
                                 :func:`ChainMapper.GetChemMapping`.
     :type mdl_chem_group_alns: :class:`list` of :class:`ost.seq.AlignmentList`
+    :param pairs: Pro param - restrict return dict to specified pairs. A set of
+                  tuples in form (<trg_ch>, <mdl_ch>)
+    :type pairs: :class:`set`
     :returns: A dictionary holding all possible ref/mdl chain alignments. Keys
               in that dictionary are tuples of the form (ref_ch, mdl_ch) and
               values are the respective pairwise alignments with first sequence
@@ -1850,6 +1853,8 @@ def _GetRefMdlAlns(ref_chem_groups, ref_chem_group_msas, mdl_chem_groups,
                                                ref_chem_group_msas):
         for ref_ch in ref_chains:
             for mdl_ch in mdl_chains:
+                if pairs is not None and (ref_ch, mdl_ch) not in pairs:
+                    continue
                 # obtain alignments of mdl and ref chains towards chem
                 # group ref sequence and merge them
                 aln_list = seq.AlignmentList()
diff --git a/modules/mol/alg/pymod/scoring.py b/modules/mol/alg/pymod/scoring.py
index babef6014..3bb9a762a 100644
--- a/modules/mol/alg/pymod/scoring.py
+++ b/modules/mol/alg/pymod/scoring.py
@@ -121,10 +121,14 @@ class Scorer:
                            not given, voronota-cadscore must be in PATH if any
                            of the CAD score related attributes is requested.
     :type cad_score_exec: :class:`str`
+    :param custom_mapping: Provide custom chain mapping between *model* and
+                           *target*. Dictionary with target chain names as key
+                           and model chain names as value.
+    :type custom_mapping: :class:`dict`
     """
     def __init__(self, model, target, resnum_alignments=False,
                  molck_settings = None, naive_chain_mapping_thresh=12,
-                 cad_score_exec = None):
+                 cad_score_exec = None, custom_mapping=None):
 
         model = model.Select("peptide=True or nucleotide=True")
         self._model = mol.CreateEntityFromView(model, False)
@@ -244,6 +248,9 @@ class Scorer:
         self._patch_qs = None
         self._patch_dockq = None
 
+        if custom_mapping is not None:
+            self._set_custom_mapping(custom_mapping)
+
     @property
     def model(self):
         """ Model with Molck cleanup
@@ -1449,3 +1456,96 @@ class Scorer:
             for a in r.atoms:
                 ed.InsertAtom(added_r, a.handle)
         return ent
+
+    def _set_custom_mapping(self, mapping):
+        """ sets self._mapping with a full blown MappingResult object
+
+        :param mapping: mapping with trg chains as key and mdl ch as values
+        :type mapping: :class:`dict`
+        """
+
+        chain_mapper = self.chain_mapper
+        chem_mapping, chem_group_alns, mdl = \
+        chain_mapper.GetChemMapping(self.model)
+
+        # now that we have a chem mapping, lets do consistency checks
+        # - check whether chain names are unique and available in structures
+        # - check whether the mapped chains actually map to the same chem groups
+        if len(mapping) != len(set(mapping.keys())):
+            raise RuntimeError(f"Expect unique trg chain names in mapping. Got "
+                               f"{mapping.keys()}")
+        if len(mapping) != len(set(mapping.values())):
+            raise RuntimeError(f"Expect unique mdl chain names in mapping. Got "
+                               f"{mapping.values()}")
+
+        trg_chains = set([ch.GetName() for ch in chain_mapper.target.chains])
+        mdl_chains = set([ch.GetName() for ch in mdl.chains])
+        for k,v in mapping.items():
+            if k not in trg_chains:
+                raise RuntimeError(f"Target chain \"{k}\" is not available "
+                                   f"in target processed for chain mapping "
+                                   f"({trg_chains})")
+            if v not in mdl_chains:
+                raise RuntimeError(f"Model chain \"{v}\" is not available "
+                                   f"in model processed for chain mapping "
+                                   f"({mdl_chains})")
+
+        for trg_ch, mdl_ch in mapping.items():
+            trg_group_idx = None
+            mdl_group_idx = None
+            for idx, group in enumerate(chain_mapper.chem_groups):
+                if trg_ch in group:
+                    trg_group_idx = idx
+                    break
+            for idx, group in enumerate(chem_mapping):
+                if mdl_ch in group:
+                    mdl_group_idx = idx
+                    break
+            if trg_group_idx is None or mdl_group_idx is None:
+                raise RuntimeError("Could not establish a valid chem grouping "
+                                   "of chain names provided in custom mapping.")
+            
+            if trg_group_idx != mdl_group_idx:
+                raise RuntimeError(f"Chem group mismatch in custom mapping: "
+                                   f"target chain \"{trg_ch}\" groups with the "
+                                   f"following chemically equivalent target "
+                                   f"chains: "
+                                   f"{chain_mapper.chem_groups[trg_group_idx]} "
+                                   f"but model chain \"{mdl_ch}\" maps to the "
+                                   f"following target chains: "
+                                   f"{chain_mapper.chem_groups[mdl_group_idx]}")
+
+        pairs = set([(trg_ch, mdl_ch) for trg_ch, mdl_ch in mapping.items()])
+        ref_mdl_alns =  \
+        chain_mapping._GetRefMdlAlns(chain_mapper.chem_groups,
+                                     chain_mapper.chem_group_alignments,
+                                     chem_mapping,
+                                     chem_group_alns,
+                                     pairs = pairs)
+
+        # translate mapping format
+        final_mapping = list()
+        for ref_chains in chain_mapper.chem_groups:
+            mapped_mdl_chains = list()
+            for ref_ch in ref_chains:
+                if ref_ch in mapping:
+                    mapped_mdl_chains.append(mapping[ref_ch])
+                else:
+                    mapped_mdl_chains.append(None)
+            final_mapping.append(mapped_mdl_chains)
+
+        alns = dict()
+        for ref_group, mdl_group in zip(chain_mapper.chem_groups,
+                                        final_mapping):
+            for ref_ch, mdl_ch in zip(ref_group, mdl_group):
+                if ref_ch is not None and mdl_ch is not None:
+                    aln = ref_mdl_alns[(ref_ch, mdl_ch)]
+                    trg_view = chain_mapper.target.Select(f"cname={ref_ch}")
+                    mdl_view = mdl.Select(f"cname={mdl_ch}")
+                    aln.AttachView(0, trg_view)
+                    aln.AttachView(1, mdl_view)
+                    alns[(ref_ch, mdl_ch)] = aln
+
+        self._mapping = chain_mapping.MappingResult(chain_mapper.target, mdl,
+                                                    chain_mapper.chem_groups,
+                                                    final_mapping, alns)
-- 
GitLab