diff --git a/modules/mol/alg/pymod/chain_mapping.py b/modules/mol/alg/pymod/chain_mapping.py index 3c2cb99201e2fff723cb90063b1b226f07ff2629..6b66f8aea305aba6d9a42077d54cc4f9d33575ca 100644 --- a/modules/mol/alg/pymod/chain_mapping.py +++ b/modules/mol/alg/pymod/chain_mapping.py @@ -426,18 +426,18 @@ class ChainMapper: Several strategies exist to identify the start seed(s): - * fast: perform all vs. all single chain lDDTs within the respective + * **fast**: perform all vs. all single chain lDDTs within the respective ref/mdl chem groups. The mapping with highest number of conserved contacts is selected as seed for extension - * full: try multiple seeds, i.e. try all ref/mdl chain combinations + * **full**: try multiple seeds, i.e. try all ref/mdl chain combinations within the respective chem groups and retain the mapping leading to - the best lDDT. Optionally, you can reduce the number of mdl chains - per ref chain to the *full_n_mdl_chains* best scoring ones. + the best lDDT. Optionally, you can reduce the number of mdl chains per + ref chain to the *full_n_mdl_chains* best scoring ones. - * block: try multiple seeds, i.e. try all ref/mdl chain combinations - within the respective chem groups and compute single chain lDDTs. - The *block_blocks_per_chem_group* best scoring ones are extend by + * **block**: try multiple seeds, i.e. try all ref/mdl chain combinations + within the respective chem groups and compute single chain lDDTs. The + *block_blocks_per_chem_group* best scoring ones are extend by *block_seed_size* chains and the best scoring one is exhaustively extended. @@ -447,7 +447,9 @@ class ChainMapper: :type inclusion_radius: :class:`float` :param thresholds: Thresholds for lDDT :type thresholds: :class:`list` of :class:`float` - :param seed_strategy: Strategy to pick starting seeds for expansion + :param seed_strategy: Strategy to pick starting seeds for expansion. + Must be in ["fast", "full", "block"] + :type seed_strategy: :class:`str` :param steep_opt_rate: If set, every *steep_opt_rate* mappings, a simple optimization is executed with the goal of @@ -511,9 +513,9 @@ class ChainMapper: return _BlockGreedy(the_greed, block_seed_size, block_blocks_per_chem_group) - def GetRigidMapping(self, model, single_chain_gdtts_thresh=0.4, - subsampling=None, first_complete=False, - iterative_superposition=False): + def GetGreedyRigidMapping(self, model, strategy = "single", + single_chain_gdtts_thresh=0.4, subsampling=None, + first_complete=False, iterative_superposition=False): """Identify chain mapping based on rigid superposition Superposition and scoring is based on CA/C3' positions which are present @@ -522,15 +524,25 @@ class ChainMapper: Transformations to superpose *model* onto :attr:`ChainMapper.target` are estimated using all possible combinations of target and model chains - within the same chem groups. For each transformation, the mapping is - extended by iteratively adding the model/target chain pair that adds - the most conserved contacts based on the GDT-TS metric (Number of - CA/C3' atoms within [8, 4, 2, 1] Angstrom). The mapping with highest - GDT-TS score is returned. However, that mapping is not guaranteed to be - complete (see *single_chain_gdtts_thresh*). + within the same chem groups and build the basis for further extension. + + There are two extension strategies: + + * **single**: Iteratively add the model/target chain pair that adds the + most conserved contacts based on the GDT-TS metric (Number of CA/C3' + atoms within [8, 4, 2, 1] Angstrom). The mapping with highest GDT-TS + score is returned. However, that mapping is not guaranteed to be + complete (see *single_chain_gdtts_thresh*). + + * **iterative**: Same as single except that the transformation gets + updated with each added chain pair. :param model: Model to map :type model: :class:`ost.mol.EntityView`/:class:`ost.mol.EntityHandle` + :param strategy: Strategy to extend mappings from initial transforms, + see description above. Must be in ["single", + "iterative"] + :type strategy: :class:`str` :param single_chain_gdtts_thresh: Minimal GDT-TS score for model/target chain pair to be added to mapping. Mapping extension for a given @@ -556,6 +568,10 @@ class ChainMapper: chains. Target chains without mapped model chains are set to None. """ + + if strategy not in ["single", "iterative"]: + raise RuntimeError("strategy must be \"single\" or \"iterative\"") + chem_mapping, chem_group_alns, mdl = self.GetChemMapping(model) trg_group_pos, mdl_group_pos = _GetRefPos(self.target, mdl, @@ -565,7 +581,8 @@ class ChainMapper: # get transforms of any mdl chain onto any trg chain in same chem group # that fulfills gdtts threshold - transforms = list() + initial_transforms = list() + initial_mappings = list() for trg_pos, trg_chains, mdl_pos, mdl_chains in zip(trg_group_pos, self.chem_groups, mdl_group_pos, @@ -573,87 +590,45 @@ class ChainMapper: for t_pos, t in zip(trg_pos, trg_chains): for m_pos, m in zip(mdl_pos, mdl_chains): if len(t_pos) >= 3 and len(m_pos) >= 3: - if iterative_superposition: - try: - res = mol.alg.IterativeSuperposeSVD(m_pos,t_pos) - except: - # potentially fails if an iteration tries to - # superpose with < 3 positions => skip - continue - else: - res = mol.alg.SuperposeSVD(m_pos,t_pos) + transform = _GetTransform(m_pos, t_pos, + iterative_superposition) t_m_pos = geom.Vec3List(m_pos) - t_m_pos.ApplyTransform(res.transformation) - gdt = t_pos.GetGDTTS(t_m_pos) - if gdt >= single_chain_gdtts_thresh: - transforms.append(res.transformation) - - best_mapping = dict() - best_gdt = 0 - for transform in transforms: - mapping = dict() - mapped_mdl_chains = set() - gdt_cache = dict() # cache for non-normalized gdt scores - - for trg_pos, trg_chains, mdl_pos, mdl_chains in zip(trg_group_pos, - self.chem_groups, - mdl_group_pos, - chem_mapping): - - if len(trg_pos) == 0 or len(trg_pos[0]) == 0: - continue # cannot compute valid gdt - - n_gdt_contacts = 4 * len(trg_pos[0]) - gdt_scores = list() - - t_mdl_pos = list() - for m_pos in mdl_pos: - t_m_pos = geom.Vec3List(m_pos) - t_m_pos.ApplyTransform(transform) - t_mdl_pos.append(t_m_pos) - - for t_pos, t in zip(trg_pos, trg_chains): - for t_m_pos, m in zip(t_mdl_pos, mdl_chains): + t_m_pos.ApplyTransform(transform) gdt = t_pos.GetGDTTS(t_m_pos) if gdt >= single_chain_gdtts_thresh: - gdt_scores.append((gdt, (t,m))) - gdt_cache[(t,m)] = n_gdt_contacts * gdt - - gdt_scores.sort(reverse=True) - sorted_pairs = [item[1] for item in gdt_scores] - for p in sorted_pairs: - if p[0] not in mapping and p[1] not in mapped_mdl_chains: - mapping[p[0]] = p[1] - mapped_mdl_chains.add(p[1]) - - # compute overall gdt for this transform (non-normalized gdt!!!) - gdt = 0 - for t,m in mapping.items(): - gdt += gdt_cache[(t,m)] - - if gdt > best_gdt: - best_gdt = gdt - best_mapping = mapping - - if first_complete: - mdl_complete = len(mapping) == len(mdl.chains) - trg_complete = len(mapping) == len(self.target.chains) - if mdl_complete or trg_complete: - break + initial_transforms.append(transform) + initial_mappings.append((t,m)) + + if strategy == "single": + mapping = _SingleRigid(initial_transforms, initial_mappings, + self.chem_groups, chem_mapping, + trg_group_pos, mdl_group_pos, + single_chain_gdtts_thresh, + iterative_superposition, first_complete, + len(self.target.chains), len(mdl.chains)) + + elif strategy == "iterative": + mapping = _IterativeRigid(initial_transforms, initial_mappings, + self.chem_groups, chem_mapping, + trg_group_pos, mdl_group_pos, + single_chain_gdtts_thresh, + iterative_superposition, first_complete, + len(self.target.chains), len(mdl.chains)) # translate mapping format and return final_mapping = list() for ref_chains in self.chem_groups: mapped_mdl_chains = list() for ref_ch in ref_chains: - if ref_ch in best_mapping: - mapped_mdl_chains.append(best_mapping[ref_ch]) + if ref_ch in mapping: + mapped_mdl_chains.append(mapping[ref_ch]) else: mapped_mdl_chains.append(None) final_mapping.append(mapped_mdl_chains) return final_mapping + def GetNMappings(self, model): """ Returns number of possible mappings @@ -1629,6 +1604,165 @@ def _BlockGreedy(the_greed, seed_size, blocks_per_chem_group): return final_mapping + +def _SingleRigid(initial_transforms, initial_mappings, chem_groups, + chem_mapping, trg_group_pos, mdl_group_pos, + single_chain_gdtts_thresh, iterative_superposition, + first_complete, n_trg_chains, n_mdl_chains): + """ Takes initial transforms and sequentially adds chain pairs with + best scoring gdtts that fulfill single_chain_gdtts_thresh. The mapping + from the transform that leads to best overall gdtts score is returned. + Optionally, the first complete mapping, i.e. a mapping that covers all + target chains or all model chains, is returned. + """ + best_mapping = dict() + best_gdt = 0 + for transform in initial_transforms: + mapping = dict() + mapped_mdl_chains = set() + gdt_cache = dict() # cache for non-normalized gdt scores + + for trg_chains, mdl_chains, trg_pos, mdl_pos, in zip(chem_groups, + chem_mapping, + trg_group_pos, + mdl_group_pos): + + if len(trg_pos) == 0 or len(mdl_pos) == 0: + continue # cannot compute valid gdt + + n_gdt_contacts = 4 * len(trg_pos[0]) + gdt_scores = list() + + t_mdl_pos = list() + for m_pos in mdl_pos: + t_m_pos = geom.Vec3List(m_pos) + t_m_pos.ApplyTransform(transform) + t_mdl_pos.append(t_m_pos) + + for t_pos, t in zip(trg_pos, trg_chains): + for t_m_pos, m in zip(t_mdl_pos, mdl_chains): + gdt = t_pos.GetGDTTS(t_m_pos) + if gdt >= single_chain_gdtts_thresh: + gdt_scores.append((gdt, (t,m))) + gdt_cache[(t,m)] = n_gdt_contacts * gdt + + gdt_scores.sort(reverse=True) + sorted_pairs = [item[1] for item in gdt_scores] + for p in sorted_pairs: + if p[0] not in mapping and p[1] not in mapped_mdl_chains: + mapping[p[0]] = p[1] + mapped_mdl_chains.add(p[1]) + + # compute overall gdt for this transform (non-normalized gdt!!!) + gdt = 0 + for t,m in mapping.items(): + gdt += gdt_cache[(t,m)] + + if gdt > best_gdt: + best_gdt = gdt + best_mapping = mapping + if first_complete: + n = len(mapping) + if n == n_mdl_chains or n == n_trg_chains: + break + + return best_mapping + + +def _IterativeRigid(initial_transforms, initial_mappings, chem_groups, + chem_mapping, trg_group_pos, mdl_group_pos, + single_chain_gdtts_thresh, iterative_superposition, + first_complete, n_trg_chains, n_mdl_chains): + """ Takes initial transforms and sequentially adds chain pairs with + best scoring gdtts that fulfill single_chain_gdtts_thresh. With each + added chain pair, the transform gets updated. Thus the naming iterative. + The mapping from the initial transform that leads to best overall gdtts + score is returned. Optionally, the first complete mapping, i.e. a mapping + that covers all target chains or all model chains, is returned. + """ + + # to directly retrieve positions using chain names + trg_pos_dict = dict() + for trg_pos, trg_chains in zip(trg_group_pos, chem_groups): + for t_pos, t in zip(trg_pos, trg_chains): + trg_pos_dict[t] = t_pos + mdl_pos_dict = dict() + for mdl_pos, mdl_chains in zip(mdl_group_pos, chem_mapping): + for m_pos, m in zip(mdl_pos, mdl_chains): + mdl_pos_dict[m] = m_pos + + best_mapping = dict() + best_gdt = 0 + for initial_transform, initial_mapping in zip(initial_transforms, + initial_mappings): + mapping = {initial_mapping[0]: initial_mapping[1]} + transform = geom.Mat4(initial_transform) + mapped_trg_pos = geom.Vec3List(trg_pos_dict[initial_mapping[0]]) + mapped_mdl_pos = geom.Vec3List(mdl_pos_dict[initial_mapping[1]]) + + # the following variables contain the chains which are + # available for mapping + trg_chain_groups = [set(group) for group in chem_groups] + mdl_chain_groups = [set(group) for group in chem_mapping] + + # search and kick out inital mapping + for group in trg_chain_groups: + if initial_mapping[0] in group: + group.remove(initial_mapping[0]) + break + for group in mdl_chain_groups: + if initial_mapping[1] in group: + group.remove(initial_mapping[1]) + break + + something_happened = True + while something_happened: + # search for best mapping given current transform + something_happened=False + best_sc_mapping = None + best_sc_group_idx = None + best_sc_gdt = 0.0 + group_idx = 0 + for trg_chains, mdl_chains in zip(trg_chain_groups, mdl_chain_groups): + for t in trg_chains: + t_pos = trg_pos_dict[t] + for m in mdl_chains: + m_pos = mdl_pos_dict[m] + t_m_pos = geom.Vec3List(m_pos) + t_m_pos.ApplyTransform(transform) + gdt = t_pos.GetGDTTS(t_m_pos) + if gdt > single_chain_gdtts_thresh and gdt > best_sc_gdt: + best_sc_gdt = gdt + best_sc_mapping = (t,m) + best_sc_group_idx = group_idx + group_idx += 1 + + if best_sc_mapping is not None: + something_happened = True + mapping[best_sc_mapping[0]] = best_sc_mapping[1] + mapped_trg_pos.extend(trg_pos_dict[best_sc_mapping[0]]) + mapped_mdl_pos.extend(mdl_pos_dict[best_sc_mapping[1]]) + trg_chain_groups[best_sc_group_idx].remove(best_sc_mapping[0]) + mdl_chain_groups[best_sc_group_idx].remove(best_sc_mapping[1]) + + transform = _GetTransform(mapped_mdl_pos, mapped_trg_pos, + iterative_superposition) + + # compute overall gdt for current transform (non-normalized gdt!!!) + mapped_mdl_pos.ApplyTransform(transform) + gdt = mapped_trg_pos.GetGDTTS(mapped_mdl_pos, norm=False) + + if gdt > best_gdt: + best_gdt = gdt + best_mapping = mapping + if first_complete: + n = len(mapping) + if n == n_mdl_chains or n == n_trg_chains: + break + + return best_mapping + + def _GetRefPos(trg, mdl, trg_msas, mdl_alns, max_pos = None): """ Extracts reference positions which are present in trg and mdl """ @@ -1890,5 +2024,29 @@ def _ChainMappings(ref_chains, mdl_chains, n_max=None): return _ConcatIterators(iterators) + +def _GetTransform(pos_one, pos_two, iterative): + """ Computes minimal RMSD superposition for pos_one onto pos_two + + :param pos_one: Positions that should be superposed onto *pos_two* + :type pos_one: :class:`geom.Vec3List` + :param pos_two: Reference positions + :type pos_two: :class:`geom.Vec3List` + :iterative: Whether iterative superposition should be used. Iterative + potentially raises, uses standard superposition as fallback. + :type iterative: :class:`bool` + :returns: Transformation matrix to superpose *pos_one* onto *pos_two* + :rtype: :class:`geom.Mat4` + """ + res = None + if iterative: + try: + res = mol.alg.IterativeSuperposeSVD(pos_one, pos_two) + except: + pass # triggers fallback below + if res is None: + res = mol.alg.SuperposeSVD(pos_one, pos_two) + return res.transformation + # specify public interface __all__ = ('ChainMapper',) diff --git a/modules/mol/alg/tests/test_chain_mapping.py b/modules/mol/alg/tests/test_chain_mapping.py index 2f73f441e32732815c19dd0a82e19577351802ed..65873b2a5deb8b89d2bb246eee4d37c3532b58ff 100644 --- a/modules/mol/alg/tests/test_chain_mapping.py +++ b/modules/mol/alg/tests/test_chain_mapping.py @@ -261,6 +261,12 @@ class TestChainMapper(unittest.TestCase): greedy_lddt_mapping = mapper.GetGreedylDDTMapping(mdl, seed_strategy="block") self.assertEqual(greedy_lddt_mapping, [['X', 'Y'],[None],['Z']]) + greedy_rigid_mapping = mapper.GetGreedyRigidMapping(mdl, strategy="single") + self.assertEqual(greedy_rigid_mapping, [['X', 'Y'],[None],['Z']]) + + greedy_rigid_mapping = mapper.GetGreedyRigidMapping(mdl, strategy="iterative") + self.assertEqual(greedy_rigid_mapping, [['X', 'Y'],[None],['Z']]) + if __name__ == "__main__": from ost import testutils