diff --git a/modules/mol/alg/pymod/chain_mapping.py b/modules/mol/alg/pymod/chain_mapping.py index 466679ec2a7793d67976b456303e717d2ec8457c..111696df92b0f1ca64e819879359e5306a247049 100644 --- a/modules/mol/alg/pymod/chain_mapping.py +++ b/modules/mol/alg/pymod/chain_mapping.py @@ -1121,7 +1121,8 @@ class ChainMapper: """ strategies = ["greedy_single_gdtts", "greedy_iterative_gdtts", - "greedy_single_rmsd", "greedy_iterative_rmsd"] + "greedy_single_rmsd", "greedy_iterative_rmsd", + "naive_rmsd"] if strategy not in strategies: raise RuntimeError(f"strategy must be {strategies}") @@ -1153,57 +1154,71 @@ class ChainMapper: chem_group_alns, max_pos = subsampling) - # get transforms of any mdl chain onto any trg chain in same chem group - # that fulfills gdtts threshold - initial_transforms = list() - initial_mappings = list() - for trg_pos, trg_chains, mdl_pos, mdl_chains in zip(trg_group_pos, - self.chem_groups, - mdl_group_pos, - chem_mapping): - for t_pos, t in zip(trg_pos, trg_chains): - for m_pos, m in zip(mdl_pos, mdl_chains): - if len(t_pos) >= 3 and len(m_pos) >= 3: - transform = _GetTransform(m_pos, t_pos, - iterative_superposition) - t_m_pos = geom.Vec3List(m_pos) - t_m_pos.ApplyTransform(transform) - gdt = t_pos.GetGDTTS(t_m_pos) - if gdt >= single_chain_gdtts_thresh: - initial_transforms.append(transform) - initial_mappings.append((t,m)) - - if strategy == "greedy_single_gdtts": - mapping = _SingleRigidGDTTS(initial_transforms, initial_mappings, - self.chem_groups, chem_mapping, - trg_group_pos, mdl_group_pos, - single_chain_gdtts_thresh, - iterative_superposition, first_complete, - len(self.target.chains), - len(mdl.chains)) - - elif strategy == "greedy_iterative_gdtts": - mapping = _IterativeRigidGDTTS(initial_transforms, initial_mappings, - self.chem_groups, chem_mapping, - trg_group_pos, mdl_group_pos, - single_chain_gdtts_thresh, - iterative_superposition, - first_complete, - len(self.target.chains), - len(mdl.chains)) - - elif strategy == "greedy_single_rmsd": - mapping = _SingleRigidRMSD(initial_transforms, initial_mappings, - self.chem_groups, chem_mapping, - trg_group_pos, mdl_group_pos, - iterative_superposition) - - - elif strategy == "greedy_iterative_rmsd": - mapping = _IterativeRigidRMSD(initial_transforms, initial_mappings, + mapping = None + + if strategy.startswith("greedy"): + # get transforms of any mdl chain onto any trg chain in same chem + # group that fulfills gdtts threshold + initial_transforms = list() + initial_mappings = list() + for trg_pos, trg_chains, mdl_pos, mdl_chains in zip(trg_group_pos, + self.chem_groups, + mdl_group_pos, + chem_mapping): + for t_pos, t in zip(trg_pos, trg_chains): + for m_pos, m in zip(mdl_pos, mdl_chains): + if len(t_pos) >= 3 and len(m_pos) >= 3: + transform = _GetTransform(m_pos, t_pos, + iterative_superposition) + t_m_pos = geom.Vec3List(m_pos) + t_m_pos.ApplyTransform(transform) + gdt = t_pos.GetGDTTS(t_m_pos) + if gdt >= single_chain_gdtts_thresh: + initial_transforms.append(transform) + initial_mappings.append((t,m)) + + if strategy == "greedy_single_gdtts": + mapping = _SingleRigidGDTTS(initial_transforms, + initial_mappings, self.chem_groups, chem_mapping, trg_group_pos, mdl_group_pos, - iterative_superposition) + single_chain_gdtts_thresh, + iterative_superposition, + first_complete, + len(self.target.chains), + len(mdl.chains)) + + elif strategy == "greedy_iterative_gdtts": + mapping = _IterativeRigidGDTTS(initial_transforms, + initial_mappings, + self.chem_groups, chem_mapping, + trg_group_pos, mdl_group_pos, + single_chain_gdtts_thresh, + iterative_superposition, + first_complete, + len(self.target.chains), + len(mdl.chains)) + + elif strategy == "greedy_single_rmsd": + mapping = _SingleRigidRMSD(initial_transforms, + initial_mappings, + self.chem_groups, + chem_mapping, + trg_group_pos, + mdl_group_pos, + iterative_superposition) + + + elif strategy == "greedy_iterative_rmsd": + mapping = _IterativeRigidRMSD(initial_transforms, + initial_mappings, + self.chem_groups, chem_mapping, + trg_group_pos, mdl_group_pos, + iterative_superposition) + elif strategy == "naive_rmsd": + mapping = _NaiveRMSD(self.chem_groups, chem_mapping, + trg_group_pos, mdl_group_pos, + iterative_superposition, self.n_max_naive) # translate mapping format and return final_mapping = list() @@ -3313,6 +3328,51 @@ def _IterativeRigidRMSD(initial_transforms, initial_mappings, chem_groups, return best_mapping +def _NaiveRMSD(chem_groups, chem_mapping, trg_group_pos, mdl_group_pos, + iterative_superposition, n_max_naive): + + # to directly retrieve positions using chain names + trg_pos_dict = dict() + for trg_pos, trg_chains in zip(trg_group_pos, chem_groups): + for t_pos, t in zip(trg_pos, trg_chains): + trg_pos_dict[t] = t_pos + mdl_pos_dict = dict() + for mdl_pos, mdl_chains in zip(mdl_group_pos, chem_mapping): + for m_pos, m in zip(mdl_pos, mdl_chains): + mdl_pos_dict[m] = m_pos + + best_mapping = dict() + best_rmsd = float("inf") + + for mapping in _ChainMappings(chem_groups, chem_mapping, n_max_naive): + trg_pos = geom.Vec3List() + mdl_pos = geom.Vec3List() + for trg_group, mdl_group in zip(chem_groups, mapping): + for trg_ch, mdl_ch in zip(trg_group, mdl_group): + if trg_ch is not None and mdl_ch is not None: + trg_pos.extend(trg_pos_dict[trg_ch]) + mdl_pos.extend(mdl_pos_dict[mdl_ch]) + superpose_res = None + if iterative_superposition: + try: + superpose_res = mol.alg.IterativeSuperposeSVD(mdl_pos, trg_pos) + except: + pass # triggers fallback below + if superpose_res is None: + superpose_res = mol.alg.SuperposeSVD(mdl_pos, trg_pos) + + if superpose_res.rmsd < best_rmsd: + best_rmsd = superpose_res.rmsd + best_mapping = mapping + + # this is stupid... + tmp = dict() + for chem_group, mapping in zip(chem_groups, best_mapping): + for trg_ch, mdl_ch in zip(chem_group, mapping): + tmp[trg_ch] = mdl_ch + + return tmp + def _GetRefPos(trg, mdl, trg_msas, mdl_alns, max_pos = None): """ Extracts reference positions which are present in trg and mdl