diff --git a/actions/ost-compare-ligand-structures b/actions/ost-compare-ligand-structures index 7c5fc8776e88f5fa153bb8bdddbad4a6847676db..a8491fea3dc7c6d780333f36731c4fbdc60e61e4 100644 --- a/actions/ost-compare-ligand-structures +++ b/actions/ost-compare-ligand-structures @@ -47,10 +47,23 @@ options, this is a dictionary with three keys: content of the JSON output will be \"status\" set to FAILURE and an additional key: "traceback". -Each score is opt-in and must be enabled with optional arguments. The scores -perform a model/reference ligand assignment and report a score for each assigned -model ligand. Optionally, unassigned model ligands are reported with a null -score and a reason why no assignment has been performed (--unassigned/-u). +Each score is opt-in and the respective results are available in two keys: + + * "model_ligands": Model ligand centric scoring based on model/reference + ligand assignment. A score including meta data is reported for each assigned + model ligand given the assigned target ligand. Unassigned model ligands are + reported with a null score and a reason why no assignment has been performed. + + * "full": The full all vs. all scoring results. Yet another dictionary with + keys: + + * "assignment": List of pairs in form (ref_lig_idx, mdl_lig_idx) specifying + the ligands in "reference_ligands"/"model_ligands". + * "scores": A dictionary with key in form (ref_lig_idx, mdl_lig_idx) and + value yet another dict with score information for each possible + reference/model ligand pair. The respective score is None if no score could + be computed. This can simply be a mismatch between the two ligands. This or + other reasons are reported. """ import argparse @@ -191,16 +204,6 @@ def _ParseArgs(): default=0.2, help=("Coverage delta for partial ligand assignment.")) - parser.add_argument( - "-u", - "--unassigned", - dest="unassigned", - default=False, - action="store_true", - help=("Report unassigned model ligands in the output together with " - "assigned ligands, with a null score, and reason for not being " - "assigned.")) - parser.add_argument( '-v', '--verbosity', @@ -511,31 +514,48 @@ def _Process(model, model_ligands, reference, reference_ligands, args): if args.lddt_pli: out["lddt_pli"] = dict() + out["lddt_pli"]["model_ligands"] = dict() + out["lddt_pli"]["full"] = dict() for lig_pair in lddtpli_scorer.assignment: score = float(lddtpli_scorer.score_matrix[lig_pair[0], lig_pair[1]]) coverage = float(lddtpli_scorer.coverage_matrix[lig_pair[0], lig_pair[1]]) aux_data = lddtpli_scorer.aux_matrix[lig_pair[0], lig_pair[1]] target_key = out["reference_ligands"][lig_pair[0]] model_key = out["model_ligands"][lig_pair[1]] - out["lddt_pli"][model_key] = {"lddt_pli": score, - "coverage": coverage, - "lddt_pli_n_contacts": aux_data["lddt_pli_n_contacts"], - "model_ligand": model_key, - "reference_ligand": target_key, - "bs_ref_res": [_QualifiedResidueNotation(r) for r in - aux_data["bs_ref_res"]], - "bs_mdl_res": [_QualifiedResidueNotation(r) for r in - aux_data["bs_mdl_res"]]} - if args.unassigned: - for i in lddtpli_scorer.unassigned_model_ligands: - model_key = out["model_ligands"][i] - reason = lddtpli_scorer.guess_model_ligand_unassigned_reason(i) - out["lddt_pli"][model_key] = {"lddt_pli": None, - "unassigned_reason": reason} + out["lddt_pli"]["model_ligands"][model_key] = {"lddt_pli": score, + "coverage": coverage, + "lddt_pli_n_contacts": aux_data["lddt_pli_n_contacts"], + "model_ligand": model_key, + "reference_ligand": target_key, + "bs_ref_res": [_QualifiedResidueNotation(r) for r in + aux_data["bs_ref_res"]], + "bs_mdl_res": [_QualifiedResidueNotation(r) for r in + aux_data["bs_mdl_res"]]} + + for i in lddtpli_scorer.unassigned_model_ligands: + model_key = out["model_ligands"][i] + reason = lddtpli_scorer.guess_model_ligand_unassigned_reason(i) + out["lddt_pli"]["model_ligands"][model_key] = {"lddt_pli": None, + "unassigned_reason": reason} + + out["lddt_pli"]["full"]["assignment"] = lddtpli_scorer.assignment + out["lddt_pli"]["full"]["scores"] = dict() + + shape = lddtpli_scorer.score_matrix.shape + for ref_lig_idx in range(shape[0]): + for mdl_lig_idx in range(shape[1]): + score = float(lddtpli_scorer.score_matrix[(ref_lig_idx, mdl_lig_idx)]) + state = int(lddtpli_scorer.state_matrix[(ref_lig_idx, mdl_lig_idx)]) + desc = lddtpli_scorer.state_decoding[state] + pair_key = [ref_lig_idx, mdl_lig_idx] + out["lddt_pli"]["full"]["scores"][pair_key] = {"score": score, + "state": desc} if args.rmsd: out["rmsd"] = dict() + out["rmsd"]["model_ligands"] = dict() + out["rmsd"]["full"] = dict() for lig_pair in scrmsd_scorer.assignment: score = float(scrmsd_scorer.score_matrix[lig_pair[0], lig_pair[1]]) coverage = float(scrmsd_scorer.coverage_matrix[lig_pair[0], lig_pair[1]]) @@ -543,29 +563,42 @@ def _Process(model, model_ligands, reference, reference_ligands, args): target_key = out["reference_ligands"][lig_pair[0]] model_key = out["model_ligands"][lig_pair[1]] transform_data = aux_data["transform"].data - out["rmsd"][model_key] = {"rmsd": score, - "coverage": coverage, - "lddt_lp": aux_data["lddt_lp"], - "bb_rmsd": aux_data["bb_rmsd"], - "model_ligand": model_key, - "reference_ligand": target_key, - "chain_mapping": aux_data["chain_mapping"], - "bs_ref_res": [_QualifiedResidueNotation(r) for r in - aux_data["bs_ref_res"]], - "bs_ref_res_mapped": [_QualifiedResidueNotation(r) for r in - aux_data["bs_ref_res_mapped"]], - "bs_mdl_res_mapped": [_QualifiedResidueNotation(r) for r in - aux_data["bs_mdl_res_mapped"]], - "inconsistent_residues": [_QualifiedResidueNotation(r) for r in - aux_data["inconsistent_residues"]], - "transform": [transform_data[i:i + 4] - for i in range(0, len(transform_data), 4)]} - if args.unassigned: - for i in scrmsd_scorer.unassigned_model_ligands: - model_key = out["model_ligands"][i] - reason = scrmsd_scorer.guess_model_ligand_unassigned_reason(i) - out["rmsd"][model_key] = {"rmsd": None, - "unassigned_reason": reason} + out["rmsd"]["model_ligands"][model_key] = {"rmsd": score, + "coverage": coverage, + "lddt_lp": aux_data["lddt_lp"], + "bb_rmsd": aux_data["bb_rmsd"], + "model_ligand": model_key, + "reference_ligand": target_key, + "chain_mapping": aux_data["chain_mapping"], + "bs_ref_res": [_QualifiedResidueNotation(r) for r in + aux_data["bs_ref_res"]], + "bs_ref_res_mapped": [_QualifiedResidueNotation(r) for r in + aux_data["bs_ref_res_mapped"]], + "bs_mdl_res_mapped": [_QualifiedResidueNotation(r) for r in + aux_data["bs_mdl_res_mapped"]], + "inconsistent_residues": [_QualifiedResidueNotation(r) for r in + aux_data["inconsistent_residues"]], + "transform": [transform_data[i:i + 4] + for i in range(0, len(transform_data), 4)]} + for i in scrmsd_scorer.unassigned_model_ligands: + model_key = out["model_ligands"][i] + reason = scrmsd_scorer.guess_model_ligand_unassigned_reason(i) + out["rmsd"]["model_ligands"][model_key] = {"rmsd": None, + "unassigned_reason": reason} + + out["rmsd"]["full"]["assignment"] = scrmsd_scorer.assignment + out["rmsd"]["full"]["scores"] = dict() + + shape = scrmsd_scorer.score_matrix.shape + for ref_lig_idx in range(shape[0]): + for mdl_lig_idx in range(shape[1]): + score = float(scrmsd_scorer.score_matrix[(ref_lig_idx, mdl_lig_idx)]) + state = int(scrmsd_scorer.state_matrix[(ref_lig_idx, mdl_lig_idx)]) + desc = scrmsd_scorer.state_decoding[state] + pair_key = [ref_lig_idx, mdl_lig_idx] + out["rmsd"]["full"]["scores"][pair_key] = {"score": score, + "state": desc} + return out diff --git a/modules/mol/alg/pymod/CMakeLists.txt b/modules/mol/alg/pymod/CMakeLists.txt index ecf4ca0fa0a70ca1ed658b6d4b71bef12830731e..436871b5d26b2abef7e9507022e9f074fde84413 100644 --- a/modules/mol/alg/pymod/CMakeLists.txt +++ b/modules/mol/alg/pymod/CMakeLists.txt @@ -34,6 +34,7 @@ set(OST_MOL_ALG_PYMOD_MODULES ligand_scoring_base.py ligand_scoring_scrmsd.py ligand_scoring_lddtpli.py + bb_lddt.py ) if (NOT ENABLE_STATIC) diff --git a/modules/mol/alg/pymod/bb_lddt.py b/modules/mol/alg/pymod/bb_lddt.py new file mode 100644 index 0000000000000000000000000000000000000000..36da48b9416acc3f14e4fadd95db0ff5580dfd4b --- /dev/null +++ b/modules/mol/alg/pymod/bb_lddt.py @@ -0,0 +1,550 @@ +import itertools +import numpy as np +from scipy.spatial import distance + +import time +from ost import mol + +class BBlDDTEntity: + """ Helper object for BBlDDT computation + + Holds structural information and getters for interacting chains, i.e. + interfaces. Peptide residues are represented by their CA position + and nucleotides by C3'. + + :param ent: Structure for BBlDDT score computation + :type ent: :class:`ost.mol.EntityView`/:class:`ost.mol.EntityHandle` + :param contact_d: Pairwise distance of residues to be considered as contacts + :type contact_d: :class:`float` + """ + def __init__(self, ent, dist_thresh = 15.0, + dist_diff_thresholds = [0.5, 1.0, 2.0, 4.0]): + pep_query = "(peptide=true and aname=\"CA\")" + nuc_query = "(nucleotide=True and aname=\"C3'\")" + self._view = ent.Select(" or ".join([pep_query, nuc_query])) + self._dist_thresh = dist_thresh + self._dist_diff_thresholds = dist_diff_thresholds + + # the following attributes will be lazily evaluated + self._chain_names = None + self._interacting_chains = None + self._potentially_contributing_chains = None + self._sequence = dict() + self._pos = dict() + self._pair_dist = dict() + self._sc_dist = dict() + self._n_pair_contacts = None + self._n_sc_contacts = None + self._n_contacts = None + # min and max xyz for elements in pos used for fast collision + # detection + self._min_pos = dict() + self._max_pos = dict() + + @property + def view(self): + """ Processed structure + + View that only contains representative atoms. That's CA for peptide + residues and C3' for nucleotides. + + :type: :class:`ost.mol.EntityView` + """ + return self._view + + @property + def dist_thresh(self): + """ Pairwise distance of residues to be considered as contacts + + Given at :class:`BBlDDTEntity` construction + + :type: :class:`float` + """ + return self._dist_thresh + + @property + def dist_diff_thresholds(self): + """ Distance difference thresholds for lDDT computation + + Given at :class:`BBlDDTEntity` construction + + :type: :class:`list` of :class:`float` + """ + return self._dist_diff_thresholds + + @property + def chain_names(self): + """ Chain names in :attr:`~view` + + Names are sorted + + :type: :class:`list` of :class:`str` + """ + if self._chain_names is None: + self._chain_names = sorted([ch.name for ch in self.view.chains]) + return self._chain_names + + @property + def interacting_chains(self): + """ Pairs of chains in :attr:`~view` with at least one contact + + :type: :class:`list` of :class:`tuples` + """ + if self._interacting_chains is None: + # ugly hack: also computes self._n_pair_contacts + # this assumption is made when computing the n_pair_contacts + # attribute + self._interacting_chains = list() + self._n_pair_contacts = list() + for x in itertools.combinations(self.chain_names, 2): + if self.PotentialInteraction(x[0], x[1]): + n = np.count_nonzero(self.PairDist(x[0], x[1]) < self.dist_thresh) + if n > 0: + self._interacting_chains.append(x) + self._n_pair_contacts.append(n) + return self._interacting_chains + + @property + def potentially_contributing_chains(self): + """ Pairs of chains in :attr:`view` with potential contribution to lDDT + + That are pairs of chains that have at least one interaction within + :attr:`~dist_thresh` + max(:attr:`~dist_diff_thresholds`) + """ + if self._potentially_contributing_chains is None: + self._potentially_contributing_chains = list() + max_dist_diff_thresh = max(self.dist_diff_thresholds) + thresh = self.dist_thresh + max_dist_diff_thresh + for x in itertools.combinations(self.chain_names, 2): + if self.PotentialInteraction(x[0], x[1], + slack = max_dist_diff_thresh): + n = np.count_nonzero(self.PairDist(x[0], x[1]) < thresh) + if n > 0: + self._potentially_contributing_chains.append(x) + + return self._potentially_contributing_chains + + @property + def n_pair_contacts(self): + """ Number of contacts in :attr:`~interacting_chains` + + :type: :class:`list` of :class:`int` + """ + if self._n_pair_contacts: + # ugly hack: assumption that computing self.interacting_chains + # also triggers computation of n_pair_contacts + int_chains = self.interacting_chains + return self._n_pair_contacts + + @property + def n_sc_contacts(self): + """ Number of contacts for single chains in :attr:`~chain_names` + + :type: :class:`list` of :class:`int` + """ + if self._n_sc_contacts is None: + self._n_sc_contacts = list() + for cname in self.chain_names: + dist_mat = self.Dist(cname) + n = np.count_nonzero(dist_mat < self.dist_thresh) + # dist_mat is symmetric => first remove the diagonal from n + # as these are distances with itself, i.e. zeroes. + # Division by two then removes the symmetric component. + self._n_sc_contacts.append(int((n-dist_mat.shape[0])/2)) + return self._n_sc_contacts + + @property + def n_contacts(self): + """ Total number of contacts + + That's the sum of all :attr:`~n_pair_contacts` and + :attr:`~n_sc_contacts`. + + :type: :class:`int` + """ + if self._n_contacts is None: + self._n_contacts = sum(self.n_pair_contacts) + sum(self.n_sc_contacts) + return self._n_contacts + + def GetChain(self, chain_name): + """ Get chain by name + + :param chain_name: Chain in :attr:`~view` + :type chain_name: :class:`str` + """ + chain = self.view.FindChain(chain_name) + if not chain.IsValid(): + raise RuntimeError(f"view has no chain named \"{chain_name}\"") + return chain + + def GetSequence(self, chain_name): + """ Get sequence of chain + + Returns sequence of specified chain as raw :class:`str` + + :param chain_name: Chain in :attr:`~view` + :type chain_name: :class:`str` + """ + if chain_name not in self._sequence: + ch = self.GetChain(chain_name) + s = ''.join([r.one_letter_code for r in ch.residues]) + self._sequence[chain_name] = s + return self._sequence[chain_name] + + def GetPos(self, chain_name): + """ Get representative positions of chain + + That's CA positions for peptide residues and C3' for + nucleotides. Returns positions as a numpy array of shape + (n_residues, 3). + + :param chain_name: Chain in :attr:`~view` + :type chain_name: :class:`str` + """ + if chain_name not in self._pos: + ch = self.GetChain(chain_name) + pos = np.zeros((ch.GetResidueCount(), 3)) + for i, r in enumerate(ch.residues): + pos[i,:] = r.atoms[0].GetPos().data + self._pos[chain_name] = pos + return self._pos[chain_name] + + def Dist(self, chain_name): + """ Get pairwise distance of residues within same chain + + Returns distances as square numpy array of shape (a,a) + where a is the number of residues in specified chain. + """ + if chain_name not in self._sc_dist: + self._sc_dist[chain_name] = distance.cdist(self.GetPos(chain_name), + self.GetPos(chain_name), + 'euclidean') + return self._sc_dist[chain_name] + + def PairDist(self, chain_name_one, chain_name_two): + """ Get pairwise distances between two chains + + Returns distances as numpy array of shape (a, b). + Where a is the number of residues of the chain that comes BEFORE the + other in :attr:`~chain_names` + """ + key = (min(chain_name_one, chain_name_two), + max(chain_name_one, chain_name_two)) + if key not in self._pair_dist: + self._pair_dist[key] = distance.cdist(self.GetPos(key[0]), + self.GetPos(key[1]), + 'euclidean') + return self._pair_dist[key] + + def GetMinPos(self, chain_name): + """ Get min x,y,z cooridnates for given chain + + Based on positions that are extracted with GetPos + + :param chain_name: Chain in :attr:`~view` + :type chain_name: :class:`str` + """ + if chain_name not in self._min_pos: + self._min_pos[chain_name] = self.GetPos(chain_name).min(0) + return self._min_pos[chain_name] + + def GetMaxPos(self, chain_name): + """ Get max x,y,z cooridnates for given chain + + Based on positions that are extracted with GetPos + + :param chain_name: Chain in :attr:`~view` + :type chain_name: :class:`str` + """ + if chain_name not in self._max_pos: + self._max_pos[chain_name] = self.GetPos(chain_name).max(0) + return self._max_pos[chain_name] + + def PotentialInteraction(self, chain_name_one, chain_name_two, + slack=0.0): + """ Returns True if chains potentially interact + + Based on crude collision detection. There is no guarantee + that they actually interact if True is returned. However, + if False is returned, they don't interact for sure. + + :param chain_name_one: Chain in :attr:`~view` + :type chain_name_one: class:`str` + :param chain_name_two: Chain in :attr:`~view` + :type chain_name_two: class:`str` + :param slack: Add slack to interaction distance threshold + :type slack: :class:`float` + """ + min_one = self.GetMinPos(chain_name_one) + max_one = self.GetMaxPos(chain_name_one) + min_two = self.GetMinPos(chain_name_two) + max_two = self.GetMaxPos(chain_name_two) + if np.max(min_one - max_two) > (self.dist_thresh + slack): + return False + if np.max(min_two - max_one) > (self.dist_thresh + slack): + return False + return True + + +class BBlDDTScorer: + """ Helper object to compute Backbone only lDDT score + + Tightly integrated into the mechanisms from the chain_mapping module. + The prefered way to derive an object of type :class:`BBlDDTScorer` is + through the static constructor: :func:`~FromMappingResult`. + + lDDT computation in :func:`BBlDDTScorer.Score` implements caching. + Repeated computations with alternative chain mappings thus become faster. + + :param target: Structure designated as "target". Can be fetched from + :class:`ost.mol.alg.chain_mapping.MappingResult` + :type target: :class:`ost.mol.EntityView`/:class:`ost.mol.EntityHandle` + :param chem_groups: Groups of chemically equivalent chains in *target*. + Can be fetched from + :class:`ost.mol.alg.chain_mapping.MappingResult` + :type chem_groups: :class:`list` of :class:`list` of :class:`str` + :param model: Structure designated as "model". Can be fetched from + :class:`ost.mol.alg.chain_mapping.MappingResult` + :type model: :class:`ost.mol.EntityView`/:class:`ost.mol.EntityHandle` + :param alns: Each alignment is accessible with ``alns[(t_chain,m_chain)]``. + First sequence is the sequence of the respective chain in + :attr:`~qsent1`, second sequence the one from :attr:`~qsent2`. + Can be fetched from + :class:`ost.mol.alg.chain_mapping.MappingResult` + :type alns: :class:`dict` with key: :class:`tuple` of :class:`str`, value: + :class:`ost.seq.AlignmentHandle` + :param dist_thresh: Max distance of a pairwise interaction in target + to be considered as contact in lDDT + :type dist_thresh: :class:`float` + :param dist_diff_thresholds: Distance difference thresholds for + lDDT computations + :type dist_diff_thresholds: :class:`list` of :class:`float` + """ + def __init__(self, target, chem_groups, model, alns, dist_thresh = 15.0, + dist_diff_thresholds = [0.5, 1.0, 2.0, 4.0]): + + self._trg = BBlDDTEntity(target, dist_thresh = dist_thresh, + dist_diff_thresholds=dist_diff_thresholds) + + # ensure that target chain names match the ones in chem_groups + chem_group_ch_names = list(itertools.chain.from_iterable(chem_groups)) + if self._trg.chain_names != sorted(chem_group_ch_names): + raise RuntimeError(f"Expect exact same chain names in chem_groups " + f"and in target (which is processed to only " + f"contain peptides/nucleotides). target: " + f"{self._trg.chain_names}, chem_groups: " + f"{chem_group_ch_names}") + + self._chem_groups = chem_groups + self._mdl = BBlDDTEntity(model, dist_thresh = dist_thresh, + dist_diff_thresholds=dist_diff_thresholds) + self._alns = alns + self._dist_diff_thresholds = dist_diff_thresholds + self._dist_thresh = dist_thresh + + # cache for mapped interface scores + # key: tuple of tuple ((trg_ch1, trg_ch2), + # ((mdl_ch1, mdl_ch2)) + # where the first tuple is lexicographically sorted + # the second tuple is mapping dependent + # value: numpy array of len(dist_thresholds) representing the + # respective numbers of fulfilled contacts + self._pairwise_cache = dict() + + # cache for mapped single chain scores + # key: tuple (trg_ch, mdl_ch) + # value: numpy array of len(dist_thresholds) representing the + # respective numbers of fulfilled contacts + self._sc_cache = dict() + + @staticmethod + def FromMappingResult(mapping_result, dist_thresh = 15.0, + dist_diff_thresholds = [0.5, 1.0, 2.0, 4.0]): + """ The preferred way to get a :clas:`BBlDDTScorer` + + Static constructor that derives an object of type :class:`QSScorer` + using a :class:`ost.mol.alg.chain_mapping.MappingResult` + + :param mapping_result: Data source + :type mapping_result: :class:`ost.mol.alg.chain_mapping.MappingResult` + :param dist_thresh: The lDDT distance threshold + :type dist_thresh: :class:`float` + :param dist_diff_thresholds: The lDDT distance difference thresholds + :type dist_diff_thresholds: :class:`list` of :class:`float` + """ + scorer = BBlDDTScorer(mapping_result.target, mapping_result.chem_groups, + mapping_result.model, alns = mapping_result.alns, + dist_thresh = dist_thresh, + dist_diff_thresholds = dist_diff_thresholds) + return scorer + + @property + def trg(self): + """ The :class:`BBlDDTEntity` representing target + + :type: :class:`BBlDDTEntity` + """ + return self._trg + + @property + def mdl(self): + """ The :class:`BBlDDTEntity` representing model + + :type: :class:`BBlDDTEntity` + """ + return self._mdl + + @property + def alns(self): + """ Alignments between chains in :attr:`~trg` and :attr:`~mdl` + + Provided at object construction. Each alignment is accessible with + ``alns[(t_chain,m_chain)]``. First sequence is the sequence of the + respective chain in :attr:`~trg`, second sequence the one from + :attr:`~mdl`. + + :type: :class:`dict` with key: :class:`tuple` of :class:`str`, value: + :class:`ost.seq.AlignmentHandle` + """ + return self._alns + + @property + def chem_groups(self): + """ Groups of chemically equivalent chains in :attr:`~trg` + + Provided at object construction + + :type: :class:`list` of :class:`list` of :class:`str` + """ + return self._chem_groups + + def Score(self, mapping, check=True): + """ Computes Backbone lDDT given chain mapping + + Again, the preferred way is to get *mapping* is from an object + of type :class:`ost.mol.alg.chain_mapping.MappingResult`. + + :param mapping: see + :attr:`ost.mol.alg.chain_mapping.MappingResult.mapping` + :type mapping: :class:`list` of :class:`list` of :class:`str` + :param check: Perform input checks, can be disabled for speed purposes + if you know what you're doing. + :type check: :class:`bool` + :returns: The score + """ + if check: + # ensure that dimensionality of mapping matches self.chem_groups + if len(self.chem_groups) != len(mapping): + raise RuntimeError("Dimensions of self.chem_groups and mapping " + "must match") + for a,b in zip(self.chem_groups, mapping): + if len(a) != len(b): + raise RuntimeError("Dimensions of self.chem_groups and " + "mapping must match") + # ensure that chain names in mapping are all present in qsent2 + for name in itertools.chain.from_iterable(mapping): + if name is not None and name not in self.mdl.chain_names: + raise RuntimeError(f"Each chain in mapping must be present " + f"in self.mdl. No match for " + f"\"{name}\"") + + flat_mapping = dict() + for a, b in zip(self.chem_groups, mapping): + flat_mapping.update({x: y for x, y in zip(a, b) if y is not None}) + + return self.FromFlatMapping(flat_mapping) + + def FromFlatMapping(self, flat_mapping): + """ Same as :func:`Score` but with flat mapping + + :param flat_mapping: Dictionary with target chain names as keys and + the mapped model chain names as value + :type flat_mapping: :class:`dict` with :class:`str` as key and value + :returns: :class:`float` representing lDDT + """ + n_conserved = np.zeros(len(self._dist_diff_thresholds), dtype=np.int32) + + # process single chains + for cname in self.trg.chain_names: + if cname in flat_mapping: + n_conserved += self._NSCConserved(cname, flat_mapping[cname]) + + # process interfaces + for interface in self.trg.interacting_chains: + if interface[0] in flat_mapping and interface[1] in flat_mapping: + mdl_interface = (flat_mapping[interface[0]], + flat_mapping[interface[1]]) + n_conserved += self._NPairConserved(interface, mdl_interface) + + return np.mean(n_conserved / self.trg.n_contacts) + + def _NSCConserved(self, trg_ch, mdl_ch): + if (trg_ch, mdl_ch) in self._sc_cache: + return self._sc_cache[(trg_ch, mdl_ch)] + trg_dist = self.trg.Dist(trg_ch) + mdl_dist = self.mdl.Dist(mdl_ch) + trg_indices, mdl_indices = self._IndexMapping(trg_ch, mdl_ch) + trg_dist = trg_dist[np.ix_(trg_indices, trg_indices)] + mdl_dist = mdl_dist[np.ix_(mdl_indices, mdl_indices)] + # mask to select relevant distances (dist in trg < dist_thresh) + # np.triu zeroes the values below the diagonal + mask = np.triu(trg_dist < self._dist_thresh) + n_diag = trg_dist.shape[0] + trg_dist = trg_dist[mask] + mdl_dist = mdl_dist[mask] + dist_diffs = np.absolute(trg_dist - mdl_dist) + n_conserved = np.zeros(len(self._dist_diff_thresholds), dtype=np.int32) + for thresh_idx, thresh in enumerate(self._dist_diff_thresholds): + N = (dist_diffs < thresh).sum() + # still need to consider the 0.0 dist diffs on the diagonal + n_conserved[thresh_idx] = int((N - n_diag)) + self._sc_cache[(trg_ch, mdl_ch)] = n_conserved + return n_conserved + + def _NPairConserved(self, trg_int, mdl_int): + key_one = (trg_int, mdl_int) + if key_one in self._pairwise_cache: + return self._pairwise_cache[key_one] + key_two = ((trg_int[1], trg_int[0]), (mdl_int[1], mdl_int[0])) + if key_two in self._pairwise_cache: + return self._pairwise_cache[key_two] + trg_dist = self.trg.PairDist(trg_int[0], trg_int[1]) + mdl_dist = self.mdl.PairDist(mdl_int[0], mdl_int[1]) + if trg_int[0] > trg_int[1]: + trg_dist = trg_dist.transpose() + if mdl_int[0] > mdl_int[1]: + mdl_dist = mdl_dist.transpose() + trg_indices_1, mdl_indices_1 = self._IndexMapping(trg_int[0], mdl_int[0]) + trg_indices_2, mdl_indices_2 = self._IndexMapping(trg_int[1], mdl_int[1]) + trg_dist = trg_dist[np.ix_(trg_indices_1, trg_indices_2)] + mdl_dist = mdl_dist[np.ix_(mdl_indices_1, mdl_indices_2)] + # reduce to relevant distances (dist in trg < dist_thresh) + mask = trg_dist < self._dist_thresh + trg_dist = trg_dist[mask] + mdl_dist = mdl_dist[mask] + dist_diffs = np.absolute(trg_dist - mdl_dist) + n_conserved = np.zeros(len(self._dist_diff_thresholds), dtype=np.int32) + for thresh_idx, thresh in enumerate(self._dist_diff_thresholds): + n_conserved[thresh_idx] = (dist_diffs < thresh).sum() + self._pairwise_cache[key_one] = n_conserved + return n_conserved + + def _IndexMapping(self, ch1, ch2): + """ Fetches aln and returns indices of aligned residues + + returns 2 numpy arrays containing the indices of residues in + ch1 and ch2 which are aligned + """ + mapped_indices_1 = list() + mapped_indices_2 = list() + idx_1 = 0 + idx_2 = 0 + for col in self.alns[(ch1, ch2)]: + if col[0] != '-' and col[1] != '-': + mapped_indices_1.append(idx_1) + mapped_indices_2.append(idx_2) + if col[0] != '-': + idx_1 +=1 + if col[1] != '-': + idx_2 +=1 + return (np.array(mapped_indices_1), np.array(mapped_indices_2)) diff --git a/modules/mol/alg/pymod/chain_mapping.py b/modules/mol/alg/pymod/chain_mapping.py index 15ba5b1535c11eb8ad60da457e02a9a8a20de4ea..189bbae636e4fd9da822ddc91265eacf15409899 100644 --- a/modules/mol/alg/pymod/chain_mapping.py +++ b/modules/mol/alg/pymod/chain_mapping.py @@ -18,6 +18,7 @@ from ost import mol from ost import geom from ost.mol.alg import lddt +from ost.mol.alg import bb_lddt from ost.mol.alg import qsscore def _CSel(ent, cnames): @@ -925,7 +926,7 @@ class ChainMapper: mapping = _lDDTGreedyBlock(the_greed, block_seed_size, block_blocks_per_chem_group) # cached => lDDT computation is fast here - opt_lddt = the_greed.lDDT(self.chem_groups, mapping) + opt_lddt = the_greed.Score(mapping) alns = dict() for ref_group, mdl_group in zip(self.chem_groups, mapping): @@ -1063,7 +1064,7 @@ class ChainMapper: mapping = _QSScoreGreedyBlock(the_greed, block_seed_size, block_blocks_per_chem_group) # cached => QSScore computation is fast here - opt_qsscore = the_greed.Score(mapping, check=False) + opt_qsscore = the_greed.Score(mapping, check=False).QS_global alns = dict() for ref_group, mdl_group in zip(self.chem_groups, mapping): @@ -1230,10 +1231,19 @@ class ChainMapper: performed (greedy_prune_contact_map = True). The default for *n_max_naive* of 40320 corresponds to an octamer (8!=40320). A structure with stoichiometry A6B2 would be 6!*2!=1440 etc. + + If :attr:`~target` has nucleotide sequences, the QS-score target + function is replaced with a backbone only lDDT score that has + an inclusion radius of 30A. """ - return self.GetQSScoreMapping(model, strategy="heuristic", - greedy_prune_contact_map=True, - heuristic_n_max_naive = n_max_naive) + if len(self.polynuc_seqs) > 0: + return self.GetlDDTMapping(model, strategy = "heuristic", + inclusion_radius = 30.0, + heuristic_n_max_naive = n_max_naive) + else: + return self.GetQSScoreMapping(model, strategy="heuristic", + greedy_prune_contact_map=True, + heuristic_n_max_naive = n_max_naive) def GetRepr(self, substructure, model, topn=1, inclusion_radius=15.0, thresholds=[0.5, 1.0, 2.0, 4.0], bb_only=False, @@ -2028,166 +2038,25 @@ def _CheckOneToOneMapping(ref_chains, mdl_chains): else: return None -class _lDDTDecomposer: - - def __init__(self, ref, mdl, ref_mdl_alns, inclusion_radius = 15.0, - thresholds = [0.5, 1.0, 2.0, 4.0]): - """ Compute backbone only lDDT scores for ref/mdl - - Uses the pairwise decomposable property of backbone only lDDT and - implements a caching mechanism to efficiently enumerate different - chain mappings. - """ - - self.ref = ref - self.mdl = mdl - self.ref_mdl_alns = ref_mdl_alns - self.inclusion_radius = inclusion_radius - self.thresholds = thresholds - - # keep track of single chains and interfaces in ref - self.ref_chains = list() # e.g. ['A', 'B', 'C'] - self.ref_interfaces = list() # e.g. [('A', 'B'), ('A', 'C')] - - # holds lDDT scorer for each chain in ref - # key: chain name, value: scorer - self.single_chain_scorer = dict() - - # cache for single chain conserved contacts - # key: tuple (ref_ch, mdl_ch) value: number of conserved contacts - self.single_chain_cache = dict() - - # holds lDDT scorer for each pairwise interface in target - # key: tuple (ref_ch1, ref_ch2), value: scorer - self.interface_scorer = dict() - - # cache for interface conserved contacts - # key: tuple of tuple ((ref_ch1, ref_ch2),((mdl_ch1, mdl_ch2)) - # value: number of conserved contacts - self.interface_cache = dict() - - self.n = 0 - - self._SetupScorer() - - def _SetupScorer(self): - for ch in self.ref.chains: - # Select everything close to that chain - query = f"{self.inclusion_radius} <> " - query += f"[cname={mol.QueryQuoteName(ch.GetName())}] " - query += f"and cname!={mol.QueryQuoteName(ch.GetName())}" - for close_ch in self.ref.Select(query).chains: - k1 = (ch.GetName(), close_ch.GetName()) - k2 = (close_ch.GetName(), ch.GetName()) - if k1 not in self.interface_scorer and \ - k2 not in self.interface_scorer: - dimer_ref = _CSel(self.ref, [k1[0], k1[1]]) - s = lddt.lDDTScorer(dimer_ref, bb_only=True) - self.interface_scorer[k1] = s - self.interface_scorer[k2] = s - self.n += sum([len(x) for x in self.interface_scorer[k1].ref_indices_ic]) - self.ref_interfaces.append(k1) - # single chain scorer are actually interface scorers to save - # some distance calculations - if ch.GetName() not in self.single_chain_scorer: - self.single_chain_scorer[ch.GetName()] = s - self.n += s.GetNChainContacts(ch.GetName(), - no_interchain=True) - self.ref_chains.append(ch.GetName()) - if close_ch.GetName() not in self.single_chain_scorer: - self.single_chain_scorer[close_ch.GetName()] = s - self.n += s.GetNChainContacts(close_ch.GetName(), - no_interchain=True) - self.ref_chains.append(close_ch.GetName()) - - # add any missing single chain scorer - for ch in self.ref.chains: - if ch.GetName() not in self.single_chain_scorer: - single_chain_ref = _CSel(self.ref, [ch.GetName()]) - self.single_chain_scorer[ch.GetName()] = \ - lddt.lDDTScorer(single_chain_ref, bb_only = True) - self.n += self.single_chain_scorer[ch.GetName()].n_distances - self.ref_chains.append(ch.GetName()) - - def lDDT(self, ref_chain_groups, mdl_chain_groups): - - flat_map = dict() - for ref_chains, mdl_chains in zip(ref_chain_groups, mdl_chain_groups): - for ref_ch, mdl_ch in zip(ref_chains, mdl_chains): - flat_map[ref_ch] = mdl_ch - - return self.lDDTFromFlatMap(flat_map) - - - def lDDTFromFlatMap(self, flat_map): - conserved = 0 - - # do single chain scores - for ref_ch in self.ref_chains: - if ref_ch in flat_map and flat_map[ref_ch] is not None: - conserved += self.SCCounts(ref_ch, flat_map[ref_ch]) - - # do interfaces - for ref_ch1, ref_ch2 in self.ref_interfaces: - if ref_ch1 in flat_map and ref_ch2 in flat_map: - mdl_ch1 = flat_map[ref_ch1] - mdl_ch2 = flat_map[ref_ch2] - if mdl_ch1 is not None and mdl_ch2 is not None: - conserved += self.IntCounts(ref_ch1, ref_ch2, mdl_ch1, - mdl_ch2) - - return conserved / (len(self.thresholds) * self.n) - - def SCCounts(self, ref_ch, mdl_ch): - if not (ref_ch, mdl_ch) in self.single_chain_cache: - alns = dict() - alns[mdl_ch] = self.ref_mdl_alns[(ref_ch, mdl_ch)] - mdl_sel = _CSel(self.mdl, [mdl_ch]) - s = self.single_chain_scorer[ref_ch] - _,_,_,conserved,_,_,_ = s.lDDT(mdl_sel, - residue_mapping=alns, - return_dist_test=True, - no_interchain=True, - chain_mapping={mdl_ch: ref_ch}, - check_resnames=False) - self.single_chain_cache[(ref_ch, mdl_ch)] = conserved - return self.single_chain_cache[(ref_ch, mdl_ch)] - - def IntCounts(self, ref_ch1, ref_ch2, mdl_ch1, mdl_ch2): - k1 = ((ref_ch1, ref_ch2),(mdl_ch1, mdl_ch2)) - k2 = ((ref_ch2, ref_ch1),(mdl_ch2, mdl_ch1)) - if k1 not in self.interface_cache and k2 not in self.interface_cache: - alns = dict() - alns[mdl_ch1] = self.ref_mdl_alns[(ref_ch1, mdl_ch1)] - alns[mdl_ch2] = self.ref_mdl_alns[(ref_ch2, mdl_ch2)] - mdl_sel = _CSel(self.mdl, [mdl_ch1, mdl_ch2]) - s = self.interface_scorer[(ref_ch1, ref_ch2)] - _,_,_,conserved,_,_,_ = s.lDDT(mdl_sel, - residue_mapping=alns, - return_dist_test=True, - no_intrachain=True, - chain_mapping={mdl_ch1: ref_ch1, - mdl_ch2: ref_ch2}, - check_resnames=False) - self.interface_cache[k1] = conserved - self.interface_cache[k2] = conserved - return self.interface_cache[k1] - -class _lDDTGreedySearcher(_lDDTDecomposer): +class _lDDTGreedySearcher(bb_lddt.BBlDDTScorer): def __init__(self, ref, mdl, ref_chem_groups, mdl_chem_groups, ref_mdl_alns, inclusion_radius = 15.0, thresholds = [0.5, 1.0, 2.0, 4.0], steep_opt_rate = None): + """ Greedy extension of already existing but incomplete chain mappings """ - super().__init__(ref, mdl, ref_mdl_alns, - inclusion_radius = inclusion_radius, - thresholds = thresholds) + super().__init__(ref, ref_chem_groups, mdl, ref_mdl_alns, + dist_thresh=inclusion_radius, + dist_diff_thresholds=thresholds) + + self.mdl_chem_groups = mdl_chem_groups self.steep_opt_rate = steep_opt_rate - self.neighbors = {k: set() for k in self.ref_chains} - for k in self.interface_scorer.keys(): - self.neighbors[k[0]].add(k[1]) - self.neighbors[k[1]].add(k[0]) + + self.neighbors = {k: set() for k in self.trg.chain_names} + for interface in self.trg.interacting_chains: + self.neighbors[interface[0]].add(interface[1]) + self.neighbors[interface[1]].add(interface[0]) assert(len(ref_chem_groups) == len(mdl_chem_groups)) self.ref_chem_groups = ref_chem_groups @@ -2203,16 +2072,10 @@ class _lDDTGreedySearcher(_lDDTDecomposer): # keep track of mdl chains that potentially give lDDT contributions, # i.e. they have locations within inclusion_radius + max(thresholds) - self.mdl_neighbors = dict() - d = self.inclusion_radius + max(self.thresholds) - for ch in self.mdl.chains: - ch_name = ch.GetName() - self.mdl_neighbors[ch_name] = set() - query = f"{d} <> [cname={mol.QueryQuoteName(ch_name)}]" - query += f" and cname !={mol.QueryQuoteName(ch_name)}" - for close_ch in self.mdl.Select(query).chains: - self.mdl_neighbors[ch_name].add(close_ch.GetName()) - + self.mdl_neighbors = {k: set() for k in self.mdl.chain_names} + for interface in self.mdl.potentially_contributing_chains: + self.mdl_neighbors[interface[0]].add(interface[1]) + self.mdl_neighbors[interface[1]].add(interface[0]) def ExtendMapping(self, mapping, max_ext = None): @@ -2261,14 +2124,14 @@ class _lDDTGreedySearcher(_lDDTDecomposer): chem_group_idx = self.ref_ch_group_mapper[ref_ch] for mdl_ch in free_mdl_chains[chem_group_idx]: # single chain score - n_single = self.SCCounts(ref_ch, mdl_ch) + n_single = self._NSCConserved(ref_ch, mdl_ch).sum() # scores towards neighbors that are already mapped n_inter = 0 for neighbor in self.neighbors[ref_ch]: if neighbor in mapping and mapping[neighbor] in \ self.mdl_neighbors[mdl_ch]: - n_inter += self.IntCounts(ref_ch, neighbor, mdl_ch, - mapping[neighbor]) + n_inter += self._NPairConserved((ref_ch, neighbor), + (mdl_ch, mapping[neighbor])).sum() n = n_single + n_inter if n_inter > 0 and n > max_n: @@ -2322,7 +2185,7 @@ class _lDDTGreedySearcher(_lDDTDecomposer): # try all possible mapping swaps. Swaps that improve the score are # immediately accepted and we start all over again - current_lddt = self.lDDTFromFlatMap(mapping) + current_lddt = self.FromFlatMapping(mapping) something_happened = True while something_happened: something_happened = False @@ -2333,13 +2196,12 @@ class _lDDTGreedySearcher(_lDDTDecomposer): swapped_mapping = dict(mapping) swapped_mapping[ch1] = mapping[ch2] swapped_mapping[ch2] = mapping[ch1] - score = self.lDDTFromFlatMap(swapped_mapping) + score = self.FromFlatMapping(swapped_mapping) if score > current_lddt: something_happened = True mapping = swapped_mapping current_lddt = score break - return mapping @@ -2416,7 +2278,7 @@ def _lDDTGreedyFast(the_greed): n_best = 0 best_seed = None for seed in seeds: - n = the_greed.SCCounts(seed[0], seed[1]) + n = the_greed._NSCConserved(seed[0], seed[1]).sum() if n > n_best: n_best = n best_seed = seed @@ -2470,7 +2332,7 @@ def _lDDTGreedyFull(the_greed): tmp_mapping = dict(mapping) tmp_mapping[remnant_seed[0]] = remnant_seed[1] tmp_mapping = the_greed.ExtendMapping(tmp_mapping) - score = the_greed.lDDTFromFlatMap(tmp_mapping) + score = the_greed.FromFlatMapping(tmp_mapping) if score > best_score: best_score = score best_mapping = tmp_mapping @@ -2478,7 +2340,7 @@ def _lDDTGreedyFull(the_greed): something_happened = True mapping = best_mapping - score = the_greed.lDDTFromFlatMap(mapping) + score = the_greed.FromFlatMapping(mapping) if score > best_overall_score: best_overall_score = score best_overall_mapping = mapping @@ -2542,7 +2404,7 @@ def _lDDTGreedyBlock(the_greed, seed_size, blocks_per_chem_group): seed = dict(mapping) seed.update({s[0]: s[1]}) seed = the_greed.ExtendMapping(seed, max_ext = max_ext) - seed_lddt = the_greed.lDDTFromFlatMap(seed) + seed_lddt = the_greed.FromFlatMapping(seed) if seed_lddt > best_score: best_score = seed_lddt best_mapping = seed @@ -2558,7 +2420,7 @@ def _lDDTGreedyBlock(the_greed, seed_size, blocks_per_chem_group): best_mapping = None for seed in starting_blocks: seed = the_greed.ExtendMapping(seed) - seed_lddt = the_greed.lDDTFromFlatMap(seed) + seed_lddt = the_greed.FromFlatMapping(seed) if seed_lddt > best_lddt: best_lddt = seed_lddt best_mapping = seed diff --git a/modules/mol/alg/tests/CMakeLists.txt b/modules/mol/alg/tests/CMakeLists.txt index e1a1aaf827cc2e84a9f9d4ec174258ad41159f77..e21f0746812554f0b3a6e33145efbd7eb3d9e254 100644 --- a/modules/mol/alg/tests/CMakeLists.txt +++ b/modules/mol/alg/tests/CMakeLists.txt @@ -14,6 +14,7 @@ set(OST_MOL_ALG_UNIT_TESTS test_contact_score.py test_biounit.py test_ost_dockq.py + test_bblddt.py ) if (COMPOUND_LIB) diff --git a/modules/mol/alg/tests/test_bblddt.py b/modules/mol/alg/tests/test_bblddt.py new file mode 100644 index 0000000000000000000000000000000000000000..38be34c90b6436d0e73fa6a8bbbf3d3ac859dbdd --- /dev/null +++ b/modules/mol/alg/tests/test_bblddt.py @@ -0,0 +1,117 @@ +import unittest, os, sys +import ost +from ost import conop +from ost import io, mol, seq, settings +import time +# check if we can import: fails if numpy or scipy not available +try: + import numpy as np + from ost.mol.alg.bb_lddt import * + from ost.mol.alg.lddt import * + from ost.mol.alg.chain_mapping import * +except ImportError: + print("Failed to import bb_lddt.py. Happens when numpy or scipy "\ + "missing. Ignoring qsscore.py tests.") + sys.exit(0) + +def _LoadFile(file_name): + """Helper to avoid repeating input path over and over.""" + return io.LoadPDB(os.path.join('testfiles', file_name)) + +class TestBBlDDT(unittest.TestCase): + + def test_bblddtentity(self): + ent = _LoadFile("3l1p.1.pdb") + ent = BBlDDTEntity(ent) + self.assertEqual(len(ent.view.chains), 4) + self.assertEqual(ent.GetChain("A").GetName(), "A") + self.assertEqual(ent.GetChain("B").GetName(), "B") + self.assertEqual(ent.GetChain("C").GetName(), "C") + self.assertEqual(ent.GetChain("D").GetName(), "D") + self.assertRaises(Exception, ent.GetChain, "E") + self.assertEqual(ent.chain_names, ["A", "B", "C", "D"]) + self.assertEqual(ent.GetSequence("A"), "DMKALQKELEQFAKLLKQKRITLGYTQADVGLTLGVLFGKVFSQTTISRFEALQLSLKNMSKLRPLLEKWVEEADNNENLQEISKSVQARKRKRTSIENRVRWSLETMFLKSPKPSLQQITHIANQLGLEKDVVRVWFSNRRQKGKR") + self.assertEqual(ent.GetSequence("B"), "KALQKELEQFAKLLKQKRITLGYTQADVGLTLGVLFGKVFSQTTISRFEALQLSLKNMSKLRPLLEKWVEEADNNENLQEISKSQARKRKRTSIENRVRWSLETMFLKSPKPSLQQITHIANQLGLEKDVVRVWFSNRRQKGKRS") + self.assertEqual(ent.GetSequence("C"), "TCCACATTTGAAAGGCAAATGGA") + self.assertEqual(ent.GetSequence("D"), "ATCCATTTGCCTTTCAAATGTGG") + + # check for a couple of positions with manually extracted values + + # GLU + pos = ent.GetPos("B") + self.assertAlmostEqual(pos[5,0], -0.901, places=3) + self.assertAlmostEqual(pos[5,1], 28.167, places=3) + self.assertAlmostEqual(pos[5,2], 13.955, places=3) + + # GLY + pos = ent.GetPos("A") + self.assertAlmostEqual(pos[23,0], 17.563, places=3) + self.assertAlmostEqual(pos[23,1], -4.082, places=3) + self.assertAlmostEqual(pos[23,2], 29.005, places=3) + + # Cytosine + pos = ent.GetPos("C") + self.assertAlmostEqual(pos[4,0], 14.796, places=3) + self.assertAlmostEqual(pos[4,1], 24.653, places=3) + self.assertAlmostEqual(pos[4,2], 59.318, places=3) + + + # check pairwise dist, chain names are always sorted => + # A is rows, C is cols + dist_one = ent.PairDist("A", "C") + dist_two = ent.PairDist("C", "A") + self.assertTrue(np.array_equal(dist_one, dist_two)) + self.assertEqual(dist_one.shape[0], len(ent.GetSequence("A"))) + self.assertEqual(dist_one.shape[1], len(ent.GetSequence("C"))) + + # check some random distance between the Gly and Cytosine that we already + # checked above + self.assertAlmostEqual(dist_one[23,4], 41.86, places=2) + + # all chains interact with each other... but hey, check nevertheless + self.assertEqual(ent.interacting_chains, [("A", "B"), ("A", "C"), + ("A", "D"), ("B", "C"), + ("B", "D"), ("C", "D")]) + + def test_bb_lddt_scorer(self): + + target = _LoadFile("3l1p.1.pdb") + model = _LoadFile("3l1p.1_model.pdb") + + # we need to derive a chain mapping prior to scoring + mapper = ChainMapper(target) + res = mapper.GetRMSDMapping(model, strategy="greedy_iterative") + + # lets compare with lddt reference implementation + + reference_lddt_scorer = lDDTScorer(target, bb_only=True) + + # make alignments accessible by mdl seq name + alns = dict() + for aln in res.alns.values(): + mdl_seq = aln.GetSequence(1) + alns[mdl_seq.name] = aln + + # lDDT requires a flat mapping with mdl_ch as key and trg_ch as value + flat_mapping = res.GetFlatMapping(mdl_as_key=True) + lddt_chain_mapping = dict() + for mdl_ch, trg_ch in flat_mapping.items(): + if mdl_ch in alns: + lddt_chain_mapping[mdl_ch] = trg_ch + + reference_lddt_score = reference_lddt_scorer.lDDT(model, + chain_mapping = lddt_chain_mapping, + residue_mapping = alns, + check_resnames=False)[0] + + bb_lddt_scorer = BBlDDTScorer.FromMappingResult(res) + bb_lddt_score = bb_lddt_scorer.Score(res.mapping) + + self.assertAlmostEqual(reference_lddt_score, bb_lddt_score, places = 4) + +if __name__ == "__main__": + from ost import testutils + if testutils.DefaultCompoundLibIsSet(): + testutils.RunTests() + else: + print('No compound lib available. Ignoring test_bblddt.py tests.')