Skip to content
Snippets Groups Projects
Commit aaec9ede authored by Studer Gabriel's avatar Studer Gabriel
Browse files

lddt: speedup for big complexes

Pairwise distance computation for the reference distances was
performed with N squared complexity and some funny guy had the idea
to throw a 180mer at it...

One possibibility would be the use of some KD tree data structure.
However, the construction itself comes with computational cost.
The implemented solution makes use of the expected spatial proximity
of atoms in the same chain and distances are computed as follows:

- process each chain individually
- perform crude collision detection
- process potentially interacting chain pairs
- concatenate distances from all processing steps

The new algorithm has been tested and compared to the previous
implementation by randomly selecting 3 models of each CASP15 oligo
target. Global lDDT has been tested for a match within 0.0001 and
per-residue lDDT for a match within 0.001. Reason for lower threshold
in per-residue lDDT is floating point accuracy. Also for the changed
unit test, one distance difference was within floating point accuracy
of one of the thresholds (see comments there). Accuracy of 0.001 still
means that we only allow a discrepancy of one for 1000 checked distances...

Observed speedups are size dependent and range from lower 2 digit
percentages up to several fold speedup for larger CASP15 targets.
The mentioned 180mer now concludes in a few minutes as oposed to
almost a day.
parent 0356d1b9
Branches
Tags
No related merge requests found
......@@ -3,6 +3,17 @@ import numpy as np
from ost import mol
from ost import conop
# use cdist of scipy, fallback to (slower) numpy implementation if scipy is not
# available
try:
from scipy.spatial.distance import cdist
except:
def cdist(p1, p2):
x2 = np.sum(p1**2, axis=1) # (m)
y2 = np.sum(p2**2, axis=1) # (n)
xy = np.matmul(p1, p2.T) # (m, n)
x2 = x2.reshape(-1, 1)
return np.sqrt(x2 - 2*xy + y2) # (m, n)
class CustomCompound:
""" Defines atoms for custom compounds
......@@ -878,47 +889,127 @@ class lDDTScorer:
def _SetupDistances(self):
"""Compute distance related members of lDDTScorer
Brute force all vs all distance computation kills lDDT for large
complexes. Instead of building some KD tree data structure, we make use
of expected spatial proximity of atoms in the same chain. Distances are
computed as follows:
- process each chain individually
- perform crude collision detection
- process potentially interacting chain pairs
- concatenate distances from all processing steps
"""
# init
self._ref_indices = [np.asarray([], dtype=np.int64) for idx in range(self.n_atoms)]
self._ref_distances = [np.asarray([], dtype=np.float64) for idx in range(self.n_atoms)]
self._sym_ref_indices = [np.asarray([], dtype=np.int64) for idx in range(self.n_atoms)]
self._sym_ref_distances = [np.asarray([], dtype=np.float64) for idx in range(self.n_atoms)]
# initialize positions with values far in nirvana. If a position is not
# set, it should be far away from any position in target (or at least
# more than inclusion_radius).
max_pos = self.target.bounds.GetMax()
max_coordinate = abs(max(max_pos[0], max_pos[1], max_pos[2]))
max_coordinate += 2 * self.inclusion_radius
indices = [list() for _ in range(self.n_atoms)]
distances = [list() for _ in range(self.n_atoms)]
per_chain_pos = list()
per_chain_indices = list()
# Process individual chains
for ch_idx, ch in enumerate(self.target.chains):
ch_start_idx = self.chain_start_indices[ch_idx]
pos_list = list()
atom_indices = list()
mask_start = list()
mask_end = list()
r_start_idx = 0
for r_idx, r in enumerate(ch.residues):
n_valid_atoms = 0
for a in r.atoms:
hash_code = a.handle.GetHashCode()
if hash_code in self.atom_indices:
p = a.GetPos()
pos_list.append(np.asarray([p[0], p[1], p[2]]))
atom_indices.append(self.atom_indices[hash_code])
n_valid_atoms += 1
mask_start.extend([r_start_idx] * n_valid_atoms)
mask_end.extend([r_start_idx + n_valid_atoms] * n_valid_atoms)
r_start_idx += n_valid_atoms
pos = np.vstack(pos_list)
atom_indices = np.asarray(atom_indices)
dists = cdist(pos, pos)
# apply masks
far_away = 2 * self.inclusion_radius
for idx in range(atom_indices.shape[0]):
dists[idx, range(mask_start[idx], mask_end[idx])] = far_away
# fish out and store close atoms within inclusion radius
within_mask = dists < self.inclusion_radius
for idx in range(atom_indices.shape[0]):
indices_to_append = atom_indices[within_mask[idx,:]]
if indices_to_append.shape[0] > 0:
full_at_idx = atom_indices[idx]
indices[full_at_idx].append(indices_to_append)
distances[full_at_idx].append(dists[idx, within_mask[idx,:]])
per_chain_pos.append(pos)
per_chain_indices.append(atom_indices)
# perform crude collision detection
min_pos = [p.min(0) for p in per_chain_pos]
max_pos = [p.max(0) for p in per_chain_pos]
chain_pairs = list()
for idx_one in range(len(self.chain_start_indices)):
for idx_two in range(idx_one + 1, len(self.chain_start_indices)):
if np.max(min_pos[idx_one] - max_pos[idx_two]) > self.inclusion_radius:
continue
if np.max(min_pos[idx_two] - max_pos[idx_one]) > self.inclusion_radius:
continue
chain_pairs.append((idx_one, idx_two))
# process potentially interacting chains
for pair in chain_pairs:
dists = cdist(per_chain_pos[pair[0]], per_chain_pos[pair[1]])
within = dists <= self.inclusion_radius
# process pair[0]
tmp = within.sum(axis=1)
for idx in range(tmp.shape[0]):
if tmp[idx] > 0:
# even though not being a strict requirement, we perform an
# insertion here such that the indices for each atom will be
# sorted after the hstack operation
at_idx = per_chain_indices[pair[0]][idx]
indices_to_insert = per_chain_indices[pair[1]][within[idx,:]]
distances_to_insert = dists[idx, within[idx, :]]
insertion_idx = len(indices[at_idx])
for i in range(insertion_idx):
if indices_to_insert[0] > indices[at_idx][i][0]:
insertion_idx = i
break
indices[at_idx].insert(insertion_idx, indices_to_insert)
distances[at_idx].insert(insertion_idx, distances_to_insert)
# process pair[1]
tmp = within.sum(axis=0)
for idx in range(tmp.shape[0]):
if tmp[idx] > 0:
# even though not being a strict requirement, we perform an
# insertion here such that the indices for each atom will be
# sorted after the hstack operation
at_idx = per_chain_indices[pair[1]][idx]
indices_to_insert = per_chain_indices[pair[0]][within[:, idx]]
distances_to_insert = dists[within[:, idx], idx]
insertion_idx = len(indices[at_idx])
for i in range(insertion_idx):
if indices_to_insert[0] > indices[at_idx][i][0]:
insertion_idx = i
break
indices[at_idx].insert(insertion_idx, indices_to_insert)
distances[at_idx].insert(insertion_idx, distances_to_insert)
# concatenate distances from all processing steps
for at_idx in range(self.n_atoms):
if len(indices[at_idx]) > 0:
self._ref_indices[at_idx] = np.hstack(indices[at_idx])
self._ref_distances[at_idx] = np.hstack(distances[at_idx])
pos = np.ones((self.n_atoms, 3), dtype=np.float32) * max_coordinate
atom_indices = list()
mask_start = list()
mask_end = list()
for r_idx, r in enumerate(self.target.residues):
r_start_idx = self.res_start_indices[r_idx]
r_n_atoms = len(self.compound_anames[r.name])
r_end_idx = r_start_idx + r_n_atoms
for a in r.atoms:
if a.handle.GetHashCode() in self.atom_indices:
idx = self.atom_indices[a.handle.GetHashCode()]
p = a.GetPos()
pos[idx][0] = p[0]
pos[idx][1] = p[1]
pos[idx][2] = p[2]
atom_indices.append(idx)
mask_start.append(r_start_idx)
mask_end.append(r_end_idx)
indices, distances = self._CloseStuff(pos, self.inclusion_radius,
atom_indices, mask_start,
mask_end)
for i in range(len(atom_indices)):
self._ref_indices[atom_indices[i]] = indices[i]
self._ref_distances[atom_indices[i]] = distances[i]
self._NonSymDistances(self._ref_indices, self._ref_distances,
self._sym_ref_indices,
self._sym_ref_distances)
......@@ -987,26 +1078,6 @@ class lDDTScorer:
self._sym_ref_indices_ic,
self._sym_ref_distances_ic)
def _CloseStuff(self, pos, inclusion_radius, indices, mask_start, mask_end):
"""returns close stuff for positions specified by indices
"""
# TODO: this function does brute force distance computation which has
# quadratic complexity...
close_indices = list()
distances = list()
# work with squared_inclusion_radius (sir) to save some square roots
sir = inclusion_radius ** 2
for idx, ms, me in zip(indices, mask_start, mask_end):
p = pos[idx, :]
tmp = pos - p[None, :]
np.square(tmp, out=tmp)
tmp = tmp.sum(axis=1)
# mask out atoms of own residue => put them far away
tmp[range(ms, me)] = 2 * sir
close_indices.append(np.nonzero(tmp <= sir)[0])
distances.append(np.sqrt(tmp[close_indices[-1]]))
return (close_indices, distances)
def _NonSymDistances(self, ref_indices, ref_distances,
sym_ref_indices, sym_ref_distances):
"""Transfer indices/distances of non-symmetric atoms in place
......
......@@ -44,7 +44,11 @@ class TestlDDT(unittest.TestCase):
for a,b in zip(aws_per_res_scores, classic_per_res_scores):
if a is None and b is None:
continue
self.assertAlmostEqual(a, b, places = 5)
# only check for 3 places. Reason for that is that the distance
# difference between GLN30.CB and TYR35.O is within floating point
# accuracy of the 0.5A threshold. So the two involved residues may
# have a difference of 1 with respect to conserved distances.
self.assertAlmostEqual(a, b, places = 3)
# do 7W1F_B
model = _LoadFile("7W1F_B_model.pdb")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment