# Copyright (c) 2013-2018, SIB - Swiss Institute of Bioinformatics and 
#                          Biozentrum - University of Basel
# 
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
#   http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


'''High-level functionality for modelling module to build pipelines. Added in 
the __init__.py file. To be used directly by passing a ModellingHandle instance
as argument.
'''

# internal
from promod3 import loop, sidechain, core
from _modelling import *
from _reconstruct_sidechains import *
from _closegaps import *
from _ring_punches import *
# external
import ost
from ost import mol, conop
from ost.mol import mm
import os, math

###############################################################################
# helper functions
def _RemoveHydrogens(ent):
    '''Hydrogen naming depends on the force field used.
    To be on the safe side, we simply remove all of them.
    Note that the pipeline ignores hydrogens when building the raw model and
    when rebuilding sidechains.
    '''
    edi = ent.EditXCS(mol.BUFFERED_EDIT)
    ha = ent.Select('ele=H')
    for a in ha.atoms:
        edi.DeleteAtom(a.handle)
    edi.UpdateICS()

def _AddHeuristicHydrogens(ent, ff):
    '''Add hydrogens with mm.HeuristicHydrogenConstructor.'''
    for res in ent.residues:
        if not res.IsPeptideLinking():
            bb = ff.GetBuildingBlock(res.name)
            edi = ent.EditXCS(mol.BUFFERED_EDIT)
            h_constructor = mm.HeuristicHydrogenConstructor(bb)
            h_constructor.ApplyOnResidue(res, edi)
            edi.UpdateICS()

def _GetTopology(ent, settings, force_fields, add_heuristic_hydrogens=False):
    '''Return topology or None if no topology could be created.
    Note: if successful, this will update ent (adding hydrogens).
    Set add_heuristic_hydrogens to True for ligands.
    '''
    # try all force fields in order
    for i_ff, ff in enumerate(force_fields):
        settings.forcefield = ff
        try:
            # check if we need to add hydrogens heuristically
            if add_heuristic_hydrogens:
                _AddHeuristicHydrogens(ent, ff)
            # ok now we try...
            topo = mm.TopologyCreator.Create(ent, settings)
        except Exception as ex:
            # report only for debugging
            ost.LogVerbose("Could not create mm topology for ff %d. %s" \
                           % (i_ff, type(ex).__name__ + ": " + ex.message))
            continue
        else:
            # all good
            return topo
    # if we got here, nothing worked
    return None

def _AddLigands(ent, top, lig_ent, settings, force_fields):
    '''Add ligands from lig_ent to topology top and update entity ent.'''
    # connect them first
    proc = conop.RuleBasedProcessor(conop.GetDefaultLib())
    proc.Process(lig_ent)
    cur_res = lig_ent.residues[0]
    lig_num = 1
    while cur_res.IsValid():
        # setup connected components
        cur_view = lig_ent.CreateEmptyView()
        cur_view.AddResidue(cur_res, mol.INCLUDE_ATOMS)
        cur_res = cur_res.next
        while cur_res.IsValid() and mol.InSequence(cur_res.prev, cur_res):
            cur_view.AddResidue(cur_res, mol.INCLUDE_ATOMS)
            cur_res = cur_res.next
        # try to add topology with special named chain
        cur_ent = mol.CreateEntityFromView(cur_view, True)
        edi = cur_ent.EditXCS()
        edi.RenameChain(cur_ent.chains[0], '_' + str(lig_num))
        lig_num += 1
        cur_top = _GetTopology(cur_ent, settings, force_fields, True)
        if cur_top is None:
            view_res_str = str([str(r) for r in cur_view.residues])
            ost.LogError("Failed to add ligands " + view_res_str + \
                         " for energy minimization! Skipping...")
        else:
            # merge into main topology
            cur_top.SetFudgeLJ(top.GetFudgeLJ())
            cur_top.SetFudgeQQ(top.GetFudgeQQ())
            top.Merge(ent, cur_top, cur_ent)

def _SetupMmSimulation(model, force_fields):
    '''Get mm simulation object for the model (incl. ligands in chain "_").
    This tries to generate a topology for the protein and for each connected
    component in the ligand chain separately by evaluating the force fields in
    the same order as given. Ligands without force fields are skipped.
    '''
    prof = core.StaticRuntimeProfiler.StartScoped('pipeline::_SetupMmSimulation')

    # get general settings 
    settings = mm.Settings()
    settings.integrator = mm.LangevinIntegrator(310, 1, 0.002)
    settings.init_temperature = 0
    settings.nonbonded_method = mm.NonbondedMethod.CutoffNonPeriodic
    settings.keep_ff_specific_naming = False
    
    # prepare entity with protein
    _RemoveHydrogens(model)
    ent = mol.CreateEntityFromView(model.Select("cname!='_'"), True)
    top = _GetTopology(ent, settings, force_fields)
    if top is None:
        raise RuntimeError("Failed to setup protein for energy minimization!")
    
    # prepare ligands: we reprocess them to ensure connectivity
    lig_ent = mol.CreateEntityFromView(model.Select("cname='_'"), True)
    if lig_ent.residue_count > 0:
        _AddLigands(ent, top, lig_ent, settings, force_fields)

    # use fast CPU platform by default
    # NOTE: settings.platform only relevant for mm.Simulation!
    settings.platform = mm.Platform.CPU
    if mm.Simulation.IsPlatformAvailable(settings):
        num_cpu_threads = os.getenv('PM3_OPENMM_CPU_THREADS')
        if num_cpu_threads is None:
            settings.cpu_properties["CpuThreads"] = "1"
        else:
            settings.cpu_properties["CpuThreads"] = num_cpu_threads
    else:
        # switch to "mm.Platform.Reference" as fallback
        settings.platform = mm.Platform.Reference
        ost.LogWarning("Switched to slower reference platform of OpenMM!")
    # finally set up the simulation
    sim = mm.Simulation(top, ent, settings)
        
    return sim

def _GetSimEntity(sim):
    '''Get Entity from mm sim and reverse ligand chain naming.'''
    ent = sim.GetEntity().Copy()
    # merge ligand chains into _
    ent_ed = ent.EditXCS(mol.BUFFERED_EDIT)
    chain_names = [ch.name for ch in ent.chains]
    for chain_name in chain_names:
        # all separate ligand chains start with _
        if chain_name[0] == '_':
            # add to chain _
            if not ent.FindChain('_').IsValid():
                ent_ed.InsertChain('_')
            lig_ch = ent.FindChain('_')
            cur_ch = ent.FindChain(chain_name)
            for res in cur_ch.residues:
                ent_ed.AppendResidue(lig_ch, res, deep=True)
            # remove old chain
            ent_ed.DeleteChain(cur_ch)
    ent_ed.UpdateICS()
    return ent
###############################################################################


def BuildSidechains(mhandle, merge_distance=4, fragment_db=None,
                    structure_db=None, torsion_sampler=None,
                    rotamer_library=None):
    '''Build sidechains for model.

    This is a wrapper for :func:`promod3.modelling.ReconstructSidechains`, 
    followed by a check for ring punches. If ring punches are found it 
    introduces gaps for the residues with punched rings and tries to fill them 
    with :func:`FillLoopsByDatabase` with *ring_punch_detection=2*.

    :param mhandle: Modelling handle on which to apply change.
    :type mhandle:  :class:`ModellingHandle`

    :param merge_distance:  Used as parameter for :func:`MergeGapsByDistance`
                            if ring punches are found.
    :type merge_distance:   :class:`int`
    :param fragment_db:     Used as parameter for :func:`FillLoopsByDatabase`
                            if ring punches are found. A default one is loaded
                            if None.
    :type fragment_db:      :class:`~promod3.loop.FragDB`
    :param structure_db:    Used as parameter for :func:`FillLoopsByDatabase`
                            if ring punches are found. A default one is loaded
                            if None.
    :type structure_db:     :class:`~promod3.loop.StructureDB`
    :param torsion_sampler: Used as parameter for :func:`FillLoopsByDatabase`
                            if ring punches are found. A default one is loaded
                            if None.
    :type torsion_sampler:  :class:`~promod3.loop.TorsionSampler`
    :param rotamer_library: Used as parameter for 
                            :func:`modelling.ReconstructSidechains`, a default 
                            one is loaded if None.
    :type rotamer_library:  :class:`~promod3.sidechain.RotamerLib` or
                            :class:`~promod3.sidechain.BBDepRotamerLib` 
    '''
    prof = core.StaticRuntimeProfiler.StartScoped('pipeline::BuildSidechains')
    ost.LogInfo("Rebuilding sidechains.")
    ReconstructSidechains(mhandle.model, keep_sidechains=True, 
                          rotamer_library=rotamer_library)
    # check for ring punches
    rings = GetRings(mhandle.model)
    ring_punches = GetRingPunches(rings, mhandle.model)
    # try to fix them
    if len(ring_punches) > 0:
        ost.LogInfo("Trying to fix %d ring punch(es)." % len(ring_punches))
        # backup old gaps
        old_gaps = [g.Copy() for g in mhandle.gaps]
        # new gaps for mhandle
        # NOTE: we currently do not delete the punched residues here
        #       BUT they could be deleted when merging gaps below...
        mhandle.gaps = StructuralGapList()
        for res in ring_punches:
            mygap = StructuralGap(res.prev, res.next, res.one_letter_code)
            mhandle.gaps.append(mygap)
        # load stuff if needed
        if fragment_db is None:
            fragment_db = loop.LoadFragDB()
        if structure_db is None:
            structure_db = loop.LoadStructureDB()
        if torsion_sampler is None:
            torsion_sampler = loop.LoadTorsionSamplerCoil()
        # fix it
        MergeGapsByDistance(mhandle, merge_distance)
        FillLoopsByDatabase(mhandle, fragment_db, structure_db,
                            torsion_sampler, ring_punch_detection=2)
        # re-build sidechains
        ReconstructSidechains(mhandle.model, keep_sidechains=True,
                              rotamer_library=rotamer_library)
        # restore gaps
        mhandle.gaps = StructuralGapList()
        for g in old_gaps:
            mhandle.gaps.append(g)
            

