Skip to content
Snippets Groups Projects
translate2modelcif.py 32 KiB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939
#! /usr/local/bin/ost
"""Translate models from Edward from PDB + extra data into ModelCIF."""

# EXAMPLES for running:
"""
GT test setup:
ost scripts/translate2modelcif.py "./InputFiles/sample_files" \
    "./InputFiles/ASFV-G_proteome_accessions.csv" \
    --out_dir="./modelcif"
For full translation (takes ~6min on laptop):
ost scripts/translate2modelcif.py "./InputFiles/AlphaFold-RENAME" \
    "./InputFiles/ASFV-G_proteome_accessions.csv" \
    --out_dir="./modelcif" > script_out.txt
"""

import argparse
import datetime
import os
import sys
import gzip, shutil, zipfile

from timeit import default_timer as timer
import numpy as np
import requests
import ujson as json
import pandas as pd
import xml.dom.minidom

import ihm
import ihm.citations
import modelcif
import modelcif.associated
import modelcif.dumper
import modelcif.model
import modelcif.protocol
import modelcif.reference

from ost import io


def _parse_args():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(
        formatter_class=argparse.RawDescriptionHelpFormatter,
        description=__doc__,
    )

    parser.add_argument(
        "model_dir",
        type=str,
        metavar="<MODEL DIR>",
        help="Directory with model(s) to be translated.",
    )
    parser.add_argument(
        "metadata_file",
        type=str,
        metavar="<METADATA FILE>",
        help="Path to CSV file with metadata.",
    )
    parser.add_argument(
        "--out_dir",
        type=str,
        metavar="<OUTPUT DIR>",
        default="",
        help="Path to separate path to store results " \
             "(model_dir used, if none given).",
    )
    parser.add_argument(
        "--compress",
        default=False,
        action="store_true",
        help="Compress ModelCIF file with gzip " \
             "(note that QA file is zipped either way).",
    )

    opts = parser.parse_args()

    # check that model dir exists
    if opts.model_dir.endswith("/"):
        opts.model_dir = opts.model_dir[:-1]
    if not os.path.exists(opts.model_dir):
        _abort_msg(f"Model directory '{opts.model_dir}' does not exist.")
    if not os.path.isdir(opts.model_dir):
        _abort_msg(f"Path '{opts.model_dir}' does not point to a directory.")
    # check metadata_file
    if not os.path.exists(opts.metadata_file):
        _abort_msg(f"Metadata file '{opts.metadata_file}' does not exist.")
    if not os.path.isfile(opts.metadata_file):
        _abort_msg(f"Path '{opts.metadata_file}' does not point to a file.")
    # check out_dir
    if not opts.out_dir:
        opts.out_dir = opts.model_dir
    else:
        if not os.path.exists(opts.out_dir):
            _abort_msg(f"Output directory '{opts.out_dir}' does not exist.")
        if not os.path.isdir(opts.out_dir):
            _abort_msg(f"Path '{opts.out_dir}' does not point to a directory.")

    return opts


# pylint: disable=too-few-public-methods
class _GlobalPLDDT(modelcif.qa_metric.Global, modelcif.qa_metric.PLDDT):
    """Predicted accuracy according to the CA-only lDDT in [0,100]"""
    name = "pLDDT"
    software = None


class _LocalPLDDT(modelcif.qa_metric.Local, modelcif.qa_metric.PLDDT):
    """Predicted accuracy according to the CA-only lDDT in [0,100]"""
    name = "pLDDT"
    software = None

class _NcbiTrgRef(modelcif.reference.TargetReference):
    """NCBI as target reference."""
    name = "NCBI"
    other_details = None
# pylint: enable=too-few-public-methods


class _OST2ModelCIF(modelcif.model.AbInitioModel):
    """Map OST entity elements to ihm.model"""

    def __init__(self, *args, **kwargs):
        """Initialise a model"""
        self.ost_entity = kwargs.pop("ost_entity")
        self.asym = kwargs.pop("asym")

        # fetch plddts per atom and per residue
        self.plddt_entity = kwargs.pop("plddt_entity")
        if self.plddt_entity:
            bf_ent = self.plddt_entity
        else:
            bf_ent = self.ost_entity
        self.plddts = []
        self.atm_bfactors = {}
        for a in bf_ent.atoms:
            res_idx = a.residue.number.num - 1
            assert res_idx <= len(self.plddts)
            if res_idx < len(self.plddts):
                assert a.b_factor == self.plddts[res_idx]
            else:
                self.plddts.append(a.b_factor)
            self.atm_bfactors[a.qualified_name] = a.b_factor

        super().__init__(*args, **kwargs)

    def get_atoms(self):
        # ToDo [internal]: Take B-factor out since its not a B-factor?
        for atm in self.ost_entity.atoms:
            if self.plddt_entity:
                b_factor = self.atm_bfactors[atm.qualified_name]
            else:
                b_factor = atm.b_factor
            yield modelcif.model.Atom(
                asym_unit=self.asym[atm.chain.name],
                seq_id=atm.residue.number.num,
                atom_id=atm.name,
                type_symbol=atm.element,
                x=atm.pos[0],
                y=atm.pos[1],
                z=atm.pos[2],
                het=atm.is_hetatom,
                biso=b_factor,
                occupancy=atm.occupancy,
            )

    def add_scores(self):
        """Add QA metrics from AF2 scores."""
        # global scores
        self.qa_metrics.append(
            _GlobalPLDDT(np.mean(self.plddts))
        )

        # local scores
        i = 0
        for chn_i in self.ost_entity.chains:
            for res_i in chn_i.residues:
                # local pLDDT
                self.qa_metrics.append(
                    _LocalPLDDT(
                        self.asym[chn_i.name].residue(res_i.number.num),
                        self.plddts[i],
                    )
                )
                i += 1


def _abort_msg(msg, exit_code=1):
    """Write error message and exit with exit_code."""
    print(f"{msg}\nAborting.", file=sys.stderr)
    sys.exit(exit_code)


def _check_file(file_path):
    """Make sure a file exists and is actually a file."""
    if not os.path.exists(file_path):
        _abort_msg(f"File not found: '{file_path}'.")
    if not os.path.isfile(file_path):
        _abort_msg(f"File path does not point to file: '{file_path}'.")


def _get_audit_authors():
    """Return the list of authors that produced this model."""
    return (
        "Spinard, Edward",
        "Azzinaro, Paul",
        "Rai, Ayushi",
        "Espinoza, Nallely",
        "Ramirez-Medina, Elizabeth",
        "Valladares, Alyssa",
        "Borca, Manuel",
        "Gladue, Douglas"
    )


def _get_metadata(metadata_file):
    """Read csv file with metedata and prepare for next steps."""
    metadata = pd.read_csv(metadata_file)
    # make sure protein and PDB names are unique
    assert len(set(metadata.Protein)) == metadata.shape[0]
    assert len(set(metadata["Associated PDB"])) == metadata.shape[0]
    return metadata.set_index("Protein")


def _get_config(is_special=False):
    """Define AF setup (special case QP509L run with other settings)."""
    if is_special:
        description = "Model generated using the AlphaFold (v2.1.0) " \
                      "colab notebook producing 5 models with 3 recycles " \
                      "each, without model relaxation, without templates, " \
                      "ranked by pLDDT, starting from an MSA with " \
                      "reduced_dbs setting."
        description2 = "The unrelaxed model was minimized and subjected to " \
                       "molecular dynamics for 1 ns using GROMACS."
        descriptions = [description, description2]
        af_config = {
            "db_preset": "reduced_dbs",
            "run_relax": False
        }
    else:
        description = "Model generated using AlphaFold (v2.2.0) " \
                      "producing 5 models with 3 recycles each, with AMBER " \
                      "relaxation, using templates, ranked by pLDDT, " \
                      "starting from an MSA with full_dbs setting."
        descriptions = [description]
        af_config = {
            "model_preset": "monomer",
            "db_preset": "full_dbs",
            "use_gpu_relax": True,
            "max_template_date": "2020-05-14",
        }
    return {
        "af_config": af_config,
        "af_version": "2.1.0" if is_special else "2.2.0",
        "descriptions": descriptions,
        "has_gromacs_step": is_special,
        "use_templates": not is_special,
        "use_small_bfd": is_special
    }


def _get_protocol_steps_and_software(config_data):
    """Create the list of protocol steps with software and parameters used."""
    protocol = []
    
    # modelling step
    step = {
        "method_type": "modeling",
        "name": None,
        "details": config_data["descriptions"][0],
    }
    # get input data
    # Must refer to data already in the JSON, so we try keywords
    step["input"] = "target_sequences"
    # get output data
    # Must refer to existing data, so we try keywords
    step["output"] = "model"
    # get software
    step["software"] = [
        {
            "name": "AlphaFold",
            "classification": "model building",
            "description": "Structure prediction",
            "citation": ihm.citations.alphafold2,
            "location": "https://github.com/deepmind/alphafold",
            "type": "package",
            "version": config_data["af_version"],
        }]
    step["software_parameters"] = config_data["af_config"]
    protocol.append(step)

    # GROMACS step
    if config_data["has_gromacs_step"]:
        step = {
            "method_type": "model refinement",
            "name": None,
            "details": config_data["descriptions"][1],
        }
        step["input"] = "model"
        step["output"] = "model"
        step["software"] = [
        {
            "name": "GROMACS",
            "classification": "refinement",
            "description": "Model relaxation",
            "citation": ihm.Citation(
                pmid=None,
                title="GROMACS: High performance molecular simulations "
                + "through multi-level parallelism from laptops to "
                + "supercomputers.",
                journal="SoftwareX",
                volume=1,
                page_range=(19, 25),
                year=2015,
                authors=[
                    "Abraham, M.J.",
                    "Murtola, T.",
                    "Schulz, R.",
                    "Pall, S.",
                    "Smith, J.C.",
                    "Hess, B.",
                    "Lindahl, E."
                ],
                doi="10.1016/j.softx.2015.06.001",
            ),
            "location": "https://www.gromacs.org",
            "type": "package",
            "version": None,
        }]
        step["software_parameters"] = {}
        protocol.append(step)

    return protocol


def _get_title(mdl_title):
    """Get a title for this modelling experiment."""
    return f"AlphaFold model for {mdl_title}"


def _get_model_details(mdl_descs, mdl_notes):
    """Get the model description."""
    mdl_desc = '\n'.join(mdl_descs)
    if type(mdl_notes) == str:
        # fix typos...
        mdl_notes = mdl_notes.replace("hypthetical", "hypothetical") \
                             .replace("Uniport", "UniProt") \
                             .replace("Uniprot", "UniProt") \
                             .replace("Mislabled", "mislabeled")
        #
        return f"{mdl_desc}\n\nNote: {mdl_notes}."
    else:
        return mdl_desc


def _get_model_group_name():
    """Get a name for a model group."""
    return None


def _get_sequence(chn):
    """Get the sequence out of an OST chain."""
    # initialise
    lst_rn = chn.residues[0].number.num
    idx = 1
    sqe = chn.residues[0].one_letter_code
    if lst_rn != 1:
        sqe = "-"
        idx = 0

    for res in chn.residues[idx:]:
        lst_rn += 1
        while lst_rn != res.number.num:
            sqe += "-"
            lst_rn += 1
        sqe += res.one_letter_code

    return sqe


def _check_sequence(up_ac, sequence):
    """Verify sequence to only contain standard olc."""
    for res in sequence:
        if res not in "ACDEFGHIKLMNPQRSTVWY":
            raise RuntimeError(
                "Non-standard aa found in UniProtKB sequence "
                + f"for entry '{up_ac}': {res}"
            )


def _fetch_upkb_entry(up_ac):
    """Fetch data for an UniProtKB entry."""
    # This is a simple parser for UniProtKB txt format, instead of breaking it up
    # into multiple functions, we just allow many many branches & statements,
    # here.
    # pylint: disable=too-many-branches,too-many-statements
    data = {}
    data["up_organism"] = ""
    data["up_sequence"] = ""
    data["up_ac"] = up_ac
    rspns = requests.get(f"https://www.uniprot.org/uniprot/{up_ac}.txt")
    for line in rspns.iter_lines(decode_unicode=True):
        if line.startswith("ID   "):
            sline = line.split()
            if len(sline) != 5:
                _abort_msg(f"Unusual UniProtKB ID line found:\n'{line}'")
            data["up_id"] = sline[1]
        elif line.startswith("OX   NCBI_TaxID="):
            # Following strictly the UniProtKB format: 'OX   NCBI_TaxID=<ID>;'
            data["up_ncbi_taxid"] = line[len("OX   NCBI_TaxID=") : -1]
            data["up_ncbi_taxid"] = data["up_ncbi_taxid"].split("{")[0].strip()
        elif line.startswith("OS   "):
            if line[-1] == ".":
                data["up_organism"] += line[len("OS   ") : -1]
            else:
                data["up_organism"] += line[len("OS   ") : -1] + " "
        elif line.startswith("SQ   "):
            sline = line.split()
            if len(sline) != 8:
                _abort_msg(f"Unusual UniProtKB SQ line found:\n'{line}'")
            data["up_seqlen"] = int(sline[2])
            data["up_crc64"] = sline[6]
        elif line.startswith("     "):
            sline = line.split()
            if len(sline) > 6:
                _abort_msg(
                    "Unusual UniProtKB sequence data line "
                    + f"found:\n'{line}'"
                )
            data["up_sequence"] += "".join(sline)
        elif line.startswith("RP   "):
            if "ISOFORM" in line.upper():
                RuntimeError(
                    f"First ISOFORM found for '{up_ac}', needs " + "handling."
                )
        elif line.startswith("DT   "):
            # 2012-10-03
            dt_flds = line[len("DT   ") :].split(", ")
            if dt_flds[1].upper().startswith("SEQUENCE VERSION "):
                data["up_last_mod"] = datetime.datetime.strptime(
                    dt_flds[0], "%d-%b-%Y"
                )
        elif line.startswith("GN   Name="):
            data["up_gn"] = line[len("GN   Name=") :].split(";")[0]
            data["up_gn"] = data["up_gn"].split("{")[0].strip()

    # we have not seen isoforms in the data set, yet, so we just set them to '.'
    data["up_isoform"] = None

    if "up_gn" not in data:
        _abort_msg(f"No gene name found for UniProtKB entry '{up_ac}'.")
    if "up_last_mod" not in data:
        _abort_msg(f"No sequence version found for UniProtKB entry '{up_ac}'.")
    if "up_crc64" not in data:
        _abort_msg(f"No CRC64 value found for UniProtKB entry '{up_ac}'.")
    if len(data["up_sequence"]) == 0:
        _abort_msg(f"No sequence found for UniProtKB entry '{up_ac}'.")
    # check that sequence length and CRC64 is correct
    if data["up_seqlen"] != len(data["up_sequence"]):
        _abort_msg(
            "Sequence length of SQ line and sequence data differ for "
            + f"UniProtKB entry '{up_ac}': {data['up_seqlen']} != "
            + f"{len(data['up_sequence'])}"
        )
    _check_sequence(data["up_ac"], data["up_sequence"])

    if "up_id" not in data:
        _abort_msg(f"No ID found for UniProtKB entry '{up_ac}'.")
    if "up_ncbi_taxid" not in data:
        _abort_msg(f"No NCBI taxonomy ID found for UniProtKB entry '{up_ac}'.")
    if len(data["up_organism"]) == 0:
        _abort_msg(f"No organism species found for UniProtKB entry '{up_ac}'.")

    return data


def _check_subset(s1, s2):
    # check if s2 is uniquely contained in s1
    # (and if so, returns values for seq_db_align_begin & seq_db_align_end)
    if s1.count(s2) == 1:
        align_begin = s1.find(s2) + 1
        align_end = align_begin + len(s2) - 1
        return align_begin, align_end
    else:
        return None


def _get_ncbi_sequence(ncbi_ac):
    """Fetch OST sequence object from NCBI web service."""
    # src: https://www.ncbi.nlm.nih.gov/books/NBK25500/#_chapter1_Downloading_Full_Records_
    rspns = requests.get(f"https://eutils.ncbi.nlm.nih.gov/entrez/eutils/" \
                         f"efetch.fcgi?db=protein&id={ncbi_ac}" \
                         f"&rettype=fasta&retmode=text")
    return io.SequenceFromString(rspns.text, "fasta")


def _get_ncbi_info(ncbi_ac):
    """Fetch dict with info from NCBI web service."""
    # src: https://www.ncbi.nlm.nih.gov/books/NBK25499/#chapter4.ESummary
    rspns = requests.get(f"https://eutils.ncbi.nlm.nih.gov/entrez/eutils/" \
                         f"esummary.fcgi?db=protein&id={ncbi_ac}")
    dom = xml.dom.minidom.parseString(rspns.text)
    docsums = dom.getElementsByTagName("DocSum")
    assert len(docsums) == 1
    docsum = docsums[0]
    ncbi_dict = {}
    for cn in docsum.childNodes:
        if cn.nodeName == "Item":
            cn_name = cn.getAttribute("Name")
            cn_type = cn.getAttribute("Type")
            if cn.childNodes:
                d = cn.childNodes[0].data
                if cn_type == "String":
                    ncbi_dict[cn_name] = d
                elif cn_type == "Integer":
                    ncbi_dict[cn_name] = int(d)
                else:
                    raise RuntimeError(f"Unknown type {cn_type} for {ncbi_ac}")
            else:
                ncbi_dict[cn_name] = None
    return ncbi_dict


def _get_entities(pdb_file, mdl_title, up_ac, ncbi_ac):
    """Gather data for the mmCIF (target) entities."""

    ost_ent = io.LoadPDB(pdb_file)
    # sanity checks
    if ost_ent.chain_count != 1:
        raise RuntimeError(
            f"Unexpected oligomer for {mdl_title}"
        )
    chn = ost_ent.chains[0]
    sqe = _get_sequence(chn)
    cif_ent = {
        "pdb_sequence": sqe,
        "pdb_chain_id": chn.name,
        "description": f"{mdl_title} protein"
    }
    # add UniProtKB info
    up_info = _fetch_upkb_entry(up_ac)
    cif_ent.update(up_info)
    if up_info["up_sequence"] != sqe:
        up_range = _check_subset(up_info["up_sequence"], sqe)
        if not up_range:
            raise RuntimeError(f"Inconsistent UP/PDB sequences for {mdl_title}")
    else:
        up_range = (1, cif_ent["up_seqlen"])
    cif_ent["up_range"] = up_range
    # check NCBI sequence
    s_ncbi = _get_ncbi_sequence(ncbi_ac)
    if up_info["up_sequence"] != str(s_ncbi):
        raise RuntimeError(f"Inconsistent UP/NCBI sequences for {mdl_title}")
    # add NCBI info
    ncbi_info = _get_ncbi_info(ncbi_ac)
    if up_info["up_ncbi_taxid"] != str(ncbi_info["TaxId"]):
        raise RuntimeError(f"Inconsistent UP/NCBI taxid for {mdl_title}")
    if ncbi_info["Status"] != "live":
        raise RuntimeError(f"NCBI entry {ncbi_ac} for {mdl_title} not live")
    if ncbi_info["ReplacedBy"]:
        raise RuntimeError(f"Outdated NCBI entry {ncbi_ac} for {mdl_title}")
    if ncbi_info["AccessionVersion"] != ncbi_ac:
        raise RuntimeError(f"NCBI AC is not AC for {mdl_title}")
    cif_ent["ncbi_ac"] = ncbi_ac
    cif_ent["ncbi_gi"] = str(ncbi_info["Gi"])
    cif_ent["ncbi_last_mod"] = datetime.datetime.strptime(
        ncbi_info["UpdateDate"], "%Y/%m/%d"
    )

    return [cif_ent], ost_ent


def _get_modelcif_entities(target_ents, source, asym_units, system):
    """Create ModelCIF entities and asymmetric units."""
    for cif_ent in target_ents:
        mdlcif_ent = modelcif.Entity(
            cif_ent["pdb_sequence"],
            description=cif_ent["description"],
            source=source,
            references=[
                modelcif.reference.UniProt(
                    cif_ent["up_id"],
                    cif_ent["up_ac"],
                    align_begin=cif_ent["up_range"][0],
                    align_end=cif_ent["up_range"][1],
                    isoform=cif_ent["up_isoform"],
                    ncbi_taxonomy_id=cif_ent["up_ncbi_taxid"],
                    organism_scientific=cif_ent["up_organism"],
                    sequence_version_date=cif_ent["up_last_mod"],
                    sequence_crc64=cif_ent["up_crc64"],
                ),
                # NOTE: assume that UP and NCBI match on most things
                _NcbiTrgRef(
                    cif_ent["ncbi_gi"],
                    cif_ent["ncbi_ac"],
                    align_begin=cif_ent["up_range"][0],
                    align_end=cif_ent["up_range"][1],
                    ncbi_taxonomy_id=cif_ent["up_ncbi_taxid"],
                    organism_scientific=cif_ent["up_organism"],
                    sequence_version_date=cif_ent["ncbi_last_mod"]
                )
            ],
        )
        asym_units[cif_ent["pdb_chain_id"]] = modelcif.AsymUnit(
            mdlcif_ent
        )
        system.target_entities.append(mdlcif_ent)


def _assemble_modelcif_software(soft_dict):
    """Create a modelcif.Software instance from dictionary."""
    return modelcif.Software(
        soft_dict["name"],
        soft_dict["classification"],
        soft_dict["description"],
        soft_dict["location"],
        soft_dict["type"],
        soft_dict["version"],
        citation=soft_dict["citation"]
    )


def _get_sequence_dbs(config_data):
    """Get AF seq. DBs."""
    # hard coded UniProt release
    up_version = "2022_01"
    up_rel_date = datetime.datetime(2022, 2, 23)
    # fill list of DBs
    seq_dbs = []
    if config_data["use_small_bfd"]:
        seq_dbs.append(modelcif.ReferenceDatabase(
            "Reduced BFD",
            "https://storage.googleapis.com/alphafold-databases/"
            + "reduced_dbs/bfd-first_non_consensus_sequences.fasta.gz"
        ))
    else:
        seq_dbs.append(modelcif.ReferenceDatabase(
            "BFD",
            "https://storage.googleapis.com/alphafold-databases/"
            + "casp14_versions/"
            + "bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt.tar.gz",
            version="6a634dc6eb105c2e9b4cba7bbae93412",
        ))
    seq_dbs.append(modelcif.ReferenceDatabase(
        "MGnify",
        "https://storage.googleapis.com/alphafold-databases/"
        + "casp14_versions/mgy_clusters_2018_12.fa.gz",
        version="2018_12",
        release_date=datetime.datetime(2018, 12, 6),
    ))
    seq_dbs.append(modelcif.ReferenceDatabase(
        "Uniclust30",
        "https://storage.googleapis.com/alphafold-databases/"
        + "casp14_versions/uniclust30_2018_08_hhsuite.tar.gz",
        version="2018_08",
        release_date=None,
    ))
    seq_dbs.append(modelcif.ReferenceDatabase(
        "TrEMBL",
        "ftp://ftp.ebi.ac.uk/pub/databases/uniprot/current_release/"
        + "knowledgebase/complete/uniprot_trembl.fasta.gz",
        version=up_version,
        release_date=up_rel_date,
    ))
    seq_dbs.append(modelcif.ReferenceDatabase(
        "Swiss-Prot",
        "ftp://ftp.ebi.ac.uk/pub/databases/uniprot/current_release/"
        + "knowledgebase/complete/uniprot_sprot.fasta.gz",
        version=up_version,
        release_date=up_rel_date,
    ))
    seq_dbs.append(modelcif.ReferenceDatabase(
        "UniRef90",
        "ftp://ftp.uniprot.org/pub/databases/uniprot/uniref/uniref90/"
        + "uniref90.fasta.gz",
        version=up_version,
        release_date=up_rel_date,
    ))
    if config_data["use_templates"]:
        seq_dbs.append(modelcif.ReferenceDatabase(
            "PDB70",
            "http://wwwuser.gwdg.de/~compbiol/data/hhsuite/databases/"
            + "hhsuite_dbs/old-releases/pdb70_from_mmcif_200401.tar.gz",
            release_date=datetime.datetime(2020, 4, 1)
        ))
    return seq_dbs


def _get_modelcif_protocol(protocol_steps, target_entities, model, ref_dbs):
    """Create the protocol for the ModelCIF file."""
    protocol = modelcif.protocol.Protocol()
    for js_step in protocol_steps:
        sftwre = None
        if js_step["software"]:
            if len(js_step["software"]) == 1:
                sftwre = _assemble_modelcif_software(js_step["software"][0])
            else:
                sftwre = []
                for sft in js_step["software"]:
                    sftwre.append(_assemble_modelcif_software(sft))
                sftwre = modelcif.SoftwareGroup(elements=sftwre)
            if js_step["software_parameters"]:
                params = []
                for k, v in js_step["software_parameters"].items():
                    params.append(
                        modelcif.SoftwareParameter(k, v)
                    )
                if isinstance(sftwre, modelcif.SoftwareGroup):
                    sftwre.parameters = params
                else:
                    sftwre = modelcif.SoftwareGroup(
                        elements=(sftwre,), parameters=params
                    )

        if js_step["input"] == "target_sequences":
            input_data = modelcif.data.DataGroup(target_entities)
            input_data.extend(ref_dbs)
        elif js_step["input"] == "model":
            input_data = model
        else:
            raise RuntimeError(f"Unknown protocol input: '{js_step['input']}'")
        if js_step["output"] == "model":
            output_data = model
        else:
            raise RuntimeError(
                f"Unknown protocol output: '{js_step['output']}'"
            )
        protocol.steps.append(
            modelcif.protocol.Step(
                input_data=input_data,
                output_data=output_data,
                name=js_step["name"],
                details=js_step["details"],
                software=sftwre,
            )
        )
        protocol.steps[-1].method_type = js_step["method_type"]

    return protocol


def _compress_cif_file(cif_file):
    """Compress cif file and delete original."""
    with open(cif_file, 'rb') as f_in:
        with gzip.open(cif_file + '.gz', 'wb') as f_out:
            shutil.copyfileobj(f_in, f_out)
    os.remove(cif_file)


def _store_as_modelcif(data_json, ost_ent, plddt_entity, out_dir, mdl_name,
                       compress):
    """Mix all the data into a ModelCIF file."""
    print("    generating ModelCIF objects...", end="")
    pstart = timer()
    # create system to gather all the data
    system = modelcif.System(
        title=data_json["title"],
        id=mdl_name.replace(' ', '_').upper(),
        model_details=data_json["model_details"],
    )
    # create target entities, references, source, asymmetric units & assembly
    # for source we assume all chains come from the same taxon
    source = ihm.source.Natural(
        ncbi_taxonomy_id=data_json["target_entities"][0]["up_ncbi_taxid"],
        scientific_name=data_json["target_entities"][0]["up_organism"],
    )

    # create an asymmetric unit and an entity per target sequence
    asym_units = {}
    _get_modelcif_entities(
        data_json["target_entities"], source, asym_units, system
    )

    assembly = modelcif.Assembly(
        asym_units.values()
    )

    # audit_authors
    system.authors.extend(data_json["audit_authors"])

    # set up the model to produce coordinates
    if data_json['mdl_num']:
        mdl_list_name = f"Model {data_json['mdl_num']} (top ranked model)"
    else:
        mdl_list_name = "Top ranked model"
    model = _OST2ModelCIF(
        assembly=assembly,
        asym=asym_units,
        ost_entity=ost_ent,
        plddt_entity=plddt_entity,
        name=mdl_list_name,
    )
    print(f" ({timer()-pstart:.2f}s)")
    print("    processing QA scores...", end="", flush=True)
    pstart = timer()
    model.add_scores()
    print(f" ({timer()-pstart:.2f}s)")

    model_group = modelcif.model.ModelGroup(
        [model], name=data_json["model_group_name"]
    )
    system.model_groups.append(model_group)

    ref_dbs = _get_sequence_dbs(data_json["config_data"])
    protocol = _get_modelcif_protocol(
        data_json["protocol"], system.target_entities, model, ref_dbs
    )
    system.protocols.append(protocol)

    # write modelcif System to file (NOTE: no PAE here!)
    print("    write to disk...", end="", flush=True)
    pstart = timer()
    out_path = os.path.join(out_dir, f"{mdl_name}.cif")
    with open(out_path, "w", encoding="ascii") as mmcif_fh:
        modelcif.dumper.write(mmcif_fh, [system])
    if compress:
        _compress_cif_file(out_path)
    print(f" ({timer()-pstart:.2f}s)")


def _create_json(config_data):
    """Create a dictionary (mimicking JSON) that contains data which is the same
    for all models."""
    data = {}

    data["audit_authors"] = _get_audit_authors()
    data["protocol"] = _get_protocol_steps_and_software(config_data)
    data["config_data"] = config_data

    return data


def _create_model_json(data, pdb_file, md_row):
    """Create a dictionary (mimicking JSON) that contains all the data."""
    data["target_entities"], ost_ent = _get_entities(
        pdb_file, data["mdl_title"], md_row["UniProt_ID"],
        md_row["NCBI_Accession"]
    )
    data["title"] = _get_title(data["mdl_title"])
    data["model_details"] = _get_model_details(
        data["config_data"]["descriptions"], md_row["notes"]
    )
    data["model_group_name"] = _get_model_group_name()

    return ost_ent


def _is_special(file_prfx):
    """Check if there is an unrelaxed file."""
    # if special case, we need separate file to fetch pLDDT and add extra 
    # GROMACS step to protocol
    plddt_path = f"{file_prfx}-unrelaxed.pdb"
    if os.path.exists(plddt_path):
        return plddt_path, True
    else:
        return None, False


def _get_mdl_num(mdl_id):
    """Fetch model number from filename used by AF."""
    # mdl_id example model_4_pred_0 -> fetch 4
    mdl_num = None
    if type(mdl_id) == str:
        mdl_id_split = mdl_id.split('_')
        if len(mdl_id_split) == 4:
            mdl_num = int(mdl_id_split[1])
    return mdl_num


def _main():
    """Run as script."""
    opts = _parse_args()

    # parse/fetch global data
    metadata = _get_metadata(opts.metadata_file)
    if opts.compress:
        cifext = "cif.gz"
    else:
        cifext = "cif"

    # get on with models
    print(f"Working on {opts.model_dir}...")

    # iterate model directory
    for fle in sorted(os.listdir(opts.model_dir)):
        # iterate PDB files
        if not fle.endswith(".pdb"):
            continue
        # check file and if to be done
        mdl_name = os.path.splitext(fle)[0]
        if mdl_name not in metadata.index:
            # skip unknown ones
            continue
        md_row = metadata.loc[mdl_name]
        assert md_row["Associated PDB"] == fle
        file_prfx = os.path.join(opts.model_dir, mdl_name)
        fle = os.path.join(opts.model_dir, fle)
        if os.path.exists(os.path.join(opts.out_dir, f"{mdl_name}.{cifext}")):
            print(f"  {mdl_name} already done...")
            continue

        # go for it
        print(f"  translating {mdl_name}...")
        pdb_start = timer()
        plddt_path, is_special = _is_special(file_prfx)
        config_data = _get_config(is_special)
        mdlcf_json = _create_json(config_data)
        mdlcf_json["mdl_title"] = md_row["_struct.title "]
        mdlcf_json["mdl_num"] = _get_mdl_num(md_row["ranking debugg model ID"])

        # gather data into JSON-like structure
        print("    preparing data...", end="")
        pstart = timer()
        ost_ent = _create_model_json(mdlcf_json, fle, md_row)
        if is_special:
            plddt_entity = io.LoadPDB(plddt_path)
        else:
            plddt_entity = None
        print(f" ({timer()-pstart:.2f}s)")

        _store_as_modelcif(mdlcf_json, ost_ent, plddt_entity, opts.out_dir,
                           mdl_name, opts.compress)
        print(f"  ... done with {mdl_name} ({timer()-pdb_start:.2f}s).")

        # check if result can be read and has expected seq.
        ent = io.LoadMMCIF(os.path.join(opts.out_dir, f"{mdl_name}.{cifext}"))
        assert ent.chain_count == 1, f"Bad chain count {mdl_name}"
        ent_seq = "".join(res.one_letter_code for res in ent.residues)
        up_range = mdlcf_json["target_entities"][0]["up_range"]
        exp_seq = mdlcf_json["target_entities"][0]["up_sequence"]
        exp_seq = exp_seq[up_range[0]-1:up_range[1]]
        assert ent_seq == exp_seq, f"Bad seq. {mdl_name}"

    print(f"... done with {opts.model_dir}.")


if __name__ == "__main__":
    _main()