diff --git a/actions/ost-compare-structures b/actions/ost-compare-structures index d4d04f94cb036ad58a6e64cd0371f7b6d53bcf0f..0014f7dd3802ba1f97dee4d26142e02bc64e2cec 100644 --- a/actions/ost-compare-structures +++ b/actions/ost-compare-structures @@ -371,6 +371,39 @@ def _ParseArgs(): return opts +def _RevertChainNames(ent): + """Revert chain names to original names.""" + editor = ent.EditXCS() + suffix = "_tmp" # just a suffix for temporary chain name + used_names = dict() + reverted_chains = dict() + for chain in ent.chains: + try: + original_name = chain.GetStringProp("original_name") + except Exception as ex: + ost.LogError("Cannot revert chain %s back to original: %s" % ( + chain.name, + str(ex))) + reverted_chains[chain.name] = chain.name + editor.RenameChain(chain, chain.name + suffix) + continue + new_name = original_name + if new_name not in used_names: + used_names[original_name] = 1 + reverted_chains[chain.name] = new_name + editor.RenameChain(chain, chain.name + suffix) + else: + new_name = "%s_%i" % (original_name, # dot causes selection error + used_names[original_name]) + reverted_chains[chain.name] = new_name + editor.RenameChain(chain, chain.name + suffix) + used_names[original_name] += 1 + for chain in ent.chains: + editor.RenameChain(chain, reverted_chains[chain.name[:-len(suffix)]]) + rev_out = ["%s -> %s" % (on, nn) for on, nn in reverted_chains.iteritems()] + ost.LogInfo("Reverted chains: %s" % ", ".join(rev_out)) + + def _ReadStructureFile(path): """Safely read structure file into OST entity. @@ -389,9 +422,7 @@ def _ReadStructureFile(path): if not entity.IsValid(): raise IOError("Provided file does not contain valid entity.") entity.SetName(os.path.basename(path)) - chain_mapping = {c.name: c.name for c in entity.chains} - entities.append({"entity": entity, - "chain_mapping": chain_mapping}) + entities.append(entity) except Exception: try: tmp_entity, cif_info = LoadMMCIF(path, info=True) @@ -407,11 +438,8 @@ def _ReadStructureFile(path): tbu.AddOperations(tinfo.GetOperations()) entity = tbu.PDBize(tmp_entity, min_polymer_size=0) entity.SetName(os.path.basename(path) + ".au") - chain_mapping = {c.name: c.GetStringProp("original_name") - for c in entity.chains} - entities.append({ - "entity": entity, - "chain_mapping": chain_mapping}) + _RevertChainNames(entity) + entities.append(entity) elif len(cif_info.biounits) > 1: for i, biounit in enumerate(cif_info.biounits): entity = biounit.PDBize(tmp_entity, min_polymer_size=0) @@ -419,11 +447,8 @@ def _ReadStructureFile(path): raise IOError( "Provided file does not contain valid entity.") entity.SetName(os.path.basename(path) + "." + str(i)) - chain_mapping = {c.name: c.GetStringProp("original_name") - for c in entity.chains} - entities.append({ - "entity": entity, - "chain_mapping": chain_mapping}) + _RevertChainNames(entity) + entities.append(entity) else: biounit = cif_info.biounits[0] entity = biounit.PDBize(tmp_entity, min_polymer_size=0) @@ -431,11 +456,8 @@ def _ReadStructureFile(path): raise IOError( "Provided file does not contain valid entity.") entity.SetName(os.path.basename(path)) - chain_mapping = {c.name: c.GetStringProp("original_name") - for c in entity.chains} - entities.append({ - "entity": entity, - "chain_mapping": chain_mapping}) + _RevertChainNames(entity) + entities.append(entity) except Exception as exc: raise exc @@ -471,22 +493,22 @@ def _Main(): references = _ReadStructureFile(opts.reference) if opts.molck: for i in range(len(references)): - _MolckEntity(references[i]["entity"], opts) - references[i]["entity"] = references[i]["entity"].CreateFullView() + _MolckEntity(references[i], opts) + references[i] = references[i].CreateFullView() for i in range(len(models)): - _MolckEntity(models[i]["entity"], opts) - models[i]["entity"] = models[i]["entity"].CreateFullView() + _MolckEntity(models[i], opts) + models[i] = models[i].CreateFullView() else: for i in range(len(references)): - references[i]["entity"] = references[i]["entity"].CreateFullView() + references[i] = references[i].CreateFullView() for i in range(len(models)): - models[i]["entity"] = models[i]["entity"].CreateFullView() + models[i] = models[i].CreateFullView() if opts.structural_checks: stereochemical_parameters = ReadStereoChemicalPropsFile( opts.parameter_file) ost.LogInfo("Performing structural checks for reference(s)") for reference in references: - CheckStructure(reference["entity"], + CheckStructure(reference, stereochemical_parameters.bond_table, stereochemical_parameters.angle_table, stereochemical_parameters.nonbonded_table, @@ -494,7 +516,7 @@ def _Main(): opts.angle_tolerance) ost.LogInfo("Performing structural checks for model(s)") for model in models: - CheckStructure(model["entity"], + CheckStructure(model, stereochemical_parameters.bond_table, stereochemical_parameters.angle_table, stereochemical_parameters.nonbonded_table, @@ -511,12 +533,10 @@ def _Main(): result["options"]["cwd"] = os.path.abspath(os.getcwd()) # # Perform scoring - for model_data in models: - model = model_data["entity"] + for model in models: model_name = model.GetName() model_results = dict() - for reference_data in references: - reference = reference_data["entity"] + for reference in references: reference_name = reference.GetName() reference_results = dict() ost.LogInfo("#\nComparing %s to %s" % ( @@ -530,11 +550,6 @@ def _Main(): "Using custom chain mapping: %s" % str( opts.chain_mapping)) qs_scorer.chain_mapping = opts.chain_mapping - original_chain_mapping = dict() - for mdl_cname, ref_cname in qs_scorer.chain_mapping.iteritems(): - orig_mdl_cname = model_data["chain_mapping"][mdl_cname] - orig_ref_cname = reference_data["chain_mapping"][ref_cname] - original_chain_mapping[orig_mdl_cname] = orig_ref_cname if opts.qs_score: ost.LogInfo("Computing QS-score") try: @@ -545,8 +560,7 @@ def _Main(): "reference_name": reference_name, "global_score": qs_scorer.global_score, "best_score": qs_scorer.best_score, - "chain_mapping": qs_scorer.chain_mapping, - "original_chain_mapping": original_chain_mapping + "chain_mapping": qs_scorer.chain_mapping } except qsscoring.QSscoreError as ex: # default handling: report failure and set score to 0 @@ -558,8 +572,7 @@ def _Main(): "reference_name": reference.GetName(), "global_score": 0.0, "best_score": 0.0, - "chain_mapping": qs_scorer.chain_mapping, - "original_chain_mapping": original_chain_mapping + "chain_mapping": qs_scorer.chain_mapping } # Calculate lDDT if opts.lddt: @@ -594,8 +607,6 @@ def _Main(): lddt_results["single_chain_lddt"].append({ "status": "SUCCESS", "error": "", - "original_model_chain": model_data["chain_mapping"][model_chain], - "original_reference_chain": reference_data["chain_mapping"][reference_chain], "model_chain": model_chain, "reference_chain": reference_chain, "global_score": lddt_scorer.global_score, @@ -606,8 +617,6 @@ def _Main(): lddt_results["single_chain_lddt"].append({ "status": "FAILURE", "error": str(ex), - "original_model_chain": model_data["chain_mapping"][model_chain], - "original_reference_chain": reference_data["chain_mapping"][reference_chain], "model_chain": model_chain, "reference_chain": reference_chain, "global_score": 0.0,