From 767c35a6d00308c59810dee3adee22803438d083 Mon Sep 17 00:00:00 2001
From: Gabriel Studer <gabriel.studer@unibas.ch>
Date: Fri, 12 Apr 2024 09:57:04 +0200
Subject: [PATCH] lddt: add experimental flag: add_mdl_contacts

Only using contacts that are within a certain distance threshold in the target
does not penalize for added model contacts. If set to True, this flag will
also consider target contacts that are within the specified distance threshold
in the model but not necessarily in the target. No contact will be added if
the respective atom pair is not resolved in the target.
---
 modules/mol/alg/pymod/chain_mapping.py |   2 +-
 modules/mol/alg/pymod/lddt.py          | 327 ++++++++++++++++---------
 modules/mol/alg/tests/test_lddt.py     |  16 ++
 3 files changed, 229 insertions(+), 116 deletions(-)

diff --git a/modules/mol/alg/pymod/chain_mapping.py b/modules/mol/alg/pymod/chain_mapping.py
index 832f01ec0..41f7f6975 100644
--- a/modules/mol/alg/pymod/chain_mapping.py
+++ b/modules/mol/alg/pymod/chain_mapping.py
@@ -2085,7 +2085,7 @@ class _lDDTDecomposer:
                     s = lddt.lDDTScorer(dimer_ref, bb_only=True)
                     self.interface_scorer[k1] = s
                     self.interface_scorer[k2] = s
-                    self.n += self.interface_scorer[k1].n_distances_ic
+                    self.n += sum([len(x) for x in self.interface_scorer[k1].ref_indices_ic])
                     self.ref_interfaces.append(k1)
                     # single chain scorer are actually interface scorers to save
                     # some distance calculations
diff --git a/modules/mol/alg/pymod/lddt.py b/modules/mol/alg/pymod/lddt.py
index b6b11dd9d..bdadd90c0 100644
--- a/modules/mol/alg/pymod/lddt.py
+++ b/modules/mol/alg/pymod/lddt.py
@@ -271,6 +271,9 @@ class lDDTScorer:
         # store indices of all atoms that have symmetry properties
         self.symmetric_atoms = set()
 
+        # the actual target positions in a numpy array of shape (self.n_atoms,3)
+        self.positions = None
+
         # setup members defined above
         self._SetupEnv(self.compound_lib, self.custom_compounds,
                        self.symmetry_settings, seqres_mapping, self.bb_only)
@@ -291,16 +294,12 @@ class lDDTScorer:
         self._sym_ref_indices = None
         self._sym_ref_distances = None
 
-        # total number of distances
-        self._n_distances = None
-
         # exactly the same as above but without interchain contacts
         # => single-chain (sc)
         self._ref_indices_sc = None
         self._ref_distances_sc = None
         self._sym_ref_indices_sc = None
         self._sym_ref_distances_sc = None
-        self._n_distances_sc = None
 
         # exactly the same as above but without intrachain contacts
         # => inter-chain (ic)
@@ -308,7 +307,6 @@ class lDDTScorer:
         self._ref_distances_ic = None
         self._sym_ref_indices_ic = None
         self._sym_ref_distances_ic = None
-        self._n_distances_ic = None
 
         # input parameter checking
         self._ProcessSequenceSeparation()
@@ -316,99 +314,123 @@ class lDDTScorer:
     @property
     def ref_indices(self):
         if self._ref_indices is None:
-            self._SetupDistances()
+            self._ref_indices, self._ref_distances = \
+            lDDTScorer._SetupDistances(self.target, self.n_atoms,
+                                       self.atom_indices,
+                                       self.inclusion_radius)
         return self._ref_indices
 
     @property
     def ref_distances(self):
         if self._ref_distances is None:
-            self._SetupDistances()
+            self._ref_indices, self._ref_distances = \
+            lDDTScorer._SetupDistances(self.target, self.n_atoms,
+                                       self.atom_indices,
+                                       self.inclusion_radius)
         return self._ref_distances
     
     @property
     def sym_ref_indices(self):
         if self._sym_ref_indices is None:
-            self._SetupDistances()
+            self._sym_ref_indices, self._sym_ref_distances = \
+            lDDTScorer._NonSymDistances(self.n_atoms, self.symmetric_atoms,
+                                        self.ref_indices, self.ref_distances)
         return self._sym_ref_indices
 
     @property
     def sym_ref_distances(self):
         if self._sym_ref_distances is None:
-            self._SetupDistances()
+            self._sym_ref_indices, self._sym_ref_distances = \
+            lDDTScorer._NonSymDistances(self.n_atoms, self.symmetric_atoms,
+                                        self.ref_indices, self.ref_distances)
         return self._sym_ref_distances
 
-    @property
-    def n_distances(self):
-        if self._n_distances is None:
-            self._n_distances = sum([len(x) for x in self.ref_indices])
-        return self._n_distances
-
     @property
     def ref_indices_sc(self):
         if self._ref_indices_sc is None:
-            self._SetupDistancesSC()
+            self._ref_indices_sc, self._ref_distances_sc = \
+            lDDTScorer._SetupDistancesSC(self.n_atoms,
+                                         self.chain_start_indices,
+                                         self.ref_indices,
+                                         self.ref_distances)
         return self._ref_indices_sc
 
     @property
     def ref_distances_sc(self):
         if self._ref_distances_sc is None:
-            self._SetupDistancesSC()
+            self._ref_indices_sc, self._ref_distances_sc = \
+            lDDTScorer._SetupDistancesSC(self.n_atoms,
+                                         self.chain_start_indices,
+                                         self.ref_indices,
+                                         self.ref_distances)
         return self._ref_distances_sc
     
     @property
     def sym_ref_indices_sc(self):
         if self._sym_ref_indices_sc is None:
-            self._SetupDistancesSC()
+            self._sym_ref_indices_sc, self._sym_ref_distances_sc = \
+            lDDTScorer._NonSymDistances(self.n_atoms,
+                                        self.symmetric_atoms,
+                                        self.ref_indices_sc,
+                                        self.ref_distances_sc)
         return self._sym_ref_indices_sc
 
     @property
     def sym_ref_distances_sc(self):
         if self._sym_ref_distances_sc is None:
-            self._SetupDistancesSC()
+            self._sym_ref_indices_sc, self._sym_ref_distances_sc = \
+            lDDTScorer._NonSymDistances(self.n_atoms,
+                                        self.symmetric_atoms,
+                                        self.ref_indices_sc,
+                                        self.ref_distances_sc)
         return self._sym_ref_distances_sc
 
-    @property
-    def n_distances_sc(self):
-        if self._n_distances_sc is None:
-            self._n_distances_sc = sum([len(x) for x in self.ref_indices_sc])
-        return self._n_distances_sc
-
     @property
     def ref_indices_ic(self):
         if self._ref_indices_ic is None:
-            self._SetupDistancesIC()
+            self._ref_indices_ic, self._ref_distances_ic = \
+            lDDTScorer._SetupDistancesIC(self.n_atoms,
+                                         self.chain_start_indices,
+                                         self.ref_indices,
+                                         self.ref_distances)
         return self._ref_indices_ic
 
     @property
     def ref_distances_ic(self):
         if self._ref_distances_ic is None:
-            self._SetupDistancesIC()
+            self._ref_indices_ic, self._ref_distances_ic = \
+            lDDTScorer._SetupDistancesIC(self.n_atoms,
+                                         self.chain_start_indices,
+                                         self.ref_indices,
+                                         self.ref_distances)
         return self._ref_distances_ic
     
     @property
     def sym_ref_indices_ic(self):
         if self._sym_ref_indices_ic is None:
-            self._SetupDistancesIC()
+            self._sym_ref_indices_ic, self._sym_ref_distances_ic = \
+            lDDTScorer._NonSymDistances(self.n_atoms,
+                                        self.symmetric_atoms,
+                                        self.ref_indices_ic,
+                                        self.ref_distances_ic)
         return self._sym_ref_indices_ic
 
     @property
     def sym_ref_distances_ic(self):
         if self._sym_ref_distances_ic is None:
-            self._SetupDistancesIC()
+            self._sym_ref_indices_ic, self._sym_ref_distances_ic = \
+            lDDTScorer._NonSymDistances(self.n_atoms,
+                                        self.symmetric_atoms,
+                                        self.ref_indices_ic,
+                                        self.ref_distances_ic)
         return self._sym_ref_distances_ic
 
-    @property
-    def n_distances_ic(self):
-        if self._n_distances_ic is None:
-            self._n_distances_ic = sum([len(x) for x in self.ref_indices_ic])
-        return self._n_distances_ic
-
     def lDDT(self, model, thresholds = [0.5, 1.0, 2.0, 4.0],
              local_lddt_prop=None, local_contact_prop=None,
              chain_mapping=None, no_interchain=False,
              no_intrachain=False, penalize_extra_chains=False,
              residue_mapping=None, return_dist_test=False,
-             check_resnames=True):
+             check_resnames=True, add_mdl_contacts=False):
         """Computes lDDT of *model* - globally and per-residue
 
         :param model: Model to be scored - models are preferably scored upon
@@ -488,6 +510,17 @@ class lDDTScorer:
         :param check_resnames: On by default. Enforces residue name matches
                                between mapped model and target residues.
         :type check_resnames: :class:`bool`
+        :param add_mdl_contacts: Adds model contacts - Only using contacts that
+                                 are within a certain distance threshold in the
+                                 target does not penalize for added model
+                                 contacts. If set to True, this flag will also
+                                 consider target contacts that are within the
+                                 specified distance threshold in the model but
+                                 not necessarily in the target. No contact will
+                                 be added if the respective atom pair is not
+                                 resolved in the target.
+        :type add_mdl_contacts: :class:`bool`
+
         :returns: global and per-residue lDDT scores as a tuple -
                   first element is global lDDT score (None if *target* has no
                   contacts) and second element a list of per-residue scores with
@@ -530,6 +563,10 @@ class lDDTScorer:
         # actually there
         res_atom_indices = list()
 
+        # and the respective hash codes
+        # this is required if add_mdl_contacts is set to True
+        res_atom_hashes = list()
+
         # indices of the scored residues
         res_indices = list()
 
@@ -566,6 +603,7 @@ class lDDTScorer:
                     list(range(res_start_idx, res_start_idx + len(anames)))
                 )
                 res_atom_indices.append(list())
+                res_atom_hashes.append(list())
                 res_indices.append(current_model_res_idx)
                 for a_idx, a in enumerate(atoms):
                     if a.IsValid():
@@ -574,6 +612,7 @@ class lDDTScorer:
                         pos[res_start_idx + a_idx][1] = p[1]
                         pos[res_start_idx + a_idx][2] = p[2]
                         res_atom_indices[-1].append(res_start_idx + a_idx)
+                        res_atom_hashes[-1].append(a.handle.GetHashCode())
                 if rname in self.compound_symmetric_atoms:
                     sym_indices = list()
                     for sym_tuple in self.compound_symmetric_atoms[rname]:
@@ -598,19 +637,23 @@ class lDDTScorer:
             sym_ref_distances = self.sym_ref_distances_sc
             ref_indices = self.ref_indices_sc
             ref_distances = self.ref_distances_sc
-            n_distances = self.n_distances_sc
         elif no_intrachain:
             sym_ref_indices = self.sym_ref_indices_ic
             sym_ref_distances = self.sym_ref_distances_ic
             ref_indices = self.ref_indices_ic
             ref_distances = self.ref_distances_ic
-            n_distances = self.n_distances_ic
         else:
             sym_ref_indices = self.sym_ref_indices
             sym_ref_distances = self.sym_ref_distances
             ref_indices = self.ref_indices
             ref_distances = self.ref_distances
-            n_distances = self.n_distances
+
+        if add_mdl_contacts:
+            ref_indices, ref_distances, \
+            sym_ref_indices, sym_ref_distances = \
+            self._AddMdlContacts(model, res_atom_indices, res_atom_hashes,
+                                 ref_indices, ref_distances,
+                                 no_interchain, no_intrachain)
 
         self._ResolveSymmetries(pos, thresholds, symmetries, sym_ref_indices,
                                 sym_ref_distances)
@@ -633,8 +676,8 @@ class lDDTScorer:
             else:
                 per_res_lDDT[res_indices[idx]] = 0.0
 
-
         # do full model score
+        n_distances = sum([len(x) for x in ref_indices])
         if penalize_extra_chains:
             n_distances += self._GetExtraModelChainPenalty(model, chain_mapping)
 
@@ -704,7 +747,7 @@ class lDDTScorer:
                                           symmetry_settings = sm,
                                           inclusion_radius = self.inclusion_radius,
                                           bb_only = self.bb_only)
-                penalty += dummy_scorer.n_distances
+                penalty += sum([len(x) for x in dummy_scorer.ref_indices])
         return penalty
 
     def _GetChainRNums(self, ch, residue_mapping, model_ch_name,
@@ -768,6 +811,7 @@ class lDDTScorer:
         residue_numbers = self._GetTargetResidueNumbers(self.target,
                                                         seqres_mapping)
         current_idx = 0
+        positions = list()
         for chain in self.target.chains:
             ch_name = chain.GetName()
             self.chain_names.append(ch_name)
@@ -789,6 +833,11 @@ class lDDTScorer:
                 for a in atoms:
                     if a.IsValid():
                         self.atom_indices[a.handle.GetHashCode()] = current_idx
+                        p = a.GetPos()
+                        positions.append(np.asarray([p[0], p[1], p[2]],
+                                                     dtype=np.float32))
+                    else:
+                        positions.append(np.zeros(3, dtype=np.float32))
                     current_idx += 1
                 
                 if r.name in self.compound_symmetric_atoms:
@@ -800,9 +849,9 @@ class lDDTScorer:
                                 self.symmetric_atoms.add(
                                     self.atom_indices[hashcode]
                                 )
+        self.positions = np.vstack(positions, dtype=np.float32)
         self.n_atoms = current_idx
 
-
     def _GetTargetResidueNumbers(self, target, seqres_mapping):
         """Returns residue numbers for each chain in target as dict
 
@@ -889,7 +938,70 @@ class lDDTScorer:
             if len(symmetric_atoms) > 0:
                 self.compound_symmetric_atoms[r.name] = symmetric_atoms
 
-    def _SetupDistances(self):
+    def _AddMdlContacts(self, model, res_atom_indices, res_atom_hashes,
+                        ref_indices, ref_distances, no_interchain,
+                        no_intrachain):
+
+        # buildup an index map for mdl atoms that are also present in target
+        in_target = np.zeros(self.n_atoms, dtype=bool)
+        for i in self.atom_indices.values():
+            in_target[i] = True
+        mdl_atom_indices = dict()
+        for at_indices, at_hashes in zip(res_atom_indices, res_atom_hashes):
+            for i, h in zip(at_indices, at_hashes):
+                if in_target[i]:
+                    mdl_atom_indices[h] = i
+
+        # get contacts for mdl - the contacts are only from atom pairs that
+        # are also present in target, as we only provide the respective
+        # hashes in mdl_atom_indices
+        mdl_ref_indices, mdl_ref_distances = \
+        lDDTScorer._SetupDistances(model, self.n_atoms, mdl_atom_indices,
+                                   self.inclusion_radius)
+        if no_interchain:
+            mdl_ref_indices, mdl_ref_distances = \
+            lDDTScorer._SetupDistancesSC(self.n_atoms,
+                                         self.chain_start_indices,
+                                         mdl_ref_indices,
+                                         mdl_ref_distances)
+
+        if no_intrachain:
+            mdl_ref_indices, mdl_ref_distances = \
+            lDDTScorer._SetupDistancesIC(self.n_atoms,
+                                         self.chain_start_indices,
+                                         mdl_ref_indices,
+                                         mdl_ref_distances)
+
+        # update ref_indices/ref_distances => add mdl contacts
+        for i in range(self.n_atoms):
+            mask = np.isin(mdl_ref_indices[i], ref_indices[i],
+                           assume_unique=True, invert=True)
+            if np.sum(mask) > 0:
+                added_mdl_indices = mdl_ref_indices[i][mask]
+                ref_indices[i] = np.append(ref_indices[i],
+                                           added_mdl_indices)
+
+                # distances need to be recomputed from ref positions
+                tmp = self.positions.take(added_mdl_indices, axis=0)
+                np.subtract(tmp, self.positions[i][None, :], out=tmp)
+                np.square(tmp, out=tmp)
+                tmp = tmp.sum(axis=1)
+                np.sqrt(tmp, out=tmp)  # distances against all relevant atoms
+                ref_distances[i] = np.append(ref_distances[i], tmp)
+
+        # recompute symmetry related indices/distances
+        sym_ref_indices, sym_ref_distances = \
+        lDDTScorer._NonSymDistances(self.n_atoms, self.symmetric_atoms,
+                                    ref_indices, ref_distances)
+
+        return (ref_indices, ref_distances, sym_ref_indices, sym_ref_distances)
+
+
+
+    @staticmethod
+    def _SetupDistances(structure, n_atoms, atom_index_mapping,
+                        inclusion_radius):
+
         """Compute distance related members of lDDTScorer
 
         Brute force all vs all distance computation kills lDDT for large
@@ -902,19 +1014,16 @@ class lDDTScorer:
         - process potentially interacting chain pairs
         - concatenate distances from all processing steps
         """
-        self._ref_indices = [np.asarray([], dtype=np.int64) for idx in range(self.n_atoms)]
-        self._ref_distances = [np.asarray([], dtype=np.float64) for idx in range(self.n_atoms)]
-        self._sym_ref_indices = [np.asarray([], dtype=np.int64) for idx in range(self.n_atoms)]
-        self._sym_ref_distances = [np.asarray([], dtype=np.float64) for idx in range(self.n_atoms)]
+        ref_indices = [np.asarray([], dtype=np.int64) for idx in range(n_atoms)]
+        ref_distances = [np.asarray([], dtype=np.float64) for idx in range(n_atoms)]
 
-        indices = [list() for _ in range(self.n_atoms)]
-        distances = [list() for _ in range(self.n_atoms)]
+        indices = [list() for _ in range(n_atoms)]
+        distances = [list() for _ in range(n_atoms)]
         per_chain_pos = list()
         per_chain_indices = list()
 
         # Process individual chains
-        for ch_idx, ch in enumerate(self.target.chains):
-            ch_start_idx = self.chain_start_indices[ch_idx]
+        for ch in structure.chains:
             pos_list = list()
             atom_indices = list()
             mask_start = list()
@@ -924,10 +1033,10 @@ class lDDTScorer:
                 n_valid_atoms = 0
                 for a in r.atoms:
                     hash_code = a.handle.GetHashCode()
-                    if hash_code in self.atom_indices:
+                    if hash_code in atom_index_mapping:
                         p = a.GetPos()
                         pos_list.append(np.asarray([p[0], p[1], p[2]]))
-                        atom_indices.append(self.atom_indices[hash_code])
+                        atom_indices.append(atom_index_mapping[hash_code])
                         n_valid_atoms += 1
                 mask_start.extend([r_start_idx] * n_valid_atoms)
                 mask_end.extend([r_start_idx + n_valid_atoms] * n_valid_atoms)
@@ -937,12 +1046,12 @@ class lDDTScorer:
             dists = cdist(pos, pos)
 
             # apply masks
-            far_away = 2 * self.inclusion_radius
+            far_away = 2 * inclusion_radius
             for idx in range(atom_indices.shape[0]):
                 dists[idx, range(mask_start[idx], mask_end[idx])] = far_away
 
             # fish out and store close atoms within inclusion radius
-            within_mask = dists < self.inclusion_radius
+            within_mask = dists < inclusion_radius
             for idx in range(atom_indices.shape[0]):
                 indices_to_append = atom_indices[within_mask[idx,:]]
                 if indices_to_append.shape[0] > 0:
@@ -957,18 +1066,18 @@ class lDDTScorer:
         min_pos = [p.min(0) for p in per_chain_pos]
         max_pos = [p.max(0) for p in per_chain_pos]
         chain_pairs = list()
-        for idx_one in range(len(self.chain_start_indices)):
-            for idx_two in range(idx_one + 1, len(self.chain_start_indices)):
-                if np.max(min_pos[idx_one] - max_pos[idx_two]) > self.inclusion_radius:
+        for idx_one in range(len(per_chain_pos)):
+            for idx_two in range(idx_one + 1, len(per_chain_pos)):
+                if np.max(min_pos[idx_one] - max_pos[idx_two]) > inclusion_radius:
                     continue
-                if np.max(min_pos[idx_two] - max_pos[idx_one]) > self.inclusion_radius:
+                if np.max(min_pos[idx_two] - max_pos[idx_one]) > inclusion_radius:
                     continue
                 chain_pairs.append((idx_one, idx_two))
 
         # process potentially interacting chains
         for pair in chain_pairs:
             dists = cdist(per_chain_pos[pair[0]], per_chain_pos[pair[1]])
-            within = dists <= self.inclusion_radius
+            within = dists <= inclusion_radius
 
             # process pair[0]
             tmp = within.sum(axis=1)
@@ -1007,93 +1116,81 @@ class lDDTScorer:
                     distances[at_idx].insert(insertion_idx, distances_to_insert)
 
         # concatenate distances from all processing steps
-        for at_idx in range(self.n_atoms):
+        for at_idx in range(n_atoms):
             if len(indices[at_idx]) > 0:
-                self._ref_indices[at_idx] = np.hstack(indices[at_idx])
-                self._ref_distances[at_idx] = np.hstack(distances[at_idx])
+                ref_indices[at_idx] = np.hstack(indices[at_idx])
+                ref_distances[at_idx] = np.hstack(distances[at_idx])
 
-        self._NonSymDistances(self._ref_indices, self._ref_distances,
-                              self._sym_ref_indices,
-                              self._sym_ref_distances)
+        return (ref_indices, ref_distances)
 
-    def _SetupDistancesSC(self):
+    @staticmethod
+    def _SetupDistancesSC(n_atoms, chain_start_indices,
+                          ref_indices, ref_distances):
         """Select subset of contacts only covering intra-chain contacts
         """
         # init
-        self._ref_indices_sc = [np.asarray([], dtype=np.int64) for idx in range(self.n_atoms)]
-        self._ref_distances_sc = [np.asarray([], dtype=np.float64) for idx in range(self.n_atoms)]
-        self._sym_ref_indices_sc = [np.asarray([], dtype=np.int64) for idx in range(self.n_atoms)]
-        self._sym_ref_distances_sc = [np.asarray([], dtype=np.float64) for idx in range(self.n_atoms)]
-
-        # start from overall contacts
-        ref_indices = self.ref_indices
-        ref_distances = self.ref_distances
-        sym_ref_indices = self.sym_ref_indices
-        sym_ref_distances = self.sym_ref_distances
-
-        n_chains = len(self.chain_start_indices)
-        for ch_idx, ch in enumerate(self.target.chains):
-            chain_s = self.chain_start_indices[ch_idx]
-            chain_e = self.n_atoms
+        ref_indices_sc = [np.asarray([], dtype=np.int64) for idx in range(n_atoms)]
+        ref_distances_sc = [np.asarray([], dtype=np.float64) for idx in range(n_atoms)]
+
+        n_chains = len(chain_start_indices)
+        for ch_idx in range(n_chains):
+            chain_s = chain_start_indices[ch_idx]
+            chain_e = n_atoms
             if ch_idx + 1 < n_chains:
-                chain_e = self.chain_start_indices[ch_idx+1]
+                chain_e = chain_start_indices[ch_idx+1]
             for i in range(chain_s, chain_e):
                 if len(ref_indices[i]) > 0:
                     intra_idx = np.where(np.logical_and(ref_indices[i]>=chain_s,
                                                   ref_indices[i]<chain_e))[0]
-                    self._ref_indices_sc[i] = ref_indices[i][intra_idx]
-                    self._ref_distances_sc[i] = ref_distances[i][intra_idx]
+                    ref_indices_sc[i] = ref_indices[i][intra_idx]
+                    ref_distances_sc[i] = ref_distances[i][intra_idx]
 
-        self._NonSymDistances(self._ref_indices_sc, self._ref_distances_sc,
-                              self._sym_ref_indices_sc,
-                              self._sym_ref_distances_sc)
+        return (ref_indices_sc, ref_distances_sc)
 
-    def _SetupDistancesIC(self):
+    @staticmethod
+    def _SetupDistancesIC(n_atoms, chain_start_indices,
+                          ref_indices, ref_distances):
         """Select subset of contacts only covering inter-chain contacts
         """
         # init
-        self._ref_indices_ic = [np.asarray([], dtype=np.int64) for idx in range(self.n_atoms)]
-        self._ref_distances_ic = [np.asarray([], dtype=np.float64) for idx in range(self.n_atoms)]
-        self._sym_ref_indices_ic = [np.asarray([], dtype=np.int64) for idx in range(self.n_atoms)]
-        self._sym_ref_distances_ic = [np.asarray([], dtype=np.float64) for idx in range(self.n_atoms)]
-
-        # start from overall contacts
-        ref_indices = self.ref_indices
-        ref_distances = self.ref_distances
-        sym_ref_indices = self.sym_ref_indices
-        sym_ref_distances = self.sym_ref_distances
-
-        n_chains = len(self.chain_start_indices)
-        for ch_idx, ch in enumerate(self.target.chains):
-            chain_s = self.chain_start_indices[ch_idx]
-            chain_e = self.n_atoms
+        ref_indices_ic = [np.asarray([], dtype=np.int64) for idx in range(n_atoms)]
+        ref_distances_ic = [np.asarray([], dtype=np.float64) for idx in range(n_atoms)]
+
+        n_chains = len(chain_start_indices)
+        for ch_idx in range(n_chains):
+            chain_s = chain_start_indices[ch_idx]
+            chain_e = n_atoms
             if ch_idx + 1 < n_chains:
-                chain_e = self.chain_start_indices[ch_idx+1]
+                chain_e = chain_start_indices[ch_idx+1]
             for i in range(chain_s, chain_e):
                 if len(ref_indices[i]) > 0:
                     inter_idx = np.where(np.logical_or(ref_indices[i]<chain_s,
                                                   ref_indices[i]>=chain_e))[0]
-                    self._ref_indices_ic[i] = ref_indices[i][inter_idx]
-                    self._ref_distances_ic[i] = ref_distances[i][inter_idx]
+                    ref_indices_ic[i] = ref_indices[i][inter_idx]
+                    ref_distances_ic[i] = ref_distances[i][inter_idx]
 
-        self._NonSymDistances(self._ref_indices_ic, self._ref_distances_ic,
-                              self._sym_ref_indices_ic,
-                              self._sym_ref_distances_ic)
+        return (ref_indices_ic, ref_distances_ic)
 
-    def _NonSymDistances(self, ref_indices, ref_distances,
-                         sym_ref_indices, sym_ref_distances):
-        """Transfer indices/distances of non-symmetric atoms in place
+    @staticmethod
+    def _NonSymDistances(n_atoms, symmetric_atoms, ref_indices, ref_distances):
+        """Transfer indices/distances of non-symmetric atoms and return
         """
-        for idx in self.symmetric_atoms:
+
+        sym_ref_indices = [np.asarray([], dtype=np.int64) for idx in range(n_atoms)]
+        sym_ref_distances = [np.asarray([], dtype=np.float64) for idx in range(n_atoms)]
+
+        for idx in symmetric_atoms:
             indices = list()
             distances = list()
             for i, d in zip(ref_indices[idx], ref_distances[idx]):
-                if i not in self.symmetric_atoms:
+                if i not in symmetric_atoms:
                     indices.append(i)
                     distances.append(d)
             sym_ref_indices[idx] = indices
             sym_ref_distances[idx] = np.asarray(distances)
 
+        return (sym_ref_indices, sym_ref_distances)
+
     def _EvalAtom(self, pos, atom_idx, thresholds, ref_indices, ref_distances):
         """Computes number of distance differences within given thresholds
 
diff --git a/modules/mol/alg/tests/test_lddt.py b/modules/mol/alg/tests/test_lddt.py
index 0e31f6cac..755717f59 100644
--- a/modules/mol/alg/tests/test_lddt.py
+++ b/modules/mol/alg/tests/test_lddt.py
@@ -221,6 +221,22 @@ class TestlDDT(unittest.TestCase):
         # same for the conserved contacts
         self.assertEqual(lDDT_cons_ic + lDDT_cons_sc, lDDT_cons)
 
+    def test_add_mdl_contacts(self):
+        model = _LoadFile("7SGN_C_model.pdb")
+        target = _LoadFile("7SGN_C_target.pdb")
+
+        lddt_scorer = lDDTScorer(target)
+        lDDT, per_res_lDDT, lDDT_tot, lDDT_cons, res_indices, per_res_exp, \
+        per_res_conserved = lddt_scorer.lDDT(model,
+                                             return_dist_test = True,
+                                             add_mdl_contacts=True)
+
+        # this value is just blindly copied in without checking whether it makes
+        # any sense... it's sole purpose is to trigger the respective flag
+        # in lDDT computation
+        self.assertEqual(lDDT, 0.6171511842396518)
+
+
 
 class TestlDDTBS(unittest.TestCase):
 
-- 
GitLab