Skip to content
Snippets Groups Projects
_reconstruct_sidechains.py 22.11 KiB
from . import _sidechain as sidechain
from ost import geom, mol, conop
from promod3 import core

###############################################################################
# helper functions
def _GetRotamerIDs(res_list):
    '''Return list (length = len(res_list)) of rotamer IDs for all residues.'''
    rotamer_ids = [sidechain.ALA] * len(res_list)
    for i,r in enumerate(res_list):
        rot_id = sidechain.TLCToRotID(r.GetName())
        if rot_id == sidechain.XXX:
            continue # no idea what it is, so we stick with ALA 
                     # => don't model sidechain
        rotamer_ids[i] = rot_id
    return rotamer_ids

def _GetPhiAngle(r):
    '''Extract phi angle for residue r.'''
    # def. fallback = helix
    phi = -1.0472
    # try to get phi from torsion angles
    tor = r.GetPhiTorsion()
    if tor.IsValid():
        phi = tor.GetAngle()
    else:
        r_prev = r.handle.prev
        if r_prev.IsValid() and mol.InSequence(r_prev, r.handle):
            c_prev = r_prev.FindAtom("C")
            n = r.FindAtom("N")
            ca = r.FindAtom("CA")
            c = r.FindAtom("C")
            if c_prev.IsValid() and n.IsValid() and ca.IsValid() and c.IsValid():
                phi = geom.DihedralAngle(c_prev.GetPos(),n.GetPos(),
                                         ca.GetPos(),c.GetPos())
    return phi

def _GetPsiAngle(r):
    '''Extract psi angle for residue r.'''
    # def. fallback = helix 
    psi = -0.7854
    # try to get psi from torsion angles
    tor = r.GetPsiTorsion()
    if tor.IsValid():
        psi = tor.GetAngle()
    else:
        r_next = r.handle.next
        if r_next.IsValid() and mol.InSequence(r.handle, r_next):
            n = r.FindAtom("N")
            ca = r.FindAtom("CA")
            c = r.FindAtom("C")
            n_next = r_next.FindAtom("N")
            if n.IsValid() and ca.IsValid() and c.IsValid() and n_next.IsValid():
                psi = geom.DihedralAngle(n.GetPos(), ca.GetPos(),
                                         c.GetPos(), n_next.GetPos())
    return psi

def _GetDihedrals(res_list):
    '''Extract dihedral angles for all residues.
    Returns phi and psi angles as 2 lists with length = len(res_list).
    '''
    prof_name = 'sidechain::_GetDihedrals'
    prof = core.StaticRuntimeProfiler.StartScoped(prof_name)
    phi_angles = [0.0] * len(res_list)
    psi_angles = [0.0] * len(res_list)
    for i,r in enumerate(res_list):
        phi_angles[i] = _GetPhiAngle(r)
        psi_angles[i] = _GetPsiAngle(r)
    return phi_angles, psi_angles

def _AddBackboneFrameResidues(frame_residues, res_list, rotamer_ids,
                              rotamer_settings, phi_angles):
    '''Update frame_residues (list) with BackboneFrameResidues for res_list.'''
    for i,r in enumerate(res_list):
        try:
            frame_residue = sidechain.ConstructBackboneFrameResidue(\
                                r.handle, rotamer_ids[i], i, rotamer_settings,
                                phi_angles[i], r.HasProp("n_ter"),
                                r.HasProp("c_ter"))
            frame_residues.append(frame_residue)
        except:
            continue

