Skip to content
Snippets Groups Projects
Commit b13a66a3 authored by Studer Gabriel's avatar Studer Gabriel
Browse files

explorative (read hacky) code to get pocket detection in ligand modelling started

parent a31134c6
No related branches found
No related tags found
No related merge requests found
......@@ -10,6 +10,7 @@ set(MODELLING_CPP
export_score_container.cc
export_scoring_weights.cc
export_sidechain_reconstructor.cc
export_pocket_finder.cc
wrap_modelling.cc
)
......
// Copyright (c) 2013-2019, SIB - Swiss Institute of Bioinformatics and
// Biozentrum - University of Basel
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <boost/python.hpp>
#include <promod3/core/export_helper.hh>
#include <promod3/modelling/pocket_finder.hh>
using namespace promod3;
using namespace promod3::modelling;
using namespace boost::python;
void export_pocket_finder() {
def("PocketFinder", &promod3::modelling::PocketFinder, (arg("view_one"), arg("view_two")));
}
......@@ -27,6 +27,7 @@ void export_rigid_blocks();
void export_score_container();
void export_scoring_weights();
void export_SidechainReconstructor();
void export_pocket_finder();
BOOST_PYTHON_MODULE(_modelling)
{
......@@ -41,4 +42,5 @@ BOOST_PYTHON_MODULE(_modelling)
export_score_container();
export_scoring_weights();
export_SidechainReconstructor();
export_pocket_finder();
}
......@@ -18,6 +18,7 @@ set(MODELLING_SOURCES
scoring_weights.cc
sidechain_reconstructor.cc
sidechain_env_listener.cc
pocket_finder.cc
)
set(MODELLING_HEADERS
......@@ -40,6 +41,7 @@ set(MODELLING_HEADERS
scoring_weights.hh
sidechain_reconstructor.hh
sidechain_env_listener.hh
pocket_finder.hh
)
module(NAME modelling
......
// Copyright (c) 2013-2019, SIB - Swiss Institute of Bioinformatics and
// Biozentrum - University of Basel
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <promod3/modelling/pocket_finder.hh>
#include <promod3/core/runtime_profiling.hh>
#include <promod3/core/eigen_types.hh>
#include <promod3/core/superpose.hh>
#include <ost/mol/atom_view.hh>
#include <chrono>
#include <tuple>
#include <unordered_map>
#include <limits>
namespace {
// the first triplet of integers describe the triange (edge lengths)
// the second triplet of integers describe a position expressed on the
// basis of this triangle
typedef std::tuple<int, int, int, int, int, int> PocketHasherKey;
typedef std::tuple<int, int> PocketHasherValue;
struct PocketHasherKeyHasher : public std::unary_function<PocketHasherKey,
std::size_t> {
std::size_t operator()(const PocketHasherKey& k) const {
std::size_t seed = std::get<0>(k);
seed ^= std::get<1>(k) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
seed ^= std::get<2>(k) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
seed ^= std::get<3>(k) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
seed ^= std::get<4>(k) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
seed ^= std::get<5>(k) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
return seed;
}
};
typedef std::unordered_map<const PocketHasherKey, PocketHasherValue,
PocketHasherKeyHasher> PocketHasherMap;
} // anon ns
namespace promod3 { namespace modelling {
PocketQuery PocketQuery::FromEntity(const ost::mol::ResidueView& res,
const ost::mol::EntityView& env) {
PocketQuery query;
return query;
}
geom::Mat4 FindTransform(const geom::Vec3List& query_positions,
const geom::Vec3List& target_positions,
int max_query_samples = 20,
Real dist_thresh = 1.0) {
std::chrono::time_point<std::chrono::system_clock> start, end;
start = std::chrono::system_clock::now();
int n_query_total = query_positions.size();
int n_query = std::min(n_query_total, max_query_samples);
promod3::core::EMatXX query_dist = promod3::core::EMatXX::Zero(n_query,n_query);
std::vector<int> query_indices;
Real frac = static_cast<Real>(n_query_total)/n_query;
for(int i = 0; i < n_query; ++i) {
query_indices.push_back(std::floor(frac*i));
}
for(int i = 0; i < n_query; ++i) {
for(int j = i+1; j < n_query; ++j) {
Real d = geom::Distance(query_positions[query_indices[i]],
query_positions[query_indices[j]]);
query_dist(i,j) = d;
query_dist(j,i) = d;
}
}
int n_target = target_positions.size();
promod3::core::EMatXX target_dist = promod3::core::EMatXX::Zero(n_target,n_target);
for(int i = 0; i < n_target; ++i) {
for(int j = i+1; j < n_target; ++j) {
Real d = geom::Distance(target_positions[i], target_positions[j]);
target_dist(i,j) = d;
target_dist(j,i) = d;
}
}
end = std::chrono::system_clock::now();
int elapsed_microseconds = std::chrono::duration_cast<std::chrono::microseconds>(end-start).count();
std::cerr<<"microseconds to do pairwise distances: "<<elapsed_microseconds<<std::endl;
start = std::chrono::system_clock::now();
std::vector<geom::Mat4> transformations;
promod3::core::EMatX3 query_t_pos = promod3::core::EMatX3::Zero(3,3);
promod3::core::EMatX3 target_t_pos = promod3::core::EMatX3::Zero(3,3);
for(int edge_one_idx = 0; edge_one_idx < n_query-2; ++edge_one_idx) {
int p1_one = edge_one_idx;
int p2_one = edge_one_idx + 1;
int p3_one = edge_one_idx + 2;
Real p1p2_one = query_dist(p1_one, p2_one);
Real p1p3_one = query_dist(p1_one, p3_one);
Real p2p3_one = query_dist(p2_one, p3_one);
for(int p1_two = 0; p1_two < n_target; ++p1_two) {
for(int p2_two = p1_two + 1; p2_two < n_target; ++p2_two) {
// only sample p3_two if p1p2 distances match
if(std::abs(p1p2_one - target_dist(p1_two, p2_two)) < dist_thresh) {
for(int p3_two = 0; p3_two < n_target; ++p3_two) {
if(p3_two != p1_two || p3_two != p2_two) {
Real p1p3_two = target_dist(p1_two,p3_two);
Real p2p3_two = target_dist(p2_two,p3_two);
if(std::abs(p1p3_one - p1p3_two) < dist_thresh &&
std::abs(p2p3_one - p2p3_two) < dist_thresh) {
promod3::core::EMatFillRow(query_t_pos, 0, query_positions[query_indices[p1_one]]);
promod3::core::EMatFillRow(query_t_pos, 1, query_positions[query_indices[p2_one]]);
promod3::core::EMatFillRow(query_t_pos, 2, query_positions[query_indices[p3_one]]);
promod3::core::EMatFillRow(target_t_pos, 0, target_positions[p1_two]);
promod3::core::EMatFillRow(target_t_pos, 1, target_positions[p2_two]);
promod3::core::EMatFillRow(target_t_pos, 2, target_positions[p3_two]);
transformations.push_back(promod3::core::MinRMSDSuperposition(query_t_pos, target_t_pos));
}
if(std::abs(p1p3_one - p2p3_two) < dist_thresh &&
std::abs(p2p3_one - p1p3_two) < dist_thresh) {
promod3::core::EMatFillRow(query_t_pos, 0, query_positions[query_indices[p1_one]]);
promod3::core::EMatFillRow(query_t_pos, 1, query_positions[query_indices[p2_one]]);
promod3::core::EMatFillRow(query_t_pos, 2, query_positions[query_indices[p3_one]]);
promod3::core::EMatFillRow(target_t_pos, 0, target_positions[p2_two]);
promod3::core::EMatFillRow(target_t_pos, 1, target_positions[p1_two]);
promod3::core::EMatFillRow(target_t_pos, 2, target_positions[p3_two]);
transformations.push_back(promod3::core::MinRMSDSuperposition(query_t_pos, target_t_pos));
}
}
}
}
}
}
}
end = std::chrono::system_clock::now();
elapsed_microseconds = std::chrono::duration_cast<std::chrono::microseconds>(end-start).count();
std::cerr<<"microseconds to do triangle matching: "<<elapsed_microseconds<<std::endl;
typedef std::tuple<int, int, int> key_t;
struct key_hash : public std::unary_function<key_t, std::size_t> {
std::size_t operator()(const key_t& k) const {
std::size_t seed = std::get<0>(k);
seed ^= std::get<1>(k) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
seed ^= std::get<2>(k) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
return seed;
}
};
typedef std::unordered_map<const key_t,int,key_hash> map_t;
map_t target_map;
key_t pos_tuple;
for(int pos_idx = 0; pos_idx < n_target; ++pos_idx) {
int x = static_cast<int>(2*target_positions[pos_idx][0]);
int y = static_cast<int>(2*target_positions[pos_idx][1]);
int z = static_cast<int>(2*target_positions[pos_idx][2]);
for(int i = x-1; i<=x+1; ++i) {
for(int j = y-1; j<=y+1; ++j) {
for(int k = z-1; k<=z+1; ++k) {
pos_tuple = std::make_tuple(i,j,k);
target_map[pos_tuple] = 1;
}
}
}
}
int max_matches = 0;
int max_matches_idx = -1;
start = std::chrono::system_clock::now();
for(uint t_idx = 0; t_idx < transformations.size(); ++t_idx) {
int matches = 0;
for(int query_pos_idx = 0; query_pos_idx < n_query; ++query_pos_idx) {
geom::Mat4& t = transformations[t_idx];
const geom::Vec3& pos = query_positions[query_indices[query_pos_idx]];
geom::Vec3 transformed_pos(t(0,0)*pos[0]+t(0,1)*pos[1]+t(0,2)*pos[2]+t(0,3),
t(1,0)*pos[0]+t(1,1)*pos[1]+t(1,2)*pos[2]+t(1,3),
t(2,0)*pos[0]+t(2,1)*pos[1]+t(2,2)*pos[2]+t(2,3));
int x = static_cast<int>(2*transformed_pos[0]);
int y = static_cast<int>(2*transformed_pos[1]);
int z = static_cast<int>(2*transformed_pos[2]);
pos_tuple = std::make_tuple(x,y,z);
map_t::iterator it = target_map.find(pos_tuple);
if(it != target_map.end()) {
++matches;
}
}
if(matches > max_matches) {
max_matches = matches;
max_matches_idx = t_idx;
}
}
end = std::chrono::system_clock::now();
elapsed_microseconds = std::chrono::duration_cast<std::chrono::microseconds>(end-start).count();
std::cerr<<"microseconds to do match finding: "<<elapsed_microseconds<<std::endl;
std::cerr<<"number of triangles: "<<transformations.size()<<std::endl;
std::cerr<<"max num matches "<<max_matches<<std::endl;
return transformations[max_matches_idx];
}
geom::Mat4 FindTransformAwesome(const geom::Vec3List& query_positions,
const geom::Vec3List& target_positions) {
// fill everything in Eigen matrices
int n_query = query_positions.size();
int n_target = target_positions.size();
promod3::core::EMatX3 query_pos = promod3::core::EMatX3::Zero(n_query, 3);
promod3::core::EMatX3 target_pos = promod3::core::EMatX3::Zero(n_target, 3);
for(int i = 0; i < n_query; ++i) {
promod3::core::EMatFillRow(query_pos, i, query_positions[i]);
}
for(int i = 0; i < n_target; ++i) {
promod3::core::EMatFillRow(target_pos, i, target_positions[i]);
}
// estimate pairwise distances
promod3::core::EMatXX query_dist = promod3::core::EMatXX::Zero(n_query,
n_query);
promod3::core::EMatXX target_dist = promod3::core::EMatXX::Zero(n_target,
n_target);
for(int i = 0; i < n_query; ++i) {
for(int j = i+1; j < n_query; ++j) {
Real d = (query_pos.row(i) - query_pos.row(j)).norm();
query_dist(i,j) = d;
query_dist(j,i) = d;
}
}
for(int i = 0; i < n_target; ++i) {
for(int j = i+1; j < n_target; ++j) {
Real d = (target_pos.row(i) - target_pos.row(j)).norm();
target_dist(i,j) = d;
target_dist(j,i) = d;
}
}
// start to build up the hash map
PocketHasherMap map;
for(int p1 = 0; p1 < n_query; ++p1) {
std::cerr<<p1<<std::endl;
for(int p2 = 0; p2 < n_query; ++p2) {
if(p2 != p1) {
int a = static_cast<int>(query_dist(p1,p2));
if(a > 12) {
continue;
}
for(int p3 = 0; p3 < n_query; ++p3) {
if(p3 != p1 && p3 != p2) {
int b = static_cast<int>(query_dist(p1,p3));
int c = static_cast<int>(query_dist(p2,p3));
if(b > 12 || c > 12) {
continue;
}
// transform all other positions into this vector space and add
// stuff to the hash map
promod3::core::EMat3 base;
base.col(0) = (query_pos.row(p2) - query_pos.row(p1)).transpose();
base.col(1) = (query_pos.row(p3) - query_pos.row(p1)).transpose();
base.col(2) = base.col(0).cross(base.col(1));
base.col(0).normalize();
base.col(1).normalize();
base.col(2).normalize();
promod3::core::EMatX3 transformed_pos =
(base.inverse() * query_pos.transpose()).transpose();
for(int i = 0; i < n_query; ++i) {
if(i != p1 && i != p2 && i != p3) {
int d = static_cast<int>(transformed_pos(i,0));
int e = static_cast<int>(transformed_pos(i,1));
int f = static_cast<int>(transformed_pos(i,2));
int triangle_idx = p1*n_query*n_query + p2*n_query + p3;
int query_idx = 0; // will change as soon as we have several
// queries at once...
map[std::make_tuple(a,b,c,d,e,f)] =
std::make_tuple(triangle_idx, query_idx);
}
}
}
}
}
}
}
// the number of rows should actually be the maximal number of triangles in
// any of the queries the number of cols is currently one, we currently only
// have one query...
promod3::core::EMatXX accumulator = promod3::core::EMatXX::Zero(n_query*n_query*n_query,1);
for(int p1 = 0; p1 < n_target; ++p1) {
std::cerr<<p1<<std::endl;
for(int p2 = 0; p2 < n_target; ++p2) {
if(p2 != p1) {
int a = static_cast<int>(target_dist(p1,p2));
if(a > 12) {
continue;
}
promod3::core::EMat3 base;
base.col(0) = (target_pos.row(p2) -
target_pos.row(p1)).transpose();
for(int p3 = 0; p3 < n_target; ++p3) {
if(p3 != p1 && p3 != p2) {
int b = static_cast<int>(target_dist(p1,p3));
int c = static_cast<int>(target_dist(p2,p3));
if(b > 12 || c > 12) {
continue;
}
// transform all other positions into this vector space and add
// stuff to
// the hash map
base.col(1) = (target_pos.row(p3) -
target_pos.row(p1)).transpose();
base.col(2) = base.col(0).cross(base.col(1));
base.col(0).normalize();
base.col(1).normalize();
base.col(2).normalize();
promod3::core::EMatX3 transformed_pos =
(base.inverse() * target_pos.transpose()).transpose();
for(int i = 0; i < n_target; ++i) {
if(i != p1 && i != p2 && i != p3) {
int d = static_cast<int>(transformed_pos(i,0));
int e = static_cast<int>(transformed_pos(i,1));
int f = static_cast<int>(transformed_pos(i,2));
PocketHasherMap::iterator it = map.find(std::make_tuple(a,b,c,
d,e,f));
if(it != map.end()) {
accumulator(std::get<0>(it->second),
std::get<1>(it->second)) += 1.0;
}
}
}
// search for high vote numbers in accumulator
Eigen::Matrix<Real,1,1> w = accumulator.colwise().maxCoeff();
for(int i = 0; i < accumulator.cols(); ++i) {
if(w(0,i) > 10) {
//TODO do something awesome with the found hit
std::cerr<<"asdfsadfasdfasdf"<<std::endl;
}
}
accumulator.setZero();
}
}
}
}
}
return geom::Mat4();
}
void PocketFinder(const ost::mol::EntityView& query_view,
const ost::mol::EntityView& target_view) {
geom::Vec3List query_pos;
geom::Vec3List target_pos;
ost::mol::ResidueViewList query_res_list = query_view.GetResidueList();
ost::mol::ResidueViewList target_res_list = target_view.GetResidueList();
for(ost::mol::ResidueViewList::const_iterator it = query_res_list.begin();
it != query_res_list.end(); ++it) {
ost::mol::AtomView ca = it->FindAtom("CA");
if(!ca.IsValid()) {
throw "fuuuuck";
}
query_pos.push_back(ca.GetPos());
}
for(ost::mol::ResidueViewList::const_iterator it = target_res_list.begin();
it != target_res_list.end(); ++it) {
ost::mol::AtomView ca = it->FindAtom("CA");
if(!ca.IsValid()) {
throw "fuuuuck";
}
target_pos.push_back(ca.GetPos());
}
geom::Mat4 t = FindTransformAwesome(query_pos, target_pos);
std::cerr<<t<<std::endl;
}
}}
// Copyright (c) 2013-2019, SIB - Swiss Institute of Bioinformatics and
// Biozentrum - University of Basel
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef PROMOD_MODELLING_POCKET_FINDER_HH
#define PROMOD_MODELLING_POCKET_FINDER_HH
#include <ost/mol/entity_view.hh>
#include <ost/mol/residue_view.hh>
#include <ost/geom/vec3.hh>
#include <vector>
namespace promod3 { namespace modelling {
class PocketQuery{
public:
static PocketQuery FromEntity(const ost::mol::ResidueView& res,
const ost::mol::EntityView& env);
private:
};
void PocketFinder(const ost::mol::EntityView& view_one,
const ost::mol::EntityView& view_two);
}} // ns
#endif
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment