From fa58ebaefe6bcfe9a7a99642a76a4fc810a08845 Mon Sep 17 00:00:00 2001 From: Gabriel Studer <gabriel.studer@unibas.ch> Date: Wed, 19 Jul 2023 08:52:15 +0200 Subject: [PATCH] Chain Mapping: QSscore mapping code cleanup and speedups for large structures --- modules/mol/alg/pymod/chain_mapping.py | 172 +++++++++++++++---------- 1 file changed, 101 insertions(+), 71 deletions(-) diff --git a/modules/mol/alg/pymod/chain_mapping.py b/modules/mol/alg/pymod/chain_mapping.py index c4a017f3d..d9d3c96d9 100644 --- a/modules/mol/alg/pymod/chain_mapping.py +++ b/modules/mol/alg/pymod/chain_mapping.py @@ -19,6 +19,8 @@ from ost import geom from ost.mol.alg import lddt from ost.mol.alg import qsscore +import time + def _CSel(ent, cnames): """ Returns view with specified chains @@ -2710,14 +2712,7 @@ class _QSScoreGreedySearcher(qsscore.QSScorer): for ref_ch in mapping.keys(): map_targets.discard(ref_ch) - # same for model - mdl_map_targets = set() - for mdl_ch in mapping.values(): - mdl_map_targets.update(self.mdl_neighbors[mdl_ch]) - for mdl_ch in mapping.values(): - mdl_map_targets.discard(mdl_ch) - - if len(map_targets) == 0 or len(mdl_map_targets) == 0: + if len(map_targets) == 0: return mapping # nothing to extend # keep track of what model chains are not yet mapped for each chem group @@ -2741,19 +2736,49 @@ class _QSScoreGreedySearcher(qsscore.QSScorer): if max_ext is not None and len(newly_mapped_ref_chains) >= max_ext: break + score_result = self.FromFlatMapping(mapping) + old_score = score_result.QS_global + nominator = score_result.weighted_scores + denominator = score_result.weight_sum + score_result.weight_extra_all + + max_diff = 0.0 max_mapping = None - best_score = self.FromFlatMapping(mapping).QS_global for ref_ch in map_targets: chem_group_idx = self.ref_ch_group_mapper[ref_ch] for mdl_ch in free_mdl_chains[chem_group_idx]: - if mdl_ch in mdl_map_targets: - new_mapping = dict(mapping) - new_mapping[ref_ch] = mdl_ch - new_score = self.FromFlatMapping(new_mapping).QS_global - if new_score > best_score: - best_score = new_score + # we're not computing full QS-score here, we directly hack + # into the QS-score formula to compute a diff + nominator_diff = 0.0 + denominator_diff = 0.0 + for neighbor in self.neighbors[ref_ch]: + if neighbor in mapping and mapping[neighbor] in \ + self.mdl_neighbors[mdl_ch]: + # it's a newly added interface if (ref_ch, mdl_ch) + # are added to mapping + int1 = (ref_ch, neighbor) + int2 = (mdl_ch, mapping[neighbor]) + a, b, c, d = self._MappedInterfaceScores(int1, int2) + nominator_diff += a # weighted_scores + denominator_diff += b # weight_sum + denominator_diff += d # weight_extra_all + # the respective interface penalties are subtracted + # from denominator + denominator_diff -= self._InterfacePenalty1(int1) + denominator_diff -= self._InterfacePenalty2(int2) + + if nominator_diff > 0: + # Only accept a new solution if its actually connected + # i.e. nominator_diff > 0. + new_nominator = nominator + nominator_diff + new_denominator = denominator + denominator_diff + new_score = 0.0 + if new_denominator != 0.0: + new_score = new_nominator/new_denominator + diff = new_score - old_score + if diff > max_diff: + max_diff = diff max_mapping = (ref_ch, mdl_ch) - + if max_mapping is not None: something_happened = True # assign new found mapping @@ -2765,13 +2790,8 @@ class _QSScoreGreedySearcher(qsscore.QSScorer): if neighbor not in mapping: map_targets.add(neighbor) - for neighbor in self.mdl_neighbors[max_mapping[1]]: - if neighbor not in mapping.values(): - mdl_map_targets.add(neighbor) - - # remove chains from map targets + # remove the ref chain from map targets map_targets.remove(max_mapping[0]) - mdl_map_targets.remove(max_mapping[1]) # remove the mdl chain from free_mdl_chains - its taken... chem_group_idx = self.ref_ch_group_mapper[max_mapping[0]] @@ -2839,28 +2859,37 @@ def _QSScoreNaive(trg, mdl, chem_groups, chem_mapping, ref_mdl_alns, contact_d, return (best_mapping, best_score) +def _GetSeeds(ref_chem_groups, mdl_chem_groups, mapped_ref_chains = set(), + mapped_mdl_chains = set()): + seeds = list() + for ref_chains, mdl_chains in zip(ref_chem_groups, + mdl_chem_groups): + for ref_ch in ref_chains: + if ref_ch not in mapped_ref_chains: + for mdl_ch in mdl_chains: + if mdl_ch not in mapped_mdl_chains: + seeds.append((ref_ch, mdl_ch)) + return seeds + + def _QSScoreGreedyFast(the_greed): something_happened = True mapping = dict() - while something_happened: something_happened = False # search for best scoring starting point, we're using lDDT here n_best = 0 best_seed = None - mapped_ref_chains = set(mapping.keys()) - mapped_mdl_chains = set(mapping.values()) - for ref_chains, mdl_chains in zip(the_greed.ref_chem_groups, - the_greed.mdl_chem_groups): - for ref_ch in ref_chains: - if ref_ch not in mapped_ref_chains: - for mdl_ch in mdl_chains: - if mdl_ch not in mapped_mdl_chains: - n = the_greed.SCCounts(ref_ch, mdl_ch) - if n > n_best: - n_best = n - best_seed = (ref_ch, mdl_ch) + seeds = _GetSeeds(the_greed.ref_chem_groups, + the_greed.mdl_chem_groups, + mapped_ref_chains = set(mapping.keys()), + mapped_mdl_chains = set(mapping.values())) + for seed in seeds: + n = the_greed.SCCounts(seed[0], seed[1]) + if n > n_best: + n_best = n + best_seed = seed if n_best == 0: break # no proper seed found anymore... # add seed to mapping and start the greed @@ -2868,7 +2897,6 @@ def _QSScoreGreedyFast(the_greed): mapping = the_greed.ExtendMapping(mapping) something_happened = True - # translate mapping format and return final_mapping = list() for ref_chains in the_greed.ref_chem_groups: @@ -2893,43 +2921,45 @@ def _QSScoreGreedyFull(the_greed, n_mdl_chains): if n_mdl_chains is not None and n_mdl_chains < 1: raise RuntimeError("n_mdl_chains must be None or >= 1") - something_happened = True - mapping = dict() + seeds = _GetSeeds(the_greed.ref_chem_groups, the_greed.mdl_chem_groups) + best_overall_score = -1.0 + best_overall_mapping = dict() - while something_happened: - something_happened = False - # Try all possible starting points and keep the one giving the best QS score - best_score = -1.0 - best_mapping = None - mapped_ref_chains = set(mapping.keys()) - mapped_mdl_chains = set(mapping.values()) - for ref_chains, mdl_chains in zip(the_greed.ref_chem_groups, - the_greed.mdl_chem_groups): - for ref_ch in ref_chains: - if ref_ch not in mapped_ref_chains: - seeds = list() - for mdl_ch in mdl_chains: - if mdl_ch not in mapped_mdl_chains: - seeds.append((ref_ch, mdl_ch)) - if n_mdl_chains is not None and n_mdl_chains < len(seeds): - counts = [the_greed.SCCounts(s[0], s[1]) for s in seeds] - tmp = [(a,b) for a,b in zip(counts, seeds)] - tmp.sort(reverse=True) - seeds = [item[1] for item in tmp[:n_mdl_chains]] - for seed in seeds: - tmp_mapping = dict(mapping) - tmp_mapping[seed[0]] = seed[1] - tmp_mapping = the_greed.ExtendMapping(tmp_mapping) - score_result = the_greed.FromFlatMapping(tmp_mapping) - if score_result.QS_global > best_score: - best_score = score_result.QS_global - best_mapping = tmp_mapping + for seed in seeds: - if best_mapping is not None and len(best_mapping) > len(mapping): - # this even accepts extensions that lead to no increase in QS-score - # at least they make sense from an lDDT perspective - something_happened = True - mapping = best_mapping + # do initial extension + mapping = the_greed.ExtendMapping({seed[0]: seed[1]}) + + # repeat the process until we have a full mapping + something_happened = True + while something_happened: + something_happened = False + remnant_seeds = _GetSeeds(the_greed.ref_chem_groups, + the_greed.mdl_chem_groups, + mapped_ref_chains = set(mapping.keys()), + mapped_mdl_chains = set(mapping.values())) + if len(remnant_seeds) > 0: + # still more mapping to be done + best_score = -1.0 + best_mapping = None + for remnant_seed in remnant_seeds: + tmp_mapping = dict(mapping) + tmp_mapping[remnant_seed[0]] = remnant_seed[1] + tmp_mapping = the_greed.ExtendMapping(tmp_mapping) + score_result = the_greed.FromFlatMapping(tmp_mapping) + if score_result.QS_global > best_score: + best_score = score_result.QS_global + best_mapping = tmp_mapping + if best_mapping is not None: + something_happened = True + mapping = best_mapping + + score_result = the_greed.FromFlatMapping(mapping) + if score_result.QS_global > best_overall_score: + best_overall_score = score_result.QS_global + best_overall_mapping = mapping + + mapping = best_overall_mapping # translate mapping format and return final_mapping = list() -- GitLab