#! /usr/local/bin/ost
# -*- coding: utf-8 -*-

"""Translate models for Miguel from PDB + extra data into ModelCIF."""

# EXAMPLES for running:
"""
ost translate2modelcif.py ./structures ./accessions.csv ./modelcif \
    --compress > script_out.txt
"""

import argparse
import datetime
import gzip
import os
import shutil
import sys
import zipfile
import pickle
import filecmp
import re
from timeit import default_timer as timer
import numpy as np
import requests
import ujson as json
import gemmi
import pandas as pd

import ihm
import ihm.citations
import modelcif
import modelcif.associated
import modelcif.dumper
import modelcif.model
import modelcif.protocol
import modelcif.reference

from ost import io, seq


def _parse_args():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(
        formatter_class=argparse.RawDescriptionHelpFormatter,
        description=__doc__,
    )

    parser.add_argument(
        "model_dir",
        type=str,
        metavar="<MODEL DIR>",
        help="Directory with PDB files to be translated.",
    )
    parser.add_argument(
        "metadata_file",
        type=str,
        metavar="<METADATA FILE>",
        help="Path to CSV file with metadata.",
    )
    parser.add_argument(
        "out_dir",
        type=str,
        metavar="<OUTPUT DIR>",
        help="Path to directory to store results.",
    )
    parser.add_argument(
        "--compress",
        default=False,
        action="store_true",
        help="Compress ModelCIF file with gzip.",
    )

    opts = parser.parse_args()

    # check that model dir exists
    if opts.model_dir.endswith("/"):
        opts.model_dir = opts.model_dir[:-1]
    if not os.path.exists(opts.model_dir):
        _abort_msg(f"Model directory '{opts.model_dir}' does not exist.")
    if not os.path.isdir(opts.model_dir):
        _abort_msg(f"Path '{opts.model_dir}' does not point to a directory.")
    # check metadata_file
    if not os.path.exists(opts.metadata_file):
        _abort_msg(f"Metadata file '{opts.metadata_file}' does not exist.")
    if not os.path.isfile(opts.metadata_file):
        _abort_msg(f"Path '{opts.metadata_file}' does not point to a file.")
    # check out_dir
    if opts.out_dir.endswith("/"):
        opts.out_dir = opts.out_dir[:-1]
    if not os.path.exists(opts.out_dir):
        os.makedirs(opts.out_dir)
    if not os.path.isdir(opts.out_dir):
        _abort_msg(f"Path '{opts.out_dir}' does not point to a directory.")
    return opts


# pylint: disable=too-few-public-methods
class _GlobalPLDDT(modelcif.qa_metric.Global, modelcif.qa_metric.PLDDT):
    """Predicted accuracy according to the CA-only lDDT in [0,100]"""
    name = "pLDDT"
    software = None

class _LocalPLDDT(modelcif.qa_metric.Local, modelcif.qa_metric.PLDDT):
    """Predicted accuracy according to the CA-only lDDT in [0,100]"""
    name = "pLDDT"
    software = None
# pylint: enable=too-few-public-methods


def _get_res_num(r, use_auth=False):
    """Get res. num. from auth. IDs if reading from mmCIF files."""
    if use_auth:
        return int(r.GetStringProp("pdb_auth_resnum"))
    else:
        return r.number.num


def _get_ch_name(ch, use_auth=False):
    """Get chain name from auth. IDs if reading from mmCIF files."""
    if use_auth:
        return ch.GetStringProp("pdb_auth_chain_name")
    else:
        return ch.name


class _OST2ModelCIF(modelcif.model.AbInitioModel):
    """Map OST entity elements to ihm.model"""

    def __init__(self, *args, **kwargs):
        """Initialise a model"""
        self.ost_entity = kwargs.pop("ost_entity")
        self.asym = kwargs.pop("asym")

        # use auth IDs for res. nums and chain names
        self.use_auth = False

        # fetch plddts per residue
        self.plddts = []
        for res in self.ost_entity.residues:
            b_factors = [a.b_factor for a in res.atoms]
            assert len(set(b_factors)) == 1 # must all be equal!
            self.plddts.append(b_factors[0])

        super().__init__(*args, **kwargs)



    def get_atoms(self):
        # ToDo [internal]: Take B-factor out since its not a B-factor?
        # NOTE: this assumes that _get_res_num maps residue to pos. in seqres
        #       within asym
        for atm in self.ost_entity.atoms:
            yield modelcif.model.Atom(
                asym_unit=self.asym[_get_ch_name(atm.chain, self.use_auth)],
                seq_id=_get_res_num(atm.residue, self.use_auth),
                atom_id=atm.name,
                type_symbol=atm.element,
                x=atm.pos[0],
                y=atm.pos[1],
                z=atm.pos[2],
                het=atm.is_hetatom,
                biso=atm.b_factor,
                occupancy=atm.occupancy,
            )

    def add_scores(self):
        """Add QA metrics from AF2 scores."""
        # global scores
        self.qa_metrics.append(
            _GlobalPLDDT(np.mean(self.plddts))
        )

        # local scores
        i = 0
        for chn_i in self.ost_entity.chains:
            ch_name = _get_ch_name(chn_i, self.use_auth)
            for res_i in chn_i.residues:
                # local pLDDT
                res_num = _get_res_num(res_i, self.use_auth)
                self.qa_metrics.append(
                    _LocalPLDDT(
                        self.asym[ch_name].residue(res_num),
                        self.plddts[i],
                    )
                )
                i += 1


def _abort_msg(msg, exit_code=1):
    """Write error message and exit with exit_code."""
    print(f"{msg}\nAborting.", file=sys.stderr)
    sys.exit(exit_code)


def _warn_msg(msg):
    """Write a warning message to stdout."""
    print(f"WARNING: {msg}")


def _check_file(file_path):
    """Make sure a file exists and is actually a file."""
    if not os.path.exists(file_path):
        _abort_msg(f"File not found: '{file_path}'.")
    if not os.path.isfile(file_path):
        _abort_msg(f"File path does not point to file: '{file_path}'.")


def _get_audit_authors():
    """Return the list of authors that produced this model."""
    return (
        "Correa Marrero, Miguel",
        "Capdevielle, Sylvain",
        "Huang, Weijie",
        "Al-Subhi, Ali M.",
        "Busscher, Marco",
        "Busscher-Lange, Jacqueline",
        "van der Wal, Froukje",
        "de Ridder, Dick",
        "van Dijk, Aalt D.J.",
        "Hogenhout, Saskia A.",
        "Immink, Richard",
    )


def _get_metadata(metadata_file):
    """Read csv file with metedata and prepare for next steps."""
    metadata = pd.read_csv(metadata_file)
    # make sure names (== PDB file names) are unique
    assert len(set(metadata.Name)) == metadata.shape[0]
    # columns: [Accession, Name]
    return metadata


def _get_config():
    """Define AF setup."""
    description = "Model generated using AlphaFold (v2.2.0) producing 5 " \
                  "monomer models with 3 recycles each, without model " \
                  "relaxation, using templates (up to Aug. 4 2022), ranked " \
                  "by pLDDT, starting from an MSA with reduced_dbs setting."
    af_config = {
        "model_preset": "monomer",
        "db_preset": "reduced_dbs",
        "max_template_date": "2022-08-04",
        "run_relax": False,
    }
    return {
        "af_config": af_config,
        "af_version": "2.2.0",
        "description": description,
        "use_templates": True,
        "use_small_bfd": True,
        "use_multimer": False,
    }


def _get_protocol_steps_and_software(config_data):
    """Create the list of protocol steps with software and parameters used."""
    protocol = []
    
    # modelling step
    step = {
        "method_type": "modeling",
        "name": None,
        "details": config_data["description"],
    }
    # get input data
    # Must refer to data already in the JSON, so we try keywords
    step["input"] = "target_sequences"
    # get output data
    # Must refer to existing data, so we try keywords
    step["output"] = "model"
    # get software
    if config_data["use_multimer"]:
        step["software"] = [{
            "name": "AlphaFold-Multimer",
            "classification": "model building",
            "description": "Structure prediction",
            "citation": ihm.Citation(
                pmid=None,
                title="Protein complex prediction with "
                + "AlphaFold-Multimer.",
                journal="bioRxiv",
                volume=None,
                page_range=None,
                year=2021,
                authors=[
                    "Evans, R.",
                    "O'Neill, M.",
                    "Pritzel, A.",
                    "Antropova, N.",
                    "Senior, A.",
                    "Green, T.",
                    "Zidek, A.",
                    "Bates, R.",
                    "Blackwell, S.",
                    "Yim, J.",
                    "Ronneberger, O.",
                    "Bodenstein, S.",
                    "Zielinski, M.",
                    "Bridgland, A.",
                    "Potapenko, A.",
                    "Cowie, A.",
                    "Tunyasuvunakool, K.",
                    "Jain, R.",
                    "Clancy, E.",
                    "Kohli, P.",
                    "Jumper, J.",
                    "Hassabis, D.",
                ],
                doi="10.1101/2021.10.04.463034",
            ),
            "location": "https://github.com/deepmind/alphafold",
            "type": "package",
            "version": config_data["af_version"],
        }]
    else:
        step["software"] = [{
            "name": "AlphaFold",
            "classification": "model building",
            "description": "Structure prediction",
            "citation": ihm.citations.alphafold2,
            "location": "https://github.com/deepmind/alphafold",
            "type": "package",
            "version": config_data["af_version"],
        }]
    step["software_parameters"] = config_data["af_config"]
    protocol.append(step)

    return protocol


def _get_title(prot_name):
    """Get a title for this modelling experiment."""
    return f"AlphaFold2 model of Candidatus Phytoplasma ({prot_name})"


def _get_model_details(prot_name):
    """Get the model description."""
    return f"The AlphaFold2 model of Candidatus Phytoplasma ({prot_name}) is " \
           f"part of a larger structural dataset. The complete dataset " \
           f"comprises AlphaFold2 models of 21 different phytoplasma " \
           f"effectors studied in a protein-protein interaction assay."


def _get_model_group_name():
    """Get a name for a model group."""
    return None


def _get_sequence(chn, use_auth=False):
    """Get the sequence out of an OST chain incl. '-' for gaps in resnums."""
    # initialise (add gaps if first is not at num. 1)
    lst_rn = _get_res_num(chn.residues[0], use_auth)
    idx = 1
    sqe = "-" * (lst_rn - 1) \
        + chn.residues[0].one_letter_code

    for res in chn.residues[idx:]:
        lst_rn += 1
        while lst_rn != _get_res_num(res, use_auth):
            sqe += "-"
            lst_rn += 1
        sqe += res.one_letter_code
    return sqe


def _check_sequence(up_ac, sequence):
    """Verify sequence to only contain standard olc."""
    ns_aa_pos = []  # positions of non-standard amino acids
    for i, res in enumerate(sequence):
        if res not in "ACDEFGHIKLMNPQRSTVWY":
            if res == "U":
                _warn_msg(
                    f"Selenocysteine found at position {i+1} of entry "
                    + f"'{up_ac}', this residue may be missing in the "
                    + "model."
                )
                ns_aa_pos.append(i)
                continue
            raise RuntimeError(
                "Non-standard aa found in UniProtKB sequence "
                + f"for entry '{up_ac}': {res}, position {i+1}"
            )
    return ns_aa_pos


