From e85318b2d7a36f74a7c9df67efbfc9c8ad979eed Mon Sep 17 00:00:00 2001
From: Gabriel Studer <gabriel.studer@unibas.ch>
Date: Tue, 13 May 2025 23:45:14 +0200
Subject: [PATCH] drmsd: correctly add penalties for atoms missing in the model

---
 modules/mol/alg/pymod/lddt.py      | 21 +++++++++++++++------
 modules/mol/alg/tests/test_lddt.py |  2 +-
 2 files changed, 16 insertions(+), 7 deletions(-)

diff --git a/modules/mol/alg/pymod/lddt.py b/modules/mol/alg/pymod/lddt.py
index da013fd8..a4f85c38 100644
--- a/modules/mol/alg/pymod/lddt.py
+++ b/modules/mol/alg/pymod/lddt.py
@@ -892,17 +892,26 @@ class lDDTScorer:
         per_res_drmsd = [None] * model.GetResidueCount()
         for r_idx in range(len(res_atom_indices)):
             end_idx = start_idx + len(res_atom_indices[r_idx])
-            n_exp = per_res_exp[r_idx]
-            if n_exp > 0:
+            n_tot = per_res_exp[r_idx]
+            if n_tot > 0:
                 ssd = np.sum(per_atom_ssd[start_idx:end_idx])
-                per_res_drmsd[res_indices[r_idx]] = np.sqrt(ssd/n_exp)
+                # add penalties from distances involving atoms that are not
+                # present in the model
+                n_missing = n_tot - np.sum(per_atom_exp[start_idx:end_idx])
+                ssd += n_missing*dist_cap*dist_cap
+                per_res_drmsd[res_indices[r_idx]] = np.sqrt(ssd/n_tot)
             start_idx = end_idx
 
         # do full model score
         drmsd = None
-        n_exp = np.sum(per_atom_exp)
-        if n_exp > 0:
-            drmsd = np.sqrt(np.sum(per_atom_ssd)/n_exp)
+        n_tot = sum([len(x) for x in ref_indices])
+        if n_tot > 0:
+            ssd = np.sum(per_atom_ssd)
+            # add penalties from distances involving atoms that are not
+            # present in the model
+            n_missing = n_tot - np.sum(per_atom_exp)
+            ssd += (dist_cap*dist_cap*n_missing)
+            drmsd = np.sqrt(ssd/n_tot)
 
         return drmsd, per_res_drmsd
 
diff --git a/modules/mol/alg/tests/test_lddt.py b/modules/mol/alg/tests/test_lddt.py
index 57b24252..a4781a3b 100644
--- a/modules/mol/alg/tests/test_lddt.py
+++ b/modules/mol/alg/tests/test_lddt.py
@@ -245,7 +245,7 @@ class TestlDDT(unittest.TestCase):
 
         # this value is just blindly copied in without checking whether it makes
         # any sense... it's sole purpose is to trigger DRMSD computation
-        self.assertAlmostEqual(drmsd, 1.895447711911706, places=5)
+        self.assertAlmostEqual(drmsd, 1.9765632785024412, places=5)
 
 class TestlDDTBS(unittest.TestCase):
 
-- 
GitLab