From d4835b37300ed2b4f97399c7f96435d496e39d0b Mon Sep 17 00:00:00 2001
From: Stefan Bienert <stefan.bienert@unibas.ch>
Date: Tue, 5 Dec 2023 14:12:45 +0100
Subject: [PATCH] Add early protocol

---
 projects/novelfams/translate2modelcif.py | 117 +++++++++++++++++++++--
 1 file changed, 109 insertions(+), 8 deletions(-)

diff --git a/projects/novelfams/translate2modelcif.py b/projects/novelfams/translate2modelcif.py
index d4392f0..49e1e72 100644
--- a/projects/novelfams/translate2modelcif.py
+++ b/projects/novelfams/translate2modelcif.py
@@ -80,6 +80,13 @@ def _parse_args():
         metavar="<OUTPUT DIR>",
         help="Path to directory to store results.",
     )
+    parser.add_argument(
+        "--af2-models",
+        default=None,
+        type=str,
+        metavar="<LIST FILE>",
+        help="Path to a txt file with models build with AF2, 1 ID per line.",
+    )
     parser.add_argument(
         "--compress",
         default=False,
@@ -415,12 +422,10 @@ def _get_modelcif_protocol_software(js_step):
     return None
 
 
-def _get_modelcif_protocol_data(data_label, target_entities, aln_data, model):
+def _get_modelcif_protocol_data(data_label, target_entities, model):
     """Assemble data for a ModelCIF protocol step."""
     if data_label == "target_sequences":
         data = modelcif.data.DataGroup(target_entities)
-    elif data_label == "MSA":
-        data = aln_data
     elif data_label == "target_sequences_and_MSA":
         data = modelcif.data.DataGroup(target_entities)
         data.append(aln_data)
@@ -431,16 +436,16 @@ def _get_modelcif_protocol_data(data_label, target_entities, aln_data, model):
     return data
 
 
-def _get_modelcif_protocol(protocol_steps, target_entities, aln_data, model):
+def _get_modelcif_protocol(protocol_steps, target_entities, model):
     """Create the protocol for the ModelCIF file."""
     protocol = modelcif.protocol.Protocol()
     for js_step in protocol_steps:
         sftwre = _get_modelcif_protocol_software(js_step)
         input_data = _get_modelcif_protocol_data(
-            js_step["input"], target_entities, aln_data, model
+            js_step["input"], target_entities, model
         )
         output_data = _get_modelcif_protocol_data(
-            js_step["output"], target_entities, aln_data, model
+            js_step["output"], target_entities, model
         )
 
         protocol.steps.append(
@@ -518,6 +523,11 @@ def _store_as_modelcif(
     model_group = modelcif.model.ModelGroup([model])
     system.model_groups.append(model_group)
 
+    protocol = _get_modelcif_protocol(
+        data_json["protocol"], system.target_entities, model
+    )
+    system.protocols.append(protocol)
+
     # write modelcif System to file (NOTE: no PAE here!)
     # NOTE: we change path and back while being exception-safe to handle zipfile
     oldpwd = os.getcwd()
@@ -533,10 +543,75 @@ def _store_as_modelcif(
         os.chdir(oldpwd)
 
 
+def _get_protocol_steps_and_software_colabfold(config_data):
+    """Get protocol steps for ColabFold models."""
+    protocol = []
+
+    # modelling step
+    step = {
+        "method_type": "modeling",
+        "name": None,
+        "details": config_data["description"],
+    }
+    # get input data
+    # Must refer to data already in the JSON, so we try keywords
+    step["input"] = "target_sequences"
+    # get output data
+    # Must refer to existing data, so we try keywords
+    step["output"] = "model"
+    # get software
+    step["software"] = [
+        {
+            "name": "ColabFold",
+            "classification": "model building",
+            "description": "Structure prediction",
+            "citation": ihm.citations.colabfold,
+            "location": "https://github.com/sokrypton/ColabFold",
+            "type": "package",
+            "version": None,
+        }
+    ]
+    step["software"].append(
+        {
+            "name": "AlphaFold",
+            "classification": "model building",
+            "description": "Structure prediction",
+            "citation": ihm.citations.alphafold2,
+            "location": "https://github.com/deepmind/alphafold",
+            "type": "package",
+            "version": None,
+        }
+    )
+    step["software_parameters"] = None
+    protocol.append(step)
+
+    return protocol
+
+
+def _get_config_colabfold():
+    """Get config variables for ColabFold"""
+    description = "Model generation using ColabFold."
+
+    return {"description": description}
+
+
+def _get_protocol_steps_and_software(mdl_id, af2_lst):
+    """Get protocol steps for this model, make a difference between AF2 and
+    ColabFold models."""
+    if mdl_id in af2_lst:
+        protocol = _get_protocol_steps_and_software_alphafold()
+    else:
+        config_data = _get_config_colabfold()
+        protocol = _get_protocol_steps_and_software_colabfold(config_data)
+
+    return protocol
+
+
 def _translate2modelcif_single(
     f_name,
     opts,
     mdl_details,
+    af2_lst,
 ):
     """Convert a single model with its accompanying data to ModelCIF."""
     # ToDo: re-enable Pylint
@@ -546,13 +621,16 @@ def _translate2modelcif_single(
     # gather data into JSON-like structure
     mdlcf_json = {}
     mdlcf_json["mdl_id"] = fam_name  # used for entry ID
+    mdlcf_json["protocol"] = _get_protocol_steps_and_software(
+        fam_name, af2_lst
+    )
 
     # process coordinates
     target_entities, ost_ent = _get_entities(f_name, fam_name)
     mdlcf_json["target_entities"] = target_entities
 
     # fill annotations
-    mdlcf_json["title"] = _get_title(f_name)
+    mdlcf_json["title"] = _get_title(fam_name)
     mdlcf_json["model_details"] = mdl_details
 
     # save ModelCIF
@@ -565,7 +643,7 @@ def _translate2modelcif_single(
     )
 
 
-def _translate2modelcif(f_name, opts):
+def _translate2modelcif(f_name, af2_lst, opts):
     """Convert a family of models with their accompanying data to ModelCIF."""
     # ToDo: re-enable Pylint
     # pylint: disable=too-many-locals
@@ -590,9 +668,28 @@ def _translate2modelcif(f_name, opts):
         f_name,
         opts,
         mdl_details,
+        af2_lst,
     )
 
 
+def _read_af2_model_list(path):
+    """Read a list of models build with AF2. One ID per line. Returns an empty
+    list if path is None."""
+    af2_lst = []
+
+    if path is None:
+        return af2_lst
+
+    with open(path, encoding="ascii") as lfh:
+        for line in lfh:
+            line = line.strip()
+            af2_lst.append(line)
+
+    print(f"Got a list of {len(af2_lst)} models built with AF2.")
+
+    return af2_lst
+
+
 def _main():
     """Run as script."""
     s_tmstmp = timer()
@@ -602,6 +699,9 @@ def _main():
     pdb_files = _get_pdb_files(opts.model_dir)
     n_mdls = len(pdb_files)
 
+    # read list of AF2 models
+    af2_mdls = _read_af2_model_list(opts.af2_models)
+
     # iterate over models
     print(f"Processing {n_mdls} models.")
     tmstmp = s_tmstmp
@@ -610,6 +710,7 @@ def _main():
         try:
             _translate2modelcif(
                 f_name,
+                af2_mdls,
                 opts,
             )
         except (_InvalidCoordinateError, _NoEntitiesError):
-- 
GitLab