From b1058e54e0af6d5cb196297d51332642d5ecaa5a Mon Sep 17 00:00:00 2001
From: Gerardo Tauriello <gerardo.tauriello@unibas.ch>
Date: Mon, 30 Oct 2023 11:45:31 +0100
Subject: [PATCH] SCHWED-6036: Implement full conversion

Needed additions/changes:
- Handle PAE matrix input
- Allow for storage of all models
- Add options to keep MSA and PAE only for top ranked model
- Relax comparison with web PDB
- Simplified folder structure for input data
- Added primary citation for models
- Minor changes in metadata and descriptions
---
 .../translate2modelcif.py                     | 549 ++++++++++++++----
 1 file changed, 422 insertions(+), 127 deletions(-)

diff --git a/projects/dark-matter-metagenomics/translate2modelcif.py b/projects/dark-matter-metagenomics/translate2modelcif.py
index 68455db..e7f5102 100644
--- a/projects/dark-matter-metagenomics/translate2modelcif.py
+++ b/projects/dark-matter-metagenomics/translate2modelcif.py
@@ -12,6 +12,7 @@ import sys
 import zipfile
 
 import pandas as pd
+import numpy as np
 
 import ihm
 import ihm.citations
@@ -22,15 +23,138 @@ import modelcif.model
 import modelcif.protocol
 import modelcif.reference
 
-from ost import io
+from ost import io, geom
 
 
-# EXAMPLES for running:
-# ost translate2modelcif.py ./raw_data ./raw_data/all_ptm_plddt.txt  \
-#     ./web_dloads/pivot ./modelcif --prefix=F000347 \
+# EXAMPLE for running:
+# ost translate2modelcif.py ./raw_data/all_ptm_plddt.txt ./raw_data/all_pdb \
+#     ./all_pae ./web_dloads/pivot ./modelcif --prefix=F000347 \
 #     --pdb-web-path=./web_dloads/pdb \
 #     --refseq-path=./web_dloads/consensus_all.fasta
-# NOTE: add "--compress" for final runs
+# NOTE: add "--compress --all-models --all-pae" for final runs
+
+
+################################################################################
+# HELPERS (mostly copied from from ost/modules/io/tests/test_io_omf.py)
+# to compare PDBs
+def _compare_atoms(a1, a2, occupancy_thresh=0.01, bfactor_thresh=0.01,
+                   dist_thresh=0.001):
+    if abs(a1.occupancy - a2.occupancy) > occupancy_thresh:
+        return False
+    if abs(a1.b_factor - a2.b_factor) > bfactor_thresh:
+        return False
+    # modification: look at x,y,z spearately
+    if abs(a1.pos.x - a2.pos.x) > dist_thresh:
+        return False
+    if abs(a1.pos.y - a2.pos.y) > dist_thresh:
+        return False
+    if abs(a1.pos.z - a2.pos.z) > dist_thresh:
+        return False
+    if a1.is_hetatom != a2.is_hetatom:
+        return False
+    if a1.element != a2.element:
+        return False
+    return True
+
+def _compare_residues(r1, r2, at_occupancy_thresh=0.01,
+                      at_bfactor_thresh=0.01, at_dist_thresh=0.001,
+                      skip_ss=False, skip_rnums=False):
+    if r1.GetName() != r2.GetName():
+        return False
+    if skip_rnums is False:
+        if r1.GetNumber() != r2.GetNumber():
+            return False
+    if skip_ss is False:
+        if str(r1.GetSecStructure()) != str(r2.GetSecStructure()):
+            return False
+    if r1.one_letter_code != r2.one_letter_code:
+        return False
+    if r1.chem_type != r2.chem_type:
+        return False
+    if r1.chem_class != r2.chem_class:
+        return False
+    anames1 = [a.GetName() for a in r1.atoms]
+    anames2 = [a.GetName() for a in r2.atoms]
+    if sorted(anames1) != sorted(anames2):
+        return False
+    anames = anames1
+    for aname in anames:
+        a1 = r1.FindAtom(aname)
+        a2 = r2.FindAtom(aname)
+        if not _compare_atoms(a1, a2,
+                              occupancy_thresh=at_occupancy_thresh,
+                              bfactor_thresh=at_bfactor_thresh,
+                              dist_thresh=at_dist_thresh):
+            return False
+    return True
+
+def _compare_chains(ch1, ch2, at_occupancy_thresh=0.01,
+                    at_bfactor_thresh=0.01, at_dist_thresh=0.001,
+                    skip_ss=False, skip_rnums=False):
+    if len(ch1.residues) != len(ch2.residues):
+        return False
+    for r1, r2 in zip(ch1.residues, ch2.residues):
+        if not _compare_residues(r1, r2,
+                                 at_occupancy_thresh=at_occupancy_thresh,
+                                 at_bfactor_thresh=at_bfactor_thresh,
+                                 at_dist_thresh=at_dist_thresh,
+                                 skip_ss=skip_ss, skip_rnums=skip_rnums):
+            return False
+    return True
+
+def _compare_bonds(ent1, ent2):
+    bonds1 = list()
+    for b in ent1.bonds:
+        bond_partners = [str(b.first), str(b.second)]
+        bonds1.append([min(bond_partners), max(bond_partners), b.bond_order])
+    bonds2 = list()
+    for b in ent2.bonds:
+        bond_partners = [str(b.first), str(b.second)]
+        bonds2.append([min(bond_partners), max(bond_partners), b.bond_order])
+    return sorted(bonds1) == sorted(bonds2)
+
+def _compare_ent(ent1, ent2, at_occupancy_thresh=0.01,
+                 at_bfactor_thresh=0.01, at_dist_thresh=0.001,
+                 skip_ss=False, skip_cnames=False, skip_bonds=False,
+                 skip_rnums=False, bu_idx=None):
+    if bu_idx is not None:
+        if ent1.GetName() + ' ' + str(bu_idx) != ent2.GetName():
+            return False
+    else:
+        if ent1.GetName() != ent2.GetName():
+            return False
+    chain_names_one = [ch.GetName() for ch in ent1.chains]
+    chain_names_two = [ch.GetName() for ch in ent2.chains]
+    if skip_cnames:
+        # only check whether we have the same number of chains
+        if len(chain_names_one) != len(chain_names_two):
+            return False
+    else:
+        if chain_names_one != chain_names_two:
+            return False
+    for ch1, ch2 in zip(ent1.chains, ent2.chains):
+        if not _compare_chains(ch1, ch2,
+                              at_occupancy_thresh=at_occupancy_thresh,
+                              at_bfactor_thresh=at_bfactor_thresh,
+                              at_dist_thresh=at_dist_thresh,
+                              skip_ss=skip_ss, skip_rnums=skip_rnums):
+            return False
+    if not skip_bonds:
+        if not _compare_bonds(ent1, ent2):
+            return False
+    return True
+
+def _compare_pdbs(f1, f2):
+    """Use atom-by-atom comparison on PDB files allowing num. errors."""
+    # first do simple file diff.
+    if filecmp.cmp(f1, f2):
+        return True
+    else:
+        ent1 = io.LoadPDB(f1)
+        ent2 = io.LoadPDB(f2)
+        # allow a bit more errors as input files can have rounding errors
+        return _compare_ent(ent1, ent2, 0.011, 0.011, 0.0011, True, False, True)
+################################################################################
 
 
 def _abort_msg(msg, exit_code=1):
@@ -76,16 +200,25 @@ def _parse_args():
     )
 
     parser.add_argument(
-        "model_base_dir",
+        "metadata_file",
         type=str,
-        metavar="<MODEL BASE DIR>",
-        help="Directory with pub_data* directories with model PDBs.",
+        metavar="<METADATA FILE>",
+        help='Path to table with metadata. Excpected columns: "ID" (Family ID '
+        + 'F...), "mdl" (AF2-model-name), "pTM", "pLDDT".',
     )
     parser.add_argument(
-        "metadata_file",
+        "model_dir",
         type=str,
-        metavar="<METADATA FILE>",
-        help="Path to table with metadata.",
+        metavar="<MODEL DIR>",
+        help='Directory with model PDBs named "{ID}_{mdl}.pdb" (with ID and '
+        + 'mdl matching info in metadata).',
+    )
+    parser.add_argument(
+        "pae_dir",
+        type=str,
+        metavar="<PAE DIR>",
+        help='Directory with PAE text files named "{ID}_{mdl}.txt.gz" (with ID '
+        + 'and mdl matching info in metadata).',
     )
     parser.add_argument(
         "msa_data_dir",
@@ -107,12 +240,31 @@ def _parse_args():
         help="Only process families starting with given prefix. By default "
         + "all families are processed.",
     )
+    parser.add_argument(
+        "--all-models",
+        default=False,
+        action="store_true",
+        help="Process all 5 models for each family (top ranked as main, others "
+        + "as accompanying data).",
+    )
+    parser.add_argument(
+        "--all-pae",
+        default=False,
+        action="store_true",
+        help="Store PAE for all models instead of just top ranked one.",
+    )
+    parser.add_argument(
+        "--all-msa",
+        default=False,
+        action="store_true",
+        help="Store MSA with all models instead of just once.",
+    )
     parser.add_argument(
         "--compress",
         default=False,
         action="store_true",
         help="Compress ModelCIF file with gzip "
-        "(note that acc. data is zipped either way).",
+        + "(note that acc. data is zipped either way).",
     )
     parser.add_argument(
         "--pdb-web-path",
@@ -133,8 +285,9 @@ def _parse_args():
     opts = parser.parse_args()
 
     # check input
-    opts.model_base_dir = _check_opts_folder(opts.model_base_dir)
     _check_file(opts.metadata_file)
+    opts.model_dir = _check_opts_folder(opts.model_dir)
+    opts.pae_dir = _check_opts_folder(opts.pae_dir)
     opts.msa_data_dir = _check_opts_folder(opts.msa_data_dir)
     # check out_dir
     if opts.out_dir.endswith("/"):
@@ -173,8 +326,15 @@ class _LocalPLDDT(modelcif.qa_metric.Local, modelcif.qa_metric.PLDDT):
     software = None
 
 
+class _LocalPairwisePAE(modelcif.qa_metric.LocalPairwise, modelcif.qa_metric.PAE):
+    """Predicted aligned error (in Angstroms)"""
+
+    name = "PAE"
+    software = None
+
+
 class _NmpfamsdbTrgRef(modelcif.reference.TargetReference):
-    """PDB as target reference."""
+    """NMPFamsDB as target reference."""
 
     name = "Other"
     other_details = "NMPFamsDB"
@@ -220,6 +380,8 @@ class _OST2ModelCIF(modelcif.model.AbInitioModel):
 
         # use auth IDs for res. nums and chain names
         self.use_auth = False
+        # what accuracy to use for PAE? (writer uses 3 anyway)
+        self.pae_digits = 3
 
         # fetch plddts per residue
         self.plddts = []
@@ -259,20 +421,39 @@ class _OST2ModelCIF(modelcif.model.AbInitioModel):
         )
 
         # local scores
+        lpae = []
         i = 0
         for chn_i in self.ost_entity.chains:
             ch_name = _get_ch_name(chn_i, self.use_auth)
             for res_i in chn_i.residues:
                 # local pLDDT
-                res_num = _get_res_num(res_i, self.use_auth)
+                res_num_i = _get_res_num(res_i, self.use_auth)
                 self.qa_metrics.append(
                     _LocalPLDDT(
-                        self.asym[ch_name].residue(res_num),
+                        self.asym[ch_name].residue(res_num_i),
                         self.plddts[i],
                     )
                 )
                 i += 1
 
+                # PAE needs to go by residue index as it also stores ones
+                # for missing residues (i.e. X)
+                if "pae" in self.scores_json:
+                    pae_i = self.scores_json["pae"][res_num_i - 1]
+                    for chn_j in self.ost_entity.chains:
+                        for res_j in chn_j.residues:
+                            res_num_j = _get_res_num(res_j, self.use_auth)
+                            pae_ij = pae_i[res_num_j - 1]
+                            lpae.append(
+                                _LocalPairwisePAE(
+                                    self.asym[chn_i.name].residue(res_num_i),
+                                    self.asym[chn_j.name].residue(res_num_j),
+                                    round(pae_ij, self.pae_digits),
+                                )
+                            )
+
+        self.qa_metrics.extend(lpae)
+
 
 def _get_audit_authors():
     """Return the list of authors that produced this model."""
@@ -312,40 +493,22 @@ def _get_metadata(metadata_file):
     return metadata
 
 
-def _get_pdb_files(model_base_dir, model_dir_prfx="all_pdb"):
-    """Collect PDB files from pub_data_* folders.
-
-    model_dir_prfx was "pub_data" for Sergey's old data.
+def _get_pdb_files(model_dir):
+    """Collect PDB files from model_dir.
 
     Returns dict with key = family name and value = list of paths to PDB files.
     """
     pdb_files_split = {}  # to return
-    pdb_files_raw = set()  # to check for duplicates
-    pub_paths = [
-        f for f in os.listdir(model_base_dir) if f.startswith(model_dir_prfx)
+    pdb_files = [
+        f for f in os.listdir(model_dir) if not f.startswith(".")
     ]
-    # NOTE: we sort pub_paths to ensure that pub_data_02 is before _03
-    for pub_path in sorted(pub_paths):
-        sub_path = os.path.join(model_base_dir, pub_path)
-        pdb_files_new = [
-            f for f in os.listdir(sub_path) if not f.startswith(".")
-        ]
-        for f in pdb_files_new:
-            f_path = os.path.join(sub_path, f)
-            f_name = f.split("_")[0]
-            if f_name in pdb_files_split:
-                pdb_files_split[f_name].append(f_path)
-            else:
-                pdb_files_split[f_name] = [f_path]
-        # check global list
-        pdb_files_new_set = set(pdb_files_new)
-        new_duplicates = pdb_files_raw.intersection(pdb_files_new_set)
-        if new_duplicates:
-            _warn_msg(
-                f"{len(new_duplicates)} duplicated files found in "
-                f"{sub_path}."
-            )
-        pdb_files_raw = pdb_files_raw.union(pdb_files_new_set)
+    for f in pdb_files:
+        f_path = os.path.join(model_dir, f)
+        f_name = f.split("_")[0]
+        if f_name in pdb_files_split:
+            pdb_files_split[f_name].append(f_path)
+        else:
+            pdb_files_split[f_name] = [f_path]
     return pdb_files_split
 
 
@@ -358,7 +521,7 @@ def _get_config():
     )
     mdl_description = (
         "Model generated using AlphaFold (v2.0.0 with models "
-        "finetuned to return ptm weights) producing 5 models, "
+        "fine-tuned to return pTM weights) producing 5 models, "
         "without model relaxation, without templates, ranked "
         "by pLDDT, starting from a custom MSA."
     )
@@ -471,16 +634,15 @@ def _get_title(fam_name):
     return f"AlphaFold model for NMPFamsDB Family {fam_name}"
 
 
-def _get_model_details(fam_name):
+def _get_model_details(fam_name, max_pLDDT, max_pTM):
     """Get the model description."""
     db_url = f"https://bib.fleming.gr/NMPFamsDB/family?id={fam_name}"
-    # TODO: check if ok to use HTML for the URL
-    db_url = f'<a href="{db_url}" target="_blank">{db_url}</a>'
     return (
-        f"Model generated using AlphaFold (v2.0.0) for the "
-        f'"Representative Sequence" of NMPFamsDB Metagenome / '
-        f"Metatranscriptome Family {fam_name}.\n\nSee {db_url} for "
-        f"additional details."
+        f'Model generated using AlphaFold (v2.0.0) for the "Representative '
+        f'Sequence" of NMPFamsDB Metagenome / Metatranscriptome Family '
+        f"{fam_name}.\n\nThe 5 produced models reached a max. global pLDDT of "
+        f"{round(max_pLDDT, 3)} and max. pTM of {round(max_pTM, 3)}.\n\n"
+        f"See {db_url} for additional details."
     )
 
 
@@ -523,10 +685,8 @@ def _get_entities(pdb_file, ref_seq, fam_name):
     if exp_seq != ref_seq.string:
         raise RuntimeError(f"Sequence in {pdb_file} does not match ref_seq")
 
-    # TODO: waiting for input on whether they change handling of "X" in seq
-    # -> HERE: assuming that "X" were in modelled sequence and PDB can have gaps
     cif_ent = {
-        "seqres": ref_seq.string,  # HACK for testing: .replace('X', 'A')
+        "seqres": ref_seq.string,
         "pdb_sequence": sqe_gaps,
         "pdb_chain_id": [_get_ch_name(chn, False)],
         "fam_name": fam_name,
@@ -565,6 +725,29 @@ def _get_modelcif_entities(target_ents, asym_units, system):
         system.target_entities.append(mdlcif_ent)
 
 
+def _get_assoc_pae_file(entry_id, mdl_name):
+    """Generate a associated file object to extract PAE to extra file."""
+    return modelcif.associated.LocalPairwiseQAScoresFile(
+        f"{mdl_name}_local_pairwise_qa.cif",
+        categories=["_ma_qa_metric_local_pairwise"],
+        copy_categories=["_ma_qa_metric"],
+        entry_id=entry_id,
+        entry_details="This file is an associated file consisting "
+        + "of local pairwise QA metrics. This is a partial mmCIF "
+        + "file and can be validated by merging with the main "
+        + "mmCIF file containing the model coordinates and other "
+        + "associated data.",
+        details="Predicted aligned error",
+    )
+
+
+def _get_aln_data():
+    """Generate Data object for ALN."""
+    aln_data = modelcif.data.Data("Custom MSA for modelling")
+    aln_data.data_content_type = "coevolution MSA"
+    return aln_data
+
+
 def _get_assoc_aln_file(fle_path):
     """Generate a modelcif.associated.File object pointing to FASTA formatted
     file containing MSA.
@@ -572,7 +755,7 @@ def _get_assoc_aln_file(fle_path):
     cfile = modelcif.associated.File(
         fle_path,
         details="Custom MSA for modelling",
-        data=modelcif.data.Data("Custom MSA for modelling"),
+        data=_get_aln_data(),
     )
     cfile.file_format = "fasta"
     cfile.file_content = "multiple sequence alignments"
@@ -685,7 +868,36 @@ def _package_associated_files(repo):
                 os.remove(zfile.path)
 
 
-def _store_as_modelcif(data_json, ost_ent, out_dir, mdl_name, compress):
+def _get_assoc_mdl_file(fle_path, data_json):
+    """Generate a modelcif.associated.File object that looks like a CIF file.
+    The dedicated CIFFile functionality in modelcif would also try to write it.
+    """
+    cfile = modelcif.associated.File(
+        fle_path,
+        details=f"#{data_json['mdl_rank']} ranked model; "
+        + f"pTM {round(data_json['ptm'], 3)}, "
+        + f"pLDDT {round(data_json['plddt_global'], 3)}",
+    )
+    cfile.file_format = "cif"
+    return cfile
+
+
+def _get_assoc_zip_file(fle_path, data_json):
+    """Create a modelcif.associated.File object that looks like a ZIP file.
+    This is NOT the archive ZIP file for the PAEs but to store that in the
+    ZIP archive of the selected model."""
+    zfile = modelcif.associated.File(
+        fle_path,
+        details="archive with multiple files for "
+        + f"#{data_json['mdl_rank']} ranked model",
+    )
+    zfile.file_format = "other"
+    return zfile
+
+
+def _store_as_modelcif(
+    data_json, ost_ent, out_dir, mdl_name, compress, add_pae, add_aln, add_files
+):
     """Mix all the data into a ModelCIF file."""
     print("    generating ModelCIF objects...", end="")
     pstart = timer()
@@ -696,6 +908,28 @@ def _store_as_modelcif(data_json, ost_ent, out_dir, mdl_name, compress):
         model_details=data_json["model_details"],
     )
 
+    # add primary citation (not using from_pubmed_id to ensure that author names
+    # have no special chars)
+    system.citations.append(ihm.Citation(
+        pmid="37821698",
+        title="Unraveling the functional dark matter through global "
+        + "metagenomics.",
+        journal="Nature",
+        volume=622,
+        page_range=(594, 602),
+        year=2023,
+        authors=[
+            'Pavlopoulos, G.A.', 'Baltoumas, F.A.', 'Liu, S.', 'Selvitopi, O.',
+            'Camargo, A.P.', 'Nayfach, S.', 'Azad, A.', 'Roux, S.', 'Call, L.',
+            'Ivanova, N.N.', 'Chen, I.M.', 'Paez-Espino, D.', 'Karatzas, E.',
+            'Iliopoulos, I.', 'Konstantinidis, K.', 'Tiedje, J.M.',
+            'Pett-Ridge, J.', 'Baker, D.', 'Visel, A.', 'Ouzounis, C.A.',
+            'Ovchinnikov, S.', 'Buluc, A.', 'Kyrpides, N.C.'
+        ],
+        doi="10.1038/s41586-023-06583-7",
+        is_primary=True,
+    ))
+
     # create an asymmetric unit and an entity per target sequence
     asym_units = {}
     _get_modelcif_entities(data_json["target_entities"], asym_units, system)
@@ -723,14 +957,25 @@ def _store_as_modelcif(data_json, ost_ent, out_dir, mdl_name, compress):
     system.model_groups.append(model_group)
 
     # handle additional files
-    aln_file = _get_assoc_aln_file(data_json["aln_file_name"])
-    system.repositories.append(_get_associated_files(mdl_name, [aln_file]))
+    arc_files = []
+    if add_pae:
+        arc_files.append(_get_assoc_pae_file(system.id, mdl_name))
+    if add_aln:
+        aln_file = _get_assoc_aln_file(data_json["aln_file_name"])
+        arc_files.append(aln_file)
+        aln_data = aln_file.data
+    else:
+        aln_data = _get_aln_data()
+        aln_data.data_other_details = "MSA stored with parent entry"
+    arc_files.extend(add_files)
+    if arc_files:
+        system.repositories.append(_get_associated_files(mdl_name, arc_files))
 
     # get data and steps
     protocol = _get_modelcif_protocol(
         data_json["protocol"],
         system.target_entities,
-        aln_file.data,
+        aln_data,
         model,
     )
     system.protocols.append(protocol)
@@ -739,10 +984,11 @@ def _store_as_modelcif(data_json, ost_ent, out_dir, mdl_name, compress):
     print("    write to disk...", end="", flush=True)
     pstart = timer()
     # copy aln file to compress them
-    shutil.copyfile(
-        data_json["aln_file_path"],
-        os.path.join(out_dir, data_json["aln_file_name"]),
-    )
+    if add_aln:
+        shutil.copyfile(
+            data_json["aln_file_path"],
+            os.path.join(out_dir, data_json["aln_file_name"]),
+        )
     # NOTE: we change path and back while being exception-safe to handle zipfile
     oldpwd = os.getcwd()
     os.chdir(out_dir)
@@ -750,60 +996,34 @@ def _store_as_modelcif(data_json, ost_ent, out_dir, mdl_name, compress):
     try:
         with open(mdl_fle, "w", encoding="ascii") as mmcif_fh:
             modelcif.dumper.write(mmcif_fh, [system])
-        _package_associated_files(system.repositories[0])
+        if arc_files:
+            _package_associated_files(system.repositories[0])
         if compress:
             _compress_cif_file(mdl_fle)
             mdl_fle += ".gz"
     finally:
         os.chdir(oldpwd)
     print(f" ({timer()-pstart:.2f}s)")
-
-
-def _translate2modelcif(f_name, opts, metadata_fam, pdb_files, ref_seq_check):
-    """Convert a model with its accompanying data to ModelCIF."""
-
-    # TODO: unclear what to do with such cases; skipped for now
-    if len(metadata_fam) != 5:
-        _warn_msg(
-            f"Unexpected number of {len(metadata_fam)} models in "
-            f"metadata for family {f_name}. Skipping..."
+    assoc_files = [_get_assoc_mdl_file(mdl_fle, data_json)]
+    if arc_files:
+        assoc_files.append(
+            _get_assoc_zip_file(system.repositories[0].files[0].path, data_json)
         )
-        return
-    #
+    return assoc_files
+
 
+def _translate2modelcif_single(
+    f_name, opts, metadata, pdb_files, mdl_rank, aln_file, aln_path, ref_seq,
+    mdl_details, add_files=[]
+):
+    """Convert a single model with its accompanying data to ModelCIF."""
     mdl_id = f_name
-    # skip if done already
-    if opts.compress:
-        cifext = "cif.gz"
-    else:
-        cifext = "cif"
-    mdl_path = os.path.join(opts.out_dir, f"{mdl_id}.{cifext}")
-    if os.path.exists(mdl_path):
-        print(f"  {mdl_id} already done...")
-        return
+    if mdl_rank > 1:
+        mdl_id += f"_rank_{mdl_rank}_{metadata.mdl}"
 
-    # go for it...
     print(f"  translating {mdl_id}...")
     pdb_start = timer()
 
-    # get aln_data and ref. seq. for this entry
-    aln_file = f"{f_name}.fasta"
-    aln_path = os.path.join(opts.msa_data_dir, aln_file)
-    # TODO: if we need to handle files without ALN, this needs fixing
-    # -> e.g. 11 extra models in pub_data_* cannot be handled right now
-    if not os.path.exists(aln_path):
-        _warn_msg(
-            f"Cannot deal with missing MSA for {f_name} (yet). " f"Skipping..."
-        )
-        return
-
-    aln = io.LoadAlignment(
-        aln_path
-    )  # note: this checks that it's an actual MSA
-    ref_seq = aln.sequences[0]
-    if ref_seq_check is not None and ref_seq_check.string != ref_seq.string:
-        raise RuntimeError(f"Sequence mismatch for {f_name}")
-
     # gather data into JSON-like structure
     print("    preparing data...", end="")
     pstart = timer()
@@ -813,50 +1033,68 @@ def _translate2modelcif(f_name, opts, metadata_fam, pdb_files, ref_seq_check):
     mdlcf_json["audit_authors"] = _get_audit_authors()
     mdlcf_json["protocol"] = _get_protocol_steps_and_software(config_data)
     mdlcf_json["config_data"] = config_data
-    mdlcf_json["mdl_id"] = mdl_id
+    mdlcf_json["mdl_id"] = mdl_id  # used for entry ID
+    mdlcf_json["mdl_rank"] = mdl_rank
     mdlcf_json["aln_file_name"] = aln_file
     mdlcf_json["aln_file_path"] = aln_path
 
     # find model to process
-    # TODO: here just top pLDDT model processed; extend for more if needed...
-    top_metadata = metadata_fam.loc[metadata_fam.pLDDT.idxmax()]
-    pdb_list_sel = [f for f in pdb_files if top_metadata.mdl in f]
+    pdb_list_sel = [f for f in pdb_files if metadata.mdl in f]
     if len(pdb_list_sel) != 1:
-        # this should only happen if duplicated file in pub_data_*
-        # TODO: for now no warning shown and we just pick first hit
-        # -> first hit is from lowest "pub_data_*"
-        # -> unclear if we should worry about it...
-        pass
-    mdlcf_json["mdl_name"] = f"Top ranked model ({top_metadata.mdl})"
-
-    # get scores for this entry
-    mdlcf_json["plddt_global"] = top_metadata.pLDDT
-    mdlcf_json["ptm"] = top_metadata.pTM
+        # this should never happen
+        raise RuntimeError(
+            f"Multiple file matches found for {metadata.mdl} in {f_name}"
+        )
+    if mdl_rank == 1:
+        mdlcf_json["mdl_name"] = f"Top ranked model ({metadata.mdl})"
+    else:
+        mdlcf_json["mdl_name"] = f"#{mdl_rank} ranked model ({metadata.mdl})"
 
     # process coordinates
     pdb_file = pdb_list_sel[0]
     target_entities, ost_ent = _get_entities(pdb_file, ref_seq, f_name)
     mdlcf_json["target_entities"] = target_entities
     # sanity check (only for top ranked model!)
-    if opts.pdb_web_path is not None:
+    if mdl_rank == 1 and opts.pdb_web_path is not None:
         pdb_file_web = os.path.join(opts.pdb_web_path, f"{f_name}.pdb")
-        if not filecmp.cmp(pdb_file, pdb_file_web):
-            raise RuntimeError(
+        if not _compare_pdbs(pdb_file, pdb_file_web):
+            # for now just as warning...(TODO: CHECK)
+            _warn_msg(
                 f"PDB file mismatch web vs top-ranked for " f"{f_name}"
             )
 
+    # get scores for this entry
+    mdlcf_json["plddt_global"] = metadata.pLDDT
+    mdlcf_json["ptm"] = metadata.pTM
+    add_pae = (mdl_rank == 1 or opts.all_pae)
+    if add_pae:
+        pdb_basename = os.path.basename(pdb_file)
+        pae_basename = os.path.splitext(pdb_basename)[0] + ".txt.gz"
+        pae_file = os.path.join(opts.pae_dir, pae_basename)
+        _check_file(pae_file)
+        mdlcf_json["pae"] = np.loadtxt(pae_file)
+        exp_num_res = len(ref_seq.string)
+        if mdlcf_json["pae"].shape != (exp_num_res, exp_num_res):
+            raise RuntimeError(f"Unexpected PAE shape in {pae_file}")
+
+
     # fill annotations
     mdlcf_json["title"] = _get_title(f_name)
-    mdlcf_json["model_details"] = _get_model_details(f_name)
+    if mdl_rank != 1:
+        mdlcf_json["title"] += f" (#{mdl_rank} ranked model)"
+    mdlcf_json["model_details"] = mdl_details
     mdlcf_json["model_group_name"] = _get_model_group_name()
     print(f" ({timer()-pstart:.2f}s)")
 
     # save ModelCIF
-    _store_as_modelcif(
-        mdlcf_json, ost_ent, opts.out_dir, mdl_id, opts.compress
+    add_aln = (mdl_rank == 1 or opts.all_msa)
+    assoc_files = _store_as_modelcif(
+        mdlcf_json, ost_ent, opts.out_dir, mdl_id, opts.compress, add_pae,
+        add_aln, add_files
     )
 
     # check if result can be read and has expected seq.
+    mdl_path = os.path.join(opts.out_dir, assoc_files[0].path)
     ent, ss = io.LoadMMCIF(mdl_path, seqres=True)
     exp_seqs = [
         trg_ent["pdb_sequence"] for trg_ent in mdlcf_json["target_entities"]
@@ -877,6 +1115,63 @@ def _translate2modelcif(f_name, opts, metadata_fam, pdb_files, ref_seq_check):
 
     print(f"  ... done with {mdl_id} ({timer()-pdb_start:.2f}s).")
 
+    return assoc_files
+
+
+def _translate2modelcif(f_name, opts, metadata_fam, pdb_files, ref_seq_check):
+    """Convert a family of models with their accompanying data to ModelCIF."""
+
+    # expected to have exactly 5 models per family
+    if len(metadata_fam) != 5:
+        raise RuntimeError(
+            f"Unexpected number of {len(metadata_fam)} models in "
+            f"metadata for family {f_name}."
+        )
+
+    # skip if done already
+    if opts.compress:
+        cifext = "cif.gz"
+    else:
+        cifext = "cif"
+    mdl_path = os.path.join(opts.out_dir, f"{f_name}.{cifext}")
+    if os.path.exists(mdl_path):
+        print(f"  {f_name} already done...")
+        return
+
+    # get aln_data and ref. seq. for this entry
+    aln_file = f"{f_name}.fasta"
+    aln_path = os.path.join(opts.msa_data_dir, aln_file)
+    # expected 11 extra families compared to web data but those don't have MSAs
+    # -> skipped for consistency and to keep code simple here
+    if not os.path.exists(aln_path):
+        _warn_msg(f"Missing MSA for {f_name}. Skipping...")
+        return
+
+    aln = io.LoadAlignment(aln_path) # note: this checks that it's an actual MSA
+    ref_seq = aln.sequences[0]
+    if ref_seq_check is not None and ref_seq_check.string != ref_seq.string:
+        raise RuntimeError(f"Sequence mismatch for {f_name}")
+
+    # get global model details
+    mdl_details = _get_model_details(
+        f_name, metadata_fam.pLDDT.max(), metadata_fam.pTM.max()
+    )
+    # rank available models
+    metadata_sorted = metadata_fam.sort_values("pLDDT", ascending=False)
+    add_files = []
+    if opts.all_models:
+        for idx in range(1, 5):
+            assoc_files = _translate2modelcif_single(
+                f_name, opts, metadata_sorted.iloc[idx], pdb_files, idx + 1,
+                aln_file, aln_path, ref_seq, mdl_details
+            )
+            add_files.extend(assoc_files)
+    # process top ranked one
+    _translate2modelcif_single(
+        f_name, opts, metadata_sorted.iloc[0], pdb_files, 1,
+        aln_file, aln_path, ref_seq, mdl_details, add_files
+    )
+
 
 def _main():
     """Run as script."""
@@ -884,7 +1179,7 @@ def _main():
 
     # parse/fetch global data
     metadata_full = _get_metadata(opts.metadata_file)
-    pdb_files_split = _get_pdb_files(opts.model_base_dir)
+    pdb_files_split = _get_pdb_files(opts.model_dir)
     if opts.refseq_path is not None:
         refseqs = io.LoadSequenceList(opts.refseq_path)
     else:
-- 
GitLab