diff --git a/modules/mol/alg/pymod/lddt.py b/modules/mol/alg/pymod/lddt.py index 0cbefb658089b9ada83c485fa2f2c227b1d25759..f2b3540d71f1d4e290140af40ea92ea616cfddbb 100644 --- a/modules/mol/alg/pymod/lddt.py +++ b/modules/mol/alg/pymod/lddt.py @@ -16,6 +16,26 @@ except: x2 = x2.reshape(-1, 1) return np.sqrt(x2 - 2*xy + y2) # (m, n) +def blockwise_cdist(A, B, block_size=1000): + """ Memory efficient cdist implementation that performs blockwise operations + + scipy cdist uses 64 bit floats (double) which can scratch at the upper + memory end for most machines when number of positions become larger. + E.g. ~4000 residues might for example have 35000 atom positions. That's + Almost 10GB to hold all pairwise distances in 64bit floats. This function + calls cdist blockwise and stores the results in a 32bit float matrix. + + This function is adapted from chatgpt output + """ + A = A.astype(np.float32) + B = B.astype(np.float32) + M, N = A.shape[0], B.shape[0] + D = np.empty((M, N), dtype=np.float32) # Output in float32 to save memory + for i in range(0, M, block_size): + A_block = A[i:i+block_size] + D[i:i+block_size, :] = cdist(A_block, B).astype(np.float32) + return D + class CustomCompound: """ Defines atoms for custom compounds @@ -578,7 +598,6 @@ class lDDTScorer: raise RuntimeError("no_interchain and no_intrachain flags are " "mutually exclusive") - sym_ref_indices = None sym_ref_distances = None ref_indices = None @@ -1279,8 +1298,8 @@ class lDDTScorer: - process potentially interacting chain pairs - concatenate distances from all processing steps """ - ref_indices = [np.asarray([], dtype=np.int64) for idx in range(n_atoms)] - ref_distances = [np.asarray([], dtype=np.float64) for idx in range(n_atoms)] + ref_indices = [np.asarray([], dtype=np.int32) for idx in range(n_atoms)] + ref_distances = [np.asarray([], dtype=np.float32) for idx in range(n_atoms)] indices = [list() for _ in range(n_atoms)] distances = [list() for _ in range(n_atoms)] @@ -1300,7 +1319,7 @@ class lDDTScorer: hash_code = a.handle.GetHashCode() if hash_code in atom_index_mapping: p = a.GetPos() - pos_list.append(np.asarray([p[0], p[1], p[2]])) + pos_list.append(np.asarray([p[0], p[1], p[2]], dtype=np.float32)) atom_indices.append(atom_index_mapping[hash_code]) n_valid_atoms += 1 mask_start.extend([r_start_idx] * n_valid_atoms) @@ -1311,9 +1330,13 @@ class lDDTScorer: # nothing to do... continue - pos = np.vstack(pos_list) - atom_indices = np.asarray(atom_indices) - dists = cdist(pos, pos) + pos = np.vstack(pos_list, dtype=np.float32) + atom_indices = np.asarray(atom_indices, dtype=np.int32) + + if atom_indices.shape[0] > 20000: + dists = blockwise_cdist(pos, pos) + else: + dists = cdist(pos, pos) # apply masks far_away = 2 * inclusion_radius @@ -1329,6 +1352,8 @@ class lDDTScorer: indices[full_at_idx].append(indices_to_append) distances[full_at_idx].append(dists[idx, within_mask[idx,:]]) + dists = None + per_chain_pos.append(pos) per_chain_indices.append(atom_indices) @@ -1346,7 +1371,10 @@ class lDDTScorer: # process potentially interacting chains for pair in chain_pairs: - dists = cdist(per_chain_pos[pair[0]], per_chain_pos[pair[1]]) + if per_chain_pos[pair[0]].shape[0] > 20000 or per_chain_pos[pair[1]].shape[0] > 20000: + dists = blockwise_cdist(per_chain_pos[pair[0]], per_chain_pos[pair[1]]) + else: + dists = cdist(per_chain_pos[pair[0]], per_chain_pos[pair[1]]) within = dists <= inclusion_radius # process pair[0] @@ -1385,6 +1413,8 @@ class lDDTScorer: indices[at_idx].insert(insertion_idx, indices_to_insert) distances[at_idx].insert(insertion_idx, distances_to_insert) + dists = None + # concatenate distances from all processing steps for at_idx in range(n_atoms): if len(indices[at_idx]) > 0: @@ -1399,8 +1429,8 @@ class lDDTScorer: """Select subset of contacts only covering intra-chain contacts """ # init - ref_indices_sc = [np.asarray([], dtype=np.int64) for idx in range(n_atoms)] - ref_distances_sc = [np.asarray([], dtype=np.float64) for idx in range(n_atoms)] + ref_indices_sc = [np.asarray([], dtype=np.int32) for idx in range(n_atoms)] + ref_distances_sc = [np.asarray([], dtype=np.float32) for idx in range(n_atoms)] n_chains = len(chain_start_indices) for ch_idx in range(n_chains): @@ -1423,8 +1453,8 @@ class lDDTScorer: """Select subset of contacts only covering inter-chain contacts """ # init - ref_indices_ic = [np.asarray([], dtype=np.int64) for idx in range(n_atoms)] - ref_distances_ic = [np.asarray([], dtype=np.float64) for idx in range(n_atoms)] + ref_indices_ic = [np.asarray([], dtype=np.int32) for idx in range(n_atoms)] + ref_distances_ic = [np.asarray([], dtype=np.float32) for idx in range(n_atoms)] n_chains = len(chain_start_indices) for ch_idx in range(n_chains): @@ -1446,8 +1476,8 @@ class lDDTScorer: """Transfer indices/distances of non-symmetric atoms and return """ - sym_ref_indices = [np.asarray([], dtype=np.int64) for idx in range(n_atoms)] - sym_ref_distances = [np.asarray([], dtype=np.float64) for idx in range(n_atoms)] + sym_ref_indices = [np.asarray([], dtype=np.int32) for idx in range(n_atoms)] + sym_ref_distances = [np.asarray([], dtype=np.float32) for idx in range(n_atoms)] for idx in symmetric_atoms: indices = list() diff --git a/modules/mol/alg/tests/test_lddt.py b/modules/mol/alg/tests/test_lddt.py index 11bce0503e3d2b2b0d829428f45852429cd32249..57b24252e6d8e4df90b868d6f9767e21041669fb 100644 --- a/modules/mol/alg/tests/test_lddt.py +++ b/modules/mol/alg/tests/test_lddt.py @@ -234,7 +234,7 @@ class TestlDDT(unittest.TestCase): # this value is just blindly copied in without checking whether it makes # any sense... it's sole purpose is to trigger the respective flag # in lDDT computation - self.assertEqual(lDDT, 0.6171511842396518) + self.assertAlmostEqual(lDDT, 0.6171511842396518, places=5) def test_drmsd(self): model = _LoadFile("7SGN_C_model.pdb") @@ -245,7 +245,7 @@ class TestlDDT(unittest.TestCase): # this value is just blindly copied in without checking whether it makes # any sense... it's sole purpose is to trigger DRMSD computation - self.assertEqual(drmsd, 1.895447711911706) + self.assertAlmostEqual(drmsd, 1.895447711911706, places=5) class TestlDDTBS(unittest.TestCase):