Skip to content
Snippets Groups Projects
Commit 9406815f authored by Studer Gabriel's avatar Studer Gabriel
Browse files

make contact scoring fast again

parent 2454b691
No related branches found
No related tags found
No related merge requests found
......@@ -188,8 +188,6 @@ class ContactEntity:
return self._sequence[chain_name]
def _SetupContacts(self):
# this function is incredibly inefficient... if performance is an issue,
# go ahead and optimize
self._contacts = dict()
self._hr_contacts = list()
......@@ -198,32 +196,93 @@ class ContactEntity:
for r_idx, r in enumerate(ch.residues):
r.SetIntProp("contact_idx", r_idx)
residue_lists = list()
min_res_x = list()
min_res_y = list()
min_res_z = list()
max_res_x = list()
max_res_y = list()
max_res_z = list()
per_res_pos = list()
min_chain_pos = list()
max_chain_pos = list()
for cname in self.chain_names:
# q1 selects stuff in current chain that is close to any other chain
q1 = f"cname={mol.QueryQuoteName(cname)} and {self.contact_d} <> [cname!={mol.QueryQuoteName(cname)}]"
# q2 selects stuff in other chains that is close to current chain
q2 = f"cname!={mol.QueryQuoteName(cname)} and {self.contact_d} <> [cname={mol.QueryQuoteName(cname)}]"
v1 = self.view.Select(q1)
v2 = self.view.Select(q2)
v1_p = [geom.Vec3List([a.pos for a in r.atoms]) for r in v1.residues]
for r1, p1 in zip(v1.residues, v1_p):
for ch2 in v2.chains:
cname2 = ch2.GetName()
if cname2 > cname:
v2_p = [geom.Vec3List([a.pos for a in r.atoms]) for r in ch2.residues]
for r2, p2 in zip(ch2.residues, v2_p):
if p1.IsWithin(p2, self.contact_d):
cname_key = (cname, cname2)
if cname_key not in self._contacts:
self._contacts[cname_key] = set()
self._contacts[cname_key].add((r1.GetIntProp("contact_idx"),
r2.GetIntProp("contact_idx")))
rnum1 = r1.GetNumber()
hr1 = f"{cname}.{rnum1.num}.{rnum1.ins_code}"
rnum2 = r2.GetNumber()
hr2 = f"{cname2}.{rnum2.num}.{rnum2.ins_code}"
self._hr_contacts.append((hr1.strip("\u0000"),
hr2.strip("\u0000")))
ch = self.view.FindChain(cname)
if ch.GetAtomCount() == 0:
raise RuntimeError(f"Chain without atoms observed: \"{cname}\"")
residue_lists.append([r for r in ch.residues])
res_pos = list()
for r in residue_lists[-1]:
pos = np.zeros((r.GetAtomCount(), 3))
for at_idx, at in enumerate(r.atoms):
p = at.GetPos()
pos[(at_idx, 0)] = p[0]
pos[(at_idx, 1)] = p[1]
pos[(at_idx, 2)] = p[2]
res_pos.append(pos)
min_res_pos = np.vstack([p.min(0) for p in res_pos])
max_res_pos = np.vstack([p.max(0) for p in res_pos])
min_res_x.append(min_res_pos[:, 0])
min_res_y.append(min_res_pos[:, 1])
min_res_z.append(min_res_pos[:, 2])
max_res_x.append(max_res_pos[:, 0])
max_res_y.append(max_res_pos[:, 1])
max_res_z.append(max_res_pos[:, 2])
min_chain_pos.append(min_res_pos.min(0))
max_chain_pos.append(max_res_pos.max(0))
per_res_pos.append(res_pos)
# operate on squared contact_d (scd) to save some square roots
scd = self.contact_d * self.contact_d
for ch1_idx in range(len(self.chain_names)):
for ch2_idx in range(ch1_idx + 1, len(self.chain_names)):
# chains which fulfill the following expressions have no contact
# within self.contact_d
if np.max(min_chain_pos[ch1_idx] - max_chain_pos[ch2_idx]) > self.contact_d:
continue
if np.max(min_chain_pos[ch2_idx] - max_chain_pos[ch1_idx]) > self.contact_d:
continue
# same thing for residue positions but all at once
skip_one = np.subtract.outer(min_res_x[ch1_idx], max_res_x[ch2_idx]) > self.contact_d
skip_one = np.logical_or(skip_one, np.subtract.outer(min_res_y[ch1_idx], max_res_y[ch2_idx]) > self.contact_d)
skip_one = np.logical_or(skip_one, np.subtract.outer(min_res_z[ch1_idx], max_res_z[ch2_idx]) > self.contact_d)
skip_two = np.subtract.outer(min_res_x[ch2_idx], max_res_x[ch1_idx]) > self.contact_d
skip_two = np.logical_or(skip_two, np.subtract.outer(min_res_y[ch2_idx], max_res_y[ch1_idx]) > self.contact_d)
skip_two = np.logical_or(skip_two, np.subtract.outer(min_res_z[ch2_idx], max_res_z[ch1_idx]) > self.contact_d)
skip = np.logical_or(skip_one, skip_two.T)
# identify residue pairs for which we cannot exclude a contact
r1_indices, r2_indices = np.nonzero(np.logical_not(skip))
ch1_per_res_pos = per_res_pos[ch1_idx]
ch2_per_res_pos = per_res_pos[ch2_idx]
for r1_idx, r2_idx in zip(r1_indices, r2_indices):
# compute pairwise distances
p1 = ch1_per_res_pos[r1_idx]
p2 = ch2_per_res_pos[r2_idx]
x2 = np.sum(p1**2, axis=1) # (m)
y2 = np.sum(p2**2, axis=1) # (n)
xy = np.matmul(p1, p2.T) # (m, n)
x2 = x2.reshape(-1, 1)
squared_distances = x2 - 2*xy + y2 # (m, n)
if np.min(squared_distances) <= scd:
# its a contact!
r1 = residue_lists[ch1_idx][r1_idx]
r2 = residue_lists[ch2_idx][r2_idx]
cname_key = (self.chain_names[ch1_idx], self.chain_names[ch2_idx])
if cname_key not in self._contacts:
self._contacts[cname_key] = set()
self._contacts[cname_key].add((r1.GetIntProp("contact_idx"),
r2.GetIntProp("contact_idx")))
rnum1 = r1.GetNumber()
hr1 = f"{self.chain_names[ch1_idx]}.{rnum1.num}.{rnum1.ins_code}"
rnum2 = r2.GetNumber()
hr2 = f"{self.chain_names[ch2_idx]}.{rnum2.num}.{rnum2.ins_code}"
self._hr_contacts.append((hr1.strip("\u0000"),
hr2.strip("\u0000")))
def _SetupInterfaceResidues(self):
self._interface_residues = {cname: set() for cname in self.chain_names}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment