From 258f724fabb3f59cb503214d972256198dc5b277 Mon Sep 17 00:00:00 2001
From: Xavier Robin <xavalias-github@xavier.robin.name>
Date: Thu, 1 Feb 2024 13:37:38 +0100
Subject: [PATCH] refactor: use new CreateBU and SaveMMCIF functions

---
 actions/ost-compare-structures | 140 ++++++++++++++-------------------
 1 file changed, 57 insertions(+), 83 deletions(-)

diff --git a/actions/ost-compare-structures b/actions/ost-compare-structures
index 3f91c3de7..56c6087a8 100644
--- a/actions/ost-compare-structures
+++ b/actions/ost-compare-structures
@@ -80,11 +80,10 @@ ost compare-structures -m model.pdb -r reference.cif -c A:B B:A
 import argparse
 import os
 import json
-import time
-import sys
 import traceback
 import math
 
+import ost
 from ost import io
 from ost.mol.alg import scoring
 
@@ -186,17 +185,17 @@ def _ParseArgs():
         dest="dump_structures",
         default=False,
         action="store_true",
-        help=("Dump cleaned structures used to calculate all the scores as "
-              "PDB files using specified suffix. Files will be dumped to the "
-              "same location as original files."))
+        help=("Dump cleaned structures used to calculate all the scores as PDB"
+              " or mmCIF files using specified suffix. Files will be dumped to"
+              " the same location and in the same format as original files."))
 
     parser.add_argument(
         "-ds",
         "--dump-suffix",
         dest="dump_suffix",
-        default=".compare.structures.pdb",
+        default="_compare_structures",
         help=("Use this suffix to dump structures.\n"
-              "Defaults to .compare.structures.pdb."))
+              "Defaults to _compare_structures"))
 
     parser.add_argument(
         "-ft",
@@ -534,50 +533,19 @@ def _RoundOrNone(num, decimals = 3):
         return None
     return round(num, decimals)
 
-def _Rename(ent):
-    """Revert chain names to original names.
-
-    PDBize assigns chain name in order A,B,C,D... which does not allow to infer
-    the original chain name. We do a renaming here:
-    if there are two chains mapping to chain A the resulting
-    chain names will be: A and A2.
+def _AddSuffix(filename, dump_suffix):
+    """Add dump_suffix to the file name.
     """
-    new_chain_names = list()
-    chain_indices = list() # the chains where we actually change the name
-    suffix_indices = dict() # keep track of whats the current suffix index
-                            # for each original chain name
-
-    for ch_idx, ch in enumerate(ent.chains):
-        if not ch.HasProp("original_name"):
-            # pdbize doesnt set this property for chain names in ['_', '-']
-            continue
-        original_name = ch.GetStringProp("original_name")
-        if original_name in new_chain_names:
-            new_name = original_name + str(suffix_indices[original_name])
-            new_chain_names.append(new_name)
-            suffix_indices[original_name] = suffix_indices[original_name] + 1
-        else:
-            new_chain_names.append(original_name)
-            suffix_indices[original_name] = 2
-        chain_indices.append(ch_idx)
-    editor = ent.EditXCS()
-    # rename to nonsense to avoid clashing chain names
-    for ch_idx in chain_indices:
-        editor.RenameChain(ent.chains[ch_idx], ent.chains[ch_idx].name+"_yolo")
-    # and do final renaming
-    for new_name, ch_idx in zip(new_chain_names, chain_indices):
-        editor.RenameChain(ent.chains[ch_idx], new_name)
-
-def _LoadStructure(structure_path, sformat=None, fault_tolerant=False,
-                   bu_idx=None):
-    """Read OST entity either from mmCIF or PDB.
-
-    The returned structure has structure_path attached as structure name
+    root, ext = os.path.splitext(filename)
+    if ext == ".gz":
+        root, ext2 = os.path.splitext(root)
+        ext = ext2 + ext
+    return root + dump_suffix + ext
+
+def _GetStructureFormat(structure_path, sformat=None):
+    """Get the structure format and return it as "pdb" or "mmcif".
     """
 
-    if not os.path.exists(structure_path):
-        raise Exception(f"file not found: {structure_path}")
-
     if sformat is None:
         # Determine file format from suffix.
         ext = structure_path.split(".")
@@ -587,11 +555,26 @@ def _LoadStructure(structure_path, sformat=None, fault_tolerant=False,
             raise Exception(f"Could not determine format of file "
                             f"{structure_path}.")
         sformat = ext[-1].lower()
+    if sformat in ["mmcif", "cif"]:
+        return "mmcif"
+    elif sformat == "pdb":
+        return sformat
+    else:
+        raise Exception(f"Unknown/unsupported file format found for "
+                        f"file {structure_path}.")
+
+def _LoadStructure(structure_path, sformat, fault_tolerant, bu_idx):
+    """Read OST entity either from mmCIF or PDB.
+
+    The returned structure has structure_path attached as structure name
+    """
+    if not os.path.exists(structure_path):
+        raise Exception(f"file not found: {structure_path}")
 
     # increase loglevel, as we would pollute the info log with weird stuff
     ost.PushVerbosityLevel(ost.LogLevel.Error)
     # Load the structure
-    if sformat in ["mmcif", "cif"]:
+    if sformat == "mmcif":
         if bu_idx is not None:
             cif_entity, cif_seqres, cif_info = \
             io.LoadMMCIF(structure_path, info=True, seqres=True,
@@ -600,28 +583,31 @@ def _LoadStructure(structure_path, sformat=None, fault_tolerant=False,
                 raise RuntimeError(f"Invalid biounit index - requested {bu_idx} "
                                    f"must be < {len(cif_info.biounits)}.")
             biounit = cif_info.biounits[bu_idx]
-            entity = biounit.PDBize(cif_entity, min_polymer_size=0)
+            entity = ost.mol.alg.CreateBU(cif_entity, biounit)
             if not entity.IsValid():
                 raise IOError(
                     "Provided file does not contain valid entity.")
-            _Rename(entity)
         else:
             entity = io.LoadMMCIF(structure_path,
                                   fault_tolerant = fault_tolerant)
         if len(entity.residues) == 0:
             raise Exception(f"No residues found in file: {structure_path}")
-    elif sformat == "pdb":
+    else:
         entity = io.LoadPDB(structure_path, fault_tolerant = fault_tolerant)
         if len(entity.residues) == 0:
             raise Exception(f"No residues found in file: {structure_path}")
-    else:
-        raise Exception(f"Unknown/ unsupported file extension found for "
-                        f"file {structure_path}.")
+
     # restore old loglevel and return
     ost.PopVerbosityLevel()
     entity.SetName(structure_path)
     return entity
 
+def _DumpStructure(entity, structure_path, sformat):
+    if sformat == "mmcif":
+        io.SaveMMCIF(entity, structure_path)
+    else:
+        io.SavePDB(entity, structure_path)
+
 def _AlnToFastaStr(aln):
     """ Returns alignment as fasta formatted string
     """
@@ -714,7 +700,7 @@ def _GetAlignedResidues(aln):
                                  "reference": ref_dct})
     return aligned_residues
 
-def _Process(model, reference, args):
+def _Process(model, reference, args, model_format, reference_format):
 
     mapping = None
     if args.chain_mapping is not None:
@@ -855,32 +841,16 @@ def _Process(model, reference, args):
         out["usalign_mapping"] = scorer.usalign_mapping
 
     if args.dump_structures:
-        try:
-            io.SavePDB(scorer.model, model.GetName() + args.dump_suffix)
-        except Exception as e:
-            if "single-letter" in str(e) and args.model_biounit is not None:
-                raise RuntimeError("Failed to dump processed model. PDB "
-                                   "format only supports single character "
-                                   "chain names. This is likely the result of "
-                                   "chain renaming when constructing a user "
-                                   "specified biounit. Dumping structures "
-                                   "fails in this case.")
-            else:
-                raise
-        try:
-            io.SavePDB(scorer.target, reference.GetName() + args.dump_suffix)
-        except Exception as e:
-            if "single-letter" in str(e) and args.reference_biounit is not None:
-                raise RuntimeError("Failed to dump processed reference. PDB "
-                                   "format only supports single character "
-                                   "chain names. This is likely the result of "
-                                   "chain renaming when constructing a user "
-                                   "specified biounit. Dumping structures "
-                                   "fails in this case.")
-            else:
-                raise
+        # Dump model
+        model_dump_filename = _AddSuffix(model.GetName(), args.dump_suffix)
+        _DumpStructure(model, model_dump_filename, model_format)
+        # Dump reference
+        reference_dump_filename = _AddSuffix(reference.GetName(), args.dump_suffix)
+        _DumpStructure(reference, reference_dump_filename, reference_format)
+
     return out
 
+
 def _Main():
 
     args = _ParseArgs()
@@ -890,15 +860,19 @@ def _Main():
             raise RuntimeError("Only support CAD score when residue numbers in "
                                "model and reference match. Use -rna flag if "
                                "this is the case.")
+        reference_format = _GetStructureFormat(args.reference,
+                                               sformat=args.reference_format)
         reference = _LoadStructure(args.reference,
-                                   sformat=args.reference_format,
+                                   sformat=reference_format,
                                    bu_idx=args.reference_biounit,
                                    fault_tolerant = args.fault_tolerant)
+        model_format = _GetStructureFormat(args.model,
+                                           sformat=args.model_format)
         model = _LoadStructure(args.model,
-                               sformat=args.model_format,
+                               sformat=model_format,
                                bu_idx=args.model_biounit,
                                fault_tolerant = args.fault_tolerant)
-        out = _Process(model, reference, args)
+        out = _Process(model, reference, args, model_format, reference_format)
 
         # append input arguments
         out["model"] = args.model
-- 
GitLab