Skip to content
Snippets Groups Projects
Commit 355145c2 authored by Studer Gabriel's avatar Studer Gabriel
Browse files

ligand scoring: adapt compare-ligand-structures action to new developments

parent f66f46a7
No related branches found
No related tags found
No related merge requests found
......@@ -61,8 +61,9 @@ import traceback
import ost
from ost import io
from ost.mol.alg import ligand_scoring
from ost.mol.alg import ligand_scoring_base
from ost.mol.alg import ligand_scoring_lddtpli
from ost.mol.alg import ligand_scoring_scrmsd
def _ParseArgs():
parser = argparse.ArgumentParser(description = __doc__,
......@@ -205,50 +206,14 @@ def _ParseArgs():
default=0.2,
help=("Coverage delta for partial ligand assignment."))
parser.add_argument(
"-gcm",
"--global-chain-mapping",
dest="global_chain_mapping",
default=False,
action="store_true",
help=("Use a global chain mapping."))
parser.add_argument(
"-c",
"--chain-mapping",
nargs="+",
dest="chain_mapping",
help=("Custom mapping of chains between the reference and the model. "
"Each separate mapping consist of key:value pairs where key "
"is the chain name in reference and value is the chain name in "
"model. Only has an effect if the --global-chain-mapping flag "
"is set."))
parser.add_argument(
"-fbs",
"--full-bs-search",
dest="full_bs_search",
default=False,
action="store_true",
help=("Enumerate all potential binding sites in the model."))
parser.add_argument(
"-ra",
"--rmsd-assignment",
dest="rmsd_assignment",
default=True,
action="store_true",
help=("Use RMSD only for ligand assignment "
"(default since OpenStructure 2.8)."))
parser.add_argument(
"-sa",
"--separate-assignment",
dest="rmsd_assignment",
default=True,
action="store_false",
help=("Use separate ligand assignments for RMSD and lDDT-PLI "
"(opposite of --rmsd-assignment)."))
help=("Enumerate all potential binding sites in the model when "
"searching rigid superposition for RMSD computation"))
parser.add_argument(
"-u",
......@@ -278,9 +243,9 @@ def _ParseArgs():
"--radius",
dest="radius",
default=4.0,
help=("Inclusion radius for the binding site. Any residue with atoms "
"within this distance of the ligand will be included in the "
"binding site."))
help=("Inclusion radius to extract reference binding site that is used "
"for RMSD computation. Any residue with atoms within this "
"distance of the ligand will be included in the binding site."))
parser.add_argument(
"--lddt-pli-radius",
......@@ -291,7 +256,7 @@ def _ParseArgs():
parser.add_argument(
"--lddt-lp-radius",
dest="lddt_lp_radius",
default=10.0,
default=15.0,
help=("lDDT inclusion radius for lDDT-LP."))
parser.add_argument(
......@@ -302,16 +267,6 @@ def _ParseArgs():
default=3,
help="Set verbosity level. Defaults to 3 (INFO).")
parser.add_argument(
"--n-max-naive",
dest="n_max_naive",
required=False,
default=12,
type=int,
help=("If number of chains in model and reference are below or equal "
"that number, the global chain mapping will naively enumerate "
"all possible mappings. A heuristic is used otherwise."))
return parser.parse_args()
......@@ -448,36 +403,67 @@ def _QualifiedResidueNotation(r):
ins_code=resnum.ins_code.strip("\u0000"),
)
def _SetupLDDTPLIScorer(model, model_ligands, reference, reference_ligands, args):
return ligand_scoring_lddtpli.LDDTPLIScorer(model, reference,
model_ligands = model_ligands,
target_ligands = reference_ligands,
resnum_alignments = args.residue_number_alignment,
check_resnames = args.enforce_consistency,
rename_ligand_chain = True,
substructure_match = args.substructure_match,
coverage_delta = args.coverage_delta,
lddt_pli_radius = args.lddt_pli_radius)
def _SetupSCRMSDScorer(model, model_ligands, reference, reference_ligands, args):
return ligand_scoring_scrmsd.SCRMSDScorer(model, reference,
model_ligands = model_ligands,
target_ligands = reference_ligands,
resnum_alignments = args.residue_number_alignment,
rename_ligand_chain = True,
substructure_match = args.substructure_match,
coverage_delta = args.coverage_delta,
bs_radius = args.radius,
lddt_lp_radius = args.lddt_lp_radius)
def _Process(model, model_ligands, reference, reference_ligands, args):
mapping = None
if args.chain_mapping is not None:
mapping = {x.split(':')[0]: x.split(':')[1] for x in args.chain_mapping}
scorer = ligand_scoring.LigandScorer(
model=model,
target=reference,
model_ligands=model_ligands,
target_ligands=reference_ligands,
resnum_alignments=args.residue_number_alignment,
check_resnames=args.enforce_consistency,
rename_ligand_chain=True,
substructure_match=args.substructure_match,
coverage_delta=args.coverage_delta,
global_chain_mapping=args.global_chain_mapping,
full_bs_search=args.full_bs_search,
rmsd_assignment=args.rmsd_assignment,
unassigned=args.unassigned,
radius=args.radius,
lddt_pli_radius=args.lddt_pli_radius,
lddt_lp_radius=args.lddt_lp_radius,
n_max_naive=args.n_max_naive,
custom_mapping=mapping
)
out = dict()
##########################
# Setup required scorers #
##########################
lddtpli_scorer = None
scrmsd_scorer = None
if args.lddt_pli:
lddtpli_scorer = _SetupLDDTPLIScorer(model, model_ligands,
reference, reference_ligands,
args)
if args.rmsd:
scrmsd_scorer = _SetupSCRMSDScorer(model, model_ligands,
reference, reference_ligands,
args)
# basic info on ligands only requires baseclass functionality
# doesn't matter which scorer we use
scorer = None
if lddtpli_scorer is not None:
scorer = lddtpli_scorer
elif scrmsd_scorer is not None:
scorer = scrmsd_scorer
else:
ost.LogWarning("No score selected, output will be empty.")
# just create SCRMSD scorer to fill basic ligand info
scorer = _SetupSCRMSDScorer(model, model_ligands,
reference, reference_ligands,
args)
####################################
# Extract / Map ligand information #
####################################
if model_ligands is not None:
# Replace model ligand by path
if len(model_ligands) == len(scorer.model_ligands):
......@@ -486,7 +472,7 @@ def _Process(model, model_ligands, reference, reference_ligands, args):
elif len(model_ligands) < len(scorer.model_ligands):
# Multi-ligand SDF files were given
# Map ligand => path:idx
out["model_ligands"] = []
out["model_ligands"] = list()
for ligand, filename in zip(model_ligands, args.model_ligands):
assert isinstance(ligand, ost.mol.EntityHandle)
for i, residue in enumerate(ligand.residues):
......@@ -500,9 +486,6 @@ def _Process(model, model_ligands, reference, reference_ligands, args):
# Map ligand => qualified residue
out["model_ligands"] = [_QualifiedResidueNotation(l) for l in scorer.model_ligands]
model_ligands_map = {k.hash_code: v for k, v in zip(
scorer.model_ligands, out["model_ligands"])}
if reference_ligands is not None:
# Replace reference ligand by path
if len(reference_ligands) == len(scorer.target_ligands):
......@@ -511,7 +494,7 @@ def _Process(model, model_ligands, reference, reference_ligands, args):
elif len(reference_ligands) < len(scorer.target_ligands):
# Multi-ligand SDF files were given
# Map ligand => path:idx
out["reference_ligands"] = []
out["reference_ligands"] = list()
for ligand, filename in zip(reference_ligands, args.reference_ligands):
assert isinstance(ligand, ost.mol.EntityHandle)
for i, residue in enumerate(ligand.residues):
......@@ -521,80 +504,72 @@ def _Process(model, model_ligands, reference, reference_ligands, args):
raise RuntimeError("Fewer ligands in the reference scorer "
"(%d) than given (%d)" % (
len(scorer.target_ligands), len(reference_ligands)))
else:
# Map ligand => qualified residue
out["reference_ligands"] = [_QualifiedResidueNotation(l) for l in scorer.target_ligands]
reference_ligands_map = {k.hash_code: v for k, v in zip(
scorer.target_ligands, out["reference_ligands"])}
if not (args.lddt_pli or args.rmsd):
ost.LogWarning("No score selected, output will be empty.")
else:
out["unassigned_model_ligands"] = {}
for chain, unassigned_residues in scorer.unassigned_model_ligands.items():
for resnum, unassigned in unassigned_residues.items():
mdl_lig = scorer.model.FindResidue(chain, resnum)
out["unassigned_model_ligands"][model_ligands_map[
mdl_lig.hash_code]] = unassigned
out["unassigned_reference_ligands"] = {}
for chain, unassigned_residues in scorer.unassigned_target_ligands.items():
for resnum, unassigned in unassigned_residues.items():
trg_lig = scorer.target.FindResidue(chain, resnum)
out["unassigned_reference_ligands"][reference_ligands_map[
trg_lig.hash_code]] = unassigned
out["unassigned_model_ligand_descriptions"] = scorer.unassigned_model_ligand_descriptions
out["unassigned_reference_ligand_descriptions"] = scorer.unassigned_target_ligand_descriptions
##################
# Compute scores #
##################
if args.lddt_pli:
out["lddt_pli"] = {}
for chain, lddt_pli_results in scorer.lddt_pli_details.items():
for resnum, lddt_pli in lddt_pli_results.items():
if args.unassigned and lddt_pli["unassigned"]:
mdl_lig = scorer.model.FindResidue(chain, resnum)
model_key = model_ligands_map[mdl_lig.hash_code]
else:
model_key = model_ligands_map[lddt_pli["model_ligand"].hash_code]
lddt_pli["reference_ligand"] = reference_ligands_map[
lddt_pli.pop("target_ligand").hash_code]
lddt_pli["model_ligand"] = model_key
lddt_pli["bs_ref_res"] = [_QualifiedResidueNotation(r) for r in
lddt_pli["bs_ref_res"]]
lddt_pli["bs_mdl_res"] = [_QualifiedResidueNotation(r) for r in
lddt_pli["bs_mdl_res"]]
lddt_pli["inconsistent_residues"] = ["%s-%s" %(
_QualifiedResidueNotation(x), _QualifiedResidueNotation(y)) for x,y in lddt_pli[
"inconsistent_residues"]]
out["lddt_pli"][model_key] = lddt_pli
out["lddt_pli"] = dict()
for lig_pair in lddtpli_scorer.assignment:
score = float(lddtpli_scorer.score_matrix[lig_pair[0], lig_pair[1]])
coverage = float(lddtpli_scorer.coverage_matrix[lig_pair[0], lig_pair[1]])
aux_data = lddtpli_scorer.aux_matrix[lig_pair[0], lig_pair[1]]
target_key = out["reference_ligands"][lig_pair[0]]
model_key = out["model_ligands"][lig_pair[1]]
out["lddt_pli"][model_key] = {"lddt_pli": score,
"coverage": coverage,
"lddt_pli_n_contacts": aux_data["lddt_pli_n_contacts"],
"model_ligand": model_key,
"reference_ligand": target_key,
"bs_ref_res": [_QualifiedResidueNotation(r) for r in
aux_data["bs_ref_res"]],
"bs_mdl_res": [_QualifiedResidueNotation(r) for r in
aux_data["bs_mdl_res"]]}
if args.unassigned:
for i in lddtpli_scorer.unassigned_model_ligands:
model_key = out["model_ligands"][i]
reason = lddtpli_scorer.guess_model_ligand_unassigned_reason(i)
out["lddt_pli"][model_key] = {"lddt_pli": None,
"unassigned_reason": reason}
if args.rmsd:
out["rmsd"] = {}
for chain, rmsd_results in scorer.rmsd_details.items():
for _, rmsd in rmsd_results.items():
if args.unassigned and rmsd["unassigned"]:
mdl_lig = scorer.model.FindResidue(chain, resnum)
model_key = model_ligands_map[mdl_lig.hash_code]
else:
model_key = model_ligands_map[rmsd["model_ligand"].hash_code]
rmsd["reference_ligand"] = reference_ligands_map[
rmsd.pop("target_ligand").hash_code]
rmsd["model_ligand"] = model_key
transform_data = rmsd["transform"].data
rmsd["transform"] = [transform_data[i:i + 4]
for i in range(0, len(transform_data), 4)]
rmsd["bs_ref_res"] = [_QualifiedResidueNotation(r) for r in
rmsd["bs_ref_res"]]
rmsd["bs_ref_res_mapped"] = [_QualifiedResidueNotation(r) for r in
rmsd["bs_ref_res_mapped"]]
rmsd["bs_mdl_res_mapped"] = [_QualifiedResidueNotation(r) for r in
rmsd["bs_mdl_res_mapped"]]
rmsd["inconsistent_residues"] = ["%s-%s" %(
_QualifiedResidueNotation(x), _QualifiedResidueNotation(y)) for x,y in rmsd[
"inconsistent_residues"]]
out["rmsd"][model_key] = rmsd
out["rmsd"] = dict()
for lig_pair in scrmsd_scorer.assignment:
score = float(scrmsd_scorer.score_matrix[lig_pair[0], lig_pair[1]])
coverage = float(lddtpli_scorer.coverage_matrix[lig_pair[0], lig_pair[1]])
aux_data = scrmsd_scorer.aux_matrix[lig_pair[0], lig_pair[1]]
target_key = out["reference_ligands"][lig_pair[0]]
model_key = out["model_ligands"][lig_pair[1]]
transform_data = aux_data["transform"].data
out["rmsd"][model_key] = {"rmsd": score,
"coverage": coverage,
"lddt_lp": aux_data["lddt_lp"],
"bb_rmsd": aux_data["bb_rmsd"],
"model_ligand": model_key,
"reference_ligand": target_key,
"chain_mapping": aux_data["chain_mapping"],
"bs_ref_res": [_QualifiedResidueNotation(r) for r in
aux_data["bs_ref_res"]],
"bs_ref_res_mapped": [_QualifiedResidueNotation(r) for r in
aux_data["bs_ref_res_mapped"]],
"bs_mdl_res_mapped": [_QualifiedResidueNotation(r) for r in
aux_data["bs_mdl_res_mapped"]],
"inconsistent_residues": [_QualifiedResidueNotation(r) for r in
aux_data["inconsistent_residues"]],
"transform": [transform_data[i:i + 4]
for i in range(0, len(transform_data), 4)]}
if args.unassigned:
for i in scrmsd_scorer.unassigned_model_ligands:
model_key = out["model_ligands"][i]
reason = scrmsd_scorer.guess_model_ligand_unassigned_reason(i)
out["rmsd"][model_key] = {"rmsd": None,
"unassigned_reason": reason}
return out
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment