diff --git a/modules/mol/alg/pymod/lddt.py b/modules/mol/alg/pymod/lddt.py index 48e40c264089f37905fa8c7b61ed9ce355819998..adc0a8901971657d39f8eb3f692f82cf615efaa8 100644 --- a/modules/mol/alg/pymod/lddt.py +++ b/modules/mol/alg/pymod/lddt.py @@ -548,85 +548,13 @@ class lDDTScorer: f"not exist. Model has chains: " f"{[c.GetName() for c in model.chains]}") - # initialize positions with values far in nirvana. If a position is not - # set, it should be far away from any position in model. - max_pos = model.bounds.GetMax() - max_coordinate = abs(max(max_pos[0], max_pos[1], max_pos[2])) - max_coordinate += 42 * max(thresholds) - pos = np.ones((self.n_atoms, 3), dtype=np.float32) * max_coordinate - - # for each scored residue in model a list of indices describing the - # atoms from the reference that should be there - res_ref_atom_indices = list() - - # for each scored residue in model a list of indices of atoms that are - # actually there - res_atom_indices = list() - - # and the respective hash codes - # this is required if add_mdl_contacts is set to True - res_atom_hashes = list() - - # indices of the scored residues - res_indices = list() - - # Will contain one element per symmetry group - symmetries = list() - - current_model_res_idx = -1 - for ch in model.chains: - model_ch_name = ch.GetName() - if model_ch_name not in chain_mapping: - current_model_res_idx += len(ch.residues) - continue # additional model chain which is not mapped - target_ch_name = chain_mapping[model_ch_name] - - rnums = self._GetChainRNums(ch, residue_mapping, model_ch_name, - target_ch_name) - - for r, rnum in zip(ch.residues, rnums): - current_model_res_idx += 1 - res_mapper_key = (target_ch_name, rnum) - if res_mapper_key not in self.res_mapper: - continue - r_idx = self.res_mapper[res_mapper_key] - if check_resnames and r.name != self.compound_names[r_idx]: - raise RuntimeError( - f"Residue name mismatch for {r}, " - f" expect {self.compound_names[r_idx]}" - ) - res_start_idx = self.res_start_indices[r_idx] - rname = self.compound_names[r_idx] - anames = self.compound_anames[rname] - atoms = [r.FindAtom(aname) for aname in anames] - res_ref_atom_indices.append( - list(range(res_start_idx, res_start_idx + len(anames))) - ) - res_atom_indices.append(list()) - res_atom_hashes.append(list()) - res_indices.append(current_model_res_idx) - for a_idx, a in enumerate(atoms): - if a.IsValid(): - p = a.GetPos() - pos[res_start_idx + a_idx][0] = p[0] - pos[res_start_idx + a_idx][1] = p[1] - pos[res_start_idx + a_idx][2] = p[2] - res_atom_indices[-1].append(res_start_idx + a_idx) - res_atom_hashes[-1].append(a.handle.GetHashCode()) - if rname in self.compound_symmetric_atoms: - sym_indices = list() - for sym_tuple in self.compound_symmetric_atoms[rname]: - a_one = atoms[sym_tuple[0]] - a_two = atoms[sym_tuple[1]] - if a_one.IsValid() and a_two.IsValid(): - sym_indices.append( - ( - res_start_idx + sym_tuple[0], - res_start_idx + sym_tuple[1], - ) - ) - if len(sym_indices) > 0: - symmetries.append(sym_indices) + # data objects defining model data - see _ProcessModel for rough + # description + pos, res_ref_atom_indices, res_atom_indices, res_atom_hashes, \ + res_indices, symmetries = self._ProcessModel(model, chain_mapping, + residue_mapping = residue_mapping, + thresholds = thresholds, + check_resnames = check_resnames) if no_interchain and no_intrachain: raise RuntimeError("no_interchain and no_intrachain flags are " @@ -733,6 +661,95 @@ class lDDTScorer: else: return self._GetNExp(list(range(s, e)), self.ref_indices) + def _ProcessModel(self, model, chain_mapping, residue_mapping = None, + thresholds = [0.5, 1.0, 2.0, 4.0], + check_resnames = True): + """ Helper that generates data structures from model + """ + + # initialize positions with values far in nirvana. If a position is not + # set, it should be far away from any position in model. + max_pos = model.bounds.GetMax() + max_coordinate = abs(max(max_pos[0], max_pos[1], max_pos[2])) + max_coordinate += 42 * max(thresholds) + pos = np.ones((self.n_atoms, 3), dtype=np.float32) * max_coordinate + + # for each scored residue in model a list of indices describing the + # atoms from the reference that should be there + res_ref_atom_indices = list() + + # for each scored residue in model a list of indices of atoms that are + # actually there + res_atom_indices = list() + + # and the respective hash codes + # this is required if add_mdl_contacts is set to True + res_atom_hashes = list() + + # indices of the scored residues + res_indices = list() + + # Will contain one element per symmetry group + symmetries = list() + + current_model_res_idx = -1 + for ch in model.chains: + model_ch_name = ch.GetName() + if model_ch_name not in chain_mapping: + current_model_res_idx += len(ch.residues) + continue # additional model chain which is not mapped + target_ch_name = chain_mapping[model_ch_name] + + rnums = self._GetChainRNums(ch, residue_mapping, model_ch_name, + target_ch_name) + + for r, rnum in zip(ch.residues, rnums): + current_model_res_idx += 1 + res_mapper_key = (target_ch_name, rnum) + if res_mapper_key not in self.res_mapper: + continue + r_idx = self.res_mapper[res_mapper_key] + if check_resnames and r.name != self.compound_names[r_idx]: + raise RuntimeError( + f"Residue name mismatch for {r}, " + f" expect {self.compound_names[r_idx]}" + ) + res_start_idx = self.res_start_indices[r_idx] + rname = self.compound_names[r_idx] + anames = self.compound_anames[rname] + atoms = [r.FindAtom(aname) for aname in anames] + res_ref_atom_indices.append( + list(range(res_start_idx, res_start_idx + len(anames))) + ) + res_atom_indices.append(list()) + res_atom_hashes.append(list()) + res_indices.append(current_model_res_idx) + for a_idx, a in enumerate(atoms): + if a.IsValid(): + p = a.GetPos() + pos[res_start_idx + a_idx][0] = p[0] + pos[res_start_idx + a_idx][1] = p[1] + pos[res_start_idx + a_idx][2] = p[2] + res_atom_indices[-1].append(res_start_idx + a_idx) + res_atom_hashes[-1].append(a.handle.GetHashCode()) + if rname in self.compound_symmetric_atoms: + sym_indices = list() + for sym_tuple in self.compound_symmetric_atoms[rname]: + a_one = atoms[sym_tuple[0]] + a_two = atoms[sym_tuple[1]] + if a_one.IsValid() and a_two.IsValid(): + sym_indices.append( + ( + res_start_idx + sym_tuple[0], + res_start_idx + sym_tuple[1], + ) + ) + if len(sym_indices) > 0: + symmetries.append(sym_indices) + + return (pos, res_ref_atom_indices, res_atom_indices, res_atom_hashes, + res_indices, symmetries) + def _GetExtraModelChainPenalty(self, model, chain_mapping): """Counts n distances in extra model chains to be added as penalty