From 014ebc2e0cd4fa4c7f7318f5d8f568d84dc5f005 Mon Sep 17 00:00:00 2001 From: Gabriel Studer <gabriel.studer@unibas.ch> Date: Tue, 25 Jun 2024 22:56:22 +0200 Subject: [PATCH] chain mapping: Use lDDT for mappings involving nucleotides - introduces optimized bb lDDT QS-score which is used as default target function for chain mappings. However, thats protein specific. As soon as nucleotides are involved, lDDT with an increased inclusion radius of 30A is now used in ChainMapper.GetMapping. It was a bit embarassing that lDDT mappings were about an order of magnitude slower than QS-score mappings. A specialized backbone only lDDT has therefore been introduced that uses matrix operations. This is not a replacement of the lDDTScorer. But it doesnt need to produce per-residue scores and doesn't need to deal with symmetries etc. --- actions/ost-compare-ligand-structures | 137 +++--- modules/mol/alg/pymod/CMakeLists.txt | 1 + modules/mol/alg/pymod/bb_lddt.py | 550 +++++++++++++++++++++++++ modules/mol/alg/pymod/chain_mapping.py | 220 ++-------- modules/mol/alg/tests/CMakeLists.txt | 1 + modules/mol/alg/tests/test_bblddt.py | 117 ++++++ 6 files changed, 795 insertions(+), 231 deletions(-) create mode 100644 modules/mol/alg/pymod/bb_lddt.py create mode 100644 modules/mol/alg/tests/test_bblddt.py diff --git a/actions/ost-compare-ligand-structures b/actions/ost-compare-ligand-structures index 7c5fc8776..a8491fea3 100644 --- a/actions/ost-compare-ligand-structures +++ b/actions/ost-compare-ligand-structures @@ -47,10 +47,23 @@ options, this is a dictionary with three keys: content of the JSON output will be \"status\" set to FAILURE and an additional key: "traceback". -Each score is opt-in and must be enabled with optional arguments. The scores -perform a model/reference ligand assignment and report a score for each assigned -model ligand. Optionally, unassigned model ligands are reported with a null -score and a reason why no assignment has been performed (--unassigned/-u). +Each score is opt-in and the respective results are available in two keys: + + * "model_ligands": Model ligand centric scoring based on model/reference + ligand assignment. A score including meta data is reported for each assigned + model ligand given the assigned target ligand. Unassigned model ligands are + reported with a null score and a reason why no assignment has been performed. + + * "full": The full all vs. all scoring results. Yet another dictionary with + keys: + + * "assignment": List of pairs in form (ref_lig_idx, mdl_lig_idx) specifying + the ligands in "reference_ligands"/"model_ligands". + * "scores": A dictionary with key in form (ref_lig_idx, mdl_lig_idx) and + value yet another dict with score information for each possible + reference/model ligand pair. The respective score is None if no score could + be computed. This can simply be a mismatch between the two ligands. This or + other reasons are reported. """ import argparse @@ -191,16 +204,6 @@ def _ParseArgs(): default=0.2, help=("Coverage delta for partial ligand assignment.")) - parser.add_argument( - "-u", - "--unassigned", - dest="unassigned", - default=False, - action="store_true", - help=("Report unassigned model ligands in the output together with " - "assigned ligands, with a null score, and reason for not being " - "assigned.")) - parser.add_argument( '-v', '--verbosity', @@ -511,31 +514,48 @@ def _Process(model, model_ligands, reference, reference_ligands, args): if args.lddt_pli: out["lddt_pli"] = dict() + out["lddt_pli"]["model_ligands"] = dict() + out["lddt_pli"]["full"] = dict() for lig_pair in lddtpli_scorer.assignment: score = float(lddtpli_scorer.score_matrix[lig_pair[0], lig_pair[1]]) coverage = float(lddtpli_scorer.coverage_matrix[lig_pair[0], lig_pair[1]]) aux_data = lddtpli_scorer.aux_matrix[lig_pair[0], lig_pair[1]] target_key = out["reference_ligands"][lig_pair[0]] model_key = out["model_ligands"][lig_pair[1]] - out["lddt_pli"][model_key] = {"lddt_pli": score, - "coverage": coverage, - "lddt_pli_n_contacts": aux_data["lddt_pli_n_contacts"], - "model_ligand": model_key, - "reference_ligand": target_key, - "bs_ref_res": [_QualifiedResidueNotation(r) for r in - aux_data["bs_ref_res"]], - "bs_mdl_res": [_QualifiedResidueNotation(r) for r in - aux_data["bs_mdl_res"]]} - if args.unassigned: - for i in lddtpli_scorer.unassigned_model_ligands: - model_key = out["model_ligands"][i] - reason = lddtpli_scorer.guess_model_ligand_unassigned_reason(i) - out["lddt_pli"][model_key] = {"lddt_pli": None, - "unassigned_reason": reason} + out["lddt_pli"]["model_ligands"][model_key] = {"lddt_pli": score, + "coverage": coverage, + "lddt_pli_n_contacts": aux_data["lddt_pli_n_contacts"], + "model_ligand": model_key, + "reference_ligand": target_key, + "bs_ref_res": [_QualifiedResidueNotation(r) for r in + aux_data["bs_ref_res"]], + "bs_mdl_res": [_QualifiedResidueNotation(r) for r in + aux_data["bs_mdl_res"]]} + + for i in lddtpli_scorer.unassigned_model_ligands: + model_key = out["model_ligands"][i] + reason = lddtpli_scorer.guess_model_ligand_unassigned_reason(i) + out["lddt_pli"]["model_ligands"][model_key] = {"lddt_pli": None, + "unassigned_reason": reason} + + out["lddt_pli"]["full"]["assignment"] = lddtpli_scorer.assignment + out["lddt_pli"]["full"]["scores"] = dict() + + shape = lddtpli_scorer.score_matrix.shape + for ref_lig_idx in range(shape[0]): + for mdl_lig_idx in range(shape[1]): + score = float(lddtpli_scorer.score_matrix[(ref_lig_idx, mdl_lig_idx)]) + state = int(lddtpli_scorer.state_matrix[(ref_lig_idx, mdl_lig_idx)]) + desc = lddtpli_scorer.state_decoding[state] + pair_key = [ref_lig_idx, mdl_lig_idx] + out["lddt_pli"]["full"]["scores"][pair_key] = {"score": score, + "state": desc} if args.rmsd: out["rmsd"] = dict() + out["rmsd"]["model_ligands"] = dict() + out["rmsd"]["full"] = dict() for lig_pair in scrmsd_scorer.assignment: score = float(scrmsd_scorer.score_matrix[lig_pair[0], lig_pair[1]]) coverage = float(scrmsd_scorer.coverage_matrix[lig_pair[0], lig_pair[1]]) @@ -543,29 +563,42 @@ def _Process(model, model_ligands, reference, reference_ligands, args): target_key = out["reference_ligands"][lig_pair[0]] model_key = out["model_ligands"][lig_pair[1]] transform_data = aux_data["transform"].data - out["rmsd"][model_key] = {"rmsd": score, - "coverage": coverage, - "lddt_lp": aux_data["lddt_lp"], - "bb_rmsd": aux_data["bb_rmsd"], - "model_ligand": model_key, - "reference_ligand": target_key, - "chain_mapping": aux_data["chain_mapping"], - "bs_ref_res": [_QualifiedResidueNotation(r) for r in - aux_data["bs_ref_res"]], - "bs_ref_res_mapped": [_QualifiedResidueNotation(r) for r in - aux_data["bs_ref_res_mapped"]], - "bs_mdl_res_mapped": [_QualifiedResidueNotation(r) for r in - aux_data["bs_mdl_res_mapped"]], - "inconsistent_residues": [_QualifiedResidueNotation(r) for r in - aux_data["inconsistent_residues"]], - "transform": [transform_data[i:i + 4] - for i in range(0, len(transform_data), 4)]} - if args.unassigned: - for i in scrmsd_scorer.unassigned_model_ligands: - model_key = out["model_ligands"][i] - reason = scrmsd_scorer.guess_model_ligand_unassigned_reason(i) - out["rmsd"][model_key] = {"rmsd": None, - "unassigned_reason": reason} + out["rmsd"]["model_ligands"][model_key] = {"rmsd": score, + "coverage": coverage, + "lddt_lp": aux_data["lddt_lp"], + "bb_rmsd": aux_data["bb_rmsd"], + "model_ligand": model_key, + "reference_ligand": target_key, + "chain_mapping": aux_data["chain_mapping"], + "bs_ref_res": [_QualifiedResidueNotation(r) for r in + aux_data["bs_ref_res"]], + "bs_ref_res_mapped": [_QualifiedResidueNotation(r) for r in + aux_data["bs_ref_res_mapped"]], + "bs_mdl_res_mapped": [_QualifiedResidueNotation(r) for r in + aux_data["bs_mdl_res_mapped"]], + "inconsistent_residues": [_QualifiedResidueNotation(r) for r in + aux_data["inconsistent_residues"]], + "transform": [transform_data[i:i + 4] + for i in range(0, len(transform_data), 4)]} + for i in scrmsd_scorer.unassigned_model_ligands: + model_key = out["model_ligands"][i] + reason = scrmsd_scorer.guess_model_ligand_unassigned_reason(i) + out["rmsd"]["model_ligands"][model_key] = {"rmsd": None, + "unassigned_reason": reason} + + out["rmsd"]["full"]["assignment"] = scrmsd_scorer.assignment + out["rmsd"]["full"]["scores"] = dict() + + shape = scrmsd_scorer.score_matrix.shape + for ref_lig_idx in range(shape[0]): + for mdl_lig_idx in range(shape[1]): + score = float(scrmsd_scorer.score_matrix[(ref_lig_idx, mdl_lig_idx)]) + state = int(scrmsd_scorer.state_matrix[(ref_lig_idx, mdl_lig_idx)]) + desc = scrmsd_scorer.state_decoding[state] + pair_key = [ref_lig_idx, mdl_lig_idx] + out["rmsd"]["full"]["scores"][pair_key] = {"score": score, + "state": desc} + return out diff --git a/modules/mol/alg/pymod/CMakeLists.txt b/modules/mol/alg/pymod/CMakeLists.txt index ecf4ca0fa..436871b5d 100644 --- a/modules/mol/alg/pymod/CMakeLists.txt +++ b/modules/mol/alg/pymod/CMakeLists.txt @@ -34,6 +34,7 @@ set(OST_MOL_ALG_PYMOD_MODULES ligand_scoring_base.py ligand_scoring_scrmsd.py ligand_scoring_lddtpli.py + bb_lddt.py ) if (NOT ENABLE_STATIC) diff --git a/modules/mol/alg/pymod/bb_lddt.py b/modules/mol/alg/pymod/bb_lddt.py new file mode 100644 index 000000000..36da48b94 --- /dev/null +++ b/modules/mol/alg/pymod/bb_lddt.py @@ -0,0 +1,550 @@ +import itertools +import numpy as np +from scipy.spatial import distance + +import time +from ost import mol + +class BBlDDTEntity: + """ Helper object for BBlDDT computation + + Holds structural information and getters for interacting chains, i.e. + interfaces. Peptide residues are represented by their CA position + and nucleotides by C3'. + + :param ent: Structure for BBlDDT score computation + :type ent: :class:`ost.mol.EntityView`/:class:`ost.mol.EntityHandle` + :param contact_d: Pairwise distance of residues to be considered as contacts + :type contact_d: :class:`float` + """ + def __init__(self, ent, dist_thresh = 15.0, + dist_diff_thresholds = [0.5, 1.0, 2.0, 4.0]): + pep_query = "(peptide=true and aname=\"CA\")" + nuc_query = "(nucleotide=True and aname=\"C3'\")" + self._view = ent.Select(" or ".join([pep_query, nuc_query])) + self._dist_thresh = dist_thresh + self._dist_diff_thresholds = dist_diff_thresholds + + # the following attributes will be lazily evaluated + self._chain_names = None + self._interacting_chains = None + self._potentially_contributing_chains = None + self._sequence = dict() + self._pos = dict() + self._pair_dist = dict() + self._sc_dist = dict() + self._n_pair_contacts = None + self._n_sc_contacts = None + self._n_contacts = None + # min and max xyz for elements in pos used for fast collision + # detection + self._min_pos = dict() + self._max_pos = dict() + + @property + def view(self): + """ Processed structure + + View that only contains representative atoms. That's CA for peptide + residues and C3' for nucleotides. + + :type: :class:`ost.mol.EntityView` + """ + return self._view + + @property + def dist_thresh(self): + """ Pairwise distance of residues to be considered as contacts + + Given at :class:`BBlDDTEntity` construction + + :type: :class:`float` + """ + return self._dist_thresh + + @property + def dist_diff_thresholds(self): + """ Distance difference thresholds for lDDT computation + + Given at :class:`BBlDDTEntity` construction + + :type: :class:`list` of :class:`float` + """ + return self._dist_diff_thresholds + + @property + def chain_names(self): + """ Chain names in :attr:`~view` + + Names are sorted + + :type: :class:`list` of :class:`str` + """ + if self._chain_names is None: + self._chain_names = sorted([ch.name for ch in self.view.chains]) + return self._chain_names + + @property + def interacting_chains(self): + """ Pairs of chains in :attr:`~view` with at least one contact + + :type: :class:`list` of :class:`tuples` + """ + if self._interacting_chains is None: + # ugly hack: also computes self._n_pair_contacts + # this assumption is made when computing the n_pair_contacts + # attribute + self._interacting_chains = list() + self._n_pair_contacts = list() + for x in itertools.combinations(self.chain_names, 2): + if self.PotentialInteraction(x[0], x[1]): + n = np.count_nonzero(self.PairDist(x[0], x[1]) < self.dist_thresh) + if n > 0: + self._interacting_chains.append(x) + self._n_pair_contacts.append(n) + return self._interacting_chains + + @property + def potentially_contributing_chains(self): + """ Pairs of chains in :attr:`view` with potential contribution to lDDT + + That are pairs of chains that have at least one interaction within + :attr:`~dist_thresh` + max(:attr:`~dist_diff_thresholds`) + """ + if self._potentially_contributing_chains is None: + self._potentially_contributing_chains = list() + max_dist_diff_thresh = max(self.dist_diff_thresholds) + thresh = self.dist_thresh + max_dist_diff_thresh + for x in itertools.combinations(self.chain_names, 2): + if self.PotentialInteraction(x[0], x[1], + slack = max_dist_diff_thresh): + n = np.count_nonzero(self.PairDist(x[0], x[1]) < thresh) + if n > 0: + self._potentially_contributing_chains.append(x) + + return self._potentially_contributing_chains + + @property + def n_pair_contacts(self): + """ Number of contacts in :attr:`~interacting_chains` + + :type: :class:`list` of :class:`int` + """ + if self._n_pair_contacts: + # ugly hack: assumption that computing self.interacting_chains + # also triggers computation of n_pair_contacts + int_chains = self.interacting_chains + return self._n_pair_contacts + + @property + def n_sc_contacts(self): + """ Number of contacts for single chains in :attr:`~chain_names` + + :type: :class:`list` of :class:`int` + """ + if self._n_sc_contacts is None: + self._n_sc_contacts = list() + for cname in self.chain_names: + dist_mat = self.Dist(cname) + n = np.count_nonzero(dist_mat < self.dist_thresh) + # dist_mat is symmetric => first remove the diagonal from n + # as these are distances with itself, i.e. zeroes. + # Division by two then removes the symmetric component. + self._n_sc_contacts.append(int((n-dist_mat.shape[0])/2)) + return self._n_sc_contacts + + @property + def n_contacts(self): + """ Total number of contacts + + That's the sum of all :attr:`~n_pair_contacts` and + :attr:`~n_sc_contacts`. + + :type: :class:`int` + """ + if self._n_contacts is None: + self._n_contacts = sum(self.n_pair_contacts) + sum(self.n_sc_contacts) + return self._n_contacts + + def GetChain(self, chain_name): + """ Get chain by name + + :param chain_name: Chain in :attr:`~view` + :type chain_name: :class:`str` + """ + chain = self.view.FindChain(chain_name) + if not chain.IsValid(): + raise RuntimeError(f"view has no chain named \"{chain_name}\"") + return chain + + def GetSequence(self, chain_name): + """ Get sequence of chain + + Returns sequence of specified chain as raw :class:`str` + + :param chain_name: Chain in :attr:`~view` + :type chain_name: :class:`str` + """ + if chain_name not in self._sequence: + ch = self.GetChain(chain_name) + s = ''.join([r.one_letter_code for r in ch.residues]) + self._sequence[chain_name] = s + return self._sequence[chain_name] + + def GetPos(self, chain_name): + """ Get representative positions of chain + + That's CA positions for peptide residues and C3' for + nucleotides. Returns positions as a numpy array of shape + (n_residues, 3). + + :param chain_name: Chain in :attr:`~view` + :type chain_name: :class:`str` + """ + if chain_name not in self._pos: + ch = self.GetChain(chain_name) + pos = np.zeros((ch.GetResidueCount(), 3)) + for i, r in enumerate(ch.residues): + pos[i,:] = r.atoms[0].GetPos().data + self._pos[chain_name] = pos + return self._pos[chain_name] + + def Dist(self, chain_name): + """ Get pairwise distance of residues within same chain + + Returns distances as square numpy array of shape (a,a) + where a is the number of residues in specified chain. + """ + if chain_name not in self._sc_dist: + self._sc_dist[chain_name] = distance.cdist(self.GetPos(chain_name), + self.GetPos(chain_name), + 'euclidean') + return self._sc_dist[chain_name] + + def PairDist(self, chain_name_one, chain_name_two): + """ Get pairwise distances between two chains + + Returns distances as numpy array of shape (a, b). + Where a is the number of residues of the chain that comes BEFORE the + other in :attr:`~chain_names` + """ + key = (min(chain_name_one, chain_name_two), + max(chain_name_one, chain_name_two)) + if key not in self._pair_dist: + self._pair_dist[key] = distance.cdist(self.GetPos(key[0]), + self.GetPos(key[1]), + 'euclidean') + return self._pair_dist[key] + + def GetMinPos(self, chain_name): + """ Get min x,y,z cooridnates for given chain + + Based on positions that are extracted with GetPos + + :param chain_name: Chain in :attr:`~view` + :type chain_name: :class:`str` + """ + if chain_name not in self._min_pos: + self._min_pos[chain_name] = self.GetPos(chain_name).min(0) + return self._min_pos[chain_name] + + def GetMaxPos(self, chain_name): + """ Get max x,y,z cooridnates for given chain + + Based on positions that are extracted with GetPos + + :param chain_name: Chain in :attr:`~view` + :type chain_name: :class:`str` + """ + if chain_name not in self._max_pos: + self._max_pos[chain_name] = self.GetPos(chain_name).max(0) + return self._max_pos[chain_name] + + def PotentialInteraction(self, chain_name_one, chain_name_two, + slack=0.0): + """ Returns True if chains potentially interact + + Based on crude collision detection. There is no guarantee + that they actually interact if True is returned. However, + if False is returned, they don't interact for sure. + + :param chain_name_one: Chain in :attr:`~view` + :type chain_name_one: class:`str` + :param chain_name_two: Chain in :attr:`~view` + :type chain_name_two: class:`str` + :param slack: Add slack to interaction distance threshold + :type slack: :class:`float` + """ + min_one = self.GetMinPos(chain_name_one) + max_one = self.GetMaxPos(chain_name_one) + min_two = self.GetMinPos(chain_name_two) + max_two = self.GetMaxPos(chain_name_two) + if np.max(min_one - max_two) > (self.dist_thresh + slack): + return False + if np.max(min_two - max_one) > (self.dist_thresh + slack): + return False + return True + + +class BBlDDTScorer: + """ Helper object to compute Backbone only lDDT score + + Tightly integrated into the mechanisms from the chain_mapping module. + The prefered way to derive an object of type :class:`BBlDDTScorer` is + through the static constructor: :func:`~FromMappingResult`. + + lDDT computation in :func:`BBlDDTScorer.Score` implements caching. + Repeated computations with alternative chain mappings thus become faster. + + :param target: Structure designated as "target". Can be fetched from + :class:`ost.mol.alg.chain_mapping.MappingResult` + :type target: :class:`ost.mol.EntityView`/:class:`ost.mol.EntityHandle` + :param chem_groups: Groups of chemically equivalent chains in *target*. + Can be fetched from + :class:`ost.mol.alg.chain_mapping.MappingResult` + :type chem_groups: :class:`list` of :class:`list` of :class:`str` + :param model: Structure designated as "model". Can be fetched from + :class:`ost.mol.alg.chain_mapping.MappingResult` + :type model: :class:`ost.mol.EntityView`/:class:`ost.mol.EntityHandle` + :param alns: Each alignment is accessible with ``alns[(t_chain,m_chain)]``. + First sequence is the sequence of the respective chain in + :attr:`~qsent1`, second sequence the one from :attr:`~qsent2`. + Can be fetched from + :class:`ost.mol.alg.chain_mapping.MappingResult` + :type alns: :class:`dict` with key: :class:`tuple` of :class:`str`, value: + :class:`ost.seq.AlignmentHandle` + :param dist_thresh: Max distance of a pairwise interaction in target + to be considered as contact in lDDT + :type dist_thresh: :class:`float` + :param dist_diff_thresholds: Distance difference thresholds for + lDDT computations + :type dist_diff_thresholds: :class:`list` of :class:`float` + """ + def __init__(self, target, chem_groups, model, alns, dist_thresh = 15.0, + dist_diff_thresholds = [0.5, 1.0, 2.0, 4.0]): + + self._trg = BBlDDTEntity(target, dist_thresh = dist_thresh, + dist_diff_thresholds=dist_diff_thresholds) + + # ensure that target chain names match the ones in chem_groups + chem_group_ch_names = list(itertools.chain.from_iterable(chem_groups)) + if self._trg.chain_names != sorted(chem_group_ch_names): + raise RuntimeError(f"Expect exact same chain names in chem_groups " + f"and in target (which is processed to only " + f"contain peptides/nucleotides). target: " + f"{self._trg.chain_names}, chem_groups: " + f"{chem_group_ch_names}") + + self._chem_groups = chem_groups + self._mdl = BBlDDTEntity(model, dist_thresh = dist_thresh, + dist_diff_thresholds=dist_diff_thresholds) + self._alns = alns + self._dist_diff_thresholds = dist_diff_thresholds + self._dist_thresh = dist_thresh + + # cache for mapped interface scores + # key: tuple of tuple ((trg_ch1, trg_ch2), + # ((mdl_ch1, mdl_ch2)) + # where the first tuple is lexicographically sorted + # the second tuple is mapping dependent + # value: numpy array of len(dist_thresholds) representing the + # respective numbers of fulfilled contacts + self._pairwise_cache = dict() + + # cache for mapped single chain scores + # key: tuple (trg_ch, mdl_ch) + # value: numpy array of len(dist_thresholds) representing the + # respective numbers of fulfilled contacts + self._sc_cache = dict() + + @staticmethod + def FromMappingResult(mapping_result, dist_thresh = 15.0, + dist_diff_thresholds = [0.5, 1.0, 2.0, 4.0]): + """ The preferred way to get a :clas:`BBlDDTScorer` + + Static constructor that derives an object of type :class:`QSScorer` + using a :class:`ost.mol.alg.chain_mapping.MappingResult` + + :param mapping_result: Data source + :type mapping_result: :class:`ost.mol.alg.chain_mapping.MappingResult` + :param dist_thresh: The lDDT distance threshold + :type dist_thresh: :class:`float` + :param dist_diff_thresholds: The lDDT distance difference thresholds + :type dist_diff_thresholds: :class:`list` of :class:`float` + """ + scorer = BBlDDTScorer(mapping_result.target, mapping_result.chem_groups, + mapping_result.model, alns = mapping_result.alns, + dist_thresh = dist_thresh, + dist_diff_thresholds = dist_diff_thresholds) + return scorer + + @property + def trg(self): + """ The :class:`BBlDDTEntity` representing target + + :type: :class:`BBlDDTEntity` + """ + return self._trg + + @property + def mdl(self): + """ The :class:`BBlDDTEntity` representing model + + :type: :class:`BBlDDTEntity` + """ + return self._mdl + + @property + def alns(self): + """ Alignments between chains in :attr:`~trg` and :attr:`~mdl` + + Provided at object construction. Each alignment is accessible with + ``alns[(t_chain,m_chain)]``. First sequence is the sequence of the + respective chain in :attr:`~trg`, second sequence the one from + :attr:`~mdl`. + + :type: :class:`dict` with key: :class:`tuple` of :class:`str`, value: + :class:`ost.seq.AlignmentHandle` + """ + return self._alns + + @property + def chem_groups(self): + """ Groups of chemically equivalent chains in :attr:`~trg` + + Provided at object construction + + :type: :class:`list` of :class:`list` of :class:`str` + """ + return self._chem_groups + + def Score(self, mapping, check=True): + """ Computes Backbone lDDT given chain mapping + + Again, the preferred way is to get *mapping* is from an object + of type :class:`ost.mol.alg.chain_mapping.MappingResult`. + + :param mapping: see + :attr:`ost.mol.alg.chain_mapping.MappingResult.mapping` + :type mapping: :class:`list` of :class:`list` of :class:`str` + :param check: Perform input checks, can be disabled for speed purposes + if you know what you're doing. + :type check: :class:`bool` + :returns: The score + """ + if check: + # ensure that dimensionality of mapping matches self.chem_groups + if len(self.chem_groups) != len(mapping): + raise RuntimeError("Dimensions of self.chem_groups and mapping " + "must match") + for a,b in zip(self.chem_groups, mapping): + if len(a) != len(b): + raise RuntimeError("Dimensions of self.chem_groups and " + "mapping must match") + # ensure that chain names in mapping are all present in qsent2 + for name in itertools.chain.from_iterable(mapping): + if name is not None and name not in self.mdl.chain_names: + raise RuntimeError(f"Each chain in mapping must be present " + f"in self.mdl. No match for " + f"\"{name}\"") + + flat_mapping = dict() + for a, b in zip(self.chem_groups, mapping): + flat_mapping.update({x: y for x, y in zip(a, b) if y is not None}) + + return self.FromFlatMapping(flat_mapping) + + def FromFlatMapping(self, flat_mapping): + """ Same as :func:`Score` but with flat mapping + + :param flat_mapping: Dictionary with target chain names as keys and + the mapped model chain names as value + :type flat_mapping: :class:`dict` with :class:`str` as key and value + :returns: :class:`float` representing lDDT + """ + n_conserved = np.zeros(len(self._dist_diff_thresholds), dtype=np.int32) + + # process single chains + for cname in self.trg.chain_names: + if cname in flat_mapping: + n_conserved += self._NSCConserved(cname, flat_mapping[cname]) + + # process interfaces + for interface in self.trg.interacting_chains: + if interface[0] in flat_mapping and interface[1] in flat_mapping: + mdl_interface = (flat_mapping[interface[0]], + flat_mapping[interface[1]]) + n_conserved += self._NPairConserved(interface, mdl_interface) + + return np.mean(n_conserved / self.trg.n_contacts) + + def _NSCConserved(self, trg_ch, mdl_ch): + if (trg_ch, mdl_ch) in self._sc_cache: + return self._sc_cache[(trg_ch, mdl_ch)] + trg_dist = self.trg.Dist(trg_ch) + mdl_dist = self.mdl.Dist(mdl_ch) + trg_indices, mdl_indices = self._IndexMapping(trg_ch, mdl_ch) + trg_dist = trg_dist[np.ix_(trg_indices, trg_indices)] + mdl_dist = mdl_dist[np.ix_(mdl_indices, mdl_indices)] + # mask to select relevant distances (dist in trg < dist_thresh) + # np.triu zeroes the values below the diagonal + mask = np.triu(trg_dist < self._dist_thresh) + n_diag = trg_dist.shape[0] + trg_dist = trg_dist[mask] + mdl_dist = mdl_dist[mask] + dist_diffs = np.absolute(trg_dist - mdl_dist) + n_conserved = np.zeros(len(self._dist_diff_thresholds), dtype=np.int32) + for thresh_idx, thresh in enumerate(self._dist_diff_thresholds): + N = (dist_diffs < thresh).sum() + # still need to consider the 0.0 dist diffs on the diagonal + n_conserved[thresh_idx] = int((N - n_diag)) + self._sc_cache[(trg_ch, mdl_ch)] = n_conserved + return n_conserved + + def _NPairConserved(self, trg_int, mdl_int): + key_one = (trg_int, mdl_int) + if key_one in self._pairwise_cache: + return self._pairwise_cache[key_one] + key_two = ((trg_int[1], trg_int[0]), (mdl_int[1], mdl_int[0])) + if key_two in self._pairwise_cache: + return self._pairwise_cache[key_two] + trg_dist = self.trg.PairDist(trg_int[0], trg_int[1]) + mdl_dist = self.mdl.PairDist(mdl_int[0], mdl_int[1]) + if trg_int[0] > trg_int[1]: + trg_dist = trg_dist.transpose() + if mdl_int[0] > mdl_int[1]: + mdl_dist = mdl_dist.transpose() + trg_indices_1, mdl_indices_1 = self._IndexMapping(trg_int[0], mdl_int[0]) + trg_indices_2, mdl_indices_2 = self._IndexMapping(trg_int[1], mdl_int[1]) + trg_dist = trg_dist[np.ix_(trg_indices_1, trg_indices_2)] + mdl_dist = mdl_dist[np.ix_(mdl_indices_1, mdl_indices_2)] + # reduce to relevant distances (dist in trg < dist_thresh) + mask = trg_dist < self._dist_thresh + trg_dist = trg_dist[mask] + mdl_dist = mdl_dist[mask] + dist_diffs = np.absolute(trg_dist - mdl_dist) + n_conserved = np.zeros(len(self._dist_diff_thresholds), dtype=np.int32) + for thresh_idx, thresh in enumerate(self._dist_diff_thresholds): + n_conserved[thresh_idx] = (dist_diffs < thresh).sum() + self._pairwise_cache[key_one] = n_conserved + return n_conserved + + def _IndexMapping(self, ch1, ch2): + """ Fetches aln and returns indices of aligned residues + + returns 2 numpy arrays containing the indices of residues in + ch1 and ch2 which are aligned + """ + mapped_indices_1 = list() + mapped_indices_2 = list() + idx_1 = 0 + idx_2 = 0 + for col in self.alns[(ch1, ch2)]: + if col[0] != '-' and col[1] != '-': + mapped_indices_1.append(idx_1) + mapped_indices_2.append(idx_2) + if col[0] != '-': + idx_1 +=1 + if col[1] != '-': + idx_2 +=1 + return (np.array(mapped_indices_1), np.array(mapped_indices_2)) diff --git a/modules/mol/alg/pymod/chain_mapping.py b/modules/mol/alg/pymod/chain_mapping.py index 15ba5b153..189bbae63 100644 --- a/modules/mol/alg/pymod/chain_mapping.py +++ b/modules/mol/alg/pymod/chain_mapping.py @@ -18,6 +18,7 @@ from ost import mol from ost import geom from ost.mol.alg import lddt +from ost.mol.alg import bb_lddt from ost.mol.alg import qsscore def _CSel(ent, cnames): @@ -925,7 +926,7 @@ class ChainMapper: mapping = _lDDTGreedyBlock(the_greed, block_seed_size, block_blocks_per_chem_group) # cached => lDDT computation is fast here - opt_lddt = the_greed.lDDT(self.chem_groups, mapping) + opt_lddt = the_greed.Score(mapping) alns = dict() for ref_group, mdl_group in zip(self.chem_groups, mapping): @@ -1063,7 +1064,7 @@ class ChainMapper: mapping = _QSScoreGreedyBlock(the_greed, block_seed_size, block_blocks_per_chem_group) # cached => QSScore computation is fast here - opt_qsscore = the_greed.Score(mapping, check=False) + opt_qsscore = the_greed.Score(mapping, check=False).QS_global alns = dict() for ref_group, mdl_group in zip(self.chem_groups, mapping): @@ -1230,10 +1231,19 @@ class ChainMapper: performed (greedy_prune_contact_map = True). The default for *n_max_naive* of 40320 corresponds to an octamer (8!=40320). A structure with stoichiometry A6B2 would be 6!*2!=1440 etc. + + If :attr:`~target` has nucleotide sequences, the QS-score target + function is replaced with a backbone only lDDT score that has + an inclusion radius of 30A. """ - return self.GetQSScoreMapping(model, strategy="heuristic", - greedy_prune_contact_map=True, - heuristic_n_max_naive = n_max_naive) + if len(self.polynuc_seqs) > 0: + return self.GetlDDTMapping(model, strategy = "heuristic", + inclusion_radius = 30.0, + heuristic_n_max_naive = n_max_naive) + else: + return self.GetQSScoreMapping(model, strategy="heuristic", + greedy_prune_contact_map=True, + heuristic_n_max_naive = n_max_naive) def GetRepr(self, substructure, model, topn=1, inclusion_radius=15.0, thresholds=[0.5, 1.0, 2.0, 4.0], bb_only=False, @@ -2028,166 +2038,25 @@ def _CheckOneToOneMapping(ref_chains, mdl_chains): else: return None -class _lDDTDecomposer: - - def __init__(self, ref, mdl, ref_mdl_alns, inclusion_radius = 15.0, - thresholds = [0.5, 1.0, 2.0, 4.0]): - """ Compute backbone only lDDT scores for ref/mdl - - Uses the pairwise decomposable property of backbone only lDDT and - implements a caching mechanism to efficiently enumerate different - chain mappings. - """ - - self.ref = ref - self.mdl = mdl - self.ref_mdl_alns = ref_mdl_alns - self.inclusion_radius = inclusion_radius - self.thresholds = thresholds - - # keep track of single chains and interfaces in ref - self.ref_chains = list() # e.g. ['A', 'B', 'C'] - self.ref_interfaces = list() # e.g. [('A', 'B'), ('A', 'C')] - - # holds lDDT scorer for each chain in ref - # key: chain name, value: scorer - self.single_chain_scorer = dict() - - # cache for single chain conserved contacts - # key: tuple (ref_ch, mdl_ch) value: number of conserved contacts - self.single_chain_cache = dict() - - # holds lDDT scorer for each pairwise interface in target - # key: tuple (ref_ch1, ref_ch2), value: scorer - self.interface_scorer = dict() - - # cache for interface conserved contacts - # key: tuple of tuple ((ref_ch1, ref_ch2),((mdl_ch1, mdl_ch2)) - # value: number of conserved contacts - self.interface_cache = dict() - - self.n = 0 - - self._SetupScorer() - - def _SetupScorer(self): - for ch in self.ref.chains: - # Select everything close to that chain - query = f"{self.inclusion_radius} <> " - query += f"[cname={mol.QueryQuoteName(ch.GetName())}] " - query += f"and cname!={mol.QueryQuoteName(ch.GetName())}" - for close_ch in self.ref.Select(query).chains: - k1 = (ch.GetName(), close_ch.GetName()) - k2 = (close_ch.GetName(), ch.GetName()) - if k1 not in self.interface_scorer and \ - k2 not in self.interface_scorer: - dimer_ref = _CSel(self.ref, [k1[0], k1[1]]) - s = lddt.lDDTScorer(dimer_ref, bb_only=True) - self.interface_scorer[k1] = s - self.interface_scorer[k2] = s - self.n += sum([len(x) for x in self.interface_scorer[k1].ref_indices_ic]) - self.ref_interfaces.append(k1) - # single chain scorer are actually interface scorers to save - # some distance calculations - if ch.GetName() not in self.single_chain_scorer: - self.single_chain_scorer[ch.GetName()] = s - self.n += s.GetNChainContacts(ch.GetName(), - no_interchain=True) - self.ref_chains.append(ch.GetName()) - if close_ch.GetName() not in self.single_chain_scorer: - self.single_chain_scorer[close_ch.GetName()] = s - self.n += s.GetNChainContacts(close_ch.GetName(), - no_interchain=True) - self.ref_chains.append(close_ch.GetName()) - - # add any missing single chain scorer - for ch in self.ref.chains: - if ch.GetName() not in self.single_chain_scorer: - single_chain_ref = _CSel(self.ref, [ch.GetName()]) - self.single_chain_scorer[ch.GetName()] = \ - lddt.lDDTScorer(single_chain_ref, bb_only = True) - self.n += self.single_chain_scorer[ch.GetName()].n_distances - self.ref_chains.append(ch.GetName()) - - def lDDT(self, ref_chain_groups, mdl_chain_groups): - - flat_map = dict() - for ref_chains, mdl_chains in zip(ref_chain_groups, mdl_chain_groups): - for ref_ch, mdl_ch in zip(ref_chains, mdl_chains): - flat_map[ref_ch] = mdl_ch - - return self.lDDTFromFlatMap(flat_map) - - - def lDDTFromFlatMap(self, flat_map): - conserved = 0 - - # do single chain scores - for ref_ch in self.ref_chains: - if ref_ch in flat_map and flat_map[ref_ch] is not None: - conserved += self.SCCounts(ref_ch, flat_map[ref_ch]) - - # do interfaces - for ref_ch1, ref_ch2 in self.ref_interfaces: - if ref_ch1 in flat_map and ref_ch2 in flat_map: - mdl_ch1 = flat_map[ref_ch1] - mdl_ch2 = flat_map[ref_ch2] - if mdl_ch1 is not None and mdl_ch2 is not None: - conserved += self.IntCounts(ref_ch1, ref_ch2, mdl_ch1, - mdl_ch2) - - return conserved / (len(self.thresholds) * self.n) - - def SCCounts(self, ref_ch, mdl_ch): - if not (ref_ch, mdl_ch) in self.single_chain_cache: - alns = dict() - alns[mdl_ch] = self.ref_mdl_alns[(ref_ch, mdl_ch)] - mdl_sel = _CSel(self.mdl, [mdl_ch]) - s = self.single_chain_scorer[ref_ch] - _,_,_,conserved,_,_,_ = s.lDDT(mdl_sel, - residue_mapping=alns, - return_dist_test=True, - no_interchain=True, - chain_mapping={mdl_ch: ref_ch}, - check_resnames=False) - self.single_chain_cache[(ref_ch, mdl_ch)] = conserved - return self.single_chain_cache[(ref_ch, mdl_ch)] - - def IntCounts(self, ref_ch1, ref_ch2, mdl_ch1, mdl_ch2): - k1 = ((ref_ch1, ref_ch2),(mdl_ch1, mdl_ch2)) - k2 = ((ref_ch2, ref_ch1),(mdl_ch2, mdl_ch1)) - if k1 not in self.interface_cache and k2 not in self.interface_cache: - alns = dict() - alns[mdl_ch1] = self.ref_mdl_alns[(ref_ch1, mdl_ch1)] - alns[mdl_ch2] = self.ref_mdl_alns[(ref_ch2, mdl_ch2)] - mdl_sel = _CSel(self.mdl, [mdl_ch1, mdl_ch2]) - s = self.interface_scorer[(ref_ch1, ref_ch2)] - _,_,_,conserved,_,_,_ = s.lDDT(mdl_sel, - residue_mapping=alns, - return_dist_test=True, - no_intrachain=True, - chain_mapping={mdl_ch1: ref_ch1, - mdl_ch2: ref_ch2}, - check_resnames=False) - self.interface_cache[k1] = conserved - self.interface_cache[k2] = conserved - return self.interface_cache[k1] - -class _lDDTGreedySearcher(_lDDTDecomposer): +class _lDDTGreedySearcher(bb_lddt.BBlDDTScorer): def __init__(self, ref, mdl, ref_chem_groups, mdl_chem_groups, ref_mdl_alns, inclusion_radius = 15.0, thresholds = [0.5, 1.0, 2.0, 4.0], steep_opt_rate = None): + """ Greedy extension of already existing but incomplete chain mappings """ - super().__init__(ref, mdl, ref_mdl_alns, - inclusion_radius = inclusion_radius, - thresholds = thresholds) + super().__init__(ref, ref_chem_groups, mdl, ref_mdl_alns, + dist_thresh=inclusion_radius, + dist_diff_thresholds=thresholds) + + self.mdl_chem_groups = mdl_chem_groups self.steep_opt_rate = steep_opt_rate - self.neighbors = {k: set() for k in self.ref_chains} - for k in self.interface_scorer.keys(): - self.neighbors[k[0]].add(k[1]) - self.neighbors[k[1]].add(k[0]) + + self.neighbors = {k: set() for k in self.trg.chain_names} + for interface in self.trg.interacting_chains: + self.neighbors[interface[0]].add(interface[1]) + self.neighbors[interface[1]].add(interface[0]) assert(len(ref_chem_groups) == len(mdl_chem_groups)) self.ref_chem_groups = ref_chem_groups @@ -2203,16 +2072,10 @@ class _lDDTGreedySearcher(_lDDTDecomposer): # keep track of mdl chains that potentially give lDDT contributions, # i.e. they have locations within inclusion_radius + max(thresholds) - self.mdl_neighbors = dict() - d = self.inclusion_radius + max(self.thresholds) - for ch in self.mdl.chains: - ch_name = ch.GetName() - self.mdl_neighbors[ch_name] = set() - query = f"{d} <> [cname={mol.QueryQuoteName(ch_name)}]" - query += f" and cname !={mol.QueryQuoteName(ch_name)}" - for close_ch in self.mdl.Select(query).chains: - self.mdl_neighbors[ch_name].add(close_ch.GetName()) - + self.mdl_neighbors = {k: set() for k in self.mdl.chain_names} + for interface in self.mdl.potentially_contributing_chains: + self.mdl_neighbors[interface[0]].add(interface[1]) + self.mdl_neighbors[interface[1]].add(interface[0]) def ExtendMapping(self, mapping, max_ext = None): @@ -2261,14 +2124,14 @@ class _lDDTGreedySearcher(_lDDTDecomposer): chem_group_idx = self.ref_ch_group_mapper[ref_ch] for mdl_ch in free_mdl_chains[chem_group_idx]: # single chain score - n_single = self.SCCounts(ref_ch, mdl_ch) + n_single = self._NSCConserved(ref_ch, mdl_ch).sum() # scores towards neighbors that are already mapped n_inter = 0 for neighbor in self.neighbors[ref_ch]: if neighbor in mapping and mapping[neighbor] in \ self.mdl_neighbors[mdl_ch]: - n_inter += self.IntCounts(ref_ch, neighbor, mdl_ch, - mapping[neighbor]) + n_inter += self._NPairConserved((ref_ch, neighbor), + (mdl_ch, mapping[neighbor])).sum() n = n_single + n_inter if n_inter > 0 and n > max_n: @@ -2322,7 +2185,7 @@ class _lDDTGreedySearcher(_lDDTDecomposer): # try all possible mapping swaps. Swaps that improve the score are # immediately accepted and we start all over again - current_lddt = self.lDDTFromFlatMap(mapping) + current_lddt = self.FromFlatMapping(mapping) something_happened = True while something_happened: something_happened = False @@ -2333,13 +2196,12 @@ class _lDDTGreedySearcher(_lDDTDecomposer): swapped_mapping = dict(mapping) swapped_mapping[ch1] = mapping[ch2] swapped_mapping[ch2] = mapping[ch1] - score = self.lDDTFromFlatMap(swapped_mapping) + score = self.FromFlatMapping(swapped_mapping) if score > current_lddt: something_happened = True mapping = swapped_mapping current_lddt = score break - return mapping @@ -2416,7 +2278,7 @@ def _lDDTGreedyFast(the_greed): n_best = 0 best_seed = None for seed in seeds: - n = the_greed.SCCounts(seed[0], seed[1]) + n = the_greed._NSCConserved(seed[0], seed[1]).sum() if n > n_best: n_best = n best_seed = seed @@ -2470,7 +2332,7 @@ def _lDDTGreedyFull(the_greed): tmp_mapping = dict(mapping) tmp_mapping[remnant_seed[0]] = remnant_seed[1] tmp_mapping = the_greed.ExtendMapping(tmp_mapping) - score = the_greed.lDDTFromFlatMap(tmp_mapping) + score = the_greed.FromFlatMapping(tmp_mapping) if score > best_score: best_score = score best_mapping = tmp_mapping @@ -2478,7 +2340,7 @@ def _lDDTGreedyFull(the_greed): something_happened = True mapping = best_mapping - score = the_greed.lDDTFromFlatMap(mapping) + score = the_greed.FromFlatMapping(mapping) if score > best_overall_score: best_overall_score = score best_overall_mapping = mapping @@ -2542,7 +2404,7 @@ def _lDDTGreedyBlock(the_greed, seed_size, blocks_per_chem_group): seed = dict(mapping) seed.update({s[0]: s[1]}) seed = the_greed.ExtendMapping(seed, max_ext = max_ext) - seed_lddt = the_greed.lDDTFromFlatMap(seed) + seed_lddt = the_greed.FromFlatMapping(seed) if seed_lddt > best_score: best_score = seed_lddt best_mapping = seed @@ -2558,7 +2420,7 @@ def _lDDTGreedyBlock(the_greed, seed_size, blocks_per_chem_group): best_mapping = None for seed in starting_blocks: seed = the_greed.ExtendMapping(seed) - seed_lddt = the_greed.lDDTFromFlatMap(seed) + seed_lddt = the_greed.FromFlatMapping(seed) if seed_lddt > best_lddt: best_lddt = seed_lddt best_mapping = seed diff --git a/modules/mol/alg/tests/CMakeLists.txt b/modules/mol/alg/tests/CMakeLists.txt index e1a1aaf82..e21f07468 100644 --- a/modules/mol/alg/tests/CMakeLists.txt +++ b/modules/mol/alg/tests/CMakeLists.txt @@ -14,6 +14,7 @@ set(OST_MOL_ALG_UNIT_TESTS test_contact_score.py test_biounit.py test_ost_dockq.py + test_bblddt.py ) if (COMPOUND_LIB) diff --git a/modules/mol/alg/tests/test_bblddt.py b/modules/mol/alg/tests/test_bblddt.py new file mode 100644 index 000000000..38be34c90 --- /dev/null +++ b/modules/mol/alg/tests/test_bblddt.py @@ -0,0 +1,117 @@ +import unittest, os, sys +import ost +from ost import conop +from ost import io, mol, seq, settings +import time +# check if we can import: fails if numpy or scipy not available +try: + import numpy as np + from ost.mol.alg.bb_lddt import * + from ost.mol.alg.lddt import * + from ost.mol.alg.chain_mapping import * +except ImportError: + print("Failed to import bb_lddt.py. Happens when numpy or scipy "\ + "missing. Ignoring qsscore.py tests.") + sys.exit(0) + +def _LoadFile(file_name): + """Helper to avoid repeating input path over and over.""" + return io.LoadPDB(os.path.join('testfiles', file_name)) + +class TestBBlDDT(unittest.TestCase): + + def test_bblddtentity(self): + ent = _LoadFile("3l1p.1.pdb") + ent = BBlDDTEntity(ent) + self.assertEqual(len(ent.view.chains), 4) + self.assertEqual(ent.GetChain("A").GetName(), "A") + self.assertEqual(ent.GetChain("B").GetName(), "B") + self.assertEqual(ent.GetChain("C").GetName(), "C") + self.assertEqual(ent.GetChain("D").GetName(), "D") + self.assertRaises(Exception, ent.GetChain, "E") + self.assertEqual(ent.chain_names, ["A", "B", "C", "D"]) + self.assertEqual(ent.GetSequence("A"), "DMKALQKELEQFAKLLKQKRITLGYTQADVGLTLGVLFGKVFSQTTISRFEALQLSLKNMSKLRPLLEKWVEEADNNENLQEISKSVQARKRKRTSIENRVRWSLETMFLKSPKPSLQQITHIANQLGLEKDVVRVWFSNRRQKGKR") + self.assertEqual(ent.GetSequence("B"), "KALQKELEQFAKLLKQKRITLGYTQADVGLTLGVLFGKVFSQTTISRFEALQLSLKNMSKLRPLLEKWVEEADNNENLQEISKSQARKRKRTSIENRVRWSLETMFLKSPKPSLQQITHIANQLGLEKDVVRVWFSNRRQKGKRS") + self.assertEqual(ent.GetSequence("C"), "TCCACATTTGAAAGGCAAATGGA") + self.assertEqual(ent.GetSequence("D"), "ATCCATTTGCCTTTCAAATGTGG") + + # check for a couple of positions with manually extracted values + + # GLU + pos = ent.GetPos("B") + self.assertAlmostEqual(pos[5,0], -0.901, places=3) + self.assertAlmostEqual(pos[5,1], 28.167, places=3) + self.assertAlmostEqual(pos[5,2], 13.955, places=3) + + # GLY + pos = ent.GetPos("A") + self.assertAlmostEqual(pos[23,0], 17.563, places=3) + self.assertAlmostEqual(pos[23,1], -4.082, places=3) + self.assertAlmostEqual(pos[23,2], 29.005, places=3) + + # Cytosine + pos = ent.GetPos("C") + self.assertAlmostEqual(pos[4,0], 14.796, places=3) + self.assertAlmostEqual(pos[4,1], 24.653, places=3) + self.assertAlmostEqual(pos[4,2], 59.318, places=3) + + + # check pairwise dist, chain names are always sorted => + # A is rows, C is cols + dist_one = ent.PairDist("A", "C") + dist_two = ent.PairDist("C", "A") + self.assertTrue(np.array_equal(dist_one, dist_two)) + self.assertEqual(dist_one.shape[0], len(ent.GetSequence("A"))) + self.assertEqual(dist_one.shape[1], len(ent.GetSequence("C"))) + + # check some random distance between the Gly and Cytosine that we already + # checked above + self.assertAlmostEqual(dist_one[23,4], 41.86, places=2) + + # all chains interact with each other... but hey, check nevertheless + self.assertEqual(ent.interacting_chains, [("A", "B"), ("A", "C"), + ("A", "D"), ("B", "C"), + ("B", "D"), ("C", "D")]) + + def test_bb_lddt_scorer(self): + + target = _LoadFile("3l1p.1.pdb") + model = _LoadFile("3l1p.1_model.pdb") + + # we need to derive a chain mapping prior to scoring + mapper = ChainMapper(target) + res = mapper.GetRMSDMapping(model, strategy="greedy_iterative") + + # lets compare with lddt reference implementation + + reference_lddt_scorer = lDDTScorer(target, bb_only=True) + + # make alignments accessible by mdl seq name + alns = dict() + for aln in res.alns.values(): + mdl_seq = aln.GetSequence(1) + alns[mdl_seq.name] = aln + + # lDDT requires a flat mapping with mdl_ch as key and trg_ch as value + flat_mapping = res.GetFlatMapping(mdl_as_key=True) + lddt_chain_mapping = dict() + for mdl_ch, trg_ch in flat_mapping.items(): + if mdl_ch in alns: + lddt_chain_mapping[mdl_ch] = trg_ch + + reference_lddt_score = reference_lddt_scorer.lDDT(model, + chain_mapping = lddt_chain_mapping, + residue_mapping = alns, + check_resnames=False)[0] + + bb_lddt_scorer = BBlDDTScorer.FromMappingResult(res) + bb_lddt_score = bb_lddt_scorer.Score(res.mapping) + + self.assertAlmostEqual(reference_lddt_score, bb_lddt_score, places = 4) + +if __name__ == "__main__": + from ost import testutils + if testutils.DefaultCompoundLibIsSet(): + testutils.RunTests() + else: + print('No compound lib available. Ignoring test_bblddt.py tests.') -- GitLab