From 53d59b37b8c571fe239966c08611fe24275f5b8f Mon Sep 17 00:00:00 2001
From: Xavier Robin <xavalias-github@xavier.robin.name>
Date: Thu, 2 Mar 2023 16:41:20 +0100
Subject: [PATCH] feat: SCHWED-5481 inital compare-ligand-structures action

---
 actions/CMakeLists.txt                  |   1 +
 actions/ost-compare-ligand-structures   | 350 ++++++++++++++++++++++++
 modules/mol/alg/pymod/ligand_scoring.py |   2 +-
 3 files changed, 352 insertions(+), 1 deletion(-)
 create mode 100644 actions/ost-compare-ligand-structures

diff --git a/actions/CMakeLists.txt b/actions/CMakeLists.txt
index 6f37c8ec2..abd002cc2 100644
--- a/actions/CMakeLists.txt
+++ b/actions/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_custom_target(actions ALL)
 
 ost_action_init()
+ost_action(ost-compare-ligand-structures actions)
 ost_action(ost-compare-structures actions)
 ost_action(ost-compare-structures-new actions)
diff --git a/actions/ost-compare-ligand-structures b/actions/ost-compare-ligand-structures
new file mode 100644
index 000000000..e24258c64
--- /dev/null
+++ b/actions/ost-compare-ligand-structures
@@ -0,0 +1,350 @@
+"""
+Evaluate model with non-polymer/small molecule ligands against reference
+
+Example: ost compare-structures-new \\
+    -m model.pdb \\
+    -ml ligand.sdf \\
+    -r reference.cif \\
+    --lddt-pli --rmsd
+
+Only minimal cleanup steps are performed (remove hydrogens, and for structures
+only, remove unknown atoms and cleanup element column).
+
+Ligands can be given as path to SDF files containing the ligand for both model
+(--model-ligands/-ml) and reference (--reference-ligands/-rl). If omitted,
+ligands will be detected in the model and reference structures. For structures
+given in mmCIF format, this is based on the annotation as "non polymer entity"
+and works reliably. For structures given in PDB format, this is based on the
+HET records and is normally not what you want.
+
+Output is written in JSON format (default: out.json). In case of no additional
+options, this is a dictionary with three keys:
+
+ * "model_ligands": A list of ligands in the model. If ligands were provided
+   explicitly with --model-ligands, elements of the list will be the paths to
+   the ligand SDF file(s). Otherwise, they will be the chain name and residue
+   number of the ligand, separated by a dot.
+ * "reference_ligands": A list of ligands in the reference. If ligands were
+   provided explicitly with --reference-ligands, elements of the list will be
+   the paths to the ligand SDF file(s). Otherwise, they will be the chain name
+   and residue number of the ligand, separated by a dot.
+ * "status": SUCCESS if everything ran through. In case of failure, the only
+   content of the JSON output will be \"status\" set to FAILURE and an
+   additional key: "traceback".
+
+Each score is opt-in and can be enabled with optional arguments and is added
+to the output. Keys correspond to the values in `"model_ligands"` above.
+Only mapped ligands are reported.
+
+
+"""
+
+import argparse
+import json
+import os
+import traceback
+import warnings
+
+import ost
+from ost import conop
+from ost import io
+from ost.mol.alg import ligand_scoring
+from ost.mol.alg import Molck, MolckSettings
+
+
+
+def _ParseArgs():
+    parser = argparse.ArgumentParser(description = __doc__,
+                                     formatter_class=argparse.RawDescriptionHelpFormatter,
+                                     prog="ost compare-ligand-structures")
+
+    parser.add_argument(
+        "-m",
+        "--mdl",
+        "--model",
+        dest="model",
+        required=True,
+        help=("Path to model file."))
+
+    parser.add_argument(
+        "-ml",
+        "--mdl-ligands",
+        "--model-ligands",
+        dest="model_ligands",
+        nargs="*",
+        default=None,
+        help=("Path to model ligand files."))
+
+    parser.add_argument(
+        "-r",
+        "--ref",
+        "--reference",
+        dest="reference",
+        required=True,
+        help=("Path to reference file."))
+
+    parser.add_argument(
+        "-rl",
+        "--ref-ligands",
+        "--reference-ligands",
+        dest="reference_ligands",
+        nargs="*",
+        default=None,
+        help=("Path to reference ligand files."))
+
+    parser.add_argument(
+        "-o",
+        "--out",
+        "--output",
+        dest="output",
+        default="out.json",
+        help=("Output file name. The output will be saved as a JSON file. "
+              "default: out.json"))
+
+    parser.add_argument(
+        "-mf",
+        "--mdl-format",
+        "--model-format",
+        dest="model_format",
+        default="auto",
+        choices=["pdb", "mmcif"],
+        help=("Format of model file. Inferred from path if not given."))
+
+    parser.add_argument(
+        "-rf",
+        "--reference-format",
+        "--ref-format",
+        dest="reference_format",
+        default="auto",
+        choices=["cif", "mmcif"],
+        help=("Format of reference file. Inferred from path if not given."))
+
+    parser.add_argument(
+        "-ft",
+        "--fault-tolerant",
+        dest="fault_tolerant",
+        default=False,
+        action="store_true",
+        help=("Fault tolerant parsing."))
+
+    parser.add_argument(
+        "--residue-number-alignment",
+        "-rna",
+        dest="residue_number_alignment",
+        default=False,
+        action="store_true",
+        help=("Make alignment based on residue number instead of using "
+              "a global BLOSUM62-based alignment (NUC44 for nucleotides)."))
+
+    parser.add_argument(
+        "--lddt-pli",
+        dest="lddt_pli",
+        default=False,
+        action="store_true",
+        help=("Compute lDDT-PLI score and store as key \"lddt-pli\"."))
+
+    parser.add_argument(
+        "--rmsd",
+        dest="rmsd",
+        default=False,
+        action="store_true",
+        help=("Compute RMSD score and store as key \"lddt-pli\"."))
+
+    parser.add_argument(
+        '-v',
+        '--verbosity',
+        dest="verbosity",
+        type=int,
+        default=3,
+        help="Set verbosity level. Defaults to 3 (INFO).")
+
+    return parser.parse_args()
+
+
+def _LoadStructure(structure_path, format="auto", fault_tolerant=False):
+    """Read OST entity either from mmCIF or PDB.
+
+    The returned structure has structure_path attached as structure name
+    """
+
+    if not os.path.exists(structure_path):
+        raise Exception(f"file not found: {structure_path}")
+
+    if format == "auto":
+        # Determine file format from suffix.
+        ext = structure_path.split(".")
+        if ext[-1] == "gz":
+            ext = ext[:-1]
+        if len(ext) <= 1:
+            raise Exception(f"Could not determine format of file "
+                            f"{structure_path}.")
+        format = ext[-1].lower()
+
+    # increase loglevel, as we would pollute the info log with weird stuff
+    ost.PushVerbosityLevel(ost.LogLevel.Error)
+    # Load the structure
+    if format in ["mmcif", "cif"]:
+        entity, seqres = io.LoadMMCIF(structure_path, seqres=True,
+                                      fault_tolerant=fault_tolerant)
+        if len(entity.residues) == 0:
+            raise Exception(f"No residues found in file: {structure_path}")
+    elif format == "pdb":
+        entity = io.LoadPDB(structure_path, fault_tolerant=fault_tolerant)
+        if len(entity.residues) == 0:
+            raise Exception(f"No residues found in file: {structure_path}")
+    else:
+        raise Exception(f"Unknown/ unsupported file extension found for "
+                        f"file {structure_path}.")
+
+    # Molck it
+    molck_settings = MolckSettings(rm_unk_atoms=True,
+                                   rm_non_std=False,
+                                   rm_hyd_atoms=True,
+                                   rm_oxt_atoms=False,
+                                   rm_zero_occ_atoms=False,
+                                   colored=False,
+                                   map_nonstd_res=False,
+                                   assign_elem=True)
+    # Cleanup a copy of the structures
+    Molck(entity, conop.GetDefaultLib(), molck_settings)
+    
+    # restore old loglevel and return
+    ost.PopVerbosityLevel()
+    entity.SetName(structure_path)
+    return entity
+
+
+def _LoadLigands(ligands):
+    """
+    Load a list of ligands from file names. Return a list of entities oif the
+    same size.
+    """
+    if ligands is None:
+        return None
+    else:
+        return [_LoadLigand(lig) for lig in ligands]
+
+
+def _LoadLigand(file):
+    """
+    Load a single ligand from file names. Return an entity.
+    """
+    ligand_ent = ost.io.LoadEntity(file, format="sdf")
+    ligand_ent.Select("ele != H")
+    return ligand_ent.Copy()
+
+
+def _Validate(structure, ligands, legend):
+    """Validate the structure.
+
+    If fault_tolerant is True, only warns in case of problems. If False,
+    raise them as ValueErrors.
+
+    At the moment this chiefly checks for ligands in polymers
+    """
+    if ligands is not None:
+        for residue in structure.residues:
+            if residue.is_ligand:
+                warnings.warn("Ligand residue %s found in %s polymer structure" %(
+                    residue.qualified_name, legend
+                ), UserWarning)
+
+
+def _Process(model, model_ligands, reference, reference_ligands, args):
+
+    scorer = ligand_scoring.LigandScorer(
+        model=model,
+        target=reference,
+        model_ligands=model_ligands,
+        target_ligands=reference_ligands,
+        resnum_alignments=args.residue_number_alignment,
+    )
+
+    out = dict()
+
+    if model_ligands is not None:
+        # Replace model ligand by path
+        assert len(model_ligands) == len(scorer.model_ligands)
+        # Map ligand => path
+        model_ligands_map = {k: v for k, v in zip(scorer.model_ligands,
+                                                  args.model_ligands)}
+        out["model_ligands"] = args.model_ligands
+    else:
+        model_ligands_map = {l: "%s.%s" % (l.chain.name, l.number)
+                             for l in scorer.model_ligands}
+        out["model_ligands"] = list(model_ligands_map.values())
+
+    if reference_ligands is not None:
+        # Replace reference ligand by path
+        assert len(reference_ligands) == len(scorer.target_ligands)
+        # Map ligand => path
+        reference_ligands_map = {k: v for k, v in zip(scorer.target_ligands,
+                                                  args.reference_ligands)}
+        out["reference_ligands"] = args.reference_ligands
+    else:
+        reference_ligands_map = {l: "%s.%s" % (l.chain.name, l.number)
+                             for l in scorer.target_ligands}
+        out["reference_ligands"] = list(reference_ligands_map.values())
+
+    if args.lddt_pli:
+        out["lddt_pli"] = {}
+        for chain, lddt_pli_results in scorer.lddt_pli_details.items():
+            for _, lddt_pli in lddt_pli_results.items():
+                model_key = model_ligands_map[lddt_pli["model_ligand"]]
+                lddt_pli["reference_ligand"] = reference_ligands_map[
+                    lddt_pli.pop("target_ligand")]
+                lddt_pli["model_ligand"] = model_key
+                lddt_pli["transform"] = str(lddt_pli["transform"])
+                out["lddt_pli"][model_key] = lddt_pli
+
+    if args.rmsd:
+        out["rmsd"] = {}
+        for chain, rmsd_results in scorer.rmsd_details.items():
+            for _, rmsd in rmsd_results.items():
+                model_key = model_ligands_map[rmsd["model_ligand"]]
+                rmsd["reference_ligand"] = reference_ligands_map[
+                    rmsd.pop("target_ligand")]
+                rmsd["model_ligand"] = model_key
+                rmsd["transform"] = str(rmsd["transform"])
+                out["rmsd"][model_key] = rmsd
+
+    return out
+
+
+def _Main():
+
+    args = _ParseArgs()
+    ost.PushVerbosityLevel(args.verbosity)
+    try:
+        # Load structures
+        reference = _LoadStructure(args.reference,
+                                   format=args.reference_format,
+                                   fault_tolerant = args.fault_tolerant)
+        model = _LoadStructure(args.model, format=args.model_format,
+                               fault_tolerant = args.fault_tolerant)
+
+        # Load ligands
+        model_ligands = _LoadLigands(args.model_ligands)
+        reference_ligands = _LoadLigands(args.reference_ligands)
+
+        # Validate
+        _Validate(model, model_ligands, "model")
+        _Validate(reference, reference_ligands, "reference")
+
+        out = _Process(model, model_ligands, reference, reference_ligands, args)
+
+        out["status"] = "SUCCESS"
+        with open(args.output, 'w') as fh:
+            json.dump(out, fh, indent=4, sort_keys=False)
+
+    except Exception:
+        out = dict()
+        out["status"] = "FAILURE"
+        out["traceback"] = traceback.format_exc()
+        with open(args.output, 'w') as fh:
+            json.dump(out, fh, indent=4, sort_keys=False)
+        raise
+
+
+if __name__ == '__main__':
+    _Main()
diff --git a/modules/mol/alg/pymod/ligand_scoring.py b/modules/mol/alg/pymod/ligand_scoring.py
index 96d2a6b51..866a71cc5 100644
--- a/modules/mol/alg/pymod/ligand_scoring.py
+++ b/modules/mol/alg/pymod/ligand_scoring.py
@@ -68,7 +68,7 @@ class LigandScorer:
         molck_settings = MolckSettings(rm_unk_atoms=True,
                                        rm_non_std=False,
                                        rm_hyd_atoms=True,
-                                       rm_oxt_atoms=True,
+                                       rm_oxt_atoms=False,
                                        rm_zero_occ_atoms=False,
                                        colored=False,
                                        map_nonstd_res=False,
-- 
GitLab