def _AddLigandFrameResidues(frame_residues, ent_lig, rotamer_settings, offset):
    '''Update frame_residues (list) with FrameResidues for res. in ent_lig.
    Set offset >= number of non-ligand residues (used for residue_index).
    '''
    # parse ligand residues
    for i, res in enumerate(ent_lig.residues):
        res_idx = offset + i
        is_done = False
        # special treatment for peptides
        if res.IsPeptideLinking():
            rot_id = sidechain.TLCToRotID(res.GetName())
            if rot_id != sidechain.XXX:
                # get more info
                phi = _GetPhiAngle(res)
                r_prev = res.handle.prev
                n_ter = not r_prev.IsValid() \
                        or not mol.InSequence(r_prev, res.handle)
                r_next = res.handle.next
                c_ter = not r_next.IsValid() \
                        or not mol.InSequence(res.handle, r_next)
                # try to add frame residues (ignore exceptions)
                try:
                    fr1 = sidechain.ConstructBackboneFrameResidue(\
                              res.handle, rot_id, res_idx, rotamer_settings,
                              phi, n_ter, c_ter)
                    if rot_id != sidechain.ALA and rot_id != sidechain.GLY:
                        fr2 = sidechain.ConstructSidechainFrameResidue(\
                                  res.handle, rot_id, res_idx, rotamer_settings)
                        frame_residues.extend([fr1,fr2])
                    else:
                        frame_residues.append(fr1)
                except:
                    pass   # ignore peptide treatment and treat below
                else:
                    is_done = True
        # if it failed, treat it as an unknown entity
        if not is_done:
            # try to add frame residues (skip exceptions)
            try:
                # NOTES:
                # - ConstructFrameResidueHeuristic has fall back if res unknown
                # - it only deals with few possible ligand cases and has not
                #   been tested extensively!
                comp_lib = conop.GetDefaultLib()
                fr = sidechain.ConstructFrameResidueHeuristic(\
                         res.handle, res_idx, rotamer_settings, comp_lib)
                frame_residues.append(fr)
            except:
                continue

def _AddSidechainFrameResidues(frame_residues, incomplete_sidechains,
                               keep_sidechains, res_list, rotamer_ids,
                               rotamer_settings, cystein_indices=None):
    '''Update frame_residues (list) with SidechainFrameResidues for res_list,
    incomplete_sidechains (list of indices) with sidechains to be constructed,
    and (if given) cystein_indices (list of indices) with all CYS (appended).
    Each residue can only end up in one of the 3 lists.
    '''
    if keep_sidechains:
        # try to generate frame residues for all existing side chains
        # skip non-existing sidechains and CYS (if cystein_indices) and update
        # incomplete_sidechains and cystein_indices
        for i,r in enumerate(res_list):

            if cystein_indices is not None and rotamer_ids[i] == sidechain.CYS:
                cystein_indices.append(i)
                continue

            if rotamer_ids[i] == sidechain.ALA or rotamer_ids[i] == sidechain.GLY:
                continue # no sidechain to model

            try:
                frame_residue = sidechain.ConstructSidechainFrameResidue(\
                                    r.handle, rotamer_ids[i], i, rotamer_settings)
                frame_residues.append(frame_residue)
            except:
                incomplete_sidechains.append(i)
    else:
        # no frame residues to create, just update incomplete_sidechains
        # and cystein_indices if needed
        for i,r in enumerate(res_list):

            if cystein_indices is not None and rotamer_ids[i] == sidechain.CYS:
                cystein_indices.append(i)
                continue

            if rotamer_ids[i] == sidechain.ALA or rotamer_ids[i] == sidechain.GLY:
                continue # no sidechain to model

            incomplete_sidechains.append(i)

def _AddCysteinFrameResidues(frame_residues, incomplete_sidechains,
                             keep_sidechains, res_list, rotamer_ids,
                             rotamer_settings, cystein_indices,
                             disulfid_indices, disulfid_rotamers):
    '''Update frame_residues (list) with cysteins.
    Parameters as in _AddSidechainFrameResidues.
    Some cysteins (in disulfid_indices) get special treatment as disulfid
    bridges (disulfid_indices, disulfid_rotamers from _GetDisulfidBridges).
    '''
    # handle cysteins participating in a disulfid bond
    for cys_idx, cys_rot in zip(disulfid_indices, disulfid_rotamers):
        # add FrameResidue
        frame_residue = sidechain.FrameResidue([cys_rot[0]], cys_idx)
        frame_residues.append(frame_residue)
        # set the position in the proteins residues
        cys_rot.ApplyOnResidue(res_list[cys_idx].handle,
                               consider_hydrogens=False)
        sidechain.ConnectSidechain(res_list[cys_idx].handle, sidechain.CYS)

    # add remaining ones according the given flags
    for idx in cystein_indices:
        if idx in disulfid_indices:
            continue # already handled
        if keep_sidechains:
            try:
                frame_residue = sidechain.ConstructSidechainFrameResidue(\
                                    res_list[idx].handle, rotamer_ids[idx],
                                    idx, rotamer_settings)
                frame_residues.append(frame_residue)
            except:
                incomplete_sidechains.append(idx) 
        else:
            incomplete_sidechains.append(idx)

def _GetRotamerGroup(res_handle, rot_id, res_idx, rot_lib, rot_settings,
                     phi, psi, use_frm, bbdep):
    '''Get RotamerGroup for res_handle according to settings.'''
    if use_frm:
        if bbdep:
            return sidechain.ConstructFRMRotamerGroup(res_handle, rot_id,
                                                      res_idx, rot_lib,
                                                      rot_settings, phi, psi)
        else:
            return sidechain.ConstructFRMRotamerGroup(res_handle, rot_id,
                                                      res_idx, rot_lib,
                                                      rot_settings)
    else:
        if bbdep:
            return sidechain.ConstructRRMRotamerGroup(res_handle, rot_id,
                                                      res_idx, rot_lib,
                                                      rot_settings, phi, psi)
        else:
            return sidechain.ConstructRRMRotamerGroup(res_handle, rot_id,
                                                      res_idx, rot_lib,
                                                      rot_settings)

def _GetRotamerGroups(res_list, rot_ids, indices, rot_lib, rot_settings,
                      phi_angles, psi_angles, use_frm, bbdep, frame_residues):
    '''Get list of rotamer groups from subset of res_list.
    Residues are chosen as res_list[i] for i in indices and only if a rotamer
    group can be created (e.g. no ALA, GLY).
    Rotamer groups are filtered to keep only best ones (given frame).
    Returns list of rotamer groups and list of res. indices they belong to.
    '''
    prof_name = 'sidechain::_GetRotamerGroups'
    prof = core.StaticRuntimeProfiler.StartScoped(prof_name)

    # res.index (res_list[i]) for each modelled sc
    residues_with_rotamer_group = list()
    #  linked to residue in residues_with_rotamer_group
    rotamer_groups = list()
    # get frame for score evaluation
    frame = sidechain.Frame(frame_residues)
    # build rotamers for chosen sidechains
    for i in indices:
        # get rotamer ID
        r = res_list[i]
        rot_id = rot_ids[i]
        
        if rot_id == sidechain.ALA or rot_id == sidechain.GLY:
            continue

        if rot_id == sidechain.CYS:
            rot_id = sidechain.CYH

        if rot_id == sidechain.PRO:
            tor = r.GetOmegaTorsion()
            omega = None
            if tor.IsValid():
                omega = tor.GetAngle()
            elif i > 0:
                # fallback computation of omega as in OST-code
                prev = res_list[i-1]
                if prev.IsValid() and prev.IsPeptideLinking():
                    ca_prev = prev.FindAtom("CA")
                    c_prev = prev.FindAtom("C")
                    n = r.FindAtom("N")
                    ca = r.FindAtom("CA")
                    valid = ca_prev.IsValid() and c_prev.IsValid() \
                            and n.IsValid() and ca.IsValid()
                    if valid and mol.BondExists(c_prev.handle, n.handle):
                        omega = geom.DihedralAngle(ca_prev.GetPos(),
                                                   c_prev.GetPos(),
                                                   n.GetPos(), ca.GetPos())
            # omega not set if prev. res. missing
            if omega is not None:
                if abs(omega) < 1.57:
                    rot_id = sidechain.CPR
                else:
                    rot_id = sidechain.TPR

        # get RotamerGroup
        try:
            rot_group = _GetRotamerGroup(r.handle, rot_id, i, rot_lib,
                                         rot_settings, phi_angles[i],
                                         psi_angles[i], use_frm, bbdep)
        except:
            continue
        # keep best ones
        rot_group.CalculateInternalEnergies()
        frame.SetFrameEnergy(rot_group)
        rot_group.ApplySelfEnergyThresh()
        rotamer_groups.append(rot_group)
        residues_with_rotamer_group.append(i)

    return rotamer_groups, residues_with_rotamer_group

def _GetDisulfidBridges(frame_residues, cystein_indices, res_list, rotamer_library,
                        use_frm, bbdep, rotamer_settings, phi_angles, psi_angles):
    '''Get disulfid bridges for CYS and according rotamers.
    CYS are identified by by items in cystein_indices (into res_list).
    Returns: disulfid_indices: list of res. index in bridge,
             disulfid_rotamers: list of rotamers (best one for bridge).
    '''
    # this is required for the disulfid score evaluation
    frame = sidechain.Frame(frame_residues)

    # some info we have to keep track of when evaluating disulfid bonds
    cystein_rotamers = list()
    cys_ca_positions = list()
    cys_cb_positions = list()

    for i in cystein_indices:
        # check ca, cb
        r = res_list[i]
        ca = r.FindAtom("CA")
        cb = r.FindAtom("CB")
        if not (ca.IsValid() and cb.IsValid()):
            continue
        cys_ca_positions.append(ca.GetPos())
        cys_cb_positions.append(cb.GetPos())
        # get RotamerGroup
        rot_group = _GetRotamerGroup(r.handle, sidechain.CYD, i, rotamer_library,
                                     rotamer_settings, phi_angles[i],
                                     psi_angles[i], use_frm, bbdep)
        frame.AddFrameEnergy(rot_group)
        cystein_rotamers.append(rot_group)

    # get CYS with disulfid bonds and the chosen rotamers
    disulfid_indices = list()
    disulfid_rotamers = list()
    for i in range(len(cystein_rotamers)):
        for j in range(i+1, len(cystein_rotamers)):

            # too far for a disulfid bond?
            if geom.Distance(cys_ca_positions[i], cys_ca_positions[j]) > 8.0:
                continue 

            # already done? NOTE: new one might be better, but here we do
            #                     first come, first served
            if cystein_indices[i] in disulfid_indices \
               or cystein_indices[j] in disulfid_indices:
                continue

            min_score = float("inf")
            min_index_k = -1
            min_index_l = -1

            for k in range(len(cystein_rotamers[i])):
                for l in range(len(cystein_rotamers[j])):
                    score = sidechain.DisulfidScore(cystein_rotamers[i][k],
                                                    cystein_rotamers[j][l],
                                                    cys_ca_positions[i], 
                                                    cys_cb_positions[i],
                                                    cys_ca_positions[j],
                                                    cys_cb_positions[j])
                    if score < min_score:
                        min_index_k = k
                        min_index_l = l
                        min_score = score

            if min_score < 45.0:
                # update indices
                cys_idx_i = cystein_indices[i]
                cys_idx_j = cystein_indices[j]
                cys_rot_i = cystein_rotamers[i][min_index_k]
                cys_rot_j = cystein_rotamers[j][min_index_l]
                disulfid_indices.append(cys_idx_i)
                disulfid_indices.append(cys_idx_j)
                disulfid_rotamers.append(cys_rot_i)
                disulfid_rotamers.append(cys_rot_j)

    return disulfid_indices, disulfid_rotamers

###############################################################################

