From 3410aec89cf94d780b68c830c3d16c3a15affa19 Mon Sep 17 00:00:00 2001 From: Gabriel Studer <gabriel.studer@unibas.ch> Date: Tue, 17 Mar 2015 16:00:06 +0100 Subject: [PATCH] Rewrite of the distance rmsd test. The old implementation segfaulted because of unknown reasons. The new implementation has been tested on thousands of structures and seems to be stable and roughly 3 times faster. --- modules/mol/alg/pymod/wrap_mol_alg.cc | 5 +- modules/mol/alg/src/distance_rmsd_test.cc | 317 +++++++++++++--------- 2 files changed, 187 insertions(+), 135 deletions(-) diff --git a/modules/mol/alg/pymod/wrap_mol_alg.cc b/modules/mol/alg/pymod/wrap_mol_alg.cc index 87e5be6a6..acd951f1a 100644 --- a/modules/mol/alg/pymod/wrap_mol_alg.cc +++ b/modules/mol/alg/pymod/wrap_mol_alg.cc @@ -124,8 +124,6 @@ BOOST_PYTHON_MODULE(_ost_mol_alg) def("CreateDistanceList",&mol::alg::CreateDistanceList); def("CreateDistanceListFromMultipleReferences",&create_distance_list_from_multiple_references); - def("DistanceRMSDTest", &mol::alg::DistanceRMSDTest, (arg("sequence_separation")=0,arg("local_lddt_property_string")="")); - def("SuperposeFrames", superpose_frames1, (arg("source"), arg("sel")=ost::mol::EntityView(), arg("begin")=0, arg("end")=-1, arg("ref")=-1)); @@ -265,6 +263,7 @@ BOOST_PYTHON_MODULE(_ost_mol_alg) .def(map_indexing_suite<mol::alg::GlobalRDMap,true>()) ; - def("DRMSD",&mol::alg::DRMSD); + def("DRMSD",&mol::alg::DRMSD,(arg("view"),arg("distance_list"), + arg("cap_distance")=5.0,arg("sequence_separation")=0)); } diff --git a/modules/mol/alg/src/distance_rmsd_test.cc b/modules/mol/alg/src/distance_rmsd_test.cc index 1411c1806..6bb500e36 100644 --- a/modules/mol/alg/src/distance_rmsd_test.cc +++ b/modules/mol/alg/src/distance_rmsd_test.cc @@ -8,117 +8,176 @@ namespace ost { namespace mol { namespace alg { namespace { -std::pair<Real, long int> calc_rmsd(const ResidueRDMap& res_distance_list, const ResNum& rnum, - ChainView mdl_chain, int sequence_separation, Real cap_distance, - bool only_fixed, bool swap, std::vector<std::pair<Real, long int> >& local_rmsd_data_list, bool log ) -{ - std::pair<Real, long int> rmsd_data(0, 0); - ResidueView mdl_res=mdl_chain.FindResidue(rnum); - for (ResidueRDMap::const_iterator ai=res_distance_list.begin(); ai!=res_distance_list.end(); ++ai) { - const UAtomIdentifiers& uais = ai->first; - const std::pair <Real,Real>& values = ai->second; - const UniqueAtomIdentifier& first_atom=uais.first; - const UniqueAtomIdentifier& second_atom=uais.second; - String name=swap ? SwappedName(first_atom.GetAtomName()) : first_atom.GetAtomName(); - AtomView av1=mdl_res ? mdl_res.FindAtom(name) : AtomView(); - - if (only_fixed) { - if (std::abs(first_atom.GetResNum().GetNum()-second_atom.GetResNum().GetNum())<=sequence_separation) { - continue; - } - if (Swappable(second_atom.GetResidueName(), second_atom.GetAtomName())) { - continue; - } - } - if (!only_fixed) { - if (first_atom.GetResNum().GetNum()<=(second_atom.GetResNum().GetNum()+sequence_separation)) { - continue; + +Real calc_squared_diff(const std::vector<geom::Vec3>& pos_one, + const std::vector<geom::Vec3>& pos_two, + const std::vector<Real>& exp_dist, + Real cap){ + //expect all all vectors to have the same length... + Real squared_diff = 0.0; + Real diff; + for(uint i = 0; i < pos_one.size(); ++i){ + diff = std::min(std::abs(geom::Distance(pos_one[i],pos_two[i]) - exp_dist[i]),cap); + squared_diff += diff*diff; + } + return squared_diff; +} + +void fill_residue_positions(const ResidueRDMap& res_dist_list, const ResidueView& res, + const ChainView& mdl_chain, Real cap_distance, int sequence_separation, + std::map<String,geom::Vec3>& positions){ + + AtomViewList atom_list = res.GetAtomList(); + //let's already fill in the positions and check later on whether a swap is necessary + for(AtomViewList::const_iterator i = atom_list.begin(); + i != atom_list.end(); ++i){ + positions[i->GetName()] = i->GetPos(); + } + + String rname = res.GetName(); + if(!(rname=="GLU" || rname=="ASP" || rname=="VAL" || rname=="TYR" || + rname=="PHE" || rname=="LEU" || rname=="ARG")){ + return; //no ambiguity check necessary for all other amino acids + } + + std::map<String,geom::Vec3> original_positions; + std::map<String,geom::Vec3> swapped_positions; + AtomView atom; + for(AtomViewList::const_iterator i = atom_list.begin(); i !=atom_list.end(); ++i){ + if(Swappable(rname,i->GetName())){ + atom = res.FindAtom(SwappedName(i->GetName())); + if(atom){ //both atoms are valid => note, that only in this case a swap is possible + original_positions[i->GetName()] = i->GetPos(); + swapped_positions[i->GetName()] = atom.GetPos(); } } - ResidueView rv2=mdl_chain.FindResidue(second_atom.GetResNum()); - rmsd_data.second+=1; - int rindex2=0; - int rindex1=mdl_res ? mdl_res.GetIndex() : -1; - if (!only_fixed && rindex1!=-1) - local_rmsd_data_list[rindex1].second+=1; - if (!rv2) { - rmsd_data.first+=(cap_distance*cap_distance); - local_rmsd_data_list[rindex1].first+=(cap_distance*cap_distance); - continue; - } - rindex2=rv2.GetIndex(); - if (!only_fixed) - local_rmsd_data_list[rindex2].second+=1; - AtomView av2=rv2.FindAtom(second_atom.GetAtomName()); - if (!(av1 && av2)) { - rmsd_data.first+=(cap_distance*cap_distance); - local_rmsd_data_list[rindex1].first+=(cap_distance*cap_distance); - local_rmsd_data_list[rindex2].first+=(cap_distance*cap_distance); - continue; - } - Real mdl_dist=geom::Length(av1.GetPos()-av2.GetPos()); - Real diff = mdl_dist-values.first; - if (diff > cap_distance) { - diff = cap_distance; - } - if (log) { - LOG_VERBOSE("drmsd:" << " " << av1.GetResidue().GetChain() << " " << av1.GetResidue().GetName() << " " << av1.GetResidue().GetNumber() << " " << av1.GetName() - << " " << av2.GetResidue().GetChain() << " " << av2.GetResidue().GetName() << " " << av2.GetResidue().GetNumber() << " " << av2.GetName() << " " - << mdl_dist << " " << values.first << " " << diff) - } - rmsd_data.first+=(diff*diff); - if (!only_fixed) { - local_rmsd_data_list[rindex1].first+=(diff*diff); - local_rmsd_data_list[rindex2].first+=(diff*diff); - } } - return rmsd_data; -} + std::vector<geom::Vec3> or_positions; + std::vector<geom::Vec3> sw_positions; + std::vector<geom::Vec3> other_positions; + std::vector<Real> expected_dist; + ResidueView other_residue; + + for(ResidueRDMap::const_iterator i = res_dist_list.begin(); + i != res_dist_list.end(); ++i){ + //check whether first atom is one of the swappables + const UniqueAtomIdentifier& first_atom = i->first.first; + if(original_positions.find(first_atom.GetAtomName()) == original_positions.end()) continue; + //we only check the second atom of the distance if it is NOT swappable + const UniqueAtomIdentifier& second_atom = i->first.second; + if(Swappable(second_atom.GetResidueName(),second_atom.GetAtomName())) continue; + //check whether other atom is valid and check the sequence separation + other_residue = mdl_chain.FindResidue(second_atom.GetResNum()); + if(!other_residue) continue; + if(std::abs(res.GetNumber().GetNum() - other_residue.GetNumber().GetNum()) <= sequence_separation) continue; + atom = other_residue.FindAtom(second_atom.GetAtomName()); + if(!atom) continue; + + or_positions.push_back(original_positions[first_atom.GetAtomName()]); + sw_positions.push_back(swapped_positions[first_atom.GetAtomName()]); + expected_dist.push_back(i->second.first); + other_positions.push_back(atom.GetPos()); + } -void drmsdt_check_and_swap(const GlobalRDMap& glob_dist_list, const EntityView& mdl, int sequence_separation, Real cap_distance, std::vector<std::pair<Real, long int> > local_rmsd_data_list) -{ - ChainView mdl_chain=mdl.GetChainList()[0]; - XCSEditor edi=mdl.GetHandle().EditXCS(BUFFERED_EDIT); - for (GlobalRDMap::const_iterator i=glob_dist_list.begin(); i!=glob_dist_list.end(); ++i) { - ResNum rnum = i->first; - if (i->second.size()==0) { - continue; + Real diff_orig = calc_squared_diff(or_positions,other_positions,expected_dist, cap_distance); + Real diff_swapped = calc_squared_diff(sw_positions,other_positions,expected_dist, cap_distance); + //finally swap the stuff if necessary + if(diff_swapped < diff_orig){ + for(std::map<String,geom::Vec3>::iterator i = swapped_positions.begin(); + i != swapped_positions.end(); ++i){ + positions[i->first] = i->second; } - ResidueView mdl_res=mdl_chain.FindResidue(rnum); - if (!mdl_res) { - continue; + } +} + +void fill_positions(const GlobalRDMap& glob_dist_list, const ChainView& mdl_chain, + Real cap_distance, int sequence_separation, + std::map<ResNum, std::map<String, geom::Vec3> >& positions){ + + ResidueView actual_res; + for(GlobalRDMap::const_iterator i = glob_dist_list.begin(); + i != glob_dist_list.end(); ++i){ + actual_res = mdl_chain.FindResidue(i->first); + if(actual_res){ + std::map<String,geom::Vec3> residue_positions; + fill_residue_positions(i->second, actual_res, mdl_chain, + cap_distance, sequence_separation, + residue_positions); + if(!residue_positions.empty()) positions[i->first] = residue_positions; } - String rname = mdl_res.GetName(); - if (!(rname=="GLU" || rname=="ASP" || rname=="VAL" || rname=="TYR" || - rname=="PHE" || rname=="LEU" || rname=="ARG")) { + } +} + +void fill_values(const GlobalRDMap& glob_dist_list, + const std::map<ResNum, std::map<String, geom::Vec3> >& positions, + Real cap_distance, int sequence_separation, + std::map<ResNum, std::pair<Real,long int> >& drmsd_values){ + + ResNum rnum_one, rnum_two; + std::vector<geom::Vec3> positions_one, positions_two; + std::vector<Real> expected_distances; + std::map<ResNum, std::map<String, geom::Vec3> >::const_iterator res_positions_one, res_positions_two; + std::map<String,geom::Vec3>::const_iterator pos_one, pos_two; + int penalties; + Real squared_diff; + int num_interactions; + Real squared_cap_distance = cap_distance * cap_distance; + + for(GlobalRDMap::const_iterator i = glob_dist_list.begin(); + i != glob_dist_list.end(); ++i){ + rnum_one = i->first; + res_positions_one = positions.find(rnum_one); + if(res_positions_one == positions.end()){ + //this residue is missing, so we give it the full penalty + drmsd_values[rnum_one] = std::make_pair(i->second.size()*squared_cap_distance,i->second.size()); continue; } - std::pair<long int, long int> ov1=calc_rmsd(i->second, rnum,mdl_chain, sequence_separation, - cap_distance, true, false, local_rmsd_data_list,true); - - std::pair<long int, long int> ov2=calc_rmsd(i->second, rnum, mdl_chain, sequence_separation, - cap_distance, true, true, local_rmsd_data_list,true); - - if (std::sqrt(ov1.first/static_cast<Real>(ov1.second))>(std::sqrt(ov2.first/static_cast<Real>(ov2.second)))) { - - AtomViewList atoms=mdl_res.GetAtomList(); - for (AtomViewList::iterator j=atoms.begin(), e2=atoms.end(); j!=e2; ++j) { - if (Swappable(rname, j->GetName())) { - edi.RenameAtom(j->GetHandle(), SwappedName(j->GetName())); - } + positions_one.clear(); + positions_two.clear(); + expected_distances.clear(); + penalties = 0; + + for(ResidueRDMap::const_iterator j = i->second.begin(); j != i->second.end(); ++j){ + //check sequence separation + rnum_two = j->first.second.GetResNum(); + if(std::abs(rnum_one.GetNum() - rnum_two.GetNum()) <= sequence_separation) continue; + //if residue is not present in the positions map, the positions are also missing in the + //model => penalty + res_positions_two = positions.find(rnum_two); + if(res_positions_two == positions.end()){ + ++penalties; + continue; } + //check whether both positions are present in the positions map + pos_one = res_positions_one->second.find(j->first.first.GetAtomName()); + pos_two = res_positions_two->second.find(j->first.second.GetAtomName()); + if(pos_one == res_positions_one->second.end() || pos_two == res_positions_two->second.end()){ + ++penalties; //one of the atoms is missing + continue; + } + positions_one.push_back(pos_one->second); + positions_two.push_back(pos_two->second); + expected_distances.push_back(j->second.first); + } + squared_diff = calc_squared_diff(positions_one, positions_two, + expected_distances, cap_distance); + num_interactions = positions_one.size(); + //add the penalty values + if(penalties > 0){ + squared_diff += squared_cap_distance*penalties; + num_interactions += penalties; } + drmsd_values[rnum_one] = std::make_pair(squared_diff,num_interactions); } } - -} +}//ns std::pair<Real,long int> DistanceRMSDTest(const EntityView& mdl, const GlobalRDMap& glob_dist_list, - Real cap_distance, int sequence_separation, const String& local_drmsdt_property_string) -{ + Real cap_distance, int sequence_separation, + const String& local_drmsdt_property_string){ if (!mdl.GetResidueCount()) { LOG_WARNING("model structures doesn't contain any residues"); return std::make_pair<long int,long int>(0,0); @@ -127,53 +186,47 @@ std::pair<Real,long int> DistanceRMSDTest(const EntityView& mdl, const GlobalRDM LOG_WARNING("global reference list is empty"); return std::make_pair<long int,long int>(0,0); } - std::vector<std::pair<Real, long int> > local_rmsd_data_list(mdl.GetResidueCount(), std::pair<Real, long int>(0, 0)); - drmsdt_check_and_swap(glob_dist_list,mdl,sequence_separation, cap_distance, local_rmsd_data_list); + + //every residue gets a map of positions + std::map<ResNum, std::map<String, geom::Vec3> > positions; + //every residue will have a sum of squared distances and number of distances + std::map<ResNum, std::pair<Real,long int> > drmsd_values; + //currently we're only looking at the first chain ChainView mdl_chain=mdl.GetChainList()[0]; - local_rmsd_data_list.clear(); - std::pair<Real, long int> total_ov(0, 0); - for (GlobalRDMap::const_iterator i=glob_dist_list.begin(), e=glob_dist_list.end(); i!=e; ++i) { - ResNum rn = i->first; - if (i->second.size()!=0) { - std::pair<Real, long int> ov1=calc_rmsd(i->second, rn, mdl_chain, sequence_separation, cap_distance, false, false, local_rmsd_data_list,true); - total_ov.first+=ov1.first; - total_ov.second+=ov1.second; - } - } - - for (GlobalRDMap::const_iterator i=glob_dist_list.begin(), - e=glob_dist_list.end();i!=e; ++i) { - ResNum rn = i->first; - if(local_drmsdt_property_string!="") { - ResidueView mdlr=mdl_chain.FindResidue(rn); - if (mdlr.IsValid()) { - int mdl_res_index =mdlr.GetIndex(); - Real local_rmsd=sqrt(local_rmsd_data_list[mdl_res_index].first/(static_cast<Real>(local_rmsd_data_list[mdl_res_index].second) ? static_cast<Real>(local_rmsd_data_list[mdl_res_index].second) : 1)); - mdlr.SetFloatProp(local_drmsdt_property_string, local_rmsd); - mdlr.SetFloatProp(local_drmsdt_property_string+"_sum", local_rmsd_data_list[mdl_res_index].first); - mdlr.SetIntProp(local_drmsdt_property_string+"_count", local_rmsd_data_list[mdl_res_index].second); - } + fill_positions(glob_dist_list, mdl_chain, cap_distance, + sequence_separation,positions); + fill_values(glob_dist_list, positions, cap_distance, + sequence_separation, drmsd_values); + + Real total_squared_diff = 0.0; + long int total_num_interactions = 0; + Real drmsd; + std::pair<Real, long int> residue_data; + ResidueView actual_residue; + for(GlobalRDMap::const_iterator i = glob_dist_list.begin(); + i != glob_dist_list.end(); ++i){ + residue_data = drmsd_values[i->first]; + total_squared_diff += residue_data.first; + total_num_interactions += residue_data.second; + actual_residue = mdl_chain.FindResidue(i->first); + if(actual_residue){ + drmsd = residue_data.second > 0 ? std::sqrt(residue_data.first/residue_data.second) : 0; + actual_residue.SetFloatProp(local_drmsdt_property_string, drmsd); + actual_residue.SetFloatProp(local_drmsdt_property_string+"_sum",residue_data.first); + actual_residue.SetFloatProp(local_drmsdt_property_string+"_count",residue_data.second); } } - local_rmsd_data_list.clear(); - return std::make_pair<Real,long int>(total_ov.first,total_ov.second); + drmsd = total_num_interactions > 0 ? std::sqrt(total_squared_diff/total_num_interactions) : 0; + return std::make_pair(drmsd,total_num_interactions); } Real DRMSD(const EntityView& v, const GlobalRDMap& global_dist_list, Real cap_distance, int sequence_separation) { String label="localdrmsd"; - std::pair<Real,long int> total_ov=alg::DistanceRMSDTest(v, global_dist_list, cap_distance, sequence_separation, label); - - Real calcdrmsd = static_cast<Real>(total_ov.first)/(static_cast<Real>(total_ov.second) ? static_cast<Real>(total_ov.second) : 1); - - return std::sqrt(calcdrmsd); + std::pair<Real,long int> result=alg::DistanceRMSDTest(v, global_dist_list, cap_distance, sequence_separation, label); + return result.first; } - - - - - -}}} +}}} //ns -- GitLab