# Copyright (c) 2013-2020, SIB - Swiss Institute of Bioinformatics and
#                          Biozentrum - University of Basel
# 
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
#   http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import unittest
from ost import io, mol, geom
from promod3 import loop, sidechain, modelling
import os

class SidechainTests(unittest.TestCase):

    @classmethod
    def setUpClass(cls):
        '''Load libs here for all tests.'''
        cls.bbdep_rotamer_library = sidechain.LoadBBDepLib() # DEFAULT
        cls.rotamer_library = sidechain.LoadLib()

    #######################################################################
    # HELPERs
    #######################################################################
    def CheckDistances(self, ent, ent_ref, max_dist=1e-6):
        # Ensure all atoms in ent_ref exist in ent and have dist <= max_dist
        for a_ref in ent_ref.atoms:
            if a_ref.name == "OXT": continue
            a = ent.FindAtom(a_ref.chain.name, a_ref.residue.number, a_ref.name)
            self.assertTrue(a.IsValid())
            self.assertLessEqual(geom.Length(a.pos - a_ref.pos), max_dist)

    def CheckEnvVsPy(self, ent, env, keep_sidechains, build_disulfids,
                     optimize_subrotamers,
                     rotamer_model, rotamer_library):
        # reconstruct sidechains for full OST entity
        ent_py = ent.Copy()
        modelling.ReconstructSidechains(ent_py, keep_sidechains=keep_sidechains,
                                        build_disulfids=build_disulfids,
                                        optimize_subrotamers=optimize_subrotamers,
                                        rotamer_model=rotamer_model,
                                        rotamer_library=rotamer_library)

        # same with SidechainReconstructor
        sc_rec = modelling.SidechainReconstructor( \
                               keep_sidechains=keep_sidechains,
                               build_disulfids=build_disulfids,
                               optimize_subrotamers=optimize_subrotamers)
        sc_rec.AttachEnvironment(env, use_frm=(rotamer_model=="frm"),
                                 rotamer_library=rotamer_library)
        res = sc_rec.Reconstruct(1, ent.residue_count)
        ent_cc = res.env_pos.all_pos.ToEntity()
        self.CheckDistances(ent_cc, ent_py)
    #######################################################################

    def testReconstruct(self):
        infile = os.path.join('data', '1eye.pdb')
        outfile = os.path.join('data', '1eye_rec.pdb')
        # get and reconstruct 1eye
        prot = io.LoadPDB(infile)
        modelling.ReconstructSidechains(prot, keep_sidechains=False,
                                        rotamer_library=self.bbdep_rotamer_library)
        # compare with reference solution
        prot_rec = io.LoadPDB(outfile)
        self.assertEqual(prot.GetAtomCount(), prot_rec.GetAtomCount())
        # NOTE: ignore rmsd for now (fails too easily)
        #diff = mol.alg.Superpose(prot_rec, prot)
        #self.assertLess(diff.rmsd, 0.01)

    def testReconstructEnvVsPy(self):
        # modified 1eye with no gaps and some sidechains missing (1 per AA-type)
        ent = io.LoadPDB(os.path.join('data', '1eye_sc_test.pdb'))
        # start with full reconstruction with RRM
        seqres_str = ''.join([r.one_letter_code for r in ent.residues])
        env = loop.AllAtomEnv(seqres_str)
        env.SetInitialEnvironment(ent)
        self.CheckEnvVsPy(ent, env, keep_sidechains=False,
                          build_disulfids=False, optimize_subrotamers=False,
                          rotamer_model="rrm",
                          rotamer_library=self.rotamer_library)
        # reuse env with keep_sidechains=True
        self.CheckEnvVsPy(ent, env, keep_sidechains=True,
                          build_disulfids=False, optimize_subrotamers=False,
                          rotamer_model="rrm",
                          rotamer_library=self.rotamer_library)
        # vary one by one (need to reset env to get new stuff)
        env = loop.AllAtomEnv(seqres_str)
        env.SetInitialEnvironment(ent)
        self.CheckEnvVsPy(ent, env, keep_sidechains=True,
                          build_disulfids=False, optimize_subrotamers=False,
                          rotamer_model="frm",
                          rotamer_library=self.rotamer_library)
        env = loop.AllAtomEnv(seqres_str)
        env.SetInitialEnvironment(ent)
        self.CheckEnvVsPy(ent, env, keep_sidechains=True,
                          build_disulfids=False, optimize_subrotamers=False,
                          rotamer_model="rrm",
                          rotamer_library=self.rotamer_library)
        env = loop.AllAtomEnv(seqres_str)
        env.SetInitialEnvironment(ent)
        self.CheckEnvVsPy(ent, env, keep_sidechains=True,
                          build_disulfids=False, optimize_subrotamers=False,
                          rotamer_model="rrm",
                          rotamer_library=self.bbdep_rotamer_library)
        env = loop.AllAtomEnv(seqres_str)
        env.SetInitialEnvironment(ent)
        self.CheckEnvVsPy(ent, env, keep_sidechains=False,
                          build_disulfids=True, optimize_subrotamers=True,
                          rotamer_model="frm",
                          rotamer_library=self.bbdep_rotamer_library)
        
        # crn needed to check for disulfid bridges
        ent = io.LoadPDB(os.path.join('data', '1crn_sc_test.pdb'))
        seqres_str = ''.join([r.one_letter_code for r in ent.residues])
        env = loop.AllAtomEnv(seqres_str)
        env.SetInitialEnvironment(ent)
        self.CheckEnvVsPy(ent, env, keep_sidechains=True,
                          build_disulfids=True, optimize_subrotamers=False,
                          rotamer_model="rrm",
                          rotamer_library=self.rotamer_library)

if __name__ == "__main__":
    from ost import testutils
    testutils.RunTests()