import unittest, os, sys import ost from ost import io, mol, settings, conop, seq # check if we can import: fails if numpy or scipy not available try: from ost.mol.alg.lddt import * from ost.mol.alg.scoring import * except ImportError: print("Failed to import qsscoring. Happens when numpy or scipy missing. " \ "Ignoring test_lddt.py tests.") sys.exit(0) def _LoadFile(file_name): """Helper to avoid repeating input path over and over.""" return io.LoadPDB(os.path.join('testfiles', file_name)) class TestlDDT(unittest.TestCase): # compare monomers to lDDT C++ reference implementation def test_lDDT_monomer(self): # do 7SGN model = _LoadFile("7SGN_C_model.pdb") target = _LoadFile("7SGN_C_target.pdb") # do awesome implementation scorer = lDDTScorer(target) aws_score, aws_per_res_scores = scorer.lDDT(model) # do reference implementation dl = mol.alg.CreateDistanceList(target.CreateFullView(), 15.0) classic_score = mol.alg.LDDTHA(model.CreateFullView(), dl) classic_per_res_scores = list() for r in model.residues: if r.HasProp("locallddt"): classic_per_res_scores.append(r.GetFloatProp("locallddt")) else: classic_per_res_scores.append(None) self.assertAlmostEqual(aws_score, classic_score, places=5) self.assertEqual(len(aws_per_res_scores), len(classic_per_res_scores)) for a,b in zip(aws_per_res_scores, classic_per_res_scores): if a is None and b is None: continue # only check for 3 places. Reason for that is that the distance # difference between GLN30.CB and TYR35.O is within floating point # accuracy of the 0.5A threshold. So the two involved residues may # have a difference of 1 with respect to conserved distances. self.assertAlmostEqual(a, b, places = 3) # do 7W1F_B model = _LoadFile("7W1F_B_model.pdb") target = _LoadFile("7W1F_B_target.pdb") # do awesome implementation scorer = lDDTScorer(target) aws_score, aws_per_res_scores = scorer.lDDT(model) # do reference implementation dl = mol.alg.CreateDistanceList(target.CreateFullView(), 15.0) classic_score = mol.alg.LDDTHA(model.CreateFullView(), dl) classic_per_res_scores = list() for r in model.residues: if r.HasProp("locallddt"): classic_per_res_scores.append(r.GetFloatProp("locallddt")) else: classic_per_res_scores.append(None) self.assertAlmostEqual(aws_score, classic_score, places=5) self.assertEqual(len(aws_per_res_scores), len(classic_per_res_scores)) for a,b in zip(aws_per_res_scores, classic_per_res_scores): if a is None and b is None: continue self.assertAlmostEqual(a, b, places = 5) # check oligo functionality def test_lDDT_oligo(self): ent_full = _LoadFile("4br6.1.pdb") model = ent_full.Select('peptide=true') target = ent_full.Select('peptide=true and cname=A,B') # hardcoded chain mapping chain_mapping = {"A": "A", "B": "B"} lddt_scorer = lDDTScorer(target) score, per_res_scores = lddt_scorer.lDDT(model, chain_mapping=chain_mapping) self.assertAlmostEqual(score, 1.0, places=5) score, per_res_scores = lddt_scorer.lDDT(model, chain_mapping=chain_mapping, no_interchain=True) self.assertAlmostEqual(score, 1.0, places=5) score, per_res_scores = lddt_scorer.lDDT(model, chain_mapping=chain_mapping, no_interchain=False, penalize_extra_chains=True) self.assertAlmostEqual(score, 0.52084655, places=5) score, per_res_scores = lddt_scorer.lDDT(model, chain_mapping=chain_mapping, no_interchain=True, penalize_extra_chains=True) self.assertAlmostEqual(score, 0.499570048, places=5) def test_lDDT_custom_resmapping(self): ent_full = _LoadFile("4br6.1.pdb") model = ent_full.Copy().Select('peptide=true') target = ent_full.Select('peptide=true and cname=A,B') # shift residue numbers in model ed = model.handle.EditXCS() for ch in model.chains: ed.RenumberChain(ch.handle, 42, True) # hardcoded chain mapping chain_mapping = {"A": "A", "B": "B"} lddt_scorer = lDDTScorer(target) # naively running lDDT will fail, as residue-residue mapping happens # with resnums. Since we shifted that stuff above we'll get an error # complaining about residue name mismatch with self.assertRaises(RuntimeError): score, per_res_scores = lddt_scorer.lDDT(model, chain_mapping=chain_mapping, no_interchain=False, penalize_extra_chains=True) # we can rescue that with alignments res_map = dict() for mdl_ch_name, trg_ch_name in chain_mapping.items(): mdl_ch = model.FindChain(mdl_ch_name) trg_ch = target.FindChain(trg_ch_name) mdl_seq = ''.join([r.one_letter_code for r in mdl_ch.residues]) mdl_seq = seq.CreateSequence(mdl_ch_name, mdl_seq) trg_seq = ''.join([r.one_letter_code for r in trg_ch.residues]) trg_seq = seq.CreateSequence(trg_ch_name, trg_seq) aln = seq.alg.GlobalAlign(trg_seq, mdl_seq, seq.alg.BLOSUM62)[0] res_map[mdl_ch_name] = aln score, per_res_scores = lddt_scorer.lDDT(model, chain_mapping=chain_mapping, no_interchain=False, penalize_extra_chains=True, residue_mapping=res_map) self.assertAlmostEqual(score, 0.52084655, places=5) def test_lDDT_seqsep(self): target = _LoadFile("7SGN_C_target.pdb") with self.assertRaises(NotImplementedError): scorer = lDDTScorer(target, sequence_separation=42) scorer = lDDTScorer(target, sequence_separation=0) def test_bb_only(self): model = _LoadFile("7SGN_C_model.pdb") target = _LoadFile("7SGN_C_target.pdb") # do scoring and select aname=CA scorer = lDDTScorer(target.Select("aname=CA")) score_one, per_res_scores_one = scorer.lDDT(model) score_two, per_res_scores_two = scorer.lDDT(model.Select("aname=CA")) # no selection, just setting bb_only flag should give the same scorer = lDDTScorer(target, bb_only=True) score_three, per_res_scores_three = scorer.lDDT(model) # check self.assertAlmostEqual(score_one, score_two, places=5) self.assertAlmostEqual(score_one, score_three, places=5) for a,b in zip(per_res_scores_one, per_res_scores_two): self.assertAlmostEqual(a, b, places=5) for a,b in zip(per_res_scores_one, per_res_scores_three): self.assertAlmostEqual(a, b, places=5) def test_resname_match(self): model = _LoadFile("7SGN_C_model.pdb") target = _LoadFile("7SGN_C_target.pdb") # introduce name mismatch ed = model.handle.EditXCS() ed.RenameResidue(model.residues[42], "asdf") # do scoring and select aname=CA scorer = lDDTScorer(target.Select("aname=CA")) with self.assertRaises(RuntimeError): scorer.lDDT(model) scorer.lDDT(model, check_resnames=False) def test_intra_interchain(self): ent_full = _LoadFile("4br6.1.pdb") model = ent_full.Select('peptide=true and cname=A,B') target = ent_full.Select('peptide=true and cname=A,B') chain_mapping = {"A": "A", "B": "B"} lddt_scorer = lDDTScorer(target) # do lDDT only on interchain contacts (ic) lDDT_ic, per_res_lDDT_ic, lDDT_tot_ic, lDDT_cons_ic, \ res_indices_ic, per_res_exp_ic, per_res_conserved_ic =\ lddt_scorer.lDDT(model, no_intrachain=True, chain_mapping = chain_mapping, return_dist_test = True) # do lDDT only on intrachain contacts (sc for single chain) lDDT_sc, per_res_lDDT_sc, lDDT_tot_sc, lDDT_cons_sc, \ res_indices_sc, per_res_exp_sc, per_res_conserved_sc =\ lddt_scorer.lDDT(model, no_interchain=True, chain_mapping = chain_mapping, return_dist_test = True) # do lDDT on everything lDDT, per_res_lDDT, lDDT_tot, lDDT_cons, res_indices, per_res_exp, \ per_res_conserved = lddt_scorer.lDDT(model, chain_mapping = chain_mapping, return_dist_test = True) # sum of lDDT_tot_ic and lDDT_tot_sc should be equal to lDDT_tot self.assertEqual(lDDT_tot_ic + lDDT_tot_sc, lDDT_tot) # same for the conserved contacts self.assertEqual(lDDT_cons_ic + lDDT_cons_sc, lDDT_cons) def test_add_mdl_contacts(self): model = _LoadFile("7SGN_C_model.pdb") target = _LoadFile("7SGN_C_target.pdb") lddt_scorer = lDDTScorer(target) lDDT, per_res_lDDT, lDDT_tot, lDDT_cons, res_indices, per_res_exp, \ per_res_conserved = lddt_scorer.lDDT(model, return_dist_test = True, add_mdl_contacts=True) # this value is just blindly copied in without checking whether it makes # any sense... it's sole purpose is to trigger the respective flag # in lDDT computation self.assertEqual(lDDT, 0.6171511842396518) class TestlDDTBS(unittest.TestCase): def test_basic(self): mdl = _LoadFile("lddtbs_mdl.pdb") ref = _LoadFile("lddtbs_ref_1r8q.1.pdb") lddtbs_scorer = lDDTBSScorer(reference=ref, model=mdl) bs_repr = lddtbs_scorer.ScoreBS(ref.Select("rname=AFB"), radius = 5.0, lddt_radius = 12.0) # select residues manually from reference for at in ref.Select("rname=AFB").atoms: close_atoms = ref.FindWithin(at.GetPos(), 5.0) for close_at in close_atoms: close_at.GetResidue().SetIntProp("asdf", 1) ref_bs = ref.Select("grasdf:0=1") ref_bs = ref_bs.Select("peptide=true") ref_bs_names = [r.GetQualifiedName() for r in ref_bs.residues] repr_bs_names = [r.GetQualifiedName() for r in bs_repr.ref_residues] self.assertEqual(sorted(ref_bs_names), sorted(repr_bs_names)) # everything below basically computes lDDTBS manually and # compares with the result we got above # select residues manually from model fancy_mapping = {"A":"B", "B":"A"} # hardcoded chain mapping... mdl_bs = mdl.CreateEmptyView() for r in ref_bs.residues: mdl_res = mdl.FindResidue(fancy_mapping[r.GetChain().GetName()], r.GetNumber()) mdl_bs.AddResidue(mdl_res, mol.ViewAddFlag.INCLUDE_ALL) # put that stuff in single chain structures sc_ref_bs = mol.CreateEntity() ed = sc_ref_bs.EditXCS() ch = ed.InsertChain("A") for r in ref_bs.residues: added_r = ed.AppendResidue(ch, r.GetName()) for a in r.atoms: ed.InsertAtom(added_r, a.GetName(), a.GetPos()) sc_mdl_bs = mol.CreateEntity() ed = sc_mdl_bs.EditXCS() ch = ed.InsertChain("A") for r in mdl_bs.residues: added_r = ed.AppendResidue(ch, r.GetName()) for a in r.atoms: ed.InsertAtom(added_r, a.GetName(), a.GetPos()) # compute and compare lddt_scorer = lDDTScorer(sc_ref_bs, inclusion_radius=12.0) self.assertAlmostEqual(bs_repr.lDDT, lddt_scorer.lDDT(sc_mdl_bs)[0]) if __name__ == "__main__": from ost import testutils if testutils.DefaultCompoundLibIsSet(): testutils.RunTests() else: print('No compound library available. Ignoring test_lddt.py tests.')