#! /usr/local/bin/ost
"""Translate models from Edward from PDB + extra data into ModelCIF."""

# EXAMPLES for running:
"""
GT test setup:
ost scripts/translate2modelcif.py "./InputFiles/sample_files" \
    "./InputFiles/ASFV-G_proteome_accessions.csv" \
    --out_dir="./modelcif"
For full translation (takes ~6min on laptop):
ost scripts/translate2modelcif.py "./InputFiles/AlphaFold-RENAME" \
    "./InputFiles/ASFV-G_proteome_accessions.csv" \
    --out_dir="./modelcif" > script_out.txt
"""

import argparse
import datetime
import os
import sys
import gzip, shutil, zipfile

from timeit import default_timer as timer
import numpy as np
import requests
import ujson as json
import pandas as pd
import xml.dom.minidom

import ihm
import ihm.citations
import modelcif
import modelcif.associated
import modelcif.dumper
import modelcif.model
import modelcif.protocol
import modelcif.reference

from ost import io


def _parse_args():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(
        formatter_class=argparse.RawDescriptionHelpFormatter,
        description=__doc__,
    )

    parser.add_argument(
        "model_dir",
        type=str,
        metavar="<MODEL DIR>",
        help="Directory with model(s) to be translated.",
    )
    parser.add_argument(
        "metadata_file",
        type=str,
        metavar="<METADATA FILE>",
        help="Path to CSV file with metadata.",
    )
    parser.add_argument(
        "--out_dir",
        type=str,
        metavar="<OUTPUT DIR>",
        default="",
        help="Path to separate path to store results " \
             "(model_dir used, if none given).",
    )
    parser.add_argument(
        "--compress",
        default=False,
        action="store_true",
        help="Compress ModelCIF file with gzip " \
             "(note that QA file is zipped either way).",
    )

    opts = parser.parse_args()

    # check that model dir exists
    if opts.model_dir.endswith("/"):
        opts.model_dir = opts.model_dir[:-1]
    if not os.path.exists(opts.model_dir):
        _abort_msg(f"Model directory '{opts.model_dir}' does not exist.")
    if not os.path.isdir(opts.model_dir):
        _abort_msg(f"Path '{opts.model_dir}' does not point to a directory.")
    # check metadata_file
    if not os.path.exists(opts.metadata_file):
        _abort_msg(f"Metadata file '{opts.metadata_file}' does not exist.")
    if not os.path.isfile(opts.metadata_file):
        _abort_msg(f"Path '{opts.metadata_file}' does not point to a file.")
    # check out_dir
    if not opts.out_dir:
        opts.out_dir = opts.model_dir
    else:
        if not os.path.exists(opts.out_dir):
            _abort_msg(f"Output directory '{opts.out_dir}' does not exist.")
        if not os.path.isdir(opts.out_dir):
            _abort_msg(f"Path '{opts.out_dir}' does not point to a directory.")

    return opts


# pylint: disable=too-few-public-methods
class _GlobalPLDDT(modelcif.qa_metric.Global, modelcif.qa_metric.PLDDT):
    """Predicted accuracy according to the CA-only lDDT in [0,100]"""
    name = "pLDDT"
    software = None


class _LocalPLDDT(modelcif.qa_metric.Local, modelcif.qa_metric.PLDDT):
    """Predicted accuracy according to the CA-only lDDT in [0,100]"""
    name = "pLDDT"
    software = None

class _NcbiTrgRef(modelcif.reference.TargetReference):
    """NCBI as target reference."""
    name = "NCBI"
    other_details = None
# pylint: enable=too-few-public-methods


class _OST2ModelCIF(modelcif.model.AbInitioModel):
    """Map OST entity elements to ihm.model"""

    def __init__(self, *args, **kwargs):
        """Initialise a model"""
        self.ost_entity = kwargs.pop("ost_entity")
        self.asym = kwargs.pop("asym")

        # fetch plddts per atom and per residue
        self.plddt_entity = kwargs.pop("plddt_entity")
        if self.plddt_entity:
            bf_ent = self.plddt_entity
        else:
            bf_ent = self.ost_entity
        self.plddts = []
        self.atm_bfactors = {}
        for a in bf_ent.atoms:
            res_idx = a.residue.number.num - 1
            assert res_idx <= len(self.plddts)
            if res_idx < len(self.plddts):
                assert a.b_factor == self.plddts[res_idx]
            else:
                self.plddts.append(a.b_factor)
            self.atm_bfactors[a.qualified_name] = a.b_factor

        super().__init__(*args, **kwargs)

    def get_atoms(self):
        # ToDo [internal]: Take B-factor out since its not a B-factor?
        for atm in self.ost_entity.atoms:
            if self.plddt_entity:
                b_factor = self.atm_bfactors[atm.qualified_name]
            else:
                b_factor = atm.b_factor
            yield modelcif.model.Atom(
                asym_unit=self.asym[atm.chain.name],
                seq_id=atm.residue.number.num,
                atom_id=atm.name,
                type_symbol=atm.element,
                x=atm.pos[0],
                y=atm.pos[1],
                z=atm.pos[2],
                het=atm.is_hetatom,
                biso=b_factor,
                occupancy=atm.occupancy,
            )

    def add_scores(self):
        """Add QA metrics from AF2 scores."""
        # global scores
        self.qa_metrics.append(
            _GlobalPLDDT(np.mean(self.plddts))
        )

        # local scores
        i = 0
        for chn_i in self.ost_entity.chains:
            for res_i in chn_i.residues:
                # local pLDDT
                self.qa_metrics.append(
                    _LocalPLDDT(
                        self.asym[chn_i.name].residue(res_i.number.num),
                        self.plddts[i],
                    )
                )
                i += 1


def _abort_msg(msg, exit_code=1):
    """Write error message and exit with exit_code."""
    print(f"{msg}\nAborting.", file=sys.stderr)
    sys.exit(exit_code)


def _check_file(file_path):
    """Make sure a file exists and is actually a file."""
    if not os.path.exists(file_path):
        _abort_msg(f"File not found: '{file_path}'.")
    if not os.path.isfile(file_path):
        _abort_msg(f"File path does not point to file: '{file_path}'.")


def _get_audit_authors():
    """Return the list of authors that produced this model."""
    return (
        "Spinard, Edward",
        "Azzinaro, Paul",
        "Rai, Ayushi",
        "Espinoza, Nallely",
        "Ramirez-Medina, Elizabeth",
        "Valladares, Alyssa",
        "Borca, Manuel",
        "Gladue, Douglas"
    )


def _get_metadata(metadata_file):
    """Read csv file with metedata and prepare for next steps."""
    metadata = pd.read_csv(metadata_file)
    # make sure protein and PDB names are unique
    assert len(set(metadata.Protein)) == metadata.shape[0]
    assert len(set(metadata["Associated PDB"])) == metadata.shape[0]
    return metadata.set_index("Protein")


def _get_config(is_special=False):
    """Define AF setup (special case QP509L run with other settings)."""
    if is_special:
        description = "Model generated using the AlphaFold (v2.1.0) " \
                      "colab notebook producing 5 models with 3 recycles " \
                      "each, without model relaxation, without templates, " \
                      "ranked by pLDDT, starting from an MSA with " \
                      "reduced_dbs setting."
        description2 = "The unrelaxed model was minimized and subjected to " \
                       "molecular dynamics for 1 ns using GROMACS."
        descriptions = [description, description2]
        af_config = {
            "db_preset": "reduced_dbs",
            "run_relax": False
        }
    else:
        description = "Model generated using AlphaFold (v2.2.0) " \
                      "producing 5 models with 3 recycles each, with AMBER " \
                      "relaxation, using templates, ranked by pLDDT, " \
                      "starting from an MSA with full_dbs setting."
        descriptions = [description]
        af_config = {
            "model_preset": "monomer",
            "db_preset": "full_dbs",
            "use_gpu_relax": True,
            "max_template_date": "2020-05-14",
        }
    return {
        "af_config": af_config,
        "af_version": "2.1.0" if is_special else "2.2.0",
        "descriptions": descriptions,
        "has_gromacs_step": is_special,
        "use_templates": not is_special,
        "use_small_bfd": is_special
    }


def _get_protocol_steps_and_software(config_data):
    """Create the list of protocol steps with software and parameters used."""
    protocol = []
    
    # modelling step
    step = {
        "method_type": "modeling",
        "name": None,
        "details": config_data["descriptions"][0],
    }
    # get input data
    # Must refer to data already in the JSON, so we try keywords
    step["input"] = "target_sequences"
    # get output data
    # Must refer to existing data, so we try keywords
    step["output"] = "model"
    # get software
    step["software"] = [
        {
            "name": "AlphaFold",
            "classification": "model building",
            "description": "Structure prediction",
            "citation": ihm.citations.alphafold2,
            "location": "https://github.com/deepmind/alphafold",
            "type": "package",
            "version": config_data["af_version"],
        }]
    step["software_parameters"] = config_data["af_config"]
    protocol.append(step)

    # GROMACS step
    if config_data["has_gromacs_step"]:
        step = {
            "method_type": "model refinement",
            "name": None,
            "details": config_data["descriptions"][1],
        }
        step["input"] = "model"
        step["output"] = "model"
        step["software"] = [
        {
            "name": "GROMACS",
            "classification": "refinement",
            "description": "Model relaxation",
            "citation": ihm.Citation(
                pmid=None,
                title="GROMACS: High performance molecular simulations "
                + "through multi-level parallelism from laptops to "
                + "supercomputers.",
                journal="SoftwareX",
                volume=1,
                page_range=(19, 25),
                year=2015,
                authors=[
                    "Abraham, M.J.",
                    "Murtola, T.",
                    "Schulz, R.",
                    "Pall, S.",
                    "Smith, J.C.",
                    "Hess, B.",
                    "Lindahl, E."
                ],
                doi="10.1016/j.softx.2015.06.001",
            ),
            "location": "https://www.gromacs.org",
            "type": "package",
            "version": None,
        }]
        step["software_parameters"] = {}
        protocol.append(step)

    return protocol


def _get_title(mdl_title):
    """Get a title for this modelling experiment."""
    return f"AlphaFold model for {mdl_title}"


def _get_model_details(mdl_descs, mdl_notes):
    """Get the model description."""
    mdl_desc = '\n'.join(mdl_descs)
    if type(mdl_notes) == str:
        # fix typos...
        mdl_notes = mdl_notes.replace("hypthetical", "hypothetical") \
                             .replace("Uniport", "UniProt") \
                             .replace("Uniprot", "UniProt") \
                             .replace("Mislabled", "mislabeled")
        #
        return f"{mdl_desc}\n\nNote: {mdl_notes}."
    else:
        return mdl_desc


def _get_model_group_name():
    """Get a name for a model group."""
    return None


def _get_sequence(chn):
    """Get the sequence out of an OST chain."""
    # initialise
    lst_rn = chn.residues[0].number.num
    idx = 1
    sqe = chn.residues[0].one_letter_code
    if lst_rn != 1:
        sqe = "-"
        idx = 0

    for res in chn.residues[idx:]:
        lst_rn += 1
        while lst_rn != res.number.num:
            sqe += "-"
            lst_rn += 1
        sqe += res.one_letter_code

    return sqe


def _check_sequence(up_ac, sequence):
    """Verify sequence to only contain standard olc."""
    for res in sequence:
        if res not in "ACDEFGHIKLMNPQRSTVWY":
            raise RuntimeError(
                "Non-standard aa found in UniProtKB sequence "
                + f"for entry '{up_ac}': {res}"
            )


def _fetch_upkb_entry(up_ac):
    """Fetch data for an UniProtKB entry."""
    # This is a simple parser for UniProtKB txt format, instead of breaking it up
    # into multiple functions, we just allow many many branches & statements,
    # here.
    # pylint: disable=too-many-branches,too-many-statements
    data = {}
    data["up_organism"] = ""
    data["up_sequence"] = ""
    data["up_ac"] = up_ac
    rspns = requests.get(f"https://www.uniprot.org/uniprot/{up_ac}.txt")
    for line in rspns.iter_lines(decode_unicode=True):
        if line.startswith("ID   "):
            sline = line.split()
            if len(sline) != 5:
                _abort_msg(f"Unusual UniProtKB ID line found:\n'{line}'")
            data["up_id"] = sline[1]
        elif line.startswith("OX   NCBI_TaxID="):
            # Following strictly the UniProtKB format: 'OX   NCBI_TaxID=<ID>;'
            data["up_ncbi_taxid"] = line[len("OX   NCBI_TaxID=") : -1]
            data["up_ncbi_taxid"] = data["up_ncbi_taxid"].split("{")[0].strip()
        elif line.startswith("OS   "):
            if line[-1] == ".":
                data["up_organism"] += line[len("OS   ") : -1]
            else:
                data["up_organism"] += line[len("OS   ") : -1] + " "
        elif line.startswith("SQ   "):
            sline = line.split()
            if len(sline) != 8:
                _abort_msg(f"Unusual UniProtKB SQ line found:\n'{line}'")
            data["up_seqlen"] = int(sline[2])
            data["up_crc64"] = sline[6]
        elif line.startswith("     "):
            sline = line.split()
            if len(sline) > 6:
                _abort_msg(
                    "Unusual UniProtKB sequence data line "
                    + f"found:\n'{line}'"
                )
            data["up_sequence"] += "".join(sline)
        elif line.startswith("RP   "):
            if "ISOFORM" in line.upper():
                RuntimeError(
                    f"First ISOFORM found for '{up_ac}', needs " + "handling."
                )
        elif line.startswith("DT   "):
            # 2012-10-03
            dt_flds = line[len("DT   ") :].split(", ")
            if dt_flds[1].upper().startswith("SEQUENCE VERSION "):
                data["up_last_mod"] = datetime.datetime.strptime(
                    dt_flds[0], "%d-%b-%Y"
                )
        elif line.startswith("GN   Name="):
            data["up_gn"] = line[len("GN   Name=") :].split(";")[0]
            data["up_gn"] = data["up_gn"].split("{")[0].strip()

    # we have not seen isoforms in the data set, yet, so we just set them to '.'
    data["up_isoform"] = None

    if "up_gn" not in data:
        _abort_msg(f"No gene name found for UniProtKB entry '{up_ac}'.")
    if "up_last_mod" not in data:
        _abort_msg(f"No sequence version found for UniProtKB entry '{up_ac}'.")
    if "up_crc64" not in data:
        _abort_msg(f"No CRC64 value found for UniProtKB entry '{up_ac}'.")
    if len(data["up_sequence"]) == 0:
        _abort_msg(f"No sequence found for UniProtKB entry '{up_ac}'.")
    # check that sequence length and CRC64 is correct
    if data["up_seqlen"] != len(data["up_sequence"]):
        _abort_msg(
            "Sequence length of SQ line and sequence data differ for "
            + f"UniProtKB entry '{up_ac}': {data['up_seqlen']} != "
            + f"{len(data['up_sequence'])}"
        )
    _check_sequence(data["up_ac"], data["up_sequence"])

    if "up_id" not in data:
        _abort_msg(f"No ID found for UniProtKB entry '{up_ac}'.")
    if "up_ncbi_taxid" not in data:
        _abort_msg(f"No NCBI taxonomy ID found for UniProtKB entry '{up_ac}'.")
    if len(data["up_organism"]) == 0:
        _abort_msg(f"No organism species found for UniProtKB entry '{up_ac}'.")

    return data


def _check_subset(s1, s2):
    # check if s2 is uniquely contained in s1
    # (and if so, returns values for seq_db_align_begin & seq_db_align_end)
    if s1.count(s2) == 1:
        align_begin = s1.find(s2) + 1
        align_end = align_begin + len(s2) - 1
        return align_begin, align_end
    else:
        return None


def _get_ncbi_sequence(ncbi_ac):
    """Fetch OST sequence object from NCBI web service."""
    # src: https://www.ncbi.nlm.nih.gov/books/NBK25500/#_chapter1_Downloading_Full_Records_
    rspns = requests.get(f"https://eutils.ncbi.nlm.nih.gov/entrez/eutils/" \
                         f"efetch.fcgi?db=protein&id={ncbi_ac}" \
                         f"&rettype=fasta&retmode=text")
    return io.SequenceFromString(rspns.text, "fasta")


def _get_ncbi_info(ncbi_ac):
    """Fetch dict with info from NCBI web service."""
    # src: https://www.ncbi.nlm.nih.gov/books/NBK25499/#chapter4.ESummary
    rspns = requests.get(f"https://eutils.ncbi.nlm.nih.gov/entrez/eutils/" \
                         f"esummary.fcgi?db=protein&id={ncbi_ac}")
    dom = xml.dom.minidom.parseString(rspns.text)
    docsums = dom.getElementsByTagName("DocSum")
    assert len(docsums) == 1
    docsum = docsums[0]
    ncbi_dict = {}
    for cn in docsum.childNodes:
        if cn.nodeName == "Item":
            cn_name = cn.getAttribute("Name")
            cn_type = cn.getAttribute("Type")
            if cn.childNodes:
                d = cn.childNodes[0].data
                if cn_type == "String":
                    ncbi_dict[cn_name] = d
                elif cn_type == "Integer":
                    ncbi_dict[cn_name] = int(d)
                else:
                    raise RuntimeError(f"Unknown type {cn_type} for {ncbi_ac}")
            else:
                ncbi_dict[cn_name] = None
    return ncbi_dict


def _get_entities(pdb_file, mdl_title, up_ac, ncbi_ac):
    """Gather data for the mmCIF (target) entities."""

    ost_ent = io.LoadPDB(pdb_file)
    # sanity checks
    if ost_ent.chain_count != 1:
        raise RuntimeError(
            f"Unexpected oligomer for {mdl_title}"
        )
    chn = ost_ent.chains[0]
    sqe = _get_sequence(chn)
    cif_ent = {
        "pdb_sequence": sqe,
        "pdb_chain_id": chn.name,
        "description": f"{mdl_title} protein"
    }
    # add UniProtKB info
    up_info = _fetch_upkb_entry(up_ac)
    cif_ent.update(up_info)
    if up_info["up_sequence"] != sqe:
        up_range = _check_subset(up_info["up_sequence"], sqe)
        if not up_range:
            raise RuntimeError(f"Inconsistent UP/PDB sequences for {mdl_title}")
    else:
        up_range = (1, cif_ent["up_seqlen"])
    cif_ent["up_range"] = up_range
    # check NCBI sequence
    s_ncbi = _get_ncbi_sequence(ncbi_ac)
    if up_info["up_sequence"] != str(s_ncbi):
        raise RuntimeError(f"Inconsistent UP/NCBI sequences for {mdl_title}")
    # add NCBI info
    ncbi_info = _get_ncbi_info(ncbi_ac)
    if up_info["up_ncbi_taxid"] != str(ncbi_info["TaxId"]):
        raise RuntimeError(f"Inconsistent UP/NCBI taxid for {mdl_title}")
    if ncbi_info["Status"] != "live":
        raise RuntimeError(f"NCBI entry {ncbi_ac} for {mdl_title} not live")
    if ncbi_info["ReplacedBy"]:
        raise RuntimeError(f"Outdated NCBI entry {ncbi_ac} for {mdl_title}")
    if ncbi_info["AccessionVersion"] != ncbi_ac:
        raise RuntimeError(f"NCBI AC is not AC for {mdl_title}")
    cif_ent["ncbi_ac"] = ncbi_ac
    cif_ent["ncbi_gi"] = str(ncbi_info["Gi"])
    cif_ent["ncbi_last_mod"] = datetime.datetime.strptime(
        ncbi_info["UpdateDate"], "%Y/%m/%d"
    )

    return [cif_ent], ost_ent


def _get_modelcif_entities(target_ents, source, asym_units, system):
    """Create ModelCIF entities and asymmetric units."""
    for cif_ent in target_ents:
        mdlcif_ent = modelcif.Entity(
            cif_ent["pdb_sequence"],
            description=cif_ent["description"],
            source=source,
            references=[
                modelcif.reference.UniProt(
                    cif_ent["up_id"],
                    cif_ent["up_ac"],
                    align_begin=cif_ent["up_range"][0],
                    align_end=cif_ent["up_range"][1],
                    isoform=cif_ent["up_isoform"],
                    ncbi_taxonomy_id=cif_ent["up_ncbi_taxid"],
                    organism_scientific=cif_ent["up_organism"],
                    sequence_version_date=cif_ent["up_last_mod"],
                    sequence_crc64=cif_ent["up_crc64"],
                ),
                # NOTE: assume that UP and NCBI match on most things
                _NcbiTrgRef(
                    cif_ent["ncbi_gi"],
                    cif_ent["ncbi_ac"],
                    align_begin=cif_ent["up_range"][0],
                    align_end=cif_ent["up_range"][1],
                    ncbi_taxonomy_id=cif_ent["up_ncbi_taxid"],
                    organism_scientific=cif_ent["up_organism"],
                    sequence_version_date=cif_ent["ncbi_last_mod"]
                )
            ],
        )
        asym_units[cif_ent["pdb_chain_id"]] = modelcif.AsymUnit(
            mdlcif_ent
        )
        system.target_entities.append(mdlcif_ent)


def _assemble_modelcif_software(soft_dict):
    """Create a modelcif.Software instance from dictionary."""
    return modelcif.Software(
        soft_dict["name"],
        soft_dict["classification"],
        soft_dict["description"],
        soft_dict["location"],
        soft_dict["type"],
        soft_dict["version"],
        citation=soft_dict["citation"]
    )


def _get_sequence_dbs(config_data):
    """Get AF seq. DBs."""
    # hard coded UniProt release
    up_version = "2022_01"
    up_rel_date = datetime.datetime(2022, 2, 23)
    # fill list of DBs
    seq_dbs = []
    if config_data["use_small_bfd"]:
        seq_dbs.append(modelcif.ReferenceDatabase(
            "Reduced BFD",
            "https://storage.googleapis.com/alphafold-databases/"
            + "reduced_dbs/bfd-first_non_consensus_sequences.fasta.gz"
        ))
    else:
        seq_dbs.append(modelcif.ReferenceDatabase(
            "BFD",
            "https://storage.googleapis.com/alphafold-databases/"
            + "casp14_versions/"
            + "bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt.tar.gz",
            version="6a634dc6eb105c2e9b4cba7bbae93412",
        ))
    seq_dbs.append(modelcif.ReferenceDatabase(
        "MGnify",
        "https://storage.googleapis.com/alphafold-databases/"
        + "casp14_versions/mgy_clusters_2018_12.fa.gz",
        version="2018_12",
        release_date=datetime.datetime(2018, 12, 6),
    ))
    seq_dbs.append(modelcif.ReferenceDatabase(
        "Uniclust30",
        "https://storage.googleapis.com/alphafold-databases/"
        + "casp14_versions/uniclust30_2018_08_hhsuite.tar.gz",
        version="2018_08",
        release_date=None,
    ))
    seq_dbs.append(modelcif.ReferenceDatabase(
        "TrEMBL",
        "ftp://ftp.ebi.ac.uk/pub/databases/uniprot/current_release/"
        + "knowledgebase/complete/uniprot_trembl.fasta.gz",
        version=up_version,
        release_date=up_rel_date,
    ))
    seq_dbs.append(modelcif.ReferenceDatabase(
        "Swiss-Prot",
        "ftp://ftp.ebi.ac.uk/pub/databases/uniprot/current_release/"
        + "knowledgebase/complete/uniprot_sprot.fasta.gz",
        version=up_version,
        release_date=up_rel_date,
    ))
    seq_dbs.append(modelcif.ReferenceDatabase(
        "UniRef90",
        "ftp://ftp.uniprot.org/pub/databases/uniprot/uniref/uniref90/"
        + "uniref90.fasta.gz",
        version=up_version,
        release_date=up_rel_date,
    ))
    if config_data["use_templates"]:
        seq_dbs.append(modelcif.ReferenceDatabase(
            "PDB70",
            "http://wwwuser.gwdg.de/~compbiol/data/hhsuite/databases/"
            + "hhsuite_dbs/old-releases/pdb70_from_mmcif_200401.tar.gz",
            release_date=datetime.datetime(2020, 4, 1)
        ))
    return seq_dbs


def _get_modelcif_protocol(protocol_steps, target_entities, model, ref_dbs):
    """Create the protocol for the ModelCIF file."""
    protocol = modelcif.protocol.Protocol()
    for js_step in protocol_steps:
        sftwre = None
        if js_step["software"]:
            if len(js_step["software"]) == 1:
                sftwre = _assemble_modelcif_software(js_step["software"][0])
            else:
                sftwre = []
                for sft in js_step["software"]:
                    sftwre.append(_assemble_modelcif_software(sft))
                sftwre = modelcif.SoftwareGroup(elements=sftwre)
            if js_step["software_parameters"]:
                params = []
                for k, v in js_step["software_parameters"].items():
                    params.append(
                        modelcif.SoftwareParameter(k, v)
                    )
                if isinstance(sftwre, modelcif.SoftwareGroup):
                    sftwre.parameters = params
                else:
                    sftwre = modelcif.SoftwareGroup(
                        elements=(sftwre,), parameters=params
                    )

        if js_step["input"] == "target_sequences":
            input_data = modelcif.data.DataGroup(target_entities)
            input_data.extend(ref_dbs)
        elif js_step["input"] == "model":
            input_data = model
        else:
            raise RuntimeError(f"Unknown protocol input: '{js_step['input']}'")
        if js_step["output"] == "model":
            output_data = model
        else:
            raise RuntimeError(
                f"Unknown protocol output: '{js_step['output']}'"
            )
        protocol.steps.append(
            modelcif.protocol.Step(
                input_data=input_data,
                output_data=output_data,
                name=js_step["name"],
                details=js_step["details"],
                software=sftwre,
            )
        )
        protocol.steps[-1].method_type = js_step["method_type"]

    return protocol


def _compress_cif_file(cif_file):
    """Compress cif file and delete original."""
    with open(cif_file, 'rb') as f_in:
        with gzip.open(cif_file + '.gz', 'wb') as f_out:
            shutil.copyfileobj(f_in, f_out)
    os.remove(cif_file)


def _store_as_modelcif(data_json, ost_ent, plddt_entity, out_dir, mdl_name,
                       compress):
    """Mix all the data into a ModelCIF file."""
    print("    generating ModelCIF objects...", end="")
    pstart = timer()
    # create system to gather all the data
    system = modelcif.System(
        title=data_json["title"],
        id=mdl_name.replace(' ', '_').upper(),
        model_details=data_json["model_details"],
    )
    # create target entities, references, source, asymmetric units & assembly
    # for source we assume all chains come from the same taxon
    source = ihm.source.Natural(
        ncbi_taxonomy_id=data_json["target_entities"][0]["up_ncbi_taxid"],
        scientific_name=data_json["target_entities"][0]["up_organism"],
    )

    # create an asymmetric unit and an entity per target sequence
    asym_units = {}
    _get_modelcif_entities(
        data_json["target_entities"], source, asym_units, system
    )

    assembly = modelcif.Assembly(
        asym_units.values()
    )

    # audit_authors
    system.authors.extend(data_json["audit_authors"])

    # set up the model to produce coordinates
    if data_json['mdl_num']:
        mdl_list_name = f"Model {data_json['mdl_num']} (top ranked model)"
    else:
        mdl_list_name = "Top ranked model"
    model = _OST2ModelCIF(
        assembly=assembly,
        asym=asym_units,
        ost_entity=ost_ent,
        plddt_entity=plddt_entity,
        name=mdl_list_name,
    )
    print(f" ({timer()-pstart:.2f}s)")
    print("    processing QA scores...", end="", flush=True)
    pstart = timer()
    model.add_scores()
    print(f" ({timer()-pstart:.2f}s)")

    model_group = modelcif.model.ModelGroup(
        [model], name=data_json["model_group_name"]
    )
    system.model_groups.append(model_group)

    ref_dbs = _get_sequence_dbs(data_json["config_data"])
    protocol = _get_modelcif_protocol(
        data_json["protocol"], system.target_entities, model, ref_dbs
    )
    system.protocols.append(protocol)

    # write modelcif System to file (NOTE: no PAE here!)
    print("    write to disk...", end="", flush=True)
    pstart = timer()
    out_path = os.path.join(out_dir, f"{mdl_name}.cif")
    with open(out_path, "w", encoding="ascii") as mmcif_fh:
        modelcif.dumper.write(mmcif_fh, [system])
    if compress:
        _compress_cif_file(out_path)
    print(f" ({timer()-pstart:.2f}s)")


