Skip to content
Snippets Groups Projects
svd_superpose.cc 15.21 KiB
//------------------------------------------------------------------------------
// This file is part of the OpenStructure project <www.openstructure.org>
//
// Copyright (C) 2008-2011 by the OpenStructure authors
//
// This library is free software; you can redistribute it and/or modify it under
// the terms of the GNU Lesser General Public License as published by the Free
// Software Foundation; either version 3.0 of the License, or (at your option)
// any later version.
// This library is distributed in the hope that it will be useful, but WITHOUT
// ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
// FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public License for more
// details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with this library; if not, write to the Free Software Foundation, Inc.,
// 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
//------------------------------------------------------------------------------

#include <stdexcept>
#include <iostream>
#include <boost/bind.hpp>

#include <Eigen/Core>
#include <Eigen/Array>
#include <Eigen/SVD>
#include <Eigen/LU>
#include <Eigen/StdVector>

#include <ost/base.hh>
#include <ost/geom/vec3.hh>
#include <ost/geom/mat4.hh>
#include <ost/mol/alg/svd_superpose.hh>
#include <ost/mol/xcs_editor.hh>
#include <ost/mol/residue_handle.hh>
#include <ost/mol/residue_view.hh>
#include <ost/mol/chain_view.hh>
#include <ost/mol/chain_handle.hh>
#include <ost/mol/view_op.hh>
#include <ost/mol/atom_view.hh>

namespace ost { namespace mol { namespace alg {


using boost::bind;
typedef Eigen::Matrix<Real,3,1> EVec3;
typedef Eigen::Matrix<Real,3,3> EMat3;
typedef Eigen::Matrix<Real,4,4> EMat4;
typedef Eigen::Matrix<Real,1,3> ERVec3;
typedef Eigen::Matrix<Real,Eigen::Dynamic,Eigen::Dynamic> EMatX;
typedef Eigen::Matrix<Real,1,Eigen::Dynamic> ERVecX;





Real calc_rmsd_for_point_lists(const std::vector<geom::Vec3>& points1,
                              const std::vector<geom::Vec3>& points2,
                              const geom::Mat4& transformation)
{
  
  Real rmsd=0.0;
  std::vector<geom::Vec3>::const_iterator a_ev1=points1.begin();
  std::vector<geom::Vec3>::const_iterator a_ev1_end=points1.end();
  std::vector<geom::Vec3>::const_iterator a_ev2=points2.begin();
  for (; a_ev1_end!=a_ev1; ++a_ev1, ++a_ev2) {
    Real val=geom::Length2(geom::Vec3(transformation*geom::Vec4(*a_ev1))-*a_ev2);
    rmsd+=val;
  }
  rmsd=sqrt(rmsd/points1.size());
  return rmsd;
}

Real calc_rmsd_for_atom_lists(const mol::AtomViewList& atoms1,
                              const mol::AtomViewList& atoms2,
                              const geom::Mat4& transformation)
{
  Real rmsd=0.0;
  mol::AtomViewList::const_iterator a_ev1=atoms1.begin();
  mol::AtomViewList::const_iterator a_ev1_end=atoms1.end();
  mol::AtomViewList::const_iterator a_ev2=atoms2.begin();
  for (; a_ev1_end!=a_ev1; ++a_ev1, ++a_ev2) {
    mol::AtomView av1=*a_ev1;
    geom::Vec3 apos=av1.GetPos();
    mol::AtomView av2=*a_ev2;
    geom::Vec3 bpos=av2.GetPos();
    Real val=geom::Length2(geom::Vec3(transformation*geom::Vec4(apos))-bpos);
    rmsd+=val;
  }
  rmsd=sqrt(rmsd/atoms1.size());
  return rmsd;
}

Real calc_rmsd_for_ematx(const EMatX& atoms1,
                         const EMatX& atoms2,
                         const EMat4& transformation)
{
  EMatX transformed_atoms1 = EMatX::Zero(atoms1.rows(), 3);

  EMatX vector = EMatX::Zero(4,1);
  EMatX transformed_vector = EMatX::Zero(4,1);
  vector(3,0)=1;

  for(int i=0;i<atoms1.rows();++i){
    vector.block<3,1>(0,0)=atoms1.block<1,3>(i,0).transpose();
    transformed_vector = transformation*vector;
    transformed_atoms1.block<1,3>(i,0)=transformed_vector.block<3,1>(0,0).transpose();
  }

  EMatX diff = EMatX::Zero(atoms1.rows(),atoms1.cols());
  EMatX squared_dist = EMatX::Zero(atoms1.rows(),1);

  diff = transformed_atoms1-atoms2;
  squared_dist = (diff.cwise()*diff).rowwise().sum(); 

  return sqrt(squared_dist.sum()/squared_dist.rows());

}

Real CalculateRMSD(const mol::EntityView& ev1,
                   const mol::EntityView& ev2,
                   const geom::Mat4& transformation) {

  return calc_rmsd_for_atom_lists(ev1.GetAtomList(), ev2.GetAtomList(), 
                                  transformation);
}

geom::Vec3 EigenVec3ToVec3(const EVec3 &vec)
{
  return geom::Vec3(vec.data());
}

geom::Mat3 EigenMat3ToMat3(const EMat3 &mat)
{
  geom::Mat3 return_mat;
  for(int i=0;i<3;++i){
    for(int j=0;j<3;++j){
      return_mat(i,j) = mat(i,j);
    }
  }
  return return_mat;
  //return geom::Mat3(mat.data());
}

geom::Mat4 EigenMat4ToMat4(const EMat4 &mat)
{ 
  geom::Mat4 return_mat;
  for(int i=0;i<4;++i){
    for(int j=0;j<4;++j){
      return_mat(i,j) = mat(i,j);
    }
  }
  return return_mat;
  //return geom::Mat4(mat.data());
}

EMatX Mat4ToEigenMat4(const geom::Mat4 &mat){
  EMat4 res = EMat4::Zero();
  for(int i=0;i<4;++i){
    for(int j=0;j<4;++j){
      res(i,j)=mat.At(i,j);
    }
  }
  return res;
}

EVec3 Vec3ToEigenRVec(const geom::Vec3 &vec)
{
  return EVec3(&vec[0]);
}

EVec3 Vec3ToEigenVec(const geom::Vec3 &vec)
{
  return EVec3(&vec[0]);
}

EMatX MatrixShiftedBy(EMatX mat, ERVecX vec)
{
  EMatX result = mat;
  for (int row=0;  row<mat.rows();++row) {
    result.row(row) -= vec;
  }
  return result;
}


class MeanSquareMinimizerImpl {
public:
  MeanSquareMinimizerImpl(int n_atoms, bool alloc_atoms):
    n_atoms_(n_atoms), alloc_atoms_(alloc_atoms) 
  {
    if (alloc_atoms_) {
      atoms1_=EMatX::Zero(n_atoms_, 3);
      atoms2_=EMatX::Zero(n_atoms_, 3);
    }
  }

