From e0eea5e08b0ab79403665962ea98fa0fd5234be4 Mon Sep 17 00:00:00 2001
From: Gabriel Studer <gabriel.studer@unibas.ch>
Date: Wed, 26 Feb 2025 11:37:39 +0100
Subject: [PATCH] chain mapping: change definition of
 ChainMapper.chem_group_alignments

Old definition: List of alignments, one alignment for each chem group.
Each sequence in such an alignment was bound to an ATOMSEQ in the
target structure. The sequences were sorted by length and the first
one was considered the reference sequence.

New defition: List of alignments, one alignment for each chem group.
The first sequence is considered the reference sequence and is not
bound to any ATOMSEQ in the target structure. All subsequent sequences
are bound to an ATOMSEQ in the target structure. So if you derive these
alignments from structure only, the first sequence is simply a copy of
the longest ATOMSEQ sequence, essentially increasing the size of the
alignment by 1.

Motivation of this change is the ability of setting an arbitrary
reference sequence which can be the SEQRES.
---
 modules/mol/alg/pymod/chain_mapping.py      | 119 ++++++++++++++++----
 modules/mol/alg/tests/test_chain_mapping.py |  14 +--
 2 files changed, 101 insertions(+), 32 deletions(-)

diff --git a/modules/mol/alg/pymod/chain_mapping.py b/modules/mol/alg/pymod/chain_mapping.py
index 04d278745..7c144730d 100644
--- a/modules/mol/alg/pymod/chain_mapping.py
+++ b/modules/mol/alg/pymod/chain_mapping.py
@@ -649,41 +649,33 @@ class ChainMapper:
         :type: :class:`list` of :class:`list` of :class:`str` (chain names)
         """
         if self._chem_groups is None:
-            self._chem_groups = list()
-            for a in self.chem_group_alignments:
-                self._chem_groups.append([s.GetName() for s in a.sequences])
+            self._ComputeChemGroups()
         return self._chem_groups
     
     @property
     def chem_group_alignments(self):
         """MSA for each group in :attr:`~chem_groups`
 
-        Sequences in MSAs exhibit same order as in :attr:`~chem_groups` and
-        have the respective :class:`ost.mol.EntityView` from *target* attached.
+        The first sequence is the reference sequence.
+        The subsequent sequences represent the ATOMSEQ sequences in
+        :attr:`~target` in same order as in :attr:`~chem_groups`.
 
         :getter: Computed on first use (cached)
         :type: :class:`ost.seq.AlignmentList`
         """
         if self._chem_group_alignments is None:
-            self._SetChemGroupAlignments()
+            self._ComputeChemGroups()
         return self._chem_group_alignments
 
     @property
     def chem_group_ref_seqs(self):
-        """Reference (longest) sequence for each group in :attr:`~chem_groups`
-
-        Respective :class:`EntityView` from *target* for each sequence s are
-        available as ``s.GetAttachedView()``
+        """Reference sequence for each group in :attr:`~chem_groups`
 
         :getter: Computed on first use (cached)
         :type: :class:`ost.seq.SequenceList`
         """
         if self._chem_group_ref_seqs is None:
-            self._chem_group_ref_seqs = seq.CreateSequenceList()
-            for a in self.chem_group_alignments:
-                s = seq.CreateSequence(a.GetSequence(0).GetName(),
-                                       a.GetSequence(0).GetGaplessString())
-                self._chem_group_ref_seqs.AddSequence(s)
+            self._ComputeChemGroups()
         return self._chem_group_ref_seqs
 
     @property
@@ -698,7 +690,7 @@ class ChainMapper:
         :type: :class:`list` of :class:`ost.mol.ChemType`
         """
         if self._chem_group_types is None:
-            self._SetChemGroupAlignments()
+            self._ComputeChemGroups()
         return self._chem_group_types
         
     def GetChemMapping(self, model):
@@ -1672,8 +1664,36 @@ class ChainMapper:
         aln.AddSequence(seq.CreateSequence(s2.GetName(), ''.join(aln_s2)))
         return aln
 
-    def _SetChemGroupAlignments(self):
-        """Sets self._chem_group_alignments and self._chem_group_types
+
+    def _ComputeChemGroups(self):
+        """ Sets properties :attr:`~chem_groups`,
+        :attr:`~chem_group_alignments`, :attr:`~chem_group_ref_seqs`,
+        :attr:`~chem_group_types`
+        """
+
+        self._chem_group_alignments, self._chem_group_types =\
+        self._ChemGroupAlignmentsFromATOMSEQ()
+
+
+        self._chem_group_ref_seqs = seq.CreateSequenceList()
+        for a in self.chem_group_alignments:
+            s = seq.CreateSequence(a.GetSequence(0).GetName(),
+                                   a.GetSequence(0).GetGaplessString())
+            self._chem_group_ref_seqs.AddSequence(s)
+
+        self._chem_groups = list()
+        for a in self.chem_group_alignments:
+            group = list()
+            for s_idx in range(1, a.GetCount()):
+                s = a.GetSequence(s_idx)
+                group.append(s.GetName())
+            self._chem_groups.append(group)
+
+    def _ChemGroupAlignmentsFromATOMSEQ(self):
+        """ Groups target sequences based on ATOMSEQ
+
+        returns tuple that can be set as self._chem_group_alignments and
+        self._chem_group_types
         """
         pep_groups = self._GroupSequences(self.polypep_seqs, self.pep_seqid_thr,
                                           self.min_pep_length,
@@ -1681,12 +1701,49 @@ class ChainMapper:
         nuc_groups = self._GroupSequences(self.polynuc_seqs, self.nuc_seqid_thr,
                                           self.min_nuc_length,
                                           mol.ChemType.NUCLEOTIDES)
+
+        # pep_groups and nuc_groups give us alignments based on ATOMSEQ.
+        # For example: If you have polymer chains A,B and C in the same
+        # group and A is the longest one, you get an alignment that looks
+        # like:
+        # A: ASDFE
+        # B: ASDF-
+        # C: -SDFE
+        #
+        # however, the first sequence in chem group alignments must not be
+        # bound to any ATOMSEQ and represent the reference sequence. In the
+        # case of this function, this is simply a copy of sequence A:
+        # REF: ASDFE
+        # A:   ASDFE
+        # B:   ASDF-
+        # C:   -SDFE
+
+        # do pep_groups
+        tmp = list()
+        for a in pep_groups:
+            new_a = seq.CreateAlignment()
+            new_a.AddSequence(a.GetSequence(0))
+            for s_idx in range(a.GetCount()):
+                new_a.AddSequence(a.GetSequence(s_idx))
+            tmp.append(new_a)
+        pep_groups = tmp
+
+        # do nuc groups        
+        tmp = list()
+        for a in nuc_groups:
+            new_a = seq.CreateAlignment()
+            new_a.AddSequence(a.GetSequence(0))
+            for s_idx in range(a.GetCount()):
+                new_a.AddSequence(a.GetSequence(s_idx))
+            tmp.append(new_a)
+        nuc_groups = tmp
+
         group_types = [mol.ChemType.AMINOACIDS] * len(pep_groups)
         group_types += [mol.ChemType.NUCLEOTIDES] * len(nuc_groups)
         groups = pep_groups
         groups.extend(nuc_groups)
-        self._chem_group_alignments = groups
-        self._chem_group_types = group_types
+
+        return (groups, group_types)
 
     def _GroupSequences(self, seqs, seqid_thr, min_length, chem_type):
         """Get list of alignments representing groups of equivalent sequences
@@ -1898,13 +1955,21 @@ def _GetRefMdlAlns(ref_chem_groups, ref_chem_group_msas, mdl_chem_groups,
                 # obtain alignments of mdl and ref chains towards chem
                 # group ref sequence and merge them
                 aln_list = seq.AlignmentList()
+                
                 # do ref aln
+                ############
+                # reference sequence
                 s1 = ref_aln.GetSequence(0)
-                s2 = ref_aln.GetSequence(ref_chains.index(ref_ch))
+                # ATOMSEQ of ref_ch
+                s2 = ref_aln.GetSequence(1+ref_chains.index(ref_ch))
                 aln_list.append(seq.CreateAlignment(s1, s2))
+
                 # do mdl aln
+                ############
                 aln_list.append(mdl_alns[mdl_ch])
+
                 # merge
+                #######
                 ref_seq = seq.CreateSequence(s1.GetName(),
                                              s1.GetGaplessString())
                 merged_aln = seq.alg.MergePairwiseAlignments(aln_list,
@@ -3001,7 +3066,9 @@ def _GetRefPos(trg, mdl, trg_msas, mdl_alns, max_pos = None):
         # extract positions
         trg_pos.append(list())
         mdl_pos.append(list())
-        for s_idx in range(trg_msa.GetCount()):
+        # first seq in trg_msa is ref sequence and does not refer to any
+        # ATOMSEQ
+        for s_idx in range(1, trg_msa.GetCount()):
             trg_pos[-1].append(_ExtractMSAPos(trg_msa, s_idx, trg_indices,
                                               bb_trg))
         # first seq in mdl_msa is ref sequence in trg and does not belong to mdl
@@ -3056,13 +3123,15 @@ def _ExtractMSAPos(msa, s_idx, indices, view):
     Indices refers to column indices in msa!
     """
     s = msa.GetSequence(s_idx)
-    s_v = _CSel(view, [s.GetName()])
+    ch = view.FindChain(s.GetName())
 
     # sanity check
-    assert(len(s.GetGaplessString()) == len(s_v.residues))
+    assert(len(s.GetGaplessString()) == ch.GetResidueCount())
+
+    residues = ch.residues
 
     residue_idx = [s.GetResidueIndex(i) for i in indices]
-    return geom.Vec3List([s_v.residues[i].atoms[0].pos for i in residue_idx])
+    return geom.Vec3List([residues[i].atoms[0].pos for i in residue_idx])
 
 def _NChemGroupMappings(ref_chains, mdl_chains):
     """ Number of mappings within one chem group
diff --git a/modules/mol/alg/tests/test_chain_mapping.py b/modules/mol/alg/tests/test_chain_mapping.py
index 3183ed7c8..1529efa3f 100644
--- a/modules/mol/alg/tests/test_chain_mapping.py
+++ b/modules/mol/alg/tests/test_chain_mapping.py
@@ -83,16 +83,16 @@ class TestChainMapper(unittest.TestCase):
 
     # check chem_group_alignments attribute
     self.assertEqual(len(mapper.chem_group_alignments), 3)
-    self.assertEqual(mapper.chem_group_alignments[0].GetCount(), 2)
-    self.assertEqual(mapper.chem_group_alignments[1].GetCount(), 1)
-    self.assertEqual(mapper.chem_group_alignments[2].GetCount(), 1)
-    s0 = mapper.chem_group_alignments[0].GetSequence(0)
-    s1 = mapper.chem_group_alignments[0].GetSequence(1)
+    self.assertEqual(mapper.chem_group_alignments[0].GetCount(), 3)
+    self.assertEqual(mapper.chem_group_alignments[1].GetCount(), 2)
+    self.assertEqual(mapper.chem_group_alignments[2].GetCount(), 2)
+    s0 = mapper.chem_group_alignments[0].GetSequence(1)
+    s1 = mapper.chem_group_alignments[0].GetSequence(2)
     self.assertEqual(s0.GetGaplessString(), str(pep_s_one))
     self.assertEqual(s1.GetGaplessString(), str(pep_s_two))
-    s0 = mapper.chem_group_alignments[1].GetSequence(0)
+    s0 = mapper.chem_group_alignments[1].GetSequence(1)
     self.assertEqual(s0.GetGaplessString(), str(nuc_s_one))
-    s0 = mapper.chem_group_alignments[2].GetSequence(0)
+    s0 = mapper.chem_group_alignments[2].GetSequence(1)
     self.assertEqual(s0.GetGaplessString(), str(nuc_s_two))
 
     # ensure that error is triggered if there are insertion codes
-- 
GitLab