diff --git a/actions/ost-compare-ligand-structures b/actions/ost-compare-ligand-structures index 42128cb1b849bfc3245c71217d62979d46c5cf43..874ca73d2db92e6a840abb988e55aa5816258469 100644 --- a/actions/ost-compare-ligand-structures +++ b/actions/ost-compare-ligand-structures @@ -61,8 +61,9 @@ import traceback import ost from ost import io -from ost.mol.alg import ligand_scoring - +from ost.mol.alg import ligand_scoring_base +from ost.mol.alg import ligand_scoring_lddtpli +from ost.mol.alg import ligand_scoring_scrmsd def _ParseArgs(): parser = argparse.ArgumentParser(description = __doc__, @@ -205,50 +206,14 @@ def _ParseArgs(): default=0.2, help=("Coverage delta for partial ligand assignment.")) - parser.add_argument( - "-gcm", - "--global-chain-mapping", - dest="global_chain_mapping", - default=False, - action="store_true", - help=("Use a global chain mapping.")) - - parser.add_argument( - "-c", - "--chain-mapping", - nargs="+", - dest="chain_mapping", - help=("Custom mapping of chains between the reference and the model. " - "Each separate mapping consist of key:value pairs where key " - "is the chain name in reference and value is the chain name in " - "model. Only has an effect if the --global-chain-mapping flag " - "is set.")) - parser.add_argument( "-fbs", "--full-bs-search", dest="full_bs_search", default=False, action="store_true", - help=("Enumerate all potential binding sites in the model.")) - - parser.add_argument( - "-ra", - "--rmsd-assignment", - dest="rmsd_assignment", - default=True, - action="store_true", - help=("Use RMSD only for ligand assignment " - "(default since OpenStructure 2.8).")) - - parser.add_argument( - "-sa", - "--separate-assignment", - dest="rmsd_assignment", - default=True, - action="store_false", - help=("Use separate ligand assignments for RMSD and lDDT-PLI " - "(opposite of --rmsd-assignment).")) + help=("Enumerate all potential binding sites in the model when " + "searching rigid superposition for RMSD computation")) parser.add_argument( "-u", @@ -278,9 +243,9 @@ def _ParseArgs(): "--radius", dest="radius", default=4.0, - help=("Inclusion radius for the binding site. Any residue with atoms " - "within this distance of the ligand will be included in the " - "binding site.")) + help=("Inclusion radius to extract reference binding site that is used " + "for RMSD computation. Any residue with atoms within this " + "distance of the ligand will be included in the binding site.")) parser.add_argument( "--lddt-pli-radius", @@ -291,7 +256,7 @@ def _ParseArgs(): parser.add_argument( "--lddt-lp-radius", dest="lddt_lp_radius", - default=10.0, + default=15.0, help=("lDDT inclusion radius for lDDT-LP.")) parser.add_argument( @@ -302,16 +267,6 @@ def _ParseArgs(): default=3, help="Set verbosity level. Defaults to 3 (INFO).") - parser.add_argument( - "--n-max-naive", - dest="n_max_naive", - required=False, - default=12, - type=int, - help=("If number of chains in model and reference are below or equal " - "that number, the global chain mapping will naively enumerate " - "all possible mappings. A heuristic is used otherwise.")) - return parser.parse_args() @@ -448,36 +403,67 @@ def _QualifiedResidueNotation(r): ins_code=resnum.ins_code.strip("\u0000"), ) +def _SetupLDDTPLIScorer(model, model_ligands, reference, reference_ligands, args): + return ligand_scoring_lddtpli.LDDTPLIScorer(model, reference, + model_ligands = model_ligands, + target_ligands = reference_ligands, + resnum_alignments = args.residue_number_alignment, + check_resnames = args.enforce_consistency, + rename_ligand_chain = True, + substructure_match = args.substructure_match, + coverage_delta = args.coverage_delta, + lddt_pli_radius = args.lddt_pli_radius) + +def _SetupSCRMSDScorer(model, model_ligands, reference, reference_ligands, args): + return ligand_scoring_scrmsd.SCRMSDScorer(model, reference, + model_ligands = model_ligands, + target_ligands = reference_ligands, + resnum_alignments = args.residue_number_alignment, + rename_ligand_chain = True, + substructure_match = args.substructure_match, + coverage_delta = args.coverage_delta, + bs_radius = args.radius, + lddt_lp_radius = args.lddt_lp_radius) def _Process(model, model_ligands, reference, reference_ligands, args): - mapping = None - if args.chain_mapping is not None: - mapping = {x.split(':')[0]: x.split(':')[1] for x in args.chain_mapping} - - scorer = ligand_scoring.LigandScorer( - model=model, - target=reference, - model_ligands=model_ligands, - target_ligands=reference_ligands, - resnum_alignments=args.residue_number_alignment, - check_resnames=args.enforce_consistency, - rename_ligand_chain=True, - substructure_match=args.substructure_match, - coverage_delta=args.coverage_delta, - global_chain_mapping=args.global_chain_mapping, - full_bs_search=args.full_bs_search, - rmsd_assignment=args.rmsd_assignment, - unassigned=args.unassigned, - radius=args.radius, - lddt_pli_radius=args.lddt_pli_radius, - lddt_lp_radius=args.lddt_lp_radius, - n_max_naive=args.n_max_naive, - custom_mapping=mapping - ) - out = dict() + ########################## + # Setup required scorers # + ########################## + + lddtpli_scorer = None + scrmsd_scorer = None + + if args.lddt_pli: + lddtpli_scorer = _SetupLDDTPLIScorer(model, model_ligands, + reference, reference_ligands, + args) + + if args.rmsd: + scrmsd_scorer = _SetupSCRMSDScorer(model, model_ligands, + reference, reference_ligands, + args) + + # basic info on ligands only requires baseclass functionality + # doesn't matter which scorer we use + scorer = None + if lddtpli_scorer is not None: + scorer = lddtpli_scorer + elif scrmsd_scorer is not None: + scorer = scrmsd_scorer + else: + ost.LogWarning("No score selected, output will be empty.") + # just create SCRMSD scorer to fill basic ligand info + scorer = _SetupSCRMSDScorer(model, model_ligands, + reference, reference_ligands, + args) + + #################################### + # Extract / Map ligand information # + #################################### + if model_ligands is not None: # Replace model ligand by path if len(model_ligands) == len(scorer.model_ligands): @@ -486,7 +472,7 @@ def _Process(model, model_ligands, reference, reference_ligands, args): elif len(model_ligands) < len(scorer.model_ligands): # Multi-ligand SDF files were given # Map ligand => path:idx - out["model_ligands"] = [] + out["model_ligands"] = list() for ligand, filename in zip(model_ligands, args.model_ligands): assert isinstance(ligand, ost.mol.EntityHandle) for i, residue in enumerate(ligand.residues): @@ -500,9 +486,6 @@ def _Process(model, model_ligands, reference, reference_ligands, args): # Map ligand => qualified residue out["model_ligands"] = [_QualifiedResidueNotation(l) for l in scorer.model_ligands] - model_ligands_map = {k.hash_code: v for k, v in zip( - scorer.model_ligands, out["model_ligands"])} - if reference_ligands is not None: # Replace reference ligand by path if len(reference_ligands) == len(scorer.target_ligands): @@ -511,7 +494,7 @@ def _Process(model, model_ligands, reference, reference_ligands, args): elif len(reference_ligands) < len(scorer.target_ligands): # Multi-ligand SDF files were given # Map ligand => path:idx - out["reference_ligands"] = [] + out["reference_ligands"] = list() for ligand, filename in zip(reference_ligands, args.reference_ligands): assert isinstance(ligand, ost.mol.EntityHandle) for i, residue in enumerate(ligand.residues): @@ -521,80 +504,72 @@ def _Process(model, model_ligands, reference, reference_ligands, args): raise RuntimeError("Fewer ligands in the reference scorer " "(%d) than given (%d)" % ( len(scorer.target_ligands), len(reference_ligands))) - else: # Map ligand => qualified residue out["reference_ligands"] = [_QualifiedResidueNotation(l) for l in scorer.target_ligands] - reference_ligands_map = {k.hash_code: v for k, v in zip( - scorer.target_ligands, out["reference_ligands"])} - - - if not (args.lddt_pli or args.rmsd): - ost.LogWarning("No score selected, output will be empty.") - else: - out["unassigned_model_ligands"] = {} - for chain, unassigned_residues in scorer.unassigned_model_ligands.items(): - for resnum, unassigned in unassigned_residues.items(): - mdl_lig = scorer.model.FindResidue(chain, resnum) - out["unassigned_model_ligands"][model_ligands_map[ - mdl_lig.hash_code]] = unassigned - out["unassigned_reference_ligands"] = {} - for chain, unassigned_residues in scorer.unassigned_target_ligands.items(): - for resnum, unassigned in unassigned_residues.items(): - trg_lig = scorer.target.FindResidue(chain, resnum) - out["unassigned_reference_ligands"][reference_ligands_map[ - trg_lig.hash_code]] = unassigned - out["unassigned_model_ligand_descriptions"] = scorer.unassigned_model_ligand_descriptions - out["unassigned_reference_ligand_descriptions"] = scorer.unassigned_target_ligand_descriptions + ################## + # Compute scores # + ################## if args.lddt_pli: - out["lddt_pli"] = {} - for chain, lddt_pli_results in scorer.lddt_pli_details.items(): - for resnum, lddt_pli in lddt_pli_results.items(): - if args.unassigned and lddt_pli["unassigned"]: - mdl_lig = scorer.model.FindResidue(chain, resnum) - model_key = model_ligands_map[mdl_lig.hash_code] - else: - model_key = model_ligands_map[lddt_pli["model_ligand"].hash_code] - lddt_pli["reference_ligand"] = reference_ligands_map[ - lddt_pli.pop("target_ligand").hash_code] - lddt_pli["model_ligand"] = model_key - lddt_pli["bs_ref_res"] = [_QualifiedResidueNotation(r) for r in - lddt_pli["bs_ref_res"]] - lddt_pli["bs_mdl_res"] = [_QualifiedResidueNotation(r) for r in - lddt_pli["bs_mdl_res"]] - lddt_pli["inconsistent_residues"] = ["%s-%s" %( - _QualifiedResidueNotation(x), _QualifiedResidueNotation(y)) for x,y in lddt_pli[ - "inconsistent_residues"]] - out["lddt_pli"][model_key] = lddt_pli + out["lddt_pli"] = dict() + for lig_pair in lddtpli_scorer.assignment: + score = float(lddtpli_scorer.score_matrix[lig_pair[0], lig_pair[1]]) + coverage = float(lddtpli_scorer.coverage_matrix[lig_pair[0], lig_pair[1]]) + aux_data = lddtpli_scorer.aux_matrix[lig_pair[0], lig_pair[1]] + target_key = out["reference_ligands"][lig_pair[0]] + model_key = out["model_ligands"][lig_pair[1]] + out["lddt_pli"][model_key] = {"lddt_pli": score, + "coverage": coverage, + "lddt_pli_n_contacts": aux_data["lddt_pli_n_contacts"], + "model_ligand": model_key, + "reference_ligand": target_key, + "bs_ref_res": [_QualifiedResidueNotation(r) for r in + aux_data["bs_ref_res"]], + "bs_mdl_res": [_QualifiedResidueNotation(r) for r in + aux_data["bs_mdl_res"]]} + if args.unassigned: + for i in lddtpli_scorer.unassigned_model_ligands: + model_key = out["model_ligands"][i] + reason = lddtpli_scorer.guess_model_ligand_unassigned_reason(i) + out["lddt_pli"][model_key] = {"lddt_pli": None, + "unassigned_reason": reason} + if args.rmsd: - out["rmsd"] = {} - for chain, rmsd_results in scorer.rmsd_details.items(): - for _, rmsd in rmsd_results.items(): - if args.unassigned and rmsd["unassigned"]: - mdl_lig = scorer.model.FindResidue(chain, resnum) - model_key = model_ligands_map[mdl_lig.hash_code] - else: - model_key = model_ligands_map[rmsd["model_ligand"].hash_code] - rmsd["reference_ligand"] = reference_ligands_map[ - rmsd.pop("target_ligand").hash_code] - rmsd["model_ligand"] = model_key - transform_data = rmsd["transform"].data - rmsd["transform"] = [transform_data[i:i + 4] - for i in range(0, len(transform_data), 4)] - rmsd["bs_ref_res"] = [_QualifiedResidueNotation(r) for r in - rmsd["bs_ref_res"]] - rmsd["bs_ref_res_mapped"] = [_QualifiedResidueNotation(r) for r in - rmsd["bs_ref_res_mapped"]] - rmsd["bs_mdl_res_mapped"] = [_QualifiedResidueNotation(r) for r in - rmsd["bs_mdl_res_mapped"]] - rmsd["inconsistent_residues"] = ["%s-%s" %( - _QualifiedResidueNotation(x), _QualifiedResidueNotation(y)) for x,y in rmsd[ - "inconsistent_residues"]] - out["rmsd"][model_key] = rmsd + out["rmsd"] = dict() + for lig_pair in scrmsd_scorer.assignment: + score = float(scrmsd_scorer.score_matrix[lig_pair[0], lig_pair[1]]) + coverage = float(lddtpli_scorer.coverage_matrix[lig_pair[0], lig_pair[1]]) + aux_data = scrmsd_scorer.aux_matrix[lig_pair[0], lig_pair[1]] + target_key = out["reference_ligands"][lig_pair[0]] + model_key = out["model_ligands"][lig_pair[1]] + transform_data = aux_data["transform"].data + out["rmsd"][model_key] = {"rmsd": score, + "coverage": coverage, + "lddt_lp": aux_data["lddt_lp"], + "bb_rmsd": aux_data["bb_rmsd"], + "model_ligand": model_key, + "reference_ligand": target_key, + "chain_mapping": aux_data["chain_mapping"], + "bs_ref_res": [_QualifiedResidueNotation(r) for r in + aux_data["bs_ref_res"]], + "bs_ref_res_mapped": [_QualifiedResidueNotation(r) for r in + aux_data["bs_ref_res_mapped"]], + "bs_mdl_res_mapped": [_QualifiedResidueNotation(r) for r in + aux_data["bs_mdl_res_mapped"]], + "inconsistent_residues": [_QualifiedResidueNotation(r) for r in + aux_data["inconsistent_residues"]], + "transform": [transform_data[i:i + 4] + for i in range(0, len(transform_data), 4)]} + if args.unassigned: + for i in scrmsd_scorer.unassigned_model_ligands: + model_key = out["model_ligands"][i] + reason = scrmsd_scorer.guess_model_ligand_unassigned_reason(i) + out["rmsd"][model_key] = {"rmsd": None, + "unassigned_reason": reason} return out