  void SetRefPos(size_t index, const geom::Vec3& pos) {
    atoms2_.row(index) = ERVec3(Vec3ToEigenVec(pos));
  }

  void SetPos(size_t index, const geom::Vec3& pos) {
    atoms1_.row(index) = ERVec3(Vec3ToEigenVec(pos));
  }

  SuperpositionResult Minimize(const EMatX& atoms, const EMatX& atoms_ref) const;

  EMatX TransformEMatX(const EMatX& mat, const EMat4& transformation) const;

  std::pair<EMatX,EMatX> CreateMatchingSubsets(const EMatX& atoms, const EMatX& atoms_ref, Real distance_threshold) const;

  SuperpositionResult IterativeMinimize(int ncycles, Real distance_threshold) const;

  SuperpositionResult MinimizeOnce() const;

private:
  int   n_atoms_;
  bool  alloc_atoms_;
  EMatX atoms1_;
  EMatX atoms2_;
};


MeanSquareMinimizer MeanSquareMinimizer::FromAtomLists(const AtomViewList& atoms,
                                                       const AtomViewList& atoms_ref) 
{
  int n_atoms = atoms.size();
  int n_atoms_ref = atoms_ref.size();
  if (n_atoms != n_atoms_ref) {
    throw Error("atom counts do not match");
  }
  if (n_atoms<3) {
    throw Error("at least 3 atoms are required for superposition");
  }
  MeanSquareMinimizer msm;
  msm.impl_ = new MeanSquareMinimizerImpl(n_atoms, true);

  for (size_t i = 0; i < atoms.size(); ++i ) {
    msm.impl_->SetRefPos(i, atoms_ref[i].GetPos());
    msm.impl_->SetPos(i, atoms[i].GetPos());
  }
  return msm;
}

MeanSquareMinimizer MeanSquareMinimizer::FromPointLists(const std::vector<geom::Vec3>& points,
                                                        const std::vector<geom::Vec3>& points_ref) 
{
  int n_points = points.size();
  int n_points_ref = points_ref.size();
  if (n_points != n_points_ref) {
    throw Error("point counts do not match");
  }
  if (n_points<3) {
    throw Error("at least 3 points are required for superposition");
  }
  MeanSquareMinimizer msm;
  msm.impl_ = new MeanSquareMinimizerImpl(n_points, true);

  for (size_t i = 0; i < points.size(); ++i ) {
    msm.impl_->SetRefPos(i, points_ref[i]);
    msm.impl_->SetPos(i, points[i]);
  }
  return msm;
}

SuperpositionResult MeanSquareMinimizerImpl::MinimizeOnce() const{
  return this->Minimize(atoms1_,atoms2_);
}

SuperpositionResult MeanSquareMinimizerImpl::IterativeMinimize(int max_cycles, Real distance_threshold) const{

  // see http://eigen.tuxfamily.org/dox/TopicStlContainers.html
  std::vector<EMat4,Eigen::aligned_allocator<EMat4> > transformation_matrices;
  EMat4 transformation_matrix;
  EMatX atoms = atoms1_;
  SuperpositionResult res;
  EMat4 diff;
  std::pair<EMatX,EMatX> subsets;
  EMat4 identity_matrix = EMat4::Identity();
  //do initial superposition

  res = this->Minimize(atoms, atoms2_);
  transformation_matrices.push_back(Mat4ToEigenMat4(res.transformation));

  //note, that the initial superposition is the first cycle...
  int cycles=1;

  for(;cycles<max_cycles;++cycles){
    atoms = this->TransformEMatX(atoms, transformation_matrices.back());
    subsets = this->CreateMatchingSubsets(atoms, atoms2_, distance_threshold);
    res = this->Minimize(subsets.first,subsets.second);
    transformation_matrix = Mat4ToEigenMat4(res.transformation);
    transformation_matrices.push_back(transformation_matrix);

    diff = transformation_matrix-identity_matrix;

    if(diff.cwise().abs().sum()<0.001){
      break;
    }

  }
  
  res.rmsd_superposed_atoms = calc_rmsd_for_ematx(subsets.first, subsets.second, transformation_matrices.back());
  res.fraction_superposed = float(subsets.first.rows())/atoms1_.rows();

  //combine the transformations into one transformation
  transformation_matrix = transformation_matrices.back();
  transformation_matrices.pop_back();
  while(!transformation_matrices.empty()){
    transformation_matrix*=transformation_matrices.back();
    transformation_matrices.pop_back();
  }

  res.transformation = EigenMat4ToMat4(transformation_matrix);
  res.ncycles=cycles;

  return res;

}

EMatX MeanSquareMinimizerImpl::TransformEMatX(const EMatX& mat, const EMat4& transformation) const {

  EMatX transformed_mat = EMatX::Zero(mat.rows(), 3);

  EMatX vector = EMatX::Zero(4,1);
  EMatX transformed_vector = EMatX::Zero(4,1);
  vector(3,0)=1;

  for(int i=0;i<mat.rows();++i){
    vector.block<3,1>(0,0)=mat.block<1,3>(i,0).transpose();
    transformed_vector = transformation*vector;
    transformed_mat.block<1,3>(i,0)=transformed_vector.block<3,1>(0,0).transpose();
  }

  return transformed_mat;
}

std::pair<EMatX, EMatX> MeanSquareMinimizerImpl::CreateMatchingSubsets(const EMatX& atoms, const EMatX& atoms_ref, Real distance_threshold) const{


  EMatX diff = EMatX::Zero(atoms.rows(),atoms.cols());
  EMatX dist = EMatX::Zero(atoms.rows(),1);

  diff = atoms-atoms_ref;
  dist = (diff.cwise()*diff).rowwise().sum(); 
  dist = dist.cwise().sqrt();
  
  for(int i = 0; i < dist.rows(); ++i){
    if(dist(i,0) <= distance_threshold){
      dist(i,0) = 1;
    }
    else{
      dist(i,0) = 0;
    }
  }

  EMatX atoms_subset = EMatX::Zero(int(dist.sum()),3);
  EMatX atoms_ref_subset = EMatX::Zero(int(dist.sum()),3);

  int actual_pos=0;

  for(int i = 0; i < dist.rows() ; ++i){
    if(dist(i,0)==1){
      atoms_subset.row(actual_pos) = atoms.row(i);
      atoms_ref_subset.row(actual_pos) = atoms_ref.row(i);
      ++actual_pos;
    }
  }
  return std::make_pair(atoms_subset, atoms_ref_subset);
}


SuperpositionResult MeanSquareMinimizerImpl::Minimize(const EMatX& atoms, const EMatX& atoms_ref) const {
  ERVec3 avg = atoms.colwise().sum()/atoms.rows();
  ERVec3 avg_ref = atoms_ref.colwise().sum()/atoms_ref.rows();

  // SVD only determines the rotational component of the superposition
  // we need to manually shift the centers of the two point sets on onto 
  // origin

  EMatX atoms_shifted = MatrixShiftedBy(atoms, avg);
  EMatX atoms_ref_shifted = MatrixShiftedBy(atoms_ref, avg_ref).transpose();

  // determine rotational component
  Eigen::SVD<EMat3> svd(atoms_ref_shifted*atoms_shifted);
  EMatX matrixVT=svd.matrixV().transpose();

  //determine rotation
  Real detv=matrixVT.determinant();
  Real dett=svd.matrixU().determinant();
  Real det=detv*dett;
  EMat3 rotation;
  if (det<0) {
    EMat3 tmat=EMat3::Identity();
    tmat(2,2)=-1;
    rotation = (svd.matrixU()*tmat)*matrixVT;
  } else {
    rotation = svd.matrixU()*matrixVT;
  }
  
  SuperpositionResult res;

  geom::Vec3 shift = EigenVec3ToVec3(avg_ref);
  geom::Vec3 com_vec = -EigenVec3ToVec3(avg);
  //geom::Mat3 rot = EigenMat3ToMat3(rotation.transpose());
  geom::Mat3 rot = EigenMat3ToMat3(rotation);
  geom::Mat4 mat4_com, mat4_rot, mat4_shift;
  mat4_rot.PasteRotation(rot);
  mat4_shift.PasteTranslation(shift);
  mat4_com.PasteTranslation(com_vec);
  //save whole transformation
  res.transformation = geom::Mat4(mat4_shift*mat4_rot*mat4_com);
  return res;
}

MeanSquareMinimizer& MeanSquareMinimizer::operator=(const MeanSquareMinimizer& rhs) {
  MeanSquareMinimizer tmp(rhs);
  this->swap(tmp);
  return *this;
}

MeanSquareMinimizer::MeanSquareMinimizer(const MeanSquareMinimizer& rhs) {
  if (rhs.impl_)
    impl_ = new MeanSquareMinimizerImpl(*rhs.impl_);
  else
    impl_ = NULL;
}


MeanSquareMinimizer::~MeanSquareMinimizer() {
  if (impl_) delete impl_;
}

SuperpositionResult MeanSquareMinimizer::MinimizeOnce() const {
  return impl_->MinimizeOnce();
}

SuperpositionResult MeanSquareMinimizer::IterativeMinimize(int ncycles, Real distance_threshold) const {
  return impl_->IterativeMinimize(ncycles, distance_threshold);
}

SuperpositionResult SuperposeAtoms(const mol::AtomViewList& atoms1,
                                   const mol::AtomViewList& atoms2,
                                   bool apply_transform=true)
{
  MeanSquareMinimizer msm = MeanSquareMinimizer::FromAtomLists(atoms1, atoms2);
  SuperpositionResult result = msm.MinimizeOnce();
  
  result.ncycles=1;
  result.rmsd = calc_rmsd_for_atom_lists(atoms1, atoms2, result.transformation);
  result.rmsd_superposed_atoms = result.rmsd;
  result.fraction_superposed = 1.0;
  if (apply_transform) {
    mol::AtomView jv=atoms1.front();
    mol::XCSEditor ed=jv.GetEntity().GetHandle().EditXCS();
    ed.ApplyTransform(result.transformation);
  }
  return result;

}

SuperpositionResult SuperposeSVD(const mol::EntityView& ev1,
                                 const mol::EntityView& ev2,
                                 bool apply_transform=true) {
  AtomViewList atoms1 = ev1.GetAtomList();
  AtomViewList atoms2 = ev2.GetAtomList();

  SuperpositionResult result = SuperposeAtoms(atoms1, atoms2, apply_transform);
  result.entity_view1 = ev1;
  result.entity_view2 = ev2;
  return result;
}

SuperpositionResult SuperposeSVD(const std::vector<geom::Vec3>& pl1,
                                 const std::vector<geom::Vec3>& pl2)
{
  MeanSquareMinimizer msm = MeanSquareMinimizer::FromPointLists(pl1, pl2);
  SuperpositionResult result = msm.MinimizeOnce();
  
  result.ncycles=1;
  result.rmsd = calc_rmsd_for_point_lists(pl1, pl2, result.transformation);
  return result;
}

SuperpositionResult IterativeSuperposeSVD(const mol::EntityView& ev,
                                          const mol::EntityView& ev_ref,
                                          int max_cycles,
                                          Real distance_threshold,
                                          bool apply_transform){

  AtomViewList atoms = ev.GetAtomList();
  AtomViewList atoms_ref = ev_ref.GetAtomList();

  MeanSquareMinimizer msm = MeanSquareMinimizer::FromAtomLists(atoms, atoms_ref);
  SuperpositionResult result = msm.IterativeMinimize(max_cycles, distance_threshold);
  result.rmsd = calc_rmsd_for_atom_lists(atoms, atoms_ref, result.transformation);
  if (apply_transform) {
    mol::AtomView jv=atoms.front();
    mol::XCSEditor ed=jv.GetEntity().GetHandle().EditXCS();
    ed.ApplyTransform(result.transformation);
  }

  result.entity_view1 = ev;
  result.entity_view2 = ev_ref;
  return result;
}


}}} //ns