Skip to content
Snippets Groups Projects
translate2modelcif.py 35.61 KiB
#! /usr/local/bin/ost
"""Translate models from Tara/ Xabi from PDB + extra data into ModelCIF."""

import argparse
import datetime
import gzip
import os
import shutil
import sys
import zipfile

from timeit import default_timer as timer
import numpy as np
import requests
import ujson as json

import ihm
import ihm.citations
import modelcif
import modelcif.associated
import modelcif.dumper
import modelcif.model
import modelcif.protocol
import modelcif.reference

from ost import io


# EXAMPLES for running:
# ost scripts/translate2modelcif.py --rank 1 --out_dir="./modelcif" \
# "A0A1B0GTU1-O75152"


def _parse_args():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(
        formatter_class=argparse.RawDescriptionHelpFormatter,
        description=__doc__,
    )

    parser.add_argument(
        "model_dir",
        type=str,
        metavar="<MODEL DIR>",
        help="Directory with model(s) to be translated. Must be of form "
        + "'<UniProtKB AC>-<UniProtKB AC>'",
    )
    parser.add_argument(
        "--selected_rank",
        type=str,
        default=None,
        help="If a certain model of a modelling project is selected by rank, "
        + "the other models are still translated to ModelCIF but stored as "
        + "accompanying files to the selected model.",
    )
    parser.add_argument(
        "--out_dir",
        type=str,
        metavar="<OUTPUT DIR>",
        default="",
        help="Path to separate path to store results "
        "(<MODEL DIR> used, if none given).",
    )
    parser.add_argument(
        "--compress",
        default=False,
        action="store_true",
        help="Compress ModelCIF file with gzip "
        "(note that QA file is zipped either way).",
    )
    opts = parser.parse_args()

    # check that model dir exists
    if opts.model_dir.endswith("/"):
        opts.model_dir = opts.model_dir[:-1]
    if not os.path.exists(opts.model_dir):
        _abort_msg(f"Model directory '{opts.model_dir}' does not exist.")
    if not os.path.isdir(opts.model_dir):
        _abort_msg(f"Path '{opts.model_dir}' does not point to a directory.")
    # check out_dir
    if not opts.out_dir:
        opts.out_dir = opts.model_dir
    else:
        if not os.path.exists(opts.out_dir):
            _abort_msg(f"Output directory '{opts.out_dir}' does not exist.")
        if not os.path.isdir(opts.out_dir):
            _abort_msg(f"Path '{opts.out_dir}' does not point to a directory.")

    return opts


# pylint: disable=too-few-public-methods
class _GlobalPTM(modelcif.qa_metric.Global, modelcif.qa_metric.PTM):
    """Predicted accuracy according to the TM-score score in [0,1]"""

    name = "pTM"
    software = None


class _GlobalPLDDT(modelcif.qa_metric.Global, modelcif.qa_metric.PLDDT):
    """Predicted accuracy according to the CA-only lDDT in [0,100]"""

    name = "pLDDT"
    software = None


class _LocalPLDDT(modelcif.qa_metric.Local, modelcif.qa_metric.PLDDT):
    """Predicted accuracy according to the CA-only lDDT in [0,100]"""

    name = "pLDDT"
    software = None


class _PAE(modelcif.qa_metric.MetricType):
    """Predicted aligned error (in Angstroms)"""

    type = "PAE"
    other_details = None


class _LocalPairwisePAE(modelcif.qa_metric.LocalPairwise, _PAE):
    """Predicted aligned error (in Angstroms)"""

    name = "PAE"
    software = None


# pylint: enable=too-few-public-methods


