From 0907fa3631a67ad3e86d9e37d255860bb84d20cf Mon Sep 17 00:00:00 2001
From: Gabriel Studer <gabriel.studer@unibas.ch>
Date: Wed, 8 May 2024 18:51:10 +0200
Subject: [PATCH] lddt-pli: cleanup

---
 modules/mol/alg/pymod/ligand_scoring.py | 741 ++++++++++++++----------
 1 file changed, 448 insertions(+), 293 deletions(-)

diff --git a/modules/mol/alg/pymod/ligand_scoring.py b/modules/mol/alg/pymod/ligand_scoring.py
index c5637ca92..538920534 100644
--- a/modules/mol/alg/pymod/ligand_scoring.py
+++ b/modules/mol/alg/pymod/ligand_scoring.py
@@ -286,7 +286,8 @@ class LigandScorer:
                  binding_sites_topn=100000, global_chain_mapping=False,
                  rmsd_assignment=False, n_max_naive=12, max_symmetries=1e5,
                  custom_mapping=None, unassigned=False, full_bs_search=False,
-                 add_mdl_contacts=False):
+                 add_mdl_contacts=False,
+                 lddt_pli_thresholds = [0.5, 1.0, 2.0, 4.0]):
 
         if isinstance(model, mol.EntityView):
             self.model = mol.CreateEntityFromView(model, False)
@@ -342,6 +343,7 @@ class LigandScorer:
         self.coverage_delta = coverage_delta
         self.full_bs_search = full_bs_search
         self.add_mdl_contacts = add_mdl_contacts
+        self.lddt_pli_thresholds = lddt_pli_thresholds
 
         # scoring matrices
         self._rmsd_matrix = None
@@ -377,6 +379,10 @@ class LigandScorer:
         # value: list of repr results
         self._repr = dict()
 
+        # lazily precomputed variables to speedup lddt-pli computation
+        self._lddt_pli_target_data = dict()
+        self._lddt_pli_model_data = dict()
+
         # cache for rmsd values
         # rmsd is used as tie breaker in lddt-pli, we therefore need access to
         # the rmsd for each target_ligand/model_ligand/repr combination
@@ -721,92 +727,34 @@ class LigandScorer:
         return best_rmsd_result
 
 
-    def _compute_lddtpli(self, symmetries, target_ligand, model_ligand,
-                         thresholds = [0.5, 1.0, 2.0, 4.0]):
+    def _compute_lddtpli(self, symmetries, target_ligand, model_ligand):
 
-        # identify residues with contacts to ligands
-        trg = self.chain_mapper.target
-        mdl = self.chain_mapping_mdl
+        if self.add_mdl_contacts:
+            return self._compute_lddt_pli_add_mdl_contacts(symmetries,
+                                                           target_ligand,
+                                                           model_ligand)
+        else:
+            return self._compute_lddt_pli_classic(symmetries,
+                                                  target_ligand,
+                                                  model_ligand)
 
-        trg_residues = set()
-        for at in target_ligand.atoms:
-            close_atoms = trg.FindWithin(at.GetPos(), self.lddt_pli_radius)
-            for close_at in close_atoms:
-                trg_residues.add(close_at.GetResidue())
 
-        mdl_residues = set()
-        for at in model_ligand.atoms:
-            close_atoms = mdl.FindWithin(at.GetPos(), self.lddt_pli_radius)
-            for close_at in close_atoms:
-                mdl_residues.add(close_at.GetResidue())
+    def _compute_lddt_pli_add_mdl_contacts(self, symmetries, target_ligand,
+                                           model_ligand):
 
-        #####################
-        # setup lDDT scorer #
-        #####################
 
-        # max dist for peptide/nucleotide atom towards ligand for which non-zero
-        # contribution is possible
-        max_r = self.lddt_pli_radius + max(thresholds)
+        ##########################################################
+        # Get stuff from model/target from lazily computed cache #
+        ##########################################################
 
-        trg_chains = set()
-        for at in target_ligand.atoms:
-            close_atoms = trg.FindWithin(at.GetPos(), max_r)
-            for close_at in close_atoms:
-                trg_chains.add(close_at.GetChain().GetName())
+        trg_residues, trg_bs, trg_chains, trg_ligand_chain, \
+        trg_ligand_res, scorer, chem_groups = \
+        self._lddt_pli_get_trg_data(target_ligand)
 
