diff --git a/actions/ost-compare-ligand-structures b/actions/ost-compare-ligand-structures index b51a6c45418b4776b2b132296c742dc148a816c8..3c33dceb50fa6378532e6066067283de524530ac 100644 --- a/actions/ost-compare-ligand-structures +++ b/actions/ost-compare-ligand-structures @@ -176,6 +176,16 @@ def _ParseArgs(): action="store_true", help=("Use a global chain mapping.")) + parser.add_argument( + "-c", + "--chain-mapping", + nargs="+", + dest="chain_mapping", + help=("Custom mapping of chains between the reference and the model. " + "Each separate mapping consist of key:value pairs where key " + "is the chain name in reference and value is the chain name in " + "model. Only has an effect if global-chain-mapping flag is set.")) + parser.add_argument( "-ra", "--rmsd-assignment", @@ -330,6 +340,10 @@ def _QualifiedResidueNotation(r): def _Process(model, model_ligands, reference, reference_ligands, args): + mapping = None + if args.chain_mapping is not None: + mapping = {x.split(':')[0]: x.split(':')[1] for x in args.chain_mapping} + scorer = ligand_scoring.LigandScorer( model=model, target=reference, @@ -344,7 +358,8 @@ def _Process(model, model_ligands, reference, reference_ligands, args): radius=args.radius, lddt_pli_radius=args.lddt_pli_radius, lddt_lp_radius=args.lddt_lp_radius, - n_max_naive=args.n_max_naive + n_max_naive=args.n_max_naive, + custom_mapping=mapping ) out = dict() diff --git a/modules/mol/alg/pymod/ligand_scoring.py b/modules/mol/alg/pymod/ligand_scoring.py index ed1e955220f6f8976b49b90f8ea343cfe11cc51e..09f578e7887b6ea12275e5f1c13006746c398530 100644 --- a/modules/mol/alg/pymod/ligand_scoring.py +++ b/modules/mol/alg/pymod/ligand_scoring.py @@ -205,6 +205,12 @@ class LigandScorer: ligand may be scored against different chain mappings). :type global_chain_mapping: :class:`bool` + :param custom_mapping: Provide custom chain mapping between *model* and + *target* that is used as global chain mapping. + Dictionary with target chain names as key and model + chain names as value. Only has an effect if + *global_chain_mapping* is True. + :type custom_mapping: :class:`dict` :param rmsd_assignment: assign ligands based on RMSD only. The default (False) is to use a combination of lDDT-PLI and RMSD for the assignment. @@ -222,7 +228,8 @@ class LigandScorer: chain_mapper=None, substructure_match=False, radius=4.0, lddt_pli_radius=6.0, lddt_lp_radius=10.0, binding_sites_topn=100000, global_chain_mapping=False, - rmsd_assignment=False, n_max_naive=12): + rmsd_assignment=False, n_max_naive=12, + custom_mapping=None): if isinstance(model, mol.EntityView): self.model = mol.CreateEntityFromView(model, False) @@ -287,6 +294,9 @@ class LigandScorer: self._binding_sites = {} self.__model_mapping = None + if custom_mapping is not None: + self._set_custom_mapping(custom_mapping) + @property def chain_mapper(self): """ Chain mapper object for the given :attr:`target`. @@ -1054,6 +1064,98 @@ class LigandScorer: return self._lddt_pli_details + def _set_custom_mapping(self, mapping): + """ sets self.__model_mapping with a full blown MappingResult object + + :param mapping: mapping with trg chains as key and mdl ch as values + :type mapping: :class:`dict` + """ + chain_mapper = self.chain_mapper + chem_mapping, chem_group_alns, mdl = \ + chain_mapper.GetChemMapping(self.model) + + # now that we have a chem mapping, lets do consistency checks + # - check whether chain names are unique and available in structures + # - check whether the mapped chains actually map to the same chem groups + if len(mapping) != len(set(mapping.keys())): + raise RuntimeError(f"Expect unique trg chain names in mapping. Got " + f"{mapping.keys()}") + if len(mapping) != len(set(mapping.values())): + raise RuntimeError(f"Expect unique mdl chain names in mapping. Got " + f"{mapping.values()}") + + trg_chains = set([ch.GetName() for ch in chain_mapper.target.chains]) + mdl_chains = set([ch.GetName() for ch in mdl.chains]) + for k,v in mapping.items(): + if k not in trg_chains: + raise RuntimeError(f"Target chain \"{k}\" is not available " + f"in target processed for chain mapping " + f"({trg_chains})") + if v not in mdl_chains: + raise RuntimeError(f"Model chain \"{v}\" is not available " + f"in model processed for chain mapping " + f"({mdl_chains})") + + for trg_ch, mdl_ch in mapping.items(): + trg_group_idx = None + mdl_group_idx = None + for idx, group in enumerate(chain_mapper.chem_groups): + if trg_ch in group: + trg_group_idx = idx + break + for idx, group in enumerate(chem_mapping): + if mdl_ch in group: + mdl_group_idx = idx + break + if trg_group_idx is None or mdl_group_idx is None: + raise RuntimeError("Could not establish a valid chem grouping " + "of chain names provided in custom mapping.") + + if trg_group_idx != mdl_group_idx: + raise RuntimeError(f"Chem group mismatch in custom mapping: " + f"target chain \"{trg_ch}\" groups with the " + f"following chemically equivalent target " + f"chains: " + f"{chain_mapper.chem_groups[trg_group_idx]} " + f"but model chain \"{mdl_ch}\" maps to the " + f"following target chains: " + f"{chain_mapper.chem_groups[mdl_group_idx]}") + + pairs = set([(trg_ch, mdl_ch) for trg_ch, mdl_ch in mapping.items()]) + ref_mdl_alns = \ + chain_mapping._GetRefMdlAlns(chain_mapper.chem_groups, + chain_mapper.chem_group_alignments, + chem_mapping, + chem_group_alns, + pairs = pairs) + + # translate mapping format + final_mapping = list() + for ref_chains in chain_mapper.chem_groups: + mapped_mdl_chains = list() + for ref_ch in ref_chains: + if ref_ch in mapping: + mapped_mdl_chains.append(mapping[ref_ch]) + else: + mapped_mdl_chains.append(None) + final_mapping.append(mapped_mdl_chains) + + alns = dict() + for ref_group, mdl_group in zip(chain_mapper.chem_groups, + final_mapping): + for ref_ch, mdl_ch in zip(ref_group, mdl_group): + if ref_ch is not None and mdl_ch is not None: + aln = ref_mdl_alns[(ref_ch, mdl_ch)] + trg_view = chain_mapper.target.Select(f"cname={ref_ch}") + mdl_view = mdl.Select(f"cname={mdl_ch}") + aln.AttachView(0, trg_view) + aln.AttachView(1, mdl_view) + alns[(ref_ch, mdl_ch)] = aln + + self.__model_mapping = chain_mapping.MappingResult(chain_mapper.target, mdl, + chain_mapper.chem_groups, + final_mapping, alns) + def _ResidueToGraph(residue, by_atom_index=False): """Return a NetworkX graph representation of the residue.