From fb3fecfc65f106622d3c033826c4b226dfa64af7 Mon Sep 17 00:00:00 2001
From: Stefan Bienert <stefan.bienert@unibas.ch>
Date: Wed, 15 Nov 2023 16:22:02 +0100
Subject: [PATCH] Store non-selected models in archive upon command line option

---
 .spelling              |   3 +-
 convert_to_modelcif.py | 172 ++++++++++++++++++++++++++++++++---------
 2 files changed, 137 insertions(+), 38 deletions(-)

diff --git a/.spelling b/.spelling
index 0b5aca6..239bed7 100644
--- a/.spelling
+++ b/.spelling
@@ -1,5 +1,6 @@
 Biopython
 CIF
+DBs
 FastA
 Jupyter
 MSA
@@ -15,4 +16,4 @@ polypeptide
 pre
 repo
 reproducibility
-subdirectory
+subdirectory
\ No newline at end of file
diff --git a/convert_to_modelcif.py b/convert_to_modelcif.py
index c6dcd81..fbb6a84 100755
--- a/convert_to_modelcif.py
+++ b/convert_to_modelcif.py
@@ -13,6 +13,7 @@ import pickle
 import re
 import shutil
 import sys
+import tempfile
 import zipfile
 
 from Bio import SeqIO
@@ -45,7 +46,7 @@ from alphapulldown.utils import make_dir_monomer_dictionary
 # ToDo: Example 1 from the GitHub repo mentions MMseqs2
 # ToDo: Discuss input of protocol steps, feature creation has baits, sequences
 #       does modelling depend on mode?
-# ToDo: check that PAE files are written to an associated file
+# ToDo: Option to add remaining models w PAE files to archive
 # ToDo: deal with `--max_template_date`, beta-barrel project has it as software
 #       parameter
 flags.DEFINE_string(
@@ -59,16 +60,24 @@ flags.DEFINE_list(
 flags.DEFINE_integer(
     "model_selected",
     None,
-    "model to be converted into ModelCIF, use '--select_all' to convert all "
-    + "models found in '--af2_output'",
+    "model to be converted into ModelCIF, omit to convert all models found in "
+    + "'--af2_output'",
 )
-flags.DEFINE_bool("compress", False, "compress the ModelCIF file using Gzip")
+flags.DEFINE_bool(
+    "add_associated",
+    False,
+    "Add models not marked by "
+    + "'--model_selected' to the archive for associated files",
+)
+flags.DEFINE_bool("compress", False, "compress the ModelCIF file(s) using Gzip")
 flags.mark_flags_as_required(["ap_output", "monomer_objects_dir"])
 
 FLAGS = flags.FLAGS
 
 # ToDo: implement a flags.register_validator() checking that files/ directories
 #       exist as expected.
+# ToDo: implement a flags.register_validator() to make sure that
+#       --add_associated is only activated if --model_selected is used, too
 
 
 # pylint: disable=too-few-public-methods
@@ -145,7 +154,7 @@ class _Biopython2ModelCIF(modelcif.model.AbInitioModel):
                 occupancy=atm.occupancy,
             )
 
-    def add_scores(self, scores_json, entry_id, file_prefix, sw_dct):
+    def add_scores(self, scores_json, entry_id, file_prefix, sw_dct, add_files):
         """Add QA metrics"""
         _GlobalPLDDT.software = sw_dct["AlphaFold"]
         _GlobalPTM.software = sw_dct["AlphaFold"]
@@ -221,6 +230,9 @@ class _Biopython2ModelCIF(modelcif.model.AbInitioModel):
             )
         ]
 
+        if add_files:
+            arc_files.extend([x[1] for x in add_files.values()])
+
         return modelcif.associated.Repository(
             "",
             [
@@ -348,7 +360,7 @@ def _get_modelcif_protocol(
 
 
 def _cast_release_date(release_date):
-    """Type cast a date into datetime.date"""
+    """Type cast a date into `datetime.date`"""
     # "AF2" has a special meaning, those DBs did not change since the first
     # release of AF2. This information is needed in the model-producing
     # pipeline.
@@ -414,7 +426,8 @@ def _store_as_modelcif(
     mdl_file: str,
     out_dir: str,
     compress: bool = False,
-    # file_prfx, add_files
+    add_files: list = None,
+    # file_prfx
 ) -> None:
     """Create the actual ModelCIF file."""
     system = modelcif.System(
@@ -451,7 +464,7 @@ def _store_as_modelcif(
     # process scores
     mdl_file = os.path.splitext(os.path.basename(mdl_file))[0]
     system.repositories.append(
-        model.add_scores(data_json, system.id, mdl_file, sw_dct)
+        model.add_scores(data_json, system.id, mdl_file, sw_dct, add_files)
     )
 
     system.model_groups.append(modelcif.model.ModelGroup([model]))
@@ -493,33 +506,78 @@ def _store_as_modelcif(
     # -> hence we cheat by changing path and back while being exception-safe...
     oldpwd = os.getcwd()
     os.chdir(out_dir)
+    created_files = {}
     try:
+        mdl_file = f"{mdl_file}.cif"
         with open(
-            f"{mdl_file}.cif",
+            mdl_file,
             "w",
             encoding="ascii",
         ) as mmcif_fh:
             modelcif.dumper.write(mmcif_fh, [system])
         if compress:
-            _compress_cif_file(f"{mdl_file}.cif")
+            mdl_file = _compress_cif_file(mdl_file)
+        created_files[mdl_file] = (
+            os.path.join(out_dir, mdl_file),
+            _get_assoc_mdl_file(mdl_file, data_json),
+        )
         # Create associated archive
         for archive in system.repositories[0].files:
             with zipfile.ZipFile(
                 archive.path, "w", zipfile.ZIP_BZIP2
             ) as cif_zip:
                 for zfile in archive.files:
-                    cif_zip.write(zfile.path, arcname=zfile.path)
-                    os.remove(zfile.path)
+                    try:
+                        # Regardless off error, fall back to `zfile.path`, the
+                        # other path is only needed as a special case.
+                        # pylint: disable=bare-except
+                        sys_path = add_files[zfile.path][0]
+                    except:
+                        sys_path = zfile.path
+                    cif_zip.write(sys_path, arcname=zfile.path)
+                    os.remove(sys_path)
+            created_files[archive.path] = (
+                os.path.join(out_dir, archive.path),
+                _get_assoc_zip_file(archive.path, data_json),
+            )
     finally:
         os.chdir(oldpwd)
 
+    return created_files
+
+
+def _get_assoc_mdl_file(fle_path, data_json):
+    """Generate a `modelcif.associated.File` object that looks like a CIF
+    file."""
+    cfile = modelcif.associated.File(
+        fle_path,
+        details=data_json["_ma_model_list.model_name"],
+    )
+    cfile.file_format = "cif"
+    return cfile
+
+
+def _get_assoc_zip_file(fle_path, data_json):
+    """Create a `modelcif.associated.File` object that looks like a ZIP file.
+    This is NOT the archive ZIP file for the PAEs but to store that in the
+    ZIP archive of the selected model."""
+    zfile = modelcif.associated.File(
+        fle_path,
+        details="archive with multiple files for "
+        + data_json["_ma_model_list.model_name"],
+    )
+    zfile.file_format = "other"
+    return zfile
+
 
 def _compress_cif_file(cif_file):
     """Compress CIF file and delete original."""
+    cif_gz_file = cif_file + ".gz"
     with open(cif_file, "rb") as f_in:
-        with gzip.open(cif_file + ".gz", "wb") as f_out:
+        with gzip.open(cif_gz_file, "wb") as f_out:
             shutil.copyfileobj(f_in, f_out)
     os.remove(cif_file)
+    return cif_gz_file
 
 
 def _get_model_details(cmplx_name: str, data_json: dict) -> str:
@@ -855,6 +913,7 @@ def alphapulldown_model_to_modelcif(
     prj_dir: str,
     monomer_objects_dir: list,
     compress: bool = False,
+    additional_assoc_files: list = None,
 ) -> None:
     """Convert an AlphaPulldown model into a ModelCIF formatted mmCIF file.
 
@@ -881,12 +940,36 @@ def alphapulldown_model_to_modelcif(
     _get_scores(modelcif_json, mdl[1])
 
     modelcif_json["ma_protocol_step"] = _get_protocol_steps(modelcif_json)
-
-    _store_as_modelcif(modelcif_json, structure, mdl[0], out_dir, compress)
+    cfs = _store_as_modelcif(
+        modelcif_json,
+        structure,
+        mdl[0],
+        out_dir,
+        compress,
+        additional_assoc_files,
+    )
     # ToDo: ENABLE logging.info(f"... done with '{mdl[0]}'")
+    return cfs
 
 
-def _get_model_list(ap_dir: str, model_selected: str) -> Tuple[str, str, list]:
+def _add_mdl_to_list(mdl, model_list, mdl_path, score_files):
+    """Fetch info from file name to add to list"""
+    rank = re.match(r"ranked_(\d+)\.pdb", mdl)
+    if rank is not None:
+        rank = int(rank.group(1))
+        model_list.append(
+            (
+                os.path.join(mdl_path, mdl),
+                score_files[rank][0],
+                score_files[rank][1],  # model ID
+                score_files[rank][2],  # model rank
+            )
+        )
+
+
+def _get_model_list(
+    ap_dir: str, model_selected: str, get_non_selected: bool
+) -> Tuple[str, str, list, list]:
     """Get the list of models to be converted.
 
     If `model_selected` is none, all models will be marked for conversion."""
@@ -924,28 +1007,25 @@ def _get_model_list(ap_dir: str, model_selected: str) -> Tuple[str, str, list]:
             i,
         )
     # match PDB files with pickle files
+    not_selected_models = []
     if model_selected is not None:
-        models.append(
-            (
-                os.path.join(mdl_path, f"ranked_{model_selected}.pdb"),
-                score_files[model_selected][0],
-                score_files[model_selected][1],  # model ID
-                score_files[model_selected][2],  # model rank
-            )
+        if model_selected not in score_files:
+            logging.info(f"Model of rank {model_selected} not found.")
+            sys.exit()
+        _add_mdl_to_list(
+            f"ranked_{model_selected}.pdb", models, mdl_path, score_files
         )
+
+        if get_non_selected:
+            for mdl in os.listdir(mdl_path):
+                if mdl == f"ranked_{model_selected}.pdb":
+                    continue
+                _add_mdl_to_list(
+                    mdl, not_selected_models, mdl_path, score_files
+                )
     else:
         for mdl in os.listdir(mdl_path):
-            rank = re.match(r"ranked_(\d+)\.pdb", mdl)
-            if rank is not None:
-                rank = int(rank.group(1))
-                models.append(
-                    (
-                        os.path.join(mdl_path, mdl),
-                        score_files[rank][0],
-                        score_files[rank][1],  # model ID
-                        score_files[rank][2],  # model rank
-                    )
-                )
+            _add_mdl_to_list(mdl, models, mdl_path, score_files)
 
     # check that files actually exist
     for mdl, scrs, *_ in models:
@@ -960,7 +1040,7 @@ def _get_model_list(ap_dir: str, model_selected: str) -> Tuple[str, str, list]:
             )
             sys.exit()
 
-    return cmplx, mdl_path, models
+    return cmplx, mdl_path, models, not_selected_models
 
 
 def main(argv):
@@ -986,9 +1066,26 @@ def main(argv):
     del argv  # Unused.
 
     # get list of selected models and assemble ModelCIF files + associated data
-    complex_name, model_dir, model_list = _get_model_list(
-        FLAGS.ap_output, FLAGS.model_selected
+    complex_name, model_dir, model_list, not_selected = _get_model_list(
+        FLAGS.ap_output,
+        FLAGS.model_selected,
+        FLAGS.add_associated,
     )
+    add_assoc_files = {}
+    if len(not_selected) > 0:
+        # pylint: disable=consider-using-with
+        ns_tmpdir = tempfile.TemporaryDirectory(suffix="_modelcif")
+        for mdl in not_selected:
+            add_assoc_files.update(
+                alphapulldown_model_to_modelcif(
+                    complex_name,
+                    mdl,
+                    ns_tmpdir.name,
+                    FLAGS.ap_output,
+                    FLAGS.monomer_objects_dir,
+                    FLAGS.compress,
+                )
+            )
     for mdl in model_list:
         alphapulldown_model_to_modelcif(
             complex_name,
@@ -997,6 +1094,7 @@ def main(argv):
             FLAGS.ap_output,
             FLAGS.monomer_objects_dir,
             FLAGS.compress,
+            add_assoc_files,
         )
 
 
-- 
GitLab