From 4e6e83055b51d4bd60eefce79aad2cfde23809b7 Mon Sep 17 00:00:00 2001
From: Gabriel Studer <gabriel.studer@unibas.ch>
Date: Wed, 4 Sep 2024 09:34:37 +0200
Subject: [PATCH] lddt: enable per-atom lDDT scores

---
 actions/ost-compare-structures   | 35 +++++++++++++-
 modules/doc/actions.rst          | 13 +++++
 modules/mol/alg/pymod/lddt.py    | 62 ++++++++++++++++++++++--
 modules/mol/alg/pymod/scoring.py | 81 ++++++++++++++++++++++++++++----
 4 files changed, 176 insertions(+), 15 deletions(-)

diff --git a/actions/ost-compare-structures b/actions/ost-compare-structures
index 9a466ce05..1afa5198a 100644
--- a/actions/ost-compare-structures
+++ b/actions/ost-compare-structures
@@ -247,6 +247,25 @@ def _ParseArgs():
               "counterparts. Atoms specified in there follow the following "
               "format: <chain_name>.<resnum>.<resnum_inscode>.<atom_name>"))
 
+    parser.add_argument(
+        "--aa-local-lddt",
+        dest="aa_local_lddt",
+        default=False,
+        action="store_true",
+        help=("Compute per-atom lDDT scores with default parameterization "
+              "and store as key \"aa_local_lddt\". Score for each atom is "
+              "accessible by key "
+              "<chain_name>.<resnum>.<resnum_inscode>.<aname>. "
+              "Alpha carbon from residue with number 42 in chain X can be "
+              "extracted with: data[\"aa_local_lddt\"][\"X.42..CA\"]. "
+              "If there is a residue insertion code, lets say A, the atom key "
+              "becomes \"X.42.A.CA\". "
+              "Stereochemical irregularities affecting lDDT are reported as "
+              "keys \"model_clashes\", \"model_bad_bonds\", "
+              "\"model_bad_angles\" and the respective reference "
+              "counterparts. Atoms specified in there follow the following "
+              "format: <chain_name>.<resnum>.<resnum_inscode>.<atom_name>"))
+
     parser.add_argument(
         "--bb-lddt",
         dest="bb_lddt",
@@ -712,6 +731,17 @@ def _LocalScoresToJSONDict(score_dict):
             json_dict[f"{ch}.{num.num}.{ins_code}"] = _RoundOrNone(s)
     return json_dict
 
+def _LocalAAScoresToJSONDict(score_dict):
+    """ Convert ResNums and atom names to str for JSON serialization
+    """
+    json_dict = dict()
+    for ch, ch_scores in score_dict.items():
+        for num, res_scores in ch_scores.items():
+            ins_code = num.ins_code.strip("\u0000")
+            for a, s in res_scores.items():
+                json_dict[f"{ch}.{num.num}.{ins_code}.{a}"] = _RoundOrNone(s)
+    return json_dict
+
 def _InterfaceResiduesToJSONList(interface_dict):
     """ Convert ResNums to str for JSON serialization.
 
@@ -816,7 +846,10 @@ def _Process(model, reference, args, model_format, reference_format):
     if args.local_lddt:
         out["local_lddt"] = _LocalScoresToJSONDict(scorer.local_lddt)
 
-    if args.lddt or args.local_lddt:
+    if args.aa_local_lddt:
+        out["aa_local_lddt"] = _LocalAAScoresToJSONDict(scorer.aa_local_lddt)
+
+    if args.lddt or args.local_lddt or args.aa_local_lddt:
         out["model_clashes"] = [x.ToJSON() for x in scorer.model_clashes]
         out["model_bad_bonds"] = [x.ToJSON() for x in scorer.model_bad_bonds]
         out["model_bad_angles"] = [x.ToJSON() for x in scorer.model_bad_angles]
diff --git a/modules/doc/actions.rst b/modules/doc/actions.rst
index d1794c49c..89b96d823 100644
--- a/modules/doc/actions.rst
+++ b/modules/doc/actions.rst
@@ -202,6 +202,19 @@ Details on the usage (output of ``ost compare-structures --help``):
                           counterparts. Atoms specified in there follow the
                           following format:
                           <chain_name>.<resnum>.<resnum_inscode>.<atom_name>
+    --aa-local-lddt       Compute per-atom lDDT scores with default
+                          parameterization and store as key "aa_local_lddt".
+                          Score for each atom is accessible by key
+                          <chain_name>.<resnum>.<resnum_inscode>.<aname>. Alpha
+                          carbon from residue with number 42 in chain X can be
+                          extracted with: data["aa_local_lddt"]["X.42..CA"]. If
+                          there is a residue insertion code, lets say A, the
+                          atom key becomes "X.42.A.CA". Stereochemical
+                          irregularities affecting lDDT are reported as keys
+                          "model_clashes", "model_bad_bonds", "model_bad_angles"
+                          and the respective reference counterparts. Atoms
+                          specified in there follow the following format:
+                          <chain_name>.<resnum>.<resnum_inscode>.<atom_name>
     --bb-lddt             Compute global lDDT score with default
                           parameterization and store as key "bb_lddt". lDDT in
                           this case is only computed on backbone atoms: CA for
diff --git a/modules/mol/alg/pymod/lddt.py b/modules/mol/alg/pymod/lddt.py
index 1453db731..83c1b2d60 100644
--- a/modules/mol/alg/pymod/lddt.py
+++ b/modules/mol/alg/pymod/lddt.py
@@ -1,3 +1,4 @@
+import itertools
 import numpy as np
 
 from ost import mol
@@ -438,7 +439,7 @@ class lDDTScorer:
              no_intrachain=False, penalize_extra_chains=False,
              residue_mapping=None, return_dist_test=False,
              check_resnames=True, add_mdl_contacts=False,
-             interaction_data=None):
+             interaction_data=None, set_atom_props=False):
         """Computes lDDT of *model* - globally and per-residue
 
         :param model: Model to be scored - models are preferably scored upon
@@ -530,6 +531,12 @@ class lDDTScorer:
         :type add_mdl_contacts: :class:`bool`
         :param interaction_data: Pro param - don't use
         :type interaction_data: :class:`tuple`
+        :param set_atom_props: If True, sets generic properties on a per atom
+                               level if *local_lddt_prop*/*local_contact_prop*
+                               are set as well.
+                               In other words: this is the only way you can
+                               get per-atom lDDT values.
+        :type set_atom_props: :class:`bool`
 
         :returns: global and per-residue lDDT scores as a tuple -
                   first element is global lDDT score (None if *target* has no
@@ -610,16 +617,28 @@ class lDDTScorer:
         self._ResolveSymmetries(pos, thresholds, symmetries, sym_ref_indices,
                                 sym_ref_distances)
 
+        atom_indices = list(itertools.chain.from_iterable(res_atom_indices))
+
+        per_atom_exp = np.asarray([self._GetNExp(i, ref_indices)
+            for i in atom_indices], dtype=np.int32)
         per_res_exp = np.asarray([self._GetNExp(res_ref_atom_indices[idx],
             ref_indices) for idx in range(len(res_indices))], dtype=np.int32)
-        per_res_conserved = self._EvalResidues(pos, thresholds,
-                                               res_atom_indices,
-                                               ref_indices, ref_distances)
+
+        per_atom_conserved = self._EvalAtoms(pos, atom_indices, thresholds,
+                                             ref_indices, ref_distances)
+        per_res_conserved = np.zeros((len(res_atom_indices), len(thresholds)),
+                                     dtype=np.int32)
+        start_idx = 0
+        for r_idx in range(len(res_atom_indices)):
+            end_idx = start_idx + len(res_atom_indices[r_idx])
+            per_res_conserved[r_idx] = np.sum(per_atom_conserved[start_idx:end_idx,:],
+                                              axis=0)
+            start_idx = end_idx
 
         n_thresh = len(thresholds)
 
         # do per-residue scores
-        per_res_lDDT = [None] * len(model.residues)
+        per_res_lDDT = [None] * model.GetResidueCount()
         for idx in range(len(res_indices)):
             n_exp = n_thresh * per_res_exp[idx]
             if n_exp > 0:
@@ -656,6 +675,39 @@ class lDDTScorer:
                 residues[r_idx].SetIntProp(conserved_prop,
                                            int(np.sum(per_res_conserved[i,:])))
 
+        if set_atom_props and (local_lddt_prop or local_contact_prop):
+            atom_list = list()
+            residues = model.residues
+            for i, indices in enumerate(res_atom_indices):
+                r = residues[res_indices[i]]
+                r_idx = ref_res_indices[i]
+                res_start_idx = self.res_start_indices[r_idx]
+                anames = self.compound_anames[r.GetName()]
+                for a_i in indices:
+                    a = r.FindAtom(anames[a_i - res_start_idx])
+                    assert(a.IsValid())
+                    atom_list.append(a)
+
+            summed_per_atom_conserved = per_atom_conserved.sum(axis=1)
+            if local_lddt_prop:
+                # the only place where actually need to compute per-atom lDDT
+                # scores
+                for a_idx in range(len(atom_list)):
+                    tmp = summed_per_atom_conserved[a_idx] / per_atom_exp[a_idx]
+                    tmp = tmp / n_thresh
+                    atom_list[a_idx].SetFloatProp(local_lddt_prop, tmp)
+
+            if local_contact_prop:
+                conserved_prop = local_contact_prop + "_cons"
+                exp_prop = local_contact_prop + "_exp"
+                for a_idx in range(len(atom_list)):
+                    # do number of conserved contacts
+                    tmp = summed_per_atom_conserved[a_idx]
+                    atom_list[a_idx].SetIntProp(conserved_prop, tmp)
+                    # do number of expected contacts
+                    tmp = per_atom_exp[a_idx] * n_thresh
+                    atom_list[a_idx].SetIntProp(exp_prop, tmp)
+
         if return_dist_test:
             return lDDT, per_res_lDDT, lDDT_tot, lDDT_cons, res_indices, \
             per_res_exp, per_res_conserved
diff --git a/modules/mol/alg/pymod/scoring.py b/modules/mol/alg/pymod/scoring.py
index 37304aa94..7445c1f2b 100644
--- a/modules/mol/alg/pymod/scoring.py
+++ b/modules/mol/alg/pymod/scoring.py
@@ -344,6 +344,7 @@ class Scorer:
         # lazily computed scores
         self._lddt = None
         self._local_lddt = None
+        self._aa_local_lddt = None
         self._bb_lddt = None
         self._bb_local_lddt = None
         self._ilddt = None
@@ -743,6 +744,23 @@ class Scorer:
             self._compute_lddt()
         return self._local_lddt
 
+    @property
+    def aa_local_lddt(self):
+        """ Per atom lDDT scores in range [0.0, 1.0]
+
+        Computed based on :attr:`~stereochecked_model` but scores for all
+        atoms in :attr:`~model` are reported. If an atom has been removed
+        by stereochemistry checks, the respective score is set to 0.0. If an
+        atom is not covered by the target or is in a chain skipped by the
+        chain mapping procedure (happens for super short chains), the respective
+        score is set to None. In case of oligomers, :attr:`~mapping` is used.
+
+        :type: :class:`dict`
+        """
+        if self._aa_local_lddt is None:
+            self._compute_lddt()
+        return self._aa_local_lddt
+
     @property
     def bb_lddt(self):
         """ Backbone only global lDDT score in range [0.0, 1.0]
@@ -1824,6 +1842,7 @@ class Scorer:
         # score variables to be set
         lddt_score = None
         local_lddt = None
+        aa_local_lddt = None
 
         if self.lddt_no_stereochecks:
             lddt_chain_mapping = dict()
@@ -1835,19 +1854,39 @@ class Scorer:
                                                residue_mapping = alns,
                                                check_resnames=False,
                                                local_lddt_prop="lddt",
-                                               add_mdl_contacts = self.lddt_add_mdl_contacts)[0]
+                                               add_mdl_contacts = self.lddt_add_mdl_contacts,
+                                               set_atom_props=True)[0]
             local_lddt = dict()
+            aa_local_lddt = dict()
             for r in self.model.residues:
+
                 cname = r.GetChain().GetName()
                 if cname not in local_lddt:
                     local_lddt[cname] = dict()
+                    aa_local_lddt[cname] = dict()
+
+                rnum = r.GetNumber()
+                if rnum not in aa_local_lddt[cname]:
+                    aa_local_lddt[cname][rnum] = dict()
+
                 if r.HasProp("lddt"):
                     score = round(r.GetFloatProp("lddt"), 3)
-                    local_lddt[cname][r.GetNumber()] = score
+                    local_lddt[cname][rnum] = score
                 else:
                     # not covered by trg or skipped in chain mapping procedure
                     # the latter happens if its part of a super short chain
-                    local_lddt[cname][r.GetNumber()] = None
+                    local_lddt[cname][rnum] = None
+
+                for a in r.atoms:
+                    if a.HasProp("lddt"):
+                        score = round(a.GetFloatProp("lddt"), 3)
+                        aa_local_lddt[cname][rnum][a.GetName()] = score
+                    else:
+                        # not covered by trg or skipped in chain mapping
+                        # procedure the latter happens if its part of a
+                        # super short chain
+                        aa_local_lddt[cname][rnum][a.GetName()] = None
+
 
         else:
             lddt_chain_mapping = dict()
@@ -1859,22 +1898,40 @@ class Scorer:
                                                residue_mapping = stereochecked_alns,
                                                check_resnames=False,
                                                local_lddt_prop="lddt",
-                                               add_mdl_contacts = self.lddt_add_mdl_contacts)[0]
+                                               add_mdl_contacts = self.lddt_add_mdl_contacts,
+                                               set_atom_props=True)[0]
             local_lddt = dict()
+            aa_local_lddt = dict()
             for r in self.model.residues:
                 cname = r.GetChain().GetName()
                 if cname not in local_lddt:
                     local_lddt[cname] = dict()
+                    aa_local_lddt[cname] = dict()
+                rnum = r.GetNumber()
+                if rnum not in aa_local_lddt[cname]:
+                    aa_local_lddt[cname][rnum] = dict()
+
                 if r.HasProp("lddt"):
                     score = round(r.GetFloatProp("lddt"), 3)
-                    local_lddt[cname][r.GetNumber()] = score
+                    local_lddt[cname][rnum] = score
+
+                    for a in r.atoms:
+                        if a.HasProp("lddt"):
+                            score = round(a.GetFloatProp("lddt"), 3)
+                            aa_local_lddt[cname][rnum][a.GetName()] = score
+                        else:
+                            # must have been removed by stereochecks
+                            aa_local_lddt[cname][rnum][a.GetName()] = 0.0
+
                 else:
                     # rsc => residue stereo checked...
-                    mdl_res = self.stereochecked_model.FindResidue(cname, r.GetNumber())
+                    mdl_res = self.stereochecked_model.FindResidue(cname, rnum)
                     if mdl_res.IsValid():
                         # not covered by trg or skipped in chain mapping procedure
                         # the latter happens if its part of a super short chain
-                        local_lddt[cname][r.GetNumber()] = None
+                        local_lddt[cname][rnum] = None
+                        for a in r.atoms:
+                            aa_local_lddt[cname][rnum][a.GetName()] = None
                     else:
                         # opt 1: removed by stereochecks => assign 0.0
                         # opt 2: removed by stereochecks AND not covered by ref
@@ -1888,12 +1945,18 @@ class Scorer:
                                         trg_r = col.GetResidue(0)
                                         break
                         if trg_r is None:
-                            local_lddt[cname][r.GetNumber()] = None
+                            local_lddt[cname][rnum] = None
+                            for a in r.atoms:
+                                aa_local_lddt[cname][rnum][a.GetName()] = None
+
                         else:
-                            local_lddt[cname][r.GetNumber()] = 0.0
+                            local_lddt[cname][rnum] = 0.0
+                            for a in r.atoms:
+                                aa_local_lddt[cname][rnum][a.GetName()] = 0.0
 
         self._lddt = lddt_score
         self._local_lddt = local_lddt
+        self._aa_local_lddt = aa_local_lddt
 
     def _compute_bb_lddt(self):
         LogScript("Computing backbone lDDT")
-- 
GitLab