From 8cee0b7a560c0206cb99e76d0fb475432de5de93 Mon Sep 17 00:00:00 2001
From: Xavier Robin <xavalias-github@xavier.robin.name>
Date: Thu, 15 Jun 2023 08:56:54 +0200
Subject: [PATCH] fix: write charge in SDF writer

---
 modules/io/src/mol/sdf_writer.cc | 25 ++++++++++++++++++++++++-
 modules/io/tests/test_io_sdf.py  | 22 ++++++++++++++++++++++
 2 files changed, 46 insertions(+), 1 deletion(-)

diff --git a/modules/io/src/mol/sdf_writer.cc b/modules/io/src/mol/sdf_writer.cc
index 8deaf9a78..e39444d0c 100644
--- a/modules/io/src/mol/sdf_writer.cc
+++ b/modules/io/src/mol/sdf_writer.cc
@@ -50,7 +50,12 @@ namespace {
               << format("%10.4f") % atom.GetPos()[1]
               << format("%10.4f ") % atom.GetPos()[2]
               << format("%-3s") % SDFAtomWriter::FormatEle(atom.GetElement())
-              << " 0  0  0  0  0  0"
+              << " 0" // Mass difference
+              << format("%-3s") % SDFAtomWriter::FormatCharge(atom.GetCharge()) // Charge
+              << "  0" // Atom stereo parity
+              << "  0" // Hydrogen count + 1
+              << "  0" // Stereo care box
+              << "  0" // Valence
               << std::endl;
         return true;
       }
@@ -66,6 +71,24 @@ namespace {
         }
         return return_ele;
       }
+
+      static String FormatCharge(const Real& chg) {
+        // Format charge according to https://doi.org/10.1021/ci00007a012
+        // 0 = uncharged or value other than these, 1 = +3, 2 = +2, 3 = +1,
+        // 4 doublet (A), 5 = -1, 6 = -2, 7 = -3
+        // Doublet means radical. This function would never return 4.
+        if (chg == 0) {
+          return "  0";
+        }
+        else if (abs(chg) > 3) {
+          String msg = "SDF format only supports charges from -3 to +3, not %g";
+          throw IOException(str(format(msg) % chg));
+        }
+        else {
+          Real chg_sdf = 4 - chg;
+          return str(format("%3.0f") % chg_sdf);
+        }
+      }
     private:
       std::ostream&      ostr_;
       std::map<long, int>& atom_indices_;
diff --git a/modules/io/tests/test_io_sdf.py b/modules/io/tests/test_io_sdf.py
index 718ac0691..735158fc3 100644
--- a/modules/io/tests/test_io_sdf.py
+++ b/modules/io/tests/test_io_sdf.py
@@ -16,6 +16,28 @@ class TestSDF(unittest.TestCase):
     ent = io.LoadSDF('testfiles/sdf/6d5w_rank1_crlf.sdf.gz')
     self.assertEqual(len(ent.atoms), 21)
     self.assertEqual(len(ent.bonds), 24)
+
+  def test_Charge(self):
+    ent = io.LoadSDF('testfiles/sdf/simple.sdf')
+    self.assertEqual(ent.FindAtom("00001_Simple Ligand", 1, "6").charge,  0)
+
+    # Write and read charges properly
+    for chg in range(-3, 4):
+      ent.FindAtom("00001_Simple Ligand", 1, "6").charge = chg
+      sdf_str = io.EntityToSDFStr(ent)
+      ent = io.SDFStrToEntity(sdf_str)
+      self.assertEqual(ent.FindAtom("00001_Simple Ligand", 1, "6").charge,  chg)
+
+    # Only -3 to +3 is supported
+    # If M CHG is implemented the following tests can be removed
+    with self.assertRaises(Exception):
+      ent.FindAtom("00001_Simple Ligand", 1, "6").charge = 4
+      io.EntityToSDFStr(ent)
+
+    with self.assertRaises(Exception):
+      ent.FindAtom("00001_Simple Ligand", 1, "6").charge = -4
+      io.EntityToSDFStr(ent)
+
     
 if __name__== '__main__':
   from ost import testutils
-- 
GitLab