From 7ec8180c3e1338049a4848b3ad82f789f2b7949c Mon Sep 17 00:00:00 2001
From: Gabriel Studer <gabriel.studer@unibas.ch>
Date: Mon, 9 May 2022 17:11:08 +0200
Subject: [PATCH] lDDT: generalize calpha flag to also deal with nucleotides
 (C3' as ref atom)

---
 modules/mol/alg/pymod/lddt.py      | 39 ++++++++++++++++++------------
 modules/mol/alg/tests/test_lddt.py |  6 ++---
 2 files changed, 26 insertions(+), 19 deletions(-)

diff --git a/modules/mol/alg/pymod/lddt.py b/modules/mol/alg/pymod/lddt.py
index e58784bd4..3f3c7d1eb 100644
--- a/modules/mol/alg/pymod/lddt.py
+++ b/modules/mol/alg/pymod/lddt.py
@@ -130,11 +130,12 @@ class lDDTScorer:
                            what the original residue numbers were. 
     :type seqres_mapping: :class:`dict` (key: :class:`str`, value:
                           :class:`ost.seq.AlignmentHandle`)
-    :param calpha: Only consider atoms with name "CA". Technically this sets
-                   the expected atom names for each residue name to ["CA"], thus
-                   invalidating *compound_lib*. No check whether the target
-                   residues are actually amino acids!
-    :type calpha: :class:`bool`
+    :param bb_only: Only consider atoms with name "CA" in case of amino acids and
+                    "C3'" for Nucleotides. this invalidates *compound_lib*.
+                    Raises if any residue in *target* is not
+                    `r.chem_class.IsPeptideLinking()` or
+                    `r.chem_class.IsNucleotideLinking()`
+    :type bb_only: :class:`bool`
     :raises: :class:`RuntimeError` if *target* contains compound which is not in
              *compound_lib*, :class:`RuntimeError` if *symmetry_settings*
              specifies symmetric atoms that are not present in the according
@@ -154,7 +155,7 @@ class lDDTScorer:
         sequence_separation=0,
         symmetry_settings=None,
         seqres_mapping=dict(),
-        calpha=False
+        bb_only=False
     ):
 
         self.target = target
@@ -171,9 +172,9 @@ class lDDTScorer:
         else:
             self.symmetry_settings = symmetry_settings
 
-        # whether to only consider atoms with name "CA", invalidates
-        # *compound_lib*
-        self.calpha=calpha
+        # whether to only consider atoms with name "CA" (amino acids) or C3'
+        # (nucleotides), invalidates *compound_lib*
+        self.bb_only=bb_only
 
         # names of heavy atoms of each unique compound present in *target* as
         # extracted from *compound_lib*, e.g.
@@ -225,7 +226,7 @@ class lDDTScorer:
 
         # setup members defined above
         self._SetupEnv(self.compound_lib, self.symmetry_settings,
-                       seqres_mapping, self.calpha)
+                       seqres_mapping, self.bb_only)
 
         # distance related members are lazily computed as they're affected
         # by different flavours of lDDT (e.g. lDDT including inter-chain
@@ -644,7 +645,7 @@ class lDDTScorer:
                                           self.compound_lib,
                                           symmetry_settings = sm,
                                           inclusion_radius = self.inclusion_radius,
-                                          calpha = self.calpha)
+                                          bb_only = self.bb_only)
                 penalty += dummy_scorer.n_distances
         return penalty
 
@@ -707,7 +708,7 @@ class lDDTScorer:
 
 
     def _SetupEnv(self, compound_lib, symmetry_settings, seqres_mapping,
-                  calpha):
+                  bb_only):
         """Sets target related lDDTScorer members defined in constructor
 
         No distance related members - see _SetupDistances
@@ -725,7 +726,7 @@ class lDDTScorer:
                     # sets compound info in self.compound_anames and
                     # self.compound_symmetric_atoms
                     self._SetupCompound(r, compound_lib, symmetry_settings,
-                                        calpha)
+                                        bb_only)
 
                 self.res_start_indices.append(current_idx)
                 self.res_mapper[(ch_name, rnum)] = len(self.compound_names)
@@ -804,12 +805,18 @@ class lDDTScorer:
             residue_numbers[ch_name] = rnums
         return residue_numbers
 
-    def _SetupCompound(self, r, compound_lib, symmetry_settings, calpha):
+    def _SetupCompound(self, r, compound_lib, symmetry_settings, bb_only):
         """fill self.compound_anames/self.compound_symmetric_atoms
         """
-        if calpha:
+        if bb_only:
             # throw away compound_lib info
-            self.compound_anames[r.name] = ["CA"]
+            if r.chem_class.IsPeptideLinking():
+                self.compound_anames[r.name] = ["CA"]
+            elif r.chem_type.IsNucleotideLinking():
+                self.compound_anames[r.name] = ["C3'"]
+            else:
+                raise RuntimeError(f"Only support amino acids and nucleotides "
+                                   f"if bb_only is True, failed with {str(r)}")
             self.compound_symmetric_atoms[r.name] = list()
         else:
             atom_names = list()
diff --git a/modules/mol/alg/tests/test_lddt.py b/modules/mol/alg/tests/test_lddt.py
index 6f3e3be47..a1f8bfca7 100644
--- a/modules/mol/alg/tests/test_lddt.py
+++ b/modules/mol/alg/tests/test_lddt.py
@@ -147,7 +147,7 @@ class TestlDDT(unittest.TestCase):
             scorer = lDDTScorer(target, sequence_separation=42)
         scorer = lDDTScorer(target, sequence_separation=0)
 
-    def test_calpha(self):
+    def test_bb_only(self):
         model = _LoadFile("7SGN_C_model.pdb")
         target = _LoadFile("7SGN_C_target.pdb")
 
@@ -156,8 +156,8 @@ class TestlDDT(unittest.TestCase):
         score_one, per_res_scores_one = scorer.lDDT(model)
         score_two, per_res_scores_two = scorer.lDDT(model.Select("aname=CA"))
 
-        # no selection, just setting calpha flag should give the same
-        scorer = lDDTScorer(target, calpha=True)
+        # no selection, just setting bb_only flag should give the same
+        scorer = lDDTScorer(target, bb_only=True)
         score_three, per_res_scores_three = scorer.lDDT(model)
 
         # check
-- 
GitLab