Skip to content
Snippets Groups Projects
Unverified Commit 88397abc authored by Xavier Robin's avatar Xavier Robin
Browse files

cleanup: SCHWED-5783 code cleanup

parent e58a2afa
No related branches found
No related tags found
No related merge requests found
import os
import warnings
import numpy as np
......@@ -75,7 +74,7 @@ class LigandScorer:
:type model_ligands: :class:`list`
:param target_ligands: Target ligands, as a list of
:class:`~ost.mol.ResidueHandle` belonging to the target
entity. Can be instanciated either a :class:list of
entity. Can be instantiated either a :class:list of
:class:`~ost.mol.ResidueHandle`/:class:`ost.mol.ResidueView`
or of :class:`ost.mol.EntityHandle`/:class:`ost.mol.EntityView`
containing a single residue each. If `None`, ligands will be
......@@ -238,7 +237,7 @@ class LigandScorer:
already_exists = new_entity.FindResidue(handle.chain.name,
handle.number).IsValid()
if already_exists:
msg = "A residue number %s already exists in chain %s" %(
msg = "A residue number %s already exists in chain %s" % (
handle.number, handle.chain.name)
raise RuntimeError(msg)
......@@ -348,7 +347,7 @@ class LigandScorer:
for resnum, old_res in enumerate(residues, 1):
seen_res_qn.append(old_res.qualified_name)
new_res = ed.AppendResidue(bs_chain, old_res.handle,
deep=True)
deep=True)
ed.SetResidueNumber(new_res, mol.ResNum(resnum))
# Add extra residues at the end.
......@@ -363,21 +362,20 @@ class LigandScorer:
# Add the ligand in chain _
ligand_chain = ed.InsertChain("_")
ligand_res = ed.AppendResidue(ligand_chain, ligand,
deep=True)
deep=True)
ed.RenameResidue(ligand_res, "LIG")
ed.SetResidueNumber(ligand_res, mol.ResNum(1))
ed.UpdateICS()
return bs_ent
def _compute_scores(self):
""""""
# Create the matrix
self._rmsd_full_matrix = np.empty((len(self.target_ligands),
len(self.model_ligands)), dtype=dict)
self._lddt_pli_full_matrix = np.empty((len(self.target_ligands),
len(self.model_ligands)), dtype=dict)
self._rmsd_full_matrix = np.empty(
(len(self.target_ligands), len(self.model_ligands)), dtype=dict)
self._lddt_pli_full_matrix = np.empty(
(len(self.target_ligands), len(self.model_ligands)), dtype=dict)
for target_i, target_ligand in enumerate(self.target_ligands):
LogDebug("Compute RMSD for target ligand %s" % target_ligand)
......@@ -473,17 +471,19 @@ class LigandScorer:
mdl_bs_ent, chain_mapping={"A": "A", "_": "_"},
no_intrachain=True,
return_dist_test=True,
check_resnames = self.check_resnames)
check_resnames=self.check_resnames)
# Save results?
best_lddt = self._lddt_pli_full_matrix[target_i, model_i]["lddt_pli"]
best_lddt = self._lddt_pli_full_matrix[
target_i, model_i]["lddt_pli"]
if global_lddt > best_lddt:
self._lddt_pli_full_matrix[target_i, model_i].update({
"lddt_pli": global_lddt,
"lddt_pli_n_contacts": lddt_tot,
})
def _find_ligand_assignment(self, mat1, mat2):
@staticmethod
def _find_ligand_assignment(mat1, mat2):
""" Find the ligand assignment based on mat1
Both mat1 and mat2 should "look" like RMSD - ie be between inf (bad)
......@@ -535,9 +535,12 @@ class LigandScorer:
for assignment in assignments:
trg_idx, mdl_idx = assignment
mdl_lig_qname = self.model_ligands[mdl_idx].qualified_name
self._rmsd[mdl_lig_qname] = self._rmsd_full_matrix[trg_idx, mdl_idx]["rmsd"]
self._rmsd_assignment[mdl_lig_qname] = self._rmsd_full_matrix[trg_idx, mdl_idx]["target_ligand"].qualified_name
self._rmsd_details[mdl_lig_qname] = self._rmsd_full_matrix[trg_idx, mdl_idx]
self._rmsd[mdl_lig_qname] = self._rmsd_full_matrix[
trg_idx, mdl_idx]["rmsd"]
self._rmsd_assignment[mdl_lig_qname] = self._rmsd_full_matrix[
trg_idx, mdl_idx]["target_ligand"].qualified_name
self._rmsd_details[mdl_lig_qname] = self._rmsd_full_matrix[
trg_idx, mdl_idx]
def _assign_ligands_lddt_pli(self):
""" Assign ligands based on lDDT-PLI.
......@@ -554,9 +557,12 @@ class LigandScorer:
for assignment in assignments:
trg_idx, mdl_idx = assignment
mdl_lig_qname = self.model_ligands[mdl_idx].qualified_name
self._lddt_pli[mdl_lig_qname] = self._lddt_pli_full_matrix[trg_idx, mdl_idx]["lddt_pli"]
self._lddt_pli_assignment[mdl_lig_qname] = self._lddt_pli_full_matrix[trg_idx, mdl_idx]["target_ligand"].qualified_name
self._lddt_pli_details[mdl_lig_qname] = self._lddt_pli_full_matrix[trg_idx, mdl_idx]
self._lddt_pli[mdl_lig_qname] = self._lddt_pli_full_matrix[
trg_idx, mdl_idx]["lddt_pli"]
self._lddt_pli_assignment[mdl_lig_qname] = self._lddt_pli_full_matrix[
trg_idx, mdl_idx]["target_ligand"].qualified_name
self._lddt_pli_details[mdl_lig_qname] = self._lddt_pli_full_matrix[
trg_idx, mdl_idx]
@property
def rmsd_matrix(self):
......@@ -574,7 +580,8 @@ class LigandScorer:
self._rmsd_matrix = np.full(shape, np.inf)
for i, j in np.ndindex(shape):
if self._rmsd_full_matrix[i, j] is not None:
self._rmsd_matrix[i, j] = self._rmsd_full_matrix[i, j]["rmsd"]
self._rmsd_matrix[i, j] = self._rmsd_full_matrix[
i, j]["rmsd"]
return self._rmsd_matrix
@property
......@@ -593,7 +600,8 @@ class LigandScorer:
self._lddt_pli_matrix = np.zeros(shape)
for i, j in np.ndindex(shape):
if self._lddt_pli_full_matrix[i, j] is not None:
self._lddt_pli_matrix[i, j] = self._lddt_pli_full_matrix[i, j]["lddt_pli"]
self._lddt_pli_matrix[i, j] = self._lddt_pli_full_matrix[
i, j]["lddt_pli"]
return self._lddt_pli_matrix
@property
......@@ -712,10 +720,13 @@ def ResidueToGraph(residue, by_atom_index=False):
Nodes are labeled with the Atom's :attr:`~ost.mol.AtomHandle.element`.
"""
nxg = networkx.Graph()
nxg.add_nodes_from([a.name for a in residue.atoms], element=[a.element for a in residue.atoms])
nxg.add_nodes_from([a.name for a in residue.atoms], element=[
a.element for a in residue.atoms])
# This will list all edges twice - once for every atom of the pair.
# But as of NetworkX 3.0 adding the same edge twice has no effect, so we're good.
nxg.add_edges_from([(b.first.name, b.second.name) for a in residue.atoms for b in a.GetBondList()])
nxg.add_edges_from([(
b.first.name,
b.second.name) for a in residue.atoms for b in a.GetBondList()])
if by_atom_index:
nxg = networkx.relabel_nodes(nxg,
......@@ -727,7 +738,7 @@ def ResidueToGraph(residue, by_atom_index=False):
def SCRMSD(model_ligand, target_ligand, transformation=geom.Mat4(),
substructure_match=False):
substructure_match=False):
"""Calculate symmetry-corrected RMSD.
Binding site superposition must be computed separately and passed as
......@@ -817,23 +828,23 @@ def _ComputeSymmetries(model_ligand, target_ligand, substructure_match=False,
# This is because a subgraph of model is isomorphic to target - but not the opposite
# as we only consider partial ligands in the reference.
# Make sure to generate the symmetries correctly in the end
GM = networkx.algorithms.isomorphism.GraphMatcher(
gm = networkx.algorithms.isomorphism.GraphMatcher(
model_graph, target_graph, node_match=lambda x, y:
x["element"] == y["element"])
if GM.is_isomorphic():
if gm.is_isomorphic():
symmetries = [
(list(isomorphism.values()), list(isomorphism.keys()))
for isomorphism in GM.isomorphisms_iter()]
for isomorphism in gm.isomorphisms_iter()]
assert len(symmetries) > 0
LogDebug("Found %s isomorphic mappings (symmetries)" % len(symmetries))
elif GM.subgraph_is_isomorphic() and substructure_match:
elif gm.subgraph_is_isomorphic() and substructure_match:
symmetries = [(list(isomorphism.values()), list(isomorphism.keys())) for isomorphism in
GM.subgraph_isomorphisms_iter()]
gm.subgraph_isomorphisms_iter()]
assert len(symmetries) > 0
# Assert that all the atoms in the target are part of the substructure
assert len(symmetries[0][0]) == len(target_ligand.atoms)
LogDebug("Found %s subgraph isomorphisms (symmetries)" % len(symmetries))
elif GM.subgraph_is_isomorphic():
elif gm.subgraph_is_isomorphic():
LogDebug("Found subgraph isomorphisms (symmetries), but"
" ignoring because substructure_match=False")
raise NoSymmetryError("No symmetry between %s and %s" % (
......@@ -851,4 +862,5 @@ class NoSymmetryError(Exception):
"""
pass
__all__ = ["LigandScorer", "ResidueToGraph", "SCRMSD", "NoSymmetryError"]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment