Skip to content
Snippets Groups Projects
Commit f541d5e5 authored by Studer Gabriel's avatar Studer Gabriel
Browse files

lddt: make main lDDT() function more slim

parent 8356d810
No related branches found
No related tags found
No related merge requests found
...@@ -548,85 +548,13 @@ class lDDTScorer: ...@@ -548,85 +548,13 @@ class lDDTScorer:
f"not exist. Model has chains: " f"not exist. Model has chains: "
f"{[c.GetName() for c in model.chains]}") f"{[c.GetName() for c in model.chains]}")
# initialize positions with values far in nirvana. If a position is not # data objects defining model data - see _ProcessModel for rough
# set, it should be far away from any position in model. # description
max_pos = model.bounds.GetMax() pos, res_ref_atom_indices, res_atom_indices, res_atom_hashes, \
max_coordinate = abs(max(max_pos[0], max_pos[1], max_pos[2])) res_indices, symmetries = self._ProcessModel(model, chain_mapping,
max_coordinate += 42 * max(thresholds) residue_mapping = residue_mapping,
pos = np.ones((self.n_atoms, 3), dtype=np.float32) * max_coordinate thresholds = thresholds,
check_resnames = check_resnames)
# for each scored residue in model a list of indices describing the
# atoms from the reference that should be there
res_ref_atom_indices = list()
# for each scored residue in model a list of indices of atoms that are
# actually there
res_atom_indices = list()
# and the respective hash codes
# this is required if add_mdl_contacts is set to True
res_atom_hashes = list()
# indices of the scored residues
res_indices = list()
# Will contain one element per symmetry group
symmetries = list()
current_model_res_idx = -1
for ch in model.chains:
model_ch_name = ch.GetName()
if model_ch_name not in chain_mapping:
current_model_res_idx += len(ch.residues)
continue # additional model chain which is not mapped
target_ch_name = chain_mapping[model_ch_name]
rnums = self._GetChainRNums(ch, residue_mapping, model_ch_name,
target_ch_name)
for r, rnum in zip(ch.residues, rnums):
current_model_res_idx += 1
res_mapper_key = (target_ch_name, rnum)
if res_mapper_key not in self.res_mapper:
continue
r_idx = self.res_mapper[res_mapper_key]
if check_resnames and r.name != self.compound_names[r_idx]:
raise RuntimeError(
f"Residue name mismatch for {r}, "
f" expect {self.compound_names[r_idx]}"
)
res_start_idx = self.res_start_indices[r_idx]
rname = self.compound_names[r_idx]
anames = self.compound_anames[rname]
atoms = [r.FindAtom(aname) for aname in anames]
res_ref_atom_indices.append(
list(range(res_start_idx, res_start_idx + len(anames)))
)
res_atom_indices.append(list())
res_atom_hashes.append(list())
res_indices.append(current_model_res_idx)
for a_idx, a in enumerate(atoms):
if a.IsValid():
p = a.GetPos()
pos[res_start_idx + a_idx][0] = p[0]
pos[res_start_idx + a_idx][1] = p[1]
pos[res_start_idx + a_idx][2] = p[2]
res_atom_indices[-1].append(res_start_idx + a_idx)
res_atom_hashes[-1].append(a.handle.GetHashCode())
if rname in self.compound_symmetric_atoms:
sym_indices = list()
for sym_tuple in self.compound_symmetric_atoms[rname]:
a_one = atoms[sym_tuple[0]]
a_two = atoms[sym_tuple[1]]
if a_one.IsValid() and a_two.IsValid():
sym_indices.append(
(
res_start_idx + sym_tuple[0],
res_start_idx + sym_tuple[1],
)
)
if len(sym_indices) > 0:
symmetries.append(sym_indices)
if no_interchain and no_intrachain: if no_interchain and no_intrachain:
raise RuntimeError("no_interchain and no_intrachain flags are " raise RuntimeError("no_interchain and no_intrachain flags are "
...@@ -733,6 +661,95 @@ class lDDTScorer: ...@@ -733,6 +661,95 @@ class lDDTScorer:
else: else:
return self._GetNExp(list(range(s, e)), self.ref_indices) return self._GetNExp(list(range(s, e)), self.ref_indices)
def _ProcessModel(self, model, chain_mapping, residue_mapping = None,
thresholds = [0.5, 1.0, 2.0, 4.0],
check_resnames = True):
""" Helper that generates data structures from model
"""
# initialize positions with values far in nirvana. If a position is not
# set, it should be far away from any position in model.
max_pos = model.bounds.GetMax()
max_coordinate = abs(max(max_pos[0], max_pos[1], max_pos[2]))
max_coordinate += 42 * max(thresholds)
pos = np.ones((self.n_atoms, 3), dtype=np.float32) * max_coordinate
# for each scored residue in model a list of indices describing the
# atoms from the reference that should be there
res_ref_atom_indices = list()
# for each scored residue in model a list of indices of atoms that are
# actually there
res_atom_indices = list()
# and the respective hash codes
# this is required if add_mdl_contacts is set to True
res_atom_hashes = list()
# indices of the scored residues
res_indices = list()
# Will contain one element per symmetry group
symmetries = list()
current_model_res_idx = -1
for ch in model.chains:
model_ch_name = ch.GetName()
if model_ch_name not in chain_mapping:
current_model_res_idx += len(ch.residues)
continue # additional model chain which is not mapped
target_ch_name = chain_mapping[model_ch_name]
rnums = self._GetChainRNums(ch, residue_mapping, model_ch_name,
target_ch_name)
for r, rnum in zip(ch.residues, rnums):
current_model_res_idx += 1
res_mapper_key = (target_ch_name, rnum)
if res_mapper_key not in self.res_mapper:
continue
r_idx = self.res_mapper[res_mapper_key]
if check_resnames and r.name != self.compound_names[r_idx]:
raise RuntimeError(
f"Residue name mismatch for {r}, "
f" expect {self.compound_names[r_idx]}"
)
res_start_idx = self.res_start_indices[r_idx]
rname = self.compound_names[r_idx]
anames = self.compound_anames[rname]
atoms = [r.FindAtom(aname) for aname in anames]
res_ref_atom_indices.append(
list(range(res_start_idx, res_start_idx + len(anames)))
)
res_atom_indices.append(list())
res_atom_hashes.append(list())
res_indices.append(current_model_res_idx)
for a_idx, a in enumerate(atoms):
if a.IsValid():
p = a.GetPos()
pos[res_start_idx + a_idx][0] = p[0]
pos[res_start_idx + a_idx][1] = p[1]
pos[res_start_idx + a_idx][2] = p[2]
res_atom_indices[-1].append(res_start_idx + a_idx)
res_atom_hashes[-1].append(a.handle.GetHashCode())
if rname in self.compound_symmetric_atoms:
sym_indices = list()
for sym_tuple in self.compound_symmetric_atoms[rname]:
a_one = atoms[sym_tuple[0]]
a_two = atoms[sym_tuple[1]]
if a_one.IsValid() and a_two.IsValid():
sym_indices.append(
(
res_start_idx + sym_tuple[0],
res_start_idx + sym_tuple[1],
)
)
if len(sym_indices) > 0:
symmetries.append(sym_indices)
return (pos, res_ref_atom_indices, res_atom_indices, res_atom_hashes,
res_indices, symmetries)
def _GetExtraModelChainPenalty(self, model, chain_mapping): def _GetExtraModelChainPenalty(self, model, chain_mapping):
"""Counts n distances in extra model chains to be added as penalty """Counts n distances in extra model chains to be added as penalty
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment