Skip to content
Snippets Groups Projects
ost-compare-structures 23.83 KiB
"""Evaluate model structure against reference.

eg.

  ost compare-structures \
    -m <MODEL> \
    -r <REF> \
    -o output.json \
    -l \
    -sc \
    -cc \
    -ml \
    -rm oxt hyd \
    -mn

CAMEO calls lddt binary as follows:

  lddt \
    -p <PARAMETER FILE> \
    -f \
    -a 15 \
    -b 15 \
    -r 15 \
    <MODEL> \
    <REF>

Only model structures are "Molck-ed" in CAMEO. The call to molck is as follows:

  molck \
    --complib=<COMPOUND LIB> \
    --rm=hyd,oxt,unk \
    --fix-ele \
    --map-nonstd <FILEPATH> \
    --out=<OUTPUT>

To be as much compatible with with CAMEO as possible one should call
compare-structures as follows:

  ost compare-structures \
    # General parameters
    ####################
    --model <MODEL> \
    --reference <REF> \
    --output output.json \
    # QS-score parameters
    #####################
    --qs-score \
    --residue-number-alignment \
    # lDDT parameters
    #################
    --lddt \
    --structural-checks \
    --consistency-checks \
    --inclusion-radius 15.0 \
    --bond-tolerance 15.0 \
    --angle-tolerance 15.0 \
    # Molck parameters
    ##################
    --molck \
    --remove oxt hyd unk \
    --clean-element-column \
    --map-nonstandard-residues
"""

import os
import sys
import json
import argparse

import ost
from ost.io import (LoadPDB, LoadMMCIF, MMCifInfoBioUnit, MMCifInfo,
                    MMCifInfoTransOp, ReadStereoChemicalPropsFile)
from ost import PushVerbosityLevel
from ost.mol.alg import (qsscoring, Molck, MolckSettings, lDDTSettings,
                         CheckStructure)
from ost.conop import CompoundLib


class _DefaultStereochemicalParamAction(argparse.Action):
    def __init__(self, default=None, required=False, **kwargs):
        # Try to set default
        cwd = os.path.abspath(os.getcwd())
        parameter_file_path = os.path.join(cwd, "stereo_chemical_props.txt")
        if not os.path.exists(parameter_file_path):
            try:
                parameter_file_path = os.path.join(
                    ost.GetSharedDataPath(),
                    "stereo_chemical_props.txt")
                default = parameter_file_path
                msg = ""
            except RuntimeError:
                msg = (
                    "Could not set default stereochemical parameter file. In "
                    "order to use the default one please set $OST_ROOT "
                    "environmental variable, run the script with OST binary or"
                    " provide a local copy of 'stereo_chemical_props.txt' in "
                    "CWD. Alternatively provide the path to the local copy")
        else:
            default = parameter_file_path
            msg = ""
        super(_DefaultStereochemicalParamAction, self).__init__(
            default=default,
            required=required,
            **kwargs)
        if msg:
            self.help += " (WARNING: %s)" % (msg,)

    def __call__(self, parser, namespace, values, option_string=None):
        if not os.path.exists(values):
            parser.error(
                "Parameter file %s does not exist." % values)
        setattr(namespace, self.dest, values)


class _DefaultCompoundLibraryAction(argparse.Action):
    def __init__(self, default=None, required=False, **kwargs):
        # Try to set default
        cwd = os.path.abspath(os.getcwd())
        compound_library_path = os.path.join(cwd, "compounds.chemlib")
        if not os.path.exists(compound_library_path):
            try:
                compound_library_path = os.path.join(
                    ost.GetSharedDataPath(),
                    "compounds.chemlib")
                default = compound_library_path
                msg = ""
            except RuntimeError:
                msg = (
                    "Could not set default compounds library path. In "
                    "order to use the default one please set $OST_ROOT "
                    "environmental variable, run the script with OST binary or"
                    " provide a local copy of 'compounds.chemlib' in CWD"
                    ". Alternatively provide the path to the local copy")
        else:
            default = compound_library_path
            msg = ""
        super(_DefaultCompoundLibraryAction, self).__init__(
            default=default,
            required=required,
            **kwargs)
        if msg:
            self.help += " (WARNING: %s)" % (msg,)

    def __call__(self, parser, namespace, values, option_string=None):
        if not os.path.exists(values):
            parser.error(
                "Compounds library file %s does not exist." % values)
        setattr(namespace, self.dest, values)


def _ParseArgs():
    """Parse command-line arguments."""
    #
    # General options
    #
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        description=__doc__,
        prog="ost compare-structures")

    parser.add_argument(
        '-v',
        '--verbosity',
        type=int,
        default=3,
        help="Set verbosity level.")
    parser.add_argument(
        "-m",
        "--model",
        dest="model",
        required=True,
        help=("Path to the model file."))
    parser.add_argument(
        "-r",
        "--reference",
        dest="reference",
        required=True,
        help=("Path to the reference file."))
    parser.add_argument(
        "-o",
        "--output",
        dest="output",
        help=("Output file name. The output will be saved as a JSON file."))
    #
    # QS-score options
    #
    parser.add_argument(
        "-qs",
        "--qs-score",
        dest="qs_score",
        default=False,
        action="store_true",
        help=("Calculate QS-score."))
    parser.add_argument(
        "-c",
        "--chain-mapping",
        nargs="+",
        type=lambda x: x.split(":"),
        dest="chain_mapping",
        help=("Mapping of chains between the model and the reference. "
              "Each separate mapping consist of key:value pairs where key "
              "is the chain name in model and value is the chain name in "
              "reference."))
    parser.add_argument(
        "-rna",
        "--residue-number-alignment",
        dest="residue_number_alignment",
        default=False,
        action="store_true",
        help=("Make alignment based on residue number instead of using "
              "Clustal."))
    #
    # lDDT options
    #
    parser.add_argument(
        "-l",
        "--lddt",
        dest="lddt",
        default=False,
        action="store_true",
        help=("Calculate lDDT."))
    parser.add_argument(
        "-s",
        "--selection",
        dest="selection",
        default="",
        help=("Selection performed on reference."))
    parser.add_argument(
        "-ca",
        "--c-alpha-only",
        dest="c_alpha_only",
        default=False,
        action="store_true",
        help=("Use C-alpha atoms only."))
    parser.add_argument(
        "-sc",
        "--structural-checks",
        dest="structural_checks",
        default=False,
        action="store_true",
        help=("Perform structural checks and filter input data."))
    parser.add_argument(
        "-ft",
        "--fault-tolerant",
        dest="fault_tolerant",
        default=False,
        action="store_true",
        help=("Fault tolerant parsing."))
    parser.add_argument(
        "-p",
        "--parameter-file",
        dest="parameter_file",
        action=_DefaultStereochemicalParamAction,
        help=("Location of the stereochemical parameter file "
              "(stereo_chemical_props.txt). "
              "If not provided, the following locations are searched in this "
              "order: 1. Working directory, 2. OpenStructure standard library "
              "location."))
    parser.add_argument(
        "-bt",
        "--bond-tolerance",
        dest="bond_tolerance",
        type=float,
        default=12.0,
        help=("Tolerance in STD for bonds."))
    parser.add_argument(
        "-at",
        "--angle-tolerance",
        dest="angle_tolerance",
        type=float,
        default=12.0,
        help=("Tolerance in STD for angles."))
    parser.add_argument(
        "-ir",
        "--inclusion-radius",
        dest="inclusion_radius",
        type=float,
        default=15.0,
        help=("Distance inclusion radius."))
    parser.add_argument(
        "-ss",
        "--sequence-separation",
        dest="sequence_separation",
        type=int,
        default=0,
        help=("Sequence separation. Only distances between residues whose "
              "separation is higher than the provided parameter are "
              "considered when computing the score"))
    parser.add_argument(
        "-cc",
        "--consistency-checks",
        dest="consistency_checks",
        default=False,
        action="store_true",
        help=("Residue name consistency checks."))
    #
    # Molck parameters
    #
    parser.add_argument(
        "-ml",
        "--molck",
        dest="molck",
        default=False,
        action="store_true",
        help=("Run molecular checker to clean up input."))
    parser.add_argument(
        "-cl",
        "--compound-library",
        dest="compound_library",
        action=_DefaultCompoundLibraryAction,
        help=("Location of the compound library file (compounds.chemlib). "
              "If not provided, the following locations are searched in this "
              "order: 1. Working directory, 2. OpenStructure standard library "
              "location."))
    parser.add_argument(
        "-rm",
        "--remove",
        dest="remove",
        nargs="+",  # *, +, ?, N
        required=False,
        default=["hyd"],
        help=("Remove atoms and residues matching some criteria: "
              "zeroocc - Remove atoms with zero occupancy, "
              "hyd - remove hydrogen atoms, "
              "oxt - remove terminal oxygens, "
              "nonstd - remove all residues not one of the 20 "
              "standard amino acids, "
              "unk - Remove unknown and atoms not following the nomenclature"))
    parser.add_argument(
        "-ce",
        "--clean-element-column",
        dest="clean_element_column",
        default=False,
        action="store_true",
        help=("Clean up element column"))
    parser.add_argument(
        "-mn",
        "--map-nonstandard-residues",
        dest="map_nonstandard_residues",
        default=False,
        action="store_true",
        help=("Map modified residues back to the parent amino acid, for "
              "example MSE -> MET, SEP -> SER."))

    # Print full help is no arguments provided
    if len(sys.argv) == 1:
        parser.print_help(sys.stderr)
        sys.exit(1)

    opts = parser.parse_args()
    # Set chain mapping
    if opts.chain_mapping is not None:
        try:
            opts.chain_mapping = dict(opts.chain_mapping)
        except ValueError:
            parser.error(
                "Cannot parse chain mapping into dictionary. The "
                "correct format is: key:value [key2:value2 ...].")
    # Check parameter file if lddt is on
    if opts.lddt and opts.parameter_file is None:
        parser.error(
            "argument -p/--parameter-file is required when --lddt "
            "option is selected.")

    # Check parameter file if lddt is on
    if opts.molck and opts.compound_library is None:
        parser.error(
            "argument -cl/--compound-library is required when --molck "
            "option is selected.")

    return opts


def _ReadStructureFile(path):
    """Safely read structure file into OST entity.

    The functin can read both PDB and mmCIF files.

    :param path: Path to the file.
    :type path: :class:`str`
    :returns: Entity
    :rtype: :class:`~ost.mol.EntityHandle`
    """
    entities = list()
    if not os.path.isfile(path):
        raise IOError("%s is not a file" % path)
    try:
        entity = LoadPDB(path)
        if not entity.IsValid():
            raise IOError("Provided file does not contain valid entity.")
        entity.SetName(os.path.basename(path))
        chain_mapping = {c.name: c.name for c in entity.chains}
        entities.append({"entity": entity,
                         "chain_mapping": chain_mapping})
    except Exception:
        try:
            tmp_entity, cif_info = LoadMMCIF(path, info=True)
            if len(cif_info.biounits) == 0:
                tbu = MMCifInfoBioUnit()
                tbu.id = 'ASU of ' + entity.pdb_id
                tbu.details = 'asymmetric unit'
                for chain in tmp_entity.chains:
                    tbu.AddChain(str(chain))
                tinfo = MMCifInfo()
                tops = MMCifInfoTransOp()
                tinfo.AddOperation(tops)
                tbu.AddOperations(tinfo.GetOperations())
                entity = tbu.PDBize(tmp_entity, min_polymer_size=0)
                entity.SetName(os.path.basename(path) + ".au")
                chain_mapping = {c.name: c.GetStringProp("original_name")
                                 for c in entity.chains}
                entities.append({
                    "entity": entity,
                    "chain_mapping": chain_mapping})
            elif len(cif_info.biounits) > 1:
                for i, biounit in enumerate(cif_info.biounits):
                    entity = biounit.PDBize(tmp_entity, min_polymer_size=0)
                    if not entity.IsValid():
                        raise IOError(
                            "Provided file does not contain valid entity.")
                    entity.SetName(os.path.basename(path) + "." + str(i))
                    chain_mapping = {c.name: c.GetStringProp("original_name")
                                     for c in entity.chains}
                    entities.append({
                        "entity": entity,
                        "chain_mapping": chain_mapping})
            else:
                biounit = cif_info.biounits[0]
                entity = biounit.PDBize(tmp_entity, min_polymer_size=0)
                if not entity.IsValid():
                    raise IOError(
                        "Provided file does not contain valid entity.")
                entity.SetName(os.path.basename(path))
                chain_mapping = {c.name: c.GetStringProp("original_name")
                                 for c in entity.chains}
                entities.append({
                    "entity": entity,
                    "chain_mapping": chain_mapping})

        except Exception as exc:
            raise exc
    return entities


def _MolckEntity(entity, options):
    """Molck the entity."""
    lib = CompoundLib.Load(options.compound_library)
    to_remove = tuple(options.remove)

    ms = MolckSettings(rm_unk_atoms="unk" in to_remove,
                       rm_non_std="nonstd" in to_remove,
                       rm_hyd_atoms="hyd" in to_remove,
                       rm_oxt_atoms="oxt" in to_remove,
                       rm_zero_occ_atoms="zeroocc" in to_remove,
                       colored=False,
                       map_nonstd_res=options.map_nonstandard_residues,
                       assign_elem=options.clean_element_column)
    Molck(entity, lib, ms)


def _Main():
    """Do the magic."""

    opts = _ParseArgs()
    PushVerbosityLevel(opts.verbosity)
    #
    # Read the input files
    ost.LogInfo("Reading model from %s" % opts.model)
    models = _ReadStructureFile(opts.model)
    ost.LogInfo("Reading reference from %s" % opts.reference)
    references = _ReadStructureFile(opts.reference)
    if opts.molck:
        for i in range(len(references)):
            _MolckEntity(references[i]["entity"], opts)
            references[i]["entity"] = references[i]["entity"].CreateFullView()
        for i in range(len(models)):
            _MolckEntity(models[i]["entity"], opts)
            models[i]["entity"] = models[i]["entity"].CreateFullView()
    else:
        for i in range(len(references)):
            references[i]["entity"] = references[i]["entity"].CreateFullView()
        for i in range(len(models)):
            models[i]["entity"] = models[i]["entity"].CreateFullView()
    if opts.structural_checks:
        stereochemical_parameters = ReadStereoChemicalPropsFile(
            opts.parameter_file)
        ost.LogInfo("Performing structural checks for reference(s)")
        for reference in references:
            CheckStructure(reference["entity"],
                           stereochemical_parameters.bond_table,
                           stereochemical_parameters.angle_table,
                           stereochemical_parameters.nonbonded_table,
                           opts.bond_tolerance,
                           opts.angle_tolerance)
        ost.LogInfo("Performing structural checks for model(s)")
        for model in models:
            CheckStructure(model["entity"],
                           stereochemical_parameters.bond_table,
                           stereochemical_parameters.angle_table,
                           stereochemical_parameters.nonbonded_table,
                           opts.bond_tolerance,
                           opts.angle_tolerance)

    if len(models) > 1 or len(references) > 1:
        ost.LogInfo(
            "Multiple complexes detected. All combinations will be tried.")

    result = {
        "result": {},
        "options": vars(opts)}
    result["options"]["cwd"] = os.path.abspath(os.getcwd())
    #
    # Perform scoring
    for model_data in models:
        model = model_data["entity"]
        model_name = model.GetName()
        model_results = dict()
        for reference_data in references:
            reference = reference_data["entity"]
            reference_name = reference.GetName()
            reference_results = dict()
            ost.LogInfo("#\nComparing %s to %s" % (
                model_name,
                reference_name))
            qs_scorer = qsscoring.QSscorer(reference,
                                           model,
                                           opts.residue_number_alignment)
            if opts.chain_mapping is not None:
                ost.LogInfo(
                    "Using custom chain mapping: %s" % str(
                        opts.chain_mapping))
                qs_scorer.chain_mapping = opts.chain_mapping
            original_chain_mapping = dict()
            for mdl_cname, ref_cname in qs_scorer.chain_mapping.iteritems():
                orig_mdl_cname = model_data["chain_mapping"][mdl_cname]
                orig_ref_cname = reference_data["chain_mapping"][ref_cname]
                original_chain_mapping[orig_mdl_cname] = orig_ref_cname
            if opts.qs_score:
                ost.LogInfo("Computing QS-score")
                try:
                    reference_results["qs_score"] = {
                        "status": "SUCCESS",
                        "error": "",
                        "model_name": model_name,
                        "reference_name": reference_name,
                        "global_score": qs_scorer.global_score,
                        "best_score": qs_scorer.best_score,
                        "chain_mapping": qs_scorer.chain_mapping,
                        "original_chain_mapping": original_chain_mapping
                    }
                except qsscoring.QSscoreError as ex:
                    # default handling: report failure and set score to 0
                    ost.LogError('QSscore failed:', str(ex))
                    reference_results["qs_score"] = {
                        "status": "FAILURE",
                        "error": str(ex),
                        "model_name": model_name,
                        "reference_name": reference.GetName(),
                        "global_score": 0.0,
                        "best_score": 0.0,
                        "chain_mapping": qs_scorer.chain_mapping,
                        "original_chain_mapping": original_chain_mapping
                    }
            # Calculate lDDT
            if opts.lddt:
                ost.LogInfo("Computing lDDT")
                lddt_results = {
                    "single_chain_lddt": list()
                }
                lddt_settings = lDDTSettings(
                    bond_tolerance=opts.bond_tolerance,
                    angle_tolerance=opts.angle_tolerance,
                    radius=opts.inclusion_radius,
                    sequence_separation=opts.sequence_separation,
                    sel=opts.selection,
                    structural_checks=False,
                    consistency_checks=opts.consistency_checks,
                    label="lddt")
                if opts.verbosity > 3:
                    lddt_settings.PrintParameters()

                oligo_lddt_scorer = qsscoring.OligoLDDTScorer(
                    qs_scorer.qs_ent_1.ent,
                    qs_scorer.qs_ent_2.ent,
                    qs_scorer.alignments,
                    qs_scorer.calpha_only,
                    lddt_settings)
                for lddt_scorer in oligo_lddt_scorer.sc_lddt_scorers:
                    # Get chains and renumber according to alignment (for lDDT)
                    try:
                        model_chain = lddt_scorer.model.chains[0].GetName()
                        reference_chain = \
                            lddt_scorer.references[0].chains[0].GetName()
                        lddt_results["single_chain_lddt"].append({
                            "status": "SUCCESS",
                            "error": "",
                            "original_model_chain": model_data["chain_mapping"][model_chain],
                            "original_reference_chain": reference_data["chain_mapping"][reference_chain],
                            "model_chain": model_chain,
                            "reference_chain": reference_chain,
                            "global_score": lddt_scorer.global_score,
                            "conserved_contacts": lddt_scorer.conserved_contacts,
                            "total_contacts": lddt_scorer.total_contacts})
                    except Exception as ex:
                        ost.LogError('Single chain lDDT failed:', str(ex))
                        lddt_results["single_chain_lddt"].append({
                            "status": "FAILURE",
                            "error": str(ex),
                            "original_model_chain": model_data["chain_mapping"][model_chain],
                            "original_reference_chain": reference_data["chain_mapping"][reference_chain],
                            "model_chain": model_chain,
                            "reference_chain": reference_chain,
                            "global_score": 0.0,
                            "conserved_contacts": 0.0,
                            "total_contacts": 0.0})
                # perform oligo lddt scoring
                try:
                    lddt_results["oligo_lddt"] = {
                        "status": "SUCCESS",
                        "error": "",
                        "global_score": oligo_lddt_scorer.oligo_lddt}
                except Exception as ex:
                    ost.LogError('Oligo lDDT failed:', str(ex))
                    lddt_results["oligo_lddt"] = {
                        "status": "FAILURE",
                        "error": str(ex),
                        "global_score": 0.0}
                try:
                    lddt_results["weighted_lddt"] = {
                        "status": "SUCCESS",
                        "error": "",
                        "global_score": oligo_lddt_scorer.weighted_lddt}
                except Exception as ex:
                    lddt_results["weighted_lddt"] = {
                        "status": "FAILURE",
                        "error": str(ex),
                        "global_score": 0.0}
                reference_results["lddt"] = lddt_results
            model_results[reference_name] = reference_results
        result["result"][model_name] = model_results

    if opts.output is not None:
        with open(opts.output, "w") as outfile:
            outfile.write(json.dumps(result, indent=4))


if __name__ == '__main__':
    # make script 'hot'
    unbuffered = os.fdopen(sys.stdout.fileno(), 'w', 0)
    sys.stdout = unbuffered
    _Main()