Skip to content
Snippets Groups Projects
Commit f0604436 authored by Studer Gabriel's avatar Studer Gabriel
Browse files

ligand scoring: add lddt pli to dissected scoring classes

parent 18da8627
Branches
Tags
No related merge requests found
......@@ -34,6 +34,7 @@ set(OST_MOL_ALG_PYMOD_MODULES
contact_score.py
ligand_scoring_base.py
ligand_scoring_scrmsd.py
ligand_scoring_lddtpli.py
)
if (NOT ENABLE_STATIC)
......
......@@ -57,10 +57,10 @@ class LigandScorer:
# lazily computed attributes
self._chain_mapper = None
# keep track of error states
# keep track of states
# simple integers instead of enums - documentation of property describes
# encoding
self._error_states = None
self._states = None
# score matrices
self._score_matrix = None
......@@ -68,16 +68,15 @@ class LigandScorer:
self._aux_data = None
@property
def error_states(self):
""" Encodes error states of ligand pairs
def states(self):
""" Encodes states of ligand pairs
Not only critical things, but also things like: a pair of ligands
simply doesn't match. Target ligands are in rows, model ligands in
columns. States are encoded as integers <= 9. Larger numbers encode
errors for child classes.
Expect a valid score if respective location in this matrix is 0.
Target ligands are in rows, model ligands in columns. States are encoded
as integers <= 9. Larger numbers encode errors for child classes.
* -1: Unknown Error - cannot be matched
* 0: Ligand pair has valid symmetries - can be matched.
* 0: Ligand pair can be matched and valid score is computed.
* 1: Ligand pair has no valid symmetry - cannot be matched.
* 2: Ligand pair has too many symmetries - cannot be matched.
You might be able to get a match by increasing *max_symmetries*.
......@@ -90,9 +89,9 @@ class LigandScorer:
:rtype: :class:`~numpy.ndarray`
"""
if self._error_states is None:
if self._states is None:
self._compute_scores()
return self._error_states
return self._states
@property
def score_matrix(self):
......@@ -102,7 +101,7 @@ class LigandScorer:
NaN values indicate that no value could be computed (i.e. different
ligands). In other words: values are only valid if respective location
:attr:`~error_states` is 0.
:attr:`~states` is 0.
:rtype: :class:`~numpy.ndarray`
"""
......@@ -118,7 +117,7 @@ class LigandScorer:
NaN values indicate that no value could be computed (i.e. different
ligands). In other words: values are only valid if respective location
:attr:`~error_states` is 0. If `substructure_match=False`, only full
:attr:`~states` is 0. If `substructure_match=False`, only full
match isomorphisms are considered, and therefore only values of 1.0
can be observed.
......@@ -138,7 +137,7 @@ class LigandScorer:
class to provide additional information for a scored ligand pair.
empty dictionaries indicate that no value could be computed
(i.e. different ligands). In other words: values are only valid if
respective location :attr:`~error_states` is 0.
respective location :attr:`~states` is 0.
:rtype: :class:`~numpy.ndarray`
"""
......@@ -301,7 +300,7 @@ class LigandScorer:
shape = (len(self.target_ligands), len(self.model_ligands))
self._score_matrix = np.full(shape, np.nan, dtype=np.float32)
self._coverage_matrix = np.full(shape, np.nan, dtype=np.float32)
self._error_states = np.full(shape, -1, dtype=np.int32)
self._states = np.full(shape, -1, dtype=np.int32)
self._aux_data = np.empty(shape, dtype=dict)
for target_id, target_ligand in enumerate(self.target_ligands):
......@@ -324,42 +323,45 @@ class LigandScorer:
# Ligands are different - skip
LogVerbose("No symmetry between %s and %s" % (
str(model_ligand), str(target_ligand)))
self._error_states[target_id, model_id] = 1
self._states[target_id, model_id] = 1
continue
except TooManySymmetriesError:
# Ligands are too symmetrical - skip
LogVerbose("Too many symmetries between %s and %s" % (
str(model_ligand), str(target_ligand)))
self._error_states[target_id, model_id] = 2
self._states[target_id, model_id] = 2
continue
except NoIsomorphicSymmetryError:
# Ligands are different - skip
LogVerbose("No isomorphic symmetry between %s and %s" % (
str(model_ligand), str(target_ligand)))
self._error_states[target_id, model_id] = 3
self._states[target_id, model_id] = 3
continue
except DisconnectedGraphError:
LogVerbose("Disconnected graph observed for %s and %s" % (
str(model_ligand), str(target_ligand)))
self._error_states[target_id, model_id] = 4
self._states[target_id, model_id] = 4
continue
#####################################################
# Compute score by calling the child class _compute #
#####################################################
score, error_state, aux = self._compute(symmetries, target_ligand,
model_ligand)
score, state, aux = self._compute(symmetries, target_ligand,
model_ligand)
############
# Finalize #
############
if error_state != 0:
if state != 0:
# non-zero error states up to 4 are reserved for base class
if error_state <= 9:
if state <= 9:
raise RuntimeError("Child returned reserved err. state")
self._error_states[target_id, model_id] = error_state
if error_state == 0:
self._states[target_id, model_id] = state
if state == 0:
if score is None or np.isnan(score):
raise RuntimeError("LigandScorer returned invalid "
"score despite 0 error state")
# it's a valid score!
self._score_matrix[target_id, model_id] = score
cvg = len(symmetries[0][0]) / len(model_ligand.atoms)
......@@ -386,12 +388,13 @@ class LigandScorer:
:class:`ost.mol.ResidueView`
:returns: A :class:`tuple` with three elements: 1) a score
(:class:`float`) 2) error state (:class:`int`).
(:class:`float`) 2) state (:class:`int`).
3) auxiliary data for this ligand pair (:class:`dict`).
If error state is 0, the score and auxiliary data will be
If state is 0, the score and auxiliary data will be
added to :attr:`~score_matrix` and :attr:`~aux_data` as well
as the respective value in :attr:`~coverage_matrix`.
Child specific non-zero states MUST be >= 10.
Returned score must be valid in this case (not None/NaN).
Child specific non-zero states must be >= 10.
"""
raise NotImplementedError("_compute must be implemented by child class")
......
import numpy as np
from ost import LogWarning
from ost import geom
from ost import mol
from ost import seq
from ost.mol.alg import lddt
from ost.mol.alg import chain_mapping
from ost.mol.alg import ligand_scoring_base
class LDDTPLIScorer(ligand_scoring_base.LigandScorer):
def __init__(self, model, target, model_ligands=None, target_ligands=None,
resnum_alignments=False, rename_ligand_chain=False,
substructure_match=False, coverage_delta=0.2,
max_symmetries=1e5, check_resnames=True, lddt_pli_radius=6.0,
add_mdl_contacts=True,
lddt_pli_thresholds = [0.5, 1.0, 2.0, 4.0],
lddt_pli_binding_site_radius=None):
super().__init__(model, target, model_ligands = model_ligands,
target_ligands = target_ligands,
resnum_alignments = resnum_alignments,
rename_ligand_chain = rename_ligand_chain,
substructure_match = substructure_match,
coverage_delta = coverage_delta,
max_symmetries = 1e5)
self.check_resnames = check_resnames
self.lddt_pli_radius = lddt_pli_radius
self.add_mdl_contacts = add_mdl_contacts
self.lddt_pli_thresholds = lddt_pli_thresholds
self.lddt_pli_binding_site_radius = lddt_pli_binding_site_radius
# lazily precomputed variables to speedup lddt-pli computation
self._lddt_pli_target_data = dict()
self._lddt_pli_model_data = dict()
self.__mappable_atoms = None
self.__chem_mapping = None
self.__chem_group_alns = None
self.__ref_mdl_alns = None
self.__chain_mapping_mdl = None
def _compute(self, symmetries, target_ligand, model_ligand):
if self.add_mdl_contacts:
result = self._compute_lddt_pli_add_mdl_contacts(symmetries,
target_ligand,
model_ligand)
else:
result = self._compute_lddt_pli_classic(symmetries,
target_ligand,
model_ligand)
state = 0
score = result["lddt_pli"]
if score is None:
if result["lddt_pli_n_contacts"] == 0:
# it's a space ship!
state = 10
else:
# unknwon error state
state = 11
return (score, state, result)
def _compute_lddt_pli_add_mdl_contacts(self, symmetries, target_ligand,
model_ligand):
###############################
# Get stuff from model/target #
###############################
trg_residues, trg_bs, trg_chains, trg_ligand_chain, \
trg_ligand_res, scorer, chem_groups = \
self._lddt_pli_get_trg_data(target_ligand)
# Copy to make sure that we don't change anything on underlying
# references
# This is not strictly necessary in the current implementation but
# hey, maybe it avoids hard to debug errors when someone changes things
ref_indices = [a.copy() for a in scorer.ref_indices_ic]
ref_distances = [a.copy() for a in scorer.ref_distances_ic]
# distance hacking... remove any interchain distance except the ones
# with the ligand
ligand_start_idx = scorer.chain_start_indices[-1]
for at_idx in range(ligand_start_idx):
mask = ref_indices[at_idx] >= ligand_start_idx
ref_indices[at_idx] = ref_indices[at_idx][mask]
ref_distances[at_idx] = ref_distances[at_idx][mask]
mdl_residues, mdl_bs, mdl_chains, mdl_ligand_chain, mdl_ligand_res, \
chem_mapping = self._lddt_pli_get_mdl_data(model_ligand)
if len(mdl_chains) == 0 or len(trg_chains) == 0:
# It's a spaceship!
return {"lddt_pli": None,
"lddt_pli_n_contacts": 0,
"target_ligand": target_ligand,
"model_ligand": model_ligand,
"bs_ref_res": trg_residues,
"bs_mdl_res": mdl_residues}
####################
# Setup alignments #
####################
# ref_mdl_alns refers to full chain mapper trg and mdl structures
# => need to adapt mdl sequence that only contain residues in contact
# with ligand
cut_ref_mdl_alns = self._lddt_pli_cut_ref_mdl_alns(chem_groups,
chem_mapping,
mdl_bs, trg_bs)
########################################
# Setup cache for added model contacts #
########################################
# get each chain mapping that we ever observe in scoring
chain_mappings = list(chain_mapping._ChainMappings(chem_groups,
chem_mapping))
# for each mdl ligand atom, we collect all trg ligand atoms that are
# ever mapped onto it given *symmetries*
ligand_atom_mappings = [set() for a in mdl_ligand_res.atoms]
for (trg_sym, mdl_sym) in symmetries:
for trg_i, mdl_i in zip(trg_sym, mdl_sym):
ligand_atom_mappings[mdl_i].add(trg_i)
mdl_ligand_pos = np.zeros((mdl_ligand_res.GetAtomCount(), 3))
for a_idx, a in enumerate(mdl_ligand_res.atoms):
p = a.GetPos()
mdl_ligand_pos[a_idx, 0] = p[0]
mdl_ligand_pos[a_idx, 1] = p[1]
mdl_ligand_pos[a_idx, 2] = p[2]
trg_ligand_pos = np.zeros((trg_ligand_res.GetAtomCount(), 3))
for a_idx, a in enumerate(trg_ligand_res.atoms):
p = a.GetPos()
trg_ligand_pos[a_idx, 0] = p[0]
trg_ligand_pos[a_idx, 1] = p[1]
trg_ligand_pos[a_idx, 2] = p[2]
mdl_lig_hashes = [a.hash_code for a in mdl_ligand_res.atoms]
symmetric_atoms = np.asarray(sorted(list(scorer.symmetric_atoms)),
dtype=np.int64)
# two caches to cache things for each chain mapping => lists
# of len(chain_mappings)
#
# In principle we're caching for each trg/mdl ligand atom pair all
# information to update ref_indices/ref_distances and resolving the
# symmetries of the binding site.
# in detail: each list entry in *scoring_cache* is a dict with
# key: (mdl_lig_at_idx, trg_lig_at_idx)
# value: tuple with 4 elements - 1: indices of atoms representing added
# contacts relative to overall inexing scheme in scorer 2: the
# respective distances 3: the same but only containing indices towards
# atoms of the binding site that are considered symmetric 4: the
# respective indices.
# each list entry in *penalty_cache* is a list of len N mdl lig atoms.
# For each mdl lig at it contains a penalty for this mdl lig at. That
# means the number of contacts in the mdl binding site that can
# directly be mapped to the target given the local chain mapping but
# are not present in the target binding site, i.e. interacting atoms are
# too far away.
scoring_cache = list()
penalty_cache = list()
for mapping in chain_mappings:
# flat mapping with mdl chain names as key
flat_mapping = dict()
for trg_chem_group, mdl_chem_group in zip(chem_groups, mapping):
for a,b in zip(trg_chem_group, mdl_chem_group):
if a is not None and b is not None:
flat_mapping[b] = a
# for each mdl bs atom (as atom hash), the trg bs atoms (as index in scorer)
bs_atom_mapping = dict()
for mdl_cname, ref_cname in flat_mapping.items():
aln = cut_ref_mdl_alns[(ref_cname, mdl_cname)]
ref_ch = trg_bs.Select(f"cname={mol.QueryQuoteName(ref_cname)}")
mdl_ch = mdl_bs.Select(f"cname={mol.QueryQuoteName(mdl_cname)}")
aln.AttachView(0, ref_ch)
aln.AttachView(1, mdl_ch)
for col in aln:
ref_r = col.GetResidue(0)
mdl_r = col.GetResidue(1)
if ref_r.IsValid() and mdl_r.IsValid():
for mdl_a in mdl_r.atoms:
ref_a = ref_r.FindAtom(mdl_a.GetName())
if ref_a.IsValid():
ref_h = ref_a.handle.hash_code
if ref_h in scorer.atom_indices:
mdl_h = mdl_a.handle.hash_code
bs_atom_mapping[mdl_h] = \
scorer.atom_indices[ref_h]
cache = dict()
n_penalties = list()
for mdl_a_idx, mdl_a in enumerate(mdl_ligand_res.atoms):
n_penalty = 0
trg_bs_indices = list()
close_a = mdl_bs.FindWithin(mdl_a.GetPos(),
self.lddt_pli_radius)
for a in close_a:
mdl_a_hash_code = a.hash_code
if mdl_a_hash_code in bs_atom_mapping:
trg_bs_indices.append(bs_atom_mapping[mdl_a_hash_code])
elif mdl_a_hash_code not in mdl_lig_hashes:
if a.GetChain().GetName() in flat_mapping:
# Its in a mapped chain
at_key = (a.GetResidue().GetNumber(), a.name)
cname = a.GetChain().name
cname_key = (flat_mapping[cname], cname)
if at_key in self._mappable_atoms[cname_key]:
# Its a contact in the model but not part of
# trg_bs. It can still be mapped using the
# global mdl_ch/ref_ch alignment
# d in ref > self.lddt_pli_radius + max_thresh
# => guaranteed to be non-fulfilled contact
n_penalty += 1
n_penalties.append(n_penalty)
trg_bs_indices = np.asarray(sorted(trg_bs_indices))
for trg_a_idx in ligand_atom_mappings[mdl_a_idx]:
# mask selects entries in trg_bs_indices that are not yet
# part of classic lDDT ref_indices for atom at trg_a_idx
# => added mdl contacts
mask = np.isin(trg_bs_indices, ref_indices[ligand_start_idx + trg_a_idx],
assume_unique=True, invert=True)
added_indices = np.asarray([], dtype=np.int64)
added_distances = np.asarray([], dtype=np.float64)
if np.sum(mask) > 0:
# compute ref distances on reference positions
added_indices = trg_bs_indices[mask]
tmp = scorer.positions.take(added_indices, axis=0)
np.subtract(tmp, trg_ligand_pos[trg_a_idx][None, :], out=tmp)
np.square(tmp, out=tmp)
tmp = tmp.sum(axis=1)
np.sqrt(tmp, out=tmp) # distances against all relevant atoms
added_distances = tmp
# extract the distances towards bs atoms that are symmetric
sym_mask = np.isin(added_indices, symmetric_atoms,
assume_unique=True)
cache[(mdl_a_idx, trg_a_idx)] = (added_indices, added_distances,
added_indices[sym_mask],
added_distances[sym_mask])
scoring_cache.append(cache)
penalty_cache.append(n_penalties)
# cache for model contacts towards non mapped trg chains - this is
# relevant for self._lddt_pli_unmapped_chain_penalty
# key: tuple in form (trg_ch, mdl_ch)
# value: yet another dict with
# key: ligand_atom_hash
# value: n contacts towards respective trg chain that can be mapped
non_mapped_cache = dict()
###############################################################
# compute lDDT for all possible chain mappings and symmetries #
###############################################################
best_score = -1.0
best_result = {"lddt_pli": None,
"lddt_pli_n_contacts": 0}
# dummy alignment for ligand chains which is needed as input later on
ligand_aln = seq.CreateAlignment()
trg_s = seq.CreateSequence(trg_ligand_chain.name,
trg_ligand_res.GetOneLetterCode())
mdl_s = seq.CreateSequence(mdl_ligand_chain.name,
mdl_ligand_res.GetOneLetterCode())
ligand_aln.AddSequence(trg_s)
ligand_aln.AddSequence(mdl_s)
ligand_at_indices = list(range(ligand_start_idx, scorer.n_atoms))
sym_idx_collector = [None] * scorer.n_atoms
sym_dist_collector = [None] * scorer.n_atoms
for mapping, s_cache, p_cache in zip(chain_mappings, scoring_cache, penalty_cache):
lddt_chain_mapping = dict()
lddt_alns = dict()
for ref_chem_group, mdl_chem_group in zip(chem_groups, mapping):
for ref_ch, mdl_ch in zip(ref_chem_group, mdl_chem_group):
# some mdl chains can be None
if mdl_ch is not None:
lddt_chain_mapping[mdl_ch] = ref_ch
lddt_alns[mdl_ch] = cut_ref_mdl_alns[(ref_ch, mdl_ch)]
# add ligand to lddt_chain_mapping/lddt_alns
lddt_chain_mapping[mdl_ligand_chain.name] = trg_ligand_chain.name
lddt_alns[mdl_ligand_chain.name] = ligand_aln
# already process model, positions will be manually hacked for each
# symmetry - small overhead for variables that are thrown away here
pos, _, _, _, _, _, lddt_symmetries = \
scorer._ProcessModel(mdl_bs, lddt_chain_mapping,
residue_mapping = lddt_alns,
thresholds = self.lddt_pli_thresholds,
check_resnames = self.check_resnames)
# estimate a penalty for unsatisfied model contacts from chains
# that are not in the local trg binding site, but can be mapped in
# the target.
# We're using the trg chain with the closest geometric center that
# can be mapped to the mdl chain according the chem mapping.
# An alternative would be to search for the target chain with
# the minimal number of additional contacts.
# There is not good solution for this problem...
unmapped_chains = list()
for mdl_ch in mdl_chains:
if mdl_ch not in lddt_chain_mapping:
# check which chain in trg is closest
chem_group_idx = None
for i, m in enumerate(self._chem_mapping):
if mdl_ch in m:
chem_group_idx = i
break
if chem_group_idx is None:
raise RuntimeError("This should never happen... "
"ask Gabriel...")
mdl_ch_view = self._chain_mapping_mdl.FindChain(mdl_ch)
mdl_center = mdl_ch_view.geometric_center
closest_ch = None
closest_dist = None
for trg_ch in self.chain_mapper.chem_groups[chem_group_idx]:
if trg_ch not in lddt_chain_mapping.values():
c = self.chain_mapper.target.FindChain(trg_ch).geometric_center
d = geom.Distance(mdl_center, c)
if closest_dist is None or d < closest_dist:
closest_dist = d
closest_ch = trg_ch
if closest_ch is not None:
unmapped_chains.append((closest_ch, mdl_ch))
for (trg_sym, mdl_sym) in symmetries:
# update positions
for mdl_i, trg_i in zip(mdl_sym, trg_sym):
pos[ligand_start_idx + trg_i, :] = mdl_ligand_pos[mdl_i, :]
# start new ref_indices/ref_distances from original values
funky_ref_indices = [np.copy(a) for a in ref_indices]
funky_ref_distances = [np.copy(a) for a in ref_distances]
# The only distances from the binding site towards the ligand
# we care about are the ones from the symmetric atoms to
# correctly compute scorer._ResolveSymmetries.
# We collect them while updating distances from added mdl
# contacts
for idx in symmetric_atoms:
sym_idx_collector[idx] = list()
sym_dist_collector[idx] = list()
# add data from added mdl contacts cache
added_penalty = 0
for mdl_i, trg_i in zip(mdl_sym, trg_sym):
added_penalty += p_cache[mdl_i]
cache = s_cache[mdl_i, trg_i]
full_trg_i = ligand_start_idx + trg_i
funky_ref_indices[full_trg_i] = \
np.append(funky_ref_indices[full_trg_i], cache[0])
funky_ref_distances[full_trg_i] = \
np.append(funky_ref_distances[full_trg_i], cache[1])
for idx, d in zip(cache[2], cache[3]):
sym_idx_collector[idx].append(full_trg_i)
sym_dist_collector[idx].append(d)
for idx in symmetric_atoms:
funky_ref_indices[idx] = \
np.append(funky_ref_indices[idx],
np.asarray(sym_idx_collector[idx],
dtype=np.int64))
funky_ref_distances[idx] = \
np.append(funky_ref_distances[idx],
np.asarray(sym_dist_collector[idx],
dtype=np.float64))
# we can pass funky_ref_indices/funky_ref_distances as
# sym_ref_indices/sym_ref_distances in
# scorer._ResolveSymmetries as we only have distances of the bs
# to the ligand and ligand atoms are "non-symmetric"
scorer._ResolveSymmetries(pos, self.lddt_pli_thresholds,
lddt_symmetries,
funky_ref_indices,
funky_ref_distances)
N = sum([len(funky_ref_indices[i]) for i in ligand_at_indices])
N += added_penalty
# collect number of expected contacts which can be mapped
if len(unmapped_chains) > 0:
N += self._lddt_pli_unmapped_chain_penalty(unmapped_chains,
non_mapped_cache,
mdl_bs,
mdl_ligand_res,
mdl_sym)
conserved = np.sum(scorer._EvalAtoms(pos, ligand_at_indices,
self.lddt_pli_thresholds,
funky_ref_indices,
funky_ref_distances), axis=0)
score = None
if N > 0:
score = np.mean(conserved/N)
if score is not None and score > best_score:
best_score = score
best_result = {"lddt_pli": score,
"lddt_pli_n_contacts": N}
# fill misc info to result object
best_result["target_ligand"] = target_ligand
best_result["model_ligand"] = model_ligand
best_result["bs_ref_res"] = trg_residues
best_result["bs_mdl_res"] = mdl_residues
return best_result
def _compute_lddt_pli_classic(self, symmetries, target_ligand,
model_ligand):
###############################
# Get stuff from model/target #
###############################
max_r = None
if self.lddt_pli_binding_site_radius:
max_r = self.lddt_pli_binding_site_radius
trg_residues, trg_bs, trg_chains, trg_ligand_chain, \
trg_ligand_res, scorer, chem_groups = \
self._lddt_pli_get_trg_data(target_ligand, max_r = max_r)
# Copy to make sure that we don't change anything on underlying
# references
# This is not strictly necessary in the current implementation but
# hey, maybe it avoids hard to debug errors when someone changes things
ref_indices = [a.copy() for a in scorer.ref_indices_ic]
ref_distances = [a.copy() for a in scorer.ref_distances_ic]
# no matter what mapping/symmetries, the number of expected
# contacts stays the same
ligand_start_idx = scorer.chain_start_indices[-1]
ligand_at_indices = list(range(ligand_start_idx, scorer.n_atoms))
n_exp = sum([len(ref_indices[i]) for i in ligand_at_indices])
mdl_residues, mdl_bs, mdl_chains, mdl_ligand_chain, mdl_ligand_res, \
chem_mapping = self._lddt_pli_get_mdl_data(model_ligand)
if n_exp == 0:
# no contacts... nothing to compute...
return {"lddt_pli": None,
"lddt_pli_n_contacts": 0,
"target_ligand": target_ligand,
"model_ligand": model_ligand,
"bs_ref_res": trg_residues,
"bs_mdl_res": mdl_residues}
# Distance hacking... remove any interchain distance except the ones
# with the ligand
for at_idx in range(ligand_start_idx):
mask = ref_indices[at_idx] >= ligand_start_idx
ref_indices[at_idx] = ref_indices[at_idx][mask]
ref_distances[at_idx] = ref_distances[at_idx][mask]
####################
# Setup alignments #
####################
# ref_mdl_alns refers to full chain mapper trg and mdl structures
# => need to adapt mdl sequence that only contain residues in contact
# with ligand
cut_ref_mdl_alns = self._lddt_pli_cut_ref_mdl_alns(chem_groups,
chem_mapping,
mdl_bs, trg_bs)
###############################################################
# compute lDDT for all possible chain mappings and symmetries #
###############################################################
best_score = -1.0
# dummy alignment for ligand chains which is needed as input later on
l_aln = seq.CreateAlignment()
l_aln.AddSequence(seq.CreateSequence(trg_ligand_chain.name,
trg_ligand_res.GetOneLetterCode()))
l_aln.AddSequence(seq.CreateSequence(mdl_ligand_chain.name,
mdl_ligand_res.GetOneLetterCode()))
mdl_ligand_pos = np.zeros((model_ligand.GetAtomCount(), 3))
for a_idx, a in enumerate(model_ligand.atoms):
p = a.GetPos()
mdl_ligand_pos[a_idx, 0] = p[0]
mdl_ligand_pos[a_idx, 1] = p[1]
mdl_ligand_pos[a_idx, 2] = p[2]
for mapping in chain_mapping._ChainMappings(chem_groups, chem_mapping):
lddt_chain_mapping = dict()
lddt_alns = dict()
for ref_chem_group, mdl_chem_group in zip(chem_groups, mapping):
for ref_ch, mdl_ch in zip(ref_chem_group, mdl_chem_group):
# some mdl chains can be None
if mdl_ch is not None:
lddt_chain_mapping[mdl_ch] = ref_ch
lddt_alns[mdl_ch] = cut_ref_mdl_alns[(ref_ch, mdl_ch)]
# add ligand to lddt_chain_mapping/lddt_alns
lddt_chain_mapping[mdl_ligand_chain.name] = trg_ligand_chain.name
lddt_alns[mdl_ligand_chain.name] = l_aln
# already process model, positions will be manually hacked for each
# symmetry - small overhead for variables that are thrown away here
pos, _, _, _, _, _, lddt_symmetries = \
scorer._ProcessModel(mdl_bs, lddt_chain_mapping,
residue_mapping = lddt_alns,
thresholds = self.lddt_pli_thresholds,
check_resnames = self.check_resnames)
for (trg_sym, mdl_sym) in symmetries:
for mdl_i, trg_i in zip(mdl_sym, trg_sym):
pos[ligand_start_idx + trg_i, :] = mdl_ligand_pos[mdl_i, :]
# we can pass ref_indices/ref_distances as
# sym_ref_indices/sym_ref_distances in
# scorer._ResolveSymmetries as we only have distances of the bs
# to the ligand and ligand atoms are "non-symmetric"
scorer._ResolveSymmetries(pos, self.lddt_pli_thresholds,
lddt_symmetries,
ref_indices,
ref_distances)
# compute number of conserved distances for ligand atoms
conserved = np.sum(scorer._EvalAtoms(pos, ligand_at_indices,
self.lddt_pli_thresholds,
ref_indices,
ref_distances), axis=0)
score = np.mean(conserved/n_exp)
if score > best_score:
best_score = score
# fill misc info to result object
best_result = {"lddt_pli": best_score,
"lddt_pli_n_contacts": n_exp,
"target_ligand": target_ligand,
"model_ligand": model_ligand,
"bs_ref_res": trg_residues,
"bs_mdl_res": mdl_residues}
return best_result
def _lddt_pli_unmapped_chain_penalty(self, unmapped_chains,
non_mapped_cache,
mdl_bs,
mdl_ligand_res,
mdl_sym):
n_exp = 0
for ch_tuple in unmapped_chains:
if ch_tuple not in non_mapped_cache:
# for each ligand atom, we count the number of mappable atoms
# within lddt_pli_radius
counts = dict()
# the select statement also excludes the ligand in mdl_bs
# as it resides in a separate chain
mdl_cname = ch_tuple[1]
mdl_bs_ch = mdl_bs.Select(f"cname={mol.QueryQuoteName(mdl_cname)}")
for a in mdl_ligand_res.atoms:
close_atoms = \
mdl_bs_ch.FindWithin(a.GetPos(), self.lddt_pli_radius)
N = 0
for close_a in close_atoms:
at_key = (close_a.GetResidue().GetNumber(),
close_a.GetName())
if at_key in self._mappable_atoms[ch_tuple]:
N += 1
counts[a.hash_code] = N
# fill cache
non_mapped_cache[ch_tuple] = counts
# add number of mdl contacts which can be mapped to target
# as non-fulfilled contacts
counts = non_mapped_cache[ch_tuple]
lig_hash_codes = [a.hash_code for a in mdl_ligand_res.atoms]
for i in mdl_sym:
n_exp += counts[lig_hash_codes[i]]
return n_exp
def _lddt_pli_get_mdl_data(self, model_ligand):
if model_ligand not in self._lddt_pli_model_data:
mdl = self._chain_mapping_mdl
mdl_residues = set()
for at in model_ligand.atoms:
close_atoms = mdl.FindWithin(at.GetPos(), self.lddt_pli_radius)
for close_at in close_atoms:
mdl_residues.add(close_at.GetResidue())
max_r = self.lddt_pli_radius + max(self.lddt_pli_thresholds)
for r in mdl.residues:
r.SetIntProp("bs", 0)
for at in model_ligand.atoms:
close_atoms = mdl.FindWithin(at.GetPos(), max_r)
for close_at in close_atoms:
close_at.GetResidue().SetIntProp("bs", 1)
mdl_bs = mol.CreateEntityFromView(mdl.Select("grbs:0=1"), True)
mdl_chains = set([ch.name for ch in mdl_bs.chains])
mdl_editor = mdl_bs.EditXCS(mol.BUFFERED_EDIT)
mdl_ligand_chain = None
for cname in ["hugo_the_cat_terminator", "ida_the_cheese_monster"]:
try:
# I'm pretty sure, one of these chain names is not there...
mdl_ligand_chain = mdl_editor.InsertChain(cname)
break
except:
pass
if mdl_ligand_chain is None:
raise RuntimeError("Fuck this, I'm out...")
mdl_ligand_res = mdl_editor.AppendResidue(mdl_ligand_chain,
model_ligand,
deep=True)
mdl_editor.RenameResidue(mdl_ligand_res, "LIG")
mdl_editor.SetResidueNumber(mdl_ligand_res, mol.ResNum(1))
chem_mapping = list()
for m in self._chem_mapping:
chem_mapping.append([x for x in m if x in mdl_chains])
self._lddt_pli_model_data[model_ligand] = (mdl_residues,
mdl_bs,
mdl_chains,
mdl_ligand_chain,
mdl_ligand_res,
chem_mapping)
return self._lddt_pli_model_data[model_ligand]
def _lddt_pli_get_trg_data(self, target_ligand, max_r = None):
if target_ligand not in self._lddt_pli_target_data:
trg = self.chain_mapper.target
if max_r is None:
max_r = self.lddt_pli_radius + max(self.lddt_pli_thresholds)
trg_residues = set()
for at in target_ligand.atoms:
close_atoms = trg.FindWithin(at.GetPos(), max_r)
for close_at in close_atoms:
trg_residues.add(close_at.GetResidue())
for r in trg.residues:
r.SetIntProp("bs", 0)
for r in trg_residues:
r.SetIntProp("bs", 1)
trg_bs = mol.CreateEntityFromView(trg.Select("grbs:0=1"), True)
trg_chains = set([ch.name for ch in trg_bs.chains])
trg_editor = trg_bs.EditXCS(mol.BUFFERED_EDIT)
trg_ligand_chain = None
for cname in ["hugo_the_cat_terminator", "ida_the_cheese_monster"]:
try:
# I'm pretty sure, one of these chain names is not there yet
trg_ligand_chain = trg_editor.InsertChain(cname)
break
except:
pass
if trg_ligand_chain is None:
raise RuntimeError("Fuck this, I'm out...")
trg_ligand_res = trg_editor.AppendResidue(trg_ligand_chain,
target_ligand,
deep=True)
trg_editor.RenameResidue(trg_ligand_res, "LIG")
trg_editor.SetResidueNumber(trg_ligand_res, mol.ResNum(1))
compound_name = trg_ligand_res.name
compound = lddt.CustomCompound.FromResidue(trg_ligand_res)
custom_compounds = {compound_name: compound}
scorer = lddt.lDDTScorer(trg_bs,
custom_compounds = custom_compounds,
inclusion_radius = self.lddt_pli_radius)
chem_groups = list()
for g in self.chain_mapper.chem_groups:
chem_groups.append([x for x in g if x in trg_chains])
self._lddt_pli_target_data[target_ligand] = (trg_residues,
trg_bs,
trg_chains,
trg_ligand_chain,
trg_ligand_res,
scorer,
chem_groups)
return self._lddt_pli_target_data[target_ligand]
def _lddt_pli_cut_ref_mdl_alns(self, chem_groups, chem_mapping, mdl_bs,
ref_bs):
cut_ref_mdl_alns = dict()
for ref_chem_group, mdl_chem_group in zip(chem_groups, chem_mapping):
for ref_ch in ref_chem_group:
ref_bs_chain = ref_bs.FindChain(ref_ch)
query = "cname=" + mol.QueryQuoteName(ref_ch)
ref_view = self.chain_mapper.target.Select(query)
for mdl_ch in mdl_chem_group:
aln = self._ref_mdl_alns[(ref_ch, mdl_ch)]
aln.AttachView(0, ref_view)
mdl_bs_chain = mdl_bs.FindChain(mdl_ch)
query = "cname=" + mol.QueryQuoteName(mdl_ch)
aln.AttachView(1, self._chain_mapping_mdl.Select(query))
cut_mdl_seq = ['-'] * aln.GetLength()
cut_ref_seq = ['-'] * aln.GetLength()
for i, col in enumerate(aln):
# check ref residue
r = col.GetResidue(0)
if r.IsValid():
bs_r = ref_bs_chain.FindResidue(r.GetNumber())
if bs_r.IsValid():
cut_ref_seq[i] = col[0]
# check mdl residue
r = col.GetResidue(1)
if r.IsValid():
bs_r = mdl_bs_chain.FindResidue(r.GetNumber())
if bs_r.IsValid():
cut_mdl_seq[i] = col[1]
cut_ref_seq = ''.join(cut_ref_seq)
cut_mdl_seq = ''.join(cut_mdl_seq)
cut_aln = seq.CreateAlignment()
cut_aln.AddSequence(seq.CreateSequence(ref_ch, cut_ref_seq))
cut_aln.AddSequence(seq.CreateSequence(mdl_ch, cut_mdl_seq))
cut_ref_mdl_alns[(ref_ch, mdl_ch)] = cut_aln
return cut_ref_mdl_alns
@property
def _mappable_atoms(self):
""" Stores mappable atoms given a chain mapping
Store for each ref_ch,mdl_ch pair all mdl atoms that can be
mapped. Don't store mappable atoms as hashes but rather as tuple
(mdl_r.GetNumber(), mdl_a.GetName()). Reason for that is that one might
operate on Copied EntityHandle objects without corresponding hashes.
Given a tuple defining c_pair: (ref_cname, mdl_cname), one
can check if a certain atom is mappable by evaluating:
if (mdl_r.GetNumber(), mdl_a.GetName()) in self._mappable_atoms(c_pair)
"""
if self.__mappable_atoms is None:
self.__mappable_atoms = dict()
for (ref_cname, mdl_cname), aln in self._ref_mdl_alns.items():
self._mappable_atoms[(ref_cname, mdl_cname)] = set()
ref_ch = self.chain_mapper.target.Select(f"cname={mol.QueryQuoteName(ref_cname)}")
mdl_ch = self._chain_mapping_mdl.Select(f"cname={mol.QueryQuoteName(mdl_cname)}")
aln.AttachView(0, ref_ch)
aln.AttachView(1, mdl_ch)
for col in aln:
ref_r = col.GetResidue(0)
mdl_r = col.GetResidue(1)
if ref_r.IsValid() and mdl_r.IsValid():
for mdl_a in mdl_r.atoms:
if ref_r.FindAtom(mdl_a.name).IsValid():
c_key = (ref_cname, mdl_cname)
at_key = (mdl_r.GetNumber(), mdl_a.name)
self.__mappable_atoms[c_key].add(at_key)
return self.__mappable_atoms
@property
def _chem_mapping(self):
if self.__chem_mapping is None:
self.__chem_mapping, self.__chem_group_alns, \
self.__chain_mapping_mdl = \
self.chain_mapper.GetChemMapping(self.model)
return self.__chem_mapping
@property
def _chem_group_alns(self):
if self.__chem_group_alns is None:
self.__chem_mapping, self.__chem_group_alns, \
self.__chain_mapping_mdl = \
self.chain_mapper.GetChemMapping(self.model)
return self.__chem_group_alns
@property
def _ref_mdl_alns(self):
if self.__ref_mdl_alns is None:
self.__ref_mdl_alns = \
chain_mapping._GetRefMdlAlns(self.chain_mapper.chem_groups,
self.chain_mapper.chem_group_alignments,
self._chem_mapping,
self._chem_group_alns)
return self.__ref_mdl_alns
@property
def _chain_mapping_mdl(self):
if self.__chain_mapping_mdl is None:
self.__chem_mapping, self.__chem_group_alns, \
self.__chain_mapping_mdl = \
self.chain_mapper.GetChemMapping(self.model)
return self.__chain_mapping_mdl
......@@ -9,11 +9,11 @@ from ost.mol.alg import ligand_scoring_base
class SCRMSDScorer(ligand_scoring_base.LigandScorer):
def __init__(self, model, target, model_ligands=None, target_ligands=None,
resnum_alignments=False, rename_ligand_chain=False,
substructure_match=False, coverage_delta=0.2,
max_symmetries=1e5, bs_radius=4.0, lddt_lp_radius=15.0,
model_bs_radius=25, binding_sites_topn=100000,
full_bs_search=False):
resnum_alignments=False, rename_ligand_chain=False,
substructure_match=False, coverage_delta=0.2,
max_symmetries=1e5, bs_radius=4.0, lddt_lp_radius=15.0,
model_bs_radius=25, binding_sites_topn=100000,
full_bs_search=False):
super().__init__(model, target, model_ligands = model_ligands,
......
......@@ -10,6 +10,7 @@ try:
from ost.mol.alg.ligand_scoring_base import *
from ost.mol.alg import ligand_scoring_base
from ost.mol.alg import ligand_scoring_scrmsd
from ost.mol.alg import ligand_scoring_lddtpli
except ImportError:
print("Failed to import ligand_scoring.py. Happens when numpy, scipy or "
"networkx is missing. Ignoring test_ligand_scoring.py tests.")
......@@ -281,6 +282,7 @@ class TestLigandScoringFancy(unittest.TestCase):
with self.assertRaises(NoSymmetryError):
ligand_scoring_scrmsd.SCRMSD(trg_g3d1_sub, mdl_g3d) # no full match
def test_compute_rmsd_scores(self):
"""Test that _compute_scores works.
"""
......@@ -300,6 +302,128 @@ class TestLigandScoringFancy(unittest.TestCase):
[0.29399303],
[np.nan]]), decimal=5)
def test_compute_lddtpli_scores(self):
trg = _LoadMMCIF("1r8q.cif.gz")
mdl = _LoadMMCIF("P84080_model_02.cif.gz")
mdl_lig = io.LoadEntity(os.path.join('testfiles', "P84080_model_02_ligand_0.sdf"))
sc = ligand_scoring_lddtpli.LDDTPLIScorer(mdl, trg, [mdl_lig], None,
add_mdl_contacts = False,
lddt_pli_binding_site_radius = 4.0)
self.assertEqual(sc.score_matrix.shape, (7, 1))
self.assertTrue(np.isnan(sc.score_matrix[0, 0]))
self.assertAlmostEqual(sc.score_matrix[1, 0], 0.99843, 5)
self.assertTrue(np.isnan(sc.score_matrix[2, 0]))
self.assertTrue(np.isnan(sc.score_matrix[3, 0]))
self.assertTrue(np.isnan(sc.score_matrix[4, 0]))
self.assertAlmostEqual(sc.score_matrix[5, 0], 1.0)
self.assertTrue(np.isnan(sc.score_matrix[6, 0]))
def test_check_resnames(self):
"""Test that the check_resname argument works.
When set to True, it should raise an error if any residue in the
representation of the binding site in the model has a different
name than in the reference. Here we manually modify a residue
name to achieve that effect. This is only relevant for the LDDTPLIScorer
"""
trg_4c0a = _LoadMMCIF("4c0a.cif.gz")
trg = trg_4c0a.Select("cname=C or cname=I")
# Here we modify the name of a residue in 4C0A (THR => TPO in C.15)
# This residue is in the binding site and should trigger the error
mdl = ost.mol.CreateEntityFromView(trg, include_exlusive_atoms=False)
ed = mdl.EditICS()
ed.RenameResidue(mdl.FindResidue("C", 15), "TPO")
ed.UpdateXCS()
with self.assertRaises(RuntimeError):
sc = ligand_scoring_lddtpli.LDDTPLIScorer(mdl, trg, [mdl.FindResidue("I", 1)], [trg.FindResidue("I", 1)], check_resnames=True)
sc._compute_scores()
sc = ligand_scoring_lddtpli.LDDTPLIScorer(mdl, trg, [mdl.FindResidue("I", 1)], [trg.FindResidue("I", 1)], check_resnames=False)
sc._compute_scores()
def test_added_mdl_contacts(self):
# binding site for ligand in chain G consists of chains A and B
prot = _LoadMMCIF("1r8q.cif.gz").Copy()
# model has the full binding site
mdl = mol.CreateEntityFromView(prot.Select("cname=A,B,G"), True)
# chain C has same sequence as chain A but is not in contact
# with ligand in chain G
# target has thus incomplete binding site only from chain B
trg = mol.CreateEntityFromView(prot.Select("cname=B,C,G"), True)
# if added model contacts are not considered, the incomplete binding
# site only from chain B is perfectly reproduced by model which also has
# chain B
sc = ligand_scoring_lddtpli.LDDTPLIScorer(mdl, trg, add_mdl_contacts=False)
self.assertAlmostEqual(sc.score_matrix[0,0], 1.0, 5)
# if added model contacts are considered, contributions from chain B are
# perfectly reproduced but all contacts of ligand towards chain A are
# added as penalty
sc = ligand_scoring_lddtpli.LDDTPLIScorer(mdl, trg, add_mdl_contacts=True)
lig = prot.Select("cname=G")
A_count = 0
B_count = 0
for a in lig.atoms:
close_atoms = mdl.FindWithin(a.GetPos(), sc.lddt_pli_radius)
for ca in close_atoms:
cname = ca.GetChain().GetName()
if cname == "G":
pass # its a ligand atom...
elif cname == "A":
A_count += 1
elif cname == "B":
B_count += 1
self.assertAlmostEqual(sc.score_matrix[0,0],
B_count/(A_count + B_count), 5)
# Same as before but additionally we remove residue TRP.66
# from chain C in the target to test mapping magic...
# Chain C is NOT in contact with the ligand but we only
# add contacts from chain A as penalty that are mappable
# to the closest chain with same sequence. That would be
# chain C
query = "cname=B,G or (cname=C and rnum!=66)"
trg = mol.CreateEntityFromView(prot.Select(query), True)
sc = ligand_scoring_lddtpli.LDDTPLIScorer(mdl, trg, add_mdl_contacts=True)
TRP66_count = 0
for a in lig.atoms:
close_atoms = mdl.FindWithin(a.GetPos(), sc.lddt_pli_radius)
for ca in close_atoms:
cname = ca.GetChain().GetName()
if cname == "A" and ca.GetResidue().GetNumber().GetNum() == 66:
TRP66_count += 1
self.assertEqual(TRP66_count, 134)
# remove TRP66_count from original penalty
self.assertAlmostEqual(sc.score_matrix[0,0],
B_count/(A_count + B_count - TRP66_count), 5)
# Move a random atom in the model from chain B towards the ligand center
# chain B is also present in the target and interacts with the ligand,
# but that atom would be far away and thus adds to the penalty. Since
# the ligand is small enough, the number of added contacts should be
# exactly the number of ligand atoms.
mdl_ed = mdl.EditXCS()
at = mdl.FindResidue("B", mol.ResNum(8)).FindAtom("NZ")
mdl_ed.SetAtomPos(at, lig.geometric_center)
sc = ligand_scoring_lddtpli.LDDTPLIScorer(mdl, trg, add_mdl_contacts=True)
# compared to the last assertAlmostEqual, we add the number of ligand
# atoms as additional penalties
self.assertAlmostEqual(sc.score_matrix[0,0],
B_count/(A_count + B_count - TRP66_count + \
lig.GetAtomCount()), 5)
if __name__ == "__main__":
from ost import testutils
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment