diff --git a/modules/mol/alg/pymod/chain_mapping.py b/modules/mol/alg/pymod/chain_mapping.py index aa314e7df7e8bbbaa2becf3e4bd438aa828e2bbc..2318f6a2a979714e00742ab644b60fd3691c05e9 100644 --- a/modules/mol/alg/pymod/chain_mapping.py +++ b/modules/mol/alg/pymod/chain_mapping.py @@ -1014,7 +1014,7 @@ class ChainMapper: are estimated using all possible combinations of target and model chains within the same chem groups and build the basis for further extension. - There are three extension strategies: + There are four extension strategies: * **greedy_single_gdtts**: Iteratively add the model/target chain pair that adds the most conserved contacts based on the GDT-TS metric @@ -1022,12 +1022,19 @@ class ChainMapper: with highest GDT-TS score is returned. However, that mapping is not guaranteed to be complete (see *single_chain_gdtts_thresh*). - * **greedy_iterative_gdtts**: Same as single except that the - transformation gets updated with each added chain pair. + * **greedy_iterative_gdtts**: Same as greedy_single_gdtts except that + the transformation gets updated with each added chain pair. - * **greedy_iterative_rmsd**: Same as iterative, i.e. the transformation - gets updated with each added chain pair. However, - **single_chain_gdtts_thresh** is only applied to derive the initial + * **greedy_single_rmsd**: Conceptually similar to greedy_single_gdtts + but the added chain pairs are the ones with lowest RMSD. + The mapping with lowest overall RMSD gets returned. + *single_chain_gdtts_thresh* is only applied to derive the initial + transformations. After that, the minimal RMSD chain pair gets + iteratively added without applying any threshold. + + * **greedy_iterative_rmsd**: Same as greedy_single_rmsd exept that + the transformation gets updated with each added chain pair. + *single_chain_gdtts_thresh* is only applied to derive the initial transformations. After that, the minimal RMSD chain pair gets iteratively added without applying any threshold. @@ -1067,7 +1074,7 @@ class ChainMapper: """ strategies = ["greedy_single_gdtts", "greedy_iterative_gdtts", - "greedy_iterative_rmsd"] + "greedy_single_rmsd", "greedy_iterative_rmsd"] if strategy not in strategies: raise RuntimeError(f"strategy must be {strategies}") @@ -1138,6 +1145,13 @@ class ChainMapper: len(self.target.chains), len(mdl.chains)) + elif strategy == "greedy_single_rmsd": + mapping = _SingleRigidRMSD(initial_transforms, initial_mappings, + self.chem_groups, chem_mapping, + trg_group_pos, mdl_group_pos, + iterative_superposition) + + elif strategy == "greedy_iterative_rmsd": mapping = _IterativeRigidRMSD(initial_transforms, initial_mappings, self.chem_groups, chem_mapping, @@ -2973,7 +2987,7 @@ def _IterativeRigidGDTTS(initial_transforms, initial_mappings, chem_groups, for mdl_pos, mdl_chains in zip(mdl_group_pos, chem_mapping): for m_pos, m in zip(mdl_pos, mdl_chains): mdl_pos_dict[m] = m_pos - + best_mapping = dict() best_gdt = 0 for initial_transform, initial_mapping in zip(initial_transforms, @@ -3045,6 +3059,52 @@ def _IterativeRigidGDTTS(initial_transforms, initial_mappings, chem_groups, return best_mapping +def _SingleRigidRMSD(initial_transforms, initial_mappings, chem_groups, + chem_mapping, trg_group_pos, mdl_group_pos, + iterative_superposition): + """ + Takes initial transforms and sequentially adds chain pairs with lowest RMSD. + The mapping from the transform that leads to lowest overall RMSD is + returned. + """ + best_mapping = dict() + best_ssd = float("inf") # we're actually going for summed squared distances + # Since all positions have same lengths and we do a + # full mapping, lowest SSD has a guarantee of also + # being lowest RMSD + for transform in initial_transforms: + mapping = dict() + mapped_mdl_chains = set() + ssd = 0.0 + for trg_chains, mdl_chains, trg_pos, mdl_pos, in zip(chem_groups, + chem_mapping, + trg_group_pos, + mdl_group_pos): + if len(trg_pos) == 0 or len(mdl_pos) == 0: + continue # cannot compute valid rmsd + ssds = list() + t_mdl_pos = list() + for m_pos in mdl_pos: + t_m_pos = geom.Vec3List(m_pos) + t_m_pos.ApplyTransform(transform) + t_mdl_pos.append(t_m_pos) + for t_pos, t in zip(trg_pos, trg_chains): + for t_m_pos, m in zip(t_mdl_pos, mdl_chains): + ssd = t_pos.GetSummedSquaredDistances(t_m_pos) + ssds.append((ssd, (t,m))) + ssds.sort() + for item in ssds: + p = item[1] + if p[0] not in mapping and p[1] not in mapped_mdl_chains: + mapping[p[0]] = p[1] + mapped_mdl_chains.add(p[1]) + ssd += item[0] + + if ssd < best_ssd: + best_ssd = ssd + best_mapping = mapping + + return best_mapping def _IterativeRigidRMSD(initial_transforms, initial_mappings, chem_groups, chem_mapping, trg_group_pos, mdl_group_pos, diff --git a/modules/mol/alg/tests/test_chain_mapping.py b/modules/mol/alg/tests/test_chain_mapping.py index d012f1e65fa872ac2a6ab1dd203026e3bd72f873..36a633b09479e89aa81297cfb45326b4262fea69 100644 --- a/modules/mol/alg/tests/test_chain_mapping.py +++ b/modules/mol/alg/tests/test_chain_mapping.py @@ -283,6 +283,9 @@ class TestChainMapper(unittest.TestCase): greedy_rigid_res = mapper.GetRigidMapping(mdl, strategy="greedy_iterative_gdtts") self.assertEqual(greedy_rigid_res.mapping, [['X', 'Y'],[None],['Z']]) + greedy_rigid_res = mapper.GetRigidMapping(mdl, strategy="greedy_single_rmsd") + self.assertEqual(greedy_rigid_res.mapping, [['X', 'Y'],[None],['Z']]) + greedy_rigid_res = mapper.GetRigidMapping(mdl, strategy="greedy_iterative_rmsd") self.assertEqual(greedy_rigid_res.mapping, [['X', 'Y'],[None],['Z']])