def MinimizeModelEnergy(mhandle, max_iterations=12, max_iter_sd=20,
                        max_iter_lbfgs=10, use_amber_ff=False,
                        extra_force_fields=list()):
    '''Minimize energy of final model using molecular mechanics.

    Uses :mod:`ost.mol.mm` to perform energy minimization.
    It will iteratively (at most *max_iterations* times):
    
    - run up to *max_iter_sd* minimization iter. of a steepest descend method
    - run up to *max_iter_lbfgs* minimization iter. of a Limited-memory 
      Broyden-Fletcher-Goldfarb-Shanno method
    - abort if no stereochemical problems found

    The idea is that we don't want to minimize "too much". So, we iteratively
    minimize until there are no stereochemical problems and not more.

    To speed things up, this can run on multiple CPU threads by setting the
    env. variable ``PM3_OPENMM_CPU_THREADS`` to the number of desired threads.
    If the variable is not set, 1 thread will be used by default.

    :param mhandle: Modelling handle on which to apply change.
    :type mhandle:  :class:`ModellingHandle`

    :param max_iterations: Max. number of iterations for SD+LBFGS
    :type max_iterations:  :class:`int`

    :param max_iter_sd: Max. number of iterations within SD method
    :type max_iter_sd:  :class:`int`

    :param max_iter_lbfgs: Max. number of iterations within LBFGS method
    :type max_iter_lbfgs:  :class:`int`

    :param use_amber_ff: if True, use the AMBER force field instead of the def.
                         CHARMM one (see :meth:`BuildFromRawModel`).
    :type use_amber_ff:  :class:`bool`

    :param extra_force_fields: Additional list of force fields to use (see
                               :meth:`BuildFromRawModel`).
    :type extra_force_fields:  :class:`list` of :class:`ost.mol.mm.Forcefield`

    :return: The model including all oxygens as used in the minimizer.
    :rtype:  :class:`Entity <ost.mol.EntityHandle>`
    '''
    prof = core.StaticRuntimeProfiler.StartScoped('pipeline::MinimizeModelEnergy')
    ost.LogInfo("Minimize energy.")
    # ignore LogInfo in stereochemical problems if output up to info done
    ignore_stereo_log = (ost.GetVerbosityLevel() == 3)

    # setup force fields
    if use_amber_ff:
        force_fields = [mm.LoadAMBERForcefield()]
    else:
        force_fields = [mm.LoadCHARMMForcefield()]
    force_fields.extend(extra_force_fields)
    # setup mm simulation
    sim = _SetupMmSimulation(mhandle.model, force_fields)

    # check for certain failure -> we get NaN/Inf if atoms are on top each other
    cur_energy = sim.GetEnergy()
    if math.isnan(cur_energy):
        ost.LogError("OpenMM could not minimize energy as atoms are on top of "
                     "each other!")
        return
    if math.isinf(cur_energy):
        ost.LogError("OpenMM could not minimize energy as atoms are almost "
                     "on top of each other!")
        return
        
    # settings to check for stereochemical problems
    clashing_distances = mol.alg.DefaultClashingDistances()
    bond_stereo_chemical_param = mol.alg.DefaultBondStereoChemicalParams()
    angle_stereo_chemical_param = mol.alg.DefaultAngleStereoChemicalParams()

    for i in range(max_iterations):
        # update atoms
        ost.LogInfo("Perform energy minimization "
                    "(iteration %d, energy: %g)" % (i+1, cur_energy))
        sim.ApplySD(tolerance = 1.0, max_iterations = max_iter_sd)
        sim.ApplyLBFGS(tolerance = 1.0, max_iterations = max_iter_lbfgs)
        sim.UpdatePositions()

        # check for stereochemical problems
        if ignore_stereo_log:
            ost.PushVerbosityLevel(2)
        temp_ent = sim.GetEntity()
        temp_ent = temp_ent.Select("aname!=OXT")
        temp_ent_clash_filtered = mol.alg.FilterClashes(\
                                                temp_ent, clashing_distances)[0]
        # note: 10,10 parameters below are hard coded bond-/angle-tolerances
        temp_ent_stereo_checked = mol.alg.CheckStereoChemistry(\
                                                temp_ent_clash_filtered, 
                                                bond_stereo_chemical_param,
                                                angle_stereo_chemical_param, 
                                                10, 10)[0]
        if ignore_stereo_log:
            ost.PopVerbosityLevel()
        # checks above would remove bad atoms
        cur_energy = sim.GetEnergy()
        if len(temp_ent_stereo_checked.Select("ele!=H").atoms) \
           == len(temp_ent.Select("ele!=H").atoms):
            ost.LogInfo("No more stereo-chemical problems "
                        "-> final energy: %g" % cur_energy)
            break

    # update model
    simulation_ent = _GetSimEntity(sim)
    mhandle.model = mol.CreateEntityFromView(simulation_ent.Select("ele!=H"),
                                             True)
    # return model with hydrogens
    return simulation_ent

