# USAGE: python test_modelCIF_MA.py [CIF-FILE]
# OUTPUT: preview of MA entry

import gemmi, sys

################################################################################
# define converters for printing or for html (here markdown used for display)
def _make_url(link_text, link_url):
    return f"[{link_text}]({link_url})"
def _make_list(list_top, list_items):
    list_texts = [f"- {item}" for item in list_items]
    list_text = '\n'.join(list_texts)
    if list_texts:
        return(f"{list_top}\n{list_text}")
def _make_multiline_text(lines):
    return '\n'.join(lines)
def _make_paragraph(text):
    return f"{text}\n\n"
################################################################################

################################################################################
# HELPERS
################################################################################
def _get_chain_entity(block):
    # NOTE: adapted from ParseCIF (asym_ids as list and handle optional fields)
    chain_entity = {}

    ma_target_entity = block.find(
        "_ma_target_entity_instance.", ["asym_id", "entity_id"]
    )
    if ma_target_entity:
        for row in ma_target_entity:
            if row["entity_id"] in chain_entity:
                chain_entity[row["entity_id"]]["asym_ids"].append(
                    gemmi.cif.as_string(row["asym_id"])
                )
            else:
                chain_entity[row["entity_id"]] = {
                    "asym_ids": [row["asym_id"]],
                    "ma_target_ref_db_details": [],
                }

        for row in block.find("_entity.", ["id", "pdbx_description"]):
            chain_entity[row["id"]].update(
                {"pdbx_description": gemmi.cif.as_string(row["pdbx_description"])}
            )

        cols = [
            "target_entity_id",
            "db_accession",
            "db_name",
            "?db_name_other_details",
            "?organism_scientific",
            "?seq_db_align_begin",
            "?seq_db_align_end",
        ]
        # NOTE: row[col] doesn't work with '?' in find! Bad crashed happen if you try...
        for row in block.find("_ma_target_ref_db_details.", cols):
            json_obj = {}
            target_entity_id = row[0] # make sure this stays at idx 0 in cols!
            for idx, cq in enumerate(cols):
                if cq.startswith('?'):
                    col = cq[1:]
                    if not row.has(idx):
                        json_obj[col] = None
                        continue
                else:
                    col = cq
                if col != "target_entity_id":
                    if col.find("seq_db_align") > -1:
                        json_obj[col] = gemmi.cif.as_int(row[idx])
                    else:
                        json_obj[col] = gemmi.cif.as_string(row[idx])
            chain_entity[target_entity_id][
                "ma_target_ref_db_details"
            ].append(json_obj)

        for row in block.find(
            "_entity_poly.",
            ["entity_id", "pdbx_strand_id", "pdbx_seq_one_letter_code"],
        ):
            chain_entity[row["entity_id"]]["olc"] = gemmi.cif.as_string(
                row["pdbx_seq_one_letter_code"]
            ).replace("\n", "")

    return chain_entity

def _fetch_qa_data(block):
    """Get 3 lists of QA scores and extract global scores."""
    # fetch main info
    table = block.find("_ma_qa_metric.", ["id", "name", "mode", "type"])
    qa_dict = dict() # for easy access: keyed on "id" and rest as dict
    for row in table:
        d = {key: row.str(idx + 1) for idx, key in enumerate(["name", "mode", "type"])}
        qa_dict[row.str(0)] = d
    # fetch global scores
    qa_global = []
    table = block.find("_ma_qa_metric_global.", ["metric_id", "metric_value"])
    for row in table:
        metric_id = row.str(0)
        metric_value = gemmi.cif.as_number(row.get(1))
        if metric_id in qa_dict:
            assert qa_dict[metric_id]["mode"] == "global"
            qa_dict[metric_id]["value"] = metric_value
            qa_global.append(qa_dict[metric_id])
    # fetch local scores
    qa_local = [d for d in qa_dict.values() if d["mode"] == "local"]
    qa_local_pairwise = [d for d in qa_dict.values() if d["mode"] == "local-pairwise"]
    return qa_global, qa_local, qa_local_pairwise
################################################################################

################################################################################
# TEXT GETTERS
# -> each function turns gemmi-block into a text
# -> bits that could be reused elsewhere are given by helpers above
################################################################################
def _get_title_text(block):
    title = block.find_value("_struct.title")
    if title:
        title = gemmi.cif.as_string(title)
    if title:
        return title
    else:
        return "NO TITLE"

def _get_overview_text(block):
    abstract = block.find_value("_struct.pdbx_model_details")
    if abstract:
        abstract = gemmi.cif.as_string(abstract)
    if abstract:
        text = _make_paragraph(abstract)
    else:
        text = _make_paragraph("NO ABSTRACT")
    # TODO: "TOFILL" below to be filled by MA using info from model set
    text += _make_paragraph('This model is part of the dataset "TOFILL"')
    return text

def _get_entity_text(block):
    # NOTE: reuse (updated!) code from ParseCIF
    chain_entity = _get_chain_entity(block)
    # do it
    item_strings = []
    for ent in chain_entity:
        item_title = f"{chain_entity[ent]['pdbx_description']} " \
                     f"(chains: {', '.join(chain_entity[ent]['asym_ids'])})"
        db_links = [item_title]
        for i, ref in enumerate(chain_entity[ent]["ma_target_ref_db_details"]):
            if ref["db_name"] == "UNP":
                link_text = ref['db_accession']
                link_url = f"https://www.uniprot.org/uniprot/{ref['db_accession']}"
                db_link = f"UniProt: {_make_url(link_text, link_url)}"
            elif ref["db_name"] == "OrthoDB":
                link_text = ref['db_accession']
                link_url = f"https://www.orthodb.org/?query={ref['db_accession']}"
                db_link = f"OrthoDB: {_make_url(link_text, link_url)}"
            elif ref["db_name"] == "Other" and ref["db_name_other_details"]:
                db_link = f"{ref['db_name_other_details']}: {ref['db_accession']}"
            else:
                db_link = f"{ref['db_name']}: {ref['db_accession']}"
            if ref['seq_db_align_begin'] and ref['seq_db_align_end']:
                db_link += f" {ref['seq_db_align_begin']}-{ref['seq_db_align_end']}"
            if ref['organism_scientific']:
                db_link += f"; {ref['organism_scientific']}"
            db_links.append(db_link)
        item_strings.append(_make_multiline_text(db_links))
    # and into a paragraph...
    if item_strings:
        list_top = f"The following molecular entities are in the model:"
        return _make_paragraph(_make_list(list_top, item_strings))
    else:
        return ""
    
def _get_sw_text(block):
    # get author names for each citation
    tmp = dict()
    for row in block.find("_citation_author.", ["citation_id", "name"]):
        cid = row.str(0)
        name = row.str(1)
        if cid not in tmp:
            tmp[cid] = {"name": name.split()[0].split(",")[0], "etal": ""}
        else:
            tmp[cid]["etal"] = " et al."
    cit_names = {cid: (d["name"] + d["etal"]) for cid, d in tmp.items()}
    # add year if available
    table = block.find("_citation.", ["id", "?year"])
    if table.has_column(1):
        for row in table:
            cid = row.str(0)
            year = row.str(1)
            if cid in cit_names and year:
                cit_names[cid] += " " + year
    # add URL if available
    cit_urls = {}
    table = block.find("_citation.", ["id", "?pdbx_database_id_DOI", "?pdbx_database_id_PubMed"])
    formatters = ["https://doi.org/%s",
                  "https://www.ncbi.nlm.nih.gov/pubmed/%s"]
    for row in table:
        cid = row.str(0)
        # add whichever URL we find first
        for i in range(1, table.width()):
            if row.has(i) and row.str(i):
                cit_urls[cid] = formatters[i - 1] % row.str(i)
                break
    # now map this to software
    item_strings = []
    table = block.find("_software.", ["name", "?location", "?version", "?citation_id"])
    for row in table:
        sw_name = row.str(0)
        if row.has(1) and row.str(1):
            item = _make_url(sw_name, row.str(1))
        else:
            item = sw_name
        if row.has(2) and row.str(2):
            item += f" ({row.str(2)})"
        if row.has(3) and row.str(3) in cit_names:
            cid = row.str(3)
            if cid in cit_urls:
                item += f" ({_make_url(cit_names[cid], cit_urls[cid])})"
            else:
                item += f" ({cit_names[cid]})"
        item_strings.append(item)
    # and into a paragraph...
    if item_strings:
        list_top = f"The following software was used:"
        return _make_paragraph(_make_list(list_top, item_strings))
    else:
        return ""

def _get_ref_db_text(block):
    # look for DBs with version or release date
    item_strings = []
    table = block.find("_ma_data_ref_db.", ["name", "?version", "?release_date"])
    for row in table:
        item = f"{row.str(0)}"
        # add whichever version we find first
        for i in range(1, 3):
            if row.has(i) and row.str(i):
                item += f" ({row.str(i)})"
                break
        item_strings.append(item)
    # and into a paragraph...
    if item_strings:
        list_top = f"The following reference databases were used:"
        return _make_paragraph(_make_list(list_top, item_strings))
    else:
        return ""

def _get_tpl_text(block):
    # collect info per tpl-id
    tpl_dict = {}  # keyed on template_id
    # fetch main info
    cols = ["template_id", "target_asym_id", "template_auth_asym_id",
            "?template_label_asym_id"]
    for row in block.find("_ma_template_details.", cols):
        tid = row.str(0)
        tpl_dict[tid] = {
            "trg_asym_id": row.str(1),
            "tpl_auth_asym_id": row.str(2)
        }
        if row.has(3) and row.str(3):
            tpl_dict[tid]["tpl_label_asym_id"] = row.str(3)
    # add ref DBs
    cols = ["template_id", "db_accession_code", "db_name",
            "?db_name_other_details"]
    for row in block.find("_ma_template_ref_db_details.", cols):
        tid = row.str(0)
        if tid in tpl_dict:
            tpl_dict[tid]["db_acc"] = row.str(1)
            if row.str(2) == "Other" and row.has(3) and row.str(3):
                tpl_dict[tid]["db_name"] = row.str(3)
            else:
                tpl_dict[tid]["db_name"] = row.str(2)
    # add info for small molecules
    cols = ["template_id", "?comp_id", "?details"]
    for row in block.find("_ma_template_non_poly.", cols):
        tid = row.str(0)
        if tid in tpl_dict:
            if row.has(1) and row.str(1):
                tpl_dict[tid]["non_poly_comp_id"] = row.str(1)
            if row.has(2) and row.str(2):
                tpl_dict[tid]["non_poly_details"] = row.str(2)
    # aggregate per template for diplaying
    tpl_td = dict()
    for tpl in tpl_dict.values():
        did = f"{tpl['db_name']}-{tpl['db_acc']}"
        if did not in tpl_td:
            if tpl['db_name'] == "PDB":
                link_url = f"http://dx.doi.org/10.2210/pdb{tpl['db_acc']}/pdb"
            elif tpl['db_name'] == "PubChem":
                link_url = f"https://pubchem.ncbi.nlm.nih.gov/compound/{tpl['db_acc']}"
            else:
                link_url = None
                print(f"URLs for {tpl['db_name']} NOT SUPPORTED YET")
            if link_url:
                tpl_text = f"{tpl['db_name']}: {_make_url(tpl['db_acc'], link_url)}"
            else:
                tpl_text = f"{tpl['db_name']}: {tpl['db_acc']}"
            tpl_td[did] = {
                "tpl_text": tpl_text,
                "tpl_chains_label": [],
                "tpl_chains_auth": [],
                "tpl_chains_all_label": True,
                "tpl_non_poly_ids": []
            }
        # collect chain names
        if "tpl_label_asym_id" in tpl:
            # if here it is guaranteed to be non-empty
            tpl_td[did]["tpl_chains_label"].append(tpl["tpl_label_asym_id"])
        else:
            # if any missing, we set all to False and fall back to auth
            tpl_td[did]["tpl_chains_all_label"] = False
        if tpl["tpl_auth_asym_id"]:
            # only add non empty ones
            tpl_td[did]["tpl_chains_auth"].append(tpl["tpl_auth_asym_id"])
        # collect info on non poly if available (prefer short comp. ID)
        if "non_poly_comp_id" in tpl:
            tpl_td[did]["tpl_non_poly_ids"].append(tpl["non_poly_comp_id"])
        elif "non_poly_details" in tpl:
            tpl_td[did]["tpl_non_poly_ids"].append(tpl["non_poly_details"])
    # turn into text
    item_strings = []
    for tpl in tpl_td.values():
        item = tpl["tpl_text"]
        if tpl["tpl_chains_all_label"] and tpl["tpl_chains_label"]:
            chain_ids = sorted(set(tpl['tpl_chains_label']))
            item += f"; chains (label_asym_id): {', '.join(chain_ids)}"
        elif tpl["tpl_chains_auth"]:
            chain_ids = sorted(set(tpl['tpl_chains_auth']))
            item += f"; chains (auth_asym_id): {', '.join(chain_ids)}"
        if tpl["tpl_non_poly_ids"]:
            np_ids = sorted(set(tpl['tpl_non_poly_ids']))
            item += f"; non-polymers: {', '.join(np_ids)}"
        item_strings.append(item)
    # and into a paragraph...
    if item_strings:
        list_top = f"The following templates were used:"
        return _make_paragraph(_make_list(list_top, item_strings))
    else:
        return ""

def _get_protocol_steps_text(block):
    # do it
    item_strings = []
    table = block.find("_ma_protocol_step.", ["step_id", "method_type", "?details"])
    for row in table:
        item = f"Step {row.str(0)} - {row.str(1)}"
        if row.has(2) and row.str(2):
            item += f" : {row.str(2)}"
        item_strings.append(item)
    # and into a paragraph...
    if item_strings:
        return _make_paragraph(_make_multiline_text(item_strings))
    else:
        return ""

def _get_qa_acc_text(block):
    # get QA part (can reuse if already used elsewhere)
    qa_global, qa_local, qa_local_pairwise = _fetch_qa_data(block)
    # parse accompanying data
    file_contents = block.find_values("_ma_entry_associated_files.file_content")
    has_single_zip_file = len(file_contents) == 1 and \
                          file_contents.str(0) == "archive with multiple files"
    if has_single_zip_file:
        # override with data from other block
        file_contents = block.find_values("_ma_associated_archive_file_details.file_content")
    has_loc_pw_in_acc = any(True for v in file_contents \
                            if gemmi.cif.as_string(v) == "local pairwise QA scores")
    # put together text
    text = ""
    # text for QA
    item = ""
    if len(qa_global) > 1:
        score_strings = [f"{v['name']} of {v['value']}" for v in qa_global]
        item = f"The model has the following global model confidence scores:" \
               f" {', '.join(score_strings)}."
    elif len(qa_global) == 1:
        item = f"The model has a global model confidence score " \
               f"({qa_global[0]['name']}) of {qa_global[0]['value']}."
    if item:
        text += _make_paragraph(item)
    # lots of options for local QA string
    item = ""
    qa_local_names = ", ".join([v["name"] for v in qa_local])
    qa_loc_pw_names = ", ".join([v["name"] for v in qa_local_pairwise])
    if qa_local_names and qa_loc_pw_names and has_loc_pw_in_acc:
        item = f"Local per-residue model confidence scores ({qa_local_names}) " \
               f"are available in the model mmCIF file " \
               f"and local per-residue-pair scores ({qa_loc_pw_names}) " \
               f"in the accompanying data download."
    elif qa_local_names and qa_loc_pw_names and not has_loc_pw_in_acc:
        item = f"Local per-residue model confidence scores ({qa_local_names}) " \
               f"and local per-residue-pair scores ({qa_loc_pw_names}) " \
               f"are available in the model mmCIF file."
    elif qa_local_names and not qa_loc_pw_names:
        item = f"Local per-residue model confidence scores ({qa_local_names}) " \
               f"are available in the model mmCIF file."
    elif not qa_local_names and qa_loc_pw_names and has_loc_pw_in_acc:
        item = f"Local per-residue-pair model confidence scores ({qa_loc_pw_names}) " \
               f"are available in the accompanying data download."
    elif not qa_local_names and qa_loc_pw_names and not has_loc_pw_in_acc:
        item = f"Local per-residue-pair model confidence scores ({qa_loc_pw_names}) " \
               f"are available in the model mmCIF file."
    if item:
        text += _make_paragraph(item)
    # list files in accompanying data (if any)
    if has_single_zip_file:
        table = block.find("_ma_associated_archive_file_details.",
                           ["file_path", "?file_content", "?description"])
    else:
        # NOTE: aimed to work legacy-style for Baker-models but should be obsoleted
        # -> can replace below with "table = None" in future
        table = block.find("_ma_entry_associated_files.",
                           ["file_url", "?file_content", "?details"])
    if table:
        list_top = "Files in accompanying data:"
        list_items = []
        for row in table:
            item = f"{row.str(0)}"
            if row.has(1) and row.str(1):
                item += f" ({row.str(1)})"
            if row.has(2) and row.str(2):
                item += f": {row.str(2)}"
            list_items.append(item)
        text += _make_paragraph(_make_list(list_top, list_items))
    # conclude with standard pointer to ModelCIF file
    model_cif_link = _make_url(
        "ModelCIF format",
        "https://mmcif.wwpdb.org/dictionaries/mmcif_ma.dic/Index/"
    )
    text += _make_paragraph(
        f"Full details are available in {model_cif_link}" \
        f"in the model mmCIF file."
    )
    return text
################################################################################

# full parsing of file
def _process_file(file_path):
    print(f"ModelCIF->MA for {file_path}")
    doc = gemmi.cif.read(file_path)
    block = doc.sole_block()
    print('-'*80)
    print(f"TITLE: {_get_title_text(block)}")
    print('-'*80)
    print(f"OVERVIEW:\n{'-'*80}")
    print(_get_overview_text(block))
    print('-'*80)
    print(f"MATERIAL:\n{'-'*80}")
    material_text = _get_entity_text(block) \
                  + _get_sw_text(block) \
                  + _get_ref_db_text(block) \
                  + _get_tpl_text(block)
    print(material_text)
    print('-'*80)
    print(f"PROCEDURE:\n{'-'*80}")
    procedure_text = _get_protocol_steps_text(block) \
                   + _get_qa_acc_text(block)
    print(procedure_text)
    print('-'*80)

def _main():
    # get file to process
    if len(sys.argv) == 2:
        file_path = sys.argv[1]
        _process_file(file_path)
    else:
        print("USAGE: python test_modelCIF_MA.py [CIF-FILE]")

if __name__ == "__main__":
    _main()