diff --git a/actions/ost-compare-structures b/actions/ost-compare-structures index 136e952a1ee7376f92ecab3eee6ca7fe878b7e2d..04958cf84c4d0d95e9894b05d7c9dad0d34ab316 100644 --- a/actions/ost-compare-structures +++ b/actions/ost-compare-structures @@ -15,7 +15,7 @@ from ost.io import (LoadPDB, LoadMMCIF, MMCifInfoBioUnit, MMCifInfo, MMCifInfoTransOp, ReadStereoChemicalPropsFile) from ost import PushVerbosityLevel from ost.mol.alg import (qsscoring, Molck, MolckSettings, lDDTSettings, - lDDTScorer, CheckStructure) + CheckStructure) from ost.conop import CompoundLib @@ -150,7 +150,8 @@ def _ParseArgs(): dest="residue_number_alignment", default=False, action="store_true", - help=("Make alignment based on residue number instead using Clustal.")) + help=("Make alignment based on residue number instead of using " + "Clustal.")) # # lDDT options # @@ -331,7 +332,9 @@ def _ReadStructureFile(path): if not entity.IsValid(): raise IOError("Provided file does not contain valid entity.") entity.SetName(os.path.basename(path)) - entities.append(entity) + chain_mapping = {c.name: c.name for c in entity.chains} + entities.append({"entity": entity, + "chain_mapping": chain_mapping}) except Exception: try: tmp_entity, cif_info = LoadMMCIF(path, info=True) @@ -347,7 +350,11 @@ def _ReadStructureFile(path): tbu.AddOperations(tinfo.GetOperations()) entity = tbu.PDBize(tmp_entity, min_polymer_size=0) entity.SetName(os.path.basename(path) + ".au") - entities.append(entity) + chain_mapping = {c.name: c.GetStringProp("original_name") + for c in entity.chains} + entities.append({ + "entity": entity, + "chain_mapping": chain_mapping}) elif len(cif_info.biounits) > 1: for i, biounit in enumerate(cif_info.biounits): entity = biounit.PDBize(tmp_entity, min_polymer_size=0) @@ -355,7 +362,11 @@ def _ReadStructureFile(path): raise IOError( "Provided file does not contain valid entity.") entity.SetName(os.path.basename(path) + "." + str(i)) - entities.append(entity) + chain_mapping = {c.name: c.GetStringProp("original_name") + for c in entity.chains} + entities.append({ + "entity": entity, + "chain_mapping": chain_mapping}) else: biounit = cif_info.biounits[0] entity = biounit.PDBize(tmp_entity, min_polymer_size=0) @@ -363,7 +374,11 @@ def _ReadStructureFile(path): raise IOError( "Provided file does not contain valid entity.") entity.SetName(os.path.basename(path)) - entities.append(entity) + chain_mapping = {c.name: c.GetStringProp("original_name") + for c in entity.chains} + entities.append({ + "entity": entity, + "chain_mapping": chain_mapping}) except Exception as exc: raise exc @@ -386,17 +401,6 @@ def _MolckEntity(entity, options): Molck(entity, lib, ms) -def _AveragelDDT(scorers): - scores = [s.global_score for s in scorers] - weights = [s.total_contacts for s in scorers] - nominator = sum([s * w for s, w in zip(scores, weights)]) - denominator = sum(weights) - if denominator > 0: - return nominator / float(denominator) - else: - return 0.0 - - def _Main(): """Do the magic.""" @@ -410,17 +414,22 @@ def _Main(): references = _ReadStructureFile(opts.reference) if opts.molck: for i in range(len(references)): - _MolckEntity(references[i], opts) - references[i] = references[i].CreateFullView() + _MolckEntity(references[i]["entity"], opts) + references[i]["entity"] = references[i]["entity"].CreateFullView() for i in range(len(models)): - _MolckEntity(models[i], opts) - models[i] = models[i].CreateFullView() + _MolckEntity(models[i]["entity"], opts) + models[i]["entity"] = models[i]["entity"].CreateFullView() + else: + for i in range(len(references)): + references[i]["entity"] = references[i]["entity"].CreateFullView() + for i in range(len(models)): + models[i]["entity"] = models[i]["entity"].CreateFullView() if opts.structural_checks: stereochemical_parameters = ReadStereoChemicalPropsFile( opts.parameter_file) ost.LogInfo("Performing structural checks for reference(s)") for reference in references: - CheckStructure(reference, + CheckStructure(reference["entity"], stereochemical_parameters.bond_table, stereochemical_parameters.angle_table, stereochemical_parameters.nonbonded_table, @@ -428,7 +437,7 @@ def _Main(): opts.angle_tolerance) ost.LogInfo("Performing structural checks for model(s)") for model in models: - CheckStructure(model, + CheckStructure(model["entity"], stereochemical_parameters.bond_table, stereochemical_parameters.angle_table, stereochemical_parameters.nonbonded_table, @@ -445,10 +454,12 @@ def _Main(): result["options"]["cwd"] = os.path.abspath(os.getcwd()) # # Perform scoring - for model in models: + for model_data in models: + model = model_data["entity"] model_name = model.GetName() model_results = dict() - for reference in references: + for reference_data in references: + reference = reference_data["entity"] reference_name = reference.GetName() reference_results = dict() ost.LogInfo("#\nComparing %s to %s" % ( @@ -462,6 +473,11 @@ def _Main(): "Using custom chain mapping: %s" % str( opts.chain_mapping)) qs_scorer.chain_mapping = opts.chain_mapping + original_chain_mapping = dict() + for mdl_cname, ref_cname in qs_scorer.chain_mapping.iteritems(): + orig_mdl_cname = model_data["chain_mapping"][mdl_cname] + orig_ref_cname = reference_data["chain_mapping"][ref_cname] + original_chain_mapping[orig_mdl_cname] = orig_ref_cname if opts.qs_score: ost.LogInfo("Computing QS-score") try: @@ -472,7 +488,8 @@ def _Main(): "reference_name": reference_name, "global_score": qs_scorer.global_score, "best_score": qs_scorer.best_score, - "chain_mapping": qs_scorer.chain_mapping + "chain_mapping": qs_scorer.chain_mapping, + "original_chain_mapping": original_chain_mapping } except qsscoring.QSscoreError as ex: # default handling: report failure and set score to 0 @@ -484,7 +501,8 @@ def _Main(): "reference_name": reference.GetName(), "global_score": 0.0, "best_score": 0.0, - "chain_mapping": None + "chain_mapping": qs_scorer.chain_mapping, + "original_chain_mapping": original_chain_mapping } # Calculate lDDT if opts.lddt: @@ -513,11 +531,16 @@ def _Main(): for lddt_scorer in oligo_lddt_scorer.sc_lddt_scorers: # Get chains and renumber according to alignment (for lDDT) try: + model_chain = lddt_scorer.model.chains[0].GetName() + reference_chain = \ + lddt_scorer.references[0].chains[0].GetName() lddt_results["single_chain_lddt"].append({ "status": "SUCCESS", "error": "", - "model_chain": lddt_scorer.model.chains[0].GetName(), - "reference_chain": lddt_scorer.references[0].chains[0].GetName(), + "original_model_chain": model_data["chain_mapping"][model_chain], + "original_reference_chain": reference_data["chain_mapping"][reference_chain], + "model_chain": model_chain, + "reference_chain": reference_chain, "global_score": lddt_scorer.global_score, "conserved_contacts": lddt_scorer.conserved_contacts, "total_contacts": lddt_scorer.total_contacts}) @@ -526,8 +549,10 @@ def _Main(): lddt_results["single_chain_lddt"].append({ "status": "FAILURE", "error": str(ex), - "model_chain": lddt_scorer.model.chains[0].GetName(), - "reference_chain": lddt_scorer.references[0].chains[0].GetName(), + "original_model_chain": model_data["chain_mapping"][model_chain], + "original_reference_chain": reference_data["chain_mapping"][reference_chain], + "model_chain": model_chain, + "reference_chain": reference_chain, "global_score": 0.0, "conserved_contacts": 0.0, "total_contacts": 0.0})