Skip to content
Snippets Groups Projects
Commit db020c14 authored by Studer Gabriel's avatar Studer Gabriel
Browse files

lddt speedups through caching

parent 7f194707
No related branches found
No related tags found
No related merge requests found
...@@ -724,7 +724,7 @@ class lDDTScorer: ...@@ -724,7 +724,7 @@ class lDDTScorer:
self._sym_ref_distances) self._sym_ref_distances)
def _SetupDistancesSC(self): def _SetupDistancesSC(self):
"""Same as above but on a single chain basis => not interchain contacts """Select subset of contacts only covering intra-chain contacts
""" """
# init # init
self._ref_indices_sc = [[] for idx in range(self.n_atoms)] self._ref_indices_sc = [[] for idx in range(self.n_atoms)]
...@@ -732,54 +732,25 @@ class lDDTScorer: ...@@ -732,54 +732,25 @@ class lDDTScorer:
self._sym_ref_indices_sc = [[] for idx in range(self.n_atoms)] self._sym_ref_indices_sc = [[] for idx in range(self.n_atoms)]
self._sym_ref_distances_sc = [[] for idx in range(self.n_atoms)] self._sym_ref_distances_sc = [[] for idx in range(self.n_atoms)]
# initialize positions with values far in nirvana. If a position is not # start from overall contacts
# set, it should be far away from any position in target (or at least ref_indices = self.ref_indices
# more than inclusion_radius). ref_distances = self.ref_distances
max_pos = self.target.bounds.GetMax() sym_ref_indices = self.sym_ref_indices
max_coordinate = max(max_pos[0], max_pos[1], max_pos[2]) sym_ref_distances = self.sym_ref_distances
max_coordinate += 2 * self.inclusion_radius
# same as above but chain-wise n_chains = len(self.chain_start_indices)
r_idx = 0
n_chains = len(self.target.chains)
for ch_idx, ch in enumerate(self.target.chains): for ch_idx, ch in enumerate(self.target.chains):
chain_start_idx = self.chain_start_indices[ch_idx] chain_s = self.chain_start_indices[ch_idx]
chain_end_idx = self.n_atoms chain_e = self.n_atoms
if ch_idx + 1 < n_chains: if ch_idx + 1 < n_chains:
chain_end_idx = self.chain_start_indices[ch_idx+1] chain_e = self.chain_start_indices[ch_idx+1]
n_chain_atoms = chain_end_idx - chain_start_idx for i in range(chain_s, chain_e):
pos = np.ones((self.n_atoms, 3), dtype=np.float32) * max_coordinate if len(ref_indices[i]) > 0:
atom_indices = list() intra_idx = np.where(np.logical_and(ref_indices[i]>=chain_s,
mask_start = list() ref_indices[i]<chain_e))[0]
mask_end = list() self._ref_indices_sc[i] = ref_indices[i][intra_idx]
for r in ch.residues: self._ref_distances_sc[i] = ref_distances[i][intra_idx]
r_start_idx = self.res_start_indices[r_idx]
r_start_idx -= chain_start_idx # map to chain reference
r_n_atoms = len(self.compound_anames[r.name])
r_end_idx = r_start_idx + r_n_atoms
for a in r.atoms:
if a.handle.GetHashCode() in self.atom_indices:
idx = self.atom_indices[a.handle.GetHashCode()]
idx -= chain_start_idx # map to chain reference
p = a.GetPos()
pos[idx][0] = p[0]
pos[idx][1] = p[1]
pos[idx][2] = p[2]
atom_indices.append(idx)
mask_start.append(r_start_idx)
mask_end.append(r_end_idx)
r_idx += 1
indices, distances = self._CloseStuff(pos,
self.inclusion_radius,
atom_indices, mask_start,
mask_end)
for i in range(len(atom_indices)):
# map back to global reference
self._ref_indices_sc[atom_indices[i]+chain_start_idx] =\
[x+chain_start_idx for x in indices[i]]
self._ref_distances_sc[atom_indices[i]+chain_start_idx] = distances[i]
self._NonSymDistances(self._ref_indices_sc, self._ref_distances_sc, self._NonSymDistances(self._ref_indices_sc, self._ref_distances_sc,
self._sym_ref_indices_sc, self._sym_ref_indices_sc,
self._sym_ref_distances_sc) self._sym_ref_distances_sc)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment