From d4d24587ce7ef6e016fb5bc80e6bf122eefaf935 Mon Sep 17 00:00:00 2001
From: Gabriel Studer <gabriel.studer@unibas.ch>
Date: Fri, 17 May 2024 18:24:39 +0200
Subject: [PATCH] lddt-pli: first implementation of added mdl contacts with
 heavy caching

Can be considered backup commit and still contains debug output
---
 modules/mol/alg/pymod/ligand_scoring.py | 488 +++++++++++++++++-------
 1 file changed, 348 insertions(+), 140 deletions(-)

diff --git a/modules/mol/alg/pymod/ligand_scoring.py b/modules/mol/alg/pymod/ligand_scoring.py
index 538920534..cd014db30 100644
--- a/modules/mol/alg/pymod/ligand_scoring.py
+++ b/modules/mol/alg/pymod/ligand_scoring.py
@@ -11,6 +11,7 @@ from ost import seq
 from ost import LogError, LogWarning, LogScript, LogInfo, LogVerbose, LogDebug
 from ost.mol.alg import chain_mapping
 from ost.mol.alg import lddt
+import time
 
 
 class LigandScorer:
@@ -286,7 +287,7 @@ class LigandScorer:
                  binding_sites_topn=100000, global_chain_mapping=False,
                  rmsd_assignment=False, n_max_naive=12, max_symmetries=1e5,
                  custom_mapping=None, unassigned=False, full_bs_search=False,
-                 add_mdl_contacts=False,
+                 add_mdl_contacts=True,
                  lddt_pli_thresholds = [0.5, 1.0, 2.0, 4.0]):
 
         if isinstance(model, mol.EntityView):
@@ -739,18 +740,30 @@ class LigandScorer:
                                                   model_ligand)
 
 
+
     def _compute_lddt_pli_add_mdl_contacts(self, symmetries, target_ligand,
                                            model_ligand):
 
+        t0 = time.time()
 
-        ##########################################################
-        # Get stuff from model/target from lazily computed cache #
-        ##########################################################
+        ###############################
+        # Get stuff from model/target #
+        ###############################
 
         trg_residues, trg_bs, trg_chains, trg_ligand_chain, \
         trg_ligand_res, scorer, chem_groups = \
         self._lddt_pli_get_trg_data(target_ligand)
 
+        # distance hacking... remove any interchain distance except the ones
+        # with the ligand
+        ref_indices = scorer.ref_indices_ic
+        ref_distances = scorer.ref_distances_ic
+        ligand_start_idx = scorer.chain_start_indices[-1]
+        for at_idx in range(ligand_start_idx):
+            mask = ref_indices[at_idx] >= ligand_start_idx
+            ref_indices[at_idx] = ref_indices[at_idx][mask]
+            ref_distances[at_idx] = ref_distances[at_idx][mask]
+
         mdl_residues, mdl_bs, mdl_chains, mdl_editor, mdl_ligand_chain,\
         mdl_ligand_res, chem_mapping = self._lddt_pli_get_mdl_data(model_ligand)
 
@@ -772,14 +785,149 @@ class LigandScorer:
         #    with ligand
         cut_ref_mdl_alns = self._lddt_pli_cut_ref_mdl_alns(chem_groups,
                                                            chem_mapping,
-                                                           mdl_bs)
-
-        ###############################################################
-        # compute lDDT for all possible chain mappings and symmetries #
-        ###############################################################
-
-        best_score = -1.0
-        best_result = None
+                                                           mdl_bs, trg_bs)
+
+        ########################
+        # Setup model contacts #
+        ########################
+
+        # THATS A GLOBAL THING AND CAN BE LAZILY COMPUTED!!!
+        # store for each ref_ch,mdl_ch pair all mdl atoms that can be
+        # mapped. Don't store mappable atoms as hashes but rather as tuple
+        # (mdl_r.GetNumber(), mdl_a.GetName()). Reason for that is that
+        # the full model and the mdl_bs are different entity handles without
+        # corresponding hashes.
+        mappable_atoms = dict()
+        for (ref_cname, mdl_cname), aln in self.ref_mdl_alns.items():
+            mappable_atoms[(ref_cname, mdl_cname)] = set()
+            ref_ch = self.chain_mapper.target.Select(f"cname={ref_cname}")
+            mdl_ch = self.chain_mapping_mdl.Select(f"cname={mdl_cname}")
+            aln.AttachView(0, ref_ch)
+            aln.AttachView(1, mdl_ch)
+            for col in aln:
+                ref_r = col.GetResidue(0)
+                mdl_r = col.GetResidue(1)
+                if ref_r.IsValid() and mdl_r.IsValid():
+                    for mdl_a in mdl_r.atoms:
+                        if ref_r.FindAtom(mdl_a.name).IsValid():
+                            at_key = (mdl_r.GetNumber(), mdl_a.name)
+                            mappable_atoms[(ref_cname, mdl_cname)].add(at_key)
+
+        # get each chain mapping that we ever observe in scoring
+        chain_mappings = list(chain_mapping._ChainMappings(chem_groups, chem_mapping))
+
+        # for each mdl ligand atom, we collect all trg ligand atoms that are
+        # ever mapped onto it 
+        ligand_atom_mappings = [set() for a in mdl_ligand_res.atoms]
+        for (trg_sym, mdl_sym) in symmetries:
+            for trg_i, mdl_i in zip(trg_sym, mdl_sym):
+                ligand_atom_mappings[mdl_i].add(trg_i)
+
+        mdl_ligand_pos = np.zeros((mdl_ligand_res.GetAtomCount(), 3))
+        for a_idx, a in enumerate(mdl_ligand_res.atoms):
+            p = a.GetPos()
+            mdl_ligand_pos[a_idx, 0] = p[0]
+            mdl_ligand_pos[a_idx, 1] = p[1]
+            mdl_ligand_pos[a_idx, 2] = p[2]
+
+        trg_ligand_pos = np.zeros((trg_ligand_res.GetAtomCount(), 3))
+        for a_idx, a in enumerate(trg_ligand_res.atoms):
+            p = a.GetPos()
+            trg_ligand_pos[a_idx, 0] = p[0]
+            trg_ligand_pos[a_idx, 1] = p[1]
+            trg_ligand_pos[a_idx, 2] = p[2]
+
+        mdl_lig_hashes = [a.hash_code for a in mdl_ligand_res.atoms]
+
+        symmetric_atoms = np.asarray(sorted(list(scorer.symmetric_atoms)), dtype=np.int64)
+
+        scoring_cache = list()
+        penalty_cache = list()
+
+        for mapping in chain_mappings:
+
+            # flat mapping with mdl chain names as key
+            flat_mapping = dict()
+            for trg_chem_group, mdl_chem_group in zip(chem_groups, mapping):
+                for a,b in zip(trg_chem_group, mdl_chem_group):
+                    if a is not None and b is not None:
+                        flat_mapping[b] = a
+
+            # for each mdl bs atom (as atom hash), the trg bs atoms (as index in scorer)
+            # some caching could help here => same mdl_ch/ref_ch combination could occur
+            # in several mappings...
+            bs_atom_mapping = dict()
+            yolo_mapping = dict()
+            for mdl_cname, ref_cname in flat_mapping.items():
+                aln = cut_ref_mdl_alns[(ref_cname, mdl_cname)]
+                ref_ch = trg_bs.Select(f"cname={ref_cname}")
+                mdl_ch = mdl_bs.Select(f"cname={mdl_cname}")
+                aln.AttachView(0, ref_ch)
+                aln.AttachView(1, mdl_ch)
+                for col in aln:
+                    ref_r = col.GetResidue(0)
+                    mdl_r = col.GetResidue(1)
+                    if ref_r.IsValid() and mdl_r.IsValid():
+                        for mdl_a in mdl_r.atoms:
+                            ref_a = ref_r.FindAtom(mdl_a.GetName())
+                            if ref_a.IsValid():
+                                ref_h = ref_a.handle.hash_code
+                                if ref_h in scorer.atom_indices:
+                                    mdl_h = mdl_a.handle.hash_code
+                                    bs_atom_mapping[mdl_h] = scorer.atom_indices[ref_h]
+                                    yolo_mapping[mdl_h] = ref_a
+
+            cache = dict()
+            n_penalties = list()
+
+            for mdl_a_idx, mdl_a in enumerate(mdl_ligand_res.atoms):
+                n_penalty = 0
+                trg_bs_indices = list()
+                close_a = mdl_bs.FindWithin(mdl_a.GetPos(), self.lddt_pli_radius)
+                for a in close_a:
+                    mdl_a_hash_code = a.hash_code
+                    if mdl_a_hash_code in bs_atom_mapping:
+                        trg_bs_indices.append(bs_atom_mapping[mdl_a_hash_code])
+                    elif mdl_a_hash_code not in mdl_lig_hashes:
+                        at_key = (a.GetResidue().GetNumber(), a.name)
+                        cname = a.GetChain().name
+                        cname_key = (flat_mapping[cname], cname)
+                        if at_key in mappable_atoms[cname_key]:
+                            # Its a contact in the model but not part of trg_bs.
+                            # It can still be mapped using the global
+                            # mdl_ch/ref_ch alignment
+                            # dist in ref > self.lddt_pli_radius + max_thresh
+                            # => guaranteed to be non-fulfilled contact
+                            n_penalty += 1
+
+                n_penalties.append(n_penalty)
+
+                trg_bs_indices = np.asarray(sorted(trg_bs_indices))
+                for trg_a_idx in ligand_atom_mappings[mdl_a_idx]:
+                    mask = np.isin(trg_bs_indices, ref_indices[ligand_start_idx + trg_a_idx],
+                                   assume_unique=True, invert=True)
+                    added_indices = np.asarray([], dtype=np.int64)
+                    added_distances = np.asarray([], dtype=np.float64)
+                    if np.sum(mask) > 0:
+                        # compute ref distances on reference positions
+                        added_indices = trg_bs_indices[mask]
+                        tmp = scorer.positions.take(added_indices, axis=0)
+                        np.subtract(tmp, trg_ligand_pos[trg_a_idx][None, :], out=tmp)
+                        np.square(tmp, out=tmp)
+                        tmp = tmp.sum(axis=1)
+                        np.sqrt(tmp, out=tmp)  # distances against all relevant atoms
+                        added_distances = tmp
+
+                    # extract the distances towards bs atoms that are symmetric
+                    sym_mask = np.isin(added_indices, symmetric_atoms,
+                                       assume_unique=True)
+
+                    cache[(mdl_a_idx, trg_a_idx)] = (added_indices, added_distances,
+                                                     added_indices[sym_mask],
+                                                     added_distances[sym_mask])
+
+            scoring_cache.append(cache)
+            penalty_cache.append(n_penalties)
 
         # cache for model contacts towards non mapped trg chains
         # key: tuple in form (mdl_ch, trg_ch)
@@ -793,7 +941,27 @@ class LigandScorer:
         # value: list of mdl atom handles that are within self.lddt_pli_radius
         close_atom_cache = dict()
 
-        for mapping in chain_mapping._ChainMappings(chem_groups, chem_mapping):
+        ###############################################################
+        # compute lDDT for all possible chain mappings and symmetries #
+        ###############################################################
+
+        best_score = -1.0
+        best_result = None
+
+        # dummy alignment for ligand chains which is needed as input later on
+        ligand_aln = seq.CreateAlignment()
+        trg_s = seq.CreateSequence(trg_ligand_chain.name,
+                                   trg_ligand_res.GetOneLetterCode())
+        mdl_s = seq.CreateSequence(mdl_ligand_chain.name,
+                                   mdl_ligand_res.GetOneLetterCode())
+        ligand_aln.AddSequence(trg_s)
+        ligand_aln.AddSequence(mdl_s)
+        ligand_at_indices =  list(range(ligand_start_idx, scorer.n_atoms))
+
+        sym_idx_collector = [None] * scorer.n_atoms
+        sym_dist_collector = [None] * scorer.n_atoms
+
+        for mapping, s_cache, p_cache in zip(chain_mappings, scoring_cache, penalty_cache):
 
             lddt_chain_mapping = dict()
             lddt_alns = dict()
@@ -806,15 +974,16 @@ class LigandScorer:
 
             # add ligand to lddt_chain_mapping/lddt_alns
             lddt_chain_mapping[mdl_ligand_chain.name] = trg_ligand_chain.name
-            ligand_aln = seq.CreateAlignment()
-            trg_s = seq.CreateSequence(trg_ligand_chain.name,
-                                       trg_ligand_res.GetOneLetterCode())
-            mdl_s = seq.CreateSequence(mdl_ligand_chain.name,
-                                       mdl_ligand_res.GetOneLetterCode())
-            ligand_aln.AddSequence(trg_s)
-            ligand_aln.AddSequence(mdl_s)
             lddt_alns[mdl_ligand_chain.name] = ligand_aln
 
+            # already process model, positions will be manually hacked for each
+            # symmetry - small overhead for variables that are thrown away here
+            pos, _, _, _, _, _, lddt_symmetries = \
+            scorer._ProcessModel(mdl_bs, lddt_chain_mapping,
+                                 residue_mapping = lddt_alns,
+                                 thresholds = self.lddt_pli_thresholds,
+                                 check_resnames = self.check_resnames)
+
             # estimate a penalty for unsatisfied model contacts from chains
             # that are not in the local trg binding site, but can be mapped in
             # the target.
@@ -848,58 +1017,55 @@ class LigandScorer:
                     if closest_ch is not None:
                         unmapped_chains.append((mdl_ch, closest_ch))
 
-            for i, (trg_sym, mdl_sym) in enumerate(symmetries):
-                # remove assert after proper testing - testing assumption made
-                # during development
-                assert(sorted(trg_sym)==list(range(len(trg_ligand_res.atoms))))
-                for a in mdl_ligand_res.atoms:
-                    mdl_editor.RenameAtom(a, "asdf")
-                for mdl_anum, trg_anum in zip(mdl_sym, trg_sym):
-                    # Rename model atoms according to symmetry
-                    trg_atom = trg_ligand_res.atoms[trg_anum]
-                    mdl_atom = mdl_ligand_res.atoms[mdl_anum]
-                    mdl_editor.RenameAtom(mdl_atom, trg_atom.name)
-
-                pos, res_ref_atom_indices, res_atom_indices, res_atom_hashes, \
-                res_indices, ref_res_indices, lddt_symmetries = \
-                scorer._ProcessModel(mdl_bs, lddt_chain_mapping,
-                                     residue_mapping = lddt_alns,
-                                     thresholds = self.lddt_pli_thresholds,
-                                     check_resnames = self.check_resnames)
-                ref_indices, ref_distances = \
-                scorer._AddMdlContacts(mdl_bs, res_atom_indices,
-                                       res_atom_hashes,
-                                       scorer.ref_indices_ic,
-                                       scorer.ref_distances_ic,
-                                       False, True)
-
-                # distance hacking... remove any interchain distance except the
-                # ones with the ligand
-                ligand_start_idx = scorer.chain_start_indices[-1]
-                for at_idx in range(ligand_start_idx):
-                    mask = ref_indices[at_idx] >= ligand_start_idx
-                    ref_indices[at_idx] = ref_indices[at_idx][mask]
-                    ref_distances[at_idx] = ref_distances[at_idx][mask]
-
-                # compute lddt symmetry related indices/distances
-                sym_ref_indices, sym_ref_distances = \
-                lddt.lDDTScorer._NonSymDistances(scorer.n_atoms,
-                                                 scorer.symmetric_atoms,
-                                                 ref_indices, ref_distances)
-
+            for (trg_sym, mdl_sym) in symmetries:
+                # update positions
+
+                t0_sym = time.time()
+
+                for mdl_i, trg_i in zip(mdl_sym, trg_sym):
+                    pos[ligand_start_idx + trg_i, :] = mdl_ligand_pos[mdl_i, :]
+
+                # start new ref_indices/ref_distances from original values
+                funky_ref_indices = [np.copy(a) for a in ref_indices]
+                funky_ref_distances = [np.copy(a) for a in ref_distances]
+
+                # The only distances from the binding site towards the ligand
+                # we care about are the ones from the symmetric atoms.
+                # We collect them while updating distances from added mdl
+                # contacts
+                for idx in symmetric_atoms:
+                    sym_idx_collector[idx] = list()
+                    sym_dist_collector[idx] = list()
+
+                # and add the ones from added mdl contacts
+                added_penalty = 0
+                for mdl_i, trg_i in zip(mdl_sym, trg_sym):
+                    added_penalty += p_cache[mdl_i]
+                    cache = s_cache[mdl_i, trg_i]
+                    full_trg_i = ligand_start_idx + trg_i
+                    funky_ref_indices[full_trg_i] = np.append(funky_ref_indices[full_trg_i], cache[0])
+                    funky_ref_distances[full_trg_i] = np.append(funky_ref_distances[full_trg_i], cache[1])
+                    for idx, d in zip(cache[2], cache[3]):
+                        sym_idx_collector[idx].append(full_trg_i)
+                        sym_dist_collector[idx].append(d)
+
+                for idx in symmetric_atoms:
+                    funky_ref_indices[idx] = np.append(funky_ref_indices[idx],
+                                                       np.asarray(sym_idx_collector[idx], dtype=np.int64))
+                    funky_ref_distances[idx] = np.append(funky_ref_distances[idx],
+                                                         np.asarray(sym_dist_collector[idx], dtype=np.float64))
+
+                # we can pass funky_ref_indices/funky_ref_distances as
+                # sym_ref_indices/sym_ref_distances in
+                # scorer._ResolveSymmetries as we only have distances of the bs
+                # to the ligand and ligand atoms are "non-symmetric"
                 scorer._ResolveSymmetries(pos, self.lddt_pli_thresholds,
                                           lddt_symmetries,
-                                          sym_ref_indices,
-                                          sym_ref_distances)
+                                          funky_ref_indices,
+                                          funky_ref_distances)
 
-                # only compute lDDT on ligand residue
-                n_exp = \
-                sum([len(ref_indices[i]) for i in range(ligand_start_idx,
-                                                        scorer.n_atoms)])
-                conserved = np.sum(scorer._EvalAtoms(pos, res_atom_indices[-1],
-                                                     self.lddt_pli_thresholds,
-                                                     ref_indices,ref_distances),
-                                   axis=0)
+                n_exp = sum([len(funky_ref_indices[i]) for i in ligand_at_indices])
+                n_exp += added_penalty
 
                 # collect number of expected contacts which can be mapped
                 if len(unmapped_chains) > 0:
@@ -910,15 +1076,20 @@ class LigandScorer:
                                                           mdl_bs,
                                                           mdl_ligand_res,
                                                           mdl_sym)
-                
+
+                conserved = np.sum(scorer._EvalAtoms(pos, ligand_at_indices,
+                                                     self.lddt_pli_thresholds,
+                                                     funky_ref_indices,
+                                                     funky_ref_distances), axis=0)
+                print(conserved)
+                print(n_exp, added_penalty)
                 score = np.mean(conserved/n_exp)
 
                 if score > best_score:
                     best_score = score
-                    # do not yet add actual bs_ref_res_mapped and bs_mdl_res_mapped
-                    # do this at the very end...
                     best_result = {"lddt_pli": score}
 
+                print("sym time:", time.time() - t0_sym)
 
         # fill misc info to result object
         best_result["target_ligand"] = target_ligand
@@ -927,21 +1098,44 @@ class LigandScorer:
         best_result["bs_mdl_res"] = mdl_residues
         best_result["inconsistent_residues"] = list()
 
+        print("full time", time.time() - t0)
+
         return best_result
 
 
     def _compute_lddt_pli_classic(self, symmetries, target_ligand,
                                   model_ligand):
 
+        ###############################
+        # Get stuff from model/target #
+        ###############################
 
-        ##########################################################
-        # Get stuff from model/target from lazily computed cache #
-        ##########################################################
+        t0 = time.time()
 
         trg_residues, trg_bs, trg_chains, trg_ligand_chain, \
         trg_ligand_res, scorer, chem_groups = \
         self._lddt_pli_get_trg_data(target_ligand)
 