-        if len(trg_chains) == 0:
-            # It's a spaceship!
-            return {"lddt_pli": 0.0,
-                    "target_ligand": target_ligand,
-                    "model_ligand": model_ligand,
-                    "bs_ref_res": trg_residues,
-                    "bs_mdl_res": mdl_residues,
-                    "inconsisntent_residues": list()}
-
-        chem_groups = list()
-        for g in self.chain_mapper.chem_groups:
-            chem_groups.append([x for x in g if x in trg_chains])
-
-        query = "cname="
-        query += ','.join([mol.QueryQuoteName(x) for x in trg_chains])
-        trg_bs = mol.CreateEntityFromView(trg.Select(query), True)
-        trg_editor = trg_bs.EditXCS(mol.BUFFERED_EDIT)
-        trg_ligand_chain = None
-        for cname in ["hugo_the_cat_terminator", "ida_the_cheese_monster"]:
-            try:
-                # I'm pretty sure, one of these chain names is not there yet
-                trg_ligand_chain = trg_editor.InsertChain(cname)
-                break
-            except:
-                pass
-        if trg_ligand_chain is None:
-            raise RuntimeError("Fuck this, I'm out...")
-
-        trg_ligand_res = trg_editor.AppendResidue(trg_ligand_chain, target_ligand,
-                                                  deep=True)
-        compound_name = trg_ligand_res.name
-        compound = lddt.CustomCompound.FromResidue(trg_ligand_res)
-        custom_compounds = {compound_name: compound}
-
-        scorer = lddt.lDDTScorer(trg_bs,
-                                 custom_compounds = custom_compounds,
-                                 inclusion_radius = self.lddt_pli_radius)
-
-        ###############
-        # setup model #
-        ###############
-        for r in mdl.residues:
-            r.SetIntProp("bs", 0)
-        for at in model_ligand.atoms:
-            close_atoms = mdl.FindWithin(at.GetPos(), max_r)
-            for close_at in close_atoms:
-                close_at.GetResidue().SetIntProp("bs", 1)
-
-
-        mdl_bs = mol.CreateEntityFromView(mdl.Select("grbs:0=1"), True)
-        mdl_chains = set([ch.name for ch in mdl_bs.chains])
-
-        if len(mdl_chains) == 0:
+        mdl_residues, mdl_bs, mdl_chains, mdl_editor, mdl_ligand_chain,\
+        mdl_ligand_res, chem_mapping = self._lddt_pli_get_mdl_data(model_ligand)
+
+        if len(mdl_chains) == 0 or len(trg_chains) == 0:
             # It's a spaceship!
             return {"lddt_pli": 0.0,
                     "target_ligand": target_ligand,
@@ -815,52 +763,21 @@ class LigandScorer:
                     "bs_mdl_res": mdl_residues,
                     "inconsistent_residues": list()}
 
-        mdl_editor = mdl_bs.EditXCS(mol.BUFFERED_EDIT)
-        mdl_ligand_chain = None
-        for cname in ["hugo_the_cat_terminator", "ida_the_cheese_monster"]:
-            try:
-                # I'm pretty sure, one of these chain names is not there yet
-                mdl_ligand_chain = mdl_editor.InsertChain(cname)
-                break
-            except:
-                pass
-        if mdl_ligand_chain is None:
-            raise RuntimeError("Fuck this, I'm out...")
-        mdl_ligand_res = mdl_editor.AppendResidue(mdl_ligand_chain, model_ligand,
-                                                  deep=True)
-
         ####################
         # Setup alignments #
         ####################
-        chem_mapping = list()
-        for m in self.chem_mapping:
-            chem_mapping.append([x for x in m if x in mdl_chains])
 
         # ref_mdl_alns refers to full chain mapper trg and mdl structures
         # => need to adapt mdl sequence that only contain residues in contact
         #    with ligand
-        cut_ref_mdl_alns = dict()
-        for ref_chem_group, mdl_chem_group in zip(chem_groups, chem_mapping):
-            for ref_ch in ref_chem_group:
-                for mdl_ch in mdl_chem_group:
-                    aln = self.ref_mdl_alns[(ref_ch, mdl_ch)]
-                    mdl_bs_chain = mdl_bs.FindChain(mdl_ch)
-                    aln.AttachView(1, mdl.Select("cname=" + mol.QueryQuoteName(mdl_ch)))
-                    cut_mdl_seq = ['-'] * aln.GetLength()
-                    for i, col in enumerate(aln):
-                        r = col.GetResidue(1)
-                        if r.IsValid():
-                            bs_r = mdl_bs_chain.FindResidue(r.GetNumber())
-                            if bs_r.IsValid():
-                                cut_mdl_seq[i] = col[1]
-                    cut_aln = seq.CreateAlignment()
-                    cut_aln.AddSequence(aln.GetSequence(0))
-                    cut_aln.AddSequence(seq.CreateSequence(mdl_ch, ''.join(cut_mdl_seq)))
-                    cut_ref_mdl_alns[(ref_ch, mdl_ch)] = cut_aln
+        cut_ref_mdl_alns = self._lddt_pli_cut_ref_mdl_alns(chem_groups,
+                                                           chem_mapping,
+                                                           mdl_bs)
 
         ###############################################################
         # compute lDDT for all possible chain mappings and symmetries #
         ###############################################################
+
         best_score = -1.0
         best_result = None
 
@@ -898,152 +815,66 @@ class LigandScorer:
             ligand_aln.AddSequence(mdl_s)
             lddt_alns[mdl_ligand_chain.name] = ligand_aln
 
-            if self.add_mdl_contacts:
-
-                # estimate a penalty for unsatisfied model contacts from chains
-                # that are not in the local trg binding site, but can be mapped in
-                # the target.
-                # We're using the trg chain with the closest geometric center that
-                # can be mapped to the mdl chain according the chem mapping.
-                # An alternative would be to search for the target chain with
-                # the minimal number of additional contacts.
-                # There is not good solution for this problem...
-                unmapped_chains = list()
-                for mdl_ch in mdl_chains:
-                    if mdl_ch not in lddt_chain_mapping:
-                        # check which chain in trg is closest
-                        chem_group_idx = None
-                        for i, m in self.chem_mapping:
-                            if mdl_ch in m:
-                                chem_group_idx = i
-                                break
-                        if chem_group_idx is None:
-                            raise RuntimeError("This should never happen... "
-                                               "ask Gabriel...")
-                        mdl_center = mdl.FindChain(mdl_ch).geometric_center
-                        closest_trg_ch = None
-                        closest_trg_ch_dist = None
-                        for trg_ch in self.chem_groups[chem_group_idx]:
-                            if trg_ch not in lddt_mapping.values():
-                                c = self.target.FindChain(trg_ch).geometric_center
-                                d = geom.Distance(mdl_center, c)
-                                if closest_trg_ch_dist is None or d < closest_trg_ch_dist:
-                                    closest_trg_ch_dist = d
-                                    closest_trg_ch = trg_ch
-                        if closest_trg_ch is not None:
-                            unmapped_chains.append((mdl_ch, closest_trg_ch))
-
-                for i, (trg_sym, mdl_sym) in enumerate(symmetries):
-                    # remove assert after proper testing - testing assumption made during development
-                    assert(sorted(trg_sym) == list(range(len(trg_ligand_res.atoms))))
-                    for a in mdl_ligand_res.atoms:
-                        mdl_editor.RenameAtom(a, "asdf")
-                    for mdl_anum, trg_anum in zip(mdl_sym, trg_sym):
-                        # Rename model atoms according to symmetry
-                        trg_atom = trg_ligand_res.atoms[trg_anum]
-                        mdl_atom = mdl_ligand_res.atoms[mdl_anum]
-                        mdl_editor.RenameAtom(mdl_atom, trg_atom.name)
-
-                    pos, res_ref_atom_indices, res_atom_indices, res_atom_hashes, \
-                    res_indices, ref_res_indices, lddt_symmetries = \
-                    scorer._ProcessModel(mdl_bs, lddt_chain_mapping,
-                                         residue_mapping = lddt_alns,
-                                         thresholds = thresholds,
-                                         check_resnames = self.check_resnames)
-                    ref_indices, ref_distances = \
-                    scorer._AddMdlContacts(mdl_bs, res_atom_indices, res_atom_hashes,
-                                           scorer.ref_indices_ic, scorer.ref_distances_ic,
-                                           False, True)
-
-                    # distance hacking... remove any interchain distance except the ones
-                    # with the ligand
-                    ligand_start_idx = scorer.chain_start_indices[-1]
-                    for at_idx in range(ligand_start_idx):
-                        mask = ref_indices[at_idx] >= ligand_start_idx
-                        ref_indices[at_idx] = ref_indices[at_idx][mask]
-                        ref_distances[at_idx] = ref_distances[at_idx][mask]
-
-                    # compute lddt symmetry related indices/distances
-                    sym_ref_indices, sym_ref_distances = \
-                    lddt.lDDTScorer._NonSymDistances(scorer.n_atoms, scorer.symmetric_atoms,
-                                                     ref_indices, ref_distances)
-
-                    scorer._ResolveSymmetries(pos, thresholds, lddt_symmetries,
-                                              sym_ref_indices, sym_ref_distances)
-
-                    # only compute lDDT on ligand residue
-                    n_exp = sum([len(ref_indices[i]) for i in range(ligand_start_idx, scorer.n_atoms)])
-                    conserved = np.sum(scorer._EvalAtoms(pos, res_atom_indices[-1], thresholds,
-                                                         ref_indices, ref_distances), axis=0)
-
-                    # collect number of expected contacts which can be mapped
-                    if len(unmapped_chains) > 0:
-                        for ch_tuple in unmapped_chains:
-                            if ch_tuple not in non_mapped_cache:
-
-                                # identify each atom in given mdl chain from mdl_bs
-                                # which can be mapped to a trg atom in given trg
-                                # chain
-                                mappable_atoms = set()
-                                aln = self.ref_mdl_alns[(ch_tuple[1], ch_tuple[0])]
-                                mdl_bs_chain = mdl_bs.FindChain(ch_tuple[0])
-                                aln.AttachView(0, trg.Select("cname=" + mol.QueryQuoteName(ch_tuple[1])))
-                                aln.AttachView(1, mdl.Select("cname=" + mol.QueryQuoteName(ch_tuple[0])))
-                                for i, col in enumerate(aln):
-                                    r = col.GetResidue(1)
-                                    if r.IsValid():
-                                        bs_r = mdl_bs_chain.FindResidue(r.GetNumber())
-                                        if bs_r.IsValid():
-                                            trg_r = col.GetResidue(0)
-                                            if trg_r.IsValid():
-                                                for bs_a in bs_r.atoms:
-                                                    trg_a = trg_r.FindAtom(bs_a.GetName())
-                                                    if trg_a.IsValid():
-                                                        mappable_atoms.add(bs_a.hash_code)
-
-                                # for each ligand atom, we count the number of mappable
-                                # atoms
-                                counts = dict()
-                                for lig_a in mdl_ligand_res.atoms:
-                                    close_atoms = None
-                                    if lig_a.hash_code not in close_atom_cache:
-                                        tmp = mdl_bs.FindWithin(lig_a.GetPos(), self.lddt_pli_radius)
-                                        lig_hash = mdl_ligand_res.hash_code
-                                        close_atoms = [x for x in tmp if x.GetResidue().GetHashCode() != lig_hash]
-                                        close_atom_cache[lig_a.hash_code] = close_atoms
-                                    else:
-                                        close_atoms = close_atom_cache[lig_a.hash_code]
-
-                                    N = 0
-                                    for close_a in close_atoms:
-                                        if close_a.hash_code in mappable_atoms:
-                                            N += 1
-
-                                    counts[lig_a.hash_code] = N
-
-                                # fill cache
-                                non_mapped_cache[ch_tuple] = counts
-
-                            # add number of mdl contacts which can be mapped to target
-                            # as non-fulfilled contacts
-                            counts = non_mapped_cache[ch_tuple]
-                            for i in mdl_sym:
-                                n_exp += counts[mdl_ligand_res.atoms[i].hash_code]
-                
-                    score = np.mean(conserved/n_exp)
-
-                    if score > best_score:
-                        best_score = score
-                        # do not yet add actual bs_ref_res_mapped and bs_mdl_res_mapped
-                        # do this at the very end...
-                        best_result = {"lddt_pli": score}
-
-            else:
-                ref_indices = scorer.ref_indices_ic
-                ref_distances = scorer.ref_distances_ic
-
-                # distance hacking... remove any interchain distance except the ones
-                # with the ligand
+            # estimate a penalty for unsatisfied model contacts from chains
+            # that are not in the local trg binding site, but can be mapped in
+            # the target.
+            # We're using the trg chain with the closest geometric center that
+            # can be mapped to the mdl chain according the chem mapping.
+            # An alternative would be to search for the target chain with
+            # the minimal number of additional contacts.
+            # There is not good solution for this problem...
+            unmapped_chains = list()
+            for mdl_ch in mdl_chains:
+                if mdl_ch not in lddt_chain_mapping:
+                    # check which chain in trg is closest
+                    chem_group_idx = None
+                    for i, m in self.chem_mapping:
+                        if mdl_ch in m:
+                            chem_group_idx = i
+                            break
+                    if chem_group_idx is None:
+                        raise RuntimeError("This should never happen... "
+                                           "ask Gabriel...")
+                    mdl_center = mdl.FindChain(mdl_ch).geometric_center
+                    closest_ch = None
+                    closest_dist = None
+                    for trg_ch in self.chem_groups[chem_group_idx]:
+                        if trg_ch not in lddt_mapping.values():
+                            c = self.target.FindChain(trg_ch).geometric_center
+                            d = geom.Distance(mdl_center, c)
+                            if closest_dist is None or d < closest_dist:
+                                closest_dist = d
+                                closest_ch = trg_ch
+                    if closest_ch is not None:
+                        unmapped_chains.append((mdl_ch, closest_ch))
+
+            for i, (trg_sym, mdl_sym) in enumerate(symmetries):
+                # remove assert after proper testing - testing assumption made
+                # during development
+                assert(sorted(trg_sym)==list(range(len(trg_ligand_res.atoms))))
+                for a in mdl_ligand_res.atoms:
+                    mdl_editor.RenameAtom(a, "asdf")
+                for mdl_anum, trg_anum in zip(mdl_sym, trg_sym):
+                    # Rename model atoms according to symmetry
+                    trg_atom = trg_ligand_res.atoms[trg_anum]
+                    mdl_atom = mdl_ligand_res.atoms[mdl_anum]
+                    mdl_editor.RenameAtom(mdl_atom, trg_atom.name)
+
+                pos, res_ref_atom_indices, res_atom_indices, res_atom_hashes, \
+                res_indices, ref_res_indices, lddt_symmetries = \
+                scorer._ProcessModel(mdl_bs, lddt_chain_mapping,
+                                     residue_mapping = lddt_alns,
+                                     thresholds = self.lddt_pli_thresholds,
+                                     check_resnames = self.check_resnames)
+                ref_indices, ref_distances = \
+                scorer._AddMdlContacts(mdl_bs, res_atom_indices,
+                                       res_atom_hashes,
+                                       scorer.ref_indices_ic,
+                                       scorer.ref_distances_ic,
+                                       False, True)
+
+                # distance hacking... remove any interchain distance except the
+                # ones with the ligand
                 ligand_start_idx = scorer.chain_start_indices[-1]
                 for at_idx in range(ligand_start_idx):
                     mask = ref_indices[at_idx] >= ligand_start_idx
