From 520a1e9044761414cd3e3e93f67f490fb1842cde Mon Sep 17 00:00:00 2001
From: Gabriel Studer <gabriel.studer@unibas.ch>
Date: Tue, 17 Jan 2023 18:16:33 +0100
Subject: [PATCH] OST specific DockQ implementation

---
 modules/geom/pymod/export_vec3.cc    |   2 +
 modules/geom/src/vec3.cc             |  22 +++
 modules/geom/src/vec3.hh             |   2 +
 modules/mol/alg/doc/molalg.rst       |   8 +
 modules/mol/alg/pymod/CMakeLists.txt |   1 +
 modules/mol/alg/pymod/dockq.py       | 234 +++++++++++++++++++++++++++
 6 files changed, 269 insertions(+)
 create mode 100644 modules/mol/alg/pymod/dockq.py

diff --git a/modules/geom/pymod/export_vec3.cc b/modules/geom/pymod/export_vec3.cc
index 7691c59ac..445b8aef7 100644
--- a/modules/geom/pymod/export_vec3.cc
+++ b/modules/geom/pymod/export_vec3.cc
@@ -127,5 +127,7 @@ void export_Vec3()
     .def("GetGDTHA", &Vec3List::GetGDTHA, (arg("other"), arg("norm")=true))
     .def("GetGDTTS", &Vec3List::GetGDTTS, (arg("other"), arg("norm")=true))
     .def("GetGDT", &Vec3List::GetGDT, (arg("other"), arg("thresh"), arg("norm")=true))
+    .def("GetMinDist", &Vec3List::GetMinDist, (arg("other")))
+    .def("IsWithin", &Vec3List::IsWithin, (arg("other"), arg("dist")))
   ;
 }
diff --git a/modules/geom/src/vec3.cc b/modules/geom/src/vec3.cc
index 0adbc1a65..bb3a0617e 100644
--- a/modules/geom/src/vec3.cc
+++ b/modules/geom/src/vec3.cc
@@ -198,6 +198,28 @@ Real Vec3List::GetGDT(const Vec3List& other, Real thresh, bool norm) const
   return norm && !this->empty() ? static_cast<Real>(n)/(this->size()) : n;
 }
 
