From 5c61238e5aa12659b0a01ab7985738f98bec133c Mon Sep 17 00:00:00 2001
From: Xavier Robin <xavalias-github@xavier.robin.name>
Date: Thu, 1 Feb 2024 14:52:32 +0100
Subject: [PATCH] feat: handle biounits in ligand scorer.

This is made possible now that we have a CreateBU function.
---
 actions/ost-compare-ligand-structures | 163 ++++++++++++++++++++------
 1 file changed, 125 insertions(+), 38 deletions(-)

diff --git a/actions/ost-compare-ligand-structures b/actions/ost-compare-ligand-structures
index 1bf297628..57d42cce3 100644
--- a/actions/ost-compare-ligand-structures
+++ b/actions/ost-compare-ligand-structures
@@ -8,8 +8,7 @@ Example: ost compare-ligand-structures \\
     --lddt-pli --rmsd
 
 Structures of polymer entities (proteins and nucleotides) can be given in PDB
-or mmCIF format. If the structure is given in mmCIF format, only the asymmetric
-unit (AU) is used for scoring.
+or mmCIF format.
 
 Ligands can be given as path to SDF files containing the ligand for both model
 (--model-ligands/-ml) and reference (--reference-ligands/-rl). If omitted,
@@ -117,18 +116,46 @@ def _ParseArgs():
         "--mdl-format",
         "--model-format",
         dest="model_format",
-        default="auto",
-        choices=["pdb", "mmcif", "cif"],
-        help=("Format of model file. Inferred from path if not given."))
+        required=False,
+        default=None,
+        choices=["pdb", "cif", "mmcif"],
+        help=("Format of model file. pdb reads pdb but also pdb.gz, same "
+              "applies to cif/mmcif. Inferred from filepath if not given."))
 
     parser.add_argument(
         "-rf",
         "--reference-format",
         "--ref-format",
         dest="reference_format",
-        default="auto",
-        choices=["pdb", "mmcif", "cif"],
-        help=("Format of reference file. Inferred from path if not given."))
+        required=False,
+        default=None,
+        choices=["pdb", "cif", "mmcif"],
+        help=("Format of reference file. pdb reads pdb but also pdb.gz, same "
+              "applies to cif/mmcif. Inferred from filepath if not given."))
+
+    parser.add_argument(
+        "-mb",
+        "--model-biounit",
+        dest="model_biounit",
+        required=False,
+        default=None,
+        type=int,
+        help=("Only has an effect if model is in mmcif format. By default, "
+              "the asymmetric unit (AU) is used for scoring. If there are "
+              "biounits defined in the mmcif file, you can specify the "
+              "(0-based) index of the one which should be used."))
+
+    parser.add_argument(
+        "-rb",
+        "--reference-biounit",
+        dest="reference_biounit",
+        required=False,
+        default=None,
+        type=int,
+        help=("Only has an effect if reference is in mmcif format. By default, "
+              "the asymmetric unit (AU) is used for scoring. If there are "
+              "biounits defined in the mmcif file, you can specify the "
+              "(0-based) index of the one which should be used."))
 
     parser.add_argument(
         "-ft",
@@ -268,16 +295,11 @@ def _ParseArgs():
     return parser.parse_args()
 
 
-def _LoadStructure(structure_path, format="auto", fault_tolerant=False):
-    """Read OST entity either from mmCIF or PDB.
-
-    The returned structure has structure_path attached as structure name
+def _GetStructureFormat(structure_path, sformat=None):
+    """Get the structure format and return it as "pdb" or "mmcif".
     """
 
-    if not os.path.exists(structure_path):
-        raise Exception(f"file not found: {structure_path}")
-
-    if format == "auto":
+    if sformat is None:
         # Determine file format from suffix.
         ext = structure_path.split(".")
         if ext[-1] == "gz":
@@ -285,26 +307,60 @@ def _LoadStructure(structure_path, format="auto", fault_tolerant=False):
         if len(ext) <= 1:
             raise Exception(f"Could not determine format of file "
                             f"{structure_path}.")
-        format = ext[-1].lower()
+        sformat = ext[-1].lower()
+    if sformat in ["mmcif", "cif"]:
+        return "mmcif"
+    elif sformat == "pdb":
+        return sformat
+    else:
+        raise Exception(f"Unknown/unsupported file format found for "
+                        f"file {structure_path}.")
+
 
+def _LoadStructure(structure_path, sformat, fault_tolerant, bu_idx):
+    """Read OST entity either from mmCIF or PDB.
+
+    The returned structure has structure_path attached as structure name
+    """
+    if not os.path.exists(structure_path):
+        raise Exception(f"file not found: {structure_path}")
+
+    # increase loglevel, as we would pollute the info log with weird stuff
+    ost.PushVerbosityLevel(ost.LogLevel.Error)
     # Load the structure
-    if format in ["mmcif", "cif"]:
-        entity = io.LoadMMCIF(structure_path, fault_tolerant=fault_tolerant)
+    if sformat == "mmcif":
+        if bu_idx is not None:
+            cif_entity, cif_seqres, cif_info = \
+            io.LoadMMCIF(structure_path, info=True, seqres=True,
+                         fault_tolerant=fault_tolerant)
+            if len(cif_info.biounits) == 0:
+                raise RuntimeError(f"No biounit found - requested index"
+                                   f" {bu_idx}.")
+            elif bu_idx < 0:
+                raise RuntimeError(f"Invalid biounit - requested index {bu_idx}, "
+                                   f"must be a positive integer or 0.")
+            elif bu_idx >= len(cif_info.biounits):
+                raise RuntimeError(f"Invalid biounit - requested index {bu_idx}, "
+                                   f"must be < {len(cif_info.biounits)}.")
+            biounit = cif_info.biounits[bu_idx]
+            entity = ost.mol.alg.CreateBU(cif_entity, biounit)
+            if not entity.IsValid():
+                raise IOError(
+                    "Provided file does not contain valid entity.")
+        else:
+            entity = io.LoadMMCIF(structure_path,
+                                  fault_tolerant = fault_tolerant)
         if len(entity.residues) == 0:
             raise Exception(f"No residues found in file: {structure_path}")
-    elif format == "pdb":
-        entity = io.LoadPDB(structure_path, fault_tolerant=fault_tolerant)
+    else:
+        entity = io.LoadPDB(structure_path, fault_tolerant = fault_tolerant)
         if len(entity.residues) == 0:
             raise Exception(f"No residues found in file: {structure_path}")
-    else:
-        raise Exception(f"Unknown/ unsupported file extension found for "
-                        f"file {structure_path}.")
 
-    # Remove hydrogens
-    cleaned_entity = ost.mol.CreateEntityFromView(entity.Select(
-        "ele != H and ele != D"), include_exlusive_atoms=False)
-    cleaned_entity.SetName(structure_path)
-    return cleaned_entity
+    # restore old loglevel and return
+    ost.PopVerbosityLevel()
+    entity.SetName(structure_path)
+    return entity
 
 
 def _LoadLigands(ligands):
@@ -322,9 +378,24 @@ def _LoadLigand(file):
     """
     Load a single ligand from file names. Return an entity.
     """
-    ligand_ent = ost.io.LoadEntity(file, format="sdf")
-    ligand_view = ligand_ent.Select("ele != H")
-    return ost.mol.CreateEntityFromView(ligand_view, False)
+    return ost.io.LoadEntity(file, format="sdf")
+
+
+def _CleanupStructure(entity):
+    """Cleans up the structure.
+    Currently only removes hydrogens (and deuterium atoms).
+    """
+    return ost.mol.CreateEntityFromView(entity.Select(
+        "ele != H and ele != D"), include_exlusive_atoms=False)
+
+
+def _CleanupLigands(ligands):
+    """Clean up a list of structures.
+    """
+    if ligands is None:
+        return None
+    else:
+        return [_CleanupStructure(lig) for lig in ligands]
 
 
 def _Validate(structure, ligands, legend, fault_tolerant=False):
@@ -333,7 +404,8 @@ def _Validate(structure, ligands, legend, fault_tolerant=False):
     If fault_tolerant is True, only warns in case of problems. If False,
     raise them as ValueErrors.
 
-    At the moment this chiefly checks for ligands in polymers
+    At the moment this chiefly checks if ligands are in the structure and are
+    given explicitly at the same time.
     """
     if ligands is not None:
         for residue in structure.residues:
@@ -518,23 +590,38 @@ def _Main():
     ost.PushVerbosityLevel(args.verbosity)
     try:
         # Load structures
+        reference_format = _GetStructureFormat(args.reference,
+                                               sformat=args.reference_format)
         reference = _LoadStructure(args.reference,
-                                   format=args.reference_format,
+                                   sformat=reference_format,
+                                   bu_idx=args.reference_biounit,
                                    fault_tolerant = args.fault_tolerant)
-        model = _LoadStructure(args.model, format=args.model_format,
+        model_format = _GetStructureFormat(args.model,
+                                           sformat=args.model_format)
+        model = _LoadStructure(args.model,
+                               sformat=model_format,
+                               bu_idx=args.model_biounit,
                                fault_tolerant = args.fault_tolerant)
 
         # Load ligands
         model_ligands = _LoadLigands(args.model_ligands)
         reference_ligands = _LoadLigands(args.reference_ligands)
 
+        # Cleanup
+        cleaned_reference = _CleanupStructure(reference)
+        cleaned_model = _CleanupStructure(model)
+        cleaned_reference_ligands = _CleanupLigands(reference_ligands)
+        cleaned_model_ligands = _CleanupLigands(model_ligands)
+
         # Validate
-        _Validate(model, model_ligands, "model",
+        _Validate(cleaned_model, cleaned_model_ligands, "model",
                   fault_tolerant = args.fault_tolerant)
-        _Validate(reference, reference_ligands, "reference",
+        _Validate(cleaned_reference, cleaned_reference_ligands, "reference",
                   fault_tolerant = args.fault_tolerant)
 
-        out = _Process(model, model_ligands, reference, reference_ligands, args)
+        out = _Process(cleaned_model, cleaned_model_ligands,
+                       cleaned_reference, cleaned_reference_ligands,
+                       args)
 
         out["status"] = "SUCCESS"
         with open(args.output, 'w') as fh:
-- 
GitLab