diff --git a/modules/mol/alg/pymod/lddt.py b/modules/mol/alg/pymod/lddt.py index 1e1230f897281c125ea08fd60410f1b20c2367a7..40e98d90c6a287b0baf7eb687dd3a8c86c723ff2 100644 --- a/modules/mol/alg/pymod/lddt.py +++ b/modules/mol/alg/pymod/lddt.py @@ -254,6 +254,14 @@ class lDDTScorer: self._sym_ref_distances_sc = None self._n_distances_sc = None + # exactly the same as above but without intrachain contacts + # => inter-chain (ic) + self._ref_indices_ic = None + self._ref_distances_ic = None + self._sym_ref_indices_ic = None + self._sym_ref_distances_ic = None + self._n_distances_ic = None + # input parameter checking self._ProcessSequenceSeparation() @@ -317,11 +325,42 @@ class lDDTScorer: self._n_distances_sc = sum([len(x) for x in self.ref_indices_sc]) return self._n_distances_sc + @property + def ref_indices_ic(self): + if self._ref_indices_ic is None: + self._SetupDistancesIC() + return self._ref_indices_ic + + @property + def ref_distances_ic(self): + if self._ref_distances_ic is None: + self._SetupDistancesIC() + return self._ref_distances_ic + + @property + def sym_ref_indices_ic(self): + if self._sym_ref_indices_ic is None: + self._SetupDistancesIC() + return self._sym_ref_indices_ic + + @property + def sym_ref_distances_ic(self): + if self._sym_ref_distances_ic is None: + self._SetupDistancesIC() + return self._sym_ref_distances_ic + + @property + def n_distances_ic(self): + if self._n_distances_ic is None: + self._n_distances_ic = sum([len(x) for x in self.ref_indices_ic]) + return self._n_distances_ic + def lDDT(self, model, thresholds = [0.5, 1.0, 2.0, 4.0], local_lddt_prop=None, local_contact_prop=None, chain_mapping=None, no_interchain=False, - penalize_extra_chains=False, residue_mapping=None, - return_dist_test=False, check_resnames=True): + no_intrachain=False, penalize_extra_chains=False, + residue_mapping=None, return_dist_test=False, + check_resnames=True): """Computes lDDT of *model* - globally and per-residue :param model: Model to be scored - models are preferably scored upon @@ -350,6 +389,9 @@ class lDDTScorer: :type chain_mapping: :class:`dict` with :class:`str` as keys/values :param no_interchain: Whether to exclude interchain contacts :type no_interchain: :class:`bool` + :param no_intrachain: Whether to exclude intrachain contacts (i.e. only + consider interface related contacts) + :type no_intrachain: :class:`bool` :param penalize_extra_chains: Whether to include a fixed penalty for additional chains in the model that are not mapped to the target. ONLY AFFECTS @@ -492,12 +534,22 @@ class lDDTScorer: if len(sym_indices) > 0: symmetries.append(sym_indices) + if no_interchain and no_intrachain: + raise RuntimeError("on_interchain and no_intrachain flags are " + "mutually exclusive") + if no_interchain: sym_ref_indices = self.sym_ref_indices_sc sym_ref_distances = self.sym_ref_distances_sc ref_indices = self.ref_indices_sc ref_distances = self.ref_distances_sc n_distances = self.n_distances_sc + elif no_intrachain: + sym_ref_indices = self.sym_ref_indices_ic + sym_ref_distances = self.sym_ref_distances_ic + ref_indices = self.ref_indices_ic + ref_distances = self.ref_distances_ic + n_distances = self.n_distances_ic else: sym_ref_indices = self.sym_ref_indices sym_ref_distances = self.sym_ref_distances @@ -864,6 +916,38 @@ class lDDTScorer: self._sym_ref_indices_sc, self._sym_ref_distances_sc) + def _SetupDistancesIC(self): + """Select subset of contacts only covering inter-chain contacts + """ + # init + self._ref_indices_ic = [[] for idx in range(self.n_atoms)] + self._ref_distances_ic = [[] for idx in range(self.n_atoms)] + self._sym_ref_indices_ic = [[] for idx in range(self.n_atoms)] + self._sym_ref_distances_ic = [[] for idx in range(self.n_atoms)] + + # start from overall contacts + ref_indices = self.ref_indices + ref_distances = self.ref_distances + sym_ref_indices = self.sym_ref_indices + sym_ref_distances = self.sym_ref_distances + + n_chains = len(self.chain_start_indices) + for ch_idx, ch in enumerate(self.target.chains): + chain_s = self.chain_start_indices[ch_idx] + chain_e = self.n_atoms + if ch_idx + 1 < n_chains: + chain_e = self.chain_start_indices[ch_idx+1] + for i in range(chain_s, chain_e): + if len(ref_indices[i]) > 0: + inter_idx = np.where(np.logical_or(ref_indices[i]<chain_s, + ref_indices[i]>=chain_e))[0] + self._ref_indices_ic[i] = ref_indices[i][inter_idx] + self._ref_distances_ic[i] = ref_distances[i][inter_idx] + + self._NonSymDistances(self._ref_indices_ic, self._ref_distances_ic, + self._sym_ref_indices_ic, + self._sym_ref_distances_ic) + def _CloseStuff(self, pos, inclusion_radius, indices, mask_start, mask_end): """returns close stuff for positions specified by indices """ diff --git a/modules/mol/alg/tests/test_lddt.py b/modules/mol/alg/tests/test_lddt.py index 9f2a079346c48f92dcd94c7cbb41f09d999317f4..6f3e3be47386b1a50cc42d4330fdb052869daa02 100644 --- a/modules/mol/alg/tests/test_lddt.py +++ b/modules/mol/alg/tests/test_lddt.py @@ -184,6 +184,40 @@ class TestlDDT(unittest.TestCase): scorer.lDDT(model, check_resnames=False) + def test_intra_interchain(self): + ent_full = _LoadFile("4br6.1.pdb") + model = ent_full.Select('peptide=true and cname=A,B') + target = ent_full.Select('peptide=true and cname=A,B') + chain_mapping = {"A": "A", "B": "B"} + + lddt_scorer = lDDTScorer(target) + + # do lDDT only on interchain contacts (ic) + lDDT_ic, per_res_lDDT_ic, lDDT_tot_ic, lDDT_cons_ic, \ + res_indices_ic, per_res_exp_ic, per_res_conserved_ic =\ + lddt_scorer.lDDT(model, no_intrachain=True, + chain_mapping = chain_mapping, + return_dist_test = True) + + # do lDDT only on intrachain contacts (sc for single chain) + lDDT_sc, per_res_lDDT_sc, lDDT_tot_sc, lDDT_cons_sc, \ + res_indices_sc, per_res_exp_sc, per_res_conserved_sc =\ + lddt_scorer.lDDT(model, no_interchain=True, + chain_mapping = chain_mapping, + return_dist_test = True) + + # do lDDT on everything + lDDT, per_res_lDDT, lDDT_tot, lDDT_cons, res_indices, per_res_exp, \ + per_res_conserved = lddt_scorer.lDDT(model, + chain_mapping = chain_mapping, + return_dist_test = True) + + # sum of lDDT_tot_ic and lDDT_tot_sc should be equal to lDDT_tot + self.assertEqual(lDDT_tot_ic + lDDT_tot_sc, lDDT_tot) + + # same for the conserved contacts + self.assertEqual(lDDT_cons_ic + lDDT_cons_sc, lDDT_cons) + class TestlDDTBS(unittest.TestCase):