From 656b852ba7cfacd8e81449538c31de3e835c0d0d Mon Sep 17 00:00:00 2001
From: Gabriel Studer <gabriel.studer@unibas.ch>
Date: Tue, 19 Apr 2022 14:10:45 +0200
Subject: [PATCH] lDDT: code cleanups

---
 modules/mol/alg/pymod/lddt.py | 93 ++++++++++++++++++++---------------
 1 file changed, 52 insertions(+), 41 deletions(-)

diff --git a/modules/mol/alg/pymod/lddt.py b/modules/mol/alg/pymod/lddt.py
index c14ea43a2..7fba7a74c 100644
--- a/modules/mol/alg/pymod/lddt.py
+++ b/modules/mol/alg/pymod/lddt.py
@@ -638,10 +638,53 @@ class lDDTScorer:
 
         No distance related members - see _SetupDistances
         """
+        residue_numbers = self._GetTargetResidueNumbers(self.target,
+                                                        seqres_mapping)
+        current_idx = 0
+        for chain in self.target.chains:
+            ch_name = chain.GetName()
+            self.chain_names.append(ch_name)
+            self.chain_start_indices.append(current_idx)
+            self.chain_res_start_indices.append(len(self.compound_names))
+            for r, rnum in zip(chain.residues, residue_numbers[ch_name]):
+                if r.name not in self.compound_anames:
+                    # sets compound info in self.compound_anames and
+                    # self.compound_symmetric_atoms
+                    self._SetupCompound(r, compound_lib, symmetry_settings,
+                                        calpha)
+
+                self.res_start_indices.append(current_idx)
+                self.res_mapper[(ch_name, rnum)] = len(self.compound_names)
+                self.compound_names.append(r.name)
+                self.res_resnums.append(rnum)
+
+                atoms = [r.FindAtom(an) for an in self.compound_anames[r.name]]
+                for a in atoms:
+                    if a.IsValid():
+                        self.atom_indices[a.handle.GetHashCode()] = current_idx
+                    current_idx += 1
+                
+                if r.name in self.compound_symmetric_atoms:
+                    for sym_tuple in self.compound_symmetric_atoms[r.name]:
+                        for a_idx in sym_tuple:
+                            a = atoms[a_idx]
+                            if a.IsValid():
+                                hashcode = a.handle.GetHashCode()
+                                self.symmetric_atoms.add(
+                                    self.atom_indices[hashcode]
+                                )
+        self.n_atoms = current_idx
+
+
+    def _GetTargetResidueNumbers(self, target, seqres_mapping):
+        """Returns residue numbers for each chain in target as dict
+
+        They're either directly extracted from the raw residue number
+        from the structure or from user provided alignments
+        """
         residue_numbers = dict()
-        for ch in self.target.chains:
+        for ch in target.chains:
             ch_name = ch.GetName()
-            self.chain_names.append(ch_name)
             rnums = list()
             if ch_name in seqres_mapping:
                 seqres = seqres_mapping[ch_name].GetSequence(0).GetString()
@@ -685,39 +728,7 @@ class lDDTScorer:
                     )
             assert len(rnums) == len(ch.residues)
             residue_numbers[ch_name] = rnums
-
-        current_idx = 0
-        for chain in self.target.chains:
-            ch_name = chain.GetName()
-            self.chain_start_indices.append(current_idx)
-            self.chain_res_start_indices.append(len(self.compound_names))
-            for r, rnum in zip(chain.residues, residue_numbers[ch_name]):
-                if r.name not in self.compound_anames:
-                    # sets compound info in self.compound_anames and
-                    # self.compound_symmetric_atoms
-                    self._SetupCompound(r, compound_lib, symmetry_settings,
-                                        calpha)
-
-                self.res_start_indices.append(current_idx)
-                self.res_mapper[(ch_name, rnum)] = len(self.compound_names)
-                self.compound_names.append(r.name)
-                self.res_resnums.append(rnum)
-
-                atoms = [r.FindAtom(an) for an in self.compound_anames[r.name]]
-                for a in atoms:
-                    if a.IsValid():
-                        self.atom_indices[a.handle.GetHashCode()] = current_idx
-                    current_idx += 1
-
-                for sym_tuple in self.compound_symmetric_atoms[r.name]:
-                    for a_idx in sym_tuple:
-                        a = atoms[a_idx]
-                        if a.IsValid():
-                            hashcode = a.handle.GetHashCode()
-                            self.symmetric_atoms.add(
-                                self.atom_indices[hashcode]
-                            )
-        self.n_atoms = current_idx
+        return residue_numbers
 
     def _SetupCompound(self, r, compound_lib, symmetry_settings, calpha):
         """fill self.compound_anames/self.compound_symmetric_atoms
@@ -735,22 +746,22 @@ class lDDTScorer:
             for atom_spec in compound.GetAtomSpecs():
                 if atom_spec.element not in ["H", "D"]:
                     atom_names.append(atom_spec.name)
-            self.compound_anames[r.name] = atom_names
             if r.name in symmetry_settings.symmetric_compounds:
                 for pair in symmetry_settings.symmetric_compounds[r.name]:
                     try:
-                        a = self.compound_anames[r.name].index(pair[0])
-                        b = self.compound_anames[r.name].index(pair[1])
+                        a = atom_names.index(pair[0])
+                        b = atom_names.index(pair[1])
                     except:
                         msg = f"Could not find symmetric atoms "
                         msg += f"({pair[0]}, {pair[1]}) for {r.name} "
                         msg += f"as specified in SymmetrySettings in "
                         msg += f"compound from component dictionary. "
-                        msg += f"Atoms in compound: "
-                        msg += f"{self.compound_anames[r.name]}"
+                        msg += f"Atoms in compound: {atom_names}"
                         raise RuntimeError(msg)
                     symmetric_atoms.append((a, b))
-            self.compound_symmetric_atoms[r.name] = symmetric_atoms
+            self.compound_anames[r.name] = atom_names
+            if len(symmetric_atoms) > 0:
+                self.compound_symmetric_atoms[r.name] = symmetric_atoms
 
     def _SetupDistances(self):
         """Compute distance related members of lDDTScorer
-- 
GitLab