def CheckFinalModel(mhandle):
    '''Performs samity checks on final models and reports problems.
    
    :param mhandle: Modelling handle for which to perform checks.
    :type mhandle:  :class:`ModellingHandle`
    '''
    prof_name = 'pipeline::CheckFinalModel'
    prof = core.StaticRuntimeProfiler.StartScoped(prof_name)
    # report incomplete models
    for chain in mhandle.model.chains:
        if chain.residue_count < 3:
            ost.LogWarning("Chain %s of returned model contains only %d "\
                           "residues! This typically indicates that the "\
                           "template is mostly a Calpha trace or contains "\
                           "too many D-peptides."\
                           % (chain.name, chain.residue_count))
    if len(mhandle.gaps) > 0:
        ost.LogWarning("Failed to close %d gap(s). Returning incomplete model!"\
                       % len(mhandle.gaps))
    else:
        # check sequences
        for chain, seq in zip(mhandle.model.chains, mhandle.seqres):
            a = chain.residues[0].GetNumber().GetNum()
            b = chain.residues[-1].GetNumber().GetNum()
            expected_seq = seq[a-1:b]
            actual_seq = ''.join([r.one_letter_code for r in chain.residues])
            if expected_seq != actual_seq:
                ost.LogWarning("Sequence mismatch in chain %s!"\
                               " Expected '%s'. Got '%s'" \
                               % (chain.name, expected_seq, actual_seq))
    
    # report ring punchings
    rings = GetRings(mhandle.model)
    ring_punches = GetRingPunches(rings, mhandle.model)
    for res in ring_punches:
        ost.LogWarning("Ring of " + str(res) + " has been punched!")
    
    # report stereo-chemical problems
    ost.PushVerbosityLevel(0)
    clashing_distances = mol.alg.DefaultClashingDistances()
    bond_stereo_chemical_param = mol.alg.DefaultBondStereoChemicalParams()
    angle_stereo_chemical_param = mol.alg.DefaultAngleStereoChemicalParams()
    # extract problems
    model_src = mhandle.model.Select("aname!=OXT")
    clash_info = mol.alg.FilterClashes(model_src, clashing_distances)[1]
    # note: 10,10 parameters below are hard coded bond-/angle-tolerances
    stereo_info = mol.alg.CheckStereoChemistry(model_src,
                                               bond_stereo_chemical_param,
                                               angle_stereo_chemical_param,
                                               10, 10)[1]
    ost.PopVerbosityLevel()
    # set bool props in model-residues
    atoms = [e.GetFirstAtom() for e in clash_info.GetClashList()]\
          + [e.GetSecondAtom() for e in clash_info.GetClashList()]\
          + [e.GetFirstAtom() for e in stereo_info.GetBondViolationList()]\
          + [e.GetSecondAtom() for e in stereo_info.GetBondViolationList()]\
          + [e.GetSecondAtom() for e in stereo_info.GetAngleViolationList()]
    for atomui in atoms:
        res = model_src.FindResidue(atomui.GetChainName(), atomui.GetResNum())
        res.SetBoolProp("stereo_chemical_problem", True)
        if atomui.GetAtomName() in ["CA", "N", "O", "C"]:
            res.SetBoolProp("stereo_chemical_problem_backbone", True)
    # report bad residues
    for res in model_src.residues:
        if res.HasProp("stereo_chemical_problem_backbone") and\
           res.GetBoolProp("stereo_chemical_problem_backbone"):
            ost.LogInfo("Stereo-chemical problem in backbone " + \
                        "of residue " + str(res))
        elif res.HasProp("stereo_chemical_problem") and\
             res.GetBoolProp("stereo_chemical_problem"):
            ost.LogInfo("Stereo-chemical problem in sidechain " + \
                        "of residue " + str(res))