@@ -1052,40 +883,42 @@ class LigandScorer:
 
                 # compute lddt symmetry related indices/distances
                 sym_ref_indices, sym_ref_distances = \
-                lddt.lDDTScorer._NonSymDistances(scorer.n_atoms, scorer.symmetric_atoms,
+                lddt.lDDTScorer._NonSymDistances(scorer.n_atoms,
+                                                 scorer.symmetric_atoms,
                                                  ref_indices, ref_distances)
 
-                for i, (trg_sym, mdl_sym) in enumerate(symmetries):
-                    # remove assert after proper testing - testing assumption made during development
-                    assert(sorted(trg_sym) == list(range(len(trg_ligand_res.atoms))))
-                    for a in mdl_ligand_res.atoms:
-                        mdl_editor.RenameAtom(a, "asdf")
-                    for mdl_anum, trg_anum in zip(mdl_sym, trg_sym):
-                        # Rename model atoms according to symmetry
-                        trg_atom = trg_ligand_res.atoms[trg_anum]
-                        mdl_atom = mdl_ligand_res.atoms[mdl_anum]
-                        mdl_editor.RenameAtom(mdl_atom, trg_atom.name)
-
-                    pos, res_ref_atom_indices, res_atom_indices, res_atom_hashes, \
-                    res_indices, ref_res_indices, lddt_symmetries = \
-                    scorer._ProcessModel(mdl_bs, lddt_chain_mapping,
-                                         residue_mapping = lddt_alns,
-                                         thresholds = thresholds,
-                                         check_resnames = self.check_resnames)
-
-                    scorer._ResolveSymmetries(pos, thresholds, lddt_symmetries,
-                                              sym_ref_indices, sym_ref_distances)
-
-                    # only compute lDDT on ligand residue
-                    n_exp = sum([len(ref_indices[i]) for i in range(ligand_start_idx, scorer.n_atoms)])
-                    conserved = np.sum(scorer._EvalAtoms(pos, res_atom_indices[-1], thresholds,
-                                                         ref_indices, ref_distances), axis=0)
-
-                    score = np.mean(conserved/n_exp)
-
-                    if score > best_score:
-                        best_score = score
-                        best_result = {"lddt_pli": score}
+                scorer._ResolveSymmetries(pos, self.lddt_pli_thresholds,
+                                          lddt_symmetries,
+                                          sym_ref_indices,
+                                          sym_ref_distances)
+
+                # only compute lDDT on ligand residue
+                n_exp = \
+                sum([len(ref_indices[i]) for i in range(ligand_start_idx,
+                                                        scorer.n_atoms)])
+                conserved = np.sum(scorer._EvalAtoms(pos, res_atom_indices[-1],
+                                                     self.lddt_pli_thresholds,
+                                                     ref_indices,ref_distances),
+                                   axis=0)
+
+                # collect number of expected contacts which can be mapped
+                if len(unmapped_chains) > 0:
+                    n_exp += \
+                    self._lddt_pli_unmapped_chain_penalty(unmapped_chains,
+                                                          non_mapped_cache,
+                                                          close_atom_cache,
+                                                          mdl_bs,
+                                                          mdl_ligand_res,
+                                                          mdl_sym)
+                
+                score = np.mean(conserved/n_exp)
+
+                if score > best_score:
+                    best_score = score
+                    # do not yet add actual bs_ref_res_mapped and bs_mdl_res_mapped
+                    # do this at the very end...
+                    best_result = {"lddt_pli": score}
+
 
         # fill misc info to result object
         best_result["target_ligand"] = target_ligand
@@ -1097,6 +930,328 @@ class LigandScorer:
         return best_result
 
 
+    def _compute_lddt_pli_classic(self, symmetries, target_ligand,
+                                  model_ligand):
+
+
+        ##########################################################
+        # Get stuff from model/target from lazily computed cache #
+        ##########################################################
+
+        trg_residues, trg_bs, trg_chains, trg_ligand_chain, \
+        trg_ligand_res, scorer, chem_groups = \
+        self._lddt_pli_get_trg_data(target_ligand)
+
+        mdl_residues, mdl_bs, mdl_chains, mdl_editor, mdl_ligand_chain,\
+        mdl_ligand_res, chem_mapping = self._lddt_pli_get_mdl_data(model_ligand)
+
+        if len(mdl_chains) == 0 or len(trg_chains) == 0:
+            # It's a spaceship!
+            return {"lddt_pli": 0.0,
+                    "target_ligand": target_ligand,
+                    "model_ligand": model_ligand,
+                    "bs_ref_res": trg_residues,
+                    "bs_mdl_res": mdl_residues,
+                    "inconsistent_residues": list()}
+
+        ####################
+        # Setup alignments #
+        ####################
+
+        # ref_mdl_alns refers to full chain mapper trg and mdl structures
+        # => need to adapt mdl sequence that only contain residues in contact
+        #    with ligand
+        cut_ref_mdl_alns = self._lddt_pli_cut_ref_mdl_alns(chem_groups, chem_mapping,
+                                                           mdl_bs)
+
+        ###############################################################
+        # compute lDDT for all possible chain mappings and symmetries #
+        ###############################################################
+
+        best_score = -1.0
+        best_result = None
+
+        for mapping in chain_mapping._ChainMappings(chem_groups, chem_mapping):
+
+            lddt_chain_mapping = dict()
+            lddt_alns = dict()
+            for ref_chem_group, mdl_chem_group in zip(chem_groups, mapping):
+                for ref_ch, mdl_ch in zip(ref_chem_group, mdl_chem_group):
+                    # some mdl chains can be None
+                    if mdl_ch is not None:
+                        lddt_chain_mapping[mdl_ch] = ref_ch
+                        lddt_alns[mdl_ch] = cut_ref_mdl_alns[(ref_ch, mdl_ch)]
+
+            # add ligand to lddt_chain_mapping/lddt_alns
+            lddt_chain_mapping[mdl_ligand_chain.name] = trg_ligand_chain.name
+            ligand_aln = seq.CreateAlignment()
+            trg_s = seq.CreateSequence(trg_ligand_chain.name,
+                                       trg_ligand_res.GetOneLetterCode())
+            mdl_s = seq.CreateSequence(mdl_ligand_chain.name,
+                                       mdl_ligand_res.GetOneLetterCode())
+            ligand_aln.AddSequence(trg_s)
+            ligand_aln.AddSequence(mdl_s)
+            lddt_alns[mdl_ligand_chain.name] = ligand_aln
+
+            ref_indices = scorer.ref_indices_ic
+            ref_distances = scorer.ref_distances_ic
+
+            # distance hacking... remove any interchain distance except the ones
+            # with the ligand
+            ligand_start_idx = scorer.chain_start_indices[-1]
+            for at_idx in range(ligand_start_idx):
+                mask = ref_indices[at_idx] >= ligand_start_idx
+                ref_indices[at_idx] = ref_indices[at_idx][mask]
+                ref_distances[at_idx] = ref_distances[at_idx][mask]
+
+            # compute lddt symmetry related indices/distances
+            sym_ref_indices, sym_ref_distances = \
+            lddt.lDDTScorer._NonSymDistances(scorer.n_atoms, scorer.symmetric_atoms,
+                                             ref_indices, ref_distances)
+
+            for i, (trg_sym, mdl_sym) in enumerate(symmetries):
+                # remove assert after proper testing - testing assumption made during development
+                assert(sorted(trg_sym) == list(range(len(trg_ligand_res.atoms))))
+                for a in mdl_ligand_res.atoms:
+                    mdl_editor.RenameAtom(a, "asdf")
+                for mdl_anum, trg_anum in zip(mdl_sym, trg_sym):
+                    # Rename model atoms according to symmetry
+                    trg_atom = trg_ligand_res.atoms[trg_anum]
+                    mdl_atom = mdl_ligand_res.atoms[mdl_anum]
+                    mdl_editor.RenameAtom(mdl_atom, trg_atom.name)
+
+                pos, res_ref_atom_indices, res_atom_indices, res_atom_hashes, \
+                res_indices, ref_res_indices, lddt_symmetries = \
+                scorer._ProcessModel(mdl_bs, lddt_chain_mapping,
+                                     residue_mapping = lddt_alns,
+                                     thresholds = self.lddt_pli_thresholds,
+                                     check_resnames = self.check_resnames)
+
+                scorer._ResolveSymmetries(pos, self.lddt_pli_thresholds, lddt_symmetries,
+                                          sym_ref_indices, sym_ref_distances)
+
+                # only compute lDDT on ligand residue
+                n_exp = sum([len(ref_indices[i]) for i in range(ligand_start_idx, scorer.n_atoms)])
+                conserved = np.sum(scorer._EvalAtoms(pos, res_atom_indices[-1], self.lddt_pli_thresholds,
+                                                     ref_indices, ref_distances), axis=0)
+
+                score = np.mean(conserved/n_exp)
+
+                if score > best_score:
+                    best_score = score
+                    best_result = {"lddt_pli": score}
+
+        # fill misc info to result object
+        best_result["target_ligand"] = target_ligand
+        best_result["model_ligand"] = model_ligand
+        best_result["bs_ref_res"] = trg_residues
+        best_result["bs_mdl_res"] = mdl_residues
+        best_result["inconsistent_residues"] = list()
+
+        return best_result
+
+    def _lddt_pli_unmapped_chain_penalty(self, unmapped_chains,
+                                         non_mapped_cache,
+                                         close_atom_cache,
+                                         mdl_bs,
+                                         mdl_ligand_res,
+                                         mdl_sym):
+
+        n_exp = 0
+        for ch_tuple in unmapped_chains:
+            if ch_tuple not in non_mapped_cache:
+
+                # identify each atom in given mdl chain from mdl_bs which can be
+                # mapped to a trg atom in given trg chain
+                mappable_atoms = set()
+                aln = self.ref_mdl_alns[(ch_tuple[1], ch_tuple[0])]
+                mdl_bs_chain = mdl_bs.FindChain(ch_tuple[0])
+                trg_query = "cname=" + mol.QueryQuoteName(ch_tuple[1])
+                trg_view = self.chain_mapper.target.Select(trg_query)
+                aln.AttachView(0, trg_view)
+                mdl_query = "cname=" + mol.QueryQuoteName(ch_tuple[0])
+                mdl_view = self.chain_mapping_mdl.Select(mdl_query)
+                aln.AttachView(1, mdl_view)
+                for i, col in enumerate(aln):
+                    r = col.GetResidue(1)
+                    if r.IsValid():
+                        bs_r = mdl_bs_chain.FindResidue(r.GetNumber())
+                        if bs_r.IsValid():
+                            trg_r = col.GetResidue(0)
+                            if trg_r.IsValid():
+                                for a in bs_r.atoms:
+                                    trg_a = trg_r.FindAtom(bs_a.GetName())
+                                    if trg_a.IsValid():
+                                        mappable_atoms.add(a.handle.hash_code)
+
+                # for each ligand atom, we count the number of mappable atoms
+                # within lddt_pli_radius
+                counts = dict()
+                for lig_a in mdl_ligand_res.atoms:
+                    close_atoms = None
+                    if lig_a.hash_code not in close_atom_cache:
+                        tmp = mdl_bs.FindWithin(lig_a.GetPos(),
+                                                self.lddt_pli_radius)
+                        h = mdl_ligand_res.hash_code
+                        tmp = [x for x in tmp if x.GetResidue().hash_code != h]
+                        close_atom_cache[lig_a.hash_code] = tmp
+                    else:
+                        close_atoms = close_atom_cache[lig_a.hash_code]
+
+                    N = 0
+                    for close_a in close_atoms:
+                        if close_a.handle.hash_code in mappable_atoms:
+                            N += 1
+
+                    counts[lig_a.hash_code] = N
+
+                # fill cache
+                non_mapped_cache[ch_tuple] = counts
+
+            # add number of mdl contacts which can be mapped to target
+            # as non-fulfilled contacts
+            counts = non_mapped_cache[ch_tuple]
+            mdl_ligand_res_atoms = mdl_ligand_res.atoms
+            for i in mdl_sym:
+                n_exp += counts[mdl_ligand_res_atoms[i].hash_code]
+
+        return n_exp
+
+
+    def _lddt_pli_get_mdl_data(self, model_ligand):
+        if model_ligand not in self._lddt_pli_model_data:
+
+            mdl = self.chain_mapping_mdl
+
+            mdl_residues = set()
+            for at in model_ligand.atoms:
+                close_atoms = mdl.FindWithin(at.GetPos(), self.lddt_pli_radius)
+                for close_at in close_atoms:
+                    mdl_residues.add(close_at.GetResidue())
+
+            max_r = self.lddt_pli_radius + max(self.lddt_pli_thresholds)
+            for r in mdl.residues:
+                r.SetIntProp("bs", 0)
+            for at in model_ligand.atoms:
+                close_atoms = mdl.FindWithin(at.GetPos(), max_r)
+                for close_at in close_atoms:
+                    close_at.GetResidue().SetIntProp("bs", 1)
+
+            mdl_bs = mol.CreateEntityFromView(mdl.Select("grbs:0=1"), True)
+            mdl_chains = set([ch.name for ch in mdl_bs.chains])
+
+            mdl_editor = mdl_bs.EditXCS(mol.BUFFERED_EDIT)
+            mdl_ligand_chain = None
+            for cname in ["hugo_the_cat_terminator", "ida_the_cheese_monster"]:
+                try:
+                    # I'm pretty sure, one of these chain names is not there...
+                    mdl_ligand_chain = mdl_editor.InsertChain(cname)
+                    break
+                except:
+                    pass
+            if mdl_ligand_chain is None:
+                raise RuntimeError("Fuck this, I'm out...")
+            mdl_ligand_res = mdl_editor.AppendResidue(mdl_ligand_chain,
+                                                      model_ligand,
+                                                      deep=True)
+
+            chem_mapping = list()
+            for m in self.chem_mapping:
+                chem_mapping.append([x for x in m if x in mdl_chains])
+
+            self._lddt_pli_model_data[model_ligand] = (mdl_residues,
+                                                       mdl_bs,
+                                                       mdl_chains,
+                                                       mdl_editor,
+                                                       mdl_ligand_chain,
+                                                       mdl_ligand_res,
+                                                       chem_mapping)
+
+        return self._lddt_pli_model_data[model_ligand]
+
+
+    def _lddt_pli_get_trg_data(self, target_ligand):
+        if target_ligand not in self._lddt_pli_target_data:
+
+            trg = self.chain_mapper.target
+
+            trg_residues = set()
+            for at in target_ligand.atoms:
+                close_atoms = trg.FindWithin(at.GetPos(), self.lddt_pli_radius)
+                for close_at in close_atoms:
+                    trg_residues.add(close_at.GetResidue())
+
+            max_r = self.lddt_pli_radius + max(self.lddt_pli_thresholds)
+
+            trg_chains = set()
+            for at in target_ligand.atoms:
+                close_atoms = trg.FindWithin(at.GetPos(), max_r)
+                for close_at in close_atoms:
+                    trg_chains.add(close_at.GetChain().GetName())
+
+            query = "cname="
+            query += ','.join([mol.QueryQuoteName(x) for x in trg_chains])
+            trg_bs = mol.CreateEntityFromView(trg.Select(query), True)
+            trg_editor = trg_bs.EditXCS(mol.BUFFERED_EDIT)
+            trg_ligand_chain = None
+            for cname in ["hugo_the_cat_terminator", "ida_the_cheese_monster"]:
+                try:
+                    # I'm pretty sure, one of these chain names is not there yet
+                    trg_ligand_chain = trg_editor.InsertChain(cname)
+                    break
+                except:
+                    pass
+            if trg_ligand_chain is None:
+                raise RuntimeError("Fuck this, I'm out...")
+
+            trg_ligand_res = trg_editor.AppendResidue(trg_ligand_chain,
+                                                      target_ligand,
+                                                      deep=True)
+            compound_name = trg_ligand_res.name
+            compound = lddt.CustomCompound.FromResidue(trg_ligand_res)
+            custom_compounds = {compound_name: compound}
+
+            scorer = lddt.lDDTScorer(trg_bs,
+                                     custom_compounds = custom_compounds,
+                                     inclusion_radius = self.lddt_pli_radius)
+
+            chem_groups = list()
+            for g in self.chain_mapper.chem_groups:
+                chem_groups.append([x for x in g if x in trg_chains])
+
+            self._lddt_pli_target_data[target_ligand] = (trg_residues,
+                                                         trg_bs,
+                                                         trg_chains,
+                                                         trg_ligand_chain,
+                                                         trg_ligand_res,
+                                                         scorer,
+                                                         chem_groups)
+
+        return self._lddt_pli_target_data[target_ligand]
+
+    def _lddt_pli_cut_ref_mdl_alns(self, chem_groups, chem_mapping, mdl_bs):
+        cut_ref_mdl_alns = dict()
+        for ref_chem_group, mdl_chem_group in zip(chem_groups, chem_mapping):
+            for ref_ch in ref_chem_group:
+                for mdl_ch in mdl_chem_group:
+                    aln = self.ref_mdl_alns[(ref_ch, mdl_ch)]
+                    mdl_bs_chain = mdl_bs.FindChain(mdl_ch)
+                    query = "cname=" + mol.QueryQuoteName(mdl_ch)
+                    aln.AttachView(1, self.chain_mapping_mdl.Select(query))
+                    cut_mdl_seq = ['-'] * aln.GetLength()
+                    for i, col in enumerate(aln):
+                        r = col.GetResidue(1)
+                        if r.IsValid():
+                            bs_r = mdl_bs_chain.FindResidue(r.GetNumber())
+                            if bs_r.IsValid():
+                                cut_mdl_seq[i] = col[1]
+                    cut_aln = seq.CreateAlignment()
+                    cut_aln.AddSequence(aln.GetSequence(0))
+                    cut_aln.AddSequence(seq.CreateSequence(mdl_ch, ''.join(cut_mdl_seq)))
+                    cut_ref_mdl_alns[(ref_ch, mdl_ch)] = cut_aln
+        return cut_ref_mdl_alns
+
+
     @staticmethod
     def _find_ligand_assignment(mat1, mat2=None, coverage=None, coverage_delta=None):
         """ Find the ligand assignment based on mat1. If mat2 is provided, it
-- 
GitLab