Something went wrong on our end
-
Studer Gabriel authored
Avoids issues in query language when funky characters appear in chain names
Studer Gabriel authoredAvoids issues in query language when funky characters appear in chain names
lddt.py 50.80 KiB
import numpy as np
from ost import mol
from ost import conop
class CustomCompound:
""" Defines atoms for custom compounds
lDDT requires the reference atoms of a compound which are typically
extracted from a :class:`ost.conop.CompoundLib`. This lightweight
container allows to handle arbitrary compounds which are not
necessarily in the compound library.
:param atom_names: Names of atoms of custom compound
:type atom_names: :class:`list` of :class:`str`
"""
def __init__(self, atom_names):
self.atom_names = atom_names
@staticmethod
def FromResidue(res):
""" Construct custom compound from residue
:param res: Residue from which reference atom names are extracted,
hydrogen/deuterium atoms are filtered out
:type res: :class:`ost.mol.ResidueView`/:class:`ost.mol.ResidueHandle`
:returns: :class:`CustomCompound`
"""
at_names = [a.name for a in res.atoms if a.element not in ["H", "D"]]
if len(at_names) != len(set(at_names)):
raise RuntimeError("Duplicate atoms detected in CustomCompound")
compound = CustomCompound(at_names)
return compound
class SymmetrySettings:
"""Container for symmetric compounds
lDDT considers symmetries and selects the one resulting in the highest
possible score.
A symmetry is defined as a renaming operation on one or more atoms that
leads to a chemically equivalent residue. Example would be OD1 and OD2 in
ASP => renaming OD1 to OD2 and vice versa gives a chemically equivalent
residue.
Use :func:`AddSymmetricCompound` to define a symmetry which can then
directly be accessed through the *symmetric_compounds* member.
"""
def __init__(self):
self.symmetric_compounds = dict()
def AddSymmetricCompound(self, name, symmetric_atoms):
"""Adds symmetry for compound with *name*
:param name: Name of compound with symmetry
:type name: :class:`str`
:param symmetric_atoms: Pairs of atom names that define renaming
operation, i.e. after applying all switches
defined in the tuples, the resulting residue
should be chemically equivalent. Atom names
must refer to the PDB component dictionary.
:type symmetric_atoms: :class:`list` of :class:`tuple`
"""
for pair in symmetric_atoms:
if len(pair) != 2:
raise RuntimeError("Expect pairs when defining symmetries")
self.symmetric_compounds[name] = symmetric_atoms
def GetDefaultSymmetrySettings():
"""Constructs and returns :class:`SymmetrySettings` object for natural amino
acids
"""
symmetry_settings = SymmetrySettings()
# ASP
symmetry_settings.AddSymmetricCompound("ASP", [("OD1", "OD2")])
# GLU
symmetry_settings.AddSymmetricCompound("GLU", [("OE1", "OE2")])
# LEU
symmetry_settings.AddSymmetricCompound("LEU", [("CD1", "CD2")])
# VAL
symmetry_settings.AddSymmetricCompound("VAL", [("CG1", "CG2")])
# ARG
symmetry_settings.AddSymmetricCompound("ARG", [("NH1", "NH2")])
# PHE
symmetry_settings.AddSymmetricCompound(
"PHE", [("CD1", "CD2"), ("CE1", "CE2")]
)
# TYR
symmetry_settings.AddSymmetricCompound(
"TYR", [("CD1", "CD2"), ("CE1", "CE2")]
)
return symmetry_settings
class lDDTScorer:
"""lDDT scorer object for a specific target
Sets up everything to score models of that target. lDDT (local distance
difference test) is defined as fraction of pairwise distances which exhibit
a difference < threshold when considering target and model. In case of
multiple thresholds, the average is returned. See
V. Mariani, M. Biasini, A. Barbato, T. Schwede, lDDT : A local
superposition-free score for comparing protein structures and models using
distance difference tests, Bioinformatics, 2013
:param target: The target
:type target: :class:`ost.mol.EntityHandle`/:class:`ost.mol.EntityView`
:param compound_lib: Compound library from which a compound for each residue
is extracted based on its name. Uses
:func:`ost.conop.GetDefaultLib` if not given, raises
if this returns no valid compound library. Atoms
defined in the compound are searched in the residue and
build the reference for scoring. If the residue has
atoms with names ["A", "B", "C"] but the corresponding
compound only has ["A", "B"], "A" and "B" are
considered for scoring. If the residue has atoms
["A", "B"] but the compound has ["A", "B", "C"], "C" is
considered missing and does not influence scoring, even
if present in the model.
:param custom_compounds: Custom compounds defining reference atoms. If
given, *custom_compounds* take precedent over
*compound_lib*.
:type custom_compounds: :class:`dict` with residue names (:class:`str`) as
key and :class:`CustomCompound` as value.
:type compound_lib: :class:`ost.conop.CompoundLib`
:param inclusion_radius: All pairwise distances < *inclusion_radius* are
considered for scoring
:type inclusion_radius: :class:`float`
:param sequence_separation: Only pairwise distances between atoms of
residues which are further apart than this
threshold are considered. Residue distance is
based on resnum. The default (0) considers all
pairwise distances except intra-residue
distances.
:type sequence_separation: :class:`int`
:param symmetry_settings: Define residues exhibiting internal symmetry, uses
:func:`GetDefaultSymmetrySettings` if not given.
:type symmetry_settings: :class:`SymmetrySettings`
:param seqres_mapping: Mapping of model residues at the scoring stage
happens with residue numbers defining their location
in a reference sequence (SEQRES) using one based
indexing. If the residue numbers in *target* don't
correspond to that SEQRES, you can specify the
mapping manually. You can provide a dictionary to
specify a reference sequence (SEQRES) for one or more
chain(s). Key: chain name, value: alignment
(seq1: SEQRES, seq2: sequence of residues in chain).
Example: The residues in a chain with name "A" have
sequence "YEAH" and residue numbers [42,43,44,45].
You can provide an alignment with seq1 "``HELLYEAH``"
and seq2 "``----YEAH``". "Y" gets assigned residue
number 5, "E" gets assigned 6 and so on no matter
what the original residue numbers were.
:type seqres_mapping: :class:`dict` (key: :class:`str`, value:
:class:`ost.seq.AlignmentHandle`)
:param bb_only: Only consider atoms with name "CA" in case of amino acids and
"C3'" for Nucleotides. this invalidates *compound_lib*.
Raises if any residue in *target* is not
`r.chem_class.IsPeptideLinking()` or
`r.chem_class.IsNucleotideLinking()`
:type bb_only: :class:`bool`
:raises: :class:`RuntimeError` if *target* contains compound which is not in
*compound_lib*, :class:`RuntimeError` if *symmetry_settings*
specifies symmetric atoms that are not present in the according
compound in *compound_lib*, :class:`RuntimeError` if
*seqres_mapping* is not provided and *target* contains residue
numbers with insertion codes or the residue numbers for each chain
are not monotonically increasing, :class:`RuntimeError` if
*seqres_mapping* is provided but an alignment is invalid
(seq1 contains gaps, mismatch in seq1/seq2, seq2 does not match
residues in corresponding chains).
"""
def __init__(
self,
target,
compound_lib=None,
custom_compounds=None,
inclusion_radius=15,
sequence_separation=0,
symmetry_settings=None,
seqres_mapping=dict(),
bb_only=False
):
self.target = target
self.inclusion_radius = inclusion_radius
self.sequence_separation = sequence_separation
if compound_lib is None:
compound_lib = conop.GetDefaultLib()
if compound_lib is None:
raise RuntimeError("No compound_lib given and conop.GetDefaultLib "
"returns no valid compound library")
self.compound_lib = compound_lib
self.custom_compounds = custom_compounds
if symmetry_settings is None:
self.symmetry_settings = GetDefaultSymmetrySettings()
else:
self.symmetry_settings = symmetry_settings
# whether to only consider atoms with name "CA" (amino acids) or C3'
# (nucleotides), invalidates *compound_lib*
self.bb_only=bb_only
# names of heavy atoms of each unique compound present in *target* as
# extracted from *compound_lib*, e.g.
# self.compound_anames["GLY"] = ["N", "CA", "C", "O"]
self.compound_anames = dict()
# stores symmetry information for those compounds as defined in
# *symmetry_settings*
self.compound_symmetric_atoms = dict()
# list of len(target.chains) containing all chain names in *target*
self.chain_names = list()
# list of len(target.residues) containing all compound names in *target*
self.compound_names = list()
# list of len(target.residues) defining start pos in internal reference
# positions for each residue
self.res_start_indices = list()
# list of len(target.residues) defining residue numbers in target
self.res_resnums = list()
# list of len(target.chains) defining start pos in internal reference
# positions for each chain
self.chain_start_indices = list()
# list of len(target.chains) defining start pos in self.compound_names
# for each chain
self.chain_res_start_indices = list()
# maps residues in *target* to indices in
# self.compound_names/self.res_start_indices. A residue gets identified
# by a tuple (first element: chain name, second element: residue number,
# residue number is either the actual residue number in *target* or
# given by *seqres_mapping*)
self.res_mapper = dict()
# number of atoms as specified in compounds. not all are necessarily
# covered by structure
self.n_atoms = None
# stores an index for each AtomHandle in *target*
# (atom hashcode => index)
self.atom_indices = dict()
# store indices of all atoms that have symmetry properties
self.symmetric_atoms = set()
# setup members defined above
self._SetupEnv(self.compound_lib, self.custom_compounds,
self.symmetry_settings, seqres_mapping, self.bb_only)
# distance related members are lazily computed as they're affected
# by different flavours of lDDT (e.g. lDDT including inter-chain
# contacts or not etc.)
# stores for each atom the other atoms within inclusion_radius
self._ref_indices = None
# the corresponding distances
self._ref_distances = None
# The following lists will be sparsely populated. We keep for each
# symmetry related atom the distances towards all atoms which are NOT
# affected by symmetry. So we can evaluate two symmetric versions
# against the fixed stuff later on and select the better scoring one.
self._sym_ref_indices = None
self._sym_ref_distances = None
# total number of distances
self._n_distances = None
# exactly the same as above but without interchain contacts
# => single-chain (sc)
self._ref_indices_sc = None
self._ref_distances_sc = None
self._sym_ref_indices_sc = None
self._sym_ref_distances_sc = None
self._n_distances_sc = None
# exactly the same as above but without intrachain contacts
# => inter-chain (ic)
self._ref_indices_ic = None
self._ref_distances_ic = None
self._sym_ref_indices_ic = None
self._sym_ref_distances_ic = None
self._n_distances_ic = None
# input parameter checking
self._ProcessSequenceSeparation()
@property
def ref_indices(self):
if self._ref_indices is None:
self._SetupDistances()
return self._ref_indices
@property
def ref_distances(self):
if self._ref_distances is None:
self._SetupDistances()
return self._ref_distances
@property
def sym_ref_indices(self):
if self._sym_ref_indices is None:
self._SetupDistances()
return self._sym_ref_indices
@property
def sym_ref_distances(self):
if self._sym_ref_distances is None:
self._SetupDistances()
return self._sym_ref_distances
@property
def n_distances(self):
if self._n_distances is None:
self._n_distances = sum([len(x) for x in self.ref_indices])
return self._n_distances
@property
def ref_indices_sc(self):
if self._ref_indices_sc is None:
self._SetupDistancesSC()
return self._ref_indices_sc
@property
def ref_distances_sc(self):
if self._ref_distances_sc is None:
self._SetupDistancesSC()
return self._ref_distances_sc
@property
def sym_ref_indices_sc(self):
if self._sym_ref_indices_sc is None:
self._SetupDistancesSC()
return self._sym_ref_indices_sc
@property
def sym_ref_distances_sc(self):
if self._sym_ref_distances_sc is None:
self._SetupDistancesSC()
return self._sym_ref_distances_sc
@property
def n_distances_sc(self):
if self._n_distances_sc is None:
self._n_distances_sc = sum([len(x) for x in self.ref_indices_sc])
return self._n_distances_sc
@property
def ref_indices_ic(self):
if self._ref_indices_ic is None:
self._SetupDistancesIC()
return self._ref_indices_ic
@property
def ref_distances_ic(self):
if self._ref_distances_ic is None:
self._SetupDistancesIC()
return self._ref_distances_ic
@property
def sym_ref_indices_ic(self):
if self._sym_ref_indices_ic is None:
self._SetupDistancesIC()
return self._sym_ref_indices_ic
@property
def sym_ref_distances_ic(self):
if self._sym_ref_distances_ic is None:
self._SetupDistancesIC()
return self._sym_ref_distances_ic
@property
def n_distances_ic(self):
if self._n_distances_ic is None:
self._n_distances_ic = sum([len(x) for x in self.ref_indices_ic])
return self._n_distances_ic
def lDDT(self, model, thresholds = [0.5, 1.0, 2.0, 4.0],
local_lddt_prop=None, local_contact_prop=None,
chain_mapping=None, no_interchain=False,
no_intrachain=False, penalize_extra_chains=False,
residue_mapping=None, return_dist_test=False,
check_resnames=True):
"""Computes lDDT of *model* - globally and per-residue
:param model: Model to be scored - models are preferably scored upon
performing stereo-chemistry checks in order to punish for
non-sensical irregularities. This must be done separately
as a pre-processing step.
:type model: :class:`ost.mol.EntityHandle`/:class:`ost.mol.EntityView`
:param thresholds: Thresholds of distance differences to be considered
as correct - see docs in constructor for more info.
default: [0.5, 1.0, 2.0, 4.0]
:type thresholds: :class:`list` of :class:`floats`
:param local_lddt_prop: If set, per-residue scores will be assigned as
generic float property of that name
:type local_lddt_prop: :class:`str`
:param local_contact_prop: If set, number of expected contacts as well
as number of conserved contacts will be
assigned as generic int property.
Excected contacts will be set as
<local_contact_prop>_exp, conserved contacts
as <local_contact_prop>_cons. Values
are summed over all thresholds.
:type local_contact_prop: :class:`str`
:param chain_mapping: Mapping of model chains (key) onto target chains
(value). This is required if target or model have
more than one chain.
:type chain_mapping: :class:`dict` with :class:`str` as keys/values
:param no_interchain: Whether to exclude interchain contacts
:type no_interchain: :class:`bool`
:param no_intrachain: Whether to exclude intrachain contacts (i.e. only
consider interface related contacts)
:type no_intrachain: :class:`bool`
:param penalize_extra_chains: Whether to include a fixed penalty for
additional chains in the model that are
not mapped to the target. ONLY AFFECTS
RETURNED GLOBAL SCORE. In detail: adds the
number of intra-chain contacts of each
extra chain to the expected contacts, thus
adding a penalty.
:param penalize_extra_chains: :class:`bool`
:param residue_mapping: By default, residue mapping is based on residue
numbers. That means, a model chain and the
respective target chain map to the same
underlying reference sequence (SEQRES).
Alternatively, you can specify one or
several alignment(s) between model and target
chains by providing a dictionary. key: Name
of chain in model (respective target chain is
extracted from *chain_mapping*),
value: Alignment with first sequence
corresponding to target chain and second
sequence to model chain. There is NO reference
sequence involved, so the two sequences MUST
exactly match the actual residues observed in
the respective target/model chains (ATOMSEQ).
:type residue_mapping: :class:`dict` with key: :class:`str`,
value: :class:`ost.seq.AlignmentHandle`
:param return_dist_test: Whether to additionally return the underlying
per-residue data for the distance difference
test. Adds five objects to the return tuple.
First: Number of total contacts summed over all
thresholds
Second: Number of conserved contacts summed
over all thresholds
Third: list with length of scored residues.
Contains indices referring to model.residues.
Fourth: numpy array of size
len(scored_residues) containing the number of
total contacts,
Fifth: numpy matrix of shape
(len(scored_residues), len(thresholds))
specifying how many for each threshold are
conserved.
:param check_resnames: On by default. Enforces residue name matches
between mapped model and target residues.
:type check_resnames: :class:`bool`
:returns: global and per-residue lDDT scores as a tuple -
first element is global lDDT score and second element
a list of per-residue scores with length len(*model*.residues)
None is assigned to residues that are not covered by target
"""
if chain_mapping is None:
if len(self.chain_names) > 1 or len(model.chains) > 1:
raise NotImplementedError("Must provide chain mapping if "
"target or model have > 1 chains.")
chain_mapping = {model.chains[0].GetName(): self.chain_names[0]}
else:
# check whether chains specified in mapping exist
for model_chain, target_chain in chain_mapping.items():
if target_chain not in self.chain_names:
raise RuntimeError(f"Target chain specified in "
f"chain_mapping ({target_chain}) does "
f"not exist. Target has chains: "
f"{self.chain_names}")
ch = model.FindChain(model_chain)
if not ch.IsValid():
raise RuntimeError(f"Model chain specified in "
f"chain_mapping ({model_chain}) does "
f"not exist. Model has chains: "
f"{[c.GetName() for c in model.chains]}")
# initialize positions with values far in nirvana. If a position is not
# set, it should be far away from any position in model.
max_pos = model.bounds.GetMax()
max_coordinate = abs(max(max_pos[0], max_pos[1], max_pos[2]))
max_coordinate += 42 * max(thresholds)
pos = np.ones((self.n_atoms, 3), dtype=np.float32) * max_coordinate
# for each scored residue in model a list of indices describing the
# atoms from the reference that should be there
res_ref_atom_indices = list()
# for each scored residue in model a list of indices of atoms that are
# actually there
res_atom_indices = list()
# indices of the scored residues
res_indices = list()
# Will contain one element per symmetry group
symmetries = list()
current_model_res_idx = -1
for ch in model.chains:
model_ch_name = ch.GetName()
if model_ch_name not in chain_mapping:
current_model_res_idx += len(ch.residues)
continue # additional model chain which is not mapped
target_ch_name = chain_mapping[model_ch_name]
rnums = self._GetChainRNums(ch, residue_mapping, model_ch_name,
target_ch_name)
for r, rnum in zip(ch.residues, rnums):
current_model_res_idx += 1
res_mapper_key = (target_ch_name, rnum)
if res_mapper_key not in self.res_mapper:
continue
r_idx = self.res_mapper[res_mapper_key]
if check_resnames and r.name != self.compound_names[r_idx]:
raise RuntimeError(
f"Residue name mismatch for {r}, "
f" expect {self.compound_names[r_idx]}"
)
res_start_idx = self.res_start_indices[r_idx]
rname = self.compound_names[r_idx]
anames = self.compound_anames[rname]
atoms = [r.FindAtom(aname) for aname in anames]
res_ref_atom_indices.append(
list(range(res_start_idx, res_start_idx + len(anames)))
)
res_atom_indices.append(list())
res_indices.append(current_model_res_idx)
for a_idx, a in enumerate(atoms):
if a.IsValid():
p = a.GetPos()
pos[res_start_idx + a_idx][0] = p[0]
pos[res_start_idx + a_idx][1] = p[1]
pos[res_start_idx + a_idx][2] = p[2]
res_atom_indices[-1].append(res_start_idx + a_idx)
if rname in self.compound_symmetric_atoms:
sym_indices = list()
for sym_tuple in self.compound_symmetric_atoms[rname]:
a_one = atoms[sym_tuple[0]]
a_two = atoms[sym_tuple[1]]
if a_one.IsValid() and a_two.IsValid():
sym_indices.append(
(
res_start_idx + sym_tuple[0],
res_start_idx + sym_tuple[1],
)
)
if len(sym_indices) > 0:
symmetries.append(sym_indices)
if no_interchain and no_intrachain:
raise RuntimeError("on_interchain and no_intrachain flags are "
"mutually exclusive")
if no_interchain:
sym_ref_indices = self.sym_ref_indices_sc
sym_ref_distances = self.sym_ref_distances_sc
ref_indices = self.ref_indices_sc
ref_distances = self.ref_distances_sc
n_distances = self.n_distances_sc
elif no_intrachain:
sym_ref_indices = self.sym_ref_indices_ic
sym_ref_distances = self.sym_ref_distances_ic
ref_indices = self.ref_indices_ic
ref_distances = self.ref_distances_ic
n_distances = self.n_distances_ic
else:
sym_ref_indices = self.sym_ref_indices
sym_ref_distances = self.sym_ref_distances
ref_indices = self.ref_indices
ref_distances = self.ref_distances
n_distances = self.n_distances
self._ResolveSymmetries(pos, thresholds, symmetries, sym_ref_indices,
sym_ref_distances)
per_res_exp = np.asarray([self._GetNExp(res_ref_atom_indices[idx],
ref_indices) for idx in range(len(res_indices))], dtype=np.int32)
per_res_conserved = self._EvalResidues(pos, thresholds,
res_atom_indices,
ref_indices, ref_distances)
n_thresh = len(thresholds)
# do per-residue scores
per_res_lDDT = [None] * len(model.residues)
for idx in range(len(res_indices)):
n_exp = n_thresh * per_res_exp[idx]
if n_exp > 0:
score = np.sum(per_res_conserved[idx,:]) / n_exp
per_res_lDDT[res_indices[idx]] = score
else:
per_res_lDDT[res_indices[idx]] = 0.0
# do full model score
if penalize_extra_chains:
n_distances += self._GetExtraModelChainPenalty(model, chain_mapping)
lDDT_tot = int(n_thresh * n_distances)
lDDT_cons = int(np.sum(per_res_conserved))
lDDT = None
if lDDT_tot > 0:
lDDT = float(lDDT_cons) / lDDT_tot
# set properties if necessary
if local_lddt_prop:
residues = model.residues
for idx in res_indices:
residues[idx].SetFloatProp(local_lddt_prop, per_res_lDDT[idx])
if local_contact_prop:
residues = model.residues
exp_prop = local_contact_prop + "_exp"
conserved_prop = local_contact_prop + "_cons"
for i, r_idx in enumerate(res_indices):
residues[r_idx].SetIntProp(exp_prop,
n_thresh * int(per_res_exp[i]))
residues[r_idx].SetIntProp(conserved_prop,
int(np.sum(per_res_conserved[i,:])))
if return_dist_test:
return lDDT, per_res_lDDT, lDDT_tot, lDDT_cons, res_indices, \
per_res_exp, per_res_conserved
else:
return lDDT, per_res_lDDT
def GetNChainContacts(self, target_chain, no_interchain=False):
"""Returns number of contacts expected for a certain chain in *target*
:param target_chain: Chain in *target* for which you want the number
of expected contacts
:type target_chain: :class:`str`
:param no_interchain: Whether to exclude interchain contacts
:type no_interchain: :class:`bool`
:raises: :class:`RuntimeError` if specified chain doesnt exist
"""
if target_chain not in self.chain_names:
raise RuntimeError(f"Specified chain name ({target_chain}) not in "
f"target")
ch_idx = self.chain_names.index(target_chain)
s = self.chain_start_indices[ch_idx]
e = self.n_atoms
if ch_idx + 1 < len(self.chain_names):
e = self.chain_start_indices[ch_idx+1]
if no_interchain:
return self._GetNExp(list(range(s, e)), self.ref_indices_sc)
else:
return self._GetNExp(list(range(s, e)), self.ref_indices)
def _GetExtraModelChainPenalty(self, model, chain_mapping):
"""Counts n distances in extra model chains to be added as penalty
"""
penalty = 0
for chain in model.chains:
ch_name = chain.GetName()
if ch_name not in chain_mapping:
sm = self.symmetry_settings
mdl_sel = model.Select(f"cname={mol.QueryQuoteName(ch_name)}")
dummy_scorer = lDDTScorer(mdl_sel, self.compound_lib,
symmetry_settings = sm,
inclusion_radius = self.inclusion_radius,
bb_only = self.bb_only)
penalty += dummy_scorer.n_distances
return penalty
def _GetChainRNums(self, ch, residue_mapping, model_ch_name,
target_ch_name):
"""Map residues in model chain to target residues
There are two options: one is simply using residue numbers,
the other is a custom mapping as given in *residue_mapping*
"""
if residue_mapping and model_ch_name in residue_mapping:
# extract residue numbers from target chain
ch_idx = self.chain_names.index(target_ch_name)
start_idx = self.chain_res_start_indices[ch_idx]
if ch_idx < len(self.chain_names) - 1:
end_idx = self.chain_res_start_indices[ch_idx+1]
else:
end_idx = len(self.compound_names)
target_rnums = self.res_resnums[start_idx:end_idx]
# get sequences from alignment and do consistency checks
target_seq = residue_mapping[model_ch_name].GetSequence(0)
model_seq = residue_mapping[model_ch_name].GetSequence(1)
if len(target_seq.GetGaplessString()) != len(target_rnums):
raise RuntimeError(f"Try to perform residue mapping for "
f"model chain {model_ch_name} which "
f"maps to {target_ch_name} in target. "
f"Target sequence in alignment suggests "
f"{len(target_seq.GetGaplessString())} "
f"residues but {len(target_rnums)} are "
f"expected.")
if len(model_seq.GetGaplessString()) != len(ch.residues):
raise RuntimeError(f"Try to perform residue mapping for "
f"model chain {model_ch_name} which "
f"maps to {target_ch_name} in target. "
f"Model sequence in alignment suggests "
f"{len(model_seq.GetGaplessString())} "
f"residues but {len(ch.residues)} are "
f"expected.")
rnums = list()
target_idx = -1
for col in residue_mapping[model_ch_name]:
if col[0] != '-':
target_idx += 1
# handle match
if col[0] != '-' and col[1] != '-':
rnums.append(target_rnums[target_idx])
# insertion in model adds None to rnum
if col[0] == '-' and col[1] != '-':
rnums.append(None)
else:
rnums = [r.GetNumber() for r in ch.residues]
return rnums
def _SetupEnv(self, compound_lib, custom_compounds, symmetry_settings,
seqres_mapping, bb_only):
"""Sets target related lDDTScorer members defined in constructor
No distance related members - see _SetupDistances
"""
residue_numbers = self._GetTargetResidueNumbers(self.target,
seqres_mapping)
current_idx = 0
for chain in self.target.chains:
ch_name = chain.GetName()
self.chain_names.append(ch_name)
self.chain_start_indices.append(current_idx)
self.chain_res_start_indices.append(len(self.compound_names))
for r, rnum in zip(chain.residues, residue_numbers[ch_name]):
if r.name not in self.compound_anames:
# sets compound info in self.compound_anames and
# self.compound_symmetric_atoms
self._SetupCompound(r, compound_lib, custom_compounds,
symmetry_settings, bb_only)
self.res_start_indices.append(current_idx)
self.res_mapper[(ch_name, rnum)] = len(self.compound_names)
self.compound_names.append(r.name)
self.res_resnums.append(rnum)
atoms = [r.FindAtom(an) for an in self.compound_anames[r.name]]
for a in atoms:
if a.IsValid():
self.atom_indices[a.handle.GetHashCode()] = current_idx
current_idx += 1
if r.name in self.compound_symmetric_atoms:
for sym_tuple in self.compound_symmetric_atoms[r.name]:
for a_idx in sym_tuple:
a = atoms[a_idx]
if a.IsValid():
hashcode = a.handle.GetHashCode()
self.symmetric_atoms.add(
self.atom_indices[hashcode]
)
self.n_atoms = current_idx
def _GetTargetResidueNumbers(self, target, seqres_mapping):
"""Returns residue numbers for each chain in target as dict
They're either directly extracted from the raw residue number
from the structure or from user provided alignments
"""
residue_numbers = dict()
for ch in target.chains:
ch_name = ch.GetName()
rnums = list()
if ch_name in seqres_mapping:
seqres = seqres_mapping[ch_name].GetSequence(0).GetString()
atomseq = seqres_mapping[ch_name].GetSequence(1).GetString()
# SEQRES must not contain gaps
if "-" in seqres:
raise RuntimeError(
"SEQRES in seqres_mapping must not " "contain gaps"
)
atomseq_from_chain = [r.one_letter_code for r in ch.residues]
if atomseq.replace("-", "") != atomseq_from_chain:
raise RuntimeError(
"ATOMSEQ in seqres_mapping must match "
"raw sequence extracted from chain "
"residues"
)
rnum = 0
for seqres_olc, atomseq_olc in zip(seqres, atomseq):
if seqres_olc != "-":
rnum += 1
if atomseq_olc != "-":
if seqres_olc != atomseq_olc:
raise RuntimeError(
f"Residue with number {rnum} in "
f"chain {ch_name} has SEQRES "
f"ATOMSEQ mismatch"
)
rnums.append(mol.ResNum(rnum))
else:
rnums = [r.GetNumber() for r in ch.residues]
assert len(rnums) == len(ch.residues)
residue_numbers[ch_name] = rnums
return residue_numbers
def _SetupCompound(self, r, compound_lib, custom_compounds,
symmetry_settings, bb_only):
"""fill self.compound_anames/self.compound_symmetric_atoms
"""
if bb_only:
# throw away compound_lib info
if r.chem_class.IsPeptideLinking():
self.compound_anames[r.name] = ["CA"]
elif r.chem_class.IsNucleotideLinking():
self.compound_anames[r.name] = ["C3'"]
else:
raise RuntimeError(f"Only support amino acids and nucleotides "
f"if bb_only is True, failed with {str(r)}")
self.compound_symmetric_atoms[r.name] = list()
else:
atom_names = list()
symmetric_atoms = list()
if custom_compounds is not None and r.GetName() in custom_compounds:
atom_names = list(custom_compounds[r.GetName()].atom_names)
else:
compound = compound_lib.FindCompound(r.name)
if compound is None:
raise RuntimeError(f"no entry for {r} in compound_lib")
for atom_spec in compound.GetAtomSpecs():
if atom_spec.element not in ["H", "D"]:
atom_names.append(atom_spec.name)
if r.name in symmetry_settings.symmetric_compounds:
for pair in symmetry_settings.symmetric_compounds[r.name]:
try:
a = atom_names.index(pair[0])
b = atom_names.index(pair[1])
except:
msg = f"Could not find symmetric atoms "
msg += f"({pair[0]}, {pair[1]}) for {r.name} "
msg += f"as specified in SymmetrySettings in "
msg += f"compound from component dictionary. "
msg += f"Atoms in compound: {atom_names}"
raise RuntimeError(msg)
symmetric_atoms.append((a, b))
self.compound_anames[r.name] = atom_names
if len(symmetric_atoms) > 0:
self.compound_symmetric_atoms[r.name] = symmetric_atoms
def _SetupDistances(self):
"""Compute distance related members of lDDTScorer
"""
# init
self._ref_indices = [np.asarray([], dtype=np.int64) for idx in range(self.n_atoms)]
self._ref_distances = [np.asarray([], dtype=np.float64) for idx in range(self.n_atoms)]
self._sym_ref_indices = [np.asarray([], dtype=np.int64) for idx in range(self.n_atoms)]
self._sym_ref_distances = [np.asarray([], dtype=np.float64) for idx in range(self.n_atoms)]
# initialize positions with values far in nirvana. If a position is not
# set, it should be far away from any position in target (or at least
# more than inclusion_radius).
max_pos = self.target.bounds.GetMax()
max_coordinate = abs(max(max_pos[0], max_pos[1], max_pos[2]))
max_coordinate += 2 * self.inclusion_radius
pos = np.ones((self.n_atoms, 3), dtype=np.float32) * max_coordinate
atom_indices = list()
mask_start = list()
mask_end = list()
for r_idx, r in enumerate(self.target.residues):
r_start_idx = self.res_start_indices[r_idx]
r_n_atoms = len(self.compound_anames[r.name])
r_end_idx = r_start_idx + r_n_atoms
for a in r.atoms:
if a.handle.GetHashCode() in self.atom_indices:
idx = self.atom_indices[a.handle.GetHashCode()]
p = a.GetPos()
pos[idx][0] = p[0]
pos[idx][1] = p[1]
pos[idx][2] = p[2]
atom_indices.append(idx)
mask_start.append(r_start_idx)
mask_end.append(r_end_idx)
indices, distances = self._CloseStuff(pos, self.inclusion_radius,
atom_indices, mask_start,
mask_end)
for i in range(len(atom_indices)):
self._ref_indices[atom_indices[i]] = indices[i]
self._ref_distances[atom_indices[i]] = distances[i]
self._NonSymDistances(self._ref_indices, self._ref_distances,
self._sym_ref_indices,
self._sym_ref_distances)
def _SetupDistancesSC(self):
"""Select subset of contacts only covering intra-chain contacts
"""
# init
self._ref_indices_sc = [np.asarray([], dtype=np.int64) for idx in range(self.n_atoms)]
self._ref_distances_sc = [np.asarray([], dtype=np.float64) for idx in range(self.n_atoms)]
self._sym_ref_indices_sc = [np.asarray([], dtype=np.int64) for idx in range(self.n_atoms)]
self._sym_ref_distances_sc = [np.asarray([], dtype=np.float64) for idx in range(self.n_atoms)]
# start from overall contacts
ref_indices = self.ref_indices
ref_distances = self.ref_distances
sym_ref_indices = self.sym_ref_indices
sym_ref_distances = self.sym_ref_distances
n_chains = len(self.chain_start_indices)
for ch_idx, ch in enumerate(self.target.chains):
chain_s = self.chain_start_indices[ch_idx]
chain_e = self.n_atoms
if ch_idx + 1 < n_chains:
chain_e = self.chain_start_indices[ch_idx+1]
for i in range(chain_s, chain_e):
if len(ref_indices[i]) > 0:
intra_idx = np.where(np.logical_and(ref_indices[i]>=chain_s,
ref_indices[i]<chain_e))[0]
self._ref_indices_sc[i] = ref_indices[i][intra_idx]
self._ref_distances_sc[i] = ref_distances[i][intra_idx]
self._NonSymDistances(self._ref_indices_sc, self._ref_distances_sc,
self._sym_ref_indices_sc,
self._sym_ref_distances_sc)
def _SetupDistancesIC(self):
"""Select subset of contacts only covering inter-chain contacts
"""
# init
self._ref_indices_ic = [np.asarray([], dtype=np.int64) for idx in range(self.n_atoms)]
self._ref_distances_ic = [np.asarray([], dtype=np.float64) for idx in range(self.n_atoms)]
self._sym_ref_indices_ic = [np.asarray([], dtype=np.int64) for idx in range(self.n_atoms)]
self._sym_ref_distances_ic = [np.asarray([], dtype=np.float64) for idx in range(self.n_atoms)]
# start from overall contacts
ref_indices = self.ref_indices
ref_distances = self.ref_distances
sym_ref_indices = self.sym_ref_indices
sym_ref_distances = self.sym_ref_distances
n_chains = len(self.chain_start_indices)
for ch_idx, ch in enumerate(self.target.chains):
chain_s = self.chain_start_indices[ch_idx]
chain_e = self.n_atoms
if ch_idx + 1 < n_chains:
chain_e = self.chain_start_indices[ch_idx+1]
for i in range(chain_s, chain_e):
if len(ref_indices[i]) > 0:
inter_idx = np.where(np.logical_or(ref_indices[i]<chain_s,
ref_indices[i]>=chain_e))[0]
self._ref_indices_ic[i] = ref_indices[i][inter_idx]
self._ref_distances_ic[i] = ref_distances[i][inter_idx]
self._NonSymDistances(self._ref_indices_ic, self._ref_distances_ic,
self._sym_ref_indices_ic,
self._sym_ref_distances_ic)
def _CloseStuff(self, pos, inclusion_radius, indices, mask_start, mask_end):
"""returns close stuff for positions specified by indices
"""
# TODO: this function does brute force distance computation which has
# quadratic complexity...
close_indices = list()
distances = list()
# work with squared_inclusion_radius (sir) to save some square roots
sir = inclusion_radius ** 2
for idx, ms, me in zip(indices, mask_start, mask_end):
p = pos[idx, :]
tmp = pos - p[None, :]
np.square(tmp, out=tmp)
tmp = tmp.sum(axis=1)
# mask out atoms of own residue => put them far away
tmp[range(ms, me)] = 2 * sir
close_indices.append(np.nonzero(tmp <= sir)[0])
distances.append(np.sqrt(tmp[close_indices[-1]]))
return (close_indices, distances)
def _NonSymDistances(self, ref_indices, ref_distances,
sym_ref_indices, sym_ref_distances):
"""Transfer indices/distances of non-symmetric atoms in place
"""
for idx in self.symmetric_atoms:
indices = list()
distances = list()
for i, d in zip(ref_indices[idx], ref_distances[idx]):
if i not in self.symmetric_atoms:
indices.append(i)
distances.append(d)
sym_ref_indices[idx] = indices
sym_ref_distances[idx] = np.asarray(distances)
def _EvalAtom(self, pos, atom_idx, thresholds, ref_indices, ref_distances):
"""Computes number of distance differences within given thresholds
returns np.array with len(thresholds) elements
"""
a_p = pos[atom_idx, :]
tmp = pos.take(ref_indices[atom_idx], axis=0)
np.subtract(tmp, a_p[None, :], out=tmp)
np.square(tmp, out=tmp)
tmp = tmp.sum(axis=1)
np.sqrt(tmp, out=tmp) # distances against all relevant atoms
np.subtract(ref_distances[atom_idx], tmp, out=tmp)
np.absolute(tmp, out=tmp) # absolute dist diffs
return np.asarray([(tmp <= thresh).sum() for thresh in thresholds],
dtype=np.int32)
def _EvalAtoms(
self, pos, atom_indices, thresholds, ref_indices, ref_distances
):
"""Calls _EvalAtom for several atoms and sums up the computed number
of distance differences within given thresholds
returns numpy matrix of shape (n_atoms, len(threshold))
"""
conserved = np.zeros((len(atom_indices), len(thresholds)),
dtype=np.int32)
for a_idx, a in enumerate(atom_indices):
conserved[a_idx, :] = self._EvalAtom(pos, a, thresholds,
ref_indices, ref_distances)
return conserved
def _EvalResidues(self, pos, thresholds, res_atom_indices, ref_indices,
ref_distances):
"""Calls _EvalAtoms for a bunch of residues
residues are defined in *res_atom_indices* as lists of atom indices
returns numpy matrix of shape (n_residues, len(thresholds)).
"""
conserved = np.zeros((len(res_atom_indices), len(thresholds)),
dtype=np.int32)
for rai_idx, rai in enumerate(res_atom_indices):
conserved[rai_idx,:] = np.sum(self._EvalAtoms(pos, rai, thresholds,
ref_indices, ref_distances), axis=0)
return conserved
def _ProcessSequenceSeparation(self):
if self.sequence_separation != 0:
raise NotImplementedError("Congratulations! You're the first one "
"requesting a non-default "
"sequence_separation in the new and "
"awesome lDDT implementation. A crate of "
"beer for Gabriel and he'll implement "
"it.")
def _GetNExp(self, atom_idx, ref_indices):
"""Returns number of close atoms around one or several atoms
"""
if isinstance(atom_idx, int):
return len(ref_indices[atom_idx])
elif isinstance(atom_idx, list):
return sum([len(ref_indices[idx]) for idx in atom_idx])
else:
raise RuntimeError("invalid input type")
def _ResolveSymmetries(self, pos, thresholds, symmetries, sym_ref_indices,
sym_ref_distances):
"""Swaps symmetric positions in-place in order to maximize lDDT scores
towards non-symmetric atoms.
"""
for sym in symmetries:
atom_indices = list()
for sym_tuple in sym:
atom_indices += [sym_tuple[0], sym_tuple[1]]
tot = self._GetNExp(atom_indices, sym_ref_indices)
if tot == 0:
continue # nothing to do
# score as is
sym_one_conserved = self._EvalAtoms(
pos,
atom_indices,
thresholds,
sym_ref_indices,
sym_ref_distances,
)
# switch positions and score again
for pair in sym:
pos[[pair[0], pair[1]]] = pos[[pair[1], pair[0]]]
sym_two_conserved = self._EvalAtoms(
pos,
atom_indices,
thresholds,
sym_ref_indices,
sym_ref_distances,
)
sym_one_score = np.sum(sym_one_conserved) / (len(thresholds) * tot)
sym_two_score = np.sum(sym_two_conserved) / (len(thresholds) * tot)
if sym_one_score >= sym_two_score:
# switch back, initial positions were better or equal
# for the equal case: we still switch back to reproduce the old
# lDDT behaviour
for pair in sym:
pos[[pair[0], pair[1]]] = pos[[pair[1], pair[0]]]