From aaec9ede55e5a9174b2bb43deb6395ae9f9921ba Mon Sep 17 00:00:00 2001 From: Gabriel Studer <gabriel.studer@unibas.ch> Date: Thu, 15 Feb 2024 17:07:38 +0100 Subject: [PATCH] lddt: speedup for big complexes Pairwise distance computation for the reference distances was performed with N squared complexity and some funny guy had the idea to throw a 180mer at it... One possibibility would be the use of some KD tree data structure. However, the construction itself comes with computational cost. The implemented solution makes use of the expected spatial proximity of atoms in the same chain and distances are computed as follows: - process each chain individually - perform crude collision detection - process potentially interacting chain pairs - concatenate distances from all processing steps The new algorithm has been tested and compared to the previous implementation by randomly selecting 3 models of each CASP15 oligo target. Global lDDT has been tested for a match within 0.0001 and per-residue lDDT for a match within 0.001. Reason for lower threshold in per-residue lDDT is floating point accuracy. Also for the changed unit test, one distance difference was within floating point accuracy of one of the thresholds (see comments there). Accuracy of 0.001 still means that we only allow a discrepancy of one for 1000 checked distances... Observed speedups are size dependent and range from lower 2 digit percentages up to several fold speedup for larger CASP15 targets. The mentioned 180mer now concludes in a few minutes as oposed to almost a day. --- modules/mol/alg/pymod/lddt.py | 179 ++++++++++++++++++++--------- modules/mol/alg/tests/test_lddt.py | 6 +- 2 files changed, 130 insertions(+), 55 deletions(-) diff --git a/modules/mol/alg/pymod/lddt.py b/modules/mol/alg/pymod/lddt.py index ff6119a53..ef35eb0f5 100644 --- a/modules/mol/alg/pymod/lddt.py +++ b/modules/mol/alg/pymod/lddt.py @@ -3,6 +3,17 @@ import numpy as np from ost import mol from ost import conop +# use cdist of scipy, fallback to (slower) numpy implementation if scipy is not +# available +try: + from scipy.spatial.distance import cdist +except: + def cdist(p1, p2): + x2 = np.sum(p1**2, axis=1) # (m) + y2 = np.sum(p2**2, axis=1) # (n) + xy = np.matmul(p1, p2.T) # (m, n) + x2 = x2.reshape(-1, 1) + return np.sqrt(x2 - 2*xy + y2) # (m, n) class CustomCompound: """ Defines atoms for custom compounds @@ -878,47 +889,127 @@ class lDDTScorer: def _SetupDistances(self): """Compute distance related members of lDDTScorer + + Brute force all vs all distance computation kills lDDT for large + complexes. Instead of building some KD tree data structure, we make use + of expected spatial proximity of atoms in the same chain. Distances are + computed as follows: + + - process each chain individually + - perform crude collision detection + - process potentially interacting chain pairs + - concatenate distances from all processing steps """ - # init self._ref_indices = [np.asarray([], dtype=np.int64) for idx in range(self.n_atoms)] self._ref_distances = [np.asarray([], dtype=np.float64) for idx in range(self.n_atoms)] self._sym_ref_indices = [np.asarray([], dtype=np.int64) for idx in range(self.n_atoms)] self._sym_ref_distances = [np.asarray([], dtype=np.float64) for idx in range(self.n_atoms)] - # initialize positions with values far in nirvana. If a position is not - # set, it should be far away from any position in target (or at least - # more than inclusion_radius). - max_pos = self.target.bounds.GetMax() - max_coordinate = abs(max(max_pos[0], max_pos[1], max_pos[2])) - max_coordinate += 2 * self.inclusion_radius + indices = [list() for _ in range(self.n_atoms)] + distances = [list() for _ in range(self.n_atoms)] + per_chain_pos = list() + per_chain_indices = list() + + # Process individual chains + for ch_idx, ch in enumerate(self.target.chains): + ch_start_idx = self.chain_start_indices[ch_idx] + pos_list = list() + atom_indices = list() + mask_start = list() + mask_end = list() + r_start_idx = 0 + for r_idx, r in enumerate(ch.residues): + n_valid_atoms = 0 + for a in r.atoms: + hash_code = a.handle.GetHashCode() + if hash_code in self.atom_indices: + p = a.GetPos() + pos_list.append(np.asarray([p[0], p[1], p[2]])) + atom_indices.append(self.atom_indices[hash_code]) + n_valid_atoms += 1 + mask_start.extend([r_start_idx] * n_valid_atoms) + mask_end.extend([r_start_idx + n_valid_atoms] * n_valid_atoms) + r_start_idx += n_valid_atoms + pos = np.vstack(pos_list) + atom_indices = np.asarray(atom_indices) + dists = cdist(pos, pos) + + # apply masks + far_away = 2 * self.inclusion_radius + for idx in range(atom_indices.shape[0]): + dists[idx, range(mask_start[idx], mask_end[idx])] = far_away + + # fish out and store close atoms within inclusion radius + within_mask = dists < self.inclusion_radius + for idx in range(atom_indices.shape[0]): + indices_to_append = atom_indices[within_mask[idx,:]] + if indices_to_append.shape[0] > 0: + full_at_idx = atom_indices[idx] + indices[full_at_idx].append(indices_to_append) + distances[full_at_idx].append(dists[idx, within_mask[idx,:]]) + + per_chain_pos.append(pos) + per_chain_indices.append(atom_indices) + + # perform crude collision detection + min_pos = [p.min(0) for p in per_chain_pos] + max_pos = [p.max(0) for p in per_chain_pos] + chain_pairs = list() + for idx_one in range(len(self.chain_start_indices)): + for idx_two in range(idx_one + 1, len(self.chain_start_indices)): + if np.max(min_pos[idx_one] - max_pos[idx_two]) > self.inclusion_radius: + continue + if np.max(min_pos[idx_two] - max_pos[idx_one]) > self.inclusion_radius: + continue + chain_pairs.append((idx_one, idx_two)) + + # process potentially interacting chains + for pair in chain_pairs: + dists = cdist(per_chain_pos[pair[0]], per_chain_pos[pair[1]]) + within = dists <= self.inclusion_radius + + # process pair[0] + tmp = within.sum(axis=1) + for idx in range(tmp.shape[0]): + if tmp[idx] > 0: + # even though not being a strict requirement, we perform an + # insertion here such that the indices for each atom will be + # sorted after the hstack operation + at_idx = per_chain_indices[pair[0]][idx] + indices_to_insert = per_chain_indices[pair[1]][within[idx,:]] + distances_to_insert = dists[idx, within[idx, :]] + insertion_idx = len(indices[at_idx]) + for i in range(insertion_idx): + if indices_to_insert[0] > indices[at_idx][i][0]: + insertion_idx = i + break + indices[at_idx].insert(insertion_idx, indices_to_insert) + distances[at_idx].insert(insertion_idx, distances_to_insert) + + # process pair[1] + tmp = within.sum(axis=0) + for idx in range(tmp.shape[0]): + if tmp[idx] > 0: + # even though not being a strict requirement, we perform an + # insertion here such that the indices for each atom will be + # sorted after the hstack operation + at_idx = per_chain_indices[pair[1]][idx] + indices_to_insert = per_chain_indices[pair[0]][within[:, idx]] + distances_to_insert = dists[within[:, idx], idx] + insertion_idx = len(indices[at_idx]) + for i in range(insertion_idx): + if indices_to_insert[0] > indices[at_idx][i][0]: + insertion_idx = i + break + indices[at_idx].insert(insertion_idx, indices_to_insert) + distances[at_idx].insert(insertion_idx, distances_to_insert) + + # concatenate distances from all processing steps + for at_idx in range(self.n_atoms): + if len(indices[at_idx]) > 0: + self._ref_indices[at_idx] = np.hstack(indices[at_idx]) + self._ref_distances[at_idx] = np.hstack(distances[at_idx]) - pos = np.ones((self.n_atoms, 3), dtype=np.float32) * max_coordinate - atom_indices = list() - mask_start = list() - mask_end = list() - - for r_idx, r in enumerate(self.target.residues): - r_start_idx = self.res_start_indices[r_idx] - r_n_atoms = len(self.compound_anames[r.name]) - r_end_idx = r_start_idx + r_n_atoms - for a in r.atoms: - if a.handle.GetHashCode() in self.atom_indices: - idx = self.atom_indices[a.handle.GetHashCode()] - p = a.GetPos() - pos[idx][0] = p[0] - pos[idx][1] = p[1] - pos[idx][2] = p[2] - atom_indices.append(idx) - mask_start.append(r_start_idx) - mask_end.append(r_end_idx) - - indices, distances = self._CloseStuff(pos, self.inclusion_radius, - atom_indices, mask_start, - mask_end) - - for i in range(len(atom_indices)): - self._ref_indices[atom_indices[i]] = indices[i] - self._ref_distances[atom_indices[i]] = distances[i] self._NonSymDistances(self._ref_indices, self._ref_distances, self._sym_ref_indices, self._sym_ref_distances) @@ -987,26 +1078,6 @@ class lDDTScorer: self._sym_ref_indices_ic, self._sym_ref_distances_ic) - def _CloseStuff(self, pos, inclusion_radius, indices, mask_start, mask_end): - """returns close stuff for positions specified by indices - """ - # TODO: this function does brute force distance computation which has - # quadratic complexity... - close_indices = list() - distances = list() - # work with squared_inclusion_radius (sir) to save some square roots - sir = inclusion_radius ** 2 - for idx, ms, me in zip(indices, mask_start, mask_end): - p = pos[idx, :] - tmp = pos - p[None, :] - np.square(tmp, out=tmp) - tmp = tmp.sum(axis=1) - # mask out atoms of own residue => put them far away - tmp[range(ms, me)] = 2 * sir - close_indices.append(np.nonzero(tmp <= sir)[0]) - distances.append(np.sqrt(tmp[close_indices[-1]])) - return (close_indices, distances) - def _NonSymDistances(self, ref_indices, ref_distances, sym_ref_indices, sym_ref_distances): """Transfer indices/distances of non-symmetric atoms in place diff --git a/modules/mol/alg/tests/test_lddt.py b/modules/mol/alg/tests/test_lddt.py index 6cfe9d5b6..0e31f6cac 100644 --- a/modules/mol/alg/tests/test_lddt.py +++ b/modules/mol/alg/tests/test_lddt.py @@ -44,7 +44,11 @@ class TestlDDT(unittest.TestCase): for a,b in zip(aws_per_res_scores, classic_per_res_scores): if a is None and b is None: continue - self.assertAlmostEqual(a, b, places = 5) + # only check for 3 places. Reason for that is that the distance + # difference between GLN30.CB and TYR35.O is within floating point + # accuracy of the 0.5A threshold. So the two involved residues may + # have a difference of 1 with respect to conserved distances. + self.assertAlmostEqual(a, b, places = 3) # do 7W1F_B model = _LoadFile("7W1F_B_model.pdb") -- GitLab