import numpy as np from ost import mol from ost import conop class CustomCompound: """ Defines atoms for custom compounds lDDT requires the reference atoms of a compound which are typically extracted from a :class:`ost.conop.CompoundLib`. This lightweight container allows to handle arbitrary compounds which are not necessarily in the compound library. :param atom_names: Names of atoms of custom compound :type atom_names: :class:`list` of :class:`str` """ def __init__(self, atom_names): self.atom_names = atom_names @staticmethod def FromResidue(res): """ Construct custom compound from residue :param res: Residue from which reference atom names are extracted :type res: :class:`ost.mol.ResidueView`/:class:`ost.mol.ResidueHandle` :returns: :class:`CustomCompound` """ atom_names = [a.GetName() for a in res.atoms] if len(atom_names) != len(set(atom_names)): raise RuntimeError("Duplicate atoms detected in CustomCompound") compound = CustomCompound(atom_names) return compound class SymmetrySettings: """Container for symmetric compounds lDDT considers symmetries and selects the one resulting in the highest possible score. A symmetry is defined as a renaming operation on one or more atoms that leads to a chemically equivalent residue. Example would be OD1 and OD2 in ASP => renaming OD1 to OD2 and vice versa gives a chemically equivalent residue. Use :func:`AddSymmetricCompound` to define a symmetry which can then directly be accessed through the *symmetric_compounds* member. """ def __init__(self): self.symmetric_compounds = dict() def AddSymmetricCompound(self, name, symmetric_atoms): """Adds symmetry for compound with *name* :param name: Name of compound with symmetry :type name: :class:`str` :param symmetric_atoms: Pairs of atom names that define renaming operation, i.e. after applying all switches defined in the tuples, the resulting residue should be chemically equivalent. Atom names must refer to the PDB component dictionary. :type symmetric_atoms: :class:`list` of :class:`tuple` """ for pair in symmetric_atoms: if len(pair) != 2: raise RuntimeError("Expect pairs when defining symmetries") self.symmetric_compounds[name] = symmetric_atoms def GetDefaultSymmetrySettings(): """Constructs and returns :class:`SymmetrySettings` object for natural amino acids """ symmetry_settings = SymmetrySettings() # ASP symmetry_settings.AddSymmetricCompound("ASP", [("OD1", "OD2")]) # GLU symmetry_settings.AddSymmetricCompound("GLU", [("OE1", "OE2")]) # LEU symmetry_settings.AddSymmetricCompound("LEU", [("CD1", "CD2")]) # VAL symmetry_settings.AddSymmetricCompound("VAL", [("CG1", "CG2")]) # ARG symmetry_settings.AddSymmetricCompound("ARG", [("NH1", "NH2")]) # PHE symmetry_settings.AddSymmetricCompound( "PHE", [("CD1", "CD2"), ("CE1", "CE2")] ) # TYR symmetry_settings.AddSymmetricCompound( "TYR", [("CD1", "CD2"), ("CE1", "CE2")] ) return symmetry_settings class lDDTScorer: """lDDT scorer object for a specific target Sets up everything to score models of that target. lDDT (local distance difference test) is defined as fraction of pairwise distances which exhibit a difference < threshold when considering target and model. In case of multiple thresholds, the average is returned. See V. Mariani, M. Biasini, A. Barbato, T. Schwede, lDDT : A local superposition-free score for comparing protein structures and models using distance difference tests, Bioinformatics, 2013 :param target: The target :type target: :class:`ost.mol.EntityHandle`/:class:`ost.mol.EntityView` :param compound_lib: Compound library from which a compound for each residue is extracted based on its name. Uses :func:`ost.conop.GetDefaultLib` if not given, raises if this returns no valid compound library. Atoms defined in the compound are searched in the residue and build the reference for scoring. If the residue has atoms with names ["A", "B", "C"] but the corresponding compound only has ["A", "B"], "A" and "B" are considered for scoring. If the residue has atoms ["A", "B"] but the compound has ["A", "B", "C"], "C" is considered missing and does not influence scoring, even if present in the model. :param custom_compounds: Custom compounds defining reference atoms. If given, *custom_compounds* take precedent over *compound_lib*. :type custom_compounds: :class:`dict` with residue names (:class:`str`) as key and :class:`CustomCompound` as value. :type compound_lib: :class:`ost.conop.CompoundLib` :param inclusion_radius: All pairwise distances < *inclusion_radius* are considered for scoring :type inclusion_radius: :class:`float` :param sequence_separation: Only pairwise distances between atoms of residues which are further apart than this threshold are considered. Residue distance is based on resnum. The default (0) considers all pairwise distances except intra-residue distances. :type sequence_separation: :class:`int` :param symmetry_settings: Define residues exhibiting internal symmetry, uses :func:`GetDefaultSymmetrySettings` if not given. :type symmetry_settings: :class:`SymmetrySettings` :param seqres_mapping: Mapping of model residues at the scoring stage happens with residue numbers defining their location in a reference sequence (SEQRES) using one based indexing. If the residue numbers in *target* don't correspond to that SEQRES, you can specify the mapping manually. You can provide a dictionary to specify a reference sequence (SEQRES) for one or more chain(s). Key: chain name, value: alignment (seq1: SEQRES, seq2: sequence of residues in chain). Example: The residues in a chain with name "A" have sequence "YEAH" and residue numbers [42,43,44,45]. You can provide an alignment with seq1 "``HELLYEAH``" and seq2 "``----YEAH``". "Y" gets assigned residue number 5, "E" gets assigned 6 and so on no matter what the original residue numbers were. :type seqres_mapping: :class:`dict` (key: :class:`str`, value: :class:`ost.seq.AlignmentHandle`) :param bb_only: Only consider atoms with name "CA" in case of amino acids and "C3'" for Nucleotides. this invalidates *compound_lib*. Raises if any residue in *target* is not `r.chem_class.IsPeptideLinking()` or `r.chem_class.IsNucleotideLinking()` :type bb_only: :class:`bool` :raises: :class:`RuntimeError` if *target* contains compound which is not in *compound_lib*, :class:`RuntimeError` if *symmetry_settings* specifies symmetric atoms that are not present in the according compound in *compound_lib*, :class:`RuntimeError` if *seqres_mapping* is not provided and *target* contains residue numbers with insertion codes or the residue numbers for each chain are not monotonically increasing, :class:`RuntimeError` if *seqres_mapping* is provided but an alignment is invalid (seq1 contains gaps, mismatch in seq1/seq2, seq2 does not match residues in corresponding chains). """ def __init__( self, target, compound_lib=None, custom_compounds=None, inclusion_radius=15, sequence_separation=0, symmetry_settings=None, seqres_mapping=dict(), bb_only=False ): self.target = target self.inclusion_radius = inclusion_radius self.sequence_separation = sequence_separation if compound_lib is None: compound_lib = conop.GetDefaultLib() if compound_lib is None: raise RuntimeError("No compound_lib given and conop.GetDefaultLib " "returns no valid compound library") self.compound_lib = compound_lib self.custom_compounds = custom_compounds if symmetry_settings is None: self.symmetry_settings = GetDefaultSymmetrySettings() else: self.symmetry_settings = symmetry_settings # whether to only consider atoms with name "CA" (amino acids) or C3' # (nucleotides), invalidates *compound_lib* self.bb_only=bb_only # names of heavy atoms of each unique compound present in *target* as # extracted from *compound_lib*, e.g. # self.compound_anames["GLY"] = ["N", "CA", "C", "O"] self.compound_anames = dict() # stores symmetry information for those compounds as defined in # *symmetry_settings* self.compound_symmetric_atoms = dict() # list of len(target.chains) containing all chain names in *target* self.chain_names = list() # list of len(target.residues) containing all compound names in *target* self.compound_names = list() # list of len(target.residues) defining start pos in internal reference # positions for each residue self.res_start_indices = list() # list of len(target.residues) defining residue numbers in target self.res_resnums = list() # list of len(target.chains) defining start pos in internal reference # positions for each chain self.chain_start_indices = list() # list of len(target.chains) defining start pos in self.compound_names # for each chain self.chain_res_start_indices = list() # maps residues in *target* to indices in # self.compound_names/self.res_start_indices. A residue gets identified # by a tuple (first element: chain name, second element: residue number, # residue number is either the actual residue number in *target* or # given by *seqres_mapping*) self.res_mapper = dict() # number of atoms as specified in compounds. not all are necessarily # covered by structure self.n_atoms = None # stores an index for each AtomHandle in *target* # (atom hashcode => index) self.atom_indices = dict() # store indices of all atoms that have symmetry properties self.symmetric_atoms = set() # setup members defined above self._SetupEnv(self.compound_lib, self.custom_compounds, self.symmetry_settings, seqres_mapping, self.bb_only) # distance related members are lazily computed as they're affected # by different flavours of lDDT (e.g. lDDT including inter-chain # contacts or not etc.) # stores for each atom the other atoms within inclusion_radius self._ref_indices = None # the corresponding distances self._ref_distances = None # The following lists will be sparsely populated. We keep for each # symmetry related atom the distances towards all atoms which are NOT # affected by symmetry. So we can evaluate two symmetric versions # against the fixed stuff later on and select the better scoring one. self._sym_ref_indices = None self._sym_ref_distances = None # total number of distances self._n_distances = None # exactly the same as above but without interchain contacts # => single-chain (sc) self._ref_indices_sc = None self._ref_distances_sc = None self._sym_ref_indices_sc = None self._sym_ref_distances_sc = None self._n_distances_sc = None # exactly the same as above but without intrachain contacts # => inter-chain (ic) self._ref_indices_ic = None self._ref_distances_ic = None self._sym_ref_indices_ic = None self._sym_ref_distances_ic = None self._n_distances_ic = None # input parameter checking self._ProcessSequenceSeparation() @property def ref_indices(self): if self._ref_indices is None: self._SetupDistances() return self._ref_indices @property def ref_distances(self): if self._ref_distances is None: self._SetupDistances() return self._ref_distances @property def sym_ref_indices(self): if self._sym_ref_indices is None: self._SetupDistances() return self._sym_ref_indices @property def sym_ref_distances(self): if self._sym_ref_distances is None: self._SetupDistances() return self._sym_ref_distances @property def n_distances(self): if self._n_distances is None: self._n_distances = sum([len(x) for x in self.ref_indices]) return self._n_distances @property def ref_indices_sc(self): if self._ref_indices_sc is None: self._SetupDistancesSC() return self._ref_indices_sc @property def ref_distances_sc(self): if self._ref_distances_sc is None: self._SetupDistancesSC() return self._ref_distances_sc @property def sym_ref_indices_sc(self): if self._sym_ref_indices_sc is None: self._SetupDistancesSC() return self._sym_ref_indices_sc @property def sym_ref_distances_sc(self): if self._sym_ref_distances_sc is None: self._SetupDistancesSC() return self._sym_ref_distances_sc @property def n_distances_sc(self): if self._n_distances_sc is None: self._n_distances_sc = sum([len(x) for x in self.ref_indices_sc]) return self._n_distances_sc @property def ref_indices_ic(self): if self._ref_indices_ic is None: self._SetupDistancesIC() return self._ref_indices_ic @property def ref_distances_ic(self): if self._ref_distances_ic is None: self._SetupDistancesIC() return self._ref_distances_ic @property def sym_ref_indices_ic(self): if self._sym_ref_indices_ic is None: self._SetupDistancesIC() return self._sym_ref_indices_ic @property def sym_ref_distances_ic(self): if self._sym_ref_distances_ic is None: self._SetupDistancesIC() return self._sym_ref_distances_ic @property def n_distances_ic(self): if self._n_distances_ic is None: self._n_distances_ic = sum([len(x) for x in self.ref_indices_ic]) return self._n_distances_ic def lDDT(self, model, thresholds = [0.5, 1.0, 2.0, 4.0], local_lddt_prop=None, local_contact_prop=None, chain_mapping=None, no_interchain=False, no_intrachain=False, penalize_extra_chains=False, residue_mapping=None, return_dist_test=False, check_resnames=True): """Computes lDDT of *model* - globally and per-residue :param model: Model to be scored - models are preferably scored upon performing stereo-chemistry checks in order to punish for non-sensical irregularities. This must be done separately as a pre-processing step. :type model: :class:`ost.mol.EntityHandle`/:class:`ost.mol.EntityView` :param thresholds: Thresholds of distance differences to be considered as correct - see docs in constructor for more info. default: [0.5, 1.0, 2.0, 4.0] :type thresholds: :class:`list` of :class:`floats` :param local_lddt_prop: If set, per-residue scores will be assigned as generic float property of that name :type local_lddt_prop: :class:`str` :param local_contact_prop: If set, number of expected contacts as well as number of conserved contacts will be assigned as generic int property. Excected contacts will be set as <local_contact_prop>_exp, conserved contacts as <local_contact_prop>_cons. Values are summed over all thresholds. :type local_contact_prop: :class:`str` :param chain_mapping: Mapping of model chains (key) onto target chains (value). This is required if target or model have more than one chain. :type chain_mapping: :class:`dict` with :class:`str` as keys/values :param no_interchain: Whether to exclude interchain contacts :type no_interchain: :class:`bool` :param no_intrachain: Whether to exclude intrachain contacts (i.e. only consider interface related contacts) :type no_intrachain: :class:`bool` :param penalize_extra_chains: Whether to include a fixed penalty for additional chains in the model that are not mapped to the target. ONLY AFFECTS RETURNED GLOBAL SCORE. In detail: adds the number of intra-chain contacts of each extra chain to the expected contacts, thus adding a penalty. :param penalize_extra_chains: :class:`bool` :param residue_mapping: By default, residue mapping is based on residue numbers. That means, a model chain and the respective target chain map to the same underlying reference sequence (SEQRES). Alternatively, you can specify one or several alignment(s) between model and target chains by providing a dictionary. key: Name of chain in model (respective target chain is extracted from *chain_mapping*), value: Alignment with first sequence corresponding to target chain and second sequence to model chain. There is NO reference sequence involved, so the two sequences MUST exactly match the actual residues observed in the respective target/model chains (ATOMSEQ). :type residue_mapping: :class:`dict` with key: :class:`str`, value: :class:`ost.seq.AlignmentHandle` :param return_dist_test: Whether to additionally return the underlying per-residue data for the distance difference test. Adds five objects to the return tuple. First: Number of total contacts summed over all thresholds Second: Number of conserved contacts summed over all thresholds Third: list with length of scored residues. Contains indices referring to model.residues. Fourth: numpy array of size len(scored_residues) containing the number of total contacts, Fifth: numpy matrix of shape (len(scored_residues), len(thresholds)) specifying how many for each threshold are conserved. :param check_resnames: On by default. Enforces residue name matches between mapped model and target residues. :type check_resnames: :class:`bool` :returns: global and per-residue lDDT scores as a tuple - first element is global lDDT score and second element a list of per-residue scores with length len(*model*.residues) None is assigned to residues that are not covered by target """ if chain_mapping is None: if len(self.chain_names) > 1 or len(model.chains) > 1: raise NotImplementedError("Must provide chain mapping if " "target or model have > 1 chains.") chain_mapping = {model.chains[0].GetName(): self.chain_names[0]} else: # check whether chains specified in mapping exist for model_chain, target_chain in chain_mapping.items(): if target_chain not in self.chain_names: raise RuntimeError(f"Target chain specified in " f"chain_mapping ({target_chain}) does " f"not exist. Target has chains: " f"{self.chain_names}") ch = model.FindChain(model_chain) if not ch.IsValid(): raise RuntimeError(f"Model chain specified in " f"chain_mapping ({model_chain}) does " f"not exist. Model has chains: " f"{[c.GetName() for c in model.chains]}") # initialize positions with values far in nirvana. If a position is not # set, it should be far away from any position in model. max_pos = model.bounds.GetMax() max_coordinate = abs(max(max_pos[0], max_pos[1], max_pos[2])) max_coordinate += 42 * max(thresholds) pos = np.ones((self.n_atoms, 3), dtype=np.float32) * max_coordinate # for each scored residue in model a list of indices describing the # atoms from the reference that should be there res_ref_atom_indices = list() # for each scored residue in model a list of indices of atoms that are # actually there res_atom_indices = list() # indices of the scored residues res_indices = list() # Will contain one element per symmetry group symmetries = list() current_model_res_idx = -1 for ch in model.chains: model_ch_name = ch.GetName() if model_ch_name not in chain_mapping: current_model_res_idx += len(ch.residues) continue # additional model chain which is not mapped target_ch_name = chain_mapping[model_ch_name] rnums = self._GetChainRNums(ch, residue_mapping, model_ch_name, target_ch_name) for r, rnum in zip(ch.residues, rnums): current_model_res_idx += 1 res_mapper_key = (target_ch_name, rnum) if res_mapper_key not in self.res_mapper: continue r_idx = self.res_mapper[res_mapper_key] if check_resnames and r.name != self.compound_names[r_idx]: raise RuntimeError( f"Residue name mismatch for {r}, " f" expect {self.compound_names[r_idx]}" ) res_start_idx = self.res_start_indices[r_idx] rname = self.compound_names[r_idx] anames = self.compound_anames[rname] atoms = [r.FindAtom(aname) for aname in anames] res_ref_atom_indices.append( list(range(res_start_idx, res_start_idx + len(anames))) ) res_atom_indices.append(list()) res_indices.append(current_model_res_idx) for a_idx, a in enumerate(atoms): if a.IsValid(): p = a.GetPos() pos[res_start_idx + a_idx][0] = p[0] pos[res_start_idx + a_idx][1] = p[1] pos[res_start_idx + a_idx][2] = p[2] res_atom_indices[-1].append(res_start_idx + a_idx) if rname in self.compound_symmetric_atoms: sym_indices = list() for sym_tuple in self.compound_symmetric_atoms[rname]: a_one = atoms[sym_tuple[0]] a_two = atoms[sym_tuple[1]] if a_one.IsValid() and a_two.IsValid(): sym_indices.append( ( res_start_idx + sym_tuple[0], res_start_idx + sym_tuple[1], ) ) if len(sym_indices) > 0: symmetries.append(sym_indices) if no_interchain and no_intrachain: raise RuntimeError("on_interchain and no_intrachain flags are " "mutually exclusive") if no_interchain: sym_ref_indices = self.sym_ref_indices_sc sym_ref_distances = self.sym_ref_distances_sc ref_indices = self.ref_indices_sc ref_distances = self.ref_distances_sc n_distances = self.n_distances_sc elif no_intrachain: sym_ref_indices = self.sym_ref_indices_ic sym_ref_distances = self.sym_ref_distances_ic ref_indices = self.ref_indices_ic ref_distances = self.ref_distances_ic n_distances = self.n_distances_ic else: sym_ref_indices = self.sym_ref_indices sym_ref_distances = self.sym_ref_distances ref_indices = self.ref_indices ref_distances = self.ref_distances n_distances = self.n_distances self._ResolveSymmetries(pos, thresholds, symmetries, sym_ref_indices, sym_ref_distances) per_res_exp = np.asarray([self._GetNExp(res_ref_atom_indices[idx], ref_indices) for idx in range(len(res_indices))], dtype=np.int32) per_res_conserved = self._EvalResidues(pos, thresholds, res_atom_indices, ref_indices, ref_distances) n_thresh = len(thresholds) # do per-residue scores per_res_lDDT = [None] * len(model.residues) for idx in range(len(res_indices)): n_exp = n_thresh * per_res_exp[idx] if n_exp > 0: score = np.sum(per_res_conserved[idx,:]) / n_exp per_res_lDDT[res_indices[idx]] = score # do full model score if penalize_extra_chains: n_distances += self._GetExtraModelChainPenalty(model, chain_mapping) lDDT_tot = int(n_thresh * n_distances) lDDT_cons = int(np.sum(per_res_conserved)) lDDT = None if lDDT_tot > 0: lDDT = float(lDDT_cons) / lDDT_tot # set properties if necessary if local_lddt_prop: residues = model.residues for idx in res_indices: residues[idx].SetFloatProp(local_lddt_prop, per_res_lDDT[idx]) if local_contact_prop: residues = model.residues exp_prop = local_contact_prop + "_exp" conserved_prop = local_contact_prop + "_cons" for i, r_idx in enumerate(res_indices): residues[r_idx].SetIntProp(exp_prop, n_thresh * int(per_res_exp[i])) residues[r_idx].SetIntProp(conserved_prop, int(np.sum(per_res_conserved[i,:]))) if return_dist_test: return lDDT, per_res_lDDT, lDDT_tot, lDDT_cons, res_indices, \ per_res_exp, per_res_conserved else: return lDDT, per_res_lDDT def GetNChainContacts(self, target_chain, no_interchain=False): """Returns number of contacts expected for a certain chain in *target* :param target_chain: Chain in *target* for which you want the number of expected contacts :type target_chain: :class:`str` :param no_interchain: Whether to exclude interchain contacts :type no_interchain: :class:`bool` :raises: :class:`RuntimeError` if specified chain doesnt exist """ if target_chain not in self.chain_names: raise RuntimeError(f"Specified chain name ({target_chain}) not in " f"target") ch_idx = self.chain_names.index(target_chain) s = self.chain_start_indices[ch_idx] e = self.n_atoms if ch_idx + 1 < len(self.chain_names): e = self.chain_start_indices[ch_idx+1] if no_interchain: return self._GetNExp(list(range(s, e)), self.ref_indices_sc) else: return self._GetNExp(list(range(s, e)), self.ref_indices) def _GetExtraModelChainPenalty(self, model, chain_mapping): """Counts n distances in extra model chains to be added as penalty """ penalty = 0 for chain in model.chains: ch_name = chain.GetName() if ch_name not in chain_mapping: sm = self.symmetry_settings dummy_scorer = lDDTScorer(model.Select("cname="+ch_name), self.compound_lib, symmetry_settings = sm, inclusion_radius = self.inclusion_radius, bb_only = self.bb_only) penalty += dummy_scorer.n_distances return penalty def _GetChainRNums(self, ch, residue_mapping, model_ch_name, target_ch_name): """Map residues in model chain to target residues There are two options: one is simply using residue numbers, the other is a custom mapping as given in *residue_mapping* """ if residue_mapping and model_ch_name in residue_mapping: # extract residue numbers from target chain ch_idx = self.chain_names.index(target_ch_name) start_idx = self.chain_res_start_indices[ch_idx] if ch_idx < len(self.chain_names) - 1: end_idx = self.chain_res_start_indices[ch_idx+1] else: end_idx = len(self.compound_names) target_rnums = self.res_resnums[start_idx:end_idx] # get sequences from alignment and do consistency checks target_seq = residue_mapping[model_ch_name].GetSequence(0) model_seq = residue_mapping[model_ch_name].GetSequence(1) if len(target_seq.GetGaplessString()) != len(target_rnums): raise RuntimeError(f"Try to perform residue mapping for " f"model chain {model_ch_name} which " f"maps to {target_ch_name} in target. " f"Target sequence in alignment suggests " f"{len(target_seq.GetGaplessString())} " f"residues but {len(target_rnums)} are " f"expected.") if len(model_seq.GetGaplessString()) != len(ch.residues): raise RuntimeError(f"Try to perform residue mapping for " f"model chain {model_ch_name} which " f"maps to {target_ch_name} in target. " f"Model sequence in alignment suggests " f"{len(model_seq.GetGaplessString())} " f"residues but {len(ch.residues)} are " f"expected.") rnums = list() target_idx = -1 for col in residue_mapping[model_ch_name]: if col[0] != '-': target_idx += 1 # handle match if col[0] != '-' and col[1] != '-': rnums.append(target_rnums[target_idx]) # insertion in model adds None to rnum if col[0] == '-' and col[1] != '-': rnums.append(None) else: rnums = [r.GetNumber() for r in ch.residues] if sum([len(rn.GetInsCode().strip("\0")) for rn in rnums]) > 0: raise RuntimeError( "Residue numbers in model must not " "contain insertion codes" ) rnums = [rn.GetNum() for rn in rnums] return rnums def _SetupEnv(self, compound_lib, custom_compounds, symmetry_settings, seqres_mapping, bb_only): """Sets target related lDDTScorer members defined in constructor No distance related members - see _SetupDistances """ residue_numbers = self._GetTargetResidueNumbers(self.target, seqres_mapping) current_idx = 0 for chain in self.target.chains: ch_name = chain.GetName() self.chain_names.append(ch_name) self.chain_start_indices.append(current_idx) self.chain_res_start_indices.append(len(self.compound_names)) for r, rnum in zip(chain.residues, residue_numbers[ch_name]): if r.name not in self.compound_anames: # sets compound info in self.compound_anames and # self.compound_symmetric_atoms self._SetupCompound(r, compound_lib, custom_compounds, symmetry_settings, bb_only) self.res_start_indices.append(current_idx) self.res_mapper[(ch_name, rnum)] = len(self.compound_names) self.compound_names.append(r.name) self.res_resnums.append(rnum) atoms = [r.FindAtom(an) for an in self.compound_anames[r.name]] for a in atoms: if a.IsValid(): self.atom_indices[a.handle.GetHashCode()] = current_idx current_idx += 1 if r.name in self.compound_symmetric_atoms: for sym_tuple in self.compound_symmetric_atoms[r.name]: for a_idx in sym_tuple: a = atoms[a_idx] if a.IsValid(): hashcode = a.handle.GetHashCode() self.symmetric_atoms.add( self.atom_indices[hashcode] ) self.n_atoms = current_idx def _GetTargetResidueNumbers(self, target, seqres_mapping): """Returns residue numbers for each chain in target as dict They're either directly extracted from the raw residue number from the structure or from user provided alignments """ residue_numbers = dict() for ch in target.chains: ch_name = ch.GetName() rnums = list() if ch_name in seqres_mapping: seqres = seqres_mapping[ch_name].GetSequence(0).GetString() atomseq = seqres_mapping[ch_name].GetSequence(1).GetString() # SEQRES must not contain gaps if "-" in seqres: raise RuntimeError( "SEQRES in seqres_mapping must not " "contain gaps" ) atomseq_from_chain = [r.one_letter_code for r in ch.residues] if atomseq.replace("-", "") != atomseq_from_chain: raise RuntimeError( "ATOMSEQ in seqres_mapping must match " "raw sequence extracted from chain " "residues" ) rnum = 0 for seqres_olc, atomseq_olc in zip(seqres, atomseq): if seqres_olc != "-": rnum += 1 if atomseq_olc != "-": if seqres_olc != atomseq_olc: raise RuntimeError( f"Residue with number {rnum} in " f"chain {ch_name} has SEQRES " f"ATOMSEQ mismatch" ) rnums.append(rnum) else: rnums = [r.GetNumber() for r in ch.residues] if sum([len(rn.GetInsCode().strip("\0")) for rn in rnums]) > 0: raise RuntimeError( "Residue numbers in target must not " "contain insertion codes" ) rnums = [rnum.GetNum() for rnum in rnums] if not all(x < y for x, y in zip(rnums, rnums[1:])): raise RuntimeError( "Residue numbers in each target chain " "must be monotonically increasing" ) assert len(rnums) == len(ch.residues) residue_numbers[ch_name] = rnums return residue_numbers def _SetupCompound(self, r, compound_lib, custom_compounds, symmetry_settings, bb_only): """fill self.compound_anames/self.compound_symmetric_atoms """ if bb_only: # throw away compound_lib info if r.chem_class.IsPeptideLinking(): self.compound_anames[r.name] = ["CA"] elif r.chem_class.IsNucleotideLinking(): self.compound_anames[r.name] = ["C3'"] else: raise RuntimeError(f"Only support amino acids and nucleotides " f"if bb_only is True, failed with {str(r)}") self.compound_symmetric_atoms[r.name] = list() else: atom_names = list() symmetric_atoms = list() if custom_compounds is not None and r.GetName() in custom_compounds: atom_names = list(custom_compounds[r.GetName()].atom_names) else: compound = compound_lib.FindCompound(r.name) if compound is None: raise RuntimeError(f"no entry for {r} in compound_lib") for atom_spec in compound.GetAtomSpecs(): if atom_spec.element not in ["H", "D"]: atom_names.append(atom_spec.name) if r.name in symmetry_settings.symmetric_compounds: for pair in symmetry_settings.symmetric_compounds[r.name]: try: a = atom_names.index(pair[0]) b = atom_names.index(pair[1]) except: msg = f"Could not find symmetric atoms " msg += f"({pair[0]}, {pair[1]}) for {r.name} " msg += f"as specified in SymmetrySettings in " msg += f"compound from component dictionary. " msg += f"Atoms in compound: {atom_names}" raise RuntimeError(msg) symmetric_atoms.append((a, b)) self.compound_anames[r.name] = atom_names if len(symmetric_atoms) > 0: self.compound_symmetric_atoms[r.name] = symmetric_atoms def _SetupDistances(self): """Compute distance related members of lDDTScorer """ # init self._ref_indices = [np.asarray([], dtype=np.int64) for idx in range(self.n_atoms)] self._ref_distances = [np.asarray([], dtype=np.float64) for idx in range(self.n_atoms)] self._sym_ref_indices = [np.asarray([], dtype=np.int64) for idx in range(self.n_atoms)] self._sym_ref_distances = [np.asarray([], dtype=np.float64) for idx in range(self.n_atoms)] # initialize positions with values far in nirvana. If a position is not # set, it should be far away from any position in target (or at least # more than inclusion_radius). max_pos = self.target.bounds.GetMax() max_coordinate = abs(max(max_pos[0], max_pos[1], max_pos[2])) max_coordinate += 2 * self.inclusion_radius pos = np.ones((self.n_atoms, 3), dtype=np.float32) * max_coordinate atom_indices = list() mask_start = list() mask_end = list() for r_idx, r in enumerate(self.target.residues): r_start_idx = self.res_start_indices[r_idx] r_n_atoms = len(self.compound_anames[r.name]) r_end_idx = r_start_idx + r_n_atoms for a in r.atoms: if a.handle.GetHashCode() in self.atom_indices: idx = self.atom_indices[a.handle.GetHashCode()] p = a.GetPos() pos[idx][0] = p[0] pos[idx][1] = p[1] pos[idx][2] = p[2] atom_indices.append(idx) mask_start.append(r_start_idx) mask_end.append(r_end_idx) indices, distances = self._CloseStuff(pos, self.inclusion_radius, atom_indices, mask_start, mask_end) for i in range(len(atom_indices)): self._ref_indices[atom_indices[i]] = indices[i] self._ref_distances[atom_indices[i]] = distances[i] self._NonSymDistances(self._ref_indices, self._ref_distances, self._sym_ref_indices, self._sym_ref_distances) def _SetupDistancesSC(self): """Select subset of contacts only covering intra-chain contacts """ # init self._ref_indices_sc = [np.asarray([], dtype=np.int64) for idx in range(self.n_atoms)] self._ref_distances_sc = [np.asarray([], dtype=np.float64) for idx in range(self.n_atoms)] self._sym_ref_indices_sc = [np.asarray([], dtype=np.int64) for idx in range(self.n_atoms)] self._sym_ref_distances_sc = [np.asarray([], dtype=np.float64) for idx in range(self.n_atoms)] # start from overall contacts ref_indices = self.ref_indices ref_distances = self.ref_distances sym_ref_indices = self.sym_ref_indices sym_ref_distances = self.sym_ref_distances n_chains = len(self.chain_start_indices) for ch_idx, ch in enumerate(self.target.chains): chain_s = self.chain_start_indices[ch_idx] chain_e = self.n_atoms if ch_idx + 1 < n_chains: chain_e = self.chain_start_indices[ch_idx+1] for i in range(chain_s, chain_e): if len(ref_indices[i]) > 0: intra_idx = np.where(np.logical_and(ref_indices[i]>=chain_s, ref_indices[i]<chain_e))[0] self._ref_indices_sc[i] = ref_indices[i][intra_idx] self._ref_distances_sc[i] = ref_distances[i][intra_idx] self._NonSymDistances(self._ref_indices_sc, self._ref_distances_sc, self._sym_ref_indices_sc, self._sym_ref_distances_sc) def _SetupDistancesIC(self): """Select subset of contacts only covering inter-chain contacts """ # init self._ref_indices_ic = [np.asarray([], dtype=np.int64) for idx in range(self.n_atoms)] self._ref_distances_ic = [np.asarray([], dtype=np.float64) for idx in range(self.n_atoms)] self._sym_ref_indices_ic = [np.asarray([], dtype=np.int64) for idx in range(self.n_atoms)] self._sym_ref_distances_ic = [np.asarray([], dtype=np.float64) for idx in range(self.n_atoms)] # start from overall contacts ref_indices = self.ref_indices ref_distances = self.ref_distances sym_ref_indices = self.sym_ref_indices sym_ref_distances = self.sym_ref_distances n_chains = len(self.chain_start_indices) for ch_idx, ch in enumerate(self.target.chains): chain_s = self.chain_start_indices[ch_idx] chain_e = self.n_atoms if ch_idx + 1 < n_chains: chain_e = self.chain_start_indices[ch_idx+1] for i in range(chain_s, chain_e): if len(ref_indices[i]) > 0: inter_idx = np.where(np.logical_or(ref_indices[i]<chain_s, ref_indices[i]>=chain_e))[0] self._ref_indices_ic[i] = ref_indices[i][inter_idx] self._ref_distances_ic[i] = ref_distances[i][inter_idx] self._NonSymDistances(self._ref_indices_ic, self._ref_distances_ic, self._sym_ref_indices_ic, self._sym_ref_distances_ic) def _CloseStuff(self, pos, inclusion_radius, indices, mask_start, mask_end): """returns close stuff for positions specified by indices """ # TODO: this function does brute force distance computation which has # quadratic complexity... close_indices = list() distances = list() # work with squared_inclusion_radius (sir) to save some square roots sir = inclusion_radius ** 2 for idx, ms, me in zip(indices, mask_start, mask_end): p = pos[idx, :] tmp = pos - p[None, :] np.square(tmp, out=tmp) tmp = tmp.sum(axis=1) # mask out atoms of own residue => put them far away tmp[range(ms, me)] = 2 * sir close_indices.append(np.nonzero(tmp <= sir)[0]) distances.append(np.sqrt(tmp[close_indices[-1]])) return (close_indices, distances) def _NonSymDistances(self, ref_indices, ref_distances, sym_ref_indices, sym_ref_distances): """Transfer indices/distances of non-symmetric atoms in place """ for idx in self.symmetric_atoms: indices = list() distances = list() for i, d in zip(ref_indices[idx], ref_distances[idx]): if i not in self.symmetric_atoms: indices.append(i) distances.append(d) sym_ref_indices[idx] = indices sym_ref_distances[idx] = np.asarray(distances) def _EvalAtom(self, pos, atom_idx, thresholds, ref_indices, ref_distances): """Computes number of distance differences within given thresholds returns np.array with len(thresholds) elements """ a_p = pos[atom_idx, :] tmp = pos.take(ref_indices[atom_idx], axis=0) np.subtract(tmp, a_p[None, :], out=tmp) np.square(tmp, out=tmp) tmp = tmp.sum(axis=1) np.sqrt(tmp, out=tmp) # distances against all relevant atoms np.subtract(ref_distances[atom_idx], tmp, out=tmp) np.absolute(tmp, out=tmp) # absolute dist diffs return np.asarray([(tmp <= thresh).sum() for thresh in thresholds], dtype=np.int32) def _EvalAtoms( self, pos, atom_indices, thresholds, ref_indices, ref_distances ): """Calls _EvalAtom for several atoms and sums up the computed number of distance differences within given thresholds returns numpy matrix of shape (n_atoms, len(threshold)) """ conserved = np.zeros((len(atom_indices), len(thresholds)), dtype=np.int32) for a_idx, a in enumerate(atom_indices): conserved[a_idx, :] = self._EvalAtom(pos, a, thresholds, ref_indices, ref_distances) return conserved def _EvalResidues(self, pos, thresholds, res_atom_indices, ref_indices, ref_distances): """Calls _EvalAtoms for a bunch of residues residues are defined in *res_atom_indices* as lists of atom indices returns numpy matrix of shape (n_residues, len(thresholds)). """ conserved = np.zeros((len(res_atom_indices), len(thresholds)), dtype=np.int32) for rai_idx, rai in enumerate(res_atom_indices): conserved[rai_idx,:] = np.sum(self._EvalAtoms(pos, rai, thresholds, ref_indices, ref_distances), axis=0) return conserved def _ProcessSequenceSeparation(self): if self.sequence_separation != 0: raise NotImplementedError("Congratulations! You're the first one " "requesting a non-default " "sequence_separation in the new and " "awesome lDDT implementation. A crate of " "beer for Gabriel and he'll implement " "it.") def _GetNExp(self, atom_idx, ref_indices): """Returns number of close atoms around one or several atoms """ if isinstance(atom_idx, int): return len(ref_indices[atom_idx]) elif isinstance(atom_idx, list): return sum([len(ref_indices[idx]) for idx in atom_idx]) else: raise RuntimeError("invalid input type") def _ResolveSymmetries(self, pos, thresholds, symmetries, sym_ref_indices, sym_ref_distances): """Swaps symmetric positions in-place in order to maximize lDDT scores towards non-symmetric atoms. """ for sym in symmetries: atom_indices = list() for sym_tuple in sym: atom_indices += [sym_tuple[0], sym_tuple[1]] tot = self._GetNExp(atom_indices, sym_ref_indices) if tot == 0: continue # nothing to do # score as is sym_one_conserved = self._EvalAtoms( pos, atom_indices, thresholds, sym_ref_indices, sym_ref_distances, ) # switch positions and score again for pair in sym: pos[[pair[0], pair[1]]] = pos[[pair[1], pair[0]]] sym_two_conserved = self._EvalAtoms( pos, atom_indices, thresholds, sym_ref_indices, sym_ref_distances, ) sym_one_score = np.sum(sym_one_conserved) / (len(thresholds) * tot) sym_two_score = np.sum(sym_two_conserved) / (len(thresholds) * tot) if sym_one_score >= sym_two_score: # switch back, initial positions were better or equal # for the equal case: we still switch back to reproduce the old # lDDT behaviour for pair in sym: pos[[pair[0], pair[1]]] = pos[[pair[1], pair[0]]]