From 5395c0eb0e52b8d68450ef07a743a2fd8bc2529a Mon Sep 17 00:00:00 2001
From: Gabriel Studer <gabriel.studer@unibas.ch>
Date: Mon, 28 Nov 2022 13:04:26 +0100
Subject: [PATCH] scoring: add interface patch scores

---
 modules/mol/alg/pymod/scoring.py | 295 ++++++++++++++++++++++++++++++-
 1 file changed, 294 insertions(+), 1 deletion(-)

diff --git a/modules/mol/alg/pymod/scoring.py b/modules/mol/alg/pymod/scoring.py
index a31acaec4..9da83733d 100644
--- a/modules/mol/alg/pymod/scoring.py
+++ b/modules/mol/alg/pymod/scoring.py
@@ -233,6 +233,9 @@ class Scorer:
         self._cad_score = None
         self._local_cad_score = None
 
+        self._patch_qs = None
+        self._patch_dockq = None
+
     @property
     def model(self):
         """ Model with Molck cleanup
@@ -729,7 +732,47 @@ class Scorer:
         if self._local_cad_score is None:
             self._compute_cad_score()
         return self._local_cad_score
+
+    @property
+    def patch_qs(self):
+        """ Patch QS-scores for each residue in :attr:`model_interface_residues`
+
+        Representative patches for each residue r in chain c are computed as
+        follows:
     
+        * mdl_patch_one: All residues in c with CB (CA for GLY) positions within
+          8A of r and within 12A of residues from any other chain.
+        * mdl_patch_two: Closest residue x to r in any other chain gets
+          identified. Patch is then constructed by selecting all residues from
+          any other chain within 8A of x and within 12A from any residue in c.
+        * trg_patch_one: Chain name and residue number based mapping from
+          mdl_patch_one
+        * trg_patch_two: Chain name and residue number based mapping from
+          mdl_patch_two
+
+        Results are stored in the same manner as
+        :attr:`model_interface_residues`, with corresponding scores instead of
+        residue numbers. Scores for residues which are not
+        :class:`mol.ChemType.AMINOACIDS` are set to None. Additionally,
+        interface patches are derived from :attr:`model`. If they contain
+        residues which are not covered by :attr:`target`, the score is set to
+        None too.
+
+        :type: :class:`dict` with chain names as key and and :class:`list`
+                with scores of the respective interface residues.
+        """
+        if self._patch_qs is None:
+            self._compute_patchqs_scores()
+        return self._patch_qs
+
+    @property
+    def patch_dockq(self):
+        """ Same as :attr:`patch_qs` but for DockQ scores
+        """
+        if self._patch_dockq is None:
+            self._compute_patchdockq_scores()
+        return self._patch_dockq
+
     def _compute_lddt(self):
         # lDDT requires a flat mapping with mdl_ch as key and trg_ch as value
         flat_mapping = self.mapping.GetFlatMapping(mdl_as_key=True)
@@ -959,7 +1002,7 @@ class Scorer:
         self._local_cad_score = local_cad
 
     def _get_repr_view(self, ent):
-        """ Returns view with representative atoms => CB, CA for GLY
+        """ Returns view with representative peptide atoms => CB, CA for GLY
     
         Ensures that each residue has exactly one atom with assertions
     
@@ -1011,3 +1054,253 @@ class Scorer:
         self._target_clashes = b
         self._target_bad_bonds = c
         self._target_bad_angles = d
+
+
+    def _get_interface_patches(self, mdl_ch, mdl_rnum):
+        """ Select interface patches representative for specified residue
+    
+        The patches for specified residue r in chain c are selected as follows:
+    
+        * mdl_patch_one: All residues in c with CB (CA for GLY) positions within 8A
+          of r and within 12A of residues from any other chain.
+        * mdl_patch_two: Closest residue x to r in any other chain gets identified.
+          Patch is then constructed by selecting all residues from any other chain
+          within 8A of x and within 12A from any residue in c.
+        * trg_patch_one: Chain name and residue number based mapping from
+          mdl_patch_one
+        * trg_patch_two: Chain name and residue number based mapping from
+          mdl_patch_two
+    
+        :param mdl_ch: Name of chain in *self.model* of residue of interest
+        :type mdl_ch: :class:`str`
+        :param mdl_rnum: Residue number of residue of interest
+        :type mdl_rnum: :class:`int`
+        :returns: Tuple with 5 elements: 1) :class:`bool` flag whether all residues
+                  in *mdl* patches are covered in *trg* 2) mtl_patch_one
+                  3) mdl_patch_two 4) trg_patch_one 5) trg_patch_two
+        """
+        # select for representative positions => CB, CA for GLY 
+        repr_mdl = self._get_repr_view(self.model.Select("peptide=true"))
+    
+        # get position for specified residue
+        r = self.model.FindResidue(mdl_ch, mol.ResNum(mdl_rnum))
+        if not r.IsValid():
+            raise RuntimeError(f"Cannot find residue {mdl_rnum} in chain {mdl_ch}")
+        if r.GetName() == "GLY":
+            at = r.FindAtom("CA")
+        else:
+            at = r.FindAtom("CB")
+        if not at.IsValid():
+            raise RuntimeError("Cannot find interface views for res without CB/CA")
+        r_pos = at.GetPos()
+    
+        # mdl_patch_one contains residues from the same chain as r
+        # => all residues within 8A of r and within 12A of any other chain
+    
+        # q1 selects for everything in same chain and within 8A of r_pos
+        q1 = f"(cname={mdl_ch} and 8 <> {{{r_pos[0]},{r_pos[1]},{r_pos[2]}}})"
+        # q2 selects for everything within 12A of any other chain
+        q2 = f"(12 <> [cname!={mdl_ch}])"
+        mdl_patch_one = self.model.CreateEmptyView()
+        sel = repr_mdl.Select(" and ".join([q1, q2]))
+        for r in sel.residues:
+            mdl_r = self.model.FindResidue(r.GetChain().GetName(), r.GetNumber())
+            mdl_patch_one.AddResidue(mdl_r, mol.ViewAddFlag.INCLUDE_ALL)
+    
+        # mdl_patch_two contains residues from all other chains. In detail:
+        # the closest residue to r is identified in any other chain, and the
+        # patch is filled with residues that are within 8A of that residue and
+        # within 12A of chain from r
+        sel = repr_mdl.Select(f"(cname!={mdl_ch})")
+        close_stuff = sel.FindWithin(r_pos, 8)
+        min_pos = None
+        min_dist = 42.0
+        for close_at in close_stuff:
+            dist = geom.Distance(r_pos, close_at.GetPos())
+            if dist < min_dist:
+                min_pos = close_at.GetPos()
+                min_dist = dist
+    
+        # q1 selects for everything not in mdl_ch but within 8A of min_pos
+        q1 = f"(cname!={mdl_ch} and 8 <> {{{min_pos[0]},{min_pos[1]},{min_pos[2]}}})"
+        # q2 selects for everything within 12A of mdl_ch
+        q2 = f"(12 <> [cname={mdl_ch}])"
+        mdl_patch_two = self.model.CreateEmptyView()
+        sel = repr_mdl.Select(" and ".join([q1, q2]))
+        for r in sel.residues:
+            mdl_r = self.model.FindResidue(r.GetChain().GetName(), r.GetNumber())
+            mdl_patch_two.AddResidue(mdl_r, mol.ViewAddFlag.INCLUDE_ALL)
+    
+        # transfer mdl residues to trg
+        flat_mapping = self.mapping.GetFlatMapping(mdl_as_key=True)
+        full_trg_coverage = True
+        trg_patch_one = self.target.CreateEmptyView()
+        for r in mdl_patch_one.residues:
+            mdl_cname = r.GetChain().GetName()
+            aln = self.mapping.alns[(flat_mapping[mdl_cname], mdl_cname)]
+            trg_r = None
+            for col in aln:
+                if col[0] != '-' and col[1] != '-':
+                    if col.GetResidue(1).GetNumber() == r.GetNumber():
+                        trg_r = col.GetResidue(0)
+                        break
+            if trg_r is not None:
+                trg_patch_one.AddResidue(trg_r.handle,
+                                         mol.ViewAddFlag.INCLUDE_ALL)
+            else:
+                full_trg_coverage = False
+    
+        trg_patch_two = self.target.CreateEmptyView()
+        for r in mdl_patch_two.residues:
+            mdl_cname = r.GetChain().GetName()
+            aln = self.mapping.alns[(flat_mapping[mdl_cname], mdl_cname)]
+            trg_r = None
+            for col in aln:
+                if col[0] != '-' and col[1] != '-':
+                    if col.GetResidue(1).GetNumber() == r.GetNumber():
+                        trg_r = col.GetResidue(0)
+                        break
+            if trg_r is not None:
+                trg_patch_two.AddResidue(trg_r.handle,
+                                         mol.ViewAddFlag.INCLUDE_ALL)
+            else:
+                full_trg_coverage = False
+    
+        return (full_trg_coverage, mdl_patch_one, mdl_patch_two,
+                trg_patch_one, trg_patch_two)
+
+    def _compute_patchqs_scores(self):
+        self._patch_qs = dict()
+        for cname, rnums in self.model_interface_residues.items():
+            scores = list()
+            for rnum in rnums:
+                score = None
+                r = self.model.FindResidue(cname, mol.ResNum(rnum))
+                if r.IsValid() and r.GetChemType() == mol.ChemType.AMINOACIDS:
+                    full_trg_coverage, mdl_patch_one, mdl_patch_two, \
+                    trg_patch_one, trg_patch_two = \
+                    self._get_interface_patches(cname, rnum)
+                    if full_trg_coverage:
+                        score = self._patchqs(mdl_patch_one, mdl_patch_two,
+                                              trg_patch_one, trg_patch_two)
+                scores.append(score)
+            self._patch_qs[cname] = scores
+
+    def _compute_patchdockq_scores(self):
+        self._patch_dockq = dict()
+        for cname, rnums in self.model_interface_residues.items():
+            scores = list()
+            for rnum in rnums:
+                score = None
+                r = self.model.FindResidue(cname, mol.ResNum(rnum))
+                if r.IsValid() and r.GetChemType() == mol.ChemType.AMINOACIDS:
+                    full_trg_coverage, mdl_patch_one, mdl_patch_two, \
+                    trg_patch_one, trg_patch_two = \
+                    self._get_interface_patches(cname, rnum)
+                    if full_trg_coverage:
+                        score = self._patchdockq(mdl_patch_one, mdl_patch_two,
+                                                 trg_patch_one, trg_patch_two)
+                scores.append(score)
+            self._patch_dockq[cname] = scores
+
+    def _patchqs(self, mdl_patch_one, mdl_patch_two, trg_patch_one, trg_patch_two):
+        """ Score interface residue patches with QS-score
+    
+        In detail: Construct two entities with two chains each. First chain
+        consists of residues from <x>_patch_one and second chain consists of
+        <x>_patch_two. The returned score is the QS-score between the two
+        entities
+    
+        :param mdl_patch_one: Interface patch representing scored residue
+        :type mdl_patch_one: :class:`ost.mol.EntityView`
+        :param mdl_patch_two: Interface patch representing scored residue
+        :type mdl_patch_two: :class:`ost.mol.EntityView`
+        :param trg_patch_one: Interface patch representing scored residue
+        :type trg_patch_one: :class:`ost.mol.EntityView`
+        :param trg_patch_two: Interface patch representing scored residue
+        :type trg_patch_two: :class:`ost.mol.EntityView`
+        :returns: PatchQS score
+        """
+        qs_ent_mdl = self._qs_ent_from_patches(mdl_patch_one, mdl_patch_two)
+        qs_ent_trg = self._qs_ent_from_patches(trg_patch_one, trg_patch_two)
+    
+        alnA = seq.CreateAlignment()
+        s = ''.join([r.one_letter_code for r in mdl_patch_one.residues])
+        alnA.AddSequence(seq.CreateSequence("A", s))
+        s = ''.join([r.one_letter_code for r in trg_patch_one.residues])
+        alnA.AddSequence(seq.CreateSequence("A", s))
+    
+        alnB = seq.CreateAlignment()
+        s = ''.join([r.one_letter_code for r in mdl_patch_two.residues])
+        alnB.AddSequence(seq.CreateSequence("B", s))
+        s = ''.join([r.one_letter_code for r in trg_patch_two.residues])
+        alnB.AddSequence(seq.CreateSequence("B", s))
+        alns = {("A", "A"): alnA, ("B", "B"): alnB}
+    
+        scorer = QSScorer(qs_ent_mdl, [["A"], ["B"]], qs_ent_trg, alns)
+        score_result = scorer.Score([["A"], ["B"]])
+    
+        return score_result.QS_global
+    
+    def _patchdockq(self, mdl_patch_one, mdl_patch_two, trg_patch_one,
+                    trg_patch_two):
+        """ Score interface residue patches with DockQ
+    
+        In detail: Construct two entities with two chains each. First chain
+        consists of residues from <x>_patch_one and second chain consists of
+        <x>_patch_two. The returned score is the QS-score between the two
+        entities
+    
+        :param mdl_patch_one: Interface patch representing scored residue
+        :type mdl_patch_one: :class:`ost.mol.EntityView`
+        :param mdl_patch_two: Interface patch representing scored residue
+        :type mdl_patch_two: :class:`ost.mol.EntityView`
+        :param trg_patch_one: Interface patch representing scored residue
+        :type trg_patch_one: :class:`ost.mol.EntityView`
+        :param trg_patch_two: Interface patch representing scored residue
+        :type trg_patch_two: :class:`ost.mol.EntityView`
+        :returns: DockQ score
+        """
+        if not self.resnum_alignments:
+            raise RuntimeError("DockQ computations rely on residue numbers "
+                               "that are consistent between target and model "
+                               "chains, i.e. only work if resnum_alignments "
+                               "is True at Scorer construction.")
+        try:
+            dockq_exec = settings.Locate("DockQ.py",
+                                         explicit_file_name=self.dockq_exec)
+        except Exception as e:
+            raise RuntimeError("DockQ.py must be in PATH for DockQ "
+                               "scoring") from e
+        m = self._qs_ent_from_patches(mdl_patch_one, mdl_patch_two)
+        t = self._qs_ent_from_patches(trg_patch_one, trg_patch_two)
+        try:
+            dockq_result = dockq.DockQ(dockq_exec, t, m, "A", "B", "A", "B")
+        except Exception as e:
+            if "AssertionError: length of native is zero" in str(e):
+                return 0.0
+            else:
+                raise
+        return dockq_result.DockQ
+
+    def _qs_ent_from_patches(self, patch_one, patch_two):
+        """ Constructs Entity with two chains named "A" and "B""
+    
+        Blindly adds all residues from *patch_one* to chain A and residues from
+        patch_two to chain B.
+        """
+        ent = mol.CreateEntity()
+        ed = ent.EditXCS()
+        added_ch = ed.InsertChain("A")
+        for r in patch_one.residues:
+            added_r = ed.AppendResidue(added_ch, r.GetName())
+            added_r.SetChemClass(str(r.GetChemClass()))
+            for a in r.atoms:
+                ed.InsertAtom(added_r, a.handle)
+        added_ch = ed.InsertChain("B")
+        for r in patch_two.residues:
+            added_r = ed.AppendResidue(added_ch, r.GetName())
+            added_r.SetChemClass(str(r.GetChemClass()))
+            for a in r.atoms:
+                ed.InsertAtom(added_r, a.handle)
+        return ent
-- 
GitLab