diff --git a/translate2modelcif.py b/translate2modelcif.py new file mode 100644 index 0000000000000000000000000000000000000000..4deb91eb24aefa5390f73bb7a2fc192e308eb5d8 --- /dev/null +++ b/translate2modelcif.py @@ -0,0 +1,991 @@ +#! /usr/local/bin/ost +"""Translate models from Tara/ Xabi from PDB + extra data into ModelCIF.""" + +import argparse +import datetime +import gzip +import os +import shutil +import sys +import zipfile + +from timeit import default_timer as timer +import numpy as np +import requests +import ujson as json + +import ihm +import ihm.citations +import modelcif +import modelcif.associated +import modelcif.dumper +import modelcif.model +import modelcif.protocol +import modelcif.reference + +from ost import io + + +# EXAMPLES for running: +# ost scripts/translate2modelcif.py --rank 1 --out_dir="./modelcif" \ +# "A0A1B0GTU1-O75152" + + +def _parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + formatter_class=argparse.RawDescriptionHelpFormatter, + description=__doc__, + ) + + parser.add_argument( + "model_dir", + type=str, + metavar="<MODEL DIR>", + help="Directory with model(s) to be translated. Must be of form " + + "'<UniProtKB AC>-<UniProtKB AC>'", + ) + parser.add_argument( + "--rank", + type=str, + default=None, + help="Only process the model with this rank.", + ) + parser.add_argument( + "--out_dir", + type=str, + metavar="<OUTPUT DIR>", + default="", + help="Path to separate path to store results " + "(<MODEL DIR> used, if none given).", + ) + parser.add_argument( + "--compress", + default=False, + action="store_true", + help="Compress ModelCIF file with gzip " + "(note that QA file is zipped either way).", + ) + opts = parser.parse_args() + + # check that model dir exists + if opts.model_dir.endswith("/"): + opts.model_dir = opts.model_dir[:-1] + if not os.path.exists(opts.model_dir): + _abort_msg(f"Model directory '{opts.model_dir}' does not exist.") + if not os.path.isdir(opts.model_dir): + _abort_msg(f"Path '{opts.model_dir}' does not point to a directory.") + # check out_dir + if not opts.out_dir: + opts.out_dir = opts.model_dir + else: + if not os.path.exists(opts.out_dir): + _abort_msg(f"Output directory '{opts.out_dir}' does not exist.") + if not os.path.isdir(opts.out_dir): + _abort_msg(f"Path '{opts.out_dir}' does not point to a directory.") + + return opts + + +# pylint: disable=too-few-public-methods +class _GlobalPTM(modelcif.qa_metric.Global, modelcif.qa_metric.PTM): + """Predicted accuracy according to the TM-score score in [0,1]""" + + name = "pTM" + software = None + + +class _GlobalPLDDT(modelcif.qa_metric.Global, modelcif.qa_metric.PLDDT): + """Predicted accuracy according to the CA-only lDDT in [0,100]""" + + name = "pLDDT" + software = None + + +class _LocalPLDDT(modelcif.qa_metric.Local, modelcif.qa_metric.PLDDT): + """Predicted accuracy according to the CA-only lDDT in [0,100]""" + + name = "pLDDT" + software = None + + +class _PAE(modelcif.qa_metric.MetricType): + """Predicted aligned error (in Angstroms)""" + + type = "PAE" + other_details = None + + +class _LocalPairwisePAE(modelcif.qa_metric.LocalPairwise, _PAE): + """Predicted aligned error (in Angstroms)""" + + name = "PAE" + software = None + + +# pylint: enable=too-few-public-methods + + +class _OST2ModelCIF(modelcif.model.AbInitioModel): + """Map OST entity elements to ihm.model""" + + def __init__(self, *args, **kwargs): + """Initialise a model""" + self.ost_entity = kwargs.pop("ost_entity") + self.asym = kwargs.pop("asym") + + super().__init__(*args, **kwargs) + + def get_atoms(self): + for atm in self.ost_entity.atoms: + yield modelcif.model.Atom( + asym_unit=self.asym[atm.chain.name], + seq_id=atm.residue.number.num, + atom_id=atm.name, + type_symbol=atm.element, + x=atm.pos[0], + y=atm.pos[1], + z=atm.pos[2], + het=atm.is_hetatom, + biso=atm.b_factor, + occupancy=atm.occupancy, + ) + + def add_scores(self, scores_json, entry_id, mdl_name): + """Add QA metrics from AF2 scores.""" + # global scores + self.qa_metrics.extend( + ( + _GlobalPLDDT(np.mean(scores_json["plddt"])), + _GlobalPTM(scores_json["ptm"]), + ) + ) + + # local scores + lpae = [] + i = 0 + for chn_i in self.ost_entity.chains: + for res_i in chn_i.residues: + # local pLDDT + self.qa_metrics.append( + _LocalPLDDT( + self.asym[chn_i.name].residue(res_i.number.num), + scores_json["plddt"][i], + ) + ) + + # pairwise alignment error + j = 0 + # We do a 2nd iteration over the OST entity above doing index + # magic because it keeps the code cleaner and is only around + # 0.3s slower than iterating the array directly. Majority of + # time goes into writing files, anyway. + for chn_j in self.ost_entity.chains: + for res_j in chn_j.residues: + lpae.append( + _LocalPairwisePAE( + self.asym[chn_i.name].residue(res_i.number.num), + self.asym[chn_j.name].residue(res_j.number.num), + scores_json["pae"][i][j], + ) + ) + j += 1 + + i += 1 + + self.qa_metrics.extend(lpae) + + ac_file = f"{mdl_name}_local_pairwise_qa.cif" + qa_file = modelcif.associated.LocalPairwiseQAScoresFile( + ac_file, + categories=["_ma_qa_metric_local_pairwise"], + copy_categories=["_ma_qa_metric"], + entry_id=entry_id, + entry_details="This file is an associated file consisting " + + "of local pairwise QA metrics. This is a partial mmCIF " + + "file and can be validated by merging with the main " + + "mmCIF file containing the model coordinates and other " + + "associated data.", + details="Predicted aligned error", + ) + return modelcif.associated.Repository( + "", + [modelcif.associated.ZipFile(f"{mdl_name}.zip", files=[qa_file])], + ) + # NOTE: by convention MA expects zip file with same name as model-cif + + +def _abort_msg(msg, exit_code=1): + """Write error message and exit with exit_code.""" + print(f"{msg}\nAborting.", file=sys.stderr) + sys.exit(exit_code) + + +def _check_file(file_path): + """Make sure a file exists and is actually a file.""" + if not os.path.exists(file_path): + _abort_msg(f"File not found: '{file_path}'.") + if not os.path.isfile(file_path): + _abort_msg(f"File path does not point to file: '{file_path}'.") + + +def _check_interaction_extra_files_present(model_dir): + """Make sure some general files are present for an interaction.""" + cnfg = os.path.join(model_dir, "config.json") + _check_file(cnfg) + return cnfg + + +def _check_model_extra_files_present(model_dir, pdb_file): + """Check that all files needed to process this model are present.""" + uid = os.path.splitext(pdb_file)[0] + prfx = os.path.join(model_dir, uid) + scrs = f"{prfx}_scores.json" + _check_file(scrs) + return prfx, uid + + +def _get_audit_authors(): + """Return the list of authors that produced this model.""" + return ( + "Bartolec, T.", + "Vazquez-Campos, X.", + "Johnson, M.", + "Norman, A.", + "Payne, R.", + "Wilkins, M.", + "Mackay, J.", + "Low, J.", + ) + + +def _parse_colabfold_config(cnfg_file): + """Read config.json and fetch relevant data from it.""" + # NOTE: following code from + # https://github.com/sokrypton/ColabFold/blob/main/colabfold/batch.py to + # understand config + + # fetch and drop fields which are not relevant for model building + with open(cnfg_file, encoding="utf8") as jfh: + cf_config = json.load(jfh) + if "num_queries" in cf_config: + del cf_config["num_queries"] + # fetch relevant data + # -> MSA mode + if cf_config["msa_mode"] == "MMseqs2 (UniRef+Environmental)": + seq_dbs = ["UniRef", "Environmental"] + use_mmseqs = True + use_msa = True + elif cf_config["msa_mode"] == "MMseqs2 (UniRef only)": + seq_dbs = ["UniRef"] + use_mmseqs = True + use_msa = True + elif cf_config["msa_mode"] == "single_sequence": + seq_dbs = [] + use_mmseqs = False + use_msa = False + elif cf_config["msa_mode"] == "custom": + print( + "WARNING: Custom MSA mode used. Not clear from config what to do " + + "here!" + ) + seq_dbs = [] + use_mmseqs = False + use_msa = True + else: + raise ValueError(f"Unknown msa_mode {cf_config['msa_mode']}") + # -> model type + if cf_config["model_type"] == "AlphaFold2-multimer-v1": + # AF-Multimer as introduced in AlphaFold v2.1.0 + use_multimer = True + multimer_version = 1 + elif cf_config["model_type"] == "AlphaFold2-multimer-v2": + # AF-Multimer as introduced in AlphaFold v2.2.0 + use_multimer = True + multimer_version = 2 + elif cf_config["model_type"] == "AlphaFold2-ptm": + use_multimer = False + multimer_version = None + else: + raise ValueError(f"Unknown model_type {cf_config['model_type']}") + + # write modeling description + mdl_description = f"Model generated using ColabFold v{cf_config['version']}" + if use_multimer: + mdl_description += f" with AlphaFold-Multimer (v{multimer_version})" + else: + mdl_description += " with AlphaFold" + if cf_config["stop_at_score"] < 100: + # early stopping feature of ColabFold + upto = "up to " + else: + upto = "" + mdl_description += ( + f" producing {upto}{cf_config['num_models']} models" + f" with {upto}{cf_config['num_recycles']} recycles each" + ) + if cf_config["use_amber"]: + mdl_description += ", with AMBER relaxation" + else: + mdl_description += ", without model relaxation" + if cf_config["use_templates"]: + print( + "WARNING: ColabFold may use PDB70 or custom templates. " + "Not clear from config!" + ) + mdl_description += ", using templates" + else: + mdl_description += ", without templates" + if cf_config["rank_by"] == "plddt": + mdl_description += ", ranked by pLDDT" + elif cf_config["rank_by"] == "ptmscore": + mdl_description += ", ranked by pTM" + elif cf_config["rank_by"] == "multimer": + mdl_description += ", ranked by ipTM*0.8+pTM*0.2" + else: + raise ValueError(f"Unknown rank_by {cf_config['rank_by']}") + if use_msa: + mdl_description += ", starting from" + if use_mmseqs: + msa_type = "MSA" + else: + msa_type = "custom MSA" + if use_multimer: + if cf_config["pair_mode"] == "unpaired+paired": + mdl_description += f" paired and unpaired {msa_type}s" + elif cf_config["pair_mode"] == "paired": + mdl_description += f" paired {msa_type}s" + elif cf_config["pair_mode"] == "unpaired": + mdl_description += f" unpaired {msa_type}s" + else: + raise ValueError(f"Unknown pair_mode {cf_config['pair_mode']}") + else: + mdl_description += f" an {msa_type}" + if use_mmseqs: + mdl_description += f" from MMseqs2 ({'+'.join(seq_dbs)})" + else: + mdl_description += " without an MSA" + mdl_description += "." + + # write selection description + slct_description = ( + "Select best model, which is either the top-ranked model as " + + "determined by the ColabFold pipeline (ipTM*0.8+pTM*0.2), or else " + + "the model with best congruence with crosslinks reported in the " + + "related study." + ) + return { + "config": cf_config, + "seq_dbs": seq_dbs, + "use_mmseqs": use_mmseqs, + "use_msa": use_msa, + "use_multimer": use_multimer, + "multimer_version": multimer_version, + "modeling_description": mdl_description, + "selection_description": slct_description, + } + + +def _get_protocol_steps_and_software(config_data): + """Create the list of protocol steps with software and parameters used.""" + protocol = [] + # modelling step + step = { + "method_type": "modeling", + "name": None, + "details": config_data["modeling_description"], + } + # get input data + # Must refer to data already in the JSON, so we try keywords + step["input"] = "target_sequences" + # get output data + # Must refer to existing data, so we try keywords + step["output"] = "model" + # get software + step["software"] = [ + { + "name": "ColabFold", + "classification": "model building", + "description": "Structure prediction", + "citation": ihm.citations.colabfold, + "location": "https://github.com/sokrypton/ColabFold", + "type": "package", + "version": "1.2.0", + } + ] + if config_data["use_mmseqs"]: + step["software"].append( + { + "name": "MMseqs2", + "classification": "data collection", + "description": "Many-against-Many sequence searching", + "citation": ihm.citations.mmseqs2, + "location": "https://github.com/soedinglab/mmseqs2", + "type": "package", + "version": None, + } + ) + if config_data["use_multimer"]: + step["software"].append( + { + "name": "AlphaFold-Multimer", + "classification": "model building", + "description": "Structure prediction", + "citation": ihm.Citation( + pmid=None, + title="Protein complex prediction with " + + "AlphaFold-Multimer.", + journal="bioRxiv", + volume=None, + page_range=None, + year=2021, + authors=[ + "Evans, R.", + "O'Neill, M.", + "Pritzel, A.", + "Antropova, N.", + "Senior, A.", + "Green, T.", + "Zidek, A.", + "Bates, R.", + "Blackwell, S.", + "Yim, J.", + "Ronneberger, O.", + "Bodenstein, S.", + "Zielinski, M.", + "Bridgland, A.", + "Potapenko, A.", + "Cowie, A.", + "Tunyasuvunakool, K.", + "Jain, R.", + "Clancy, E.", + "Kohli, P.", + "Jumper, J.", + "Hassabis, D.", + ], + doi="10.1101/2021.10.04.463034", + ), + "location": "https://github.com/deepmind/alphafold", + "type": "package", + "version": None, + } + ) + else: + step["software"].append( + { + "name": "AlphaFold", + "classification": "model building", + "description": "Structure prediction", + "citation": ihm.citations.alphafold2, + "location": "https://github.com/deepmind/alphafold", + "type": "package", + "version": None, + } + ) + step["software_parameters"] = config_data["config"] + protocol.append(step) + + # model selection step + if ( + "selection_description" not in config_data + or len(config_data["selection_description"]) == 0 + ): + return protocol + + step = { + "method_type": "model selection", + "name": None, + "details": config_data["selection_description"], + } + step["input"] = "model" + step["output"] = "model" + step["software"] = [] + step["software_parameters"] = {} + protocol.append(step) + return protocol + + +def _get_title(gene_names): + """Get a title for this modelling experiment.""" + return f"Predicted interaction between {' and '.join(gene_names)}" + + +def _get_model_details(gene_names): + """Get the model description.""" + return ( + f"Dimer model generated for {' and '.join(gene_names)}, produced " + + "using AlphaFold-Multimer (AlphaFold v2.2.0) as implemented by " + + "ColabFold (v1.2.0) which uses MMseqs2 for MSA generation (UniRef30 " + + "+ Environmental)." + ) + + +def _get_model_group_name(): + """Get a name for a model group.""" + return "Crosslinked Heterodimer AlphaFold-Multimer v2 Models" + + +def _get_sequence(chn): + """Get the sequence out of an OST chain.""" + # initialise + lst_rn = chn.residues[0].number.num + idx = 1 + sqe = chn.residues[0].one_letter_code + if lst_rn != 1: + sqe = "-" + idx = 0 + + for res in chn.residues[idx:]: + lst_rn += 1 + while lst_rn != res.number.num: + sqe += "-" + lst_rn += 1 + sqe += res.one_letter_code + return sqe + + +def _check_sequence(up_ac, sequence): + """Verify sequence to only contain standard olc.""" + for res in sequence: + if res not in "ACDEFGHIKLMNPQRSTVWY": + raise RuntimeError( + "Non-standard aa found in UniProtKB sequence " + + f"for entry '{up_ac}': {res}" + ) + + +def _fetch_upkb_entry(up_ac): + """Fetch data for an UniProtKB entry.""" + # This is a simple parser for UniProtKB txt format, instead of breaking it + # up into multiple functions, we just allow many many branches & statements, + # here. + # pylint: disable=too-many-branches,too-many-statements + data = {} + data["up_organism"] = "" + data["up_sequence"] = "" + data["up_ac"] = up_ac + rspns = requests.get( + f"https://www.uniprot.org/uniprot/{up_ac}.txt", timeout=180 + ) + for line in rspns.iter_lines(decode_unicode=True): + if line.startswith("ID "): + sline = line.split() + if len(sline) != 5: + _abort_msg(f"Unusual UniProtKB ID line found:\n'{line}'") + data["up_id"] = sline[1] + elif line.startswith("OX NCBI_TaxID="): + # Following strictly the UniProtKB format: 'OX NCBI_TaxID=<ID>;' + data["up_ncbi_taxid"] = line[len("OX NCBI_TaxID=") : -1] + data["up_ncbi_taxid"] = data["up_ncbi_taxid"].split("{")[0].strip() + elif line.startswith("OS "): + if line[-1] == ".": + data["up_organism"] += line[len("OS ") : -1] + else: + data["up_organism"] += line[len("OS ") : -1] + " " + elif line.startswith("SQ "): + sline = line.split() + if len(sline) != 8: + _abort_msg(f"Unusual UniProtKB SQ line found:\n'{line}'") + data["up_seqlen"] = int(sline[2]) + data["up_crc64"] = sline[6] + elif line.startswith(" "): + sline = line.split() + if len(sline) > 6: + _abort_msg( + "Unusual UniProtKB sequence data line " + + f"found:\n'{line}'" + ) + data["up_sequence"] += "".join(sline) + elif line.startswith("RP "): + if "ISOFORM" in line.upper(): + RuntimeError( + f"First ISOFORM found for '{up_ac}', needs " + "handling." + ) + elif line.startswith("DT "): + # 2012-10-03 + dt_flds = line[len("DT ") :].split(", ") + if dt_flds[1].upper().startswith("SEQUENCE VERSION "): + data["up_last_mod"] = datetime.datetime.strptime( + dt_flds[0], "%d-%b-%Y" + ) + elif line.startswith("GN Name="): + data["up_gn"] = line[len("GN Name=") :].split(";")[0] + data["up_gn"] = data["up_gn"].split("{")[0].strip() + + # we have not seen isoforms in the data set, yet, so we just set them to '.' + data["up_isoform"] = None + + if "up_gn" not in data: + _abort_msg(f"No gene name found for UniProtKB entry '{up_ac}'.") + if "up_last_mod" not in data: + _abort_msg(f"No sequence version found for UniProtKB entry '{up_ac}'.") + if "up_crc64" not in data: + _abort_msg(f"No CRC64 value found for UniProtKB entry '{up_ac}'.") + if len(data["up_sequence"]) == 0: + _abort_msg(f"No sequence found for UniProtKB entry '{up_ac}'.") + # check that sequence length and CRC64 is correct + if data["up_seqlen"] != len(data["up_sequence"]): + _abort_msg( + "Sequence length of SQ line and sequence data differ for " + + f"UniProtKB entry '{up_ac}': {data['up_seqlen']} != " + + f"{len(data['up_sequence'])}" + ) + _check_sequence(data["up_ac"], data["up_sequence"]) + + if "up_id" not in data: + _abort_msg(f"No ID found for UniProtKB entry '{up_ac}'.") + if "up_ncbi_taxid" not in data: + _abort_msg(f"No NCBI taxonomy ID found for UniProtKB entry '{up_ac}'.") + if len(data["up_organism"]) == 0: + _abort_msg(f"No organism species found for UniProtKB entry '{up_ac}'.") + return data + + +def _get_upkb_for_sequence(sqe, up_ac): + """Get UniProtKB entry data for given sequence.""" + up_data = _fetch_upkb_entry(up_ac) + if sqe != up_data["up_sequence"]: + raise RuntimeError( + f"Sequences not equal from file: {sqe}, from UniProtKB: " + + f"{up_data['up_sequence']}" + ) + return up_data + + +def _get_entities(pdb_file, up_acs): + """Gather data for the mmCIF (target) entities.""" + entities = [] + + ost_ent = io.LoadPDB(pdb_file) + for i, chn in enumerate(ost_ent.chains): + cif_ent = {} + sqe = _get_sequence(chn) + upkb = _get_upkb_for_sequence(sqe, up_acs[i]) + cif_ent["pdb_sequence"] = sqe + cif_ent["pdb_chain_id"] = chn.name + cif_ent["description"] = ( + f"{upkb['up_organism']} {upkb['up_gn']} " f"({upkb['up_ac']})" + ) + cif_ent.update(upkb) + entities.append(cif_ent) + return entities, ost_ent + + +def _get_scores(data, prfx): + """Check that all files needed to process this model are present.""" + scrs_fle = f"{prfx}_scores.json" + with open(scrs_fle, encoding="utf8") as jfh: + scrs_json = json.load(jfh) + + # NOTE for reuse of data when iterating multiple models: this will overwrite + # scores in data but will not delete any scores if prev. models had more... + data.update(scrs_json) + + +def _get_modelcif_entities(target_ents, source, asym_units, system): + """Create ModelCIF entities and asymmetric units.""" + for cif_ent in target_ents: + mdlcif_ent = modelcif.Entity( + cif_ent["pdb_sequence"], + description=cif_ent["description"], + source=source, + references=[ + modelcif.reference.UniProt( + cif_ent["up_id"], + cif_ent["up_ac"], + # Careful: alignments are just that easy because of equal + # sequences! + align_begin=1, + align_end=cif_ent["up_seqlen"], + isoform=cif_ent["up_isoform"], + ncbi_taxonomy_id=cif_ent["up_ncbi_taxid"], + organism_scientific=cif_ent["up_organism"], + sequence_version_date=cif_ent["up_last_mod"], + sequence_crc64=cif_ent["up_crc64"], + ) + ], + ) + asym_units[cif_ent["pdb_chain_id"]] = modelcif.AsymUnit(mdlcif_ent) + system.target_entities.append(mdlcif_ent) + + +def _assemble_modelcif_software(soft_dict): + """Create a modelcif.Software instance from dictionary.""" + return modelcif.Software( + soft_dict["name"], + soft_dict["classification"], + soft_dict["description"], + soft_dict["location"], + soft_dict["type"], + soft_dict["version"], + citation=soft_dict["citation"], + ) + + +def _get_sequence_dbs(seq_dbs): + """Get ColabFold seq. DBs.""" + # NOTE: hard coded for ColabFold versions before 2022/07/13 + # -> afterwards UniRef30 updated to 2022_02 (and maybe more changes) + db_dict = { + "UniRef": modelcif.ReferenceDatabase( + "UniRef30", + "http://wwwuser.gwdg.de/~compbiol/colabfold/uniref30_2103.tar.gz", + version="2021_03", + ), + "Environmental": modelcif.ReferenceDatabase( + "ColabFold DB", + "http://wwwuser.gwdg.de/~compbiol/colabfold/" + + "colabfold_envdb_202108.tar.gz", + version="2021_08", + ), + } + return [db_dict[seq_db] for seq_db in seq_dbs] + + +def _get_modelcif_protocol_software(js_step): + """Assemble software entries for a ModelCIF protocol step.""" + if js_step["software"]: + if len(js_step["software"]) == 1: + sftwre = _assemble_modelcif_software(js_step["software"][0]) + else: + sftwre = [] + for sft in js_step["software"]: + sftwre.append(_assemble_modelcif_software(sft)) + sftwre = modelcif.SoftwareGroup(elements=sftwre) + if js_step["software_parameters"]: + params = [] + for key, val in js_step["software_parameters"].items(): + params.append(modelcif.SoftwareParameter(key, val)) + if isinstance(sftwre, modelcif.SoftwareGroup): + sftwre.parameters = params + else: + sftwre = modelcif.SoftwareGroup( + elements=(sftwre,), parameters=params + ) + return sftwre + return None + + +def _get_modelcif_protocol_input(js_step, target_entities, ref_dbs, model): + """Assemble input data for a ModelCIF protocol step.""" + if js_step["input"] == "target_sequences": + input_data = modelcif.data.DataGroup(target_entities) + input_data.extend(ref_dbs) + elif js_step["input"] == "model": + input_data = model + else: + raise RuntimeError(f"Unknown protocol input: '{js_step['input']}'") + return input_data + + +def _get_modelcif_protocol_output(js_step, model): + """Assemble output data for a ModelCIF protocol step.""" + if js_step["output"] == "model": + output_data = model + else: + raise RuntimeError(f"Unknown protocol output: '{js_step['output']}'") + return output_data + + +def _get_modelcif_protocol(protocol_steps, target_entities, model, ref_dbs): + """Create the protocol for the ModelCIF file.""" + protocol = modelcif.protocol.Protocol() + for js_step in protocol_steps: + sftwre = _get_modelcif_protocol_software(js_step) + input_data = _get_modelcif_protocol_input( + js_step, target_entities, ref_dbs, model + ) + output_data = _get_modelcif_protocol_output(js_step, model) + + protocol.steps.append( + modelcif.protocol.Step( + input_data=input_data, + output_data=output_data, + name=js_step["name"], + details=js_step["details"], + software=sftwre, + ) + ) + protocol.steps[-1].method_type = js_step["method_type"] + return protocol + + +def _compress_cif_file(cif_file): + """Compress cif file and delete original.""" + with open(cif_file, "rb") as f_in: + with gzip.open(cif_file + ".gz", "wb") as f_out: + shutil.copyfileobj(f_in, f_out) + os.remove(cif_file) + + +def _package_associated_files(mdl_name): + """Compress associated files into single zip file and delete original.""" + # file names must match ones from add_scores + zip_path = f"{mdl_name}.zip" + files = [f"{mdl_name}_local_pairwise_qa.cif"] + # zip settings tested for good speed vs compression + with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_BZIP2) as myzip: + for file in files: + myzip.write(file) + os.remove(file) + + +def _store_as_modelcif(data_json, ost_ent, out_dir, file_prfx, compress): + """Mix all the data into a ModelCIF file.""" + print(" generating ModelCIF objects...", end="") + pstart = timer() + # create system to gather all the data + system = modelcif.System( + title=data_json["title"], + id=data_json["data_block_id"].upper(), + model_details=data_json["model_details"], + ) + + # create target entities, references, source, asymmetric units & assembly + # for source we assume all chains come from the same taxon + # create an asymmetric unit and an entity per target sequence + asym_units = {} + _get_modelcif_entities( + data_json["target_entities"], + ihm.source.Natural( + ncbi_taxonomy_id=data_json["target_entities"][0]["up_ncbi_taxid"], + scientific_name=data_json["target_entities"][0]["up_organism"], + ), + asym_units, + system, + ) + + # audit_authors + system.authors.extend(data_json["audit_authors"]) + + # set up the model to produce coordinates + if data_json["rank_num"] == 1: + mdl_list_name = f"Model {data_json['mdl_num']} (top ranked model)" + else: + mdl_list_name = ( + f"Model {data_json['mdl_num']} " + f"(#{data_json['rank_num']} ranked model)" + ) + model = _OST2ModelCIF( + assembly=modelcif.Assembly(asym_units.values()), + asym=asym_units, + ost_entity=ost_ent, + name=mdl_list_name, + ) + print(f" ({timer()-pstart:.2f}s)") + print(" processing QA scores...", end="", flush=True) + pstart = timer() + mdl_name = os.path.basename(file_prfx) + system.repositories.append(model.add_scores(data_json, system.id, mdl_name)) + print(f" ({timer()-pstart:.2f}s)") + + system.model_groups.append( + modelcif.model.ModelGroup([model], name=data_json["model_group_name"]) + ) + + ref_dbs = _get_sequence_dbs(data_json["config_data"]["seq_dbs"]) + protocol = _get_modelcif_protocol( + data_json["protocol"], system.target_entities, model, ref_dbs + ) + system.protocols.append(protocol) + + # write modelcif System to file + print(" write to disk...", end="", flush=True) + pstart = timer() + # NOTE: this will dump PAE on path provided in add_scores + # -> hence we cheat by changing path and back while being exception-safe... + oldpwd = os.getcwd() + os.chdir(out_dir) + try: + with open(f"{mdl_name}.cif", "w", encoding="ascii") as mmcif_fh: + modelcif.dumper.write(mmcif_fh, [system]) + _package_associated_files(mdl_name) + if compress: + _compress_cif_file(f"{mdl_name}.cif") + finally: + os.chdir(oldpwd) + + print(f" ({timer()-pstart:.2f}s)") + + +def _create_interaction_json(config_data): + """Create a dictionary (mimicking JSON) that contains data which is the same + for all models.""" + data = {} + data["audit_authors"] = _get_audit_authors() + data["protocol"] = _get_protocol_steps_and_software(config_data) + data["config_data"] = config_data + return data + + +def _create_model_json(data, pdb_file, up_acs, block_id): + """Create a dictionary (mimicking JSON) that contains all the data.""" + data["target_entities"], ost_ent = _get_entities(pdb_file, up_acs) + gns = [] + for i in data["target_entities"]: + gns.append(i["up_gn"]) + data["title"] = _get_title(gns) + data["data_block_id"] = block_id + data["model_details"] = _get_model_details(gns) + data["model_group_name"] = _get_model_group_name() + return ost_ent + + +def _main(): + """Run as script.""" + opts = _parse_args() + + interaction = os.path.split(opts.model_dir)[-1] + print(f"Working on {interaction}...") + + # get UniProtKB ACs from directory name + up_acs = interaction.split("-") + + cnfg = _check_interaction_extra_files_present(opts.model_dir) + config_data = _parse_colabfold_config(cnfg) + + # iterate model directory + found_ranked = False + for fle in sorted(os.listdir(opts.model_dir)): + # iterate PDB files + if not fle.endswith(".pdb"): + continue + if opts.rank is not None and f"rank_{opts.rank}" not in fle: + continue + found_ranked = True + print(f" translating {fle}...") + pdb_start = timer() + file_prfx, uid = _check_model_extra_files_present(opts.model_dir, fle) + fle = os.path.join(opts.model_dir, fle) + + # gather data into JSON-like structure + print(" preparing data...", end="") + pstart = timer() + + mdlcf_json = _create_interaction_json(config_data) + + # uid = ..._rank_X_model_Y.pdb + mdl_name_parts = uid.split("_") + assert mdl_name_parts[-4] == "rank" + assert mdl_name_parts[-2] == "model" + mdlcf_json["rank_num"] = int(mdl_name_parts[-3]) + mdlcf_json["mdl_num"] = int(mdl_name_parts[-1]) + + ost_ent = _create_model_json(mdlcf_json, fle, up_acs, uid) + + # read quality scores from JSON file + _get_scores(mdlcf_json, file_prfx) + print(f" ({timer()-pstart:.2f}s)") + + _store_as_modelcif( + mdlcf_json, ost_ent, opts.out_dir, file_prfx, opts.compress + ) + print(f" ... done with {fle} ({timer()-pdb_start:.2f}s).") + + if opts.rank and not found_ranked: + _abort_msg(f"Could not find model of requested rank '{opts.rank}'") + print(f"... done with {opts.model_dir}.") + + +if __name__ == "__main__": + _main()