def BuildFromRawModel(mhandle, use_amber_ff=False, extra_force_fields=list(),
                      model_termini=False):
    '''Build a model starting with a raw model (see :func:`BuildRawModel`).

    This function implements a recommended pipeline to generate complete models
    from a raw model. The steps are shown in detail in the code example
    :ref:`above <modelling_steps_example>`. If you wish to use your own
    pipeline, you can use that code as a starting point for your own custom
    modelling pipeline. For reproducibility, we recommend that you keep copies
    of custom pipelines.

    To adapt the scoring used during loop closing, you can call
    :func:`SetupDefaultBackboneScoring` and :func:`SetupDefaultAllAtomScoring`
    and adapt the default scoring members. Alternatively, you can setup the
    scoring manually, but you must ensure consistency yourself!

    By default, a simple backbone dihedral sampling is performed when entering 
    Monte Carlo. If *mhandle* has a list of :class:`FraggerHandle` objects 
    attached as "fragger_handles" attribute, the sampling will be performed with 
    structural fragments. To ensure consistency, the fragger handles should be 
    attached using :meth:`SetFraggerHandles`. 
    But be aware of increased runtime due to the fragment search step.

    If the function fails to close all gaps, it will produce a warning and
    return an incomplete model.

    :param mhandle: The prepared template coordinates loaded with the input
                    alignment.
    :type mhandle:  :class:`ModellingHandle`

    :param use_amber_ff: if True, use the AMBER force field instead of the def.
                         CHARMM one (see :func:`ost.mol.mm.LoadAMBERForcefield`
                         and :func:`ost.mol.mm.LoadCHARMMForcefield`).
                         Both do a similarly good job without ligands (CHARMM
                         slightly better), but you will want to be consistent
                         with the optional force fields in `extra_force_fields`.
    :type use_amber_ff:  :class:`bool`

    :param extra_force_fields: Additional list of force fields to use if a 
                               (ligand) residue cannot be parametrized with the
                               default force field. The force fields are tried
                               in the order as given and ligands without an
                               existing parametrization are skipped.
    :type extra_force_fields:  :class:`list` of :class:`ost.mol.mm.Forcefield`

    :param model_termini: The default modelling pipeline in ProMod3 is optimized
                          to generate a gap-free model of the region in the 
                          target sequence(s) that is covered with template 
                          information. Terminal extensions without template 
                          coverage are negelected. 
                          You can activate this flag to enforce a model of the
                          full target sequence(s). The terminal parts will be 
                          modelled with a crude Monte Carlo approach. Be aware
                          that the accuracy of those termini is likely to be
                          limited. Termini of length 1 won't be modelled.
    :type model_termini:  :class:`bool`
                          

    :return: Delivers the model as an |ost_s| entity.
    :rtype: :class:`Entity <ost.mol.EntityHandle>`
    '''
    prof_name = 'pipeline::BuildFromRawModel'
    prof = core.StaticRuntimeProfiler.StartScoped(prof_name)
    # ignore empty models
    if mhandle.model.residue_count == 0:
        ost.LogError("Cannot perform modelling with an empty raw model.")
        return mhandle.model
    else:
        ost.LogInfo("Starting modelling based on a raw model.")

    # a bit of setup
    fragment_db = loop.LoadFragDB()
    structure_db = loop.LoadStructureDB()
    torsion_sampler = loop.LoadTorsionSamplerCoil()
    rotamer_library = sidechain.LoadBBDepLib()
    merge_distance = 4

    if not model_termini:
        # remove terminal gaps
        RemoveTerminalGaps(mhandle)

    # check whether we have fragger handles
    fragger_handles = None
    if hasattr(mhandle, "fragger_handles"):
        fragger_handles = mhandle.fragger_handles
        ost.LogInfo("Use fragments for Monte Carlo sampling")

    # close gaps
    CloseGaps(mhandle, merge_distance=merge_distance, 
              fragment_db=fragment_db, structure_db=structure_db, 
              torsion_sampler=torsion_sampler, fragger_handles=fragger_handles)

    if model_termini:
        ModelTermini(mhandle, torsion_sampler, fragger_handles=fragger_handles)
        RemoveTerminalGaps(mhandle)  # length=1 ignored above

    # build sidechains
    BuildSidechains(mhandle, merge_distance, fragment_db,
                    structure_db, torsion_sampler, rotamer_library)

    # minimize energy of final model using molecular mechanics
    MinimizeModelEnergy(mhandle, use_amber_ff=use_amber_ff,
                        extra_force_fields=extra_force_fields)

    # sanity checks
    CheckFinalModel(mhandle)

    # done
    return mhandle.model

# these methods will be exported into module
__all__ = ('BuildFromRawModel', 'BuildSidechains', 
           'MinimizeModelEnergy', 'CheckFinalModel')