From e2c3585a9595ce82bae4bf9249c74355f43217b4 Mon Sep 17 00:00:00 2001 From: Gabriel Studer <gabriel.studer@unibas.ch> Date: Tue, 6 Sep 2022 14:56:50 +0200 Subject: [PATCH] implement QS score based greedy chain mapping strategies --- modules/mol/alg/pymod/chain_mapping.py | 687 ++++++++++++++++---- modules/mol/alg/pymod/qsscore.py | 32 +- modules/mol/alg/tests/test_chain_mapping.py | 25 +- 3 files changed, 618 insertions(+), 126 deletions(-) diff --git a/modules/mol/alg/pymod/chain_mapping.py b/modules/mol/alg/pymod/chain_mapping.py index d4ab0e9db..7573286c8 100644 --- a/modules/mol/alg/pymod/chain_mapping.py +++ b/modules/mol/alg/pymod/chain_mapping.py @@ -690,11 +690,11 @@ class ChainMapper: self.n_max_naive) else: # its one of the greedy strategies - setup greedy searcher - the_greed = _GreedySearcher(self.target, mdl, self.chem_groups, - chem_mapping, ref_mdl_alns, - inclusion_radius=inclusion_radius, - thresholds=thresholds, - steep_opt_rate=steep_opt_rate) + the_greed = _lDDTGreedySearcher(self.target, mdl, self.chem_groups, + chem_mapping, ref_mdl_alns, + inclusion_radius=inclusion_radius, + thresholds=thresholds, + steep_opt_rate=steep_opt_rate) if strategy == "greedy_fast": mapping = _lDDTGreedyFast(the_greed) elif strategy == "greedy_full": @@ -716,7 +716,106 @@ class ChainMapper: alns) - def GetRigidMapping(self, model, strategy = "greedy_single", + def GetQSScoreMapping(self, model, contact_d = 12.0, strategy = "naive", + full_n_mdl_chains = None, block_seed_size = 5, + block_blocks_per_chem_group = 5, + steep_opt_rate = None): + """ Identify chain mapping based on QSScore + + Scoring is based on CA/C3' positions which are present in all chains of + a :attr:`chem_groups` as well as the *model* chains which are mapped to + that respective chem group. QS score is not defined for single chains. + The greedy strategies that require to identify starting seeds thus + often rely on single chain lDDTs. + + The following strategies are available: + + * **naive**: Naively iterate all possible mappings and return best based + on QS score. + + * **greedy_fast**: perform all vs. all single chain lDDTs within the + respective ref/mdl chem groups. The mapping with highest number of + conserved contacts is selected as seed for greedy extension. + Extension is based on QS score. + + * **greedy_full**: try multiple seeds for greedy extension, i.e. try + all ref/mdl chain combinations within the respective chem groups and + retain the mapping leading to the best QS score. Optionally, you can + reduce the number of mdl chains per ref chain to the + *full_n_mdl_chains* best scoring with respect to single chain lDDT. + + * **greedy_block**: try multiple seeds for greedy extension, i.e. try + all ref/mdl chain combinations within the respective chem groups and + compute single chain lDDTs. The *block_blocks_per_chem_group* best + scoring ones are extend by *block_seed_size* chains and the block with + with best QS score is exhaustively extended. + + :param model: Model to map + :type model: :class:`ost.mol.EntityView`/:class:`ost.mol.EntityHandle` + :param contact_d: Max distance between two residues to be considered as + contact in qs scoring + :type contact_d: :class:`float` + :param strategy: Strategy for sampling, must be in ["naive"] + :type strategy: :class:`str` + :returns: A :class:`MappingResult` + """ + + strategies = ["naive", "greedy_fast", "greedy_full", "greedy_block"] + if strategy not in strategies: + raise RuntimeError(f"strategy must be {strategies}") + + chem_mapping, chem_group_alns, mdl = self.GetChemMapping(model) + ref_mdl_alns = _GetRefMdlAlns(self.chem_groups, + self.chem_group_alignments, + chem_mapping, + chem_group_alns) + + # check for the simplest case + one_to_one = _CheckOneToOneMapping(self.chem_groups, chem_mapping) + if one_to_one is not None: + alns = dict() + for ref_group, mdl_group in zip(self.chem_groups, one_to_one): + for ref_ch, mdl_ch in zip(ref_group, mdl_group): + if ref_ch is not None and mdl_ch is not None: + aln = ref_mdl_alns[(ref_ch, mdl_ch)] + aln.AttachView(0, self.target.Select(f"cname={ref_ch}")) + aln.AttachView(1, mdl.Select(f"cname={mdl_ch}")) + alns[(ref_ch, mdl_ch)] = aln + return MappingResult(self.target, mdl, self.chem_groups, one_to_one, + alns) + + if strategy == "naive": + mapping = _QSScoreNaive(self.target, mdl, self.chem_groups, + chem_mapping, ref_mdl_alns, contact_d, + self.n_max_naive) + else: + # its one of the greedy strategies - setup greedy searcher + + the_greed = _QSScoreGreedySearcher(self.target, mdl, self.chem_groups, + chem_mapping, ref_mdl_alns, + contact_d = contact_d, + steep_opt_rate=steep_opt_rate) + if strategy == "greedy_fast": + mapping = _QSScoreGreedyFast(the_greed) + elif strategy == "greedy_full": + mapping = _QSScoreGreedyFull(the_greed, full_n_mdl_chains) + elif strategy == "greedy_block": + mapping = _QSScoreGreedyBlock(the_greed, block_seed_size, + block_blocks_per_chem_group) + + alns = dict() + for ref_group, mdl_group in zip(self.chem_groups, mapping): + for ref_ch, mdl_ch in zip(ref_group, mdl_group): + if ref_ch is not None and mdl_ch is not None: + aln = ref_mdl_alns[(ref_ch, mdl_ch)] + aln.AttachView(0, self.target.Select(f"cname={ref_ch}")) + aln.AttachView(1, mdl.Select(f"cname={mdl_ch}")) + alns[(ref_ch, mdl_ch)] = aln + + return MappingResult(self.target, mdl, self.chem_groups, mapping, + alns) + + def GetRigidMapping(self, model, strategy = "greedy_single_gdtts", single_chain_gdtts_thresh=0.4, subsampling=None, first_complete=False, iterative_superposition=False): """Identify chain mapping based on rigid superposition @@ -731,14 +830,14 @@ class ChainMapper: There are three extension strategies: - * **greedy_single**: Iteratively add the model/target chain pair that - adds the most conserved contacts based on the GDT-TS metric + * **greedy_single_gdtts**: Iteratively add the model/target chain pair + that adds the most conserved contacts based on the GDT-TS metric (Number of CA/C3' atoms within [8, 4, 2, 1] Angstrom). The mapping with highest GDT-TS score is returned. However, that mapping is not guaranteed to be complete (see *single_chain_gdtts_thresh*). - * **greedy_iterative**: Same as single except that the transformation - gets updated with each added chain pair. + * **greedy_iterative_gdtts**: Same as single except that the + transformation gets updated with each added chain pair. * **greedy_iterative_rmsd**: Same as iterative, i.e. the transformation gets updated with each added chain pair. However, @@ -776,7 +875,8 @@ class ChainMapper: :returns: A :class:`MappingResult` """ - strategies = ["greedy_single", "greedy_iterative", "greedy_iterative_rmsd"] + strategies = ["greedy_single_gdtts", "greedy_iterative_gdtts", + "greedy_iterative_rmsd"] if strategy not in strategies: raise RuntimeError(f"strategy must be {strategies}") @@ -825,22 +925,24 @@ class ChainMapper: initial_transforms.append(transform) initial_mappings.append((t,m)) - - if strategy == "greedy_single": - mapping = _SingleRigid(initial_transforms, initial_mappings, - self.chem_groups, chem_mapping, - trg_group_pos, mdl_group_pos, - single_chain_gdtts_thresh, - iterative_superposition, first_complete, - len(self.target.chains), len(mdl.chains)) - - elif strategy == "greedy_iterative": - mapping = _IterativeRigid(initial_transforms, initial_mappings, - self.chem_groups, chem_mapping, - trg_group_pos, mdl_group_pos, - single_chain_gdtts_thresh, - iterative_superposition, first_complete, - len(self.target.chains), len(mdl.chains)) + if strategy == "greedy_single_gdtts": + mapping = _SingleRigidGDTTS(initial_transforms, initial_mappings, + self.chem_groups, chem_mapping, + trg_group_pos, mdl_group_pos, + single_chain_gdtts_thresh, + iterative_superposition, first_complete, + len(self.target.chains), + len(mdl.chains)) + + elif strategy == "greedy_iterative_gdtts": + mapping = _IterativeRigidGDTTS(initial_transforms, initial_mappings, + self.chem_groups, chem_mapping, + trg_group_pos, mdl_group_pos, + single_chain_gdtts_thresh, + iterative_superposition, + first_complete, + len(self.target.chains), + len(mdl.chains)) elif strategy == "greedy_iterative_rmsd": mapping = _IterativeRigidRMSD(initial_transforms, initial_mappings, @@ -872,69 +974,6 @@ class ChainMapper: alns) - def GetQSScoreMapping(self, model, contact_d = 12.0, strategy = "naive"): - """ Identify chain mapping based on QSScore - - Scoring is based on CA/C3' positions which are present in all chains of - a :attr:`chem_groups` as well as the *model* chains which are mapped to - that respective chem group. - - There is currently one sampling strategy: - - * **naive**: Naively iterate all possible mappings and return best based - on QS score. - - :param model: Model to map - :type model: :class:`ost.mol.EntityView`/:class:`ost.mol.EntityHandle` - :param contact_d: Max distance between two residues to be considered as - contact in qs scoring - :type contact_d: :class:`float` - :param strategy: Strategy for sampling, must be in ["naive"] - :type strategy: :class:`str` - :returns: A :class:`MappingResult` - """ - - strategies = ["naive"] - if strategy not in strategies: - raise RuntimeError(f"strategy must be {strategies}") - - chem_mapping, chem_group_alns, mdl = self.GetChemMapping(model) - ref_mdl_alns = _GetRefMdlAlns(self.chem_groups, - self.chem_group_alignments, - chem_mapping, - chem_group_alns) - - # check for the simplest case - one_to_one = _CheckOneToOneMapping(self.chem_groups, chem_mapping) - if one_to_one is not None: - alns = dict() - for ref_group, mdl_group in zip(self.chem_groups, one_to_one): - for ref_ch, mdl_ch in zip(ref_group, mdl_group): - if ref_ch is not None and mdl_ch is not None: - aln = ref_mdl_alns[(ref_ch, mdl_ch)] - aln.AttachView(0, self.target.Select(f"cname={ref_ch}")) - aln.AttachView(1, mdl.Select(f"cname={mdl_ch}")) - alns[(ref_ch, mdl_ch)] = aln - return MappingResult(self.target, mdl, self.chem_groups, one_to_one, - alns) - - if strategy == "naive": - mapping = _QSScoreNaive(self.target, mdl, self.chem_groups, - chem_mapping, ref_mdl_alns, contact_d, - self.n_max_naive) - - alns = dict() - for ref_group, mdl_group in zip(self.chem_groups, mapping): - for ref_ch, mdl_ch in zip(ref_group, mdl_group): - if ref_ch is not None and mdl_ch is not None: - aln = ref_mdl_alns[(ref_ch, mdl_ch)] - aln.AttachView(0, self.target.Select(f"cname={ref_ch}")) - aln.AttachView(1, mdl.Select(f"cname={mdl_ch}")) - alns[(ref_ch, mdl_ch)] = aln - - return MappingResult(self.target, mdl, self.chem_groups, mapping, - alns) - def GetRepr(self, substructure, model, topn=1, inclusion_radius=15.0, thresholds=[0.5, 1.0, 2.0, 4.0], bb_only=False, only_interchain=False): @@ -1742,7 +1781,7 @@ class _lDDTDecomposer: self.interface_cache[k2] = conserved return self.interface_cache[k1] -class _GreedySearcher(_lDDTDecomposer): +class _lDDTGreedySearcher(_lDDTDecomposer): def __init__(self, ref, mdl, ref_chem_groups, mdl_chem_groups, ref_mdl_alns, inclusion_radius = 15.0, thresholds = [0.5, 1.0, 2.0, 4.0], @@ -2157,10 +2196,449 @@ def _lDDTGreedyBlock(the_greed, seed_size, blocks_per_chem_group): return final_mapping -def _SingleRigid(initial_transforms, initial_mappings, chem_groups, - chem_mapping, trg_group_pos, mdl_group_pos, - single_chain_gdtts_thresh, iterative_superposition, - first_complete, n_trg_chains, n_mdl_chains): +class _QSScoreGreedySearcher(qsscore.QSScorer): + def __init__(self, ref, mdl, ref_chem_groups, mdl_chem_groups, + ref_mdl_alns, contact_d = 12.0, + steep_opt_rate = None): + """ Greedy extension of already existing but incomplete chain mappings + """ + super().__init__(ref, ref_chem_groups, mdl, ref_mdl_alns, + contact_d = contact_d) + self.ref = ref + self.mdl = mdl + self.ref_mdl_alns = ref_mdl_alns + self.steep_opt_rate = steep_opt_rate + + self.neighbors = {k: set() for k in self.qsent1.chain_names} + for p in self.qsent1.interacting_chains: + self.neighbors[p[0]].add(p[1]) + self.neighbors[p[1]].add(p[0]) + + self.mdl_neighbors = {k: set() for k in self.qsent2.chain_names} + for p in self.qsent2.interacting_chains: + self.mdl_neighbors[p[0]].add(p[1]) + self.mdl_neighbors[p[1]].add(p[0]) + + assert(len(ref_chem_groups) == len(mdl_chem_groups)) + self.ref_chem_groups = ref_chem_groups + self.mdl_chem_groups = mdl_chem_groups + self.ref_ch_group_mapper = dict() + self.mdl_ch_group_mapper = dict() + for g_idx, (ref_g, mdl_g) in enumerate(zip(ref_chem_groups, + mdl_chem_groups)): + for ch in ref_g: + self.ref_ch_group_mapper[ch] = g_idx + for ch in mdl_g: + self.mdl_ch_group_mapper[ch] = g_idx + + # cache for lDDT based single chain conserved contacts + # used to identify starting points for further extension by QS score + # key: tuple (ref_ch, mdl_ch) value: number of conserved contacts + self.single_chain_scorer = dict() + self.single_chain_cache = dict() + for ch in self.ref.chains: + single_chain_ref = self.ref.Select(f"cname={ch.GetName()}") + self.single_chain_scorer[ch.GetName()] = \ + lddt.lDDTScorer(single_chain_ref, bb_only = True) + + def SCCounts(self, ref_ch, mdl_ch): + if not (ref_ch, mdl_ch) in self.single_chain_cache: + alns = dict() + alns[mdl_ch] = self.ref_mdl_alns[(ref_ch, mdl_ch)] + mdl_sel = self.mdl.Select(f"cname={mdl_ch}") + s = self.single_chain_scorer[ref_ch] + _,_,_,conserved,_,_,_ = s.lDDT(mdl_sel, + residue_mapping=alns, + return_dist_test=True, + no_interchain=True, + chain_mapping={mdl_ch: ref_ch}, + check_resnames=False) + self.single_chain_cache[(ref_ch, mdl_ch)] = conserved + return self.single_chain_cache[(ref_ch, mdl_ch)] + + def ExtendMapping(self, mapping, max_ext = None): + + if len(mapping) == 0: + raise RuntimError("Mapping must contain a starting point") + + for ref_ch, mdl_ch in mapping.items(): + assert(ref_ch in self.ref_ch_group_mapper) + assert(mdl_ch in self.mdl_ch_group_mapper) + assert(self.ref_ch_group_mapper[ref_ch] == \ + self.mdl_ch_group_mapper[mdl_ch]) + + # Ref chains onto which we can map. The algorithm starts with a mapping + # on ref_ch. From there we can start to expand to connected neighbors. + # All neighbors that we can reach from the already mapped chains are + # stored in this set which will be updated during runtime + map_targets = set() + for ref_ch in mapping.keys(): + map_targets.update(self.neighbors[ref_ch]) + + # remove the already mapped chains + for ref_ch in mapping.keys(): + map_targets.discard(ref_ch) + + if len(map_targets) == 0: + return mapping # nothing to extend + + # keep track of what model chains are not yet mapped for each chem group + free_mdl_chains = list() + for chem_group in self.mdl_chem_groups: + tmp = [x for x in chem_group if x not in mapping.values()] + free_mdl_chains.append(set(tmp)) + + # keep track of what ref chains got a mapping + newly_mapped_ref_chains = list() + + something_happened = True + while something_happened: + something_happened=False + + if self.steep_opt_rate is not None: + n_chains = len(newly_mapped_ref_chains) + if n_chains > 0 and n_chains % self.steep_opt_rate == 0: + mapping = self._SteepOpt(mapping, newly_mapped_ref_chains) + + if max_ext is not None and len(newly_mapped_ref_chains) >= max_ext: + break + + # nominator and denominator to determine current QS score + nominator, denominator = self._FromFlatMapping(mapping) + old_score = 0.0 + if denominator != 0.0: + old_score = nominator/denominator + + max_diff = 0.0 + max_mapping = None + for ref_ch in map_targets: + chem_group_idx = self.ref_ch_group_mapper[ref_ch] + for mdl_ch in free_mdl_chains[chem_group_idx]: + nominator_diff = 0.0 + denominator_diff = 0.0 + for neighbor in self.neighbors[ref_ch]: + if neighbor in mapping and mapping[neighbor] in \ + self.mdl_neighbors[mdl_ch]: + # it's a newly added interface if (ref_ch, mdl_ch) + # are added to mapping + int1 = (ref_ch, neighbor) + int2 = (mdl_ch, mapping[neighbor]) + a, b = self._MappedInterfaceScores(int1, int2) + nominator_diff += a + denominator_diff += b + # the respective interface penalties are subtracted + # from denominator + denominator_diff -= self._InterfacePenalty1(int1) + denominator_diff -= self._InterfacePenalty2(int2) + + if nominator_diff > 0: + # Only accept a new solution if its actually connected + # i.e. nominator_diff > 0. + new_nominator = nominator + nominator_diff + new_denominator = denominator + denominator_diff + new_score = 0.0 + if new_denominator != 0.0: + new_score = new_nominator/new_denominator + diff = new_score - old_score + if diff > max_diff: + max_diff = diff + max_mapping = (ref_ch, mdl_ch) + + if max_mapping is not None: + something_happened = True + # assign new found mapping + mapping[max_mapping[0]] = max_mapping[1] + + # add all neighboring chains to map targets as they are now + # reachable + for neighbor in self.neighbors[max_mapping[0]]: + if neighbor not in mapping: + map_targets.add(neighbor) + + # remove the ref chain from map targets + map_targets.remove(max_mapping[0]) + + # remove the mdl chain from free_mdl_chains - its taken... + chem_group_idx = self.ref_ch_group_mapper[max_mapping[0]] + free_mdl_chains[chem_group_idx].remove(max_mapping[1]) + + # keep track of what ref chains got a mapping + newly_mapped_ref_chains.append(max_mapping[0]) + + return mapping + + def _SteepOpt(self, mapping, chains_to_optimize=None): + + # just optimize ALL ref chains if nothing specified + if chains_to_optimize is None: + chains_to_optimize = mapping.keys() + + # make sure that we only have ref chains which are actually mapped + ref_chains = [x for x in chains_to_optimize if mapping[x] is not None] + + # group ref chains to be optimized into chem groups + tmp = dict() + for ch in ref_chains: + chem_group_idx = self.ref_ch_group_mapper[ch] + if chem_group_idx in tmp: + tmp[chem_group_idx].append(ch) + else: + tmp[chem_group_idx] = [ch] + chem_groups = list(tmp.values()) + + # try all possible mapping swaps. Swaps that improve the score are + # immediately accepted and we start all over again + nominator, denominator = self._fromFlatMapping(mapping) + current_score = 0.0 + if denominator != 0.0: + current_score = nominator / denominator + something_happened = True + while something_happened: + something_happened = False + for chem_group in chem_groups: + if something_happened: + break + for ch1, ch2 in itertools.combinations(chem_group, 2): + swapped_mapping = dict(mapping) + swapped_mapping[ch1] = mapping[ch2] + swapped_mapping[ch2] = mapping[ch1] + a, b = self._FromFlatMapping(swapped_mapping) + score = 0.0 + if b != 0.0: + score = a/b + if score > current_score: + something_happened = True + mapping = swapped_mapping + current_score = score + break + + return mapping + + +def _QSScoreNaive(trg, mdl, chem_groups, chem_mapping, ref_mdl_alns, contact_d, + n_max_naive): + best_mapping = None + best_score = -1.0 + # qs_scorer implements caching, score calculation is thus as fast as it gets + # you'll just hit a wall when the number of possible mappings becomes large + qs_scorer = qsscore.QSScorer(trg, chem_groups, mdl, ref_mdl_alns) + for mapping in _ChainMappings(chem_groups, chem_mapping, n_max_naive): + score = qs_scorer.GetQSScore(mapping, check=False) + if score > best_score: + best_mapping = mapping + best_score = score + return best_mapping + + +def _QSScoreGreedyFast(the_greed): + + something_happened = True + mapping = dict() + + while something_happened: + something_happened = False + # search for best scoring starting point, we're using lDDT here + n_best = 0 + best_seed = None + mapped_ref_chains = set(mapping.keys()) + mapped_mdl_chains = set(mapping.values()) + for ref_chains, mdl_chains in zip(the_greed.ref_chem_groups, + the_greed.mdl_chem_groups): + for ref_ch in ref_chains: + if ref_ch not in mapped_ref_chains: + for mdl_ch in mdl_chains: + if mdl_ch not in mapped_mdl_chains: + n = the_greed.SCCounts(ref_ch, mdl_ch) + if n > n_best: + n_best = n + best_seed = (ref_ch, mdl_ch) + if n_best == 0: + break # no proper seed found anymore... + # add seed to mapping and start the greed + mapping[best_seed[0]] = best_seed[1] + mapping = the_greed.ExtendMapping(mapping) + something_happened = True + + + # translate mapping format and return + final_mapping = list() + for ref_chains in the_greed.ref_chem_groups: + mapped_mdl_chains = list() + for ref_ch in ref_chains: + if ref_ch in mapping: + mapped_mdl_chains.append(mapping[ref_ch]) + else: + mapped_mdl_chains.append(None) + final_mapping.append(mapped_mdl_chains) + + return final_mapping + + +def _QSScoreGreedyFull(the_greed, n_mdl_chains): + """ Uses each reference chain as starting point for expansion + + However, not all mdl chain are mapped onto these reference chains, + that's controlled by *n_mdl_chains* + """ + + if n_mdl_chains is not None and n_mdl_chains < 1: + raise RuntimeError("n_mdl_chains must be None or >= 1") + + something_happened = True + mapping = dict() + + while something_happened: + something_happened = False + # Try all possible starting points and keep the one giving the best QS score + best_score = 0.0 + best_mapping = None + mapped_ref_chains = set(mapping.keys()) + mapped_mdl_chains = set(mapping.values()) + for ref_chains, mdl_chains in zip(the_greed.ref_chem_groups, + the_greed.mdl_chem_groups): + for ref_ch in ref_chains: + if ref_ch not in mapped_ref_chains: + seeds = list() + for mdl_ch in mdl_chains: + if mdl_ch not in mapped_mdl_chains: + seeds.append((ref_ch, mdl_ch)) + if n_mdl_chains is not None and n_mdl_chains < len(seeds): + counts = [the_greed.SCCounts(s[0], s[1]) for s in seeds] + tmp = [(a,b) for a,b in zip(counts, seeds)] + tmp.sort(reverse=True) + seeds = [item[1] for item in tmp[:n_mdl_chains]] + for seed in seeds: + tmp_mapping = dict(mapping) + tmp_mapping[seed[0]] = seed[1] + tmp_mapping = the_greed.ExtendMapping(tmp_mapping) + a, b = the_greed._FromFlatMapping(tmp_mapping) + tmp_score = 0.0 + if b != 0.0: + tmp_score = a/b + if tmp_score > best_score: + best_score = tmp_score + best_mapping = tmp_mapping + + if best_score == 0.0: + break # no proper mapping found anymore... + + something_happened = True + mapping = best_mapping + + # translate mapping format and return + final_mapping = list() + for ref_chains in the_greed.ref_chem_groups: + mapped_mdl_chains = list() + for ref_ch in ref_chains: + if ref_ch in mapping: + mapped_mdl_chains.append(mapping[ref_ch]) + else: + mapped_mdl_chains.append(None) + final_mapping.append(mapped_mdl_chains) + + return final_mapping + + +def _QSScoreGreedyBlock(the_greed, seed_size, blocks_per_chem_group): + """ try multiple seeds, i.e. try all ref/mdl chain combinations within the + respective chem groups and compute single chain lDDTs. The + *blocks_per_chem_group* best scoring ones are extend by *seed_size* chains + and the best scoring one with respect to QS score is exhaustively extended. + """ + + if seed_size is None or seed_size < 1: + raise RuntimeError(f"seed_size must be an int >= 1 (got {seed_size})") + + if blocks_per_chem_group is None or blocks_per_chem_group < 1: + raise RuntimeError(f"blocks_per_chem_group must be an int >= 1 " + f"(got {blocks_per_chem_group})") + + max_ext = seed_size - 1 # -1 => start seed already has size 1 + + ref_chem_groups = copy.deepcopy(the_greed.ref_chem_groups) + mdl_chem_groups = copy.deepcopy(the_greed.mdl_chem_groups) + + mapping = dict() + + something_happened = True + while something_happened: + something_happened = False + starting_blocks = list() + for ref_chains, mdl_chains in zip(ref_chem_groups, mdl_chem_groups): + if len(mdl_chains) == 0: + continue # nothing to map + + # Identify starting seeds for *blocks_per_chem_group* blocks + # thats done with lDDT + seeds = list() + for ref_ch in ref_chains: + seeds += [(ref_ch, mdl_ch) for mdl_ch in mdl_chains] + counts = [the_greed.SCCounts(s[0], s[1]) for s in seeds] + tmp = [(a,b) for a,b in zip(counts, seeds)] + tmp.sort(reverse=True) + seeds = [item[1] for item in tmp[:blocks_per_chem_group]] + + # extend starting seeds to *seed_size* and retain best scoring block + # for further extension + best_score = 0.0 + best_mapping = None + for s in seeds: + seed = dict(mapping) + seed.update({s[0]: s[1]}) + seed = the_greed.ExtendMapping(seed, max_ext = max_ext) + a, b = the_greed._FromFlatMapping(seed) + seed_score = 0.0 + if b != 0.0: + seed_score = a/b + if seed_score > best_score: + best_score = seed_score + best_mapping = seed + if best_mapping != None: + starting_blocks.append(best_mapping) + + # fully expand initial starting blocks + best_score = 0.0 + best_mapping = None + for seed in starting_blocks: + seed = the_greed.ExtendMapping(seed) + a, b = the_greed._FromFlatMapping(seed) + seed_score = 0.0 + if b != 0.0: + seed_score = a/b + if seed_score > best_score: + best_score = seed_score + best_mapping = seed + + if best_score == 0.0: + break # no proper mapping found anymore + + something_happened = True + mapping.update(best_mapping) + for ref_ch, mdl_ch in best_mapping.items(): + for group_idx in range(len(ref_chem_groups)): + if ref_ch in ref_chem_groups[group_idx]: + ref_chem_groups[group_idx].remove(ref_ch) + if mdl_ch in mdl_chem_groups[group_idx]: + mdl_chem_groups[group_idx].remove(mdl_ch) + + # translate mapping format and return + final_mapping = list() + for ref_chains in the_greed.ref_chem_groups: + mapped_mdl_chains = list() + for ref_ch in ref_chains: + if ref_ch in mapping: + mapped_mdl_chains.append(mapping[ref_ch]) + else: + mapped_mdl_chains.append(None) + final_mapping.append(mapped_mdl_chains) + + return final_mapping + + +def _SingleRigidGDTTS(initial_transforms, initial_mappings, chem_groups, + chem_mapping, trg_group_pos, mdl_group_pos, + single_chain_gdtts_thresh, iterative_superposition, + first_complete, n_trg_chains, n_mdl_chains): """ Takes initial transforms and sequentially adds chain pairs with best scoring gdtts that fulfill single_chain_gdtts_thresh. The mapping from the transform that leads to best overall gdtts score is returned. @@ -2221,10 +2699,10 @@ def _SingleRigid(initial_transforms, initial_mappings, chem_groups, return best_mapping -def _IterativeRigid(initial_transforms, initial_mappings, chem_groups, - chem_mapping, trg_group_pos, mdl_group_pos, - single_chain_gdtts_thresh, iterative_superposition, - first_complete, n_trg_chains, n_mdl_chains): +def _IterativeRigidGDTTS(initial_transforms, initial_mappings, chem_groups, + chem_mapping, trg_group_pos, mdl_group_pos, + single_chain_gdtts_thresh, iterative_superposition, + first_complete, n_trg_chains, n_mdl_chains): """ Takes initial transforms and sequentially adds chain pairs with best scoring gdtts that fulfill single_chain_gdtts_thresh. With each added chain pair, the transform gets updated. Thus the naming iterative. @@ -2403,17 +2881,6 @@ def _IterativeRigidRMSD(initial_transforms, initial_mappings, chem_groups, return best_mapping -def _QSScoreNaive(trg, mdl, chem_groups, chem_mapping, ref_mdl_alns, contact_d, - n_max_naive): - best_mapping = None - best_score = -1.0 - qs_scorer = qsscore.QSScorer(trg, chem_groups, mdl, ref_mdl_alns) - for mapping in _ChainMappings(chem_groups, chem_mapping, n_max_naive): - score = qs_scorer.GetQSScore(mapping, check=False) - if score > best_score: - best_mapping = mapping - best_score = score - return best_mapping def _GetRefPos(trg, mdl, trg_msas, mdl_alns, max_pos = None): """ Extracts reference positions which are present in trg and mdl diff --git a/modules/mol/alg/pymod/qsscore.py b/modules/mol/alg/pymod/qsscore.py index 54c4a3d5a..391499ae5 100644 --- a/modules/mol/alg/pymod/qsscore.py +++ b/modules/mol/alg/pymod/qsscore.py @@ -326,6 +326,16 @@ class QSScorer: for a, b in zip(self.chem_groups, mapping): flat_mapping.update({x: y for x, y in zip(a, b) if y is not None}) + # refers to equation 6 in Bertoni et al., 2017 + nominator, denominator = self._FromFlatMapping(flat_mapping) + + if denominator > 0.0: + return nominator / denominator + else: + return 0.0 + + def _FromFlatMapping(self, flat_mapping): + # refers to equation 6 in Bertoni et al., 2017 nominator = 0.0 denominator= 0.0 @@ -349,15 +359,18 @@ class QSScorer: if int2 not in processed_qsent2_interfaces: denominator += self._InterfacePenalty2(int2) - if denominator > 0.0: - return nominator / denominator - else: - return 0.0 + return (nominator, denominator) def _MappedInterfaceScores(self, int1, int2): - if (int1, int2) not in self._mapped_cache: - self._mapped_cache[(int1, int2)] = self._InterfaceScores(int1, int2) - return self._mapped_cache[(int1, int2)] + key_one = (int1, int2) + if key_one in self._mapped_cache: + return self._mapped_cache[key_one] + key_two = ((int1[1], int1[0]), (int2[1], int2[0])) + if key_two in self._mapped_cache: + return self._mapped_cache[key_two] + nominator, denominator = self._InterfaceScores(int1, int2) + self._mapped_cache[key_one] = (nominator, denominator) + return (nominator, denominator) def _InterfaceScores(self, int1, int2): @@ -366,12 +379,11 @@ class QSScorer: # given two chain names a and b: if a < b, shape of pairwise distances is # (len(a), len(b)). However, if b > a, its (len(b), len(a)) => transpose + if int1[0] > int1[1]: + d1 = d1.transpose() if int2[0] > int2[1]: d2 = d2.transpose() - # should be given by design => no need to transpose d1 - assert(int1[0] < int1[1]) - # indices of the first chain in the two interfaces mapped_indices_1_1, mapped_indices_1_2 = \ self._IndexMapping(int1[0], int2[0]) diff --git a/modules/mol/alg/tests/test_chain_mapping.py b/modules/mol/alg/tests/test_chain_mapping.py index 6e41fad79..d9a9362eb 100644 --- a/modules/mol/alg/tests/test_chain_mapping.py +++ b/modules/mol/alg/tests/test_chain_mapping.py @@ -245,6 +245,7 @@ class TestChainMapper(unittest.TestCase): # This is not supposed to be in depth algorithm testing, we just check # whether the various algorithms return sensible chain mappings + # lDDT based chain mappings naive_lddt_res = mapper.GetlDDTMapping(mdl, strategy="naive") self.assertEqual(naive_lddt_res.mapping, [['X', 'Y'],[None],['Z']]) @@ -258,19 +259,31 @@ class TestChainMapper(unittest.TestCase): greedy_lddt_res = mapper.GetlDDTMapping(mdl, strategy="greedy_block") self.assertEqual(greedy_lddt_res.mapping, [['X', 'Y'],[None],['Z']]) - greedy_rigid_res = mapper.GetRigidMapping(mdl, strategy="greedy_single") + + # QS score based chain mappings + naive_qsscore_res = mapper.GetQSScoreMapping(mdl, strategy="naive") + self.assertEqual(naive_qsscore_res.mapping, [['X', 'Y'],[None],['Z']]) + + greedy_qsscore_res = mapper.GetQSScoreMapping(mdl, strategy="greedy_fast") + self.assertEqual(naive_qsscore_res.mapping, [['X', 'Y'],[None],['Z']]) + + greedy_qsscore_res = mapper.GetQSScoreMapping(mdl, strategy="greedy_full") + self.assertEqual(naive_qsscore_res.mapping, [['X', 'Y'],[None],['Z']]) + + greedy_qsscore_res = mapper.GetQSScoreMapping(mdl, strategy="greedy_block") + self.assertEqual(naive_qsscore_res.mapping, [['X', 'Y'],[None],['Z']]) + + + # rigid chain mappings + greedy_rigid_res = mapper.GetRigidMapping(mdl, strategy="greedy_single_gdtts") self.assertEqual(greedy_rigid_res.mapping, [['X', 'Y'],[None],['Z']]) - greedy_rigid_res = mapper.GetRigidMapping(mdl, strategy="greedy_iterative") + greedy_rigid_res = mapper.GetRigidMapping(mdl, strategy="greedy_iterative_gdtts") self.assertEqual(greedy_rigid_res.mapping, [['X', 'Y'],[None],['Z']]) greedy_rigid_res = mapper.GetRigidMapping(mdl, strategy="greedy_iterative_rmsd") self.assertEqual(greedy_rigid_res.mapping, [['X', 'Y'],[None],['Z']]) - naive_qsscore_res = mapper.GetQSScoreMapping(mdl, strategy="naive") - self.assertEqual(naive_qsscore_res.mapping, [['X', 'Y'],[None],['Z']]) - - if __name__ == "__main__": from ost import testutils if testutils.SetDefaultCompoundLib(): -- GitLab