def _get_n_parse_up_entry(up_ac, up_url):
    """Get data for an UniProtKB entry and parse it."""
    # This is a simple parser for UniProtKB txt format, instead of breaking it
    # up into multiple functions, we just allow many many branches & statements,
    # here.
    # pylint: disable=too-many-branches,too-many-statements
    data = {}
    data["up_organism"] = ""
    data["up_sequence"] = ""
    data["up_ac"] = up_ac
    rspns = requests.get(up_url, timeout=180)
    for line in rspns.iter_lines(decode_unicode=True):
        if line.startswith("ID   "):
            sline = line.split()
            if len(sline) != 5:
                raise RuntimeError(f"Unusual UniProtKB ID line found:\n" \
                                   f"'{line}'")
            data["up_id"] = sline[1]
        elif line.startswith("OX   NCBI_TaxID="):
            # Following strictly the UniProtKB format: 'OX   NCBI_TaxID=<ID>;'
            data["up_ncbi_taxid"] = line[len("OX   NCBI_TaxID=") : -1]
            data["up_ncbi_taxid"] = data["up_ncbi_taxid"].split("{")[0].strip()
        elif line.startswith("OS   "):
            if line[-1] == ".":
                data["up_organism"] += line[len("OS   ") : -1]
            else:
                data["up_organism"] += line[len("OS   ") : -1] + " "
        elif line.startswith("SQ   "):
            sline = line.split()
            if len(sline) != 8:
                raise RuntimeError(f"Unusual UniProtKB SQ line found:\n" \
                                   f"'{line}'")
            data["up_seqlen"] = int(sline[2])
            data["up_crc64"] = sline[6]
        elif line.startswith("     "):
            sline = line.split()
            if len(sline) > 6:
                raise RuntimeError(
                    "Unusual UniProtKB sequence data line "
                    + f"found:\n'{line}'"
                )
            data["up_sequence"] += "".join(sline)
        elif line.startswith("RP   "):
            if "ISOFORM" in line.upper():
                raise RuntimeError(
                    f"First ISOFORM found for '{up_ac}', needs handling."
                )
        elif line.startswith("DT   "):
            # 2012-10-03
            dt_flds = line[len("DT   ") :].split(", ")
            if dt_flds[1].upper().startswith("SEQUENCE VERSION "):
                data["up_last_mod"] = datetime.datetime.strptime(
                    dt_flds[0], "%d-%b-%Y"
                )
            elif dt_flds[1].upper().startswith("ENTRY VERSION "):
                data["up_entry_version"] = dt_flds[1][len("ENTRY VERSION ") :]
                if data["up_entry_version"][-1] == ".":
                    data["up_entry_version"] = data["up_entry_version"][:-1]
                data["up_entry_version"] = int(data["up_entry_version"])
        elif line.startswith("GN   Name="):
            data["up_gn"] = line[len("GN   Name=") :].split(";")[0]
            data["up_gn"] = data["up_gn"].split("{")[0].strip()

    # we have not seen isoforms in the data set, yet, so we just set them to '.'
    data["up_isoform"] = None

    # NOTE: no gene names in this set (use provided names instead)
    # if "up_gn" not in data:
    #     _warn_msg(
    #         f"No gene name found for UniProtKB entry '{up_ac}', using "
    #         + "UniProtKB AC instead."
    #     )
    #     data["up_gn"] = up_ac
    if "up_last_mod" not in data:
        raise RuntimeError(f"No sequence version found for UniProtKB entry " \
                           f"'{up_ac}'.")
    if "up_crc64" not in data:
        raise RuntimeError(f"No CRC64 value found for UniProtKB entry " \
                           f"'{up_ac}'.")
    if len(data["up_sequence"]) == 0:
        raise RuntimeError(f"No sequence found for UniProtKB entry '{up_ac}'.")
    # check that sequence length and CRC64 is correct
    if data["up_seqlen"] != len(data["up_sequence"]):
        raise RuntimeError(
            "Sequence length of SQ line and sequence data differ for "
            + f"UniProtKB entry '{up_ac}': {data['up_seqlen']} != "
            + f"{len(data['up_sequence'])}"
        )
    data["up_ns_aa"] = _check_sequence(data["up_ac"], data["up_sequence"])

    if "up_id" not in data:
        raise RuntimeError(f"No ID found for UniProtKB entry '{up_ac}'.")
    if "up_ncbi_taxid" not in data:
        raise RuntimeError(f"No NCBI taxonomy ID found for UniProtKB entry " \
                           f"'{up_ac}'.")
    if len(data["up_organism"]) == 0:
        raise RuntimeError(f"No organism species found for UniProtKB entry " \
                           f"'{up_ac}'.")
    return data


def _fetch_upkb_entry(up_ac):
    """Get an UniProtKB entry."""
    return _get_n_parse_up_entry(
        up_ac, f"https://rest.uniprot.org/uniprotkb/{up_ac}.txt"
    )


def _fetch_unisave_entry(up_ac, version):
    """Get an UniSave entry, in contrast to an UniProtKB entry, that allows us
    to specify a version."""
    return _get_n_parse_up_entry(
        up_ac,
        f"https://rest.uniprot.org/unisave/{up_ac}?format=txt&"
        + f"versions={version}",
    )


def _cmp_sequences(mdl, upkb, ns_aa_pos, deletion_mismatches=True):
    """Compare sequence while paying attention on non-standard amino acids.
    Returns list of mismatches (up_pos, olc_up, olc_mdl) and covered UP range.
    UniProt positions and ranges are 1-indexed.
    Negative "up_pos" relates to (1-indexed) pos. in model for added residues.
    If deletion_mismatches is True, res. in UP seq. but not in mdl within UP
    range are counted as mismatches (N/C-terminal ones are never counted).
    """
    # We add a U to the sequence when necessary. AF2 does not model it. The PDB
    # has selenocysteine as canonical aa, see PDB entry 7Z0T.
    for pos in ns_aa_pos:
        if mdl[pos] != "-":
            _abort_msg(
                f"Position {pos+1} of non-canonical amino acid should be "
                "a gap!"
            )
        mdl = mdl[0:pos] + "U" + mdl[pos + 1 :]
    if mdl == upkb:
        mismatches = []
        up_range = (1, len(mdl))
    else:
        # align and report mismatches
        up_seq = seq.CreateSequence("UP", upkb)
        ch_seq = seq.CreateSequence("CH", mdl)
        aln = seq.alg.SemiGlobalAlign(up_seq, ch_seq, seq.alg.BLOSUM62)[0]
        # get range and mismatches
        aligned_indices = [i for i, c in enumerate(aln) \
                           if c[0] != '-' and c[1] != '-']
        up_range = (
            aln.GetResidueIndex(0, aligned_indices[0]) + 1,
            aln.GetResidueIndex(0, aligned_indices[-1]) + 1,
        )
        mismatches = []
        for idx, (olc_up, olc_mdl) in enumerate(aln):
            if olc_up != olc_mdl:
                # mismatches are either extra res. in mdl/UP or mismatch
                if olc_up == '-':
                    up_pos = -(aln.GetResidueIndex(1, idx) + 1)
                else:
                    up_pos = aln.GetResidueIndex(0, idx) + 1
                    # ignore if out of UP range
                    if up_pos < up_range[0] or up_pos > up_range[1]:
                        continue
                    # optionally ignore extra res. in UP also otherwise
                    if not deletion_mismatches and olc_mdl == '-':
                        continue
                mismatches.append((up_pos, olc_up, olc_mdl))
    return mismatches, up_range


def _get_upkb_for_sequence(sqe, up_ac, up_version=None):
    """Get UniProtKB entry data for given sequence.
    If up_version given, we start from historical data in unisave.
    Returns best possible hit (i.e. fewest mismatches between sqe and UP seq.)
    as dict. with parsed UP data (see _get_n_parse_up_entry) with range covered
    and mismatches (see _cmp_sequences) added as "up_range" and "mismatches".
    """
    if up_version is None:
        up_data = _fetch_upkb_entry(up_ac)
    else:
        up_data = _fetch_unisave_entry(up_ac, up_version)
    min_up_data = None
    while True:
        mismatches, up_range = _cmp_sequences(sqe, up_data["up_sequence"],
                                              up_data["up_ns_aa"])
        if min_up_data is None or \
           len(mismatches) < len(min_up_data["mismatches"]):
            min_up_data = up_data
            min_up_data["mismatches"] = mismatches
            min_up_data["up_range"] = up_range
        if len(mismatches) == 0:
            # found hit; done
            break
        # fetch next one (skip if exceptions happen)
        next_v = up_data["up_entry_version"] - 1
        while next_v > 0:
            try:
                # note: can fail to parse very old UP versions...
                up_data = _fetch_unisave_entry(up_ac, next_v)
                # can move on if no exception happened
                break
            except RuntimeError as ex:
                #_warn_msg(f"Error in parsing v{next_v} of {up_ac}:\n{ex}")
                # try next one
                next_v -= 1
        if next_v == 0:
            # warn user about failure to find match and abort
            min_mismatches = min_up_data["mismatches"]
            msg = f"Sequences not equal from file: {sqe}, from UniProtKB: " \
                  f"{min_up_data['up_sequence']} ({up_ac}), checked entire " \
                  f"entry history and best match had following mismatches " \
                  f"in v{min_up_data['up_entry_version']} (range " \
                  f"{min_up_data['up_range']}): {min_up_data['mismatches']}."
            _warn_msg(msg)
            # raise RuntimeError(msg)
            break
    return min_up_data


def _get_entities(pdb_file, up_ac, prot_name):
    """Gather data for the mmCIF (target) entities."""
    _check_file(pdb_file)
    ost_ent = io.LoadPDB(pdb_file)
    if ost_ent.chain_count != 1:
        raise RuntimeError(
            f"Unexpected oligomer in {pdb_file}"
        )
    chn = ost_ent.chains[0]
    sqe_no_gaps = "".join([res.one_letter_code for res in chn.residues])
    sqe_gaps = _get_sequence(chn)
    if sqe_no_gaps != sqe_gaps:
        raise RuntimeError(f"Sequence in {pdb_file} has gaps for chain " \
                           f"{chn.name}")

    # map to entities
    # special case: A0A1B3JKP4 is obsolete (there may be better ways to catch that)
    if up_ac == "A0A1B3JKP4":
        up_version = 10
    else:
        up_version = None
    upkb = _get_upkb_for_sequence(sqe_no_gaps, up_ac, up_version)
    description = f"Candidatus Phytoplasma ({prot_name})"
    if len(upkb["mismatches"]) != 0:
        description += " (sequence mismatches due to sequencing from experimental assay)"
    cif_ent = {
        "seqres": sqe_no_gaps,
        "pdb_sequence": sqe_no_gaps,
        "pdb_chain_id": [_get_ch_name(chn, False)],
        "prot_name": prot_name,
        "description": description
    }
    cif_ent.update(upkb)

    return [cif_ent], ost_ent


def _get_modelcif_entities(target_ents, asym_units, system):
    """Create ModelCIF entities and asymmetric units."""
    for cif_ent in target_ents:
        mdlcif_ent = modelcif.Entity(
            # NOTE: sequence here defines residues in model!
            cif_ent["seqres"],
            description=cif_ent["description"],
            source=ihm.source.Natural(
                ncbi_taxonomy_id=cif_ent["up_ncbi_taxid"],
                scientific_name=cif_ent["up_organism"],
            ),
            references=[
                modelcif.reference.UniProt(
                    cif_ent["up_id"],
                    cif_ent["up_ac"],
                    align_begin=cif_ent["up_range"][0],
                    align_end=cif_ent["up_range"][1],
                    isoform=cif_ent["up_isoform"],
                    ncbi_taxonomy_id=cif_ent["up_ncbi_taxid"],
                    organism_scientific=cif_ent["up_organism"],
                    sequence_version_date=cif_ent["up_last_mod"],
                    sequence_crc64=cif_ent["up_crc64"],
                )
            ],
        )
        # NOTE: this assigns (potentially new) alphabetic chain names
        for pdb_chain_id in cif_ent["pdb_chain_id"]:
            asym_units[pdb_chain_id] = modelcif.AsymUnit(
                mdlcif_ent, strand_id=pdb_chain_id,
            )
        system.target_entities.append(mdlcif_ent)


def _assemble_modelcif_software(soft_dict):
    """Create a modelcif.Software instance from dictionary."""
    return modelcif.Software(
        soft_dict["name"],
        soft_dict["classification"],
        soft_dict["description"],
        soft_dict["location"],
        soft_dict["type"],
        soft_dict["version"],
        citation=soft_dict["citation"],
    )


def _get_sequence_dbs(config_data):
    """Get AF seq. DBs."""
    # hard coded UniProt release (see https://www.uniprot.org/release-notes)
    # (TO BE UPDATED FOR EVERY DEPOSITION!)
    up_version = "2021_04"
    up_rel_date = datetime.datetime(2021, 11, 17)
    # fill list of DBs
    seq_dbs = []
    if config_data["use_small_bfd"]:
        seq_dbs.append(modelcif.ReferenceDatabase(
            "Reduced BFD",
            "https://storage.googleapis.com/alphafold-databases/"
            + "reduced_dbs/bfd-first_non_consensus_sequences.fasta.gz"
        ))
    else:
        seq_dbs.append(modelcif.ReferenceDatabase(
            "BFD",
            "https://storage.googleapis.com/alphafold-databases/"
            + "casp14_versions/"
            + "bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt.tar.gz",
            version="6a634dc6eb105c2e9b4cba7bbae93412",
        ))
    if config_data["af_version"] < "2.3.0":
        seq_dbs.append(modelcif.ReferenceDatabase(
            "MGnify",
            "https://storage.googleapis.com/alphafold-databases/"
            + "casp14_versions/mgy_clusters_2018_12.fa.gz",
            version="2018_12",
            release_date=datetime.datetime(2018, 12, 6),
        ))
        seq_dbs.append(modelcif.ReferenceDatabase(
            "Uniclust30",
            "https://storage.googleapis.com/alphafold-databases/"
            + "casp14_versions/uniclust30_2018_08_hhsuite.tar.gz",
            version="2018_08",
            release_date=None,
        ))
    else:
        # NOTE: release date according to https://ftp.ebi.ac.uk/pub/databases/metagenomics/peptide_database/2022_05/
        seq_dbs.append(modelcif.ReferenceDatabase(
            "MGnify",
            "https://storage.googleapis.com/alphafold-databases/"
            + "v2.3/mgy_clusters_2022_05.fa.gz",
            version="2022_05",
            release_date=datetime.datetime(2022, 5, 6),
        ))
        seq_dbs.append(modelcif.ReferenceDatabase(
            "UniRef30",
            "https://storage.googleapis.com/alphafold-databases/"
            + "v2.3/UniRef30_2021_03.tar.gz",
            version="2021_03",
            release_date=None,
        ))
    if config_data["use_multimer"]:
        seq_dbs.append(modelcif.ReferenceDatabase(
            "TrEMBL",
            "ftp://ftp.ebi.ac.uk/pub/databases/uniprot/current_release/"
            + "knowledgebase/complete/uniprot_trembl.fasta.gz",
            version=up_version,
            release_date=up_rel_date,
        ))
        seq_dbs.append(modelcif.ReferenceDatabase(
            "Swiss-Prot",
            "ftp://ftp.ebi.ac.uk/pub/databases/uniprot/current_release/"
            + "knowledgebase/complete/uniprot_sprot.fasta.gz",
            version=up_version,
            release_date=up_rel_date,
        ))
    seq_dbs.append(modelcif.ReferenceDatabase(
        "UniRef90",
        "ftp://ftp.uniprot.org/pub/databases/uniprot/uniref/uniref90/"
        + "uniref90.fasta.gz",
        version=up_version,
        release_date=up_rel_date,
    ))
    if config_data["use_templates"]:
        seq_dbs.append(modelcif.ReferenceDatabase(
            "PDB70",
            "http://wwwuser.gwdg.de/~compbiol/data/hhsuite/databases/"
            + "hhsuite_dbs/old-releases/pdb70_from_mmcif_200401.tar.gz",
            release_date=datetime.datetime(2020, 4, 1)
        ))
    return seq_dbs


def _get_modelcif_protocol_software(js_step):
    """Assemble software entries for a ModelCIF protocol step."""
    if js_step["software"]:
        if len(js_step["software"]) == 1:
            sftwre = _assemble_modelcif_software(js_step["software"][0])
        else:
            sftwre = []
            for sft in js_step["software"]:
                sftwre.append(_assemble_modelcif_software(sft))
            sftwre = modelcif.SoftwareGroup(elements=sftwre)
        if js_step["software_parameters"]:
            params = []
            for key, val in js_step["software_parameters"].items():
                params.append(modelcif.SoftwareParameter(key, val))
            if isinstance(sftwre, modelcif.SoftwareGroup):
                sftwre.parameters = params
            else:
                sftwre = modelcif.SoftwareGroup(
                    elements=(sftwre,), parameters=params
                )
        return sftwre
    return None


def _get_modelcif_protocol_input(js_step, target_entities, ref_dbs, model):
    """Assemble input data for a ModelCIF protocol step."""
    if js_step["input"] == "target_sequences":
        input_data = modelcif.data.DataGroup(target_entities)
        input_data.extend(ref_dbs)
    elif js_step["input"] == "model":
        input_data = model
    else:
        raise RuntimeError(f"Unknown protocol input: '{js_step['input']}'")
    return input_data


def _get_modelcif_protocol_output(js_step, model):
    """Assemble output data for a ModelCIF protocol step."""
    if js_step["output"] == "model":
        output_data = model
    else:
        raise RuntimeError(f"Unknown protocol output: '{js_step['output']}'")
    return output_data


def _get_modelcif_protocol(protocol_steps, target_entities, model, ref_dbs):
    """Create the protocol for the ModelCIF file."""
    protocol = modelcif.protocol.Protocol()
    for js_step in protocol_steps:
        sftwre = _get_modelcif_protocol_software(js_step)
        input_data = _get_modelcif_protocol_input(
            js_step, target_entities, ref_dbs, model
        )
        output_data = _get_modelcif_protocol_output(js_step, model)

        protocol.steps.append(
            modelcif.protocol.Step(
                input_data=input_data,
                output_data=output_data,
                name=js_step["name"],
                details=js_step["details"],
                software=sftwre,
            )
        )
        protocol.steps[-1].method_type = js_step["method_type"]
    return protocol


def _compress_cif_file(cif_file):
    """Compress cif file and delete original."""
    with open(cif_file, 'rb') as f_in:
        with gzip.open(cif_file + '.gz', 'wb') as f_out:
            shutil.copyfileobj(f_in, f_out)
    os.remove(cif_file)


def _store_as_modelcif(data_json, ost_ent, out_dir, mdl_name, compress):
    """Mix all the data into a ModelCIF file."""
    print("    generating ModelCIF objects...", end="")
    pstart = timer()
    # create system to gather all the data
    system = modelcif.System(
        title=data_json["title"],
        id=data_json["mdl_id"].upper(),
        model_details=data_json["model_details"],
    )

    # create an asymmetric unit and an entity per target sequence
    asym_units = {}
    _get_modelcif_entities(
        data_json["target_entities"], asym_units, system
    )

    # audit_authors
    system.authors.extend(data_json["audit_authors"])

    # set up the model to produce coordinates
    model = _OST2ModelCIF(
        assembly=modelcif.Assembly(asym_units.values()),
        asym=asym_units,
        ost_entity=ost_ent,
        name="Top ranked model",
    )
    print(f" ({timer()-pstart:.2f}s)")
    print("    processing QA scores...", end="", flush=True)
    pstart = timer()
    model.add_scores()
    print(f" ({timer()-pstart:.2f}s)")

    model_group = modelcif.model.ModelGroup(
        [model], name=data_json["model_group_name"]
    )
    system.model_groups.append(model_group)

    ref_dbs = _get_sequence_dbs(data_json["config_data"])
    protocol = _get_modelcif_protocol(
        data_json["protocol"], system.target_entities, model, ref_dbs
    )
    system.protocols.append(protocol)

    # write modelcif System to file (NOTE: no PAE here!)
    print("    write to disk...", end="", flush=True)
    pstart = timer()
    out_path = os.path.join(out_dir, f"{mdl_name}.cif")
    with open(out_path, "w", encoding="ascii") as mmcif_fh:
        modelcif.dumper.write(mmcif_fh, [system])
    if compress:
        _compress_cif_file(out_path)
    print(f" ({timer()-pstart:.2f}s)")


def _translate2modelcif(up_ac, prot_name, opts):
    """Convert a model with its accompanying data to ModelCIF."""
    mdl_id = prot_name
    # skip if done already (disabled here due to info to be returned)
    if opts.compress:
        cifext = "cif.gz"
    else:
        cifext = "cif"
    mdl_path = os.path.join(opts.out_dir, f"{mdl_id}.{cifext}")
    # if os.path.exists(mdl_path):
    #     print(f"  {mdl_id} already done...")
    #     return

    # go for it...
    print(f"  translating {mdl_id}...")
    pdb_start = timer()

    # gather data into JSON-like structure
    print("    preparing data...", end="")
    pstart = timer()

    # now we can fill all data
    config_data = _get_config()
    mdlcf_json = {}
    mdlcf_json["audit_authors"] = _get_audit_authors()
    mdlcf_json["protocol"] = _get_protocol_steps_and_software(config_data)
    mdlcf_json["config_data"] = config_data
    mdlcf_json["mdl_id"] = mdl_id

    # process coordinates
    pdb_file = os.path.join(opts.model_dir, f"{prot_name}.pdb")
    target_entities, ost_ent = _get_entities(pdb_file, up_ac, prot_name)
    mdlcf_json["target_entities"] = target_entities

    # fill annotations
    mdlcf_json["title"] = _get_title(prot_name)
    mdlcf_json["model_details"] = _get_model_details(prot_name)
    mdlcf_json["model_group_name"] = _get_model_group_name()
    print(f" ({timer()-pstart:.2f}s)")

    # save ModelCIF
    _store_as_modelcif(mdlcf_json, ost_ent, opts.out_dir, mdl_id, opts.compress)

    # check if result can be read and has expected seq.
    ent = io.LoadMMCIF(mdl_path)
    exp_seqs = [trg_ent["pdb_sequence"] \
                for trg_ent in mdlcf_json["target_entities"]]
    assert ent.chain_count == len(exp_seqs), f"Bad chain count {mdl_id}"
    ent_seq = "".join(res.one_letter_code for res in ent.residues)
    assert ent_seq == "".join(exp_seqs), f"Bad seq. {mdl_id}"

    print(f"  ... done with {mdl_id} ({timer()-pdb_start:.2f}s).")


def _main():
    """Run as script."""
    opts = _parse_args()

    # parse/fetch global data
    metadata = _get_metadata(opts.metadata_file)

    # get on with models
    print(f"Working on {opts.metadata_file}...")

    # iterate over models
    for _, mrow in metadata.iterrows():
        up_ac = mrow.Accession.strip()
        prot_name = mrow.Name.strip()
        _translate2modelcif(up_ac, prot_name, opts)

    print(f"... done with {opts.metadata_file}.")


if __name__ == "__main__":
    _main()