From 3eb3b00b3b210fa8f08cb806db71763d043d9796 Mon Sep 17 00:00:00 2001
From: Xavier Robin <xavier.robin@unibas.ch>
Date: Wed, 16 Aug 2023 14:41:46 +0200
Subject: [PATCH] feat: FindCompounds for SMILES, InChI, formula

This replaces the awkward FindCompound(... by=) which would return
a single arbitrary compound.
---
 CHANGELOG.txt                             |   4 +-
 modules/conop/doc/compoundlib.rst         |  43 ++++--
 modules/conop/pymod/export_compound.cc    |  25 +++-
 modules/conop/src/compound_lib.cc         | 172 ++++++++++++++--------
 modules/conop/src/compound_lib.hh         |  14 +-
 modules/conop/src/compound_lib_base.hh    |   3 +-
 modules/conop/src/minimal_compound_lib.cc |   9 +-
 modules/conop/src/minimal_compound_lib.hh |   3 +-
 modules/conop/tests/test_compound.py      |  28 ++--
 9 files changed, 195 insertions(+), 106 deletions(-)

diff --git a/CHANGELOG.txt b/CHANGELOG.txt
index 6154c0643..bc9e71819 100644
--- a/CHANGELOG.txt
+++ b/CHANGELOG.txt
@@ -12,8 +12,8 @@ Changes in Release 2.6.0
    read (SMILES will be empty, charges set to 0, no information about the
    obsolete status of compounds, with warnings). Older files are no
    longer supported.
- * FindCompound now takes an optional 'by' argument to find compounds by SMILES
-   string, InChI code or InChI key.
+ * New FindCompounds method of CompoundLib to query compounds by SMILES
+   string, InChI code, InChI key or formula.
  * The compound library now reads the "InChI=" part of InChI codes.
  * Several bug fixes and improvements.
 
diff --git a/modules/conop/doc/compoundlib.rst b/modules/conop/doc/compoundlib.rst
index 75c1065af..ebe5c58a3 100644
--- a/modules/conop/doc/compoundlib.rst
+++ b/modules/conop/doc/compoundlib.rst
@@ -58,27 +58,38 @@ built with OST 1.5.0 or later can be loaded.
     
     Create a new compound library
     
-  .. method:: FindCompound(id, dialect='PDB', by="tlc")
-  
-    Lookup a compound. By default the compound is searched by its
-    three-letter-code, e.g ALA. This can be changed with the `by` argument.
-    The following keys are available: "tlc" (three-letter-code or compound ID),
-    "inchi_code", "inchi_key" and "smiles".
-
-    .. note::
-      Multiple compounds may share the same SMILES string or InChI code or key.
-      An arbitrary compound will be returned, with a preference for a
-      non-obsolete one.
-
-    If no compound with that name exists, the function returns None.
+  .. method:: FindCompound(id, dialect='PDB')
 
-    Compounds are cached after they have been loaded with FindCompound.
-    To delete the compound cache, use
+    Lookup compound by its three-letter-code, e.g ALA. If no compound with that
+    name exists, the function returns None. Compounds are cached after they
+    have been loaded with FindCompound. To delete the compound cache, use
     :meth:`ClearCache`.
     
     :returns: The found compound
     :rtype: :class:`Compound`
-  
+
+  .. method:: FindCompounds(query, by, dialect='PDB')
+
+    Lookup one or more compound by SMILES string, InChI code, InChI key or
+    formula.
+
+    The compound library is queried for exact matches. Many SMILES strings
+    can represent the same compound, so this function is only useful for SMILES
+    strings coming from the PDB. This is also the case for InChI codes,
+    although to a lesser extent.
+
+    :param query: the string to lookup.
+    :type query: :class:`string`
+    :param key: the field into which to look up the query. One of: "smiles",
+      "inchi_code", "inchi_key" or "formula".
+    :type key: :class:`string`
+    :param dialect: the dialect to select for (typically "PDB", or "CHARMM" if
+      your compound library was built with charmm support).
+    :type dialect: :class:`string`
+    :returns: A list of found compounds, or an empty list if no compound was
+      found.
+    :rtype: :class:`list` or :class:`Compound`
+
   .. method:: Copy(dst_filename)
   
     Copy database to dst_filename. The new library will be an exact copy of the 
diff --git a/modules/conop/pymod/export_compound.cc b/modules/conop/pymod/export_compound.cc
index 465dc7c63..d193801bb 100644
--- a/modules/conop/pymod/export_compound.cc
+++ b/modules/conop/pymod/export_compound.cc
@@ -70,10 +70,25 @@ char get_chemtype(CompoundPtr compound)
 }
 
 CompoundPtr find_compound(CompoundLibPtr comp_lib, 
-                          const String& id, const String& dialect,
-                          const String& by="tlc")
+                          const String& id, const String& dialect)
 {
-  return comp_lib->FindCompound(id, tr_dialect(dialect), by);
+  return comp_lib->FindCompound(id, tr_dialect(dialect));
+}
+
+boost::python::list find_compounds(CompoundLibPtr comp_lib,
+                                        const String& query,
+                                        const String& by,
+                                        const String& dialect)
+{
+  CompoundPtrList ptr_list = comp_lib->FindCompounds(query, by, tr_dialect(dialect));
+  // We can't return ptr_list directly - the list was full of non working
+  // compounds for no obvious reason. So we convert it to a boost python list
+  // of Compounds.
+  boost::python::list l;
+  for(auto it = ptr_list.begin(); it != ptr_list.end(); ++it) {
+    l.append(*it);
+  }
+  return l;
 }
 
 bool is_residue_complete(CompoundLibPtr comp_lib,
@@ -156,7 +171,9 @@ void export_Compound() {
   class_<CompoundLib>("CompoundLib", no_init)
     .def("Load", &CompoundLib::Load, arg("readonly")=true).staticmethod("Load")
     .def("FindCompound", &find_compound, 
-         (arg("id"), arg("dialect")="PDB", arg("by")="tlc"))
+         (arg("id"), arg("dialect")="PDB"))
+    .def("FindCompounds", &find_compounds,
+         (arg("query"), arg("by"), arg("dialect")="PDB"))
     .def("IsResidueComplete", &is_residue_complete, (arg("residue"), 
                                                      arg("check_hydrogens")=false,
                                                      arg("dialect")="PDB"))
diff --git a/modules/conop/src/compound_lib.cc b/modules/conop/src/compound_lib.cc
index 7d906fb3f..2ec4ea8b2 100644
--- a/modules/conop/src/compound_lib.cc
+++ b/modules/conop/src/compound_lib.cc
@@ -524,12 +524,78 @@ void CompoundLib::LoadBondsFromDB(CompoundPtr comp, int pk) const {
   sqlite3_finalize(stmt);
 }
 
-CompoundPtr CompoundLib::FindCompound(const String& id, 
-                                      Compound::Dialect dialect,
-                                      const String& by) const {
+String CompoundLib::BuildFindCompoundQuery(const String& id,
+                                           Compound::Dialect dialect,
+                                           const String& by) const {
+
+  // Build the query
+  String query="SELECT id, tlc, olc, chem_class, dialect, formula, chem_type, name, inchi_code, inchi_key";
+  if(smiles_available_) {
+    query+=", smiles";
+  }
+  if(obsolete_available_) {
+    query+=", obsolete, replaced_by";
+  }
+  query+=" FROM chem_compounds"
+         " WHERE " + by + "=? AND dialect='"+String(1, char(dialect))+"'";
+
+  return query;
+}
+
+CompoundPtr CompoundLib::LoadCompoundFromDB(sqlite3_stmt* stmt) const {
+  int pk=sqlite3_column_int(stmt, 0);
+  const char* id=reinterpret_cast<const char*>(sqlite3_column_text(stmt, 1));
+  CompoundPtr compound(new Compound(id));
+  compound->SetOneLetterCode((sqlite3_column_text(stmt, 2))[0]);
+  compound->SetChemClass(mol::ChemClass(sqlite3_column_text(stmt, 3)[0]));
+  compound->SetDialect(Compound::Dialect(sqlite3_column_text(stmt, 4)[0]));
+  const char* f=reinterpret_cast<const char*>(sqlite3_column_text(stmt, 5));
+  compound->SetFormula(f);
+  compound->SetChemType(mol::ChemType(sqlite3_column_text(stmt, 6)[0]));
+  const char* name=reinterpret_cast<const char*>(sqlite3_column_text(stmt, 7));
+  compound->SetName(name);
+  const char* inchi_code=reinterpret_cast<const char*>(sqlite3_column_text(stmt, 8));
+  if (inchi_code) {
+    compound->SetInchi(inchi_code);
+  }
+  const char* inchi_key=reinterpret_cast<const char*>(sqlite3_column_text(stmt, 9));
+  if (inchi_key) {
+    compound->SetInchiKey(inchi_key);
+  }
+  int next_column = 10;
+  if (smiles_available_) {
+    const char* smiles=reinterpret_cast<const char*>(sqlite3_column_text(stmt, next_column));
+    next_column++;
+    if (smiles) {
+      compound->SetSMILES(smiles);
+    }
+  }
+  if (obsolete_available_) {
+    bool obsolete=sqlite3_column_int(stmt, next_column);
+    compound->SetObsolete(obsolete);
+    next_column++;
+    const char* replaced_by=reinterpret_cast<const char*>(sqlite3_column_text(stmt, next_column));
+    next_column++;
+    if (replaced_by) {
+      compound->SetReplacedBy(replaced_by);
+    }
+  }
+
+  // Load atoms and bonds
+  this->LoadAtomsFromDB(compound, pk);
+  this->LoadBondsFromDB(compound, pk);
+
+  return compound;
+}
+
+
+CompoundPtrList CompoundLib::FindCompounds(const String& query,
+                                                    const String& by,
+                                                    Compound::Dialect dialect) const {
+  CompoundPtrList compounds_vec;
 
   // Validate "by" argument
-  std::set<std::string> allowed_keys{"tlc", "inchi_code", "inchi_key"};
+  std::set<std::string> allowed_keys{"inchi_code", "inchi_key", "formula"};
   if(smiles_available_) {
      allowed_keys.insert("smiles");
   }
@@ -539,28 +605,54 @@ CompoundPtr CompoundLib::FindCompound(const String& id,
     throw ost::Error(msg.str());
   }
 
+  String sql_query = BuildFindCompoundQuery(query, dialect, by);
+
+  if(obsolete_available_) {
+    // Prefer active compounds, then the ones with a replacement
+    sql_query += " ORDER BY obsolete, replaced_by IS NULL";
+  }
+
+  // Run the query
+  sqlite3_stmt* stmt;
+  int retval=sqlite3_prepare_v2(db_->ptr, sql_query.c_str(),
+                                static_cast<int>(sql_query.length()),
+                                &stmt, NULL);
+  sqlite3_bind_text(stmt, 1, query.c_str(),
+                      strlen(query.c_str()), NULL);
+
+  if (SQLITE_OK==retval) {
+    int ret=sqlite3_step(stmt);
+    if (SQLITE_DONE==ret) {
+      sqlite3_finalize(stmt);
+      return compounds_vec;  // Empty
+    }
+    while (SQLITE_ROW==ret) {
+      CompoundPtr compound = LoadCompoundFromDB(stmt);
+      compounds_vec.push_back(compound);
+      // next row
+      ret=sqlite3_step(stmt);
+    }
+    assert(SQLITE_DONE==sqlite3_step(stmt));
+  } else {
+    LOG_ERROR("ERROR: " << sqlite3_errmsg(db_->ptr));
+    sqlite3_finalize(stmt);
+    return compounds_vec;  // empty
+  }
+  sqlite3_finalize(stmt);
+  return compounds_vec;
+}
+
+CompoundPtr CompoundLib::FindCompound(const String& id, 
+                                      Compound::Dialect dialect) const {
   // Check cache
-  String cache_key = by + "_" + id;
+  String cache_key = id;
   CompoundMap::const_iterator i=compound_cache_.find(cache_key);
   if (i!=compound_cache_.end()) {
     LOG_DEBUG("Retrieved compound " << cache_key << " from cache");
     return i->second;
   }
 
-  // Build the query
-  String query="SELECT id, tlc, olc, chem_class, dialect, formula, chem_type, name, inchi_code, inchi_key";
-  if(smiles_available_) {
-    query+=", smiles";
-  }
-  if(obsolete_available_) {
-    query+=", obsolete, replaced_by";
-  }
-  query+=" FROM chem_compounds"
-         " WHERE " + by + "=? AND dialect='"+String(1, char(dialect))+"'";
-  if(obsolete_available_) {
-    // Prefer active compounds, then the ones with a replacement
-    query+=" ORDER BY obsolete, replaced_by IS NULL";
-  }
+  String query = BuildFindCompoundQuery(id, dialect, "tlc");
 
   // Run the query
   sqlite3_stmt* stmt;
@@ -577,47 +669,7 @@ CompoundPtr CompoundLib::FindCompound(const String& id,
       return CompoundPtr();
     }
     if (SQLITE_ROW==ret) {
-      int pk=sqlite3_column_int(stmt, 0);
-      const char* id=reinterpret_cast<const char*>(sqlite3_column_text(stmt, 1));
-      CompoundPtr compound(new Compound(id));
-      compound->SetOneLetterCode((sqlite3_column_text(stmt, 2))[0]);
-      compound->SetChemClass(mol::ChemClass(sqlite3_column_text(stmt, 3)[0]));
-      compound->SetDialect(Compound::Dialect(sqlite3_column_text(stmt, 4)[0]));
-      const char* f=reinterpret_cast<const char*>(sqlite3_column_text(stmt, 5));
-      compound->SetFormula(f);
-      compound->SetChemType(mol::ChemType(sqlite3_column_text(stmt, 6)[0]));
-      const char* name=reinterpret_cast<const char*>(sqlite3_column_text(stmt, 7));
-      compound->SetName(name);
-      const char* inchi_code=reinterpret_cast<const char*>(sqlite3_column_text(stmt, 8));
-      if (inchi_code) {
-        compound->SetInchi(inchi_code);
-      }
-      const char* inchi_key=reinterpret_cast<const char*>(sqlite3_column_text(stmt, 9));
-      if (inchi_key) {
-        compound->SetInchiKey(inchi_key);
-      }
-      int next_column = 10;
-      if (smiles_available_) {
-        const char* smiles=reinterpret_cast<const char*>(sqlite3_column_text(stmt, next_column));
-        next_column++;
-        if (smiles) {
-          compound->SetSMILES(smiles);
-        }
-      }
-      if (obsolete_available_) {
-        bool obsolete=sqlite3_column_int(stmt, next_column);
-        compound->SetObsolete(obsolete);
-        next_column++;
-        const char* replaced_by=reinterpret_cast<const char*>(sqlite3_column_text(stmt, next_column));
-        next_column++;
-        if (replaced_by) {
-          compound->SetReplacedBy(replaced_by);
-        }
-      }
-
-      // Load atoms and bonds      
-      this->LoadAtomsFromDB(compound, pk);
-      this->LoadBondsFromDB(compound, pk);
+      CompoundPtr compound = LoadCompoundFromDB(stmt);
       compound_cache_.insert(std::make_pair(cache_key, compound));
       sqlite3_finalize(stmt);
       return compound;   
diff --git a/modules/conop/src/compound_lib.hh b/modules/conop/src/compound_lib.hh
index 8226ecdba..3b7d50b67 100644
--- a/modules/conop/src/compound_lib.hh
+++ b/modules/conop/src/compound_lib.hh
@@ -21,6 +21,7 @@
 
 #include <map>
 #include <boost/shared_ptr.hpp>
+#include <sqlite3.h>
 
 #include "module_config.hh"
 #include "compound.hh"
@@ -31,6 +32,7 @@ namespace ost { namespace conop {
 class CompoundLib;
 
 typedef boost::shared_ptr<CompoundLib> CompoundLibPtr;
+typedef std::vector<CompoundPtr> CompoundPtrList;
 
 class DLLEXPORT_OST_CONOP CompoundLib : public CompoundLibBase {
 public:
@@ -39,8 +41,10 @@ public:
   ~CompoundLib();
   
   virtual CompoundPtr FindCompound(const String& id, 
-                                   Compound::Dialect dialect,
-                                   const String& by="tlc") const;
+                                   Compound::Dialect dialect) const;
+  virtual CompoundPtrList FindCompounds(const String& query,
+                                   const String& by,
+                                   Compound::Dialect dialect) const;
   void AddCompound(const CompoundPtr& compound);
   CompoundLibPtr Copy(const String& filename) const;
   void ClearCache();
@@ -51,7 +55,11 @@ private:
     CompoundLib();
 
     void LoadAtomsFromDB(CompoundPtr comp, int pk) const;
-    void LoadBondsFromDB(CompoundPtr comp, int pk) const;    
+    void LoadBondsFromDB(CompoundPtr comp, int pk) const;
+    String BuildFindCompoundQuery(const String& id,
+                                   Compound::Dialect dialect,
+                                   const String& by) const;
+    CompoundPtr LoadCompoundFromDB(sqlite3_stmt* stmt) const;
 private:
   struct Database;
   Database* db_;
diff --git a/modules/conop/src/compound_lib_base.hh b/modules/conop/src/compound_lib_base.hh
index d8cdd7e43..aee5215b9 100644
--- a/modules/conop/src/compound_lib_base.hh
+++ b/modules/conop/src/compound_lib_base.hh
@@ -13,8 +13,7 @@ class DLLEXPORT_OST_CONOP CompoundLibBase {
 public:
   virtual ~CompoundLibBase() {}
   virtual CompoundPtr FindCompound(const String& id, 
-                                   Compound::Dialect dialect,
-                                   const String& by="tlc") const = 0;
+                                   Compound::Dialect dialect) const = 0;
 
   bool IsResidueComplete(const ost::mol::ResidueHandle& res, 
                          bool check_hydrogens, 
diff --git a/modules/conop/src/minimal_compound_lib.cc b/modules/conop/src/minimal_compound_lib.cc
index 3b0d6dc8c..9eeb10753 100644
--- a/modules/conop/src/minimal_compound_lib.cc
+++ b/modules/conop/src/minimal_compound_lib.cc
@@ -42,15 +42,8 @@ CompoundMap MinimalCompoundLib::InitCompounds() {
 
 
 CompoundPtr MinimalCompoundLib::FindCompound(const String& id, 
-                                             Compound::Dialect dialect,
-                                             const String& by) const
+                                             Compound::Dialect dialect) const
 {
-   if (by != "tlc") {
-     // Only tlc is supported by the minimal compound lib
-     std::stringstream msg;
-     msg << "Invalid 'by' key: " << by;
-     throw ost::Error(msg.str());
-   }
    CompoundMap::const_iterator i = MinimalCompoundLib::compounds_.find(id); 
    if (i != MinimalCompoundLib::compounds_.end()) {
      return i->second;
diff --git a/modules/conop/src/minimal_compound_lib.hh b/modules/conop/src/minimal_compound_lib.hh
index 16b58bff4..b0defbaba 100644
--- a/modules/conop/src/minimal_compound_lib.hh
+++ b/modules/conop/src/minimal_compound_lib.hh
@@ -17,8 +17,7 @@ public:
     CompoundLibBase()
   {}
   virtual CompoundPtr FindCompound(const String& id, 
-                                   Compound::Dialect dialect,
-                                   const String& by="tlc") const;
+                                   Compound::Dialect dialect) const;
 private:
   static CompoundMap InitCompounds();
   // since this information is never going to change, it is shared 
diff --git a/modules/conop/tests/test_compound.py b/modules/conop/tests/test_compound.py
index 18b88b861..21f0808eb 100644
--- a/modules/conop/tests/test_compound.py
+++ b/modules/conop/tests/test_compound.py
@@ -26,26 +26,36 @@ class TestCompound(unittest.TestCase):
 
     def testFindCompoundBySMILES(self):
         """ Test FindCompound by="smiles"."""
-        compound = self.compound_lib.FindCompound('O', by="smiles")
-        self.assertNotEqual(compound, None)
-        self.assertEqual(compound.smiles, 'O')
+        compounds = self.compound_lib.FindCompounds('O', by="smiles")
+        # Make sure all the compounds have the right smiles
+        for compound in compounds:
+            self.assertNotEqual(compound, None)
+            self.assertEqual(compound.smiles, 'O')
 
-        # Now we should prefer a non-obsolete compound
+        # Now we should prefer a non-obsolete compound first.
         # Default ordering has DIS as first pick but FindCompound should sort
         # active compounds first.
         # This assumes there are non-obsolete O/HOH compounds in the compound
         # lib, which should always be the case.
-        self.assertFalse(compound.obsolete)
+        self.assertFalse(compounds[0].obsolete)
 
     def testFindCompoundByInChI(self):
         """ Test FindCompound by="inchi_code|key"."""
         inchi_code = "InChI=1/H2O/h1H2"
         inchi_key = "XLYOFNOQVPJJNP-UHFFFAOYAF"
-        compound = self.compound_lib.FindCompound(inchi_code, by="inchi_code")
-        self.assertNotEqual(compound, None)
-        self.assertEqual(compound.inchi, inchi_code)
-        self.assertEqual(compound.inchi_key, inchi_key)
+        compounds = self.compound_lib.FindCompounds(inchi_code, by="inchi_code")
+        # Make sure all the compounds have the right inchis
+        for compound in compounds:
+            self.assertNotEqual(compound, None)
+            self.assertEqual(compound.inchi, inchi_code)
+            self.assertEqual(compound.inchi_key, inchi_key)
 
+        compounds = self.compound_lib.FindCompounds(inchi_key, by="inchi_key")
+        # Make sure all the compounds have the right inchis
+        for compound in compounds:
+            self.assertNotEqual(compound, None)
+            self.assertEqual(compound.inchi, inchi_code)
+            self.assertEqual(compound.inchi_key, inchi_key)
      
 if __name__=='__main__':
     from ost import testutils
-- 
GitLab