diff --git a/modules/mol/alg/pymod/lddt.py b/modules/mol/alg/pymod/lddt.py index 617afbb7c80f1b3946b6d36196e1100a82a327e4..0cbefb658089b9ada83c485fa2f2c227b1d25759 100644 --- a/modules/mol/alg/pymod/lddt.py +++ b/modules/mol/alg/pymod/lddt.py @@ -571,7 +571,7 @@ class lDDTScorer: res_indices, ref_res_indices, symmetries = \ self._ProcessModel(model, chain_mapping, residue_mapping = residue_mapping, - thresholds = thresholds, + nirvana_dist = self.inclusion_radius + max(thresholds), check_resnames = check_resnames) if no_interchain and no_intrachain: @@ -715,6 +715,178 @@ class lDDTScorer: else: return lDDT, per_res_lDDT + def DRMSD(self, model, dist_cap = 5, + chain_mapping=None, no_interchain=False, + no_intrachain=False, residue_mapping=None, + check_resnames=True, add_mdl_contacts=False, + interaction_data=None): + """ EXPERIMENTAL DRMSD of *model* - globally and per-residue + + Very similar to LDDT as we operate on distance differences for all + interatomic distances within the same inclusion radius as in LDDT. + DRMSD is the distance rmsd, i.e. the RMSD of distance differences. + Distance differences are capped at *dist_cap* which is also the default + value for missing distances. + + :param model: Model to be scored - models are preferably scored upon + performing stereo-chemistry checks in order to punish for + non-sensical irregularities. This must be done separately + as a pre-processing step. Target contacts that are not + covered by *model* are considered not conserved, thus + increasing DRMSD score. This also includes missing model + chains or model chains for which no mapping is provided in + *chain_mapping*. + :type model: :class:`ost.mol.EntityHandle`/:class:`ost.mol.EntityView` + :param dist_cap: Cap for distance differences. + :type dist_cap: :class:`float` + :param chain_mapping: Mapping of model chains (key) onto target chains + (value). This is required if target or model have + more than one chain. + :type chain_mapping: :class:`dict` with :class:`str` as keys/values + :param no_interchain: Whether to exclude interchain contacts + :type no_interchain: :class:`bool` + :param no_intrachain: Whether to exclude intrachain contacts (i.e. only + consider interface related contacts) + :type no_intrachain: :class:`bool` + :param residue_mapping: By default, residue mapping is based on residue + numbers. That means, a model chain and the + respective target chain map to the same + underlying reference sequence (SEQRES). + Alternatively, you can specify one or + several alignment(s) between model and target + chains by providing a dictionary. key: Name + of chain in model (respective target chain is + extracted from *chain_mapping*), + value: Alignment with first sequence + corresponding to target chain and second + sequence to model chain. There is NO reference + sequence involved, so the two sequences MUST + exactly match the actual residues observed in + the respective target/model chains (ATOMSEQ). + :type residue_mapping: :class:`dict` with key: :class:`str`, + value: :class:`ost.seq.AlignmentHandle` + :param check_resnames: On by default. Enforces residue name matches + between mapped model and target residues. + :type check_resnames: :class:`bool` + :param add_mdl_contacts: Adds model contacts - Only using contacts that + are within a certain distance threshold in the + target does not penalize for added model + contacts. If set to True, this flag will also + consider target contacts that are within the + specified distance threshold in the model but + not necessarily in the target. No contact will + be added if the respective atom pair is not + resolved in the target. + :type add_mdl_contacts: :class:`bool` + :param interaction_data: Pro param - don't use + :type interaction_data: :class:`tuple` + + :returns: global and per-residue DRMSD scores as a tuple - + first element is global DRMSD score (None if *target* has no + contacts) and second element a list of per-residue scores with + length len(*model*.residues). None is assigned to residues that + are not covered by target. If a residue is covered but has no + contacts in *target*, None is assigned. + """ + if chain_mapping is None: + if len(self.chain_names) > 1 or len(model.chains) > 1: + raise NotImplementedError("Must provide chain mapping if " + "target or model have > 1 chains.") + chain_mapping = {model.chains[0].GetName(): self.chain_names[0]} + else: + # check whether chains specified in mapping exist + for model_chain, target_chain in chain_mapping.items(): + if target_chain not in self.chain_names: + raise RuntimeError(f"Target chain specified in " + f"chain_mapping ({target_chain}) does " + f"not exist. Target has chains: " + f"{self.chain_names}") + ch = model.FindChain(model_chain) + if not ch.IsValid(): + raise RuntimeError(f"Model chain specified in " + f"chain_mapping ({model_chain}) does " + f"not exist. Model has chains: " + f"{[c.GetName() for c in model.chains]}") + + # data objects defining model data - see _ProcessModel for rough + # description + pos, res_ref_atom_indices, res_atom_indices, res_atom_hashes, \ + res_indices, ref_res_indices, symmetries = \ + self._ProcessModel(model, chain_mapping, + residue_mapping = residue_mapping, + nirvana_dist = self.inclusion_radius + dist_cap, + check_resnames = check_resnames) + + if no_interchain and no_intrachain: + raise RuntimeError("no_interchain and no_intrachain flags are " + "mutually exclusive") + + sym_ref_indices = None + sym_ref_distances = None + ref_indices = None + ref_distances = None + + if interaction_data is None: + if no_interchain: + sym_ref_indices = self.sym_ref_indices_sc + sym_ref_distances = self.sym_ref_distances_sc + ref_indices = self.ref_indices_sc + ref_distances = self.ref_distances_sc + elif no_intrachain: + sym_ref_indices = self.sym_ref_indices_ic + sym_ref_distances = self.sym_ref_distances_ic + ref_indices = self.ref_indices_ic + ref_distances = self.ref_distances_ic + else: + sym_ref_indices = self.sym_ref_indices + sym_ref_distances = self.sym_ref_distances + ref_indices = self.ref_indices + ref_distances = self.ref_distances + + if add_mdl_contacts: + ref_indices, ref_distances = \ + self._AddMdlContacts(model, res_atom_indices, res_atom_hashes, + ref_indices, ref_distances, + no_interchain, no_intrachain) + # recompute symmetry related indices/distances + sym_ref_indices, sym_ref_distances = \ + lDDTScorer._NonSymDistances(self.n_atoms, self.symmetric_atoms, + ref_indices, ref_distances) + else: + sym_ref_indices, sym_ref_distances, ref_indices, ref_distances = \ + interaction_data + + self._ResolveSymmetriesSSD(pos, dist_cap, symmetries, sym_ref_indices, + sym_ref_distances) + + atom_indices = list(itertools.chain.from_iterable(res_atom_indices)) + + per_atom_exp = np.asarray([self._GetNExp(i, ref_indices) + for i in atom_indices], dtype=np.int32) + per_res_exp = np.asarray([self._GetNExp(res_ref_atom_indices[idx], + ref_indices) for idx in range(len(res_indices))], dtype=np.int32) + per_atom_ssd = self._EvalAtomsSSD(pos, atom_indices, dist_cap, + ref_indices, ref_distances) + + # do per residue scores + start_idx = 0 + per_res_drmsd = [None] * model.GetResidueCount() + for r_idx in range(len(res_atom_indices)): + end_idx = start_idx + len(res_atom_indices[r_idx]) + n_exp = per_res_exp[r_idx] + if n_exp > 0: + ssd = np.sum(per_atom_ssd[start_idx:end_idx]) + per_res_drmsd[res_indices[r_idx]] = np.sqrt(ssd/n_exp) + start_idx = end_idx + + # do full model score + drmsd = None + n_exp = np.sum(per_atom_exp) + if n_exp > 0: + drmsd = np.sqrt(np.sum(per_atom_ssd)/n_exp) + + return drmsd, per_res_drmsd + def GetNChainContacts(self, target_chain, no_interchain=False): """Returns number of contacts expected for a certain chain in *target* @@ -739,7 +911,7 @@ class lDDTScorer: return self._GetNExp(list(range(s, e)), self.ref_indices) def _ProcessModel(self, model, chain_mapping, residue_mapping = None, - thresholds = [0.5, 1.0, 2.0, 4.0], + nirvana_dist = 100, check_resnames = True): """ Helper that generates data structures from model """ @@ -748,7 +920,7 @@ class lDDTScorer: # set, it should be far away from any position in model. max_pos = model.bounds.GetMax() max_coordinate = abs(max(max_pos[0], max_pos[1], max_pos[2])) - max_coordinate += 42 * max(thresholds) + max_coordinate += 42 * nirvana_dist pos = np.ones((self.n_atoms, 3), dtype=np.float32) * max_coordinate # for each scored residue in model a list of indices describing the @@ -1398,3 +1570,73 @@ class lDDTScorer: # LDDT behaviour for pair in sym: pos[[pair[0], pair[1]]] = pos[[pair[1], pair[0]]] + + def _EvalAtomSSD(self, pos, atom_idx, dist_cap, ref_indices, ref_distances): + """ Computes summed squared distances + + distances are capped at dist_cap + """ + a_p = pos[atom_idx, :] + tmp = pos.take(ref_indices[atom_idx], axis=0) + np.subtract(tmp, a_p[None, :], out=tmp) + np.square(tmp, out=tmp) + tmp = tmp.sum(axis=1) + np.sqrt(tmp, out=tmp) # distances against all relevant atoms + np.subtract(ref_distances[atom_idx], tmp, out=tmp) # distance difference + np.square(tmp, out=tmp) # squared distance difference + squared_dist_cap = dist_cap*dist_cap + tmp[tmp > squared_dist_cap] = squared_dist_cap + return tmp.sum() + + def _EvalAtomsSSD( + self, pos, atom_indices, dist_cap, ref_indices, ref_distances + ): + """Calls _EvalAtomSSD for several atoms + """ + return np.asarray([self._EvalAtomSSD(pos, a, dist_cap, ref_indices, + ref_distances) for a in atom_indices], + dtype=np.float32) + + def _ResolveSymmetriesSSD(self, pos, dist_cap, symmetries, sym_ref_indices, + sym_ref_distances): + """Swaps symmetric positions in-place in order to maximize summed + squared distances towards non-symmetric atoms. + """ + for sym in symmetries: + + atom_indices = list() + for sym_tuple in sym: + atom_indices += [sym_tuple[0], sym_tuple[1]] + tot = self._GetNExp(atom_indices, sym_ref_indices) + + if tot == 0: + continue # nothing to do + + # score as is + sym_one_ssd = self._EvalAtomsSSD( + pos, + atom_indices, + dist_cap, + sym_ref_indices, + sym_ref_distances, + ) + + # switch positions and score again + for pair in sym: + pos[[pair[0], pair[1]]] = pos[[pair[1], pair[0]]] + + sym_two_ssd = self._EvalAtomsSSD( + pos, + atom_indices, + dist_cap, + sym_ref_indices, + sym_ref_distances, + ) + + sym_one_score = np.sum(sym_one_ssd) + sym_two_score = np.sum(sym_two_ssd) + + if sym_one_score < sym_two_score: + # switch back, initial positions were better + for pair in sym: + pos[[pair[0], pair[1]]] = pos[[pair[1], pair[0]]] diff --git a/modules/mol/alg/pymod/ligand_scoring_lddtpli.py b/modules/mol/alg/pymod/ligand_scoring_lddtpli.py index 9338b484f5f6adad14e993870f5adc4309af9313..ad73ed91b80cafe74f988a1c44b67770378aaedb 100644 --- a/modules/mol/alg/pymod/ligand_scoring_lddtpli.py +++ b/modules/mol/alg/pymod/ligand_scoring_lddtpli.py @@ -436,7 +436,7 @@ class LDDTPLIScorer(ligand_scoring_base.LigandScorer): pos, _, _, _, _, _, lddt_symmetries = \ scorer._ProcessModel(mdl_bs, lddt_chain_mapping, residue_mapping = lddt_alns, - thresholds = self.lddt_pli_thresholds, + nirvana_dist = self.lddt_pli_radius + max(self.lddt_pli_thresholds), check_resnames = False) # estimate a penalty for unsatisfied model contacts from chains @@ -668,7 +668,7 @@ class LDDTPLIScorer(ligand_scoring_base.LigandScorer): pos, _, _, _, _, _, lddt_symmetries = \ scorer._ProcessModel(mdl_bs, lddt_chain_mapping, residue_mapping = lddt_alns, - thresholds = self.lddt_pli_thresholds, + nirvana_dist = self.lddt_pli_radius + max(self.lddt_pli_thresholds), check_resnames = False) for (trg_sym, mdl_sym) in symmetries: diff --git a/modules/mol/alg/tests/test_lddt.py b/modules/mol/alg/tests/test_lddt.py index 755717f597606579b816e15b8afac3b6e120860c..11bce0503e3d2b2b0d829428f45852429cd32249 100644 --- a/modules/mol/alg/tests/test_lddt.py +++ b/modules/mol/alg/tests/test_lddt.py @@ -236,7 +236,16 @@ class TestlDDT(unittest.TestCase): # in lDDT computation self.assertEqual(lDDT, 0.6171511842396518) + def test_drmsd(self): + model = _LoadFile("7SGN_C_model.pdb") + target = _LoadFile("7SGN_C_target.pdb") + lddt_scorer = lDDTScorer(target) + drmsd, per_res_drmsd = lddt_scorer.DRMSD(model) + + # this value is just blindly copied in without checking whether it makes + # any sense... it's sole purpose is to trigger DRMSD computation + self.assertEqual(drmsd, 1.895447711911706) class TestlDDTBS(unittest.TestCase):