From 39e528b0e44256ab600ccc8a83b668a25184b1d3 Mon Sep 17 00:00:00 2001
From: Xavier Robin <xavier.robin@unibas.ch>
Date: Mon, 16 Jan 2023 15:35:00 +0100
Subject: [PATCH] feat: SCHWED-5783 add networkx dependency

---
 docker/Dockerfile                            |  1 +
 modules/mol/alg/pymod/ligand_scoring.py      | 82 +++++++++++++++++++-
 modules/mol/alg/tests/test_ligand_scoring.py | 12 ++-
 3 files changed, 91 insertions(+), 4 deletions(-)

diff --git a/docker/Dockerfile b/docker/Dockerfile
index 4a1a9bd26..1dce8f696 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -28,6 +28,7 @@ RUN apt-get update -y && apt-get install -y cmake \
                                             python3-numpy \
                                             python3-scipy \
                                             python3-pandas \
+                                            python3-networkx \
                                             doxygen \
                                             swig \
                                             clustalw \
diff --git a/modules/mol/alg/pymod/ligand_scoring.py b/modules/mol/alg/pymod/ligand_scoring.py
index 879a75e4c..d38873290 100644
--- a/modules/mol/alg/pymod/ligand_scoring.py
+++ b/modules/mol/alg/pymod/ligand_scoring.py
@@ -1,13 +1,22 @@
 import os
+
+import numpy as np
+import networkx
+
 from ost import mol
 from ost.mol.alg import chain_mapping
-import numpy as np
 
 
 class LigandScorer:
     """ Helper class to access the various small molecule ligand (non polymer)
     scores available from ost.mol.alg.
 
+    .. note ::
+      Extra requirements:
+
+      - Python modules `numpy` and `networkx` must be available
+        (e.g. use ``pip install numpy networkx``)
+
     Mostly expects cleaned up structures (you can use the
     :class:`~ost.mol.alg.scoring.Scorer` outputs for that).
 
@@ -215,5 +224,74 @@ class LigandScorer:
             new_editor.UpdateICS()
         return extracted_ligands
 
+    def _get_binding_sites(self, ligand, topn=100000):
+        """Find representations of the binding site of *ligand* in the model.
 
-__all__ = ["LigandScorer"]
+        Ignore other ligands and waters that may be in proximity.
+
+        :param ligand: Defines the binding site to identify.
+        :type ligand: :class:`~ost.mol.ResidueHandle`
+        """
+        if ligand.hash_code not in self._binding_sites:
+
+            # create view of reference binding site
+            ref_residues_hashes = set()  # helper to keep track of added residues
+            for ligand_at in ligand.atoms:
+                close_atoms = self.target.FindWithin(ligand_at.GetPos(), self.radius)
+                for close_at in close_atoms:
+                    # Skip other ligands and waters.
+                    # This assumes that .IsLigand() is properly set on the entity's
+                    # residues.
+                    ref_res = close_at.GetResidue()
+                    if not (ref_res.is_ligand or
+                            ref_res.chem_type == mol.ChemType.WATERS):
+                        h = ref_res.handle.GetHashCode()
+                        if h not in ref_residues_hashes:
+                            ref_residues_hashes.add(h)
+
+            # reason for doing that separately is to guarantee same ordering of
+            # residues as in underlying entity. (Reorder by ResNum seems only
+            # available on ChainHandles)
+            ref_bs = self.target.CreateEmptyView()
+            for ch in self.target.chains:
+                for r in ch.residues:
+                    if r.handle.GetHashCode() in ref_residues_hashes:
+                        ref_bs.AddResidue(r, mol.ViewAddFlag.INCLUDE_ALL)
+
+            # Find the representations
+            self._binding_sites[ligand.hash_code] = self.chain_mapper.GetRepr(
+                ref_bs, self.model, inclusion_radius=self.lddt_bs_radius,
+                topn=topn)
+        return self._binding_sites[ligand.hash_code]
+
+    def _compute_rmsd(self):
+        """"""
+        # Create the matrix
+        self._rmsd_matrix = np.empty((len(self.target_ligands),
+                                      len(self.model_ligands)), dtype=float)
+        for trg_ligand in self.target_ligands:
+            LogDebug("Compute RMSD for target ligand %s" % trg_ligand)
+
+            for binding_site in self._get_binding_sites(trg_ligand):
+                import ipdb; ipdb.set_trace()
+
+                for mdl_ligand in self.model_ligands:
+                    LogDebug("Compute RMSD for model ligand %s" % mdl_ligand)
+
+                    # Get symmetry graphs
+                    model_graph = model_ligand.spyrmsd_mol.to_graph()
+                    target_graph = target_ligand.struct_spyrmsd_mol.to_graph()
+                    pass
+
+
+def ResidueToGraph(residue):
+    """Return a NetworkX graph representation of the residue."""
+    nxg = networkx.Graph()
+    nxg.add_nodes_from([a.name for a in residue.atoms], element=[a.element for a in residue.atoms])
+    # This will list all edges twice - once for every atom of the pair.
+    # But as of NetworkX 3.0 adding the same edge twice has no effect so we're good.
+    nxg.add_edges_from([(b.first.name, b.second.name) for a in residue.atoms for b in a.GetBondList()])
+    return nxg
+
+
+__all__ = ["LigandScorer", "ResidueToGraph"]
diff --git a/modules/mol/alg/tests/test_ligand_scoring.py b/modules/mol/alg/tests/test_ligand_scoring.py
index ed8cbf98b..baf8b2bf6 100644
--- a/modules/mol/alg/tests/test_ligand_scoring.py
+++ b/modules/mol/alg/tests/test_ligand_scoring.py
@@ -5,8 +5,8 @@ from ost import io, mol
 try:
     from ost.mol.alg.ligand_scoring import *
 except ImportError:
-    print("Failed to import ligand_scoring.py. Happens when numpy or scipy " \
-          "missing. Ignoring test_ligand_scoring.py tests.")
+    print("Failed to import ligand_scoring.py. Happens when numpy, scipy or "
+          "networkx is missing. Ignoring test_ligand_scoring.py tests.")
     sys.exit(0)
 
 
@@ -114,6 +114,14 @@ class TestLigandScoring(unittest.TestCase):
         with self.assertRaises(RuntimeError):
             sc = LigandScorer(mdl, trg, mdl_ligs, [lig0, lig1])
 
+    def test_ResidueToGraph(self):
+        """Test that ResidueToGraph works as expected
+        """
+        mdl_lig = io.LoadEntity(os.path.join('testfiles', "P84080_model_02_ligand_0.sdf"))
+
+        graph = ResidueToGraph(mdl_lig.residues[0])
+        assert len(graph.edges) == 34
+        assert len(graph.nodes) == 32
 
 
 if __name__ == "__main__":
-- 
GitLab