def _create_json(config_data):
    """Create a dictionary (mimicking JSON) that contains data which is the same
    for all models."""
    data = {}

    data["audit_authors"] = _get_audit_authors()
    data["protocol"] = _get_protocol_steps_and_software(config_data)
    data["config_data"] = config_data

    return data


def _create_model_json(data, pdb_file, md_row):
    """Create a dictionary (mimicking JSON) that contains all the data."""
    data["target_entities"], ost_ent = _get_entities(
        pdb_file, data["mdl_title"], md_row["UniProt_ID"],
        md_row["NCBI_Accession"]
    )
    data["title"] = _get_title(data["mdl_title"])
    data["model_details"] = _get_model_details(
        data["config_data"]["descriptions"], md_row["notes"]
    )
    data["model_group_name"] = _get_model_group_name()

    return ost_ent


def _is_special(file_prfx):
    """Check if there is an unrelaxed file."""
    # if special case, we need separate file to fetch pLDDT and add extra 
    # GROMACS step to protocol
    plddt_path = f"{file_prfx}-unrelaxed.pdb"
    if os.path.exists(plddt_path):
        return plddt_path, True
    else:
        return None, False


def _get_mdl_num(mdl_id):
    """Fetch model number from filename used by AF."""
    # mdl_id example model_4_pred_0 -> fetch 4
    mdl_num = None
    if type(mdl_id) == str:
        mdl_id_split = mdl_id.split('_')
        if len(mdl_id_split) == 4:
            mdl_num = int(mdl_id_split[1])
    return mdl_num


def _main():
    """Run as script."""
    opts = _parse_args()

    # parse/fetch global data
    metadata = _get_metadata(opts.metadata_file)
    if opts.compress:
        cifext = "cif.gz"
    else:
        cifext = "cif"

    # get on with models
    print(f"Working on {opts.model_dir}...")

    # iterate model directory
    for fle in sorted(os.listdir(opts.model_dir)):
        # iterate PDB files
        if not fle.endswith(".pdb"):
            continue
        # check file and if to be done
        mdl_name = os.path.splitext(fle)[0]
        if mdl_name not in metadata.index:
            # skip unknown ones
            continue
        md_row = metadata.loc[mdl_name]
        assert md_row["Associated PDB"] == fle
        file_prfx = os.path.join(opts.model_dir, mdl_name)
        fle = os.path.join(opts.model_dir, fle)
        if os.path.exists(os.path.join(opts.out_dir, f"{mdl_name}.{cifext}")):
            print(f"  {mdl_name} already done...")
            continue

        # go for it
        print(f"  translating {mdl_name}...")
        pdb_start = timer()
        plddt_path, is_special = _is_special(file_prfx)
        config_data = _get_config(is_special)
        mdlcf_json = _create_json(config_data)
        mdlcf_json["mdl_title"] = md_row["_struct.title "]
        mdlcf_json["mdl_num"] = _get_mdl_num(md_row["ranking debugg model ID"])

        # gather data into JSON-like structure
        print("    preparing data...", end="")
        pstart = timer()
        ost_ent = _create_model_json(mdlcf_json, fle, md_row)
        if is_special:
            plddt_entity = io.LoadPDB(plddt_path)
        else:
            plddt_entity = None
        print(f" ({timer()-pstart:.2f}s)")

        _store_as_modelcif(mdlcf_json, ost_ent, plddt_entity, opts.out_dir,
                           mdl_name, opts.compress)
        print(f"  ... done with {mdl_name} ({timer()-pdb_start:.2f}s).")

        # check if result can be read and has expected seq.
        ent = io.LoadMMCIF(os.path.join(opts.out_dir, f"{mdl_name}.{cifext}"))
        assert ent.chain_count == 1, f"Bad chain count {mdl_name}"
        ent_seq = "".join(res.one_letter_code for res in ent.residues)
        up_range = mdlcf_json["target_entities"][0]["up_range"]
        exp_seq = mdlcf_json["target_entities"][0]["up_sequence"]
        exp_seq = exp_seq[up_range[0]-1:up_range[1]]
        assert ent_seq == exp_seq, f"Bad seq. {mdl_name}"

    print(f"... done with {opts.model_dir}.")


if __name__ == "__main__":
    _main()