diff --git a/modules/mol/alg/pymod/ligand_scoring.py b/modules/mol/alg/pymod/ligand_scoring.py index 3971195f298adc9fcd81b005898d81b5ef1e81ab..847f548b8e16adb7e3d81145b85547de18885926 100644 --- a/modules/mol/alg/pymod/ligand_scoring.py +++ b/modules/mol/alg/pymod/ligand_scoring.py @@ -308,15 +308,32 @@ class LigandScorer: pass -def ResidueToGraph(residue): +def ResidueToGraph(residue, by_atom_index=False): """Return a NetworkX graph representation of the residue. - Nodes are labeled with the Atom's :attr:`~ost.mol.AtomHandle.element`""" + :param residue: the residue from which to derive the graph + :type residue: :class:`ost.mol.ResidueHandle` or + :class:`ost.mol.ResidueView` + :param by_atom_index: Set this parameter to True if you need the nodes to + be labeled by atom index (within the residue). + Otherwise, if False, the nodes will be labeled by + atom names. + :type by_atom_index: :class:`bool` + + Nodes are labeled with the Atom's :attr:`~ost.mol.AtomHandle.element` + """ nxg = networkx.Graph() nxg.add_nodes_from([a.name for a in residue.atoms], element=[a.element for a in residue.atoms]) # This will list all edges twice - once for every atom of the pair. - # But as of NetworkX 3.0 adding the same edge twice has no effect so we're good. + # But as of NetworkX 3.0 adding the same edge twice has no effect, so we're good. nxg.add_edges_from([(b.first.name, b.second.name) for a in residue.atoms for b in a.GetBondList()]) + + if by_atom_index: + nxg = networkx.relabel_nodes(nxg, + {a: b for a, b in zip( + [a.name for a in residue.atoms], + range(len(residue.atoms)))}, + True) return nxg @@ -398,25 +415,14 @@ def _ComputeSymmetries(model_ligand, target_ligand, substructure_match=False, to refer to atom index (within the residue). Otherwise, if False, the symmetries refer to atom names. + :type by_atom_index: :class:`bool` :raises: :class:`NoSymmetryError` when no symmetry can be found. """ # Get the Graphs of the ligands - model_graph = ResidueToGraph(model_ligand) - target_graph = ResidueToGraph(target_ligand) - - if by_atom_index: - networkx.relabel_nodes(model_graph, - {a: b for a, b in zip( - [a.name for a in model_ligand.atoms], - range(len(model_ligand.atoms)))}, - False) - networkx.relabel_nodes(target_graph, - {a: b for a, b in zip( - [a.name for a in target_ligand.atoms], - range(len(target_ligand.atoms)))}, - False) + model_graph = ResidueToGraph(model_ligand, by_atom_index=by_atom_index) + target_graph = ResidueToGraph(target_ligand, by_atom_index=by_atom_index) # Note the argument order (model, target) which differs from spyrmsd. # This is because a subgraph of model is isomorphic to target - but not the opposite diff --git a/modules/mol/alg/tests/test_ligand_scoring.py b/modules/mol/alg/tests/test_ligand_scoring.py index 8b3121d9753619c33b275e83a11c96a5c340be39..d7cf4ef07a0ed0fe7d1eed5ee4d74b9fd9518081 100644 --- a/modules/mol/alg/tests/test_ligand_scoring.py +++ b/modules/mol/alg/tests/test_ligand_scoring.py @@ -134,6 +134,12 @@ class TestLigandScoring(unittest.TestCase): # Check an arbitrary node assert [a for a in graph.adj["14"].keys()] == ["13", "29"] + graph = ResidueToGraph(mdl_lig.residues[0], by_atom_index=True) + assert len(graph.edges) == 34 + assert len(graph.nodes) == 32 + # Check an arbitrary node + assert [a for a in graph.adj[13].keys()] == [12, 28] + def test__ComputeSymmetries(self): """Test that _ComputeSymmetries works. """