From 4e6e83055b51d4bd60eefce79aad2cfde23809b7 Mon Sep 17 00:00:00 2001 From: Gabriel Studer <gabriel.studer@unibas.ch> Date: Wed, 4 Sep 2024 09:34:37 +0200 Subject: [PATCH] lddt: enable per-atom lDDT scores --- actions/ost-compare-structures | 35 +++++++++++++- modules/doc/actions.rst | 13 +++++ modules/mol/alg/pymod/lddt.py | 62 ++++++++++++++++++++++-- modules/mol/alg/pymod/scoring.py | 81 ++++++++++++++++++++++++++++---- 4 files changed, 176 insertions(+), 15 deletions(-) diff --git a/actions/ost-compare-structures b/actions/ost-compare-structures index 9a466ce05..1afa5198a 100644 --- a/actions/ost-compare-structures +++ b/actions/ost-compare-structures @@ -247,6 +247,25 @@ def _ParseArgs(): "counterparts. Atoms specified in there follow the following " "format: <chain_name>.<resnum>.<resnum_inscode>.<atom_name>")) + parser.add_argument( + "--aa-local-lddt", + dest="aa_local_lddt", + default=False, + action="store_true", + help=("Compute per-atom lDDT scores with default parameterization " + "and store as key \"aa_local_lddt\". Score for each atom is " + "accessible by key " + "<chain_name>.<resnum>.<resnum_inscode>.<aname>. " + "Alpha carbon from residue with number 42 in chain X can be " + "extracted with: data[\"aa_local_lddt\"][\"X.42..CA\"]. " + "If there is a residue insertion code, lets say A, the atom key " + "becomes \"X.42.A.CA\". " + "Stereochemical irregularities affecting lDDT are reported as " + "keys \"model_clashes\", \"model_bad_bonds\", " + "\"model_bad_angles\" and the respective reference " + "counterparts. Atoms specified in there follow the following " + "format: <chain_name>.<resnum>.<resnum_inscode>.<atom_name>")) + parser.add_argument( "--bb-lddt", dest="bb_lddt", @@ -712,6 +731,17 @@ def _LocalScoresToJSONDict(score_dict): json_dict[f"{ch}.{num.num}.{ins_code}"] = _RoundOrNone(s) return json_dict +def _LocalAAScoresToJSONDict(score_dict): + """ Convert ResNums and atom names to str for JSON serialization + """ + json_dict = dict() + for ch, ch_scores in score_dict.items(): + for num, res_scores in ch_scores.items(): + ins_code = num.ins_code.strip("\u0000") + for a, s in res_scores.items(): + json_dict[f"{ch}.{num.num}.{ins_code}.{a}"] = _RoundOrNone(s) + return json_dict + def _InterfaceResiduesToJSONList(interface_dict): """ Convert ResNums to str for JSON serialization. @@ -816,7 +846,10 @@ def _Process(model, reference, args, model_format, reference_format): if args.local_lddt: out["local_lddt"] = _LocalScoresToJSONDict(scorer.local_lddt) - if args.lddt or args.local_lddt: + if args.aa_local_lddt: + out["aa_local_lddt"] = _LocalAAScoresToJSONDict(scorer.aa_local_lddt) + + if args.lddt or args.local_lddt or args.aa_local_lddt: out["model_clashes"] = [x.ToJSON() for x in scorer.model_clashes] out["model_bad_bonds"] = [x.ToJSON() for x in scorer.model_bad_bonds] out["model_bad_angles"] = [x.ToJSON() for x in scorer.model_bad_angles] diff --git a/modules/doc/actions.rst b/modules/doc/actions.rst index d1794c49c..89b96d823 100644 --- a/modules/doc/actions.rst +++ b/modules/doc/actions.rst @@ -202,6 +202,19 @@ Details on the usage (output of ``ost compare-structures --help``): counterparts. Atoms specified in there follow the following format: <chain_name>.<resnum>.<resnum_inscode>.<atom_name> + --aa-local-lddt Compute per-atom lDDT scores with default + parameterization and store as key "aa_local_lddt". + Score for each atom is accessible by key + <chain_name>.<resnum>.<resnum_inscode>.<aname>. Alpha + carbon from residue with number 42 in chain X can be + extracted with: data["aa_local_lddt"]["X.42..CA"]. If + there is a residue insertion code, lets say A, the + atom key becomes "X.42.A.CA". Stereochemical + irregularities affecting lDDT are reported as keys + "model_clashes", "model_bad_bonds", "model_bad_angles" + and the respective reference counterparts. Atoms + specified in there follow the following format: + <chain_name>.<resnum>.<resnum_inscode>.<atom_name> --bb-lddt Compute global lDDT score with default parameterization and store as key "bb_lddt". lDDT in this case is only computed on backbone atoms: CA for diff --git a/modules/mol/alg/pymod/lddt.py b/modules/mol/alg/pymod/lddt.py index 1453db731..83c1b2d60 100644 --- a/modules/mol/alg/pymod/lddt.py +++ b/modules/mol/alg/pymod/lddt.py @@ -1,3 +1,4 @@ +import itertools import numpy as np from ost import mol @@ -438,7 +439,7 @@ class lDDTScorer: no_intrachain=False, penalize_extra_chains=False, residue_mapping=None, return_dist_test=False, check_resnames=True, add_mdl_contacts=False, - interaction_data=None): + interaction_data=None, set_atom_props=False): """Computes lDDT of *model* - globally and per-residue :param model: Model to be scored - models are preferably scored upon @@ -530,6 +531,12 @@ class lDDTScorer: :type add_mdl_contacts: :class:`bool` :param interaction_data: Pro param - don't use :type interaction_data: :class:`tuple` + :param set_atom_props: If True, sets generic properties on a per atom + level if *local_lddt_prop*/*local_contact_prop* + are set as well. + In other words: this is the only way you can + get per-atom lDDT values. + :type set_atom_props: :class:`bool` :returns: global and per-residue lDDT scores as a tuple - first element is global lDDT score (None if *target* has no @@ -610,16 +617,28 @@ class lDDTScorer: self._ResolveSymmetries(pos, thresholds, symmetries, sym_ref_indices, sym_ref_distances) + atom_indices = list(itertools.chain.from_iterable(res_atom_indices)) + + per_atom_exp = np.asarray([self._GetNExp(i, ref_indices) + for i in atom_indices], dtype=np.int32) per_res_exp = np.asarray([self._GetNExp(res_ref_atom_indices[idx], ref_indices) for idx in range(len(res_indices))], dtype=np.int32) - per_res_conserved = self._EvalResidues(pos, thresholds, - res_atom_indices, - ref_indices, ref_distances) + + per_atom_conserved = self._EvalAtoms(pos, atom_indices, thresholds, + ref_indices, ref_distances) + per_res_conserved = np.zeros((len(res_atom_indices), len(thresholds)), + dtype=np.int32) + start_idx = 0 + for r_idx in range(len(res_atom_indices)): + end_idx = start_idx + len(res_atom_indices[r_idx]) + per_res_conserved[r_idx] = np.sum(per_atom_conserved[start_idx:end_idx,:], + axis=0) + start_idx = end_idx n_thresh = len(thresholds) # do per-residue scores - per_res_lDDT = [None] * len(model.residues) + per_res_lDDT = [None] * model.GetResidueCount() for idx in range(len(res_indices)): n_exp = n_thresh * per_res_exp[idx] if n_exp > 0: @@ -656,6 +675,39 @@ class lDDTScorer: residues[r_idx].SetIntProp(conserved_prop, int(np.sum(per_res_conserved[i,:]))) + if set_atom_props and (local_lddt_prop or local_contact_prop): + atom_list = list() + residues = model.residues + for i, indices in enumerate(res_atom_indices): + r = residues[res_indices[i]] + r_idx = ref_res_indices[i] + res_start_idx = self.res_start_indices[r_idx] + anames = self.compound_anames[r.GetName()] + for a_i in indices: + a = r.FindAtom(anames[a_i - res_start_idx]) + assert(a.IsValid()) + atom_list.append(a) + + summed_per_atom_conserved = per_atom_conserved.sum(axis=1) + if local_lddt_prop: + # the only place where actually need to compute per-atom lDDT + # scores + for a_idx in range(len(atom_list)): + tmp = summed_per_atom_conserved[a_idx] / per_atom_exp[a_idx] + tmp = tmp / n_thresh + atom_list[a_idx].SetFloatProp(local_lddt_prop, tmp) + + if local_contact_prop: + conserved_prop = local_contact_prop + "_cons" + exp_prop = local_contact_prop + "_exp" + for a_idx in range(len(atom_list)): + # do number of conserved contacts + tmp = summed_per_atom_conserved[a_idx] + atom_list[a_idx].SetIntProp(conserved_prop, tmp) + # do number of expected contacts + tmp = per_atom_exp[a_idx] * n_thresh + atom_list[a_idx].SetIntProp(exp_prop, tmp) + if return_dist_test: return lDDT, per_res_lDDT, lDDT_tot, lDDT_cons, res_indices, \ per_res_exp, per_res_conserved diff --git a/modules/mol/alg/pymod/scoring.py b/modules/mol/alg/pymod/scoring.py index 37304aa94..7445c1f2b 100644 --- a/modules/mol/alg/pymod/scoring.py +++ b/modules/mol/alg/pymod/scoring.py @@ -344,6 +344,7 @@ class Scorer: # lazily computed scores self._lddt = None self._local_lddt = None + self._aa_local_lddt = None self._bb_lddt = None self._bb_local_lddt = None self._ilddt = None @@ -743,6 +744,23 @@ class Scorer: self._compute_lddt() return self._local_lddt + @property + def aa_local_lddt(self): + """ Per atom lDDT scores in range [0.0, 1.0] + + Computed based on :attr:`~stereochecked_model` but scores for all + atoms in :attr:`~model` are reported. If an atom has been removed + by stereochemistry checks, the respective score is set to 0.0. If an + atom is not covered by the target or is in a chain skipped by the + chain mapping procedure (happens for super short chains), the respective + score is set to None. In case of oligomers, :attr:`~mapping` is used. + + :type: :class:`dict` + """ + if self._aa_local_lddt is None: + self._compute_lddt() + return self._aa_local_lddt + @property def bb_lddt(self): """ Backbone only global lDDT score in range [0.0, 1.0] @@ -1824,6 +1842,7 @@ class Scorer: # score variables to be set lddt_score = None local_lddt = None + aa_local_lddt = None if self.lddt_no_stereochecks: lddt_chain_mapping = dict() @@ -1835,19 +1854,39 @@ class Scorer: residue_mapping = alns, check_resnames=False, local_lddt_prop="lddt", - add_mdl_contacts = self.lddt_add_mdl_contacts)[0] + add_mdl_contacts = self.lddt_add_mdl_contacts, + set_atom_props=True)[0] local_lddt = dict() + aa_local_lddt = dict() for r in self.model.residues: + cname = r.GetChain().GetName() if cname not in local_lddt: local_lddt[cname] = dict() + aa_local_lddt[cname] = dict() + + rnum = r.GetNumber() + if rnum not in aa_local_lddt[cname]: + aa_local_lddt[cname][rnum] = dict() + if r.HasProp("lddt"): score = round(r.GetFloatProp("lddt"), 3) - local_lddt[cname][r.GetNumber()] = score + local_lddt[cname][rnum] = score else: # not covered by trg or skipped in chain mapping procedure # the latter happens if its part of a super short chain - local_lddt[cname][r.GetNumber()] = None + local_lddt[cname][rnum] = None + + for a in r.atoms: + if a.HasProp("lddt"): + score = round(a.GetFloatProp("lddt"), 3) + aa_local_lddt[cname][rnum][a.GetName()] = score + else: + # not covered by trg or skipped in chain mapping + # procedure the latter happens if its part of a + # super short chain + aa_local_lddt[cname][rnum][a.GetName()] = None + else: lddt_chain_mapping = dict() @@ -1859,22 +1898,40 @@ class Scorer: residue_mapping = stereochecked_alns, check_resnames=False, local_lddt_prop="lddt", - add_mdl_contacts = self.lddt_add_mdl_contacts)[0] + add_mdl_contacts = self.lddt_add_mdl_contacts, + set_atom_props=True)[0] local_lddt = dict() + aa_local_lddt = dict() for r in self.model.residues: cname = r.GetChain().GetName() if cname not in local_lddt: local_lddt[cname] = dict() + aa_local_lddt[cname] = dict() + rnum = r.GetNumber() + if rnum not in aa_local_lddt[cname]: + aa_local_lddt[cname][rnum] = dict() + if r.HasProp("lddt"): score = round(r.GetFloatProp("lddt"), 3) - local_lddt[cname][r.GetNumber()] = score + local_lddt[cname][rnum] = score + + for a in r.atoms: + if a.HasProp("lddt"): + score = round(a.GetFloatProp("lddt"), 3) + aa_local_lddt[cname][rnum][a.GetName()] = score + else: + # must have been removed by stereochecks + aa_local_lddt[cname][rnum][a.GetName()] = 0.0 + else: # rsc => residue stereo checked... - mdl_res = self.stereochecked_model.FindResidue(cname, r.GetNumber()) + mdl_res = self.stereochecked_model.FindResidue(cname, rnum) if mdl_res.IsValid(): # not covered by trg or skipped in chain mapping procedure # the latter happens if its part of a super short chain - local_lddt[cname][r.GetNumber()] = None + local_lddt[cname][rnum] = None + for a in r.atoms: + aa_local_lddt[cname][rnum][a.GetName()] = None else: # opt 1: removed by stereochecks => assign 0.0 # opt 2: removed by stereochecks AND not covered by ref @@ -1888,12 +1945,18 @@ class Scorer: trg_r = col.GetResidue(0) break if trg_r is None: - local_lddt[cname][r.GetNumber()] = None + local_lddt[cname][rnum] = None + for a in r.atoms: + aa_local_lddt[cname][rnum][a.GetName()] = None + else: - local_lddt[cname][r.GetNumber()] = 0.0 + local_lddt[cname][rnum] = 0.0 + for a in r.atoms: + aa_local_lddt[cname][rnum][a.GetName()] = 0.0 self._lddt = lddt_score self._local_lddt = local_lddt + self._aa_local_lddt = aa_local_lddt def _compute_bb_lddt(self): LogScript("Computing backbone lDDT") -- GitLab