From 115c411cf264f6883c564e43110ff6b5d397c77b Mon Sep 17 00:00:00 2001
From: Gabriel Studer <gabriel.studer@unibas.ch>
Date: Wed, 7 Sep 2022 13:19:29 +0200
Subject: [PATCH] lDDT: allow custom compounds which are not in default
 compound lib

---
 modules/mol/alg/doc/molalg.rst |  3 ++
 modules/mol/alg/pymod/lddt.py  | 65 +++++++++++++++++++++++++++-------
 2 files changed, 55 insertions(+), 13 deletions(-)

diff --git a/modules/mol/alg/doc/molalg.rst b/modules/mol/alg/doc/molalg.rst
index 23d670854..89b170b95 100644
--- a/modules/mol/alg/doc/molalg.rst
+++ b/modules/mol/alg/doc/molalg.rst
@@ -34,6 +34,9 @@ Local Distance Test scores (lDDT, DRMSD)
 
 .. autofunction:: ost.mol.alg.lddt.GetDefaultSymmetrySettings
 
+.. autoclass:: ost.mol.alg.lddt.CustomCompound
+  :members:
+
 .. function:: CheckStructure(ent, \
                              bond_table, \
                              angle_table, \
diff --git a/modules/mol/alg/pymod/lddt.py b/modules/mol/alg/pymod/lddt.py
index 4f8565bc7..f905cf416 100644
--- a/modules/mol/alg/pymod/lddt.py
+++ b/modules/mol/alg/pymod/lddt.py
@@ -4,6 +4,34 @@ from ost import mol
 from ost import conop
 
 
+class CustomCompound:
+    """ Defines atoms for custom compounds
+
+    lDDT requires the reference atoms of a compound which are typically
+    extracted from a :class:`ost.conop.CompoundLib`. This lightweight
+    container allows to handle arbitrary compounds which are not
+    necessarily in the compound library.
+
+    :param atom_names: Names of atoms of custom compound
+    :type atom_names: :class:`list` of :class:`str`
+    """
+    def __init__(self, atom_names):
+        self.atom_names = atom_names
+
+    @staticmethod
+    def FromResidue(res):
+        """ Construct custom compound from residue
+
+        :param res: Residue from which reference atom names are extracted
+        :type res: :class:`ost.mol.ResidueView`/:class:`ost.mol.ResidueHandle`
+        :returns: :class:`CustomCompound`
+        """
+        atom_names = [a.GetName() for a in res.atoms]
+        if len(atom_names) != len(set(atom_names)):
+            raise RuntimeError("Duplicate atoms detected in CustomCompound")
+        compound = CustomCompound(atom_names)
+        return compound
+
 class SymmetrySettings:
     """Container for symmetric compounds
 
@@ -99,6 +127,11 @@ class lDDTScorer:
                          ["A", "B"] but the compound has ["A", "B", "C"], "C" is
                          considered missing and does not influence scoring, even
                          if present in the model.
+    :param custom_compounds: Custom compounds defining reference atoms. If
+                             given, *custom_compounds* take precedent over
+                             *compound_lib*.
+    :type custom_compounds: :class:`dict` with residue names (:class:`str`) as
+                            key and :class:`CustomCompound` as value.
     :type compound_lib: :class:`ost.conop.CompoundLib`
     :param inclusion_radius: All pairwise distances < *inclusion_radius* are
                              considered for scoring
@@ -151,6 +184,7 @@ class lDDTScorer:
         self,
         target,
         compound_lib=None,
+        custom_compounds=None,
         inclusion_radius=15,
         sequence_separation=0,
         symmetry_settings=None,
@@ -167,6 +201,7 @@ class lDDTScorer:
             raise RuntimeError("No compound_lib given and conop.GetDefaultLib "
                                "returns no valid compound library")
         self.compound_lib = compound_lib
+        self.custom_compounds = custom_compounds
         if symmetry_settings is None:
             self.symmetry_settings = GetDefaultSymmetrySettings()
         else:
@@ -225,8 +260,8 @@ class lDDTScorer:
         self.symmetric_atoms = set()
 
         # setup members defined above
-        self._SetupEnv(self.compound_lib, self.symmetry_settings,
-                       seqres_mapping, self.bb_only)
+        self._SetupEnv(self.compound_lib, self.custom_compounds,
+                       self.symmetry_settings, seqres_mapping, self.bb_only)
 
         # distance related members are lazily computed as they're affected
         # by different flavours of lDDT (e.g. lDDT including inter-chain
@@ -709,8 +744,8 @@ class lDDTScorer:
         return rnums
 
 
-    def _SetupEnv(self, compound_lib, symmetry_settings, seqres_mapping,
-                  bb_only):
+    def _SetupEnv(self, compound_lib, custom_compounds, symmetry_settings,
+                  seqres_mapping, bb_only):
         """Sets target related lDDTScorer members defined in constructor
 
         No distance related members - see _SetupDistances
@@ -727,8 +762,8 @@ class lDDTScorer:
                 if r.name not in self.compound_anames:
                     # sets compound info in self.compound_anames and
                     # self.compound_symmetric_atoms
-                    self._SetupCompound(r, compound_lib, symmetry_settings,
-                                        bb_only)
+                    self._SetupCompound(r, compound_lib, custom_compounds,
+                                        symmetry_settings, bb_only)
 
                 self.res_start_indices.append(current_idx)
                 self.res_mapper[(ch_name, rnum)] = len(self.compound_names)
@@ -807,7 +842,8 @@ class lDDTScorer:
             residue_numbers[ch_name] = rnums
         return residue_numbers
 
-    def _SetupCompound(self, r, compound_lib, symmetry_settings, bb_only):
+    def _SetupCompound(self, r, compound_lib, custom_compounds,
+                       symmetry_settings, bb_only):
         """fill self.compound_anames/self.compound_symmetric_atoms
         """
         if bb_only:
@@ -823,12 +859,15 @@ class lDDTScorer:
         else:
             atom_names = list()
             symmetric_atoms = list()
-            compound = compound_lib.FindCompound(r.name)
-            if compound is None:
-                raise RuntimeError(f"no entry for {r} in compound_lib")
-            for atom_spec in compound.GetAtomSpecs():
-                if atom_spec.element not in ["H", "D"]:
-                    atom_names.append(atom_spec.name)
+            if custom_compounds is not None and r.GetName() in custom_compounds:
+                atom_names = list(custom_compounds[r.GetName()].atom_names)
+            else:
+                compound = compound_lib.FindCompound(r.name)
+                if compound is None:
+                    raise RuntimeError(f"no entry for {r} in compound_lib")
+                for atom_spec in compound.GetAtomSpecs():
+                    if atom_spec.element not in ["H", "D"]:
+                        atom_names.append(atom_spec.name)
             if r.name in symmetry_settings.symmetric_compounds:
                 for pair in symmetry_settings.symmetric_compounds[r.name]:
                     try:
-- 
GitLab