From 95c5150922aca4845f346acf2a07cc30521b7739 Mon Sep 17 00:00:00 2001 From: Gabriel Studer <gabriel.studer@unibas.ch> Date: Mon, 12 May 2025 16:32:34 +0200 Subject: [PATCH] reduce lDDT memory peaks lDDT computes all vs all pairwise distances within individual chains, as well as between pairs of chains. A chain of 5000 residues is already considered long and leads to around 40000 atom positions. scipy cdist strictly operates on doubles (64bit floats). In case of 40000 positions, it needs around 12GB to hold all pairwise distances. The new implementation uses a fallback for every cdist call above 20000 positions. cdist is called blockwise and stores the 64 bit results for each block in a 32bit result matrix, halving the memory requirements. --- modules/mol/alg/pymod/lddt.py | 58 ++++++++++++++++++++++-------- modules/mol/alg/tests/test_lddt.py | 4 +-- 2 files changed, 46 insertions(+), 16 deletions(-) diff --git a/modules/mol/alg/pymod/lddt.py b/modules/mol/alg/pymod/lddt.py index 0cbefb65..f2b3540d 100644 --- a/modules/mol/alg/pymod/lddt.py +++ b/modules/mol/alg/pymod/lddt.py @@ -16,6 +16,26 @@ except: x2 = x2.reshape(-1, 1) return np.sqrt(x2 - 2*xy + y2) # (m, n) +def blockwise_cdist(A, B, block_size=1000): + """ Memory efficient cdist implementation that performs blockwise operations + + scipy cdist uses 64 bit floats (double) which can scratch at the upper + memory end for most machines when number of positions become larger. + E.g. ~4000 residues might for example have 35000 atom positions. That's + Almost 10GB to hold all pairwise distances in 64bit floats. This function + calls cdist blockwise and stores the results in a 32bit float matrix. + + This function is adapted from chatgpt output + """ + A = A.astype(np.float32) + B = B.astype(np.float32) + M, N = A.shape[0], B.shape[0] + D = np.empty((M, N), dtype=np.float32) # Output in float32 to save memory + for i in range(0, M, block_size): + A_block = A[i:i+block_size] + D[i:i+block_size, :] = cdist(A_block, B).astype(np.float32) + return D + class CustomCompound: """ Defines atoms for custom compounds @@ -578,7 +598,6 @@ class lDDTScorer: raise RuntimeError("no_interchain and no_intrachain flags are " "mutually exclusive") - sym_ref_indices = None sym_ref_distances = None ref_indices = None @@ -1279,8 +1298,8 @@ class lDDTScorer: - process potentially interacting chain pairs - concatenate distances from all processing steps """ - ref_indices = [np.asarray([], dtype=np.int64) for idx in range(n_atoms)] - ref_distances = [np.asarray([], dtype=np.float64) for idx in range(n_atoms)] + ref_indices = [np.asarray([], dtype=np.int32) for idx in range(n_atoms)] + ref_distances = [np.asarray([], dtype=np.float32) for idx in range(n_atoms)] indices = [list() for _ in range(n_atoms)] distances = [list() for _ in range(n_atoms)] @@ -1300,7 +1319,7 @@ class lDDTScorer: hash_code = a.handle.GetHashCode() if hash_code in atom_index_mapping: p = a.GetPos() - pos_list.append(np.asarray([p[0], p[1], p[2]])) + pos_list.append(np.asarray([p[0], p[1], p[2]], dtype=np.float32)) atom_indices.append(atom_index_mapping[hash_code]) n_valid_atoms += 1 mask_start.extend([r_start_idx] * n_valid_atoms) @@ -1311,9 +1330,13 @@ class lDDTScorer: # nothing to do... continue - pos = np.vstack(pos_list) - atom_indices = np.asarray(atom_indices) - dists = cdist(pos, pos) + pos = np.vstack(pos_list, dtype=np.float32) + atom_indices = np.asarray(atom_indices, dtype=np.int32) + + if atom_indices.shape[0] > 20000: + dists = blockwise_cdist(pos, pos) + else: + dists = cdist(pos, pos) # apply masks far_away = 2 * inclusion_radius @@ -1329,6 +1352,8 @@ class lDDTScorer: indices[full_at_idx].append(indices_to_append) distances[full_at_idx].append(dists[idx, within_mask[idx,:]]) + dists = None + per_chain_pos.append(pos) per_chain_indices.append(atom_indices) @@ -1346,7 +1371,10 @@ class lDDTScorer: # process potentially interacting chains for pair in chain_pairs: - dists = cdist(per_chain_pos[pair[0]], per_chain_pos[pair[1]]) + if per_chain_pos[pair[0]].shape[0] > 20000 or per_chain_pos[pair[1]].shape[0] > 20000: + dists = blockwise_cdist(per_chain_pos[pair[0]], per_chain_pos[pair[1]]) + else: + dists = cdist(per_chain_pos[pair[0]], per_chain_pos[pair[1]]) within = dists <= inclusion_radius # process pair[0] @@ -1385,6 +1413,8 @@ class lDDTScorer: indices[at_idx].insert(insertion_idx, indices_to_insert) distances[at_idx].insert(insertion_idx, distances_to_insert) + dists = None + # concatenate distances from all processing steps for at_idx in range(n_atoms): if len(indices[at_idx]) > 0: @@ -1399,8 +1429,8 @@ class lDDTScorer: """Select subset of contacts only covering intra-chain contacts """ # init - ref_indices_sc = [np.asarray([], dtype=np.int64) for idx in range(n_atoms)] - ref_distances_sc = [np.asarray([], dtype=np.float64) for idx in range(n_atoms)] + ref_indices_sc = [np.asarray([], dtype=np.int32) for idx in range(n_atoms)] + ref_distances_sc = [np.asarray([], dtype=np.float32) for idx in range(n_atoms)] n_chains = len(chain_start_indices) for ch_idx in range(n_chains): @@ -1423,8 +1453,8 @@ class lDDTScorer: """Select subset of contacts only covering inter-chain contacts """ # init - ref_indices_ic = [np.asarray([], dtype=np.int64) for idx in range(n_atoms)] - ref_distances_ic = [np.asarray([], dtype=np.float64) for idx in range(n_atoms)] + ref_indices_ic = [np.asarray([], dtype=np.int32) for idx in range(n_atoms)] + ref_distances_ic = [np.asarray([], dtype=np.float32) for idx in range(n_atoms)] n_chains = len(chain_start_indices) for ch_idx in range(n_chains): @@ -1446,8 +1476,8 @@ class lDDTScorer: """Transfer indices/distances of non-symmetric atoms and return """ - sym_ref_indices = [np.asarray([], dtype=np.int64) for idx in range(n_atoms)] - sym_ref_distances = [np.asarray([], dtype=np.float64) for idx in range(n_atoms)] + sym_ref_indices = [np.asarray([], dtype=np.int32) for idx in range(n_atoms)] + sym_ref_distances = [np.asarray([], dtype=np.float32) for idx in range(n_atoms)] for idx in symmetric_atoms: indices = list() diff --git a/modules/mol/alg/tests/test_lddt.py b/modules/mol/alg/tests/test_lddt.py index 11bce050..57b24252 100644 --- a/modules/mol/alg/tests/test_lddt.py +++ b/modules/mol/alg/tests/test_lddt.py @@ -234,7 +234,7 @@ class TestlDDT(unittest.TestCase): # this value is just blindly copied in without checking whether it makes # any sense... it's sole purpose is to trigger the respective flag # in lDDT computation - self.assertEqual(lDDT, 0.6171511842396518) + self.assertAlmostEqual(lDDT, 0.6171511842396518, places=5) def test_drmsd(self): model = _LoadFile("7SGN_C_model.pdb") @@ -245,7 +245,7 @@ class TestlDDT(unittest.TestCase): # this value is just blindly copied in without checking whether it makes # any sense... it's sole purpose is to trigger DRMSD computation - self.assertEqual(drmsd, 1.895447711911706) + self.assertAlmostEqual(drmsd, 1.895447711911706, places=5) class TestlDDTBS(unittest.TestCase): -- GitLab