From efeb1b55d440aab18801fe6fbb7be2852881322a Mon Sep 17 00:00:00 2001
From: Gabriel Studer <gabriel.studer@unibas.ch>
Date: Fri, 26 Aug 2022 19:55:22 +0200
Subject: [PATCH] add GetRepr function to ChainMapper

Searches for representations of a target stubstructure in model. Main intent
for this function was to find most similar interfaces but in principle it can
also be used to implement the CAMEO lDDT-BS
---
 modules/mol/alg/pymod/chain_mapping.py | 403 ++++++++++++++++++++++++-
 1 file changed, 399 insertions(+), 4 deletions(-)

diff --git a/modules/mol/alg/pymod/chain_mapping.py b/modules/mol/alg/pymod/chain_mapping.py
index 8ed4b2baa..e9a026203 100644
--- a/modules/mol/alg/pymod/chain_mapping.py
+++ b/modules/mol/alg/pymod/chain_mapping.py
@@ -18,6 +18,209 @@ from ost import geom
 
 from ost.mol.alg import lddt
 
+class ReprResult:
+
+    """ Result object for :func:`ChainMapper.GetRepr`
+
+    Constructor is directly called within the function, no need to construct
+    such object yourself.
+
+    :param lDDT: lDDT for this mapping. Depends on how you call
+                 :func:`ChainMapper.GetRepr` whether this is backbone only or
+                 full atom lDDT.
+    :type lDDT: :class:`float`
+    :param ref_residues: Qualified names of ref structure, i.e. return values 
+                         of :func:`ost.mol.ResidueHandle.GetQualifiedName`
+    :type ref_residues: :class:`list` of :class:`str`
+    :param mdl_residues: Same for mdl residues
+    :type mdl_residues: :class:`list` of :class:`str`
+    :param ref_bb_pos: Representative backbone positions for reference residues.
+                       Thats CA positions for peptides and C3' positions for
+                       Nucleotides.
+    :type ref_bb_pos: :class:`geom.Vec3List`
+    """
+    def __init__(self, lDDT, ref_residues, mdl_residues, ref_bb_pos,
+                 mdl_bb_pos):
+        self._lDDT = lDDT
+        self._ref_residues = ref_residues
+        self._mdl_residues = mdl_residues
+        self._ref_bb_pos = ref_bb_pos
+        self._mdl_bb_pos = mdl_bb_pos
+
+        # lazily evaluated attributes
+        self._transform = None
+        self._superposed_mdl_bb_pos = None
+        self._bb_rmsd = None
+        self._gdt_8 = None
+        self._gdt_4 = None
+        self._gdt_2 = None
+        self._gdt_1 = None
+        self._ost_query = None
+
+    @property
+    def lDDT(self):
+        """ lDDT of representation result
+
+        Depends on how you call :func:`ChainMapper.GetRepr` whether this is
+        backbone only or full atom lDDT.
+
+        :type: :class:`float`
+        """
+        return self._lDDT
+    
+    @property
+    def ref_residues(self):
+        """ Qualified names of ref structure residues
+
+        The return values of :func:`ost.mol.ResidueHandle.GetQualifiedName`
+
+        :type: :class:`list` of :class:`str`
+        """
+        return self._ref_residues
+    
+    @property
+    def mdl_residues(self):
+        """ Qualified names of mdl structure residues
+
+        The return values of :func:`ost.mol.ResidueHandle.GetQualifiedName`
+
+        :type: :class:`list` of :class:`str`
+        """
+        return self._mdl_residues
+    
+    @property
+    def ref_bb_pos(self):
+        """ Representative backbone positions for reference residues.
+
+        Thats CA positions for peptides and C3' positions for Nucleotides.
+
+        :type: :class:`geom.Vec3List`
+        """
+        return self._ref_bb_pos
+
+    @property
+    def mdl_bb_pos(self):
+        """ Representative backbone positions for model residues.
+
+        Thats CA positions for peptides and C3' positions for Nucleotides.
+
+        :type: :class:`geom.Vec3List`
+        """
+        return self._mdl_bb_pos
+
+    @property
+    def transform(self):
+        """ Transformation to superpose mdl residues onto ref residues
+
+        Superposition computed as minimal RMSD superposition on
+        :attr:`ref_bb_pos` and :attr:`mdl_bb_pos`
+
+        :type: :class:`ost.geom.Mat4`
+        """
+        if self._transform is None:
+            self._transform = _GetTransform(self.mdl_bb_pos, self.ref_bb_pos,
+                                            False)
+        return self._transform
+
+    @property
+    def superposed_mdl_bb_pos(self):
+        """ :attr:`mdl_bb_pos` with :attr:`transform applied`
+
+        :type: :class:`geom.Vec3List`
+        """
+        if self._superposed_mdl_bb_pos is None:
+            self._superposed_mdl_bb_pos = geom.Vec3List(self.mdl_bb_pos)
+            self._superposed_mdl_bb_pos.ApplyTransform(self.transform)
+        return self._superposed_mdl_bb_pos
+
+    @property
+    def bb_rmsd(self):
+        """ RMSD between :attr:`ref_bb_pos` and :attr:`superposed_mdl_bb_pos`
+
+        :type: :class:`float`
+        """
+        if self._bb_rmsd is None:
+            self._bb_rmsd = self.ref_bb_pos.GetRMSD(self.superposed_mdl_bb_pos)
+        return self._bb_rmsd
+
+    @property
+    def gdt_8(self):
+        """ GDT with one single threshold: 8.0
+
+        :type: :class:`float`
+        """
+        if self._gdt_8 is None:
+            self._gdt_8 = self.ref_bb_pos.GetGDT(self.superposed_mdl_bb_pos, 8.0)
+        return self._gdt_8
+
+    @property
+    def gdt_4(self):
+        """ GDT with one single threshold: 4.0
+
+        :type: :class:`float`
+        """
+        if self._gdt_4 is None:
+            self._gdt_4 = self.ref_bb_pos.GetGDT(self.superposed_mdl_bb_pos, 4.0)
+        return self._gdt_4
+
+    @property
+    def gdt_2(self):
+        """ GDT with one single threshold: 2.0
+
+        :type: :class:`float`
+        """
+        if self._gdt_2 is None:
+            self._gdt_2 = self.ref_bb_pos.GetGDT(self.superposed_mdl_bb_pos, 2.0)
+        return self._gdt_2
+
+    @property
+    def gdt_1(self):
+        """ GDT with one single threshold: 1.0
+
+        :type: :class:`float`
+        """
+        if self._gdt_1 is None:
+            self._gdt_1 = self.ref_bb_pos.GetGDT(self.superposed_mdl_bb_pos, 1.0)
+        return self._gdt_1
+
+    @property
+    def ost_query(self):
+        """ query for mdl residues in OpenStructure query language
+
+        Repr can be selected as ``full_mdl.Select(ost_query)``
+
+        :type: :class:`str`
+        """
+        if self._ost_query is None:
+            chain_rnums = dict()
+            for r in self.mdl_residues:
+                chname = r.split('.')[0]
+                rnum = r.split('.')[1][3:]
+                if chname not in chain_rnums:
+                    chain_rnums[chname] = list()
+                chain_rnums[chname].append(rnum)
+            chain_queries = list()
+            for k,v in chain_rnums.items():
+                chain_queries.append(f"(cname={k} and rnum={','.join(v)})")
+            self._ost_query = " or ".join(chain_queries)
+        return self._ost_query
+
+    def JSONSummary(self):
+        """ Returns JSON serializable summary of results
+        """
+        json_dict = dict()
+        json_dict["lDDT"] = self.lDDT
+        json_dict["ref_residues"] = self.ref_residues
+        json_dict["mdl_residues"] = self.mdl_residues
+        json_dict["transform"] = list(self.transform.data)
+        json_dict["bb_rmsd"] = self.bb_rmsd
+        json_dict["gdt_8"] = self.gdt_8
+        json_dict["gdt_4"] = self.gdt_4
+        json_dict["gdt_2"] = self.gdt_2
+        json_dict["gdt_1"] = self.gdt_1
+        json_dict["ost_query"] = self.ost_query
+        return json_dict
+
 
 class ChainMapper:
     """ Class to compute chain mappings
@@ -340,8 +543,8 @@ class ChainMapper:
                 # chain_mapping and alns as input for lDDT computation
                 lddt_chain_mapping = dict()
                 lddt_alns = dict()
-                for ref_chem_group, mdl_chem_group, ref_aln in \
-                zip(self.chem_groups, mapping, self.chem_group_alignments):
+                for ref_chem_group, mdl_chem_group in zip(self.chem_groups,
+                                                          mapping):
                     for ref_ch, mdl_ch in zip(ref_chem_group, mdl_chem_group):
                         # some mdl chains can be None
                         if mdl_ch is not None:
@@ -604,6 +807,199 @@ class ChainMapper:
         return final_mapping
 
 
+    def GetRepr(self, substructure, model, topn=1, inclusion_radius=15.0,
+                thresholds=[0.5, 1.0, 2.0, 4.0], bb_only=False,
+                only_interchain=False):
+        """ Identify *topn* representations of *substructure* in *model*
+
+        *substructure* defines a subset of :attr:`~target` for which one
+        wants the *topn* representations in *model*. Representations are scored
+        and sorted by lDDT.
+
+        :param substructure: A :class:`ost.mol.EntityView` which is a subset of
+                             :attr:`~target`. Should be selected with the
+                             OpenStructure query language. Example: if you're
+                             interested in residues with number 42,43 and 85 in
+                             chain A:
+                             ``substructure=mapper.target.Select("cname=A and rnum=42,43,85")``
+                             A :class:`RuntimeError` is raised if *substructure*
+                             does not refer to the same underlying
+                             :class:`ost.mol.EntityHandle` as :attr:`~target`.
+        :type substructure: :class:`ost.mol.EntityView`
+        :param model: Structure in which one wants to find representations for
+                      *substructure*
+        :type model: :class:`ost.mol.EntityView`/:class:`ost.mol.EntityHandle`
+        :param topn: Max number of representations that are returned
+        :type topn: :class:`int`
+        :param inclusion_radius: Inclusion radius for lDDT
+        :type inclusion_radius: :class:`float`
+        :param thresholds: Thresholds for lDDT
+        :type thresholds: :class:`list` of :class:`float`
+        :param bb_only: Only consider backbone atoms in lDDT computation
+        :type bb_only: :class:`bool`
+        :param only_interchain: Only score interchain contacts in lDDT. Useful
+                                if you want to identify interface patches.
+        :type only_interchain: :class:`bool`
+
+        :returns: :class:`list` of :class:`ReprResult`
+        """
+
+        if topn < 1:
+            raise RuntimeError("topn must be >= 1")
+
+        # check whether substructure really is a subset of self.target
+        for r in substructure.residues:
+            ch_name = r.GetChain().GetName()
+            rnum = r.GetNumber()
+            target_r = self.target.FindResidue(ch_name, rnum)
+            if target_r is None:
+                raise RuntimeError(f"substructure has residue "
+                                   f"{r.GetQualifiedName()} which is not in "
+                                   f"self.target")
+            if target_r.handle.GetHashCode() != r.handle.GetHashCode():
+                raise RuntimeError(f"substructure has residue "
+                                   f"{r.GetQualifiedName()} which has an "
+                                   f"equivalent in self.target but it does "
+                                   f"not refer to the same underlying "
+                                   f"EntityHandle")
+            for a in r.atoms:
+                target_a = target_r.FindAtom(a.GetName())
+                if target_a is None:
+                    raise RuntimeError(f"substructure has atom "
+                                       f"{a.GetQualifiedName()} which is not "
+                                       f"in self.target")
+                if a.handle.GetHashCode() != target_a.handle.GetHashCode():
+                    raise RuntimeError(f"substructure has atom "
+                                       f"{a.GetQualifiedName()} which has an "
+                                       f"equivalent in self.target but it does "
+                                       f"not refer to the same underlying "
+                                       f"EntityHandle")
+
+            # check whether it contains either CA or C3'
+            ca = r.FindAtom("CA")
+            c3 = r.FindAtom("C3'") # FindAtom with prime in string is tested
+                                   # and works
+            if ca is None and c3 is None:
+                raise RuntimeError("All residues in substructure must contain "
+                                   "a backbone atom named CA or C3\'")
+
+        # perform mapping and alignments on full structures
+        chem_mapping, chem_group_alns, mdl = self.GetChemMapping(model)
+        ref_mdl_alns =  _GetRefMdlAlns(self.chem_groups,
+                                       self.chem_group_alignments,
+                                       chem_mapping,
+                                       chem_group_alns)
+
+        # Get residue indices relative to full target chain 
+        substructure_res_indices = dict()
+        for ch in substructure.chains:
+            full_ch = self.target.FindChain(ch.GetName())
+            idx = [full_ch.GetResidueIndex(r.GetNumber()) for r in ch.residues]
+            substructure_res_indices[ch.GetName()] = idx
+
+        # strip down variables to make them specific to substructure
+        # keep only chem_groups which are present in substructure
+        substructure_chem_groups = list()
+        substructure_chem_mapping = list()
+        
+        chnames = set([ch.GetName() for ch in substructure.chains])
+        for chem_group, mapping in zip(self.chem_groups, chem_mapping):
+            substructure_chem_group = [ch for ch in chem_group if ch in chnames]
+            if len(substructure_chem_group) > 0:
+                substructure_chem_groups.append(substructure_chem_group)
+                substructure_chem_mapping.append(mapping)
+
+        # strip the reference sequence in alignments to only contain
+        # sequence from substructure
+        substructure_ref_mdl_alns = dict()
+        for chem_group, mapping in zip(substructure_chem_groups,
+                                       substructure_chem_mapping):
+            for ref_ch in chem_group:
+                for mdl_ch in mapping:
+                    full_aln = ref_mdl_alns[(ref_ch, mdl_ch)]
+                    ref_seq = full_aln.GetSequence(0)
+                    # the ref sequence is tricky... we start with a gap only
+                    # sequence and only add olcs as defined by the residue
+                    # indices that we extracted before...
+                    tmp = ['-'] * len(full_aln)
+                    for idx in substructure_res_indices[ref_ch]:
+                        idx_in_seq = ref_seq.GetPos(idx)
+                        tmp[idx_in_seq] = ref_seq[idx_in_seq]
+                    ref_seq = seq.CreateSequence(ref_ch, ''.join(tmp))
+                    ref_seq.AttachView(substructure.Select(f"cname={ref_ch}"))
+                    aln = seq.CreateAlignment()
+                    aln.AddSequence(ref_seq)
+                    aln.AddSequence(full_aln.GetSequence(1))
+                    substructure_ref_mdl_alns[(ref_ch, mdl_ch)] = aln
+
+        lddt_scorer = lddt.lDDTScorer(substructure,
+                                      inclusion_radius = inclusion_radius,
+                                      bb_only = bb_only)
+        scored_mappings = list()
+        for mapping in _ChainMappings(substructure_chem_groups,
+                                      substructure_chem_mapping,
+                                      self.n_max_naive):
+            # chain_mapping and alns as input for lDDT computation
+            lddt_chain_mapping = dict()
+            lddt_alns = dict()
+            for ref_chem_group, mdl_chem_group in zip(substructure_chem_groups,
+                                                      mapping):
+                for ref_ch, mdl_ch in zip(ref_chem_group, mdl_chem_group):
+                    # some mdl chains can be None
+                    if mdl_ch is not None:
+                        lddt_chain_mapping[mdl_ch] = ref_ch
+                        aln = substructure_ref_mdl_alns[(ref_ch, mdl_ch)]
+                        lddt_alns[mdl_ch] = aln
+            lDDT, _ = lddt_scorer.lDDT(mdl, thresholds=thresholds,
+                                       chain_mapping=lddt_chain_mapping,
+                                       residue_mapping = lddt_alns,
+                                       check_resnames = False,
+                                       no_intrachain = only_interchain)
+
+            if len(scored_mappings) == 0:
+                if lDDT > 0.0:
+                    scored_mappings.append((lDDT, mapping))
+            elif len(scored_mappings) < topn:
+                if lDDT > 0.0:
+                    scored_mappings.append((lDDT, mapping))
+                    scored_mappings.sort(reverse=True)
+            elif lDDT > scored_mappings[-1][0]:
+                scored_mappings.append((lDDT, mapping))
+                scored_mappings.sort(reverse=True)
+                scored_mappings = scored_mappings[:topn]
+
+        # finalize and return
+        results = list()
+        for scored_mapping in scored_mappings:
+            ref_residues = list()
+            mdl_residues = list()
+            ref_bb_pos = geom.Vec3List()
+            mdl_bb_pos = geom.Vec3List()
+
+            for ref_ch_group, mdl_ch_group in zip(substructure_chem_groups,
+                                                  scored_mapping[1]):
+                for ref_ch, mdl_ch in zip(ref_ch_group, mdl_ch_group):
+                    if ref_ch is not None and mdl_ch is not None:
+                        aln = substructure_ref_mdl_alns[(ref_ch, mdl_ch)]
+                        for col in aln:
+                            if col[0] != '-' and col[1] != '-':
+                                ref_r = col.GetResidue(0)
+                                mdl_r = col.GetResidue(1)
+                                ref_residues.append(ref_r.GetQualifiedName())
+                                mdl_residues.append(mdl_r.GetQualifiedName())
+                                ref_at = ref_r.FindAtom("CA")
+                                if ref_at is None:
+                                    ref_at = ref_r.FindAtom("C3'")
+                                mdl_at = mdl_r.FindAtom("CA")
+                                if mdl_at is None:
+                                    mdl_at = mdl_r.FindAtom("C3'")
+                                ref_bb_pos.append(ref_at.GetPos())
+                                mdl_bb_pos.append(mdl_at.GetPos())
+
+            results.append(ReprResult(scored_mapping[0], ref_residues,
+                                      mdl_residues, ref_bb_pos, mdl_bb_pos))
+        return results
+
     def GetNMappings(self, model):
         """ Returns number of possible mappings
 
@@ -2038,7 +2434,6 @@ def _ConcatIterators(iterators):
     for item in itertools.product(*iterators):
         yield list(item)
 
-
 def _ChainMappings(ref_chains, mdl_chains, n_max=None):
     """Returns all possible ways to map *mdl_chains* onto fixed *ref_chains*
 
@@ -2113,4 +2508,4 @@ def _GetTransform(pos_one, pos_two, iterative):
     return res.transformation
 
 # specify public interface
-__all__ = ('ChainMapper',)
+__all__ = ('ChainMapper', 'ReprResult')
-- 
GitLab