class _OST2ModelCIF(modelcif.model.AbInitioModel):
    """Map OST entity elements to ihm.model"""

    def __init__(self, *args, **kwargs):
        """Initialise a model"""
        self.ost_entity = kwargs.pop("ost_entity")
        self.asym = kwargs.pop("asym")

        super().__init__(*args, **kwargs)

    def get_atoms(self):
        for atm in self.ost_entity.atoms:
            yield modelcif.model.Atom(
                asym_unit=self.asym[atm.chain.name],
                seq_id=atm.residue.number.num,
                atom_id=atm.name,
                type_symbol=atm.element,
                x=atm.pos[0],
                y=atm.pos[1],
                z=atm.pos[2],
                het=atm.is_hetatom,
                biso=atm.b_factor,
                occupancy=atm.occupancy,
            )

    def add_scores(self, scores_json, entry_id, mdl_name, add_files):
        """Add QA metrics from AF2 scores."""
        # global scores
        self.qa_metrics.extend(
            (
                _GlobalPLDDT(np.mean(scores_json["plddt"])),
                _GlobalPTM(scores_json["ptm"]),
            )
        )

        # local scores
        lpae = []
        i = 0
        for chn_i in self.ost_entity.chains:
            for res_i in chn_i.residues:
                # local pLDDT
                self.qa_metrics.append(
                    _LocalPLDDT(
                        self.asym[chn_i.name].residue(res_i.number.num),
                        scores_json["plddt"][i],
                    )
                )

                # pairwise alignment error
                j = 0
                # We do a 2nd iteration over the OST entity above doing index
                # magic because it keeps the code cleaner and is only around
                # 0.3s slower than iterating the array directly. Majority of
                # time goes into writing files, anyway.
                for chn_j in self.ost_entity.chains:
                    for res_j in chn_j.residues:
                        lpae.append(
                            _LocalPairwisePAE(
                                self.asym[chn_i.name].residue(res_i.number.num),
                                self.asym[chn_j.name].residue(res_j.number.num),
                                scores_json["pae"][i][j],
                            )
                        )
                        j += 1

                i += 1

        self.qa_metrics.extend(lpae)

        ac_file = f"{mdl_name}_local_pairwise_qa.cif"
        arc_files = [
            modelcif.associated.LocalPairwiseQAScoresFile(
                ac_file,
                categories=["_ma_qa_metric_local_pairwise"],
                copy_categories=["_ma_qa_metric"],
                entry_id=entry_id,
                entry_details="This file is an associated file consisting "
                + "of local pairwise QA metrics. This is a partial mmCIF "
                + "file and can be validated by merging with the main "
                + "mmCIF file containing the model coordinates and other "
                + "associated data.",
                details="Predicted aligned error",
            )
        ]
        if add_files:
            arc_files.extend(add_files)

        return modelcif.associated.Repository(
            "",
            [modelcif.associated.ZipFile(f"{mdl_name}.zip", files=arc_files)],
        )
        # NOTE: by convention MA expects zip file with same name as model-cif


def _abort_msg(msg, exit_code=1):
    """Write error message and exit with exit_code."""
    print(f"{msg}\nAborting.", file=sys.stderr)
    sys.exit(exit_code)


def _check_file(file_path):
    """Make sure a file exists and is actually a file."""
    if not os.path.exists(file_path):
        _abort_msg(f"File not found: '{file_path}'.")
    if not os.path.isfile(file_path):
        _abort_msg(f"File path does not point to file: '{file_path}'.")


def _check_interaction_extra_files_present(model_dir):
    """Make sure some general files are present for an interaction."""
    cnfg = os.path.join(model_dir, "config.json")
    _check_file(cnfg)
    return cnfg


def _check_model_extra_files_present(model_dir, pdb_file):
    """Check that all files needed to process this model are present."""
    uid = os.path.splitext(pdb_file)[0]
    prfx = os.path.join(model_dir, uid)
    scrs = f"{prfx}_scores.json"
    _check_file(scrs)
    return prfx, uid


def _get_audit_authors():
    """Return the list of authors that produced this model."""
    return (
        "Bartolec, T.K.",
        "Vazquez-Campos, X.",
        "Norman, A.",
        "Luong, C.",
        "Payne, R.J.",
        "Wilkins, M.R.",
        "Mackay, J.P.",
        "Low, J.K.K.",
    )


def _parse_colabfold_config(cnfg_file):
    """Read config.json and fetch relevant data from it."""
    # NOTE: following code from
    # https://github.com/sokrypton/ColabFold/blob/main/colabfold/batch.py to
    # understand config

    # fetch and drop fields which are not relevant for model building
    with open(cnfg_file, encoding="utf8") as jfh:
        cf_config = json.load(jfh)
    if "num_queries" in cf_config:
        del cf_config["num_queries"]
    # fetch relevant data
    # -> MSA mode
    if cf_config["msa_mode"] == "MMseqs2 (UniRef+Environmental)":
        seq_dbs = ["UniRef", "Environmental"]
        use_mmseqs = True
        use_msa = True
    elif cf_config["msa_mode"] == "MMseqs2 (UniRef only)":
        seq_dbs = ["UniRef"]
        use_mmseqs = True
        use_msa = True
    elif cf_config["msa_mode"] == "single_sequence":
        seq_dbs = []
        use_mmseqs = False
        use_msa = False
    elif cf_config["msa_mode"] == "custom":
        print(
            "WARNING: Custom MSA mode used. Not clear from config what to do "
            + "here!"
        )
        seq_dbs = []
        use_mmseqs = False
        use_msa = True
    else:
        raise ValueError(f"Unknown msa_mode {cf_config['msa_mode']}")
    # -> model type
    if cf_config["model_type"] == "AlphaFold2-multimer-v1":
        # AF-Multimer as introduced in AlphaFold v2.1.0
        use_multimer = True
        multimer_version = 1
    elif cf_config["model_type"] == "AlphaFold2-multimer-v2":
        # AF-Multimer as introduced in AlphaFold v2.2.0
        use_multimer = True
        multimer_version = 2
    elif cf_config["model_type"] == "AlphaFold2-ptm":
        use_multimer = False
        multimer_version = None
    else:
        raise ValueError(f"Unknown model_type {cf_config['model_type']}")

    # write modeling description
    mdl_description = f"Model generated using ColabFold v{cf_config['version']}"
    if use_multimer:
        mdl_description += f" with AlphaFold-Multimer (v{multimer_version})"
    else:
        mdl_description += " with AlphaFold"
    if cf_config["stop_at_score"] < 100:
        # early stopping feature of ColabFold
        upto = "up to "
    else:
        upto = ""
    mdl_description += (
        f" producing {upto}{cf_config['num_models']} models"
        f" with {upto}{cf_config['num_recycles']} recycles each"
    )
    if cf_config["use_amber"]:
        mdl_description += ", with AMBER relaxation"
    else:
        mdl_description += ", without model relaxation"
    if cf_config["use_templates"]:
        print(
            "WARNING: ColabFold may use PDB70 or custom templates. "
            "Not clear from config!"
        )
        mdl_description += ", using templates"
    else:
        mdl_description += ", without templates"
    if cf_config["rank_by"] == "plddt":
        mdl_description += ", ranked by pLDDT"
    elif cf_config["rank_by"] == "ptmscore":
        mdl_description += ", ranked by pTM"
    elif cf_config["rank_by"] == "multimer":
        mdl_description += ", ranked by ipTM*0.8+pTM*0.2"
    else:
        raise ValueError(f"Unknown rank_by {cf_config['rank_by']}")
    if use_msa:
        mdl_description += ", starting from"
        if use_mmseqs:
            msa_type = "MSA"
        else:
            msa_type = "custom MSA"
        if use_multimer:
            if cf_config["pair_mode"] == "unpaired+paired":
                mdl_description += f" paired and unpaired {msa_type}s"
            elif cf_config["pair_mode"] == "paired":
                mdl_description += f" paired {msa_type}s"
            elif cf_config["pair_mode"] == "unpaired":
                mdl_description += f" unpaired {msa_type}s"
            else:
                raise ValueError(f"Unknown pair_mode {cf_config['pair_mode']}")
        else:
            mdl_description += f" an {msa_type}"
        if use_mmseqs:
            mdl_description += f" from MMseqs2 ({'+'.join(seq_dbs)})"
    else:
        mdl_description += " without an MSA"
    mdl_description += "."

    # write selection description
    slct_description = (
        "Select best model, which is either the top-ranked model as "
        + "determined by the ColabFold pipeline (ipTM*0.8+pTM*0.2), or else "
        + "the model with best congruence with crosslinks reported in the "
        + "related study."
    )
    return {
        "config": cf_config,
        "seq_dbs": seq_dbs,
        "use_mmseqs": use_mmseqs,
        "use_msa": use_msa,
        "use_multimer": use_multimer,
        "multimer_version": multimer_version,
        "modeling_description": mdl_description,
        "selection_description": slct_description,
    }


