diff --git a/modules/mol/alg/pymod/chain_mapping.py b/modules/mol/alg/pymod/chain_mapping.py index d9d3c96d9d8c50361edb3ce91567cb86a7bfbc9a..7a7627509ac9109171fdbc8ef4c236d6c608ceeb 100644 --- a/modules/mol/alg/pymod/chain_mapping.py +++ b/modules/mol/alg/pymod/chain_mapping.py @@ -19,8 +19,6 @@ from ost import geom from ost.mol.alg import lddt from ost.mol.alg import qsscore -import time - def _CSel(ent, cnames): """ Returns view with specified chains @@ -788,8 +786,8 @@ class ChainMapper: def GetlDDTMapping(self, model, inclusion_radius=15.0, thresholds=[0.5, 1.0, 2.0, 4.0], strategy="naive", - steep_opt_rate = None, full_n_mdl_chains = None, - block_seed_size = 5, block_blocks_per_chem_group = 5, + steep_opt_rate = None, block_seed_size = 5, + block_blocks_per_chem_group = 5, chem_mapping_result = None): """ Identify chain mapping by optimizing lDDT score @@ -816,9 +814,7 @@ class ChainMapper: * **greedy_full**: try multiple seeds for greedy extension, i.e. try all ref/mdl chain combinations within the respective chem groups and - retain the mapping leading to the best lDDT. Optionally, you can - reduce the number of mdl chains per ref chain to the - *full_n_mdl_chains* best scoring ones. + retain the mapping leading to the best lDDT. * **greedy_block**: try multiple seeds for greedy extension, i.e. try all ref/mdl chain combinations within the respective chem groups and @@ -846,10 +842,6 @@ class ChainMapper: swaps that improve lDDT score. Iteration stops as soon as no improvement can be achieved anymore. :type steep_opt_rate: :class:`int` - :param full_n_mdl_chains: Param for *greedy_full* strategy - Max number of - mdl chains that are tried per ref chain. The - default (None) tries all of them. - :type full_n_mdl_chains: :class:`int` :param block_seed_size: Param for *greedy_block* strategy - Initial seeds are extended by that number of chains. :type block_seed_size: :class:`int` @@ -912,7 +904,7 @@ class ChainMapper: if strategy == "greedy_fast": mapping = _lDDTGreedyFast(the_greed) elif strategy == "greedy_full": - mapping = _lDDTGreedyFull(the_greed, full_n_mdl_chains) + mapping = _lDDTGreedyFull(the_greed) elif strategy == "greedy_block": mapping = _lDDTGreedyBlock(the_greed, block_seed_size, block_blocks_per_chem_group) @@ -933,17 +925,14 @@ class ChainMapper: def GetQSScoreMapping(self, model, contact_d = 12.0, strategy = "naive", - full_n_mdl_chains = None, block_seed_size = 5, - block_blocks_per_chem_group = 5, + block_seed_size = 5, block_blocks_per_chem_group = 5, steep_opt_rate = None, chem_mapping_result = None, greedy_prune_contact_map = False): """ Identify chain mapping based on QSScore Scoring is based on CA/C3' positions which are present in all chains of a :attr:`chem_groups` as well as the *model* chains which are mapped to - that respective chem group. QS score is not defined for single chains. - The greedy strategies that require to identify starting seeds thus - often rely on single chain lDDTs. + that respective chem group. The following strategies are available: @@ -953,13 +942,11 @@ class ChainMapper: * **greedy_fast**: perform all vs. all single chain lDDTs within the respective ref/mdl chem groups. The mapping with highest number of conserved contacts is selected as seed for greedy extension. - Extension is based on QS score. + Extension is based on QS-score. * **greedy_full**: try multiple seeds for greedy extension, i.e. try all ref/mdl chain combinations within the respective chem groups and - retain the mapping leading to the best QS score. Optionally, you can - reduce the number of mdl chains per ref chain to the - *full_n_mdl_chains* best scoring with respect to single chain lDDT. + retain the mapping leading to the best QS-score. * **greedy_block**: try multiple seeds for greedy extension, i.e. try all ref/mdl chain combinations within the respective chem groups and @@ -1040,7 +1027,7 @@ class ChainMapper: if strategy == "greedy_fast": mapping = _QSScoreGreedyFast(the_greed) elif strategy == "greedy_full": - mapping = _QSScoreGreedyFull(the_greed, full_n_mdl_chains) + mapping = _QSScoreGreedyFull(the_greed) elif strategy == "greedy_block": mapping = _QSScoreGreedyBlock(the_greed, block_seed_size, block_blocks_per_chem_group) @@ -2242,12 +2229,6 @@ class _lDDTGreedySearcher(_lDDTDecomposer): if len(mapping) == 0: raise RuntimError("Mapping must contain a starting point") - for ref_ch, mdl_ch in mapping.items(): - assert(ref_ch in self.ref_ch_group_mapper) - assert(mdl_ch in self.mdl_ch_group_mapper) - assert(self.ref_ch_group_mapper[ref_ch] == \ - self.mdl_ch_group_mapper[mdl_ch]) - # Ref chains onto which we can map. The algorithm starts with a mapping # on ref_ch. From there we can start to expand to connected neighbors. # All neighbors that we can reach from the already mapped chains are @@ -2417,6 +2398,19 @@ def _lDDTNaive(trg, mdl, inclusion_radius, thresholds, chem_groups, return (best_mapping, best_lddt) +def _GetSeeds(ref_chem_groups, mdl_chem_groups, mapped_ref_chains = set(), + mapped_mdl_chains = set()): + seeds = list() + for ref_chains, mdl_chains in zip(ref_chem_groups, + mdl_chem_groups): + for ref_ch in ref_chains: + if ref_ch not in mapped_ref_chains: + for mdl_ch in mdl_chains: + if mdl_ch not in mapped_mdl_chains: + seeds.append((ref_ch, mdl_ch)) + return seeds + + def _lDDTGreedyFast(the_greed): something_happened = True @@ -2424,21 +2418,18 @@ def _lDDTGreedyFast(the_greed): while something_happened: something_happened = False + seeds = _GetSeeds(the_greed.ref_chem_groups, + the_greed.mdl_chem_groups, + mapped_ref_chains = set(mapping.keys()), + mapped_mdl_chains = set(mapping.values())) # search for best scoring starting point n_best = 0 best_seed = None - mapped_ref_chains = set(mapping.keys()) - mapped_mdl_chains = set(mapping.values()) - for ref_chains, mdl_chains in zip(the_greed.ref_chem_groups, - the_greed.mdl_chem_groups): - for ref_ch in ref_chains: - if ref_ch not in mapped_ref_chains: - for mdl_ch in mdl_chains: - if mdl_ch not in mapped_mdl_chains: - n = the_greed.SCCounts(ref_ch, mdl_ch) - if n > n_best: - n_best = n - best_seed = (ref_ch, mdl_ch) + for seed in seeds: + n = the_greed.SCCounts(seed[0], seed[1]) + if n > n_best: + n_best = n + best_seed = seed if n_best == 0: break # no proper seed found anymore... # add seed to mapping and start the greed @@ -2446,7 +2437,6 @@ def _lDDTGreedyFast(the_greed): mapping = the_greed.ExtendMapping(mapping) something_happened = True - # translate mapping format and return final_mapping = list() for ref_chains in the_greed.ref_chem_groups: @@ -2461,53 +2451,49 @@ def _lDDTGreedyFast(the_greed): return final_mapping -def _lDDTGreedyFull(the_greed, n_mdl_chains): +def _lDDTGreedyFull(the_greed): """ Uses each reference chain as starting point for expansion - - However, not all mdl chain are mapped onto these reference chains, - that's controlled by *n_mdl_chains* """ - if n_mdl_chains is not None and n_mdl_chains < 1: - raise RuntimeError("n_mdl_chains must be None or >= 1") - - something_happened = True - mapping = dict() + seeds = _GetSeeds(the_greed.ref_chem_groups, the_greed.mdl_chem_groups) + best_overall_score = -1.0 + best_overall_mapping = dict() - while something_happened: - something_happened = False - # Try all possible starting points and keep the one giving the best lDDT - best_lddt = 0.0 - best_mapping = None - mapped_ref_chains = set(mapping.keys()) - mapped_mdl_chains = set(mapping.values()) - for ref_chains, mdl_chains in zip(the_greed.ref_chem_groups, - the_greed.mdl_chem_groups): - for ref_ch in ref_chains: - if ref_ch not in mapped_ref_chains: - seeds = list() - for mdl_ch in mdl_chains: - if mdl_ch not in mapped_mdl_chains: - seeds.append((ref_ch, mdl_ch)) - if n_mdl_chains is not None and n_mdl_chains < len(seeds): - counts = [the_greed.SCCounts(s[0], s[1]) for s in seeds] - tmp = [(a,b) for a,b in zip(counts, seeds)] - tmp.sort(reverse=True) - seeds = [item[1] for item in tmp[:n_mdl_chains]] - for seed in seeds: - tmp_mapping = dict(mapping) - tmp_mapping[seed[0]] = seed[1] - tmp_mapping = the_greed.ExtendMapping(tmp_mapping) - tmp_lddt = the_greed.lDDTFromFlatMap(tmp_mapping) - if tmp_lddt > best_lddt: - best_lddt = tmp_lddt - best_mapping = tmp_mapping + for seed in seeds: - if best_lddt == 0.0: - break # no proper mapping found anymore... + # do initial extension + mapping = the_greed.ExtendMapping({seed[0]: seed[1]}) + # repeat the process until we have a full mapping something_happened = True - mapping = best_mapping + while something_happened: + something_happened = False + remnant_seeds = _GetSeeds(the_greed.ref_chem_groups, + the_greed.mdl_chem_groups, + mapped_ref_chains = set(mapping.keys()), + mapped_mdl_chains = set(mapping.values())) + if len(remnant_seeds) > 0: + # still more mapping to be done + best_score = -1.0 + best_mapping = None + for remnant_seed in remnant_seeds: + tmp_mapping = dict(mapping) + tmp_mapping[remnant_seed[0]] = remnant_seed[1] + tmp_mapping = the_greed.ExtendMapping(tmp_mapping) + score = the_greed.lDDTFromFlatMap(tmp_mapping) + if score > best_score: + best_score = score + best_mapping = tmp_mapping + if best_mapping is not None: + something_happened = True + mapping = best_mapping + + score = the_greed.lDDTFromFlatMap(mapping) + if score > best_overall_score: + best_overall_score = score + best_overall_mapping = mapping + + mapping = best_overall_mapping # translate mapping format and return final_mapping = list() @@ -2543,7 +2529,6 @@ def _lDDTGreedyBlock(the_greed, seed_size, blocks_per_chem_group): mdl_chem_groups = copy.deepcopy(the_greed.mdl_chem_groups) mapping = dict() - something_happened = True while something_happened: something_happened = False @@ -2694,12 +2679,6 @@ class _QSScoreGreedySearcher(qsscore.QSScorer): if len(mapping) == 0: raise RuntimError("Mapping must contain a starting point") - for ref_ch, mdl_ch in mapping.items(): - assert(ref_ch in self.ref_ch_group_mapper) - assert(mdl_ch in self.mdl_ch_group_mapper) - assert(self.ref_ch_group_mapper[ref_ch] == \ - self.mdl_ch_group_mapper[mdl_ch]) - # Ref chains onto which we can map. The algorithm starts with a mapping # on ref_ch. From there we can start to expand to connected neighbors. # All neighbors that we can reach from the already mapped chains are @@ -2859,19 +2838,6 @@ def _QSScoreNaive(trg, mdl, chem_groups, chem_mapping, ref_mdl_alns, contact_d, return (best_mapping, best_score) -def _GetSeeds(ref_chem_groups, mdl_chem_groups, mapped_ref_chains = set(), - mapped_mdl_chains = set()): - seeds = list() - for ref_chains, mdl_chains in zip(ref_chem_groups, - mdl_chem_groups): - for ref_ch in ref_chains: - if ref_ch not in mapped_ref_chains: - for mdl_ch in mdl_chains: - if mdl_ch not in mapped_mdl_chains: - seeds.append((ref_ch, mdl_ch)) - return seeds - - def _QSScoreGreedyFast(the_greed): something_happened = True @@ -2911,16 +2877,10 @@ def _QSScoreGreedyFast(the_greed): return final_mapping -def _QSScoreGreedyFull(the_greed, n_mdl_chains): +def _QSScoreGreedyFull(the_greed): """ Uses each reference chain as starting point for expansion - - However, not all mdl chain are mapped onto these reference chains, - that's controlled by *n_mdl_chains* """ - if n_mdl_chains is not None and n_mdl_chains < 1: - raise RuntimeError("n_mdl_chains must be None or >= 1") - seeds = _GetSeeds(the_greed.ref_chem_groups, the_greed.mdl_chem_groups) best_overall_score = -1.0 best_overall_mapping = dict()