+        # distance hacking... remove any interchain distance except the ones
+        # with the ligand
+        ref_indices = scorer.ref_indices_ic
+        ref_distances = scorer.ref_distances_ic
+        ligand_start_idx = scorer.chain_start_indices[-1]
+        for at_idx in range(ligand_start_idx):
+            mask = ref_indices[at_idx] >= ligand_start_idx
+            ref_indices[at_idx] = ref_indices[at_idx][mask]
+            ref_distances[at_idx] = ref_distances[at_idx][mask]
+
+        # no matter what mapping/symmetries, the number of expected
+        # contacts stays the same
+        ligand_at_indices = list(range(ligand_start_idx, scorer.n_atoms))
+        n_exp = sum([len(ref_indices[i]) for i in ligand_at_indices])
+
+        # compute lddt symmetry related indices/distances
+        sym_ref_indices, sym_ref_distances = \
+        lddt.lDDTScorer._NonSymDistances(scorer.n_atoms, scorer.symmetric_atoms,
+                                         ref_indices, ref_distances)
+
         mdl_residues, mdl_bs, mdl_chains, mdl_editor, mdl_ligand_chain,\
         mdl_ligand_res, chem_mapping = self._lddt_pli_get_mdl_data(model_ligand)
 
@@ -961,8 +1155,9 @@ class LigandScorer:
         # ref_mdl_alns refers to full chain mapper trg and mdl structures
         # => need to adapt mdl sequence that only contain residues in contact
         #    with ligand
-        cut_ref_mdl_alns = self._lddt_pli_cut_ref_mdl_alns(chem_groups, chem_mapping,
-                                                           mdl_bs)
+        cut_ref_mdl_alns = self._lddt_pli_cut_ref_mdl_alns(chem_groups,
+                                                           chem_mapping,
+                                                           mdl_bs, trg_bs)
 
         ###############################################################
         # compute lDDT for all possible chain mappings and symmetries #
@@ -971,6 +1166,22 @@ class LigandScorer:
         best_score = -1.0
         best_result = None
 
+        # dummy alignment for ligand chains which is needed as input later on
+        ligand_aln = seq.CreateAlignment()
+        trg_s = seq.CreateSequence(trg_ligand_chain.name,
+                                   trg_ligand_res.GetOneLetterCode())
+        mdl_s = seq.CreateSequence(mdl_ligand_chain.name,
+                                   mdl_ligand_res.GetOneLetterCode())
+        ligand_aln.AddSequence(trg_s)
+        ligand_aln.AddSequence(mdl_s)
+
+        mdl_ligand_pos = np.zeros((model_ligand.GetAtomCount(), 3))
+        for a_idx, a in enumerate(model_ligand.atoms):
+            p = a.GetPos()
+            mdl_ligand_pos[a_idx, 0] = p[0]
+            mdl_ligand_pos[a_idx, 1] = p[1]
+            mdl_ligand_pos[a_idx, 2] = p[2]
+
         for mapping in chain_mapping._ChainMappings(chem_groups, chem_mapping):
 
             lddt_chain_mapping = dict()
@@ -984,62 +1195,34 @@ class LigandScorer:
 
             # add ligand to lddt_chain_mapping/lddt_alns
             lddt_chain_mapping[mdl_ligand_chain.name] = trg_ligand_chain.name
-            ligand_aln = seq.CreateAlignment()
-            trg_s = seq.CreateSequence(trg_ligand_chain.name,
-                                       trg_ligand_res.GetOneLetterCode())
-            mdl_s = seq.CreateSequence(mdl_ligand_chain.name,
-                                       mdl_ligand_res.GetOneLetterCode())
-            ligand_aln.AddSequence(trg_s)
-            ligand_aln.AddSequence(mdl_s)
             lddt_alns[mdl_ligand_chain.name] = ligand_aln
 