def Reconstruct(ent, keep_sidechains=False, build_disulfids=True,
                rotamer_model="frm", consider_hbonds=True,
                consider_ligands=True, rotamer_library=None):
    '''Reconstruct sidechains for the given structure.

    :param ent:          Structure for sidechain reconstruction. Note, that the
                         sidechain reconstruction gets directly applied on the
                         structure itself.

    :param keep_sidechains: Flag, whether complete sidechains in *ent* (i.e. 
                            containing all required atoms) should be kept rigid
                            and directly be added to the frame.

    :param build_disulfids: Flag, whether possible disulfid bonds should be 
                            searched. If a disulfid bond is found, the two
                            participating cysteins are fixed and added to
                            the frame.

    :param rotamer_model: Rotamer model to be used, can either be "frm" or "rrm"

    :param consider_hbonds: Flag, whether hbonds should be evaluated in the
                            energy function. If set to False, no hydrogens will
                            be built when building rotamers and frame.

    :param consider_ligands: Flag, whether to add ligands (anything in chain
                             '_') as static objects.

    :param rotamer_library: A rotamer library to extract the rotamers from. The
                            default is the :meth:`Dunbrack <LoadDunbrackLib>`
                            library.


    :type ent:            :class:`ost.mol.EntityHandle`
    :type keep_sidechains: :class:`bool`
    :type build_disulfids: :class:`bool`
    :type rotamer_model:   :class:`str`
    :type consider_hbonds: :class:`bool`
    :type consider_ligands: :class:`bool`
    :type rotamer_library: :class:`BBDepRotamerLib` / :class:`RotamerLib`
    '''
    prof_name = 'sidechain::Reconstruct'
    prof = core.StaticRuntimeProfiler.StartScoped(prof_name)

    # setup settings
    if rotamer_model.lower() == "frm":
        use_frm = True
    elif rotamer_model.lower() == "rrm":
        use_frm = False
    else:
        raise RuntimeError("Only \"rrm\" and \"frm\" allowed for rotamer_model!")

    rotamer_settings = sidechain.RotamerSettings()
    rotamer_settings.consider_hbonds = consider_hbonds
    if rotamer_library == None: 
        rotamer_library = sidechain.LoadDunbrackLib()
    bbdep = False
    if type(rotamer_library) is sidechain.BBDepRotamerLib:
        bbdep = True
    
    # take out ligand chain and any non-peptides
    prot = ent.Select("peptide=true and cname!='_'")
    
    # parse residues (all lists of length len(prot.residues))
    rotamer_ids = _GetRotamerIDs(prot.residues)
    phi_angles, psi_angles = _GetDihedrals(prot.residues)

    # set nter and cter (needed in _AddBackboneFrameResidues)
    for c in prot.chains:
        c.residues[0].SetIntProp("n_ter",1)
        c.residues[-1].SetIntProp("c_ter",1)

    # build up frame
    frame_residues = list()         # list of frame residues connected to frame
    incomplete_sidechains = list()  # residue indices
    _AddBackboneFrameResidues(frame_residues, prot.residues, rotamer_ids,
                              rotamer_settings, phi_angles)
    
    # add ligands?
    if consider_ligands:
        ligs = ent.Select("cname='_'")
        offset = len(prot.residues)
        _AddLigandFrameResidues(frame_residues, ligs, rotamer_settings, offset)

    # check special handling of cysteins
    if build_disulfids:
        # residue indices of cysteins
        cystein_indices = list()
        # update frame_residues, incomplete_sidechains, cystein_indices
        _AddSidechainFrameResidues(frame_residues, incomplete_sidechains,
                                   keep_sidechains, prot.residues, rotamer_ids,
                                   rotamer_settings, cystein_indices)
        # update frame_residues, incomplete_sidechains with cysteins (if needed)
        if len(cystein_indices) > 0:
            # get disulfid bridges and according rotamers
            disulfid_indices, disulfid_rotamers = \
                _GetDisulfidBridges(frame_residues, cystein_indices, prot.residues,
                                    rotamer_library, use_frm, bbdep,
                                    rotamer_settings, phi_angles, psi_angles)
            # update frame_residues, incomplete_sidechains
            _AddCysteinFrameResidues(frame_residues, incomplete_sidechains,
                                     keep_sidechains, prot.residues, rotamer_ids,
                                     rotamer_settings, cystein_indices,
                                     disulfid_indices, disulfid_rotamers)
    else:
        # update frame_residues, incomplete_sidechains
        _AddSidechainFrameResidues(frame_residues, incomplete_sidechains,
                                   keep_sidechains, prot.residues, rotamer_ids,
                                   rotamer_settings)
    
    # get rotamer groups and residues they're linked to
    rotamer_groups, residues_with_rotamer_group = \
        _GetRotamerGroups(prot.residues, rotamer_ids, incomplete_sidechains,
                          rotamer_library, rotamer_settings, phi_angles,
                          psi_angles, use_frm, bbdep, frame_residues)

    # set up graph and solve to get best rotamers
    if use_frm:
        graph = sidechain.RotamerGraph.CreateFromFRMList(rotamer_groups)
    else:
        graph = sidechain.RotamerGraph.CreateFromRRMList(rotamer_groups)

    solution = graph.TreeSolve(100000000,0.02)[0]

    # update structure
    for i,rot_group,sol in zip(residues_with_rotamer_group,rotamer_groups,solution):
        try:
            res_handle = prot.residues[i].handle
            rot_group[sol].ApplyOnResidue(res_handle, consider_hydrogens=False)
            sidechain.ConnectSidechain(res_handle, rotamer_ids[i])
        except:
            print "there is a backbone atom missing... ", res_handle.GetQualifiedName()

# these methods will be exported into module
__all__ = ('Reconstruct',)