Skip to content
Snippets Groups Projects
Select Git revision
  • 9ddd09a836c220ab6b39da7e2069990a5f0c1227
  • master default protected
  • develop protected
  • cmake_boost_refactor
  • ubuntu_ci
  • mmtf
  • non-orthogonal-maps
  • no_boost_filesystem
  • data_viewer
  • 2.11.1
  • 2.11.0
  • 2.10.0
  • 2.9.3
  • 2.9.2
  • 2.9.1
  • 2.9.0
  • 2.8.0
  • 2.7.0
  • 2.6.1
  • 2.6.0
  • 2.6.0-rc4
  • 2.6.0-rc3
  • 2.6.0-rc2
  • 2.6.0-rc
  • 2.5.0
  • 2.5.0-rc2
  • 2.5.0-rc
  • 2.4.0
  • 2.4.0-rc2
29 results

qsscore.py

Blame
  • Gabriel Studer's avatar
    Studer Gabriel authored
    QS_global is the "original" QS-score, whereas QS_best only considers
    residue that are actually mapped, i.e. does not punish if one structure
    has an additional chain or more residues in one of the chains.
    f9192f20
    History
    qsscore.py 23.50 KiB
    import itertools
    import numpy as np
    from scipy.spatial import distance
    
    import time
    from ost import mol
    
    class QSEntity:
        """ Helper object for QS-score computation
    
        Holds structural information and getters for interacting chains, i.e.
        interfaces. Peptide residues are represented by their CB position
        (CA for GLY) and nucleotides by C3'.
    
        :param ent: Structure for QS score computation
        :type ent: :class:`ost.mol.EntityView`/:class:`ost.mol.EntityHandle`
        :param contact_d: Pairwise distance of residues to be considered as contacts
        :type contact_d: :class:`float`
        """
        def __init__(self, ent, contact_d = 12.0):
            pep_query = "(peptide=true and (aname=\"CB\" or (rname=\"GLY\" and aname=\"CA\")))"
            nuc_query = "(nucleotide=True and aname=\"C3'\")"
            self._view = ent.Select(" or ".join([pep_query, nuc_query]))
            self._contact_d = contact_d
    
            # the following attributes will be lazily evaluated
            self._chain_names = None
            self._interacting_chains = None
            self._sequence = dict()
            self._pos = dict()
            self._pair_dist = dict()
    
        @property
        def view(self):
            """ Processed structure
    
            View that only contains representative atoms. That's CB for peptide
            residues (CA for GLY) and C3' for nucleotides.
    
            :type: :class:`ost.mol.EntityView`
            """
            return self._view
    
        @property
        def contact_d(self):
            """ Pairwise distance of residues to be considered as contacts
    
            Given at :class:`QSEntity` construction
    
            :type: :class:`float`
            """
            return self._contact_d
    
        @property
        def chain_names(self):
            """ Chain names in :attr:`~view`
     
            Names are sorted
    
            :type: :class:`list` of :class:`str`
            """
            if self._chain_names is None:
                self._chain_names = sorted([ch.name for ch in self.view.chains])
            return self._chain_names
    
        @property
        def interacting_chains(self):
            """ Pairs of chains in :attr:`~view` with at least one contact
    
            :type: :class:`list` of :class:`tuples`
            """
            if self._interacting_chains is None:
                self._interacting_chains = list()
                for x in itertools.combinations(self.chain_names, 2):
                    if np.count_nonzero(self.PairDist(x[0], x[1]) < self.contact_d):
                        self._interacting_chains.append(x)
            return self._interacting_chains
        
        def GetChain(self, chain_name):
            """ Get chain by name
    
            :param chain_name: Chain in :attr:`~view`
            :type chain_name: :class:`str`
            """ 
            chain = self.view.FindChain(chain_name)
            if not chain.IsValid():
                raise RuntimeError(f"view has no chain named \"{chain_name}\"")
            return chain
    
        def GetSequence(self, chain_name):
            """ Get sequence of chain
    
            Returns sequence of specified chain as raw :class:`str`
    
            :param chain_name: Chain in :attr:`~view`
            :type chain_name: :class:`str`
            """
            if chain_name not in self._sequence:
                ch = self.GetChain(chain_name)
                s = ''.join([r.one_letter_code for r in ch.residues])
                self._sequence[chain_name] = s
            return self._sequence[chain_name]
    
        def GetPos(self, chain_name):
            """ Get representative positions of chain
    
            That's CB positions for peptide residues (CA for GLY) and C3' for
            nucleotides. Returns positions as a numpy array of shape
            (n_residues, 3).
    
            :param chain_name: Chain in :attr:`~view`
            :type chain_name: :class:`str`
            """
            if chain_name not in self._pos:
                ch = self.GetChain(chain_name)
                pos = np.zeros((len(ch.residues), 3))
                for i, r in enumerate(ch.residues):
                    pos[i,:] = r.atoms[0].GetPos().data
                self._pos[chain_name] = pos
            return self._pos[chain_name]
    
        def PairDist(self, chain_name_one, chain_name_two):
            """ Get pairwise distances between two chains
    
            Returns distances as numpy array of shape (a, b).
            Where a is the number of residues of the chain that comes BEFORE the
            other in :attr:`~chain_names` 
            """
            key = (min(chain_name_one, chain_name_two),
                   max(chain_name_one, chain_name_two))
            if key not in self._pair_dist:
                self._pair_dist[key] = distance.cdist(self.GetPos(key[0]),
                                                      self.GetPos(key[1]),
                                                      'euclidean')
            return self._pair_dist[key]
    
    class QSScorerResult:
        """
        Holds data relevant for QS-score computation. Formulas for QS scores:
    
        ::
    
          - QS_best = weighted_scores / (weight_sum + weight_extra_mapped)
          - QS_global = weighted_scores / (weight_sum + weight_extra_all)
          -> weighted_scores = sum(w(min(d1,d2)) * (1 - abs(d1-d2)/12)) for shared
          -> weight_sum = sum(w(min(d1,d2))) for shared
          -> weight_extra_mapped = sum(w(d)) for all mapped but non-shared
          -> weight_extra_all = sum(w(d)) for all non-shared
          -> w(d) = 1 if d <= 5, exp(-2 * ((d-5.0)/4.28)^2) else
    
        In the formulas above:
    
        * "d": CA/CB-CA/CB distance of an "inter-chain contact" ("d1", "d2" for
          "shared" contacts).
        * "mapped": we could map chains of two structures and align residues in
          :attr:`alignments`.
        * "shared": pairs of residues which are "mapped" and have
          "inter-chain contact" in both structures.
        * "inter-chain contact": CB-CB pairs (CA for GLY) with distance <= 12 A
          (fallback to CA-CA if :attr:`calpha_only` is True).
        * "w(d)": weighting function (prob. of 2 res. to interact given CB distance)
          from `Xu et al. 2009 <https://dx.doi.org/10.1016%2Fj.jmb.2008.06.002>`_.
        """
        def __init__(self, weighted_scores, weight_sum, weight_extra_mapped,
                     weight_extra_all):
            self._weighted_scores = weighted_scores
            self._weight_sum = weight_sum
            self._weight_extra_mapped = weight_extra_mapped
            self._weight_extra_all = weight_extra_all
    
        @property
        def weighted_scores(self):
            """ weighted_scores attribute as described in formula section above
    
            :type: :class:`float`
            """
            return self._weighted_scores
    
        @property
        def weight_sum(self):
            """ weight_sum attribute as described in formula section above
    
            :type: :class:`float`
            """
            return self._weight_sum
    
        @property
        def weight_extra_mapped(self):
            """ weight_extra_mapped attribute as described in formula section above
    
            :type: :class:`float`
            """
            return self._weight_extra_mapped
    
        @property
        def weight_extra_all(self):
            """ weight_extra_all attribute as described in formula section above
    
            :type: :class:`float`
            """
            return self._weight_extra_all
    
        @property
        def QS_best(self):
            """ QS_best - the actual score as described in formula section above
    
            :type: :class:`float`
            """
            nominator = self.weighted_scores
            denominator = self.weight_sum + self.weight_extra_mapped
            if denominator != 0.0:
                return nominator/denominator
            else:
                return 0.0
    
        @property
        def QS_global(self):
            """ QS_global - the actual score as described in formula section above
    
            :type: :class:`float`
            """
            nominator = self.weighted_scores
            denominator = self.weight_sum + self.weight_extra_all
            if denominator != 0.0:
                return nominator/denominator
            else:
                return 0.0
    
    
    class QSScorer:
        """ Helper object to compute QS-score
    
        Tightly integrated into the mechanisms from the chain_mapping module.
        The prefered way to derive an object of type :class:`QSScorer` is through
        the static constructor: :func:`~FromMappingResult`. Example score
        computation including mapping:
    
        ::
    
            from ost.mol.alg.qsscore import QSScorer
            from ost.mol.alg.chain_mapping import ChainMapper
    
            ent_1 = io.LoadPDB("path_to_assembly_1.pdb")
            ent_2 = io.LoadPDB("path_to_assembly_2.pdb")
    
            chain_mapper = ChainMapper(ent_1)
            mapping_result = chain_mapper.GetlDDTMapping(ent_2)
            qs_scorer = QSScorer.FromMappingResult(mapping_result)
            score_result = qs_scorer.GetQSScore(mapping_result.mapping)
            print("score:", score_result.QS_global)
    
        QS-score computation in :func:`QSScorer.Score` implements caching.
        Repeated computations with alternative chain mappings thus become faster.
    
        :param target: Structure designated as "target". Can be fetched from
                       :class:`ost.mol.alg.chain_mapping.MappingResult`
        :type target: :class:`ost.mol.EntityView`/:class:`ost.mol.EntityHandle`
        :param chem_groups: Groups of chemically equivalent chains in *target*.
                            Can be fetched from
                            :class:`ost.mol.alg.chain_mapping.MappingResult`
        :type chem_groups: :class:`list` of :class:`list` of :class:`str`
        :param model: Structure designated as "model". Can be fetched from
                      :class:`ost.mol.alg.chain_mapping.MappingResult`
        :type model: :class:`ost.mol.EntityView`/:class:`ost.mol.EntityHandle`
        :param alns: Each alignment is accessible with ``alns[(t_chain,m_chain)]``.
                     First sequence is the sequence of the respective chain in
                     :attr:`~qsent1`, second sequence the one from :attr:`~qsent2`.
                     Can be fetched from
                     :class:`ost.mol.alg.chain_mapping.MappingResult`
        :type alns: :class:`dict` with key: :class:`tuple` of :class:`str`, value:
                    :class:`ost.seq.AlignmentHandle`
        """
        def __init__(self, target, chem_groups, model, alns, contact_d = 12.0):
    
            self._qsent1 = QSEntity(target, contact_d = contact_d)
    
            # ensure that target chain names match the ones in chem_groups
            chem_group_ch_names = list(itertools.chain.from_iterable(chem_groups))
            if self._qsent1.chain_names != sorted(chem_group_ch_names):
                raise RuntimeError(f"Expect exact same chain names in chem_groups "
                                   f"and in target (which is processed to only "
                                   f"contain peptides/nucleotides). target: "
                                   f"{self._qsent1.chain_names}, chem_groups: "
                                   f"{chem_group_ch_names}")
    
            self._chem_groups = chem_groups
            self._qsent2 = QSEntity(model, contact_d = contact_d)
            self._alns = alns
    
            # cache for mapped interface scores
            # key: tuple of tuple ((qsent1_ch1, qsent1_ch2),
            #                     ((qsent2_ch1, qsent2_ch2))
            # value: tuple with four numbers referring to QS-score formalism
            #        1: weighted_scores
            #        2: weight_sum
            #        3: weight_extra_mapped
            #        4: weight_extra_all
            self._mapped_cache = dict()
    
            # cache for non-mapped interfaces in qsent1
            # key: tuple (qsent1_ch1, qsent1_ch2)
            # value: contribution of that interface to weight_extra_all
            self._qsent_1_penalties = dict()
    
            # same for qsent2
            self._qsent_2_penalties = dict()
    
        @staticmethod
        def FromMappingResult(mapping_result):
            """ The preferred way to get a :class:`QSScorer`
    
            Static constructor that derives an object of type :class:`QSScorer`
            using a :class:`ost.mol.alg.chain_mapping.MappingResult`
    
            :param mapping_result: Data source
            :type mapping_result: :class:`ost.mol.alg.chain_mapping.MappingResult`
            """
            qs_scorer = QSScorer(mapping_result.target, mapping_result.chem_groups,
                                 mapping_result.model, alns = mapping_result.alns)
            return qs_scorer
    
        @property
        def qsent1(self):
            """ Represents *target*
    
            :type: :class:`QSEntity`
            """
            return self._qsent1
    
        @property
        def chem_groups(self):
            """ Groups of chemically equivalent chains in *target*
    
            Provided at object construction
    
            :type: :class:`list` of :class:`list` of :class:`str`
            """
            return self._chem_groups
    
        @property
        def qsent2(self):
            """ Represents *model*
    
            :type: :class:`QSEntity`
            """
            return self._qsent2
    
        @property
        def alns(self):
            """ Alignments between chains in :attr:`~qsent1` and :attr:`~qsent2`
    
            Provided at object construction. Each alignment is accessible with
            ``alns[(t_chain,m_chain)]``. First sequence is the sequence of the
            respective chain in :attr:`~qsent1`, second sequence the one from
            :attr:`~qsent2`.
    
            :type: :class:`dict` with key: :class:`tuple` of :class:`str`, value:
                   :class:`ost.seq.AlignmentHandle`
            """
            return self._alns
        
        def Score(self, mapping, check=True):
            """ Computes QS-score given chain mapping
    
            Again, the preferred way is to get *mapping* is from an object
            of type :class:`ost.mol.alg.chain_mapping.MappingResult`.
    
            :param mapping: see 
                            :attr:`ost.mol.alg.chain_mapping.MappingResult.mapping`
            :type mapping: :class:`list` of :class:`list` of :class:`str`
            :param check: Perform input checks, can be disabled for speed purposes
                          if you know what you're doing.
            :type check: :class:`bool`
            :returns: Result object of type :class:`QSScorerResult`
            """
    
            if check:
                # ensure that dimensionality of mapping matches self.chem_groups
                if len(self.chem_groups) != len(mapping):
                    raise RuntimeError("Dimensions of self.chem_groups and mapping "
                                       "must match")
                for a,b in zip(self.chem_groups, mapping):
                    if len(a) != len(b):
                        raise RuntimeError("Dimensions of self.chem_groups and "
                                           "mapping must match")
                # ensure that chain names in mapping are all present in qsent2
                for name in itertools.chain.from_iterable(mapping):
                    if name is not None and name not in self.qsent2.chain_names:
                        raise RuntimeError(f"Each chain in mapping must be present "
                                           f"in self.qsent2. No match for "
                                           f"\"{name}\"")
    
            flat_mapping = dict()
            for a, b in zip(self.chem_groups, mapping):
                flat_mapping.update({x: y for x, y in zip(a, b) if y is not None})
    
            return self.FromFlatMapping(flat_mapping)
    
        def FromFlatMapping(self, flat_mapping):
            """ Same as :func:`Score` but with flat mapping
    
            :param flat_mapping: Dictionary with target chain names as keys and
                                 the mapped model chain names as value
            :type flat_mapping: :class:`dict` with :class:`str` as key and value
            :returns: Result object of type :class:`QSScorerResult`
            """
    
            weighted_scores = 0.0
            weight_sum = 0.0
            weight_extra_mapped = 0.0
            weight_extra_all = 0.0
    
            # keep track of processed interfaces in qsent2
            processed_qsent2_interfaces = set()
    
            for int1 in self.qsent1.interacting_chains:
                if int1[0] in flat_mapping and int1[1] in flat_mapping:
                    int2 = (flat_mapping[int1[0]], flat_mapping[int1[1]])
                    a, b, c, d = self._MappedInterfaceScores(int1, int2)
                    weighted_scores += a
                    weight_sum += b
                    weight_extra_mapped += c
                    weight_extra_all += d
                    processed_qsent2_interfaces.add((min(int2[0], int2[1]),
                                                     max(int2[0], int2[1])))
                else:
                    weight_extra_all += self._InterfacePenalty1(int1)
    
            # process interfaces that only exist in qsent2
            r_flat_mapping = {v:k for k,v in flat_mapping.items()} # reverse mapping...
            for int2 in self.qsent2.interacting_chains:
                if int2 not in processed_qsent2_interfaces:
                    if int2[0] in r_flat_mapping and int2[1] in r_flat_mapping:
                        int1 = (r_flat_mapping[int2[0]], r_flat_mapping[int2[1]])
                        a, b, c, d = self._MappedInterfaceScores(int1, int2)
                        weighted_scores += a
                        weight_sum += b
                        weight_extra_mapped += c
                        weight_extra_all += d
                    else:
                        weight_extra_all += self._InterfacePenalty2(int2)
    
            return QSScorerResult(weighted_scores, weight_sum, weight_extra_mapped,
                                  weight_extra_all)
    
        def _MappedInterfaceScores(self, int1, int2):
            key_one = (int1, int2)
            if key_one in self._mapped_cache:
                return self._mapped_cache[key_one]
            key_two = ((int1[1], int1[0]), (int2[1], int2[0]))
            if key_two in self._mapped_cache:
                return self._mapped_cache[key_two]
    
            weighted_scores, weight_sum, weight_extra_mapped, weight_extra_all = \
            self._InterfaceScores(int1, int2)
            self._mapped_cache[key_one] = (weighted_scores, weight_sum, weight_extra_mapped,
                                           weight_extra_all)
            return (weighted_scores, weight_sum, weight_extra_mapped, weight_extra_all)
    
        def _InterfaceScores(self, int1, int2):
    
            d1 = self.qsent1.PairDist(int1[0], int1[1])
            d2 = self.qsent2.PairDist(int2[0], int2[1])
    
            # given two chain names a and b: if a < b, shape of pairwise distances is
            # (len(a), len(b)). However, if b > a, its (len(b), len(a)) => transpose
            if int1[0] > int1[1]:
                d1 = d1.transpose()
            if int2[0] > int2[1]:
                d2 = d2.transpose()
    
            # indices of the first chain in the two interfaces
            mapped_indices_1_1, mapped_indices_1_2 = \
            self._IndexMapping(int1[0], int2[0])
            # indices of the second chain in the two interfaces
            mapped_indices_2_1, mapped_indices_2_2 = \
            self._IndexMapping(int1[1], int2[1])
    
            # get shared_masks - for this we first need to select the actual
            # mapped positions to get a one-to-one relationship and map it back
            # to the original mask size
            assert(self.qsent1.contact_d == self.qsent2.contact_d)
            contact_d = self.qsent1.contact_d
            mapped_idx_grid_1 = np.ix_(mapped_indices_1_1, mapped_indices_2_1)
            mapped_idx_grid_2 = np.ix_(mapped_indices_1_2, mapped_indices_2_2)
            mapped_d1_contacts = d1[mapped_idx_grid_1] < contact_d
            mapped_d2_contacts = d2[mapped_idx_grid_2] < contact_d
            assert(mapped_d1_contacts.shape == mapped_d2_contacts.shape)
            shared_mask = np.logical_and(mapped_d1_contacts, mapped_d2_contacts)
            shared_mask_d1 = np.full(d1.shape, False, dtype=bool)
            shared_mask_d1[mapped_idx_grid_1] = shared_mask
            shared_mask_d2 = np.full(d2.shape, False, dtype=bool)
            shared_mask_d2[mapped_idx_grid_2] = shared_mask
    
            # get mapped but nonshared masks
            mapped_nonshared_mask_d1 = np.full(d1.shape, False, dtype=bool)
            mapped_nonshared_mask_d1[mapped_idx_grid_1] = \
            np.logical_and(np.logical_not(shared_mask), mapped_d1_contacts)
            mapped_nonshared_mask_d2 = np.full(d2.shape, False, dtype=bool)
            mapped_nonshared_mask_d2[mapped_idx_grid_2] = \
            np.logical_and(np.logical_not(shared_mask), mapped_d2_contacts)
    
            # contributions from shared contacts
            shared_d1 = d1[shared_mask_d1]
            shared_d2 = d2[shared_mask_d2]
            shared_min = np.minimum(shared_d1, shared_d2)
            shared_abs_diff_div_12 = np.abs(np.subtract(shared_d1, shared_d2))/12.0
            weight_term = np.ones(shared_min.shape[0])
            bigger_5_mask = shared_min > 5.0
            weights = np.exp(-2.0*np.square((shared_min[bigger_5_mask]-5.0)/4.28))
            weight_term[bigger_5_mask] = weights
            diff_term = np.subtract(np.ones(weight_term.shape[0]),
                                    shared_abs_diff_div_12)
            weighted_scores = np.sum(np.multiply(weight_term, diff_term))
            weight_sum = np.sum(weight_term)
    
            # do weight_extra_all for interface one
            nonshared_contact_mask_d1 = np.logical_and(np.logical_not(shared_mask_d1),
                                                       d1 < contact_d)
            contact_distances = d1[nonshared_contact_mask_d1]
            bigger_5 = contact_distances[contact_distances > 5]
            weight_extra_all = np.sum(np.exp(-2.0*np.square((bigger_5-5.0)/4.28)))
            # add 1.0 for all contact distances <= 5.0
            weight_extra_all += contact_distances.shape[0] - bigger_5.shape[0]
            # same for interface two
            nonshared_contact_mask_d2 = np.logical_and(np.logical_not(shared_mask_d2),
                                                       d2 < contact_d)
            contact_distances = d2[nonshared_contact_mask_d2]
            bigger_5 = contact_distances[contact_distances > 5]
            weight_extra_all += np.sum(np.exp(-2.0*np.square((bigger_5-5.0)/4.28)))
            # add 1.0 for all contact distances <= 5.0
            weight_extra_all += contact_distances.shape[0] - bigger_5.shape[0]
    
            # do weight_extra_mapped for interface one
            contact_distances = d1[mapped_nonshared_mask_d1]
            bigger_5 = contact_distances[contact_distances > 5]
            weight_extra_mapped = np.sum(np.exp(-2.0*np.square((bigger_5-5.0)/4.28)))
            # add 1.0 for all contact distances <= 5.0
            weight_extra_mapped += contact_distances.shape[0] - bigger_5.shape[0]
            # same for interface two
            contact_distances = d2[mapped_nonshared_mask_d2]
            bigger_5 = contact_distances[contact_distances > 5]
            weight_extra_mapped += np.sum(np.exp(-2.0*np.square((bigger_5-5.0)/4.28)))
            # add 1.0 for all contact distances <= 5.0
            weight_extra_mapped += contact_distances.shape[0] - bigger_5.shape[0]
    
            return (weighted_scores, weight_sum, weight_extra_mapped, weight_extra_all)
    
        def _IndexMapping(self, ch1, ch2):
            """ Fetches aln and returns indices of (non-)aligned residues
    
            returns 2 numpy arrays containing the indices of residues in
            ch1 and ch2 which are aligned
            """
            mapped_indices_1 = list()
            mapped_indices_2 = list()
            idx_1 = 0
            idx_2 = 0
            for col in self.alns[(ch1, ch2)]:
                if col[0] != '-' and col[1] != '-':
                    mapped_indices_1.append(idx_1)
                    mapped_indices_2.append(idx_2)
                if col[0] != '-':
                    idx_1 +=1
                if col[1] != '-':
                    idx_2 +=1
            return (np.array(mapped_indices_1), np.array(mapped_indices_2))
    
        def _InterfacePenalty1(self, interface):
            if interface not in self._qsent_1_penalties:
                self._qsent_1_penalties[interface] = \
                self._InterfacePenalty(self.qsent1, interface)
            return self._qsent_1_penalties[interface]
    
        def _InterfacePenalty2(self, interface):
            if interface not in self._qsent_2_penalties:
                self._qsent_2_penalties[interface] = \
                self._InterfacePenalty(self.qsent2, interface)
            return self._qsent_2_penalties[interface]
    
        def _InterfacePenalty(self, qsent, interface):
            d = qsent.PairDist(interface[0], interface[1])
            contact_distances = d[d < qsent.contact_d]
            bigger_5 = contact_distances[contact_distances > 5]
            penalty = np.sum(np.exp(-2.0*np.square((bigger_5-5.0)/4.28)))
            # add 1.0 for all contact distances <= 5.0
            penalty += contact_distances.shape[0] - bigger_5.shape[0]
            return penalty
    
    # specify public interface
    __all__ = ('QSEntity', 'QSScorer')