-            ref_indices = scorer.ref_indices_ic
-            ref_distances = scorer.ref_distances_ic
-
-            # distance hacking... remove any interchain distance except the ones
-            # with the ligand
-            ligand_start_idx = scorer.chain_start_indices[-1]
-            for at_idx in range(ligand_start_idx):
-                mask = ref_indices[at_idx] >= ligand_start_idx
-                ref_indices[at_idx] = ref_indices[at_idx][mask]
-                ref_distances[at_idx] = ref_distances[at_idx][mask]
-
-            # compute lddt symmetry related indices/distances
-            sym_ref_indices, sym_ref_distances = \
-            lddt.lDDTScorer._NonSymDistances(scorer.n_atoms, scorer.symmetric_atoms,
-                                             ref_indices, ref_distances)
-
-            for i, (trg_sym, mdl_sym) in enumerate(symmetries):
-                # remove assert after proper testing - testing assumption made during development
-                assert(sorted(trg_sym) == list(range(len(trg_ligand_res.atoms))))
-                for a in mdl_ligand_res.atoms:
-                    mdl_editor.RenameAtom(a, "asdf")
-                for mdl_anum, trg_anum in zip(mdl_sym, trg_sym):
-                    # Rename model atoms according to symmetry
-                    trg_atom = trg_ligand_res.atoms[trg_anum]
-                    mdl_atom = mdl_ligand_res.atoms[mdl_anum]
-                    mdl_editor.RenameAtom(mdl_atom, trg_atom.name)
-
-                pos, res_ref_atom_indices, res_atom_indices, res_atom_hashes, \
-                res_indices, ref_res_indices, lddt_symmetries = \
-                scorer._ProcessModel(mdl_bs, lddt_chain_mapping,
-                                     residue_mapping = lddt_alns,
-                                     thresholds = self.lddt_pli_thresholds,
-                                     check_resnames = self.check_resnames)
-
-                scorer._ResolveSymmetries(pos, self.lddt_pli_thresholds, lddt_symmetries,
-                                          sym_ref_indices, sym_ref_distances)
-
-                # only compute lDDT on ligand residue
-                n_exp = sum([len(ref_indices[i]) for i in range(ligand_start_idx, scorer.n_atoms)])
-                conserved = np.sum(scorer._EvalAtoms(pos, res_atom_indices[-1], self.lddt_pli_thresholds,
-                                                     ref_indices, ref_distances), axis=0)
-
+            # already process model, positions will be manually hacked for each
+            # symmetry - small overhead of variables that are thrown away here
+            pos, _, _, _, _, _, lddt_symmetries = \
+            scorer._ProcessModel(mdl_bs, lddt_chain_mapping,
+                                 residue_mapping = lddt_alns,
+                                 thresholds = self.lddt_pli_thresholds,
+                                 check_resnames = self.check_resnames)
+
+            for (trg_sym, mdl_sym) in symmetries:
+                t0_sym = time.time()
+                for mdl_i, trg_i in zip(mdl_sym, trg_sym):
+                    pos[ligand_start_idx + trg_i, :] = mdl_ligand_pos[mdl_i, :]
+                scorer._ResolveSymmetries(pos, self.lddt_pli_thresholds,
+                                          lddt_symmetries,
+                                          sym_ref_indices,
+                                          sym_ref_distances)
+                conserved = np.sum(scorer._EvalAtoms(pos, ligand_at_indices,
+                                                     self.lddt_pli_thresholds,
+                                                     ref_indices,
+                                                     ref_distances), axis=0)
                 score = np.mean(conserved/n_exp)
 
                 if score > best_score:
                     best_score = score
                     best_result = {"lddt_pli": score}
+                print("sym time:", time.time() - t0_sym)
 
         # fill misc info to result object
         best_result["target_ligand"] = target_ligand
@@ -1048,6 +1231,8 @@ class LigandScorer:
         best_result["bs_mdl_res"] = mdl_residues
         best_result["inconsistent_residues"] = list()
 
+        print("full time:", time.time() - t0)
+
         return best_result
 
     def _lddt_pli_unmapped_chain_penalty(self, unmapped_chains,
@@ -1175,23 +1360,23 @@ class LigandScorer:
 
             trg = self.chain_mapper.target
 
+            max_r = self.lddt_pli_radius + max(self.lddt_pli_thresholds)
+
             trg_residues = set()
             for at in target_ligand.atoms:
-                close_atoms = trg.FindWithin(at.GetPos(), self.lddt_pli_radius)
+                close_atoms = trg.FindWithin(at.GetPos(), max_r)
                 for close_at in close_atoms:
                     trg_residues.add(close_at.GetResidue())
 
-            max_r = self.lddt_pli_radius + max(self.lddt_pli_thresholds)
+            for r in trg.residues:
+                r.SetIntProp("bs", 0)
 
-            trg_chains = set()
-            for at in target_ligand.atoms:
-                close_atoms = trg.FindWithin(at.GetPos(), max_r)
-                for close_at in close_atoms:
-                    trg_chains.add(close_at.GetChain().GetName())
+            for r in trg_residues:
+                r.SetIntProp("bs", 1)
+
+            trg_bs = mol.CreateEntityFromView(trg.Select("grbs:0=1"), True)
+            trg_chains = set([ch.name for ch in trg_bs.chains])
 
-            query = "cname="
-            query += ','.join([mol.QueryQuoteName(x) for x in trg_chains])
-            trg_bs = mol.CreateEntityFromView(trg.Select(query), True)
             trg_editor = trg_bs.EditXCS(mol.BUFFERED_EDIT)
             trg_ligand_chain = None
             for cname in ["hugo_the_cat_terminator", "ida_the_cheese_monster"]:
@@ -1229,29 +1414,52 @@ class LigandScorer:
 
         return self._lddt_pli_target_data[target_ligand]
 
-    def _lddt_pli_cut_ref_mdl_alns(self, chem_groups, chem_mapping, mdl_bs):
+
+    def _lddt_pli_cut_ref_mdl_alns(self, chem_groups, chem_mapping, mdl_bs,
+                                   ref_bs):
         cut_ref_mdl_alns = dict()
         for ref_chem_group, mdl_chem_group in zip(chem_groups, chem_mapping):
             for ref_ch in ref_chem_group:
+
+                ref_bs_chain = ref_bs.FindChain(ref_ch)
+                query = "cname=" + mol.QueryQuoteName(ref_ch)
+                ref_view = self.chain_mapper.target.Select(query)
+
                 for mdl_ch in mdl_chem_group:
                     aln = self.ref_mdl_alns[(ref_ch, mdl_ch)]
+
+                    aln.AttachView(0, ref_view)
+
                     mdl_bs_chain = mdl_bs.FindChain(mdl_ch)
                     query = "cname=" + mol.QueryQuoteName(mdl_ch)
                     aln.AttachView(1, self.chain_mapping_mdl.Select(query))
+
                     cut_mdl_seq = ['-'] * aln.GetLength()
+                    cut_ref_seq = ['-'] * aln.GetLength()
                     for i, col in enumerate(aln):
+
+                       # check ref residue
+                        r = col.GetResidue(0)
+                        if r.IsValid():
+                            bs_r = ref_bs_chain.FindResidue(r.GetNumber())
+                            if bs_r.IsValid():
+                                cut_ref_seq[i] = col[1]
+
+                        # check mdl residue
                         r = col.GetResidue(1)
                         if r.IsValid():
                             bs_r = mdl_bs_chain.FindResidue(r.GetNumber())
                             if bs_r.IsValid():
                                 cut_mdl_seq[i] = col[1]
+
+                    cut_ref_seq = ''.join(cut_ref_seq)         
+                    cut_mdl_seq = ''.join(cut_mdl_seq)         
                     cut_aln = seq.CreateAlignment()
-                    cut_aln.AddSequence(aln.GetSequence(0))
-                    cut_aln.AddSequence(seq.CreateSequence(mdl_ch, ''.join(cut_mdl_seq)))
+                    cut_aln.AddSequence(seq.CreateSequence(ref_ch, cut_ref_seq))
+                    cut_aln.AddSequence(seq.CreateSequence(mdl_ch, cut_mdl_seq))
                     cut_ref_mdl_alns[(ref_ch, mdl_ch)] = cut_aln
         return cut_ref_mdl_alns
 
-
     @staticmethod
     def _find_ligand_assignment(mat1, mat2=None, coverage=None, coverage_delta=None):
         """ Find the ligand assignment based on mat1. If mat2 is provided, it
-- 
GitLab