+Real Vec3List::GetMinDist(const Vec3List& other) const {
+  Real min = std::numeric_limits<Real>::max();
+  for(size_t i = 0; i < this->size(); ++i) {
+    for(size_t j = 0; j < other.size(); ++j) {
+      min = std::min(min, geom::Length2((*this)[i] - other[j]));
+    }
+  }
+  return std::sqrt(min);
+}
+
+bool Vec3List::IsWithin(const Vec3List& other, Real dist) const {
+  Real squared_dist = dist*dist;
+  for(size_t i = 0; i < this->size(); ++i) {
+    for(size_t j = 0; j < other.size(); ++j) {
+      if(geom::Length2((*this)[i] - other[j]) <= squared_dist) {
+        return true;
+      }
+    }
+  }
+  return false;
+}
+
 std::pair<Line3, Real> Vec3List::FitCylinder(const Vec3& initial_direction) const
   { 
     Vec3 center=this->GetCenter();
diff --git a/modules/geom/src/vec3.hh b/modules/geom/src/vec3.hh
index 830272906..2c0e2bbb2 100644
--- a/modules/geom/src/vec3.hh
+++ b/modules/geom/src/vec3.hh
@@ -343,6 +343,8 @@ public:
   Real GetGDTHA(const Vec3List& other, bool norm=true) const;
   Real GetGDTTS(const Vec3List& other, bool norm=true) const;
   Real GetGDT(const Vec3List& other, Real thresh, bool norm=true) const;
+  Real GetMinDist(const Vec3List& other) const;
+  bool IsWithin(const Vec3List& other, Real dist) const;
 
   //This function fits a cylinder to the positions in Vec3List
   //It takes as argument an initial guess for the direction.
diff --git a/modules/mol/alg/doc/molalg.rst b/modules/mol/alg/doc/molalg.rst
index 759903efc..57f43ddb9 100644
--- a/modules/mol/alg/doc/molalg.rst
+++ b/modules/mol/alg/doc/molalg.rst
@@ -145,6 +145,14 @@ Local Distance Test scores (lDDT, DRMSD)
 .. currentmodule:: ost.mol.alg
 
 
+:mod:`DockQ <ost.mol.alg.dockq>` -- DockQ implementation
+--------------------------------------------------------------------------------
+
+.. autofunction:: ost.mol.alg.dockq.DockQ
+
+.. currentmodule:: ost.mol.alg
+
+
 .. _steric-clashes:
 
 Steric Clashes
diff --git a/modules/mol/alg/pymod/CMakeLists.txt b/modules/mol/alg/pymod/CMakeLists.txt
index e50232556..9c65a6b87 100644
--- a/modules/mol/alg/pymod/CMakeLists.txt
+++ b/modules/mol/alg/pymod/CMakeLists.txt
@@ -28,6 +28,7 @@ set(OST_MOL_ALG_PYMOD_MODULES
   chain_mapping.py
   stereochemistry.py
   ligand_scoring.py
+  dockq.py
 )
 
 if (NOT ENABLE_STATIC)
diff --git a/modules/mol/alg/pymod/dockq.py b/modules/mol/alg/pymod/dockq.py
new file mode 100644
index 000000000..c22ca78d2
--- /dev/null
+++ b/modules/mol/alg/pymod/dockq.py
@@ -0,0 +1,234 @@
+from ost import geom
+from ost import mol
+
+def _PreprocessStructures(mdl, ref, mdl_ch1, mdl_ch2, ref_ch1, ref_ch2,
+                          ch1_aln = None, ch2_aln = None):
+    """ Preprocesses *mdl* and *ref*
+
+    Returns two entity views with the exact same number of residues. I.e. the
+    residues correspond to a one-to-one mapping. Additionally, each residue gets
+    the int property "dockq_map" assigned, which corresponds to the residue
+    index in the respective chain of the processed structures.
+    """
+    mdl_residues_1 = list()
+    mdl_residues_2 = list()
+    ref_residues_1 = list()
+    ref_residues_2 = list()
+
+    if ch1_aln is None and ch2_aln is None:
+        # go by residue numbers
+        for mdl_r in mdl.Select(f"cname={mdl_ch1}").residues:
+            ref_r = ref.FindResidue(ref_ch1, mdl_r.GetNumber())
+            if ref_r.IsValid():
+                mdl_residues_1.append(mdl_r)
+                ref_residues_1.append(ref_r)
+        for mdl_r in mdl.Select(f"cname={mdl_ch2}").residues:
+            ref_r = ref.FindResidue(ref_ch2, mdl_r.GetNumber())
+            if ref_r.IsValid():
+                mdl_residues_2.append(mdl_r)
+                ref_residues_2.append(ref_r)
+    else:
+        raise NotImplementedError("No aln mapping implemented yet")
+
+    new_mdl = mdl.handle.CreateEmptyView()
+    new_ref = ref.handle.CreateEmptyView()
+    for r in mdl_residues_1:
+        new_mdl.AddResidue(r.handle, mol.INCLUDE_ALL)
+    for r in mdl_residues_2:
+        new_mdl.AddResidue(r.handle, mol.INCLUDE_ALL)
+    for r in ref_residues_1:
+        new_ref.AddResidue(r.handle, mol.INCLUDE_ALL)
+    for r in ref_residues_2:
+        new_ref.AddResidue(r.handle, mol.INCLUDE_ALL)
+
+    # set dockq_map property
+    ch = new_mdl.FindChain(mdl_ch1)
+    for r_idx, r in enumerate(ch.residues):
+        r.SetIntProp("dockq_map", r_idx)
+    ch = new_mdl.FindChain(mdl_ch2)
+    for r_idx, r in enumerate(ch.residues):
+        r.SetIntProp("dockq_map", r_idx)
+    ch = new_ref.FindChain(ref_ch1)
+    for r_idx, r in enumerate(ch.residues):
+        r.SetIntProp("dockq_map", r_idx)
+    ch = new_ref.FindChain(ref_ch2)
+    for r_idx, r in enumerate(ch.residues):
+        r.SetIntProp("dockq_map", r_idx)
+
+    return (new_mdl, new_ref)
+
+def _GetContacts(ent, ch1, ch2, dist_thresh):
+    int1 = ent.Select(f"cname={ch1} and {dist_thresh} <> [cname={ch2}]")
+    int2 = ent.Select(f"cname={ch2} and {dist_thresh} <> [cname={ch1}]")
+    contacts = set()
+    int1_p = [geom.Vec3List([a.pos for a in r.atoms]) for r in int1.residues]
+    int2_p = [geom.Vec3List([a.pos for a in r.atoms]) for r in int2.residues]
+    for r1, p1 in zip(int1.residues, int1_p):
+        for r2, p2 in zip(int2.residues, int2_p):
+            if p1.IsWithin(p2, dist_thresh):
+                contacts.add((r1.GetIntProp("dockq_map"), r2.GetIntProp("dockq_map")))
+    return contacts
+
+def _ContactScores(mdl, ref, mdl_ch1, mdl_ch2, ref_ch1, ref_ch2, dist_thresh=5.0):
+    ref_contacts = _GetContacts(ref, ref_ch1, ref_ch2, dist_thresh)
+    mdl_contacts = _GetContacts(mdl, mdl_ch1, mdl_ch2, dist_thresh)
+
+    nnat = len(ref_contacts)
+    nmdl = len(mdl_contacts)
+
+    fnat = len(ref_contacts.intersection(mdl_contacts))
+    if nnat > 0:
+        fnat /= nnat
+
+    fnonnat = len(mdl_contacts.difference(ref_contacts))
+    if len(mdl_contacts) > 0:
+        fnonnat /= len(mdl_contacts)
+
+    return (nnat, nmdl, fnat, fnonnat)
+
+def _RMSDScores(mdl, ref, mdl_ch1, mdl_ch2, ref_ch1, ref_ch2, dist_thresh=10.0):
+
+    mdl_ch1_residues = mdl.FindChain(mdl_ch1).residues
+    mdl_ch2_residues = mdl.FindChain(mdl_ch2).residues
+    ref_ch1_residues = ref.FindChain(ref_ch1).residues
+    ref_ch2_residues = ref.FindChain(ref_ch2).residues
+
+    # iRMSD
+    #######
+    int1 = ref.Select(f"cname={ref_ch1} and {dist_thresh} <> [cname={ref_ch2}]")
+    int2 = ref.Select(f"cname={ref_ch2} and {dist_thresh} <> [cname={ref_ch1}]")
+    int1_indices = [r.GetIntProp("dockq_map") for r in int1.residues]
+    int2_indices = [r.GetIntProp("dockq_map") for r in int2.residues]
+    ref_pos = geom.Vec3List()
+    mdl_pos = geom.Vec3List()
+    atom_names = ['CA','C','N','O']
+    for idx in int1_indices:
+        ref_r = ref_ch1_residues[idx]
+        mdl_r = mdl_ch1_residues[idx]
+        for aname in atom_names:
+            ref_a = ref_r.FindAtom(aname)
+            mdl_a = mdl_r.FindAtom(aname)
+            if ref_a.IsValid() and mdl_a.IsValid():
+                ref_pos.append(ref_a.pos)
+                mdl_pos.append(mdl_a.pos)
+
+    for idx in int2_indices:
+        ref_r = ref_ch2_residues[idx]
+        mdl_r = mdl_ch2_residues[idx]
+        for aname in atom_names:
+            ref_a = ref_r.FindAtom(aname)
+            mdl_a = mdl_r.FindAtom(aname)
+            if ref_a.IsValid() and mdl_a.IsValid():
+                ref_pos.append(ref_a.pos)
+                mdl_pos.append(mdl_a.pos)
+
+    if len(mdl_pos) >= 3:
+        sup_result = mol.alg.SuperposeSVD(mdl_pos, ref_pos)
+        irmsd = sup_result.rmsd
+    else:
+        irmsd = 0.0
+
+    # lRMSD
+    #######
+    # receptor is by definition the larger chain
+    if len(ref_ch1_residues) > len(ref_ch2_residues):
+        ref_receptor_residues = ref_ch1_residues
+        ref_ligand_residues = ref_ch2_residues
+        mdl_receptor_residues = mdl_ch1_residues
+        mdl_ligand_residues = mdl_ch2_residues
+    else:
+        ref_receptor_residues = ref_ch2_residues
+        ref_ligand_residues = ref_ch1_residues
+        mdl_receptor_residues = mdl_ch2_residues
+        mdl_ligand_residues = mdl_ch1_residues
+
+    ref_receptor_positions = geom.Vec3List()
+    mdl_receptor_positions = geom.Vec3List()
+    ref_ligand_positions = geom.Vec3List()
+    mdl_ligand_positions = geom.Vec3List()
+
+    for ref_r, mdl_r in zip(ref_receptor_residues, mdl_receptor_residues):
+        for aname in atom_names:
+            ref_a = ref_r.FindAtom(aname)
+            mdl_a = mdl_r.FindAtom(aname)
+            if ref_a.IsValid() and mdl_a.IsValid():
+                ref_receptor_positions.append(ref_a.pos)
+                mdl_receptor_positions.append(mdl_a.pos)
+
+    for ref_r, mdl_r in zip(ref_ligand_residues, mdl_ligand_residues):
+        for aname in atom_names:
+            ref_a = ref_r.FindAtom(aname)
+            mdl_a = mdl_r.FindAtom(aname)
+            if ref_a.IsValid() and mdl_a.IsValid():
+                ref_ligand_positions.append(ref_a.pos)
+                mdl_ligand_positions.append(mdl_a.pos)
+
+    if len(mdl_receptor_positions) >= 3:
+        sup_result = mol.alg.SuperposeSVD(mdl_receptor_positions,
+                                          ref_receptor_positions)
+        mdl_ligand_positions.ApplyTransform(sup_result.transformation)
+        lrmsd = mdl_ligand_positions.GetRMSD(ref_ligand_positions)
+    else:
+        lrmsd = 0.0
+
+    return (irmsd, lrmsd)
+
+def _ScaleRMSD(rmsd, d):
+    return 1.0/(1+(rmsd/d)**2)
+
+def _DockQ(fnat, lrmsd, irmsd, d1, d2):
+    """ The final number chrunching as described in the DockQ manuscript
+    """
+    return (fnat + _ScaleRMSD(lrmsd, d1) + _ScaleRMSD(irmsd, d2))/3
+
+def DockQ(mdl, ref, mdl_ch1, mdl_ch2, ref_ch1, ref_ch2,
+          ch1_aln=None, ch2_aln=None):
+    """ Computes DockQ for specified interface
+
+    DockQ is described in: Sankar Basu and Bjoern Wallner (2016), "DockQ: A
+    Quality Measure for Protein-Protein Docking Models", PLOS one 
+
+    Residues are mapped based on residue numbers by default. If you provide
+    *ch1_aln* and *ch2_aln* you can enforce an arbitrary mapping.
+
+    :param mdl: Model structure
+    :type mdl: :class:`ost.mol.EntityView`/:class:`ost.mol.EntityHandle`
+    :param ref: Reference structure, i.e. native structure
+    :type ref: :class:`ost.mol.EntityView`/:class:`ost.mol.EntityHandle`
+    :param mdl_ch1: Specifies chain in model constituting first part of
+                    interface
+    :type mdl_ch1: :class:`str`
+    :param mdl_ch2: Specifies chain in model constituting second part of
+                    interface
+    :type mdl_ch2: :class:`str`
+    :param ref_ch1: ref equivalent of mdl_ch1
+    :type ref_ch1: :class:`str`
+    :param ref_ch2: ref equivalent of mdl_ch2
+    :type ref_ch2: :class:`str`
+    :param ch1_aln: Alignment with two sequences to map *ref_ch1* and *mdl_ch1*.
+                    The first sequence must match the sequence in *ref_ch1* and
+                    the second to *mdl_ch1*.
+    :type ch1_aln: :class:`ost.seq.AlignmentHandle`
+    :param ch2_aln: Alignment with two sequences to map *ref_ch2* and *mdl_ch2*.
+                    The first sequence must match the sequence in *ref_ch2* and
+                    the second to *mdl_ch2*.
+    :type ch2_aln: :class:`ost.seq.AlignmentHandle`
+    :returns: :class:`dict` with keys nnat, nmdl, fnat, fnonnat, irmsd, lrmsd,
+              DockQ which corresponds to the equivalent values in the original
+              DockQ implementation.
+    """
+    mapped_model, mapped_ref = _PreprocessStructures(mdl, ref, mdl_ch1, mdl_ch2,
+                                                     ref_ch1, ref_ch2,
+                                                     ch1_aln = ch1_aln,
+                                                     ch2_aln = ch2_aln)
+    nnat, nmdl, fnat, fnonnat = _ContactScores(mapped_model, mapped_ref,
+                                         mdl_ch1, mdl_ch2, ref_ch1, ref_ch2)
+    irmsd, lrmsd = _RMSDScores(mapped_model, mapped_ref,
+                               mdl_ch1, mdl_ch2, ref_ch1, ref_ch2)
+    return {"nnat": nnat,
+            "nmdl": nmdl,
+            "fnat": fnat,
+            "fnonnat": fnonnat,
+            "irmsd": round(irmsd, 3),
+            "lrmsd": round(lrmsd, 3),
+            "DockQ": round(_DockQ(fnat, lrmsd, irmsd, 8.5, 1.5), 3)}
-- 
GitLab