diff --git a/projects/novelfams/translate2modelcif.py b/projects/novelfams/translate2modelcif.py new file mode 100644 index 0000000000000000000000000000000000000000..9c3f5f80a3e64e3d8057f42311c7cd3c15bb8b8f --- /dev/null +++ b/projects/novelfams/translate2modelcif.py @@ -0,0 +1,1254 @@ +#! /usr/local/bin/ost +# -*- coding: utf-8 -*- +"""Translate models for novelfams from PDB + extra data into ModelCIF.""" +# re-enable Pylint for final version +# pylint: disable=too-many-lines + +from timeit import default_timer as timer +import argparse +import filecmp +import gzip +import os +import shutil +import sys +import zipfile + +import pandas as pd +import numpy as np + +import ihm +import ihm.citations +import modelcif +import modelcif.associated +import modelcif.dumper +import modelcif.model +import modelcif.protocol +import modelcif.reference + +# pylint: disable=unused-import,wrong-import-order +from ost import io, geom, mol + +# pylint: enable=unused-import,wrong-import-order + + +# EXAMPLE for running: +# ost translate2modelcif.py ... + + +################################################################################ +# HELPERS (mostly copied from from ost/modules/io/tests/test_io_omf.py) +# to compare PDBs +def _compare_atoms( + a1, a2, occupancy_thresh=0.01, bfactor_thresh=0.01, dist_thresh=0.001 +): + if abs(a1.occupancy - a2.occupancy) > occupancy_thresh: + return False + if abs(a1.b_factor - a2.b_factor) > bfactor_thresh: + return False + # modification: look at x,y,z spearately + if abs(a1.pos.x - a2.pos.x) > dist_thresh: + return False + if abs(a1.pos.y - a2.pos.y) > dist_thresh: + return False + if abs(a1.pos.z - a2.pos.z) > dist_thresh: + return False + if a1.is_hetatom != a2.is_hetatom: + return False + if a1.element != a2.element: + return False + return True + + +def _compare_residues( + r1, + r2, + at_occupancy_thresh=0.01, + at_bfactor_thresh=0.01, + at_dist_thresh=0.001, + skip_ss=False, + skip_rnums=False, +): + if r1.GetName() != r2.GetName(): + return False + if skip_rnums is False: + if r1.GetNumber() != r2.GetNumber(): + return False + if skip_ss is False: + if str(r1.GetSecStructure()) != str(r2.GetSecStructure()): + return False + if r1.one_letter_code != r2.one_letter_code: + return False + if r1.chem_type != r2.chem_type: + return False + if r1.chem_class != r2.chem_class: + return False + anames1 = [a.GetName() for a in r1.atoms] + anames2 = [a.GetName() for a in r2.atoms] + if sorted(anames1) != sorted(anames2): + return False + anames = anames1 + for aname in anames: + a1 = r1.FindAtom(aname) + a2 = r2.FindAtom(aname) + if not _compare_atoms( + a1, + a2, + occupancy_thresh=at_occupancy_thresh, + bfactor_thresh=at_bfactor_thresh, + dist_thresh=at_dist_thresh, + ): + return False + return True + + +def _compare_chains( + ch1, + ch2, + at_occupancy_thresh=0.01, + at_bfactor_thresh=0.01, + at_dist_thresh=0.001, + skip_ss=False, + skip_rnums=False, +): + if len(ch1.residues) != len(ch2.residues): + return False + for r1, r2 in zip(ch1.residues, ch2.residues): + if not _compare_residues( + r1, + r2, + at_occupancy_thresh=at_occupancy_thresh, + at_bfactor_thresh=at_bfactor_thresh, + at_dist_thresh=at_dist_thresh, + skip_ss=skip_ss, + skip_rnums=skip_rnums, + ): + return False + return True + + +def _compare_bonds(ent1, ent2): + bonds1 = list() + for b in ent1.bonds: + bond_partners = [str(b.first), str(b.second)] + bonds1.append([min(bond_partners), max(bond_partners), b.bond_order]) + bonds2 = list() + for b in ent2.bonds: + bond_partners = [str(b.first), str(b.second)] + bonds2.append([min(bond_partners), max(bond_partners), b.bond_order]) + return sorted(bonds1) == sorted(bonds2) + + +def _compare_ent( + ent1, + ent2, + at_occupancy_thresh=0.01, + at_bfactor_thresh=0.01, + at_dist_thresh=0.001, + skip_ss=False, + skip_cnames=False, + skip_bonds=False, + skip_rnums=False, + bu_idx=None, +): + if bu_idx is not None: + if ent1.GetName() + " " + str(bu_idx) != ent2.GetName(): + return False + else: + if ent1.GetName() != ent2.GetName(): + return False + chain_names_one = [ch.GetName() for ch in ent1.chains] + chain_names_two = [ch.GetName() for ch in ent2.chains] + if skip_cnames: + # only check whether we have the same number of chains + if len(chain_names_one) != len(chain_names_two): + return False + else: + if chain_names_one != chain_names_two: + return False + for ch1, ch2 in zip(ent1.chains, ent2.chains): + if not _compare_chains( + ch1, + ch2, + at_occupancy_thresh=at_occupancy_thresh, + at_bfactor_thresh=at_bfactor_thresh, + at_dist_thresh=at_dist_thresh, + skip_ss=skip_ss, + skip_rnums=skip_rnums, + ): + return False + if not skip_bonds: + if not _compare_bonds(ent1, ent2): + return False + return True + + +def _compare_pdbs(f_name, f1, f2): + """Use atom-by-atom comparison on PDB files allowing num. errors.""" + # first do simple file diff. + if filecmp.cmp(f1, f2): + return True + else: + ent1 = io.LoadPDB(f1) + ent2 = io.LoadPDB(f2) + # allow a bit more errors as input files can have rounding errors + if _compare_ent(ent1, ent2, 0.011, 0.011, 0.0011, True, False, True): + return True + else: + # check manually and give warning... + atom_names_1 = [a.qualified_name for a in ent1.atoms] + atom_names_2 = [a.qualified_name for a in ent2.atoms] + assert atom_names_1 == atom_names_2 + b_diffs = [ + abs(a1.b_factor - a2.b_factor) + for a1, a2 in zip(ent1.atoms, ent2.atoms) + ] + max_b_diff = max(b_diffs) + rmsd = mol.alg.CalculateRMSD(ent1.Select(""), ent2.Select("")) + _warn_msg( + f"PDB file mismatch web vs top-ranked for {f_name}: " + f"RMSD {rmsd:.3f}, max. b_factor diff {max_b_diff:.3f}" + ) + return False + + +################################################################################ + + +def _abort_msg(msg, exit_code=1): + """Write error message and exit with exit_code.""" + print(f"{msg}\nAborting.", file=sys.stderr) + sys.exit(exit_code) + + +def _warn_msg(msg): + """Write a warning message to stdout.""" + print(f"WARNING: {msg}") + + +def _check_file(file_path): + """Make sure a file exists and is actually a file.""" + if not os.path.exists(file_path): + _abort_msg(f"File not found: '{file_path}'.") + if not os.path.isfile(file_path): + _abort_msg(f"File path does not point to file: '{file_path}'.") + + +def _check_folder(dir_path): + """Make sure a file exists and is actually a file.""" + if not os.path.exists(dir_path): + _abort_msg(f"Path not found: '{dir_path}'.") + if not os.path.isdir(dir_path): + _abort_msg(f"Path does not point to a directory: '{dir_path}'.") + + +def _check_opts_folder(dir_path): + """Remove trailing '/' (return fixed one) and check if path valid.""" + if dir_path.endswith("/"): + dir_path = dir_path[:-1] + _check_folder(dir_path) + return dir_path + + +def _parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + formatter_class=argparse.RawDescriptionHelpFormatter, + description=__doc__, + ) + + parser.add_argument( + "model_dir", + type=str, + metavar="<MODEL DIR>", + help='Directory with model PDBs named "{ID}_{mdl}.pdb" (with ID and ' + + "mdl matching info in metadata).", + ) + parser.add_argument( + "out_dir", + type=str, + metavar="<OUTPUT DIR>", + help="Path to directory to store results.", + ) + parser.add_argument( + "--compress", + default=False, + action="store_true", + help="Compress ModelCIF file with gzip.", + ) + opts = parser.parse_args() + + # check input + opts.model_dir = _check_opts_folder(opts.model_dir) + # check out_dir + if opts.out_dir.endswith("/"): + opts.out_dir = opts.out_dir[:-1] + if not os.path.exists(opts.out_dir): + os.makedirs(opts.out_dir) + if not os.path.isdir(opts.out_dir): + _abort_msg(f"Path '{opts.out_dir}' does not point to a directory.") + + return opts + + +# pylint: disable=too-few-public-methods +class _GlobalPTM(modelcif.qa_metric.Global, modelcif.qa_metric.PTM): + """Predicted accuracy according to the TM-score score in [0,1]""" + + name = "pTM" + software = None + + +class _GlobalPLDDT(modelcif.qa_metric.Global, modelcif.qa_metric.PLDDT): + """Predicted accuracy according to the CA-only lDDT in [0,100]""" + + name = "pLDDT" + software = None + + +class _LocalPLDDT(modelcif.qa_metric.Local, modelcif.qa_metric.PLDDT): + """Predicted accuracy according to the CA-only lDDT in [0,100]""" + + name = "pLDDT" + software = None + + +class _LocalPairwisePAE( + modelcif.qa_metric.LocalPairwise, modelcif.qa_metric.PAE +): + """Predicted aligned error (in Angstroms)""" + + name = "PAE" + software = None + + +class _NmpfamsdbTrgRef(modelcif.reference.TargetReference): + """NMPFamsDB as target reference.""" + + name = "Other" + other_details = "NMPFamsDB" + + +class _LPeptideAlphabetWithX(ihm.LPeptideAlphabet): + """Have the default amino acid alphabet plus 'X' for unknown residues.""" + + def __init__(self): + """Create the alphabet.""" + super().__init__() + self._comps["X"] = self._comps["UNK"] + + +# pylint: enable=too-few-public-methods + + +def _get_res_num(r, use_auth=False): + """Get res. num. from auth. IDs if reading from mmCIF files.""" + if use_auth: + return int(r.GetStringProp("pdb_auth_resnum")) + return r.number.num + + +def _get_ch_name(ch, use_auth=False): + """Get chain name from auth. IDs if reading from mmCIF files.""" + if use_auth: + return ch.GetStringProp("pdb_auth_chain_name") + return ch.name + + +class _OST2ModelCIF(modelcif.model.AbInitioModel): + """Map OST entity elements to ihm.model""" + + def __init__(self, *args, **kwargs): + """Initialise a model""" + for i in ["ost_entity", "asym", "scores_json"]: + if i not in kwargs: + raise TypeError(f"Required keyword argument '{i}' not found.") + self.ost_entity = kwargs.pop("ost_entity") + self.asym = kwargs.pop("asym") + self.scores_json = kwargs.pop("scores_json") + + # use auth IDs for res. nums and chain names + self.use_auth = False + # what accuracy to use for PAE? (writer uses 3 anyway) + self.pae_digits = 3 + + # fetch plddts per residue + self.plddts = [] + for res in self.ost_entity.residues: + b_factors = [a.b_factor for a in res.atoms] + assert len(set(b_factors)) == 1 # must all be equal! + self.plddts.append(b_factors[0]) + + super().__init__(*args, **kwargs) + + def get_atoms(self): + # ToDo [internal]: Take B-factor out since its not a B-factor? + # NOTE: this assumes that _get_res_num maps residue to pos. in seqres + # within asym + for atm in self.ost_entity.atoms: + yield modelcif.model.Atom( + asym_unit=self.asym[_get_ch_name(atm.chain, self.use_auth)], + seq_id=_get_res_num(atm.residue, self.use_auth), + atom_id=atm.name, + type_symbol=atm.element, + x=atm.pos[0], + y=atm.pos[1], + z=atm.pos[2], + het=atm.is_hetatom, + biso=atm.b_factor, + occupancy=atm.occupancy, + ) + + def add_scores(self): + """Add QA metrics from AF2 scores.""" + # global scores + self.qa_metrics.extend( + ( + _GlobalPLDDT(self.scores_json["plddt_global"]), + _GlobalPTM(self.scores_json["ptm"]), + ) + ) + + # local scores + lpae = [] + i = 0 + for chn_i in self.ost_entity.chains: + ch_name = _get_ch_name(chn_i, self.use_auth) + for res_i in chn_i.residues: + # local pLDDT + res_num_i = _get_res_num(res_i, self.use_auth) + self.qa_metrics.append( + _LocalPLDDT( + self.asym[ch_name].residue(res_num_i), + self.plddts[i], + ) + ) + i += 1 + + # PAE needs to go by residue index as it also stores ones + # for missing residues (i.e. X) + if "pae" in self.scores_json: + pae_i = self.scores_json["pae"][res_num_i - 1] + for chn_j in self.ost_entity.chains: + for res_j in chn_j.residues: + res_num_j = _get_res_num(res_j, self.use_auth) + pae_ij = pae_i[res_num_j - 1] + lpae.append( + _LocalPairwisePAE( + self.asym[chn_i.name].residue(res_num_i), + self.asym[chn_j.name].residue(res_num_j), + round(pae_ij, self.pae_digits), + ) + ) + + self.qa_metrics.extend(lpae) + + +def _get_audit_authors(): + """Return the list of authors that produced this model.""" + return ( + "Pavlopoulos, Georgios A.", + "Baltoumas, Fotis A.", + "Liu, Sirui", + "Selvitopi, Oguz", + "Camargo, Antonio Pedro", + "Nayfach, Stephen", + "Azad, Ariful", + "Roux, Simon", + "Call, Lee", + "Ivanova, Natalia N.", + "Chen, I-Min", + "Paez-Espino, David", + "Karatzas, Evangelos", + "Novel Metagenome Protein Families Consortium", + "Iliopoulos, Ioannis", + "Konstantinidi, Konstantinos", + "Tiedje, James M.", + "Pett-Ridge, Jennifer", + "Baker, David", + "Visel, Axel", + "Ouzounis, Christos A.", + "Ovchinnikov, Sergey", + "Buluc, Aydin", + "Kyrpides, Nikos C.", + ) + + +def _get_metadata(metadata_file): + """Read csv file with metedata and prepare for next steps.""" + metadata = pd.read_csv( + metadata_file, sep=" ", names=["ID", "mdl", "pTM", "pLDDT"] + ) + return metadata + + +def _get_pdb_files(model_dir): + """Collect PDB files from model_dir. + + Returns a list of paths to PDB files. + """ + pdb_files = [f for f in os.listdir(model_dir) if not f.startswith(".")] + pdb_paths = [] + for f in pdb_files: + f_path = os.path.join(model_dir, f) + pdb_paths.append(f_path) + + return pdb_paths + + +def _get_config(): + """Define AF setup.""" + msa_description = ( + 'MSA created by calculating the central or "pivot" ' + "sequence of each seed MSA, and refining each " + "alignment using that sequence as the guide." + ) + mdl_description = ( + "Model generated using AlphaFold (v2.0.0 with models " + "fine-tuned to return pTM weights) producing 5 models, " + "without model relaxation, without templates, ranked " + "by pLDDT, starting from a custom MSA." + ) + af_config = {} + return { + "af_config": af_config, + "af_version": "2.0.0", + "mdl_description": mdl_description, + "msa_description": msa_description, + "use_templates": False, + "use_small_bfd": False, + "use_multimer": False, + } + + +def _get_protocol_steps_and_software(config_data): + """Create the list of protocol steps with software and parameters used.""" + protocol = [] + + # MSA step + step = { + "method_type": "coevolution MSA", + "name": None, + "details": config_data["msa_description"], + } + step["input"] = "target_sequences" + step["output"] = "MSA" + step["software"] = [] + step["software_parameters"] = {} + protocol.append(step) + + # modelling step + step = { + "method_type": "modeling", + "name": None, + "details": config_data["mdl_description"], + } + # get input data + # Must refer to data already in the JSON, so we try keywords + step["input"] = "target_sequences_and_MSA" + # get output data + # Must refer to existing data, so we try keywords + step["output"] = "model" + # get software + if config_data["use_multimer"]: + step["software"] = [ + { + "name": "AlphaFold-Multimer", + "classification": "model building", + "description": "Structure prediction", + "citation": ihm.Citation( + pmid=None, + title="Protein complex prediction with " + + "AlphaFold-Multimer.", + journal="bioRxiv", + volume=None, + page_range=None, + year=2021, + authors=[ + "Evans, R.", + "O'Neill, M.", + "Pritzel, A.", + "Antropova, N.", + "Senior, A.", + "Green, T.", + "Zidek, A.", + "Bates, R.", + "Blackwell, S.", + "Yim, J.", + "Ronneberger, O.", + "Bodenstein, S.", + "Zielinski, M.", + "Bridgland, A.", + "Potapenko, A.", + "Cowie, A.", + "Tunyasuvunakool, K.", + "Jain, R.", + "Clancy, E.", + "Kohli, P.", + "Jumper, J.", + "Hassabis, D.", + ], + doi="10.1101/2021.10.04.463034", + ), + "location": "https://github.com/deepmind/alphafold", + "type": "package", + "version": config_data["af_version"], + } + ] + else: + step["software"] = [ + { + "name": "AlphaFold", + "classification": "model building", + "description": "Structure prediction", + "citation": ihm.citations.alphafold2, + "location": "https://github.com/deepmind/alphafold", + "type": "package", + "version": config_data["af_version"], + } + ] + step["software_parameters"] = config_data["af_config"] + protocol.append(step) + + return protocol + + +def _get_title(fam_name): + """Get a title for this modelling experiment.""" + return f"AlphaFold model for NMPFamsDB Family {fam_name}" + + +def _get_model_details(fam_name, max_pLDDT, max_pTM): + """Get the model description.""" + db_url = f"https://bib.fleming.gr/NMPFamsDB/family?id={fam_name}" + return ( + f'Model generated using AlphaFold (v2.0.0) for the "Representative ' + f'Sequence" of NMPFamsDB Metagenome / Metatranscriptome Family ' + f"{fam_name}.\n\nThe 5 produced models reached a max. global pLDDT of " + f"{round(max_pLDDT, 3)} and max. pTM of {round(max_pTM, 3)}.\n\n" + f"See {db_url} for additional details." + ) + + +def _get_model_group_name(): + """Get a name for a model group.""" + return None + + +def _get_sequence(chn, use_auth=False): + """Get the sequence out of an OST chain incl. '-' for gaps in resnums.""" + # initialise (add gaps if first is not at num. 1) + lst_rn = _get_res_num(chn.residues[0], use_auth) + idx = 1 + sqe = "-" * (lst_rn - 1) + chn.residues[0].one_letter_code + + for res in chn.residues[idx:]: + lst_rn += 1 + while lst_rn != _get_res_num(res, use_auth): + sqe += "-" + lst_rn += 1 + sqe += res.one_letter_code + return sqe + + +def _get_entities(pdb_file, ref_seq, fam_name): + """Gather data for the mmCIF (target) entities.""" + _check_file(pdb_file) + + ost_ent = io.LoadPDB(pdb_file) + if ost_ent.chain_count != 1: + raise RuntimeError(f"Unexpected oligomer in {pdb_file}") + chn = ost_ent.chains[0] + sqe_gaps = _get_sequence(chn) + + # NOTE: can have gaps to accommodate "X" in ref_seq + exp_seq = sqe_gaps.replace("-", "X") + len_diff = len(ref_seq.string) - len(exp_seq) + if len_diff > 0: + exp_seq += "X" * len_diff + if exp_seq != ref_seq.string: + raise RuntimeError(f"Sequence in {pdb_file} does not match ref_seq") + + cif_ent = { + "seqres": ref_seq.string, + "pdb_sequence": sqe_gaps, + "pdb_chain_id": [_get_ch_name(chn, False)], + "fam_name": fam_name, + "description": "Representative Sequence of NMPFamsDB Family " + + f"{fam_name}", + } + + return [cif_ent], ost_ent + + +def _get_modelcif_entities(target_ents, asym_units, system): + """Create ModelCIF entities and asymmetric units.""" + alphabet = _LPeptideAlphabetWithX() + for cif_ent in target_ents: + mdlcif_ent = modelcif.Entity( + # NOTE: sequence here defines residues in model! + cif_ent["seqres"], + alphabet=alphabet, + description=cif_ent["description"], + source=None, + references=[ + _NmpfamsdbTrgRef( + cif_ent["fam_name"], + cif_ent["fam_name"], + align_begin=1, + align_end=len(cif_ent["seqres"]), + ) + ], + ) + # NOTE: this assigns (potentially new) alphabetic chain names + for pdb_chain_id in cif_ent["pdb_chain_id"]: + asym_units[pdb_chain_id] = modelcif.AsymUnit( + mdlcif_ent, + strand_id=pdb_chain_id, + ) + system.target_entities.append(mdlcif_ent) + + +def _get_assoc_pae_file(entry_id, mdl_name): + """Generate a associated file object to extract PAE to extra file.""" + return modelcif.associated.LocalPairwiseQAScoresFile( + f"{mdl_name}_local_pairwise_qa.cif", + categories=["_ma_qa_metric_local_pairwise"], + copy_categories=["_ma_qa_metric"], + entry_id=entry_id, + entry_details="This file is an associated file consisting " + + "of local pairwise QA metrics. This is a partial mmCIF " + + "file and can be validated by merging with the main " + + "mmCIF file containing the model coordinates and other " + + "associated data.", + details="Predicted aligned error", + ) + + +def _get_aln_data(): + """Generate Data object for ALN.""" + aln_data = modelcif.data.Data("Custom MSA for modelling") + aln_data.data_content_type = "coevolution MSA" + return aln_data + + +def _get_assoc_aln_file(fle_path): + """Generate a modelcif.associated.File object pointing to FASTA formatted + file containing MSA. + """ + cfile = modelcif.associated.File( + fle_path, + details="Custom MSA for modelling", + data=_get_aln_data(), + ) + cfile.file_format = "fasta" + cfile.file_content = "multiple sequence alignments" + return cfile + + +def _get_associated_files(mdl_name, arc_files): + """Create entry for associated files.""" + # package all into zip file + return modelcif.associated.Repository( + "", + [modelcif.associated.ZipFile(f"{mdl_name}.zip", files=arc_files)], + ) + # NOTE: by convention MA expects zip file with same name as model-cif + + +def _assemble_modelcif_software(soft_dict): + """Create a modelcif.Software instance from dictionary.""" + return modelcif.Software( + soft_dict["name"], + soft_dict["classification"], + soft_dict["description"], + soft_dict["location"], + soft_dict["type"], + soft_dict["version"], + citation=soft_dict["citation"], + ) + + +def _get_modelcif_protocol_software(js_step): + """Assemble software entries for a ModelCIF protocol step.""" + if js_step["software"]: + if len(js_step["software"]) == 1: + sftwre = _assemble_modelcif_software(js_step["software"][0]) + else: + sftwre = [] + for sft in js_step["software"]: + sftwre.append(_assemble_modelcif_software(sft)) + sftwre = modelcif.SoftwareGroup(elements=sftwre) + if js_step["software_parameters"]: + params = [] + for key, val in js_step["software_parameters"].items(): + params.append(modelcif.SoftwareParameter(key, val)) + if isinstance(sftwre, modelcif.SoftwareGroup): + sftwre.parameters = params + else: + sftwre = modelcif.SoftwareGroup( + elements=(sftwre,), parameters=params + ) + return sftwre + return None + + +def _get_modelcif_protocol_data(data_label, target_entities, aln_data, model): + """Assemble data for a ModelCIF protocol step.""" + if data_label == "target_sequences": + data = modelcif.data.DataGroup(target_entities) + elif data_label == "MSA": + data = aln_data + elif data_label == "target_sequences_and_MSA": + data = modelcif.data.DataGroup(target_entities) + data.append(aln_data) + elif data_label == "model": + data = model + else: + raise RuntimeError(f"Unknown protocol data: '{data_label}'") + return data + + +def _get_modelcif_protocol(protocol_steps, target_entities, aln_data, model): + """Create the protocol for the ModelCIF file.""" + protocol = modelcif.protocol.Protocol() + for js_step in protocol_steps: + sftwre = _get_modelcif_protocol_software(js_step) + input_data = _get_modelcif_protocol_data( + js_step["input"], target_entities, aln_data, model + ) + output_data = _get_modelcif_protocol_data( + js_step["output"], target_entities, aln_data, model + ) + + protocol.steps.append( + modelcif.protocol.Step( + input_data=input_data, + output_data=output_data, + name=js_step["name"], + details=js_step["details"], + software=sftwre, + ) + ) + protocol.steps[-1].method_type = js_step["method_type"] + return protocol + + +def _compress_cif_file(cif_file): + """Compress cif file and delete original.""" + with open(cif_file, "rb") as f_in: + with gzip.open(cif_file + ".gz", "wb") as f_out: + shutil.copyfileobj(f_in, f_out) + os.remove(cif_file) + + +def _package_associated_files(repo): + """Compress associated files into single zip file and delete original.""" + # zip settings tested for good speed vs compression + for archive in repo.files: + with zipfile.ZipFile(archive.path, "w", zipfile.ZIP_BZIP2) as cif_zip: + for zfile in archive.files: + cif_zip.write(zfile.path, arcname=zfile.path) + os.remove(zfile.path) + + +def _get_assoc_mdl_file(fle_path, data_json): + """Generate a modelcif.associated.File object that looks like a CIF file. + The dedicated CIFFile functionality in modelcif would also try to write it. + """ + cfile = modelcif.associated.File( + fle_path, + details=f"#{data_json['mdl_rank']} ranked model; " + + f"pTM {round(data_json['ptm'], 3)}, " + + f"pLDDT {round(data_json['plddt_global'], 3)}", + ) + cfile.file_format = "cif" + return cfile + + +def _get_assoc_zip_file(fle_path, data_json): + """Create a modelcif.associated.File object that looks like a ZIP file. + This is NOT the archive ZIP file for the PAEs but to store that in the + ZIP archive of the selected model.""" + zfile = modelcif.associated.File( + fle_path, + details="archive with multiple files for " + + f"#{data_json['mdl_rank']} ranked model", + ) + zfile.file_format = "other" + return zfile + + +def _store_as_modelcif( + data_json, + ost_ent, + out_dir, + mdl_name, + compress, + add_pae, + add_aln, + add_files, +): + """Mix all the data into a ModelCIF file.""" + print(" generating ModelCIF objects...", end="") + pstart = timer() + # create system to gather all the data + system = modelcif.System( + title=data_json["title"], + id=data_json["mdl_id"].upper(), + model_details=data_json["model_details"], + ) + + # add primary citation (not using from_pubmed_id to ensure that author names + # have no special chars) + system.citations.append( + ihm.Citation( + pmid="37821698", + title="Unraveling the functional dark matter through global " + + "metagenomics.", + journal="Nature", + volume=622, + page_range=(594, 602), + year=2023, + authors=[ + "Pavlopoulos, G.A.", + "Baltoumas, F.A.", + "Liu, S.", + "Selvitopi, O.", + "Camargo, A.P.", + "Nayfach, S.", + "Azad, A.", + "Roux, S.", + "Call, L.", + "Ivanova, N.N.", + "Chen, I.M.", + "Paez-Espino, D.", + "Karatzas, E.", + "Iliopoulos, I.", + "Konstantinidis, K.", + "Tiedje, J.M.", + "Pett-Ridge, J.", + "Baker, D.", + "Visel, A.", + "Ouzounis, C.A.", + "Ovchinnikov, S.", + "Buluc, A.", + "Kyrpides, N.C.", + ], + doi="10.1038/s41586-023-06583-7", + is_primary=True, + ) + ) + + # create an asymmetric unit and an entity per target sequence + asym_units = {} + _get_modelcif_entities(data_json["target_entities"], asym_units, system) + + # audit_authors + system.authors.extend(data_json["audit_authors"]) + + # set up the model to produce coordinates + model = _OST2ModelCIF( + assembly=modelcif.Assembly(asym_units.values()), + asym=asym_units, + ost_entity=ost_ent, + scores_json=data_json, + name=data_json["mdl_name"], + ) + print(f" ({timer()-pstart:.2f}s)") + print(" processing QA scores...", end="", flush=True) + pstart = timer() + model.add_scores() + print(f" ({timer()-pstart:.2f}s)") + + model_group = modelcif.model.ModelGroup( + [model], name=data_json["model_group_name"] + ) + system.model_groups.append(model_group) + + # handle additional files + arc_files = [] + if add_pae: + arc_files.append(_get_assoc_pae_file(system.id, mdl_name)) + if add_aln: + aln_file = _get_assoc_aln_file(data_json["aln_file_name"]) + arc_files.append(aln_file) + aln_data = aln_file.data + else: + aln_data = _get_aln_data() + aln_data.data_other_details = "MSA stored with parent entry" + arc_files.extend(add_files) + if arc_files: + system.repositories.append(_get_associated_files(mdl_name, arc_files)) + + # get data and steps + protocol = _get_modelcif_protocol( + data_json["protocol"], + system.target_entities, + aln_data, + model, + ) + system.protocols.append(protocol) + + # write modelcif System to file (NOTE: no PAE here!) + print(" write to disk...", end="", flush=True) + pstart = timer() + # copy aln file to compress them + if add_aln: + shutil.copyfile( + data_json["aln_file_path"], + os.path.join(out_dir, data_json["aln_file_name"]), + ) + # NOTE: we change path and back while being exception-safe to handle zipfile + oldpwd = os.getcwd() + os.chdir(out_dir) + mdl_fle = f"{mdl_name}.cif" + try: + with open(mdl_fle, "w", encoding="ascii") as mmcif_fh: + modelcif.dumper.write(mmcif_fh, [system]) + if arc_files: + _package_associated_files(system.repositories[0]) + if compress: + _compress_cif_file(mdl_fle) + mdl_fle += ".gz" + finally: + os.chdir(oldpwd) + print(f" ({timer()-pstart:.2f}s)") + assoc_files = [_get_assoc_mdl_file(mdl_fle, data_json)] + if arc_files: + assoc_files.append( + _get_assoc_zip_file( + system.repositories[0].files[0].path, data_json + ) + ) + return assoc_files + + +def _translate2modelcif_single( + f_name, + opts, + metadata, + pdb_files, + mdl_rank, + aln_file, + aln_path, + ref_seq, + mdl_details, + add_files=[], +): + """Convert a single model with its accompanying data to ModelCIF.""" + mdl_id = f_name + if mdl_rank > 1: + mdl_id += f"_rank_{mdl_rank}_{metadata.mdl}" + + print(f" translating {mdl_id}...") + pdb_start = timer() + + # gather data into JSON-like structure + print(" preparing data...", end="") + pstart = timer() + + config_data = _get_config() + mdlcf_json = {} + mdlcf_json["audit_authors"] = _get_audit_authors() + mdlcf_json["protocol"] = _get_protocol_steps_and_software(config_data) + mdlcf_json["config_data"] = config_data + mdlcf_json["mdl_id"] = mdl_id # used for entry ID + mdlcf_json["mdl_rank"] = mdl_rank + mdlcf_json["aln_file_name"] = aln_file + mdlcf_json["aln_file_path"] = aln_path + + # find model to process + pdb_list_sel = [f for f in pdb_files if metadata.mdl in f] + if len(pdb_list_sel) != 1: + # this should never happen + raise RuntimeError( + f"Multiple file matches found for {metadata.mdl} in {f_name}" + ) + if mdl_rank == 1: + mdlcf_json["mdl_name"] = f"Top ranked model ({metadata.mdl})" + else: + mdlcf_json["mdl_name"] = f"#{mdl_rank} ranked model ({metadata.mdl})" + + # process coordinates + pdb_file = pdb_list_sel[0] + target_entities, ost_ent = _get_entities(pdb_file, ref_seq, f_name) + mdlcf_json["target_entities"] = target_entities + # sanity check (only for top ranked model!) + if mdl_rank == 1 and opts.pdb_web_path is not None: + pdb_file_web = os.path.join(opts.pdb_web_path, f"{f_name}.pdb") + # warning handled in compare function... + _compare_pdbs(f_name, pdb_file, pdb_file_web) + + # get scores for this entry + mdlcf_json["plddt_global"] = metadata.pLDDT + mdlcf_json["ptm"] = metadata.pTM + add_pae = mdl_rank == 1 or opts.all_pae + if add_pae: + pdb_basename = os.path.basename(pdb_file) + pae_basename = os.path.splitext(pdb_basename)[0] + ".txt.gz" + pae_file = os.path.join(opts.pae_dir, pae_basename) + _check_file(pae_file) + mdlcf_json["pae"] = np.loadtxt(pae_file) + exp_num_res = len(ref_seq.string) + if mdlcf_json["pae"].shape != (exp_num_res, exp_num_res): + raise RuntimeError(f"Unexpected PAE shape in {pae_file}") + + # fill annotations + mdlcf_json["title"] = _get_title(f_name) + if mdl_rank != 1: + mdlcf_json["title"] += f" (#{mdl_rank} ranked model)" + mdlcf_json["model_details"] = mdl_details + mdlcf_json["model_group_name"] = _get_model_group_name() + print(f" ({timer()-pstart:.2f}s)") + + # save ModelCIF + assoc_files = _store_as_modelcif( + mdlcf_json, + ost_ent, + opts.out_dir, + mdl_id, + opts.compress and mdl_rank == 1, + add_pae, + opts.all_msa or mdl_rank == 1, + add_files, + ) + + # check if result can be read and has expected seq. + mdl_path = os.path.join(opts.out_dir, assoc_files[0].path) + ent, ss = io.LoadMMCIF(mdl_path, seqres=True) + exp_seqs = [ + trg_ent["pdb_sequence"] for trg_ent in mdlcf_json["target_entities"] + ] + assert ent.chain_count == len(exp_seqs), f"Bad chain count {mdl_id}" + # here we expect auth = label IDs + ent_seq = "".join([_get_sequence(chn, False) for chn in ent.chains]) + ent_seq_a = "".join([_get_sequence(chn, True) for chn in ent.chains]) + assert ent_seq == ent_seq_a + assert ent_seq == "".join(exp_seqs), f"Bad seq. {mdl_id}" + ent_seqres = "".join( + [ss.FindSequence(chn.name).string for chn in ent.chains] + ) + exp_seqres = "".join( + [trg_ent["seqres"] for trg_ent in mdlcf_json["target_entities"]] + ) + assert ent_seqres == exp_seqres, f"Bad seqres {mdl_id}" + + print(f" ... done with {mdl_id} ({timer()-pdb_start:.2f}s).") + + return assoc_files + + +def _translate2modelcif(f_name, opts, metadata_fam, pdb_files, ref_seq_check): + """Convert a family of models with their accompanying data to ModelCIF.""" + # re-enable Pylint for final version + # pylint: disable=too-many-locals + # expected to have exactly 5 models per family + if len(metadata_fam) != 5: + raise RuntimeError( + f"Unexpected number of {len(metadata_fam)} models in " + f"metadata for family {f_name}." + ) + + # skip if done already + if opts.compress: + cifext = "cif.gz" + else: + cifext = "cif" + mdl_path = os.path.join(opts.out_dir, f"{f_name}.{cifext}") + if os.path.exists(mdl_path): + print(f" {f_name} already done...") + return + + # get aln_data and ref. seq. for this entry + aln_file = f"{f_name}.fasta" + aln_path = os.path.join(opts.msa_data_dir, aln_file) + # expected 11 extra families compared to web data but those don't have MSAs + # -> skipped for consistency and to keep code simple here + if not os.path.exists(aln_path): + _warn_msg(f"Missing MSA for {f_name}. Skipping...") + return + + aln = io.LoadAlignment( + aln_path + ) # note: this checks that it's an actual MSA + ref_seq = aln.sequences[0] + if ref_seq_check is not None and ref_seq_check.string != ref_seq.string: + raise RuntimeError(f"Sequence mismatch for {f_name}") + + # get global model details + mdl_details = _get_model_details( + f_name, metadata_fam.pLDDT.max(), metadata_fam.pTM.max() + ) + # rank available models + metadata_sorted = metadata_fam.sort_values("pLDDT", ascending=False) + add_files = [] + if opts.all_models: + for idx in range(1, 5): + assoc_files = _translate2modelcif_single( + f_name, + opts, + metadata_sorted.iloc[idx], + pdb_files, + idx + 1, + aln_file, + aln_path, + ref_seq, + mdl_details, + ) + add_files.extend(assoc_files) + # process top ranked one + _translate2modelcif_single( + f_name, + opts, + metadata_sorted.iloc[0], + pdb_files, + 1, + aln_file, + aln_path, + ref_seq, + mdl_details, + add_files, + ) + + +def _main(): + """Run as script.""" + s_tmstmp = timer() + opts = _parse_args() + + # get a list of PDB files with the path to load them. + pdb_files = _get_pdb_files(opts.model_dir) + n_mdls = len(pdb_files) + + # iterate over models + print(f"Processing {n_mdls} models.") + tmstmp = s_tmstmp + for f_name in sorted(pdb_files): + n_mdls -= 1 + """ + if f_name.startswith(opts.prefix): + _translate2modelcif( + f_name, + opts, + metadata_full[metadata_full.ID == f_name], + pdb_files_split[f_name], + refseqs.FindSequence(f_name) if refseqs is not None else None, + ) + """ + # report progress after a bit of time + if timer() - tmstmp > 60: + print( + f"... {n_mdls} models left after " + + f"{(timer() - s_tmstmp)/60:.2f}min, last seen: " + + f"{os.path.splitext(os.path.basename(f_name))[0]}" + ) + tmstmp = timer() + + print( + f"... done, {n_mdls} models left after " + + f"{(timer() - s_tmstmp)/60:.2f}min." + ) + + +if __name__ == "__main__": + _main()