diff --git a/modules/mol/alg/pymod/lddt.py b/modules/mol/alg/pymod/lddt.py index ff6119a53c9f9052296e339de458223a29ba505a..ef35eb0f54590ece8043cafbdc1ab80a077f7d96 100644 --- a/modules/mol/alg/pymod/lddt.py +++ b/modules/mol/alg/pymod/lddt.py @@ -3,6 +3,17 @@ import numpy as np from ost import mol from ost import conop +# use cdist of scipy, fallback to (slower) numpy implementation if scipy is not +# available +try: + from scipy.spatial.distance import cdist +except: + def cdist(p1, p2): + x2 = np.sum(p1**2, axis=1) # (m) + y2 = np.sum(p2**2, axis=1) # (n) + xy = np.matmul(p1, p2.T) # (m, n) + x2 = x2.reshape(-1, 1) + return np.sqrt(x2 - 2*xy + y2) # (m, n) class CustomCompound: """ Defines atoms for custom compounds @@ -878,47 +889,127 @@ class lDDTScorer: def _SetupDistances(self): """Compute distance related members of lDDTScorer + + Brute force all vs all distance computation kills lDDT for large + complexes. Instead of building some KD tree data structure, we make use + of expected spatial proximity of atoms in the same chain. Distances are + computed as follows: + + - process each chain individually + - perform crude collision detection + - process potentially interacting chain pairs + - concatenate distances from all processing steps """ - # init self._ref_indices = [np.asarray([], dtype=np.int64) for idx in range(self.n_atoms)] self._ref_distances = [np.asarray([], dtype=np.float64) for idx in range(self.n_atoms)] self._sym_ref_indices = [np.asarray([], dtype=np.int64) for idx in range(self.n_atoms)] self._sym_ref_distances = [np.asarray([], dtype=np.float64) for idx in range(self.n_atoms)] - # initialize positions with values far in nirvana. If a position is not - # set, it should be far away from any position in target (or at least - # more than inclusion_radius). - max_pos = self.target.bounds.GetMax() - max_coordinate = abs(max(max_pos[0], max_pos[1], max_pos[2])) - max_coordinate += 2 * self.inclusion_radius + indices = [list() for _ in range(self.n_atoms)] + distances = [list() for _ in range(self.n_atoms)] + per_chain_pos = list() + per_chain_indices = list() + + # Process individual chains + for ch_idx, ch in enumerate(self.target.chains): + ch_start_idx = self.chain_start_indices[ch_idx] + pos_list = list() + atom_indices = list() + mask_start = list() + mask_end = list() + r_start_idx = 0 + for r_idx, r in enumerate(ch.residues): + n_valid_atoms = 0 + for a in r.atoms: + hash_code = a.handle.GetHashCode() + if hash_code in self.atom_indices: + p = a.GetPos() + pos_list.append(np.asarray([p[0], p[1], p[2]])) + atom_indices.append(self.atom_indices[hash_code]) + n_valid_atoms += 1 + mask_start.extend([r_start_idx] * n_valid_atoms) + mask_end.extend([r_start_idx + n_valid_atoms] * n_valid_atoms) + r_start_idx += n_valid_atoms + pos = np.vstack(pos_list) + atom_indices = np.asarray(atom_indices) + dists = cdist(pos, pos) + + # apply masks + far_away = 2 * self.inclusion_radius + for idx in range(atom_indices.shape[0]): + dists[idx, range(mask_start[idx], mask_end[idx])] = far_away + + # fish out and store close atoms within inclusion radius + within_mask = dists < self.inclusion_radius + for idx in range(atom_indices.shape[0]): + indices_to_append = atom_indices[within_mask[idx,:]] + if indices_to_append.shape[0] > 0: + full_at_idx = atom_indices[idx] + indices[full_at_idx].append(indices_to_append) + distances[full_at_idx].append(dists[idx, within_mask[idx,:]]) + + per_chain_pos.append(pos) + per_chain_indices.append(atom_indices) + + # perform crude collision detection + min_pos = [p.min(0) for p in per_chain_pos] + max_pos = [p.max(0) for p in per_chain_pos] + chain_pairs = list() + for idx_one in range(len(self.chain_start_indices)): + for idx_two in range(idx_one + 1, len(self.chain_start_indices)): + if np.max(min_pos[idx_one] - max_pos[idx_two]) > self.inclusion_radius: + continue + if np.max(min_pos[idx_two] - max_pos[idx_one]) > self.inclusion_radius: + continue + chain_pairs.append((idx_one, idx_two)) + + # process potentially interacting chains + for pair in chain_pairs: + dists = cdist(per_chain_pos[pair[0]], per_chain_pos[pair[1]]) + within = dists <= self.inclusion_radius + + # process pair[0] + tmp = within.sum(axis=1) + for idx in range(tmp.shape[0]): + if tmp[idx] > 0: + # even though not being a strict requirement, we perform an + # insertion here such that the indices for each atom will be + # sorted after the hstack operation + at_idx = per_chain_indices[pair[0]][idx] + indices_to_insert = per_chain_indices[pair[1]][within[idx,:]] + distances_to_insert = dists[idx, within[idx, :]] + insertion_idx = len(indices[at_idx]) + for i in range(insertion_idx): + if indices_to_insert[0] > indices[at_idx][i][0]: + insertion_idx = i + break + indices[at_idx].insert(insertion_idx, indices_to_insert) + distances[at_idx].insert(insertion_idx, distances_to_insert) + + # process pair[1] + tmp = within.sum(axis=0) + for idx in range(tmp.shape[0]): + if tmp[idx] > 0: + # even though not being a strict requirement, we perform an + # insertion here such that the indices for each atom will be + # sorted after the hstack operation + at_idx = per_chain_indices[pair[1]][idx] + indices_to_insert = per_chain_indices[pair[0]][within[:, idx]] + distances_to_insert = dists[within[:, idx], idx] + insertion_idx = len(indices[at_idx]) + for i in range(insertion_idx): + if indices_to_insert[0] > indices[at_idx][i][0]: + insertion_idx = i + break + indices[at_idx].insert(insertion_idx, indices_to_insert) + distances[at_idx].insert(insertion_idx, distances_to_insert) + + # concatenate distances from all processing steps + for at_idx in range(self.n_atoms): + if len(indices[at_idx]) > 0: + self._ref_indices[at_idx] = np.hstack(indices[at_idx]) + self._ref_distances[at_idx] = np.hstack(distances[at_idx]) - pos = np.ones((self.n_atoms, 3), dtype=np.float32) * max_coordinate - atom_indices = list() - mask_start = list() - mask_end = list() - - for r_idx, r in enumerate(self.target.residues): - r_start_idx = self.res_start_indices[r_idx] - r_n_atoms = len(self.compound_anames[r.name]) - r_end_idx = r_start_idx + r_n_atoms - for a in r.atoms: - if a.handle.GetHashCode() in self.atom_indices: - idx = self.atom_indices[a.handle.GetHashCode()] - p = a.GetPos() - pos[idx][0] = p[0] - pos[idx][1] = p[1] - pos[idx][2] = p[2] - atom_indices.append(idx) - mask_start.append(r_start_idx) - mask_end.append(r_end_idx) - - indices, distances = self._CloseStuff(pos, self.inclusion_radius, - atom_indices, mask_start, - mask_end) - - for i in range(len(atom_indices)): - self._ref_indices[atom_indices[i]] = indices[i] - self._ref_distances[atom_indices[i]] = distances[i] self._NonSymDistances(self._ref_indices, self._ref_distances, self._sym_ref_indices, self._sym_ref_distances) @@ -987,26 +1078,6 @@ class lDDTScorer: self._sym_ref_indices_ic, self._sym_ref_distances_ic) - def _CloseStuff(self, pos, inclusion_radius, indices, mask_start, mask_end): - """returns close stuff for positions specified by indices - """ - # TODO: this function does brute force distance computation which has - # quadratic complexity... - close_indices = list() - distances = list() - # work with squared_inclusion_radius (sir) to save some square roots - sir = inclusion_radius ** 2 - for idx, ms, me in zip(indices, mask_start, mask_end): - p = pos[idx, :] - tmp = pos - p[None, :] - np.square(tmp, out=tmp) - tmp = tmp.sum(axis=1) - # mask out atoms of own residue => put them far away - tmp[range(ms, me)] = 2 * sir - close_indices.append(np.nonzero(tmp <= sir)[0]) - distances.append(np.sqrt(tmp[close_indices[-1]])) - return (close_indices, distances) - def _NonSymDistances(self, ref_indices, ref_distances, sym_ref_indices, sym_ref_distances): """Transfer indices/distances of non-symmetric atoms in place diff --git a/modules/mol/alg/tests/test_lddt.py b/modules/mol/alg/tests/test_lddt.py index 6cfe9d5b6ba1e5f6b71d37898e15f1f5147eac0c..0e31f6cac9ec1b90639c66ea8686fe0c48b4a1af 100644 --- a/modules/mol/alg/tests/test_lddt.py +++ b/modules/mol/alg/tests/test_lddt.py @@ -44,7 +44,11 @@ class TestlDDT(unittest.TestCase): for a,b in zip(aws_per_res_scores, classic_per_res_scores): if a is None and b is None: continue - self.assertAlmostEqual(a, b, places = 5) + # only check for 3 places. Reason for that is that the distance + # difference between GLN30.CB and TYR35.O is within floating point + # accuracy of the 0.5A threshold. So the two involved residues may + # have a difference of 1 with respect to conserved distances. + self.assertAlmostEqual(a, b, places = 3) # do 7W1F_B model = _LoadFile("7W1F_B_model.pdb")