From e0bc7f7b40542b953d1cee6861fe25d6bcd74139 Mon Sep 17 00:00:00 2001
From: Gabriel Studer <gabriel.studer@unibas.ch>
Date: Mon, 22 Jul 2024 12:15:54 +0200
Subject: [PATCH] GDT: improve optimization and handle edge cases with low
 number of positions

Optimization has been improved by sampling several window sizes. The results
come really close to LGA (avg diff of 0.16 GDT points when evaluation all
CASP15 TS models).

In the previous implementation, an error was thrown when number of positions
was lower than the window size parameter. Now, the window size is just set
to the number of positions in that case and the algorithm runs through.
Special cases of only one or two positions are handled separately to produce
sensible output.
---
 modules/mol/alg/pymod/scoring.py |  41 +++++++++--
 modules/mol/alg/src/gdt.cc       | 120 ++++++++++++++++++++++++++++++-
 2 files changed, 155 insertions(+), 6 deletions(-)

diff --git a/modules/mol/alg/pymod/scoring.py b/modules/mol/alg/pymod/scoring.py
index 4444e4637..d415b0aad 100644
--- a/modules/mol/alg/pymod/scoring.py
+++ b/modules/mol/alg/pymod/scoring.py
@@ -381,6 +381,7 @@ class Scorer:
         self._rigid_n_target_not_mapped = None
         self._rigid_transform = None
 
+        self._gdt_window_sizes = [5, 7, 9, 12, 24, 36, 48]
         self._gdt_05 = None
         self._gdt_1 = None
         self._gdt_2 = None
@@ -1399,7 +1400,13 @@ class Scorer:
         :type: :class:`float` 
         """
         if self._gdt_05 is None:
-            n, m = GDT(self.rigid_mapped_model_pos, self.rigid_mapped_target_pos, 7, 1000, 0.5)
+            N = list()
+            for window_size in self._gdt_window_sizes:
+                n = GDT(self.rigid_mapped_model_pos,
+                        self.rigid_mapped_target_pos,
+                        window_size, 1000, 0.5)[0]
+                N.append(n)
+            n = max(N)
             n_full = len(self.rigid_mapped_target_pos) + self.rigid_n_target_not_mapped
             if n_full > 0:
                 self._gdt_05 = float(n) / n_full
@@ -1417,7 +1424,13 @@ class Scorer:
         :type: :class:`float` 
         """
         if self._gdt_1 is None:
-            n, m = GDT(self.rigid_mapped_model_pos, self.rigid_mapped_target_pos, 7, 1000, 1.0)
+            N = list()
+            for window_size in self._gdt_window_sizes:
+                n = GDT(self.rigid_mapped_model_pos,
+                        self.rigid_mapped_target_pos,
+                        window_size, 1000, 1.0)[0]
+                N.append(n)
+            n = max(N)
             n_full = len(self.rigid_mapped_target_pos) + self.rigid_n_target_not_mapped
             if n_full > 0:
                 self._gdt_1 = float(n) / n_full
@@ -1436,7 +1449,13 @@ class Scorer:
         :type: :class:`float` 
         """
         if self._gdt_2 is None:
-            n, m = GDT(self.rigid_mapped_model_pos, self.rigid_mapped_target_pos, 7, 1000, 2.0)
+            N = list()
+            for window_size in self._gdt_window_sizes:
+                n = GDT(self.rigid_mapped_model_pos,
+                        self.rigid_mapped_target_pos,
+                        window_size, 1000, 2.0)[0]
+                N.append(n)
+            n = max(N)
             n_full = len(self.rigid_mapped_target_pos) + self.rigid_n_target_not_mapped
             if n_full > 0:
                 self._gdt_2 = float(n) / n_full
@@ -1454,7 +1473,13 @@ class Scorer:
         :type: :class:`float` 
         """
         if self._gdt_4 is None:
-            n, m = GDT(self.rigid_mapped_model_pos, self.rigid_mapped_target_pos, 7, 1000, 4.0)
+            N = list()
+            for window_size in self._gdt_window_sizes:
+                n = GDT(self.rigid_mapped_model_pos,
+                        self.rigid_mapped_target_pos,
+                        window_size, 1000, 4.0)[0]
+                N.append(n)
+            n = max(N)
             n_full = len(self.rigid_mapped_target_pos) + self.rigid_n_target_not_mapped
             if n_full > 0:
                 self._gdt_4 = float(n) / n_full
