diff --git a/modules/mol/alg/pymod/contact_score.py b/modules/mol/alg/pymod/contact_score.py index 5dd7e48a8f5ba71dd3d4624af06688650783e275..845f620e0da0032deccc5184ee8a3bac482f12c4 100644 --- a/modules/mol/alg/pymod/contact_score.py +++ b/modules/mol/alg/pymod/contact_score.py @@ -188,8 +188,6 @@ class ContactEntity: return self._sequence[chain_name] def _SetupContacts(self): - # this function is incredibly inefficient... if performance is an issue, - # go ahead and optimize self._contacts = dict() self._hr_contacts = list() @@ -198,32 +196,93 @@ class ContactEntity: for r_idx, r in enumerate(ch.residues): r.SetIntProp("contact_idx", r_idx) + residue_lists = list() + min_res_x = list() + min_res_y = list() + min_res_z = list() + max_res_x = list() + max_res_y = list() + max_res_z = list() + per_res_pos = list() + min_chain_pos = list() + max_chain_pos = list() + for cname in self.chain_names: - # q1 selects stuff in current chain that is close to any other chain - q1 = f"cname={mol.QueryQuoteName(cname)} and {self.contact_d} <> [cname!={mol.QueryQuoteName(cname)}]" - # q2 selects stuff in other chains that is close to current chain - q2 = f"cname!={mol.QueryQuoteName(cname)} and {self.contact_d} <> [cname={mol.QueryQuoteName(cname)}]" - v1 = self.view.Select(q1) - v2 = self.view.Select(q2) - v1_p = [geom.Vec3List([a.pos for a in r.atoms]) for r in v1.residues] - for r1, p1 in zip(v1.residues, v1_p): - for ch2 in v2.chains: - cname2 = ch2.GetName() - if cname2 > cname: - v2_p = [geom.Vec3List([a.pos for a in r.atoms]) for r in ch2.residues] - for r2, p2 in zip(ch2.residues, v2_p): - if p1.IsWithin(p2, self.contact_d): - cname_key = (cname, cname2) - if cname_key not in self._contacts: - self._contacts[cname_key] = set() - self._contacts[cname_key].add((r1.GetIntProp("contact_idx"), - r2.GetIntProp("contact_idx"))) - rnum1 = r1.GetNumber() - hr1 = f"{cname}.{rnum1.num}.{rnum1.ins_code}" - rnum2 = r2.GetNumber() - hr2 = f"{cname2}.{rnum2.num}.{rnum2.ins_code}" - self._hr_contacts.append((hr1.strip("\u0000"), - hr2.strip("\u0000"))) + ch = self.view.FindChain(cname) + if ch.GetAtomCount() == 0: + raise RuntimeError(f"Chain without atoms observed: \"{cname}\"") + residue_lists.append([r for r in ch.residues]) + res_pos = list() + for r in residue_lists[-1]: + pos = np.zeros((r.GetAtomCount(), 3)) + for at_idx, at in enumerate(r.atoms): + p = at.GetPos() + pos[(at_idx, 0)] = p[0] + pos[(at_idx, 1)] = p[1] + pos[(at_idx, 2)] = p[2] + res_pos.append(pos) + min_res_pos = np.vstack([p.min(0) for p in res_pos]) + max_res_pos = np.vstack([p.max(0) for p in res_pos]) + min_res_x.append(min_res_pos[:, 0]) + min_res_y.append(min_res_pos[:, 1]) + min_res_z.append(min_res_pos[:, 2]) + max_res_x.append(max_res_pos[:, 0]) + max_res_y.append(max_res_pos[:, 1]) + max_res_z.append(max_res_pos[:, 2]) + min_chain_pos.append(min_res_pos.min(0)) + max_chain_pos.append(max_res_pos.max(0)) + per_res_pos.append(res_pos) + + # operate on squared contact_d (scd) to save some square roots + scd = self.contact_d * self.contact_d + + for ch1_idx in range(len(self.chain_names)): + for ch2_idx in range(ch1_idx + 1, len(self.chain_names)): + # chains which fulfill the following expressions have no contact + # within self.contact_d + if np.max(min_chain_pos[ch1_idx] - max_chain_pos[ch2_idx]) > self.contact_d: + continue + if np.max(min_chain_pos[ch2_idx] - max_chain_pos[ch1_idx]) > self.contact_d: + continue + + # same thing for residue positions but all at once + skip_one = np.subtract.outer(min_res_x[ch1_idx], max_res_x[ch2_idx]) > self.contact_d + skip_one = np.logical_or(skip_one, np.subtract.outer(min_res_y[ch1_idx], max_res_y[ch2_idx]) > self.contact_d) + skip_one = np.logical_or(skip_one, np.subtract.outer(min_res_z[ch1_idx], max_res_z[ch2_idx]) > self.contact_d) + skip_two = np.subtract.outer(min_res_x[ch2_idx], max_res_x[ch1_idx]) > self.contact_d + skip_two = np.logical_or(skip_two, np.subtract.outer(min_res_y[ch2_idx], max_res_y[ch1_idx]) > self.contact_d) + skip_two = np.logical_or(skip_two, np.subtract.outer(min_res_z[ch2_idx], max_res_z[ch1_idx]) > self.contact_d) + skip = np.logical_or(skip_one, skip_two.T) + + # identify residue pairs for which we cannot exclude a contact + r1_indices, r2_indices = np.nonzero(np.logical_not(skip)) + ch1_per_res_pos = per_res_pos[ch1_idx] + ch2_per_res_pos = per_res_pos[ch2_idx] + for r1_idx, r2_idx in zip(r1_indices, r2_indices): + # compute pairwise distances + p1 = ch1_per_res_pos[r1_idx] + p2 = ch2_per_res_pos[r2_idx] + x2 = np.sum(p1**2, axis=1) # (m) + y2 = np.sum(p2**2, axis=1) # (n) + xy = np.matmul(p1, p2.T) # (m, n) + x2 = x2.reshape(-1, 1) + squared_distances = x2 - 2*xy + y2 # (m, n) + if np.min(squared_distances) <= scd: + # its a contact! + r1 = residue_lists[ch1_idx][r1_idx] + r2 = residue_lists[ch2_idx][r2_idx] + cname_key = (self.chain_names[ch1_idx], self.chain_names[ch2_idx]) + if cname_key not in self._contacts: + self._contacts[cname_key] = set() + self._contacts[cname_key].add((r1.GetIntProp("contact_idx"), + r2.GetIntProp("contact_idx"))) + rnum1 = r1.GetNumber() + hr1 = f"{self.chain_names[ch1_idx]}.{rnum1.num}.{rnum1.ins_code}" + rnum2 = r2.GetNumber() + hr2 = f"{self.chain_names[ch2_idx]}.{rnum2.num}.{rnum2.ins_code}" + self._hr_contacts.append((hr1.strip("\u0000"), + hr2.strip("\u0000"))) + def _SetupInterfaceResidues(self): self._interface_residues = {cname: set() for cname in self.chain_names}