def _get_protocol_steps_and_software(config_data):
    """Create the list of protocol steps with software and parameters used."""
    protocol = []
    # modelling step
    step = {
        "method_type": "modeling",
        "name": None,
        "details": config_data["modeling_description"],
    }
    # get input data
    # Must refer to data already in the JSON, so we try keywords
    step["input"] = "target_sequences"
    # get output data
    # Must refer to existing data, so we try keywords
    step["output"] = "model"
    # get software
    step["software"] = [
        {
            "name": "ColabFold",
            "classification": "model building",
            "description": "Structure prediction",
            "citation": ihm.citations.colabfold,
            "location": "https://github.com/sokrypton/ColabFold",
            "type": "package",
            "version": "1.2.0",
        }
    ]
    if config_data["use_mmseqs"]:
        step["software"].append(
            {
                "name": "MMseqs2",
                "classification": "data collection",
                "description": "Many-against-Many sequence searching",
                "citation": ihm.citations.mmseqs2,
                "location": "https://github.com/soedinglab/mmseqs2",
                "type": "package",
                "version": None,
            }
        )
    if config_data["use_multimer"]:
        step["software"].append(
            {
                "name": "AlphaFold-Multimer",
                "classification": "model building",
                "description": "Structure prediction",
                "citation": ihm.Citation(
                    pmid=None,
                    title="Protein complex prediction with "
                    + "AlphaFold-Multimer.",
                    journal="bioRxiv",
                    volume=None,
                    page_range=None,
                    year=2021,
                    authors=[
                        "Evans, R.",
                        "O'Neill, M.",
                        "Pritzel, A.",
                        "Antropova, N.",
                        "Senior, A.",
                        "Green, T.",
                        "Zidek, A.",
                        "Bates, R.",
                        "Blackwell, S.",
                        "Yim, J.",
                        "Ronneberger, O.",
                        "Bodenstein, S.",
                        "Zielinski, M.",
                        "Bridgland, A.",
                        "Potapenko, A.",
                        "Cowie, A.",
                        "Tunyasuvunakool, K.",
                        "Jain, R.",
                        "Clancy, E.",
                        "Kohli, P.",
                        "Jumper, J.",
                        "Hassabis, D.",
                    ],
                    doi="10.1101/2021.10.04.463034",
                ),
                "location": "https://github.com/deepmind/alphafold",
                "type": "package",
                "version": None,
            }
        )
    else:
        step["software"].append(
            {
                "name": "AlphaFold",
                "classification": "model building",
                "description": "Structure prediction",
                "citation": ihm.citations.alphafold2,
                "location": "https://github.com/deepmind/alphafold",
                "type": "package",
                "version": None,
            }
        )
    step["software_parameters"] = config_data["config"]
    protocol.append(step)

    # model selection step
    if (
        "selection_description" not in config_data
        or len(config_data["selection_description"]) == 0
    ):
        return protocol

    step = {
        "method_type": "model selection",
        "name": None,
        "details": config_data["selection_description"],
    }
    step["input"] = "model"
    step["output"] = "model"
    step["software"] = []
    step["software_parameters"] = {}
    protocol.append(step)
    return protocol


def _get_title(gene_names):
    """Get a title for this modelling experiment."""
    return f"Predicted interaction between {' and '.join(gene_names)}"


def _get_model_details(gene_names):
    """Get the model description."""
    return (
        f"Dimer model generated for {' and '.join(gene_names)}, produced "
        + "using AlphaFold-Multimer (AlphaFold v2.2.0) as implemented by "
        + "ColabFold (v1.2.0) which uses MMseqs2 for MSA generation (UniRef30 "
        + "+ Environmental)."
    )


def _get_model_group_name():
    """Get a name for a model group."""
    return "Crosslinked Heterodimer AlphaFold-Multimer v2 Models"


def _get_sequence(chn):
    """Get the sequence out of an OST chain."""
    # initialise
    lst_rn = chn.residues[0].number.num
    idx = 1
    sqe = chn.residues[0].one_letter_code
    if lst_rn != 1:
        sqe = "-"
        idx = 0

    for res in chn.residues[idx:]:
        lst_rn += 1
        while lst_rn != res.number.num:
            sqe += "-"
            lst_rn += 1
        sqe += res.one_letter_code
    return sqe


def _check_sequence(up_ac, sequence):
    """Verify sequence to only contain standard olc."""
    for res in sequence:
        if res not in "ACDEFGHIKLMNPQRSTVWY":
            raise RuntimeError(
                "Non-standard aa found in UniProtKB sequence "
                + f"for entry '{up_ac}': {res}"
            )


def _fetch_upkb_entry(up_ac):
    """Fetch data for an UniProtKB entry."""
    # This is a simple parser for UniProtKB txt format, instead of breaking it
    # up into multiple functions, we just allow many many branches & statements,
    # here.
    # pylint: disable=too-many-branches,too-many-statements
    data = {}
    data["up_organism"] = ""
    data["up_sequence"] = ""
    data["up_ac"] = up_ac
    rspns = requests.get(
        f"https://www.uniprot.org/uniprot/{up_ac}.txt", timeout=180
    )
    for line in rspns.iter_lines(decode_unicode=True):
        if line.startswith("ID   "):
            sline = line.split()
            if len(sline) != 5:
                _abort_msg(f"Unusual UniProtKB ID line found:\n'{line}'")
            data["up_id"] = sline[1]
        elif line.startswith("OX   NCBI_TaxID="):
            # Following strictly the UniProtKB format: 'OX   NCBI_TaxID=<ID>;'
            data["up_ncbi_taxid"] = line[len("OX   NCBI_TaxID=") : -1]
            data["up_ncbi_taxid"] = data["up_ncbi_taxid"].split("{")[0].strip()
        elif line.startswith("OS   "):
            if line[-1] == ".":
                data["up_organism"] += line[len("OS   ") : -1]
            else:
                data["up_organism"] += line[len("OS   ") : -1] + " "
        elif line.startswith("SQ   "):
            sline = line.split()
            if len(sline) != 8:
                _abort_msg(f"Unusual UniProtKB SQ line found:\n'{line}'")
            data["up_seqlen"] = int(sline[2])
            data["up_crc64"] = sline[6]
        elif line.startswith("     "):
            sline = line.split()
            if len(sline) > 6:
                _abort_msg(
                    "Unusual UniProtKB sequence data line "
                    + f"found:\n'{line}'"
                )
            data["up_sequence"] += "".join(sline)
        elif line.startswith("RP   "):
            if "ISOFORM" in line.upper():
                RuntimeError(
                    f"First ISOFORM found for '{up_ac}', needs " + "handling."
                )
        elif line.startswith("DT   "):
            # 2012-10-03
            dt_flds = line[len("DT   ") :].split(", ")
            if dt_flds[1].upper().startswith("SEQUENCE VERSION "):
                data["up_last_mod"] = datetime.datetime.strptime(
                    dt_flds[0], "%d-%b-%Y"
                )
        elif line.startswith("GN   Name="):
            data["up_gn"] = line[len("GN   Name=") :].split(";")[0]
            data["up_gn"] = data["up_gn"].split("{")[0].strip()

    # we have not seen isoforms in the data set, yet, so we just set them to '.'
    data["up_isoform"] = None

    if "up_gn" not in data:
        _abort_msg(f"No gene name found for UniProtKB entry '{up_ac}'.")
    if "up_last_mod" not in data:
        _abort_msg(f"No sequence version found for UniProtKB entry '{up_ac}'.")
    if "up_crc64" not in data:
        _abort_msg(f"No CRC64 value found for UniProtKB entry '{up_ac}'.")
    if len(data["up_sequence"]) == 0:
        _abort_msg(f"No sequence found for UniProtKB entry '{up_ac}'.")
    # check that sequence length and CRC64 is correct
    if data["up_seqlen"] != len(data["up_sequence"]):
        _abort_msg(
            "Sequence length of SQ line and sequence data differ for "
            + f"UniProtKB entry '{up_ac}': {data['up_seqlen']} != "
            + f"{len(data['up_sequence'])}"
        )
    _check_sequence(data["up_ac"], data["up_sequence"])

    if "up_id" not in data:
        _abort_msg(f"No ID found for UniProtKB entry '{up_ac}'.")
    if "up_ncbi_taxid" not in data:
        _abort_msg(f"No NCBI taxonomy ID found for UniProtKB entry '{up_ac}'.")
    if len(data["up_organism"]) == 0:
        _abort_msg(f"No organism species found for UniProtKB entry '{up_ac}'.")
    return data


def _get_upkb_for_sequence(sqe, up_ac):
    """Get UniProtKB entry data for given sequence."""
    up_data = _fetch_upkb_entry(up_ac)
    if sqe != up_data["up_sequence"]:
        raise RuntimeError(
            f"Sequences not equal from file: {sqe}, from UniProtKB: "
            + f"{up_data['up_sequence']}"
        )
    return up_data


def _get_entities(pdb_file, up_acs):
    """Gather data for the mmCIF (target) entities."""
    entities = []

    ost_ent = io.LoadPDB(pdb_file)
    for i, chn in enumerate(ost_ent.chains):
        cif_ent = {}
        sqe = _get_sequence(chn)
        upkb = _get_upkb_for_sequence(sqe, up_acs[i])
        cif_ent["pdb_sequence"] = sqe
        cif_ent["pdb_chain_id"] = chn.name
        cif_ent["description"] = (
            f"{upkb['up_organism']} {upkb['up_gn']} " f"({upkb['up_ac']})"
        )
        cif_ent.update(upkb)
        entities.append(cif_ent)
    return entities, ost_ent


def _get_scores(data, prfx):
    """Check that all files needed to process this model are present."""
    scrs_fle = f"{prfx}_scores.json"
    with open(scrs_fle, encoding="utf8") as jfh:
        scrs_json = json.load(jfh)

    # NOTE for reuse of data when iterating multiple models: this will overwrite
    # scores in data but will not delete any scores if prev. models had more...
    data.update(scrs_json)


def _get_modelcif_entities(target_ents, source, asym_units, system):
    """Create ModelCIF entities and asymmetric units."""
    for cif_ent in target_ents:
        mdlcif_ent = modelcif.Entity(
            cif_ent["pdb_sequence"],
            description=cif_ent["description"],
            source=source,
            references=[
                modelcif.reference.UniProt(
                    cif_ent["up_id"],
                    cif_ent["up_ac"],
                    # Careful: alignments are just that easy because of equal
                    # sequences!
                    align_begin=1,
                    align_end=cif_ent["up_seqlen"],
                    isoform=cif_ent["up_isoform"],
                    ncbi_taxonomy_id=cif_ent["up_ncbi_taxid"],
                    organism_scientific=cif_ent["up_organism"],
                    sequence_version_date=cif_ent["up_last_mod"],
                    sequence_crc64=cif_ent["up_crc64"],
                )
            ],
        )
        asym_units[cif_ent["pdb_chain_id"]] = modelcif.AsymUnit(mdlcif_ent)
        system.target_entities.append(mdlcif_ent)


def _assemble_modelcif_software(soft_dict):
    """Create a modelcif.Software instance from dictionary."""
    return modelcif.Software(
        soft_dict["name"],
        soft_dict["classification"],
        soft_dict["description"],
        soft_dict["location"],
        soft_dict["type"],
        soft_dict["version"],
        citation=soft_dict["citation"],
    )


def _get_sequence_dbs(seq_dbs):
    """Get ColabFold seq. DBs."""
    # NOTE: hard coded for ColabFold versions before 2022/07/13
    # -> afterwards UniRef30 updated to 2022_02 (and maybe more changes)
    db_dict = {
        "UniRef": modelcif.ReferenceDatabase(
            "UniRef30",
            "http://wwwuser.gwdg.de/~compbiol/colabfold/uniref30_2103.tar.gz",
            version="2021_03",
        ),
        "Environmental": modelcif.ReferenceDatabase(
            "ColabFold DB",
            "http://wwwuser.gwdg.de/~compbiol/colabfold/"
            + "colabfold_envdb_202108.tar.gz",
            version="2021_08",
        ),
    }
    return [db_dict[seq_db] for seq_db in seq_dbs]


def _get_modelcif_protocol_software(js_step):
    """Assemble software entries for a ModelCIF protocol step."""
    if js_step["software"]:
        if len(js_step["software"]) == 1:
            sftwre = _assemble_modelcif_software(js_step["software"][0])
        else:
            sftwre = []
            for sft in js_step["software"]:
                sftwre.append(_assemble_modelcif_software(sft))
            sftwre = modelcif.SoftwareGroup(elements=sftwre)
        if js_step["software_parameters"]:
            params = []
            for key, val in js_step["software_parameters"].items():
                params.append(modelcif.SoftwareParameter(key, val))
            if isinstance(sftwre, modelcif.SoftwareGroup):
                sftwre.parameters = params
            else:
                sftwre = modelcif.SoftwareGroup(
                    elements=(sftwre,), parameters=params
                )
        return sftwre
    return None


def _get_modelcif_protocol_input(js_step, target_entities, ref_dbs, model):
    """Assemble input data for a ModelCIF protocol step."""
    if js_step["input"] == "target_sequences":
        input_data = modelcif.data.DataGroup(target_entities)
        input_data.extend(ref_dbs)
    elif js_step["input"] == "model":
        input_data = model
    else:
        raise RuntimeError(f"Unknown protocol input: '{js_step['input']}'")
    return input_data


def _get_modelcif_protocol_output(js_step, model):
    """Assemble output data for a ModelCIF protocol step."""
    if js_step["output"] == "model":
        output_data = model
    else:
        raise RuntimeError(f"Unknown protocol output: '{js_step['output']}'")
    return output_data


def _get_modelcif_protocol(protocol_steps, target_entities, model, ref_dbs):
    """Create the protocol for the ModelCIF file."""
    protocol = modelcif.protocol.Protocol()
    for js_step in protocol_steps:
        sftwre = _get_modelcif_protocol_software(js_step)
        input_data = _get_modelcif_protocol_input(
            js_step, target_entities, ref_dbs, model
        )
        output_data = _get_modelcif_protocol_output(js_step, model)

        protocol.steps.append(
            modelcif.protocol.Step(
                input_data=input_data,
                output_data=output_data,
                name=js_step["name"],
                details=js_step["details"],
                software=sftwre,
            )
        )
        protocol.steps[-1].method_type = js_step["method_type"]
    return protocol


def _compress_cif_file(cif_file):
    """Compress cif file and delete original."""
    with open(cif_file, "rb") as f_in:
        with gzip.open(cif_file + ".gz", "wb") as f_out:
            shutil.copyfileobj(f_in, f_out)
    os.remove(cif_file)


def _package_associated_files(repo):
    """Compress associated files into single zip file and delete original."""
    # zip settings tested for good speed vs compression
    for archive in repo.files:
        with zipfile.ZipFile(archive.path, "w", zipfile.ZIP_BZIP2) as cif_zip:
            for zfile in archive.files:
                cif_zip.write(zfile.path, arcname=zfile.path)
                os.remove(zfile.path)


def _store_as_modelcif(
    data_json, ost_ent, out_dir, file_prfx, compress, add_files
):
    """Mix all the data into a ModelCIF file."""
    print("    generating ModelCIF objects...", end="")
    pstart = timer()
    # create system to gather all the data
    system = modelcif.System(
        title=data_json["title"],
        id=data_json["data_block_id"].upper(),
        model_details=data_json["model_details"],
    )

    # create target entities, references, source, asymmetric units & assembly
    # for source we assume all chains come from the same taxon
    # create an asymmetric unit and an entity per target sequence
    asym_units = {}
    _get_modelcif_entities(
        data_json["target_entities"],
        ihm.source.Natural(
            ncbi_taxonomy_id=data_json["target_entities"][0]["up_ncbi_taxid"],
            scientific_name=data_json["target_entities"][0]["up_organism"],
        ),
        asym_units,
        system,
    )

    # audit_authors
    system.authors.extend(data_json["audit_authors"])

    # set up the model to produce coordinates
    if data_json["rank_num"] == 1:
        mdl_list_name = f"Model {data_json['mdl_num']} (top ranked model)"
    else:
        mdl_list_name = (
            f"Model {data_json['mdl_num']} "
            f"(#{data_json['rank_num']} ranked model)"
        )
    model = _OST2ModelCIF(
        assembly=modelcif.Assembly(asym_units.values()),
        asym=asym_units,
        ost_entity=ost_ent,
        name=mdl_list_name,
    )
    print(f" ({timer()-pstart:.2f}s)")
    print("    processing QA scores...", end="", flush=True)
    pstart = timer()
    mdl_name = os.path.basename(file_prfx)
    system.repositories.append(
        model.add_scores(data_json, system.id, mdl_name, add_files)
    )
    print(f" ({timer()-pstart:.2f}s)")

    system.model_groups.append(
        modelcif.model.ModelGroup([model], name=data_json["model_group_name"])
    )

    system.protocols.append(
        _get_modelcif_protocol(
            data_json["protocol"],
            system.target_entities,
            model,
            _get_sequence_dbs(data_json["config_data"]["seq_dbs"]),
        )
    )

    # write modelcif System to file
    print("    write to disk...", end="", flush=True)
    pstart = timer()
    # NOTE: this will dump PAE on path provided in add_scores
    # -> hence we cheat by changing path and back while being exception-safe...
    oldpwd = os.getcwd()
    os.chdir(out_dir)
    mdl_fle = f"{mdl_name}.cif"
    try:
        with open(mdl_fle, "w", encoding="ascii") as mmcif_fh:
            modelcif.dumper.write(mmcif_fh, [system])
        _package_associated_files(system.repositories[0])
        if compress:
            _compress_cif_file(mdl_fle)
    finally:
        os.chdir(oldpwd)
    print(f" ({timer()-pstart:.2f}s)")

    mdl_fle = _get_assoc_mdl_file(mdl_fle, data_json)
    zip_fle = _get_assoc_zip_file(
        system.repositories[0].files[0].path, data_json
    )
    return mdl_fle, zip_fle


def _get_assoc_mdl_file(fle_path, data_json):
    """Generate a modelcif.associated.File object that looks like a CIF file.
    The dedicated CIFFile functionality in modelcif would also try to write it.
    """
    cfile = modelcif.associated.File(
        fle_path,
        details=f"model {data_json['mdl_num']}; rank {data_json['rank_num']}",
    )
    cfile.file_format = "cif"
    return cfile


def _get_assoc_zip_file(fle_path, data_json):
    """Create a modelcif.associated.File object that looks like a ZIP file.
    This is NOT the archive ZIP file for the PAEs but to store that in the
    ZIP archive of the selected model."""
    zfile = modelcif.associated.File(
        fle_path,
        details="archive with multiple files for model "
        + f"{data_json['mdl_num']}; rank {data_json['rank_num']}",
    )
    zfile.file_format = "other"
    return zfile


def _create_interaction_json(config_data):
    """Create a dictionary (mimicking JSON) that contains data which is the same
    for all models."""
    data = {}
    data["audit_authors"] = _get_audit_authors()
    data["protocol"] = _get_protocol_steps_and_software(config_data)
    data["config_data"] = config_data
    return data


def _create_model_json(data, pdb_file, up_acs, block_id):
    """Create a dictionary (mimicking JSON) that contains all the data."""
    data["target_entities"], ost_ent = _get_entities(pdb_file, up_acs)
    gns = []
    for i in data["target_entities"]:
        gns.append(i["up_gn"])
    data["title"] = _get_title(gns)
    data["data_block_id"] = block_id
    data["model_details"] = _get_model_details(gns)
    data["model_group_name"] = _get_model_group_name()
    return ost_ent


def _translate2modelcif(up_acs, pdb_fle, config_data, opts, add_files):
    """Convert a PDB file with its accompanying data to ModelCIF."""
    pdb_start = timer()
    file_prfx, uid = _check_model_extra_files_present(opts.model_dir, pdb_fle)
    pdb_fle = os.path.join(opts.model_dir, pdb_fle)

    # gather data into JSON-like structure
    print("    preparing data...", end="")
    pstart = timer()

    mdlcf_json = _create_interaction_json(config_data)

    # uid = ..._rank_X_model_Y.pdb
    mdl_name_parts = uid.split("_")
    assert mdl_name_parts[-4] == "rank"
    assert mdl_name_parts[-2] == "model"
    mdlcf_json["rank_num"] = int(mdl_name_parts[-3])
    mdlcf_json["mdl_num"] = int(mdl_name_parts[-1])

    ost_ent = _create_model_json(mdlcf_json, pdb_fle, up_acs, uid)

    # read quality scores from JSON file
    _get_scores(mdlcf_json, file_prfx)
    print(f" ({timer()-pstart:.2f}s)")
    mdlcf_fle, zip_fle = _store_as_modelcif(
        mdlcf_json,
        ost_ent,
        opts.out_dir,
        file_prfx,
        opts.compress,
        add_files,
    )
    return pdb_start, pdb_fle, mdlcf_fle, zip_fle


def _main():
    """Run as script."""
    opts = _parse_args()

    interaction = os.path.split(opts.model_dir)[-1]
    print(f"Working on {interaction}...")

    # get UniProtKB ACs from directory name
    up_acs = interaction.split("-")

    cnfg = _check_interaction_extra_files_present(opts.model_dir)
    config_data = _parse_colabfold_config(cnfg)

    # iterate model directory
    # There is 1 representative for a modelling project, the other models are
    # stored in its ZIP archive.
    not_slctd_mdls = []
    slctd_mdl = None
    for fle in sorted(os.listdir(opts.model_dir)):
        # iterate PDB files
        if not fle.endswith(".pdb"):
            continue
        if (
            opts.selected_rank is not None
            and f"rank_{opts.selected_rank}" in fle
        ):
            slctd_mdl = fle
            continue
        print(f"  translating {fle}...")
        pdb_start, fle, mdlcf_fle, zip_fle = _translate2modelcif(
            up_acs,
            fle,
            config_data,
            opts,
            None,
        )
        print(f"  ... done with {fle} ({timer()-pdb_start:.2f}s).")
        not_slctd_mdls.append(mdlcf_fle)
        not_slctd_mdls.append(zip_fle)
    if opts.selected_rank:
        if slctd_mdl is None:
            _abort_msg(
                f"Could not find model of requested rank '{opts.selected_rank}'"
            )
        print(
            f"  translating selected model {opts.selected_rank} "
            + f"({slctd_mdl})..."
        )
        _translate2modelcif(
            up_acs,
            slctd_mdl,
            config_data,
            opts,
            not_slctd_mdls,
        )
        print(f"  ... done with {slctd_mdl} ({timer()-pdb_start:.2f}s).")

    print(f"... done with {opts.model_dir}.")


if __name__ == "__main__":
    _main()