@@ -1471,7 +1496,13 @@ class Scorer:
         :type: :class:`float` 
         """
         if self._gdt_8 is None:
-            n, m = GDT(self.rigid_mapped_model_pos, self.rigid_mapped_target_pos, 7, 1000, 8.0)
+            N = list()
+            for window_size in self._gdt_window_sizes:
+                n = GDT(self.rigid_mapped_model_pos,
+                        self.rigid_mapped_target_pos,
+                        window_size, 1000, 8.0)[0]
+                N.append(n)
+            n = max(N)
             n_full = len(self.rigid_mapped_target_pos) + self.rigid_n_target_not_mapped
             if n_full > 0:
                 self._gdt_8 = float(n) / n_full
diff --git a/modules/mol/alg/src/gdt.cc b/modules/mol/alg/src/gdt.cc
index 93b6e8b17..8eedd7796 100644
--- a/modules/mol/alg/src/gdt.cc
+++ b/modules/mol/alg/src/gdt.cc
@@ -302,8 +302,126 @@ void GDT(const geom::Vec3List& mdl_pos, const geom::Vec3List& ref_pos,
 
   int n_pos = mdl_pos.size();
 
+  // deal with special cases that don't produce valid transforms
+  if(n_pos == 1) {
+      transform = geom::Mat4::Identity();
+      transform.PasteTranslation(ref_pos[0] - mdl_pos[0]);
+      n_superposed = 1;
+    return;
+  }
+
+  if(n_pos == 2) {
+    Real mdl_d = geom::Distance(mdl_pos[0], mdl_pos[1]);
+    Real ref_d = geom::Distance(ref_pos[0], ref_pos[1]);
+    Real dd = std::abs(mdl_d - ref_d);
+    if(dd/2 <= distance_thresh) {
+      // the two can be superposed within specified distance threshold
+      // BUT: cannot construct valid transformation from two positions
+      // => Construct matrix with four positions 
+      // Two are constructed starting from the center point +- some direction
+      // vector that is orthogonal to the vector connecting the original two
+      // points. 
+      geom::Vec3 mdl_center = (mdl_pos[0] + mdl_pos[1])*0.5;
+      geom::Vec3 ref_center = (ref_pos[0] + ref_pos[1])*0.5;
+      Eigen::Matrix<double, 4, 3> eigen_mdl_pos = \
+      Eigen::Matrix<double, 4, 3>::Zero(4, 3);
+      Eigen::Matrix<double, 4, 3> eigen_ref_pos = \
+      Eigen::Matrix<double, 4, 3>::Zero(4, 3);
+      Eigen::Matrix<double,3,3> eigen_rot = \
+      Eigen::Matrix<double,3,3>::Identity();
+
+      geom::Vec3 mdl_dir = geom::Normalize(mdl_pos[1] - mdl_pos[0]);
+      geom::Vec3 ref_dir = geom::Normalize(ref_pos[1] - ref_pos[0]);
+      geom::Vec3 mdl_normal;
+      geom::Vec3 ref_normal;
+
+      // Use cross product to get some normal on mdl_dir
+      // The direction of the second vector doesn't really matter, but shouldnt
+      // be collinear with mdl_dir
+      if(mdl_dir[0] < 0.999) {
+        mdl_normal = geom::Cross(geom::Vec3(1,0,0), mdl_dir);
+      } else {
+        mdl_normal = geom::Cross(geom::Vec3(0,1,0), mdl_dir);
+      }
+
+      // same for ref_dir
+      if(ref_dir[0] < 0.999) {
+        ref_normal = geom::Cross(geom::Vec3(1,0,0), ref_dir);
+      } else {
+        ref_normal = geom::Cross(geom::Vec3(0,1,0), ref_dir);
+      }
+
+      eigen_mdl_pos(0, 0) = mdl_pos[0][0] - mdl_center[0];
+      eigen_mdl_pos(0, 1) = mdl_pos[0][1] - mdl_center[1];
+      eigen_mdl_pos(0, 2) = mdl_pos[0][2] - mdl_center[2];
+      eigen_mdl_pos(1, 0) = mdl_pos[1][0] - mdl_center[0];
+      eigen_mdl_pos(1, 1) = mdl_pos[1][1] - mdl_center[1];
+      eigen_mdl_pos(1, 2) = mdl_pos[1][2] - mdl_center[2];
+      eigen_mdl_pos(2, 0) = mdl_normal[0];
+      eigen_mdl_pos(2, 1) = mdl_normal[1];
+      eigen_mdl_pos(2, 2) = mdl_normal[2];
+      eigen_mdl_pos(3, 0) = -mdl_normal[0];
+      eigen_mdl_pos(3, 1) = -mdl_normal[1];
+      eigen_mdl_pos(3, 2) = -mdl_normal[2];
+      eigen_ref_pos(0, 0) = ref_pos[0][0] - ref_center[0];
+      eigen_ref_pos(0, 1) = ref_pos[0][1] - ref_center[1];
+      eigen_ref_pos(0, 2) = ref_pos[0][2] - ref_center[2];
+      eigen_ref_pos(1, 0) = ref_pos[1][0] - ref_center[0];
+      eigen_ref_pos(1, 1) = ref_pos[1][1] - ref_center[1];
+      eigen_ref_pos(1, 2) = ref_pos[1][2] - ref_center[2];
+      eigen_ref_pos(2, 0) = ref_normal[0];
+      eigen_ref_pos(2, 1) = ref_normal[1];
+      eigen_ref_pos(2, 2) = ref_normal[2];
+      eigen_ref_pos(3, 0) = -ref_normal[0];
+      eigen_ref_pos(3, 1) = -ref_normal[1];
+      eigen_ref_pos(3, 2) = -ref_normal[2];
+
+      Real tmp; // no need to store RMSD
+      TheobaldRMSD(eigen_mdl_pos, eigen_ref_pos, tmp, eigen_rot);
+      transform = geom::Mat4();
+      transform(0, 0) = eigen_rot(0, 0);
+      transform(0, 1) = eigen_rot(0, 1);
+      transform(0, 2) = eigen_rot(0, 2);
+      transform(1, 0) = eigen_rot(1, 0);
+      transform(1, 1) = eigen_rot(1, 1);
+      transform(1, 2) = eigen_rot(1, 2);
+      transform(2, 0) = eigen_rot(2, 0);
+      transform(2, 1) = eigen_rot(2, 1);
+      transform(2, 2) = eigen_rot(2, 2);
+
+      // there are three transformation to be applied to reach ref_pos from
+      // mdl_pos:
+      // 1: shift mdl_pos to center
+      // 2: apply estimated rotation
+      // 3: shift onto average of ref_pos
+      Eigen::Matrix<double,1,3> eigen_avg_mdl = Eigen::Matrix<double,1,3>::Zero();
+      Eigen::Matrix<double,1,3> eigen_avg_ref = Eigen::Matrix<double,1,3>::Zero();
+      eigen_avg_mdl(0,0) = mdl_center[0]; 
+      eigen_avg_mdl(0,1) = mdl_center[1]; 
+      eigen_avg_mdl(0,2) = mdl_center[2]; 
+      eigen_avg_ref(0,0) = ref_center[0]; 
+      eigen_avg_ref(0,1) = ref_center[1]; 
+      eigen_avg_ref(0,2) = ref_center[2]; 
+      Eigen::Matrix<double,1,3> translation = eigen_rot *
+                                              (-eigen_avg_mdl.transpose()) + 
+                                              eigen_avg_ref.transpose();
+      transform(0, 3) = translation(0, 0);
+      transform(1, 3) = translation(0, 1);
+      transform(2, 3) = translation(0, 2);
+      n_superposed = 2;
+    } else {
+      // the two cannot be superposed within specified distance threshold
+      // => just set n_superposed to 1 and generate the same transformation
+      // as in n_pos == 1
+      transform = geom::Mat4::Identity();
+      transform.PasteTranslation(ref_pos[0] - mdl_pos[0]);
+      n_superposed = 1;
+    }
+    return;
+  }
+
   if(window_size > n_pos) {
-    throw ost::Error("Window size in GDT algorithm is larger than positions");
+    window_size = n_pos;
   }
 
   Eigen::Matrix<double, Eigen::Dynamic, 3> eigen_mdl_pos = \
-- 
GitLab