diff --git a/modules/bindings/pymod/cadscore.py b/modules/bindings/pymod/cadscore.py index 7ecf79874e65ae912f5b4e189f7a6e61e8edade2..39676da905fd0da0ebeaccf843a4f92dfba4082c 100644 --- a/modules/bindings/pymod/cadscore.py +++ b/modules/bindings/pymod/cadscore.py @@ -31,7 +31,19 @@ Authors: Valerio Mariani, Alessandro Barbato import subprocess, os, tempfile, platform, re from ost import settings, io, mol -def _SetupFiles(model,reference): +def _SetupFiles(model, reference, chain_mapping): + + if chain_mapping is not None: + model_handle = model + if isinstance(model_handle, mol.EntityView): + model_handle = mol.CreateEntityFromView(model_handle, False) + mapped_model = mol.CreateEntity() + ed = mapped_model.EditXCS() + for k,v in chain_mapping.items(): + if v is not None: + ed.InsertChain(v, model_handle.FindChain(k), deep=True) + model = mapped_model + # create temporary directory tmp_dir_name=tempfile.mkdtemp() dia = 'PDB' @@ -60,6 +72,7 @@ def _SetupFiles(model,reference): dia = 'CHARMM' break; io.SavePDB(reference, os.path.join(tmp_dir_name, 'reference.pdb'),dialect=dia) + return tmp_dir_name def _CleanupFiles(dir_name): @@ -241,7 +254,6 @@ def _RunCAD(tmp_dir, mode, cad_bin_path, old_regime): raise RuntimeError("Invalid CAD mode! Allowed are: " "[\"classic\", \"voronota\"]") - return CADResult(globalAA,localAA) def _HasInsertionCodes(model, reference): @@ -253,18 +265,27 @@ def _HasInsertionCodes(model, reference): return True return False -def _MapLabels(model, cad_results, label): - for k,v in cad_results.localAA.items(): - r = model.FindResidue(k[0], k[1]) - if not r.IsValid(): - raise RuntimeError("Failed to map cadscore on residues: " + - "CAD score estimated for residue in chain \"" + - k[0] + "\" with ResNum " + str(k[1]) + ". Residue " + - "could not be found in model.") - r.SetFloatProp(label, v) +def _MapLabels(model, cad_results, label, chain_mapping): + + if chain_mapping is None: + for k,v in cad_results.localAA.items(): + r = model.FindResidue(k[0], k[1]) + if r.IsValid(): + r.SetFloatProp(label, v) + else: + # chain_mapping has mdl chains as key and target chains as values + # the raw CAD results refer to the target chains => reverse mapping + rev_mapping = {v:k for k,v in chain_mapping.items()} + for k,v in cad_results.localAA.items(): + cname = k[0] + rnum = k[1] + if cname in rev_mapping: + r = model.FindResidue(rev_mapping[cname], rnum) + if r.IsValid(): + r.SetFloatProp(label, v) def CADScore(model, reference, mode = "voronota", label = "localcad", - old_regime = False, cad_bin_path = None): + old_regime = False, cad_bin_path = None, chain_mapping=None): """ Calculates global and local atom-atom (AA) CAD Scores. @@ -303,6 +324,13 @@ def CADScore(model, reference, mode = "voronota", label = "localcad", or ["voronota-cadscore"] for "voronota" *mode*). If not set, the env path is searched. :type cad_bin_path: :class:`str` + :param chain_mapping: Provide custom chain mapping in case of oligomers + (only supported for "voronota" *mode*). Provided as + :class:`dict` with model chain name as key and target + chain name as value. If set, scoring happens on a + substructure of model that is stripped to chains with + valid mapping. + :type chain_mapping: :class:`dict` :returns: The result of the CAD score calculation :rtype: :class:`CADResult` @@ -313,8 +341,32 @@ def CADScore(model, reference, mode = "voronota", label = "localcad", if mode == "classic" and _HasInsertionCodes(model, reference): raise RuntimeError("The classic CAD score implementation does not support " "insertion codes in residues") - tmp_dir_name=_SetupFiles(model, reference) + + if chain_mapping is not None: + if model == "classic": + raise RuntimeError("The classic CAD score implementation does not " + "support custom chain mappings") + + # do consistency checks of custom chain mapping + mdl_cnames = [ch.GetName() for ch in model.chains] + ref_cnames = [ch.GetName() for ch in reference.chains] + + # check that each model chain name in the mapping is actually there + for cname in chain_mapping.keys(): + if cname not in mdl_cnames: + raise RuntimeError(f"Model chain name \"{cname}\" provided in " + f"custom chain mapping is not present in provided " + f"model structure.") + + # check that each target chain name in the mapping is actually there + for cname in chain_mapping.values(): + if cname not in ref_cnames: + raise RuntimeError(f"Reference chain name \"{cname}\" provided in " + f"custom chain mapping is not present in provided " + f"reference structure.") + + tmp_dir_name=_SetupFiles(model, reference, chain_mapping) result=_RunCAD(tmp_dir_name, mode, cad_bin_path, old_regime) _CleanupFiles(tmp_dir_name) - _MapLabels(model, result, label) + _MapLabels(model, result, label, chain_mapping) return result