Skip to content
Snippets Groups Projects
Commit da46b61e authored by Rafal Gumienny's avatar Rafal Gumienny
Browse files

feat: Save chain mappings

parent b95e8095
Branches
Tags
No related merge requests found
...@@ -15,7 +15,7 @@ from ost.io import (LoadPDB, LoadMMCIF, MMCifInfoBioUnit, MMCifInfo, ...@@ -15,7 +15,7 @@ from ost.io import (LoadPDB, LoadMMCIF, MMCifInfoBioUnit, MMCifInfo,
MMCifInfoTransOp, ReadStereoChemicalPropsFile) MMCifInfoTransOp, ReadStereoChemicalPropsFile)
from ost import PushVerbosityLevel from ost import PushVerbosityLevel
from ost.mol.alg import (qsscoring, Molck, MolckSettings, lDDTSettings, from ost.mol.alg import (qsscoring, Molck, MolckSettings, lDDTSettings,
lDDTScorer, CheckStructure) CheckStructure)
from ost.conop import CompoundLib from ost.conop import CompoundLib
...@@ -150,7 +150,8 @@ def _ParseArgs(): ...@@ -150,7 +150,8 @@ def _ParseArgs():
dest="residue_number_alignment", dest="residue_number_alignment",
default=False, default=False,
action="store_true", action="store_true",
help=("Make alignment based on residue number instead using Clustal.")) help=("Make alignment based on residue number instead of using "
"Clustal."))
# #
# lDDT options # lDDT options
# #
...@@ -331,7 +332,9 @@ def _ReadStructureFile(path): ...@@ -331,7 +332,9 @@ def _ReadStructureFile(path):
if not entity.IsValid(): if not entity.IsValid():
raise IOError("Provided file does not contain valid entity.") raise IOError("Provided file does not contain valid entity.")
entity.SetName(os.path.basename(path)) entity.SetName(os.path.basename(path))
entities.append(entity) chain_mapping = {c.name: c.name for c in entity.chains}
entities.append({"entity": entity,
"chain_mapping": chain_mapping})
except Exception: except Exception:
try: try:
tmp_entity, cif_info = LoadMMCIF(path, info=True) tmp_entity, cif_info = LoadMMCIF(path, info=True)
...@@ -347,7 +350,11 @@ def _ReadStructureFile(path): ...@@ -347,7 +350,11 @@ def _ReadStructureFile(path):
tbu.AddOperations(tinfo.GetOperations()) tbu.AddOperations(tinfo.GetOperations())
entity = tbu.PDBize(tmp_entity, min_polymer_size=0) entity = tbu.PDBize(tmp_entity, min_polymer_size=0)
entity.SetName(os.path.basename(path) + ".au") entity.SetName(os.path.basename(path) + ".au")
entities.append(entity) chain_mapping = {c.name: c.GetStringProp("original_name")
for c in entity.chains}
entities.append({
"entity": entity,
"chain_mapping": chain_mapping})
elif len(cif_info.biounits) > 1: elif len(cif_info.biounits) > 1:
for i, biounit in enumerate(cif_info.biounits): for i, biounit in enumerate(cif_info.biounits):
entity = biounit.PDBize(tmp_entity, min_polymer_size=0) entity = biounit.PDBize(tmp_entity, min_polymer_size=0)
...@@ -355,7 +362,11 @@ def _ReadStructureFile(path): ...@@ -355,7 +362,11 @@ def _ReadStructureFile(path):
raise IOError( raise IOError(
"Provided file does not contain valid entity.") "Provided file does not contain valid entity.")
entity.SetName(os.path.basename(path) + "." + str(i)) entity.SetName(os.path.basename(path) + "." + str(i))
entities.append(entity) chain_mapping = {c.name: c.GetStringProp("original_name")
for c in entity.chains}
entities.append({
"entity": entity,
"chain_mapping": chain_mapping})
else: else:
biounit = cif_info.biounits[0] biounit = cif_info.biounits[0]
entity = biounit.PDBize(tmp_entity, min_polymer_size=0) entity = biounit.PDBize(tmp_entity, min_polymer_size=0)
...@@ -363,7 +374,11 @@ def _ReadStructureFile(path): ...@@ -363,7 +374,11 @@ def _ReadStructureFile(path):
raise IOError( raise IOError(
"Provided file does not contain valid entity.") "Provided file does not contain valid entity.")
entity.SetName(os.path.basename(path)) entity.SetName(os.path.basename(path))
entities.append(entity) chain_mapping = {c.name: c.GetStringProp("original_name")
for c in entity.chains}
entities.append({
"entity": entity,
"chain_mapping": chain_mapping})
except Exception as exc: except Exception as exc:
raise exc raise exc
...@@ -386,17 +401,6 @@ def _MolckEntity(entity, options): ...@@ -386,17 +401,6 @@ def _MolckEntity(entity, options):
Molck(entity, lib, ms) Molck(entity, lib, ms)
def _AveragelDDT(scorers):
scores = [s.global_score for s in scorers]
weights = [s.total_contacts for s in scorers]
nominator = sum([s * w for s, w in zip(scores, weights)])
denominator = sum(weights)
if denominator > 0:
return nominator / float(denominator)
else:
return 0.0
def _Main(): def _Main():
"""Do the magic.""" """Do the magic."""
...@@ -410,17 +414,22 @@ def _Main(): ...@@ -410,17 +414,22 @@ def _Main():
references = _ReadStructureFile(opts.reference) references = _ReadStructureFile(opts.reference)
if opts.molck: if opts.molck:
for i in range(len(references)): for i in range(len(references)):
_MolckEntity(references[i], opts) _MolckEntity(references[i]["entity"], opts)
references[i] = references[i].CreateFullView() references[i]["entity"] = references[i]["entity"].CreateFullView()
for i in range(len(models)): for i in range(len(models)):
_MolckEntity(models[i], opts) _MolckEntity(models[i]["entity"], opts)
models[i] = models[i].CreateFullView() models[i]["entity"] = models[i]["entity"].CreateFullView()
else:
for i in range(len(references)):
references[i]["entity"] = references[i]["entity"].CreateFullView()
for i in range(len(models)):
models[i]["entity"] = models[i]["entity"].CreateFullView()
if opts.structural_checks: if opts.structural_checks:
stereochemical_parameters = ReadStereoChemicalPropsFile( stereochemical_parameters = ReadStereoChemicalPropsFile(
opts.parameter_file) opts.parameter_file)
ost.LogInfo("Performing structural checks for reference(s)") ost.LogInfo("Performing structural checks for reference(s)")
for reference in references: for reference in references:
CheckStructure(reference, CheckStructure(reference["entity"],
stereochemical_parameters.bond_table, stereochemical_parameters.bond_table,
stereochemical_parameters.angle_table, stereochemical_parameters.angle_table,
stereochemical_parameters.nonbonded_table, stereochemical_parameters.nonbonded_table,
...@@ -428,7 +437,7 @@ def _Main(): ...@@ -428,7 +437,7 @@ def _Main():
opts.angle_tolerance) opts.angle_tolerance)
ost.LogInfo("Performing structural checks for model(s)") ost.LogInfo("Performing structural checks for model(s)")
for model in models: for model in models:
CheckStructure(model, CheckStructure(model["entity"],
stereochemical_parameters.bond_table, stereochemical_parameters.bond_table,
stereochemical_parameters.angle_table, stereochemical_parameters.angle_table,
stereochemical_parameters.nonbonded_table, stereochemical_parameters.nonbonded_table,
...@@ -445,10 +454,12 @@ def _Main(): ...@@ -445,10 +454,12 @@ def _Main():
result["options"]["cwd"] = os.path.abspath(os.getcwd()) result["options"]["cwd"] = os.path.abspath(os.getcwd())
# #
# Perform scoring # Perform scoring
for model in models: for model_data in models:
model = model_data["entity"]
model_name = model.GetName() model_name = model.GetName()
model_results = dict() model_results = dict()
for reference in references: for reference_data in references:
reference = reference_data["entity"]
reference_name = reference.GetName() reference_name = reference.GetName()
reference_results = dict() reference_results = dict()
ost.LogInfo("#\nComparing %s to %s" % ( ost.LogInfo("#\nComparing %s to %s" % (
...@@ -462,6 +473,11 @@ def _Main(): ...@@ -462,6 +473,11 @@ def _Main():
"Using custom chain mapping: %s" % str( "Using custom chain mapping: %s" % str(
opts.chain_mapping)) opts.chain_mapping))
qs_scorer.chain_mapping = opts.chain_mapping qs_scorer.chain_mapping = opts.chain_mapping
original_chain_mapping = dict()
for mdl_cname, ref_cname in qs_scorer.chain_mapping.iteritems():
orig_mdl_cname = model_data["chain_mapping"][mdl_cname]
orig_ref_cname = reference_data["chain_mapping"][ref_cname]
original_chain_mapping[orig_mdl_cname] = orig_ref_cname
if opts.qs_score: if opts.qs_score:
ost.LogInfo("Computing QS-score") ost.LogInfo("Computing QS-score")
try: try:
...@@ -472,7 +488,8 @@ def _Main(): ...@@ -472,7 +488,8 @@ def _Main():
"reference_name": reference_name, "reference_name": reference_name,
"global_score": qs_scorer.global_score, "global_score": qs_scorer.global_score,
"best_score": qs_scorer.best_score, "best_score": qs_scorer.best_score,
"chain_mapping": qs_scorer.chain_mapping "chain_mapping": qs_scorer.chain_mapping,
"original_chain_mapping": original_chain_mapping
} }
except qsscoring.QSscoreError as ex: except qsscoring.QSscoreError as ex:
# default handling: report failure and set score to 0 # default handling: report failure and set score to 0
...@@ -484,7 +501,8 @@ def _Main(): ...@@ -484,7 +501,8 @@ def _Main():
"reference_name": reference.GetName(), "reference_name": reference.GetName(),
"global_score": 0.0, "global_score": 0.0,
"best_score": 0.0, "best_score": 0.0,
"chain_mapping": None "chain_mapping": qs_scorer.chain_mapping,
"original_chain_mapping": original_chain_mapping
} }
# Calculate lDDT # Calculate lDDT
if opts.lddt: if opts.lddt:
...@@ -513,11 +531,16 @@ def _Main(): ...@@ -513,11 +531,16 @@ def _Main():
for lddt_scorer in oligo_lddt_scorer.sc_lddt_scorers: for lddt_scorer in oligo_lddt_scorer.sc_lddt_scorers:
# Get chains and renumber according to alignment (for lDDT) # Get chains and renumber according to alignment (for lDDT)
try: try:
model_chain = lddt_scorer.model.chains[0].GetName()
reference_chain = \
lddt_scorer.references[0].chains[0].GetName()
lddt_results["single_chain_lddt"].append({ lddt_results["single_chain_lddt"].append({
"status": "SUCCESS", "status": "SUCCESS",
"error": "", "error": "",
"model_chain": lddt_scorer.model.chains[0].GetName(), "original_model_chain": model_data["chain_mapping"][model_chain],
"reference_chain": lddt_scorer.references[0].chains[0].GetName(), "original_reference_chain": reference_data["chain_mapping"][reference_chain],
"model_chain": model_chain,
"reference_chain": reference_chain,
"global_score": lddt_scorer.global_score, "global_score": lddt_scorer.global_score,
"conserved_contacts": lddt_scorer.conserved_contacts, "conserved_contacts": lddt_scorer.conserved_contacts,
"total_contacts": lddt_scorer.total_contacts}) "total_contacts": lddt_scorer.total_contacts})
...@@ -526,8 +549,10 @@ def _Main(): ...@@ -526,8 +549,10 @@ def _Main():
lddt_results["single_chain_lddt"].append({ lddt_results["single_chain_lddt"].append({
"status": "FAILURE", "status": "FAILURE",
"error": str(ex), "error": str(ex),
"model_chain": lddt_scorer.model.chains[0].GetName(), "original_model_chain": model_data["chain_mapping"][model_chain],
"reference_chain": lddt_scorer.references[0].chains[0].GetName(), "original_reference_chain": reference_data["chain_mapping"][reference_chain],
"model_chain": model_chain,
"reference_chain": reference_chain,
"global_score": 0.0, "global_score": 0.0,
"conserved_contacts": 0.0, "conserved_contacts": 0.0,
"total_contacts": 0.0}) "total_contacts": 0.0})
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment