Skip to content
Snippets Groups Projects
Commit db020c14 authored by Studer Gabriel's avatar Studer Gabriel
Browse files

lddt speedups through caching

parent 7f194707
No related branches found
No related tags found
No related merge requests found
......@@ -724,7 +724,7 @@ class lDDTScorer:
self._sym_ref_distances)
def _SetupDistancesSC(self):
"""Same as above but on a single chain basis => not interchain contacts
"""Select subset of contacts only covering intra-chain contacts
"""
# init
self._ref_indices_sc = [[] for idx in range(self.n_atoms)]
......@@ -732,54 +732,25 @@ class lDDTScorer:
self._sym_ref_indices_sc = [[] for idx in range(self.n_atoms)]
self._sym_ref_distances_sc = [[] for idx in range(self.n_atoms)]
# initialize positions with values far in nirvana. If a position is not
# set, it should be far away from any position in target (or at least
# more than inclusion_radius).
max_pos = self.target.bounds.GetMax()
max_coordinate = max(max_pos[0], max_pos[1], max_pos[2])
max_coordinate += 2 * self.inclusion_radius
# start from overall contacts
ref_indices = self.ref_indices
ref_distances = self.ref_distances
sym_ref_indices = self.sym_ref_indices
sym_ref_distances = self.sym_ref_distances
# same as above but chain-wise
r_idx = 0
n_chains = len(self.target.chains)
n_chains = len(self.chain_start_indices)
for ch_idx, ch in enumerate(self.target.chains):
chain_start_idx = self.chain_start_indices[ch_idx]
chain_end_idx = self.n_atoms
chain_s = self.chain_start_indices[ch_idx]
chain_e = self.n_atoms
if ch_idx + 1 < n_chains:
chain_end_idx = self.chain_start_indices[ch_idx+1]
n_chain_atoms = chain_end_idx - chain_start_idx
pos = np.ones((self.n_atoms, 3), dtype=np.float32) * max_coordinate
atom_indices = list()
mask_start = list()
mask_end = list()
for r in ch.residues:
r_start_idx = self.res_start_indices[r_idx]
r_start_idx -= chain_start_idx # map to chain reference
r_n_atoms = len(self.compound_anames[r.name])
r_end_idx = r_start_idx + r_n_atoms
for a in r.atoms:
if a.handle.GetHashCode() in self.atom_indices:
idx = self.atom_indices[a.handle.GetHashCode()]
idx -= chain_start_idx # map to chain reference
p = a.GetPos()
pos[idx][0] = p[0]
pos[idx][1] = p[1]
pos[idx][2] = p[2]
atom_indices.append(idx)
mask_start.append(r_start_idx)
mask_end.append(r_end_idx)
r_idx += 1
indices, distances = self._CloseStuff(pos,
self.inclusion_radius,
atom_indices, mask_start,
mask_end)
for i in range(len(atom_indices)):
# map back to global reference
self._ref_indices_sc[atom_indices[i]+chain_start_idx] =\
[x+chain_start_idx for x in indices[i]]
self._ref_distances_sc[atom_indices[i]+chain_start_idx] = distances[i]
chain_e = self.chain_start_indices[ch_idx+1]
for i in range(chain_s, chain_e):
if len(ref_indices[i]) > 0:
intra_idx = np.where(np.logical_and(ref_indices[i]>=chain_s,
ref_indices[i]<chain_e))[0]
self._ref_indices_sc[i] = ref_indices[i][intra_idx]
self._ref_distances_sc[i] = ref_distances[i][intra_idx]
self._NonSymDistances(self._ref_indices_sc, self._ref_distances_sc,
self._sym_ref_indices_sc,
self._sym_ref_distances_sc)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment