Something went wrong on our end
-
Studer Gabriel authored
QS_global is the "original" QS-score, whereas QS_best only considers residue that are actually mapped, i.e. does not punish if one structure has an additional chain or more residues in one of the chains.
Studer Gabriel authoredQS_global is the "original" QS-score, whereas QS_best only considers residue that are actually mapped, i.e. does not punish if one structure has an additional chain or more residues in one of the chains.
qsscore.py 23.50 KiB
import itertools
import numpy as np
from scipy.spatial import distance
import time
from ost import mol
class QSEntity:
""" Helper object for QS-score computation
Holds structural information and getters for interacting chains, i.e.
interfaces. Peptide residues are represented by their CB position
(CA for GLY) and nucleotides by C3'.
:param ent: Structure for QS score computation
:type ent: :class:`ost.mol.EntityView`/:class:`ost.mol.EntityHandle`
:param contact_d: Pairwise distance of residues to be considered as contacts
:type contact_d: :class:`float`
"""
def __init__(self, ent, contact_d = 12.0):
pep_query = "(peptide=true and (aname=\"CB\" or (rname=\"GLY\" and aname=\"CA\")))"
nuc_query = "(nucleotide=True and aname=\"C3'\")"
self._view = ent.Select(" or ".join([pep_query, nuc_query]))
self._contact_d = contact_d
# the following attributes will be lazily evaluated
self._chain_names = None
self._interacting_chains = None
self._sequence = dict()
self._pos = dict()
self._pair_dist = dict()
@property
def view(self):
""" Processed structure
View that only contains representative atoms. That's CB for peptide
residues (CA for GLY) and C3' for nucleotides.
:type: :class:`ost.mol.EntityView`
"""
return self._view
@property
def contact_d(self):
""" Pairwise distance of residues to be considered as contacts
Given at :class:`QSEntity` construction
:type: :class:`float`
"""
return self._contact_d
@property
def chain_names(self):
""" Chain names in :attr:`~view`
Names are sorted
:type: :class:`list` of :class:`str`
"""
if self._chain_names is None:
self._chain_names = sorted([ch.name for ch in self.view.chains])
return self._chain_names
@property
def interacting_chains(self):
""" Pairs of chains in :attr:`~view` with at least one contact
:type: :class:`list` of :class:`tuples`
"""
if self._interacting_chains is None:
self._interacting_chains = list()
for x in itertools.combinations(self.chain_names, 2):
if np.count_nonzero(self.PairDist(x[0], x[1]) < self.contact_d):
self._interacting_chains.append(x)
return self._interacting_chains
def GetChain(self, chain_name):
""" Get chain by name
:param chain_name: Chain in :attr:`~view`
:type chain_name: :class:`str`
"""
chain = self.view.FindChain(chain_name)
if not chain.IsValid():
raise RuntimeError(f"view has no chain named \"{chain_name}\"")
return chain
def GetSequence(self, chain_name):
""" Get sequence of chain
Returns sequence of specified chain as raw :class:`str`
:param chain_name: Chain in :attr:`~view`
:type chain_name: :class:`str`
"""
if chain_name not in self._sequence:
ch = self.GetChain(chain_name)
s = ''.join([r.one_letter_code for r in ch.residues])
self._sequence[chain_name] = s
return self._sequence[chain_name]
def GetPos(self, chain_name):
""" Get representative positions of chain
That's CB positions for peptide residues (CA for GLY) and C3' for
nucleotides. Returns positions as a numpy array of shape
(n_residues, 3).
:param chain_name: Chain in :attr:`~view`
:type chain_name: :class:`str`
"""
if chain_name not in self._pos:
ch = self.GetChain(chain_name)
pos = np.zeros((len(ch.residues), 3))
for i, r in enumerate(ch.residues):
pos[i,:] = r.atoms[0].GetPos().data
self._pos[chain_name] = pos
return self._pos[chain_name]
def PairDist(self, chain_name_one, chain_name_two):
""" Get pairwise distances between two chains
Returns distances as numpy array of shape (a, b).
Where a is the number of residues of the chain that comes BEFORE the
other in :attr:`~chain_names`
"""
key = (min(chain_name_one, chain_name_two),
max(chain_name_one, chain_name_two))
if key not in self._pair_dist:
self._pair_dist[key] = distance.cdist(self.GetPos(key[0]),
self.GetPos(key[1]),
'euclidean')
return self._pair_dist[key]
class QSScorerResult:
"""
Holds data relevant for QS-score computation. Formulas for QS scores:
::
- QS_best = weighted_scores / (weight_sum + weight_extra_mapped)
- QS_global = weighted_scores / (weight_sum + weight_extra_all)
-> weighted_scores = sum(w(min(d1,d2)) * (1 - abs(d1-d2)/12)) for shared
-> weight_sum = sum(w(min(d1,d2))) for shared
-> weight_extra_mapped = sum(w(d)) for all mapped but non-shared
-> weight_extra_all = sum(w(d)) for all non-shared
-> w(d) = 1 if d <= 5, exp(-2 * ((d-5.0)/4.28)^2) else
In the formulas above:
* "d": CA/CB-CA/CB distance of an "inter-chain contact" ("d1", "d2" for
"shared" contacts).
* "mapped": we could map chains of two structures and align residues in
:attr:`alignments`.
* "shared": pairs of residues which are "mapped" and have
"inter-chain contact" in both structures.
* "inter-chain contact": CB-CB pairs (CA for GLY) with distance <= 12 A
(fallback to CA-CA if :attr:`calpha_only` is True).
* "w(d)": weighting function (prob. of 2 res. to interact given CB distance)
from `Xu et al. 2009 <https://dx.doi.org/10.1016%2Fj.jmb.2008.06.002>`_.
"""
def __init__(self, weighted_scores, weight_sum, weight_extra_mapped,
weight_extra_all):
self._weighted_scores = weighted_scores
self._weight_sum = weight_sum
self._weight_extra_mapped = weight_extra_mapped
self._weight_extra_all = weight_extra_all
@property
def weighted_scores(self):
""" weighted_scores attribute as described in formula section above
:type: :class:`float`
"""
return self._weighted_scores
@property
def weight_sum(self):
""" weight_sum attribute as described in formula section above
:type: :class:`float`
"""
return self._weight_sum
@property
def weight_extra_mapped(self):
""" weight_extra_mapped attribute as described in formula section above
:type: :class:`float`
"""
return self._weight_extra_mapped
@property
def weight_extra_all(self):
""" weight_extra_all attribute as described in formula section above
:type: :class:`float`
"""
return self._weight_extra_all
@property
def QS_best(self):
""" QS_best - the actual score as described in formula section above
:type: :class:`float`
"""
nominator = self.weighted_scores
denominator = self.weight_sum + self.weight_extra_mapped
if denominator != 0.0:
return nominator/denominator
else:
return 0.0
@property
def QS_global(self):
""" QS_global - the actual score as described in formula section above
:type: :class:`float`
"""
nominator = self.weighted_scores
denominator = self.weight_sum + self.weight_extra_all
if denominator != 0.0:
return nominator/denominator
else:
return 0.0
class QSScorer:
""" Helper object to compute QS-score
Tightly integrated into the mechanisms from the chain_mapping module.
The prefered way to derive an object of type :class:`QSScorer` is through
the static constructor: :func:`~FromMappingResult`. Example score
computation including mapping:
::
from ost.mol.alg.qsscore import QSScorer
from ost.mol.alg.chain_mapping import ChainMapper
ent_1 = io.LoadPDB("path_to_assembly_1.pdb")
ent_2 = io.LoadPDB("path_to_assembly_2.pdb")
chain_mapper = ChainMapper(ent_1)
mapping_result = chain_mapper.GetlDDTMapping(ent_2)
qs_scorer = QSScorer.FromMappingResult(mapping_result)
score_result = qs_scorer.GetQSScore(mapping_result.mapping)
print("score:", score_result.QS_global)
QS-score computation in :func:`QSScorer.Score` implements caching.
Repeated computations with alternative chain mappings thus become faster.
:param target: Structure designated as "target". Can be fetched from
:class:`ost.mol.alg.chain_mapping.MappingResult`
:type target: :class:`ost.mol.EntityView`/:class:`ost.mol.EntityHandle`
:param chem_groups: Groups of chemically equivalent chains in *target*.
Can be fetched from
:class:`ost.mol.alg.chain_mapping.MappingResult`
:type chem_groups: :class:`list` of :class:`list` of :class:`str`
:param model: Structure designated as "model". Can be fetched from
:class:`ost.mol.alg.chain_mapping.MappingResult`
:type model: :class:`ost.mol.EntityView`/:class:`ost.mol.EntityHandle`
:param alns: Each alignment is accessible with ``alns[(t_chain,m_chain)]``.
First sequence is the sequence of the respective chain in
:attr:`~qsent1`, second sequence the one from :attr:`~qsent2`.
Can be fetched from
:class:`ost.mol.alg.chain_mapping.MappingResult`
:type alns: :class:`dict` with key: :class:`tuple` of :class:`str`, value:
:class:`ost.seq.AlignmentHandle`
"""
def __init__(self, target, chem_groups, model, alns, contact_d = 12.0):
self._qsent1 = QSEntity(target, contact_d = contact_d)
# ensure that target chain names match the ones in chem_groups
chem_group_ch_names = list(itertools.chain.from_iterable(chem_groups))
if self._qsent1.chain_names != sorted(chem_group_ch_names):
raise RuntimeError(f"Expect exact same chain names in chem_groups "
f"and in target (which is processed to only "
f"contain peptides/nucleotides). target: "
f"{self._qsent1.chain_names}, chem_groups: "
f"{chem_group_ch_names}")
self._chem_groups = chem_groups
self._qsent2 = QSEntity(model, contact_d = contact_d)
self._alns = alns
# cache for mapped interface scores
# key: tuple of tuple ((qsent1_ch1, qsent1_ch2),
# ((qsent2_ch1, qsent2_ch2))
# value: tuple with four numbers referring to QS-score formalism
# 1: weighted_scores
# 2: weight_sum
# 3: weight_extra_mapped
# 4: weight_extra_all
self._mapped_cache = dict()
# cache for non-mapped interfaces in qsent1
# key: tuple (qsent1_ch1, qsent1_ch2)
# value: contribution of that interface to weight_extra_all
self._qsent_1_penalties = dict()
# same for qsent2
self._qsent_2_penalties = dict()
@staticmethod
def FromMappingResult(mapping_result):
""" The preferred way to get a :class:`QSScorer`
Static constructor that derives an object of type :class:`QSScorer`
using a :class:`ost.mol.alg.chain_mapping.MappingResult`
:param mapping_result: Data source
:type mapping_result: :class:`ost.mol.alg.chain_mapping.MappingResult`
"""
qs_scorer = QSScorer(mapping_result.target, mapping_result.chem_groups,
mapping_result.model, alns = mapping_result.alns)
return qs_scorer
@property
def qsent1(self):
""" Represents *target*
:type: :class:`QSEntity`
"""
return self._qsent1
@property
def chem_groups(self):
""" Groups of chemically equivalent chains in *target*
Provided at object construction
:type: :class:`list` of :class:`list` of :class:`str`
"""
return self._chem_groups
@property
def qsent2(self):
""" Represents *model*
:type: :class:`QSEntity`
"""
return self._qsent2
@property
def alns(self):
""" Alignments between chains in :attr:`~qsent1` and :attr:`~qsent2`
Provided at object construction. Each alignment is accessible with
``alns[(t_chain,m_chain)]``. First sequence is the sequence of the
respective chain in :attr:`~qsent1`, second sequence the one from
:attr:`~qsent2`.
:type: :class:`dict` with key: :class:`tuple` of :class:`str`, value:
:class:`ost.seq.AlignmentHandle`
"""
return self._alns
def Score(self, mapping, check=True):
""" Computes QS-score given chain mapping
Again, the preferred way is to get *mapping* is from an object
of type :class:`ost.mol.alg.chain_mapping.MappingResult`.
:param mapping: see
:attr:`ost.mol.alg.chain_mapping.MappingResult.mapping`
:type mapping: :class:`list` of :class:`list` of :class:`str`
:param check: Perform input checks, can be disabled for speed purposes
if you know what you're doing.
:type check: :class:`bool`
:returns: Result object of type :class:`QSScorerResult`
"""
if check:
# ensure that dimensionality of mapping matches self.chem_groups
if len(self.chem_groups) != len(mapping):
raise RuntimeError("Dimensions of self.chem_groups and mapping "
"must match")
for a,b in zip(self.chem_groups, mapping):
if len(a) != len(b):
raise RuntimeError("Dimensions of self.chem_groups and "
"mapping must match")
# ensure that chain names in mapping are all present in qsent2
for name in itertools.chain.from_iterable(mapping):
if name is not None and name not in self.qsent2.chain_names:
raise RuntimeError(f"Each chain in mapping must be present "
f"in self.qsent2. No match for "
f"\"{name}\"")
flat_mapping = dict()
for a, b in zip(self.chem_groups, mapping):
flat_mapping.update({x: y for x, y in zip(a, b) if y is not None})
return self.FromFlatMapping(flat_mapping)
def FromFlatMapping(self, flat_mapping):
""" Same as :func:`Score` but with flat mapping
:param flat_mapping: Dictionary with target chain names as keys and
the mapped model chain names as value
:type flat_mapping: :class:`dict` with :class:`str` as key and value
:returns: Result object of type :class:`QSScorerResult`
"""
weighted_scores = 0.0
weight_sum = 0.0
weight_extra_mapped = 0.0
weight_extra_all = 0.0
# keep track of processed interfaces in qsent2
processed_qsent2_interfaces = set()
for int1 in self.qsent1.interacting_chains:
if int1[0] in flat_mapping and int1[1] in flat_mapping:
int2 = (flat_mapping[int1[0]], flat_mapping[int1[1]])
a, b, c, d = self._MappedInterfaceScores(int1, int2)
weighted_scores += a
weight_sum += b
weight_extra_mapped += c
weight_extra_all += d
processed_qsent2_interfaces.add((min(int2[0], int2[1]),
max(int2[0], int2[1])))
else:
weight_extra_all += self._InterfacePenalty1(int1)
# process interfaces that only exist in qsent2
r_flat_mapping = {v:k for k,v in flat_mapping.items()} # reverse mapping...
for int2 in self.qsent2.interacting_chains:
if int2 not in processed_qsent2_interfaces:
if int2[0] in r_flat_mapping and int2[1] in r_flat_mapping:
int1 = (r_flat_mapping[int2[0]], r_flat_mapping[int2[1]])
a, b, c, d = self._MappedInterfaceScores(int1, int2)
weighted_scores += a
weight_sum += b
weight_extra_mapped += c
weight_extra_all += d
else:
weight_extra_all += self._InterfacePenalty2(int2)
return QSScorerResult(weighted_scores, weight_sum, weight_extra_mapped,
weight_extra_all)
def _MappedInterfaceScores(self, int1, int2):
key_one = (int1, int2)
if key_one in self._mapped_cache:
return self._mapped_cache[key_one]
key_two = ((int1[1], int1[0]), (int2[1], int2[0]))
if key_two in self._mapped_cache:
return self._mapped_cache[key_two]
weighted_scores, weight_sum, weight_extra_mapped, weight_extra_all = \
self._InterfaceScores(int1, int2)
self._mapped_cache[key_one] = (weighted_scores, weight_sum, weight_extra_mapped,
weight_extra_all)
return (weighted_scores, weight_sum, weight_extra_mapped, weight_extra_all)
def _InterfaceScores(self, int1, int2):
d1 = self.qsent1.PairDist(int1[0], int1[1])
d2 = self.qsent2.PairDist(int2[0], int2[1])
# given two chain names a and b: if a < b, shape of pairwise distances is
# (len(a), len(b)). However, if b > a, its (len(b), len(a)) => transpose
if int1[0] > int1[1]:
d1 = d1.transpose()
if int2[0] > int2[1]:
d2 = d2.transpose()
# indices of the first chain in the two interfaces
mapped_indices_1_1, mapped_indices_1_2 = \
self._IndexMapping(int1[0], int2[0])
# indices of the second chain in the two interfaces
mapped_indices_2_1, mapped_indices_2_2 = \
self._IndexMapping(int1[1], int2[1])
# get shared_masks - for this we first need to select the actual
# mapped positions to get a one-to-one relationship and map it back
# to the original mask size
assert(self.qsent1.contact_d == self.qsent2.contact_d)
contact_d = self.qsent1.contact_d
mapped_idx_grid_1 = np.ix_(mapped_indices_1_1, mapped_indices_2_1)
mapped_idx_grid_2 = np.ix_(mapped_indices_1_2, mapped_indices_2_2)
mapped_d1_contacts = d1[mapped_idx_grid_1] < contact_d
mapped_d2_contacts = d2[mapped_idx_grid_2] < contact_d
assert(mapped_d1_contacts.shape == mapped_d2_contacts.shape)
shared_mask = np.logical_and(mapped_d1_contacts, mapped_d2_contacts)
shared_mask_d1 = np.full(d1.shape, False, dtype=bool)
shared_mask_d1[mapped_idx_grid_1] = shared_mask
shared_mask_d2 = np.full(d2.shape, False, dtype=bool)
shared_mask_d2[mapped_idx_grid_2] = shared_mask
# get mapped but nonshared masks
mapped_nonshared_mask_d1 = np.full(d1.shape, False, dtype=bool)
mapped_nonshared_mask_d1[mapped_idx_grid_1] = \
np.logical_and(np.logical_not(shared_mask), mapped_d1_contacts)
mapped_nonshared_mask_d2 = np.full(d2.shape, False, dtype=bool)
mapped_nonshared_mask_d2[mapped_idx_grid_2] = \
np.logical_and(np.logical_not(shared_mask), mapped_d2_contacts)
# contributions from shared contacts
shared_d1 = d1[shared_mask_d1]
shared_d2 = d2[shared_mask_d2]
shared_min = np.minimum(shared_d1, shared_d2)
shared_abs_diff_div_12 = np.abs(np.subtract(shared_d1, shared_d2))/12.0
weight_term = np.ones(shared_min.shape[0])
bigger_5_mask = shared_min > 5.0
weights = np.exp(-2.0*np.square((shared_min[bigger_5_mask]-5.0)/4.28))
weight_term[bigger_5_mask] = weights
diff_term = np.subtract(np.ones(weight_term.shape[0]),
shared_abs_diff_div_12)
weighted_scores = np.sum(np.multiply(weight_term, diff_term))
weight_sum = np.sum(weight_term)
# do weight_extra_all for interface one
nonshared_contact_mask_d1 = np.logical_and(np.logical_not(shared_mask_d1),
d1 < contact_d)
contact_distances = d1[nonshared_contact_mask_d1]
bigger_5 = contact_distances[contact_distances > 5]
weight_extra_all = np.sum(np.exp(-2.0*np.square((bigger_5-5.0)/4.28)))
# add 1.0 for all contact distances <= 5.0
weight_extra_all += contact_distances.shape[0] - bigger_5.shape[0]
# same for interface two
nonshared_contact_mask_d2 = np.logical_and(np.logical_not(shared_mask_d2),
d2 < contact_d)
contact_distances = d2[nonshared_contact_mask_d2]
bigger_5 = contact_distances[contact_distances > 5]
weight_extra_all += np.sum(np.exp(-2.0*np.square((bigger_5-5.0)/4.28)))
# add 1.0 for all contact distances <= 5.0
weight_extra_all += contact_distances.shape[0] - bigger_5.shape[0]
# do weight_extra_mapped for interface one
contact_distances = d1[mapped_nonshared_mask_d1]
bigger_5 = contact_distances[contact_distances > 5]
weight_extra_mapped = np.sum(np.exp(-2.0*np.square((bigger_5-5.0)/4.28)))
# add 1.0 for all contact distances <= 5.0
weight_extra_mapped += contact_distances.shape[0] - bigger_5.shape[0]
# same for interface two
contact_distances = d2[mapped_nonshared_mask_d2]
bigger_5 = contact_distances[contact_distances > 5]
weight_extra_mapped += np.sum(np.exp(-2.0*np.square((bigger_5-5.0)/4.28)))
# add 1.0 for all contact distances <= 5.0
weight_extra_mapped += contact_distances.shape[0] - bigger_5.shape[0]
return (weighted_scores, weight_sum, weight_extra_mapped, weight_extra_all)
def _IndexMapping(self, ch1, ch2):
""" Fetches aln and returns indices of (non-)aligned residues
returns 2 numpy arrays containing the indices of residues in
ch1 and ch2 which are aligned
"""
mapped_indices_1 = list()
mapped_indices_2 = list()
idx_1 = 0
idx_2 = 0
for col in self.alns[(ch1, ch2)]:
if col[0] != '-' and col[1] != '-':
mapped_indices_1.append(idx_1)
mapped_indices_2.append(idx_2)
if col[0] != '-':
idx_1 +=1
if col[1] != '-':
idx_2 +=1
return (np.array(mapped_indices_1), np.array(mapped_indices_2))
def _InterfacePenalty1(self, interface):
if interface not in self._qsent_1_penalties:
self._qsent_1_penalties[interface] = \
self._InterfacePenalty(self.qsent1, interface)
return self._qsent_1_penalties[interface]
def _InterfacePenalty2(self, interface):
if interface not in self._qsent_2_penalties:
self._qsent_2_penalties[interface] = \
self._InterfacePenalty(self.qsent2, interface)
return self._qsent_2_penalties[interface]
def _InterfacePenalty(self, qsent, interface):
d = qsent.PairDist(interface[0], interface[1])
contact_distances = d[d < qsent.contact_d]
bigger_5 = contact_distances[contact_distances > 5]
penalty = np.sum(np.exp(-2.0*np.square((bigger_5-5.0)/4.28)))
# add 1.0 for all contact distances <= 5.0
penalty += contact_distances.shape[0] - bigger_5.shape[0]
return penalty
# specify public interface
__all__ = ('QSEntity', 'QSScorer')