Commit d62287c2 authored by Studer Gabriel's avatar Studer Gabriel
Browse files

sidechain modelling performance benchmark

parent 0a4cc0cf
,schdaude,schdaudoputer,29.06.2020 09:45,file:///home/schdaude/.config/libreoffice/4;
\ No newline at end of file
Scripts and data to reproduce the sidechain modelling accuracy benchmark.
The whole benchmark is a three step process:
1. Fetch Benchmark Set
2. Modelling with ProMod3 and SCWRL4
3. Evaluation
Fetch Benchmark Set
-------------------
Structures for benchmarking are fetched from the pdb using the remote loading
capabilities from OpenStructure:
`ost get_scwrl_testset_structures.py`
loads, cleans and saves structures to the scwrl_testset directory.
Modelling with ProMod3 and SCWRL4
---------------------------------
`pm reconstruct_sidechains_promod.py` reconstructs sidechains using different
different settings:
- Flexible Rotamer Model (FRM) with subrotamer optimization with default
(backbone dependent) rotamer library
- FRM without subrotamer optimization with default rotamer library
- Rigid Rotamer Model (RRM) with default rotamer library
- FRM with subrotamer optimization and backbone indepenedent rotamer library
- RRM with backbone indepenedent rotamer library
The generated models appear in the subdirectories of models/promod
`python reconstruct_sidechains_scwrl.py` reconstructs sidechains with
the flexible rotamer model as well as the rigid rotamer model. The script
requires the Scwrl4 executable in your path and the generated models appear
in the subdirectories of models/scwrl
Evaluation
----------
`pm do_plot.py` recreates Fig 2 from ProMod3 manuscript.
`pm do_supplemental_tables.py` creates detailled csv tables with performances
of the alternative modelling runs from the last step.
import os
import numpy as np
import traceback
from promod3 import sidechain
from ost import io, seq, geom
def _GetAmbigRMSD(r_one, r_two, fix_atoms, ambig_atom_pairs):
rmsd_one = 0.0
rmsd_two = 0.0
for aname in fix_atoms:
a = r_one.FindAtom(aname)
b = r_two.FindAtom(aname)
if a.IsValid() and b.IsValid():
dist = geom.Distance(a.GetPos(),b.GetPos())
rmsd_one += dist*dist
rmsd_two += dist*dist
else:
return float('NaN')
# process rmsd_one: everything matches nicely
for anames in ambig_atom_pairs:
a = r_one.FindAtom(anames[0])
b = r_two.FindAtom(anames[0])
if not (a.IsValid() and b.IsValid()):
return float('NaN')
dist = geom.Distance(a.GetPos(),b.GetPos())
rmsd_one += dist*dist
a = r_one.FindAtom(anames[1])
b = r_two.FindAtom(anames[1])
if not (a.IsValid() and b.IsValid()):
return float('NaN')
dist = geom.Distance(a.GetPos(),b.GetPos())
rmsd_one += dist*dist
# process rmsd_two: the ambig atom pairs are flipped
for anames in ambig_atom_pairs:
a = r_one.FindAtom(anames[0])
b = r_two.FindAtom(anames[1])
if not (a.IsValid() and b.IsValid()):
return float('NaN')
dist = geom.Distance(a.GetPos(),b.GetPos())
rmsd_two += dist*dist
a = r_one.FindAtom(anames[1])
b = r_two.FindAtom(anames[0])
if not (a.IsValid() and b.IsValid()):
return float('NaN')
dist = geom.Distance(a.GetPos(),b.GetPos())
rmsd_two += dist*dist
n_atoms = len(fix_atoms) + 2 * len(ambig_atom_pairs)
return min(np.sqrt(rmsd_one / n_atoms), np.sqrt(rmsd_two / n_atoms))
def GetRMSD(r_one, r_two):
rname = r_one.GetName()
if rname != r_two.GetName():
return float('NaN')
# handle the ambiguous cases
if rname == 'TYR':
return _GetAmbigRMSD(r_one, r_two, ['CG', 'CZ', 'OH'],
[('CD1', 'CD2'), ('CE1', 'CE2')])
if rname == 'VAL':
return _GetAmbigRMSD(r_one, r_two, [], [('CG1', 'CG2')])
if rname == 'PHE':
return _GetAmbigRMSD(r_one, r_two, ['CG', 'CZ'],
[('CD1', 'CD2'), ('CE1', 'CE2')])
if rname == 'ASP':
return _GetAmbigRMSD(r_one, r_two, ['CG'], [('OD1', 'OD2')])
if rname == 'GLU':
return _GetAmbigRMSD(r_one, r_two, ['CG', 'CD'], [('OE1', 'OE2')])
# the general case
rm = 0.0
counter = 0
for a in r_one.atoms:
b = r_two.FindAtom(a.GetName())
if a.IsValid() and b.IsValid():
dist = geom.Distance(a.GetPos(),b.GetPos())
rm += dist*dist
counter+=1
else:
return float('NaN')
return np.sqrt(rm/counter)
class AAEval:
def __init__(self):
self.num_valid_chi1 = 0.0
self.num_correct_chi1 = 0.0
self.num_valid_chi2 = 0.0
self.num_correct_chi2 = 0.0
self.num_valid_chi2_given_chi1 = 0.0
self.num_correct_chi2_given_chi1 = 0.0
self.rmsd_values = list()
self.num_residues = 0.0
def GetAAEvaluations(model_dir, target_dir = 'scwrl_testset',
angle_thresh = 20.0/180.0*np.pi):
AA = ['ARG','ASN','ASP','CYS','GLN','GLU','HIS','ILE','LEU','LYS','MET','PHE',
'PRO','SER','THR','TRP','TYR','VAL']
eval_data = {aa : AAEval() for aa in AA}
testset_ids = set(f[:4] for f in os.listdir(target_dir))
for target_idx, t in enumerate(testset_ids):
try:
target_path = os.path.join(target_dir, t + '-molcked.pdb')
model_path = os.path.join(model_dir, t + '.pdb')
target_structure = io.LoadPDB(target_path).Select('peptide=true')
model_structure = io.LoadPDB(model_path).Select('peptide=true')
# If several chains, we only consider unique sequences of chains
# => first occurence matters (e.g. in case of homo oligomers)
present_chains = list()
present_sequences = list()
for ch in target_structure.chains:
ch_s = seq.CreateSequence('A', ''.join([r.one_letter_code for r in ch.residues]))
already_there = False
for s in present_sequences:
aln = seq.alg.GlobalAlign(ch_s, s, seq.alg.BLOSUM62)[0]
s_id = seq.alg.SequenceIdentity(aln)
if s_id > 0.9:
# High sequence identity can be caused through a supper crappy
# alignment where we mostly have gaps. Lets estimate the fraction
# of aligned residues for the shorter sequence
s0 = str(aln.GetSequence(0))
s1 = str(aln.GetSequence(1))
n_aligned = sum([int(a!='-' and b!='-') for a,b in zip(s0, s1)])
frac = float(n_aligned) / min(sum([int(c!='-') for c in s0]),
sum([int(c!='-') for c in s1]))
# set arbitrary threshold, most residues are expected to be aligned
if frac > 0.8:
already_there = True
if not already_there:
present_chains.append(ch.GetName())
present_sequences.append(ch_s)
relevant_query = 'cname=' + ','.join(present_chains)
relevant_target = target_structure.Select(relevant_query)
residues_to_consider = relevant_target.residues
for r_t in residues_to_consider:
rot_id = sidechain.TLCToRotID(r_t.GetName())
if rot_id == sidechain.XXX or rot_id == sidechain.GLY or rot_id == sidechain.ALA:
continue
r_m = model_structure.FindResidue(r_t.GetChain().GetName(),r_t.GetNumber())
if not r_m.IsValid():
continue
try:
rot_t = sidechain.RotamerLibEntry.FromResidue(r_t)
rot_m = sidechain.RotamerLibEntry.FromResidue(r_m)
except:
continue # one of the residues is incomplete
eval_data[r_t.GetName()].num_residues += 1
t_chi1 = rot_t.chi1
m_chi1 = rot_m.chi1
t_chi2 = rot_t.chi2
m_chi2 = rot_m.chi2
chi1_ok = rot_t.SimilarDihedral(rot_m, 0, angle_thresh, rot_id)
chi2_ok = rot_t.SimilarDihedral(rot_m, 1, angle_thresh, rot_id)
if t_chi1 == t_chi1 and m_chi1 == m_chi1:
eval_data[r_t.GetName()].num_valid_chi1 += 1
if chi1_ok:
eval_data[r_t.GetName()].num_correct_chi1 += 1
if t_chi2 == t_chi2 and m_chi2 == m_chi2:
eval_data[r_t.GetName()].num_valid_chi2 += 1
if chi2_ok:
eval_data[r_t.GetName()].num_correct_chi2 += 1
if chi1_ok:
if t_chi2 == t_chi2 and m_chi2 == m_chi2:
eval_data[r_t.GetName()].num_valid_chi2_given_chi1 += 1
if chi2_ok:
eval_data[r_t.GetName()].num_correct_chi2_given_chi1 += 1
rm = GetRMSD(r_t, r_m)
if rm == rm:
eval_data[r_t.GetName()].rmsd_values.append(rm)
except:
print('failed to evaluate target', t)
traceback.print_exc()
return eval_data
def PrintPerformance(model_dir, only_table=True, csv_path=None):
eval_data = GetAAEvaluations(model_dir)
# get the total number of chi1 values
total_num_chi1 = 0
total_num_chi2 = 0
total_num_residues = 0
total_correct_num_chi1 = 0
total_correct_num_chi2 = 0
for key in eval_data.keys():
total_num_chi1 += eval_data[key].num_valid_chi1
total_num_chi2 += eval_data[key].num_valid_chi2
total_num_residues += eval_data[key].num_residues
total_correct_num_chi1 += eval_data[key].num_correct_chi1
total_correct_num_chi2 += eval_data[key].num_correct_chi2
if not only_table:
print('total num valid chi1 angles: ', total_num_chi1)
print('fraction correct:', float(total_correct_num_chi1) / total_num_chi1)
print('total num valid chi2 angles: ', total_num_chi2)
print('fraction correct:', float(total_correct_num_chi2) / total_num_chi2)
print('total num residues: ', total_num_residues)
for key in eval_data.keys():
print(key)
if eval_data[key].num_valid_chi1 > 0:
print('chi1 accuracy: ',
eval_data[key].num_correct_chi1/eval_data[key].num_valid_chi1,
eval_data[key].num_valid_chi1, eval_data[key].num_residues)
if eval_data[key].num_valid_chi2 > 0:
print('chi2 accuracy: ',
eval_data[key].num_correct_chi2/eval_data[key].num_valid_chi2,
eval_data[key].num_valid_chi2, eval_data[key].num_residues)
if eval_data[key].num_valid_chi2_given_chi1 > 0:
print('chi2 accuracy given chi1: ',
eval_data[key].num_correct_chi2_given_chi1/eval_data[key].num_valid_chi2_given_chi1,
eval_data[key].num_valid_chi2_given_chi1)
# output for csv table
csv_out = list()
csv_out.append('AA,num,X1 correct (%),X2 correct (%),X2 correct given X1 (%),avg RMSD')
AA = ['ARG','ASN','ASP','CYS','GLN','GLU','HIS','ILE','LEU','LYS','MET','PHE','PRO','SER','THR','TRP','TYR','VAL']
for aa in AA:
data = eval_data[aa]
num = '%i'%(int(data.num_valid_chi1))
chi1_fraction = ' '
chi2_fraction = ' '
chi2_fraction_given_chi1 = ' '
rmsd = ' '
if data.num_valid_chi1 > 0:
chi1_fraction = '%.2f'%(100*float(data.num_correct_chi1)/data.num_valid_chi1)
if data.num_valid_chi2 > 0:
chi2_fraction = '%.2f'%(100*float(data.num_correct_chi2)/data.num_valid_chi2)
if(data.num_valid_chi2_given_chi1) > 0:
chi2_fraction_given_chi1 = '%.2f'%(100*data.num_correct_chi2_given_chi1/data.num_valid_chi2_given_chi1)
if len(data.rmsd_values) > 0:
rmsd = '%.2f'%(np.mean(data.rmsd_values))
csv_out.append(','.join([aa, num, chi1_fraction, chi2_fraction, chi2_fraction_given_chi1, rmsd]))
if csv_path:
with open(csv_path, 'w') as fh:
fh.write('\n'.join(csv_out))
import matplotlib.pyplot as plt
import numpy as np
import analyze
model_dir_one = 'models/promod/frm'
model_dir_two = 'models/scwrl/frm'
label_one = 'ProMod3'
label_two = 'SCWRL4'
fig_out_path = 'chi1_accuracy_frm.png'
one = analyze.GetAAEvaluations(model_dir_one)
two = analyze.GetAAEvaluations(model_dir_two)
one_chi1 = list()
two_chi1 = list()
AA = ['ARG','ASN','ASP','CYS','GLN','GLU','HIS','ILE','LEU','LYS','MET','PHE',
'PRO','SER','THR','TRP','TYR','VAL']
for aa in AA:
one_chi1.append(one[aa].num_correct_chi1/one[aa].num_valid_chi1)
two_chi1.append(two[aa].num_correct_chi1/two[aa].num_valid_chi1)
x_promod = list()
x_scwrl = list()
for i in range(18):
x_promod.append(i+0.1)
x_scwrl.append(i+0.5)
cred = (128.0/255,0.0,0.0)
cblue = (102.0/255,153.0/255,204.0/255)
plt.bar(x_promod, one_chi1, width = 0.4, color=cred, label=label_one,
linewidth=2.0, alpha=0.75, align='edge', edgecolor='k')
plt.bar(x_scwrl, two_chi1, width = 0.4, color=cblue, label=label_two,
linewidth=2.0, alpha=0.75, align='edge', edgecolor='k')
plt.ylim((0.5, 1.0))
plt.legend(frameon=False, loc='upper left')
x_tick_positions = list()
for item in x_promod:
x_tick_positions.append(item + 0.4)
plt.xticks(x_tick_positions, AA, rotation=30)
plt.savefig(fig_out_path)
plt.clf()
import analyze
model_directories = ['models/promod/frm',
'models/promod/frm_no_subrotamer_optimization',
'models/promod/rrm',
'models/promod/bb_indep_frm',
'models/promod/bb_indep_rrm',
'models/scwrl/frm',
'models/scwrl/rrm']
# you'll get a different number of evaluated aspartic acids in this testset when
# evaluating SCWRL4 and ProMod3
# The bad guy is: 2pst, X.ASP9
# The problem is: a backbone oxygen is missing.
# SCWRL4 completely deletes the residue
# ProMod3 doesn't touch it and the correct sidechain remains in place
# This is unfair in favour of ProMod3, as the correct sidechain is evaluated.
# however, the effect on overall performance can considered to be neglectible...
for md in model_directories:
print('processing:', md)
csv_path = '_'.join(['eval', md.split('/')[-2], md.split('/')[-1] + '.csv'])
analyze.PrintPerformance(md, False, csv_path=csv_path)
AA,num,X1 correct (%),X2 correct (%),X2 correct given X1 (%),avg RMSD
ARG,3601,68.09,68.20,74.18,2.02
ASN,2875,76.17,39.76,47.49,0.75
ASP,4013,71.72,53.35,62.09,1.07
CYS,999,84.68, , ,0.22
GLN,2501,72.57,62.69,72.62,1.16
GLU,4611,65.80,62.18,69.28,1.59
HIS,1542,79.64,45.72,49.59,1.03
ILE,3964,91.60,84.06,86.97,0.34
LEU,6554,85.66,83.25,94.14,0.38
LYS,3819,68.29,73.21,76.88,1.23
MET,1406,78.73,67.07,76.60,0.81
PHE,2715,90.39,85.27,88.96,0.77
PRO,3230,79.01,78.20,98.79,0.17
SER,4101,59.33, , ,0.41
THR,3784,75.29, , ,0.38
TRP,979,83.55,71.50,81.05,1.26
TYR,2346,89.00,81.50,86.11,0.90
VAL,5018,87.23, , ,0.35
\ No newline at end of file
AA,num,X1 correct (%),X2 correct (%),X2 correct given X1 (%),avg RMSD
ARG,3601,66.98,69.98,74.71,2.07
ASN,2875,74.57,37.18,44.82,0.78
ASP,4013,70.15,50.83,59.68,1.13
CYS,999,84.38, , ,0.22
GLN,2501,71.57,61.22,70.06,1.19
GLU,4611,63.96,62.61,68.90,1.67
HIS,1542,78.08,40.53,44.44,1.09
ILE,3964,90.94,83.25,86.10,0.35
LEU,6554,84.82,83.06,94.39,0.40
LYS,3819,67.32,73.24,77.17,1.26
MET,1406,77.52,67.00,77.52,0.84
PHE,2715,86.37,79.52,85.12,0.91
PRO,3230,78.30,77.43,98.73,0.17
SER,4101,58.55, , ,0.41
THR,3784,74.71, , ,0.38
TRP,979,80.59,66.09,75.03,1.45
TYR,2346,85.51,79.28,85.14,1.07
VAL,5018,87.35, , ,0.35
\ No newline at end of file
AA,num,X1 correct (%),X2 correct (%),X2 correct given X1 (%),avg RMSD
ARG,3601,73.67,68.65,73.80,1.96
ASN,2875,83.34,49.04,56.43,0.63
ASP,4013,82.01,61.50,70.86,0.77
CYS,999,87.89, , ,0.18
GLN,2501,75.77,63.97,72.03,1.11
GLU,4611,70.11,62.91,70.40,1.49
HIS,1542,84.95,48.57,52.21,0.89
ILE,3964,95.86,86.28,88.05,0.26
LEU,6554,88.74,85.96,94.77,0.33
LYS,3819,74.94,74.00,77.88,1.12
MET,1406,81.58,72.05,80.38,0.74
PHE,2715,94.00,87.62,90.56,0.61
PRO,3230,80.93,80.00,98.81,0.14
SER,4101,69.06, , ,0.32
THR,3784,89.27, , ,0.21
TRP,979,89.68,76.51,83.37,0.99
TYR,2346,92.16,84.83,88.39,0.74
VAL,5018,93.08, , ,0.25
\ No newline at end of file
AA,num,X1 correct (%),X2 correct (%),X2 correct given X1 (%),avg RMSD
ARG,3601,73.76,71.26,76.05,1.97
ASN,2875,83.17,49.15,56.71,0.63
ASP,4013,82.28,61.48,70.32,0.77
CYS,999,87.69, , ,0.18
GLN,2501,76.33,63.89,71.50,1.11
GLU,4611,70.07,65.26,72.79,1.48
HIS,1542,85.02,48.44,51.79,0.89
ILE,3964,95.71,86.15,87.98,0.26
LEU,6554,88.59,85.92,94.64,0.33
LYS,3819,75.18,74.18,77.88,1.13
MET,1406,81.93,72.12,80.03,0.75
PHE,2715,92.04,87.11,90.16,0.67
PRO,3230,80.93,80.00,98.81,0.14
SER,4101,69.03, , ,0.32
THR,3784,89.19, , ,0.21
TRP,979,87.54,75.69,82.03,1.07
TYR,2346,90.58,84.78,88.33,0.80
VAL,5018,93.20, , ,0.24
\ No newline at end of file
AA,num,X1 correct (%),X2 correct (%),X2 correct given X1 (%),avg RMSD
ARG,3601,71.79,68.93,73.93,2.03
ASN,2875,81.81,46.09,53.66,0.66
ASP,4013,80.36,58.58,68.74,0.83
CYS,999,87.99, , ,0.18
GLN,2501,75.53,63.13,70.62,1.14
GLU,4611,68.08,63.70,71.39,1.56
HIS,1542,84.18,44.36,48.07,0.93
ILE,3964,95.23,85.19,87.05,0.28
LEU,6554,87.46,84.67,94.49,0.35
LYS,3819,73.61,73.42,77.20,1.16
MET,1406,81.01,70.70,79.10,0.78
PHE,2715,91.42,82.80,86.62,0.70
PRO,3230,80.25,79.29,98.77,0.14
SER,4101,69.06, , ,0.32
THR,3784,89.16, , ,0.21
TRP,979,87.23,71.60,78.81,1.17
TYR,2346,89.09,81.80,87.03,0.89
VAL,5018,93.12, , ,0.25
\ No newline at end of file
AA,num,X1 correct (%),X2 correct (%),X2 correct given X1 (%),avg RMSD
ARG,3601,74.23,71.20,76.81,1.93
ASN,2875,81.60,47.62,55.24,0.67
ASP,4012,81.56,60.82,70.17,0.82
CYS,999,87.99, , ,0.20
GLN,2501,75.37,63.97,72.41,1.13
GLU,4611,69.29,65.73,73.49,1.49
HIS,1542,83.72,45.59,49.26,0.95
ILE,3964,95.31,84.76,86.69,0.27
LEU,6554,87.60,85.26,94.79,0.36
LYS,3819,74.31,74.10,77.80,1.14
MET,1406,80.94,72.33,79.88,0.78
PHE,2715,92.08,85.56,87.96,0.72
PRO,3230,80.80,79.41,98.20,0.14
SER,4101,69.11, , ,0.33
THR,3784,89.48, , ,0.21
TRP,979,87.44,76.00,82.48,1.10
TYR,2346,90.49,83.08,86.62,0.86
VAL,5018,92.63, , ,0.26
\ No newline at end of file
AA,num,X1 correct (%),X2 correct (%),X2 correct given X1 (%),avg RMSD
ARG,3601,72.06,68.29,73.87,2.02
ASN,2875,80.87,46.75,54.75,0.68
ASP,4012,79.49,57.63,67.89,0.88
CYS,999,87.59, , ,0.20
GLN,2501,74.09,63.17,71.40,1.16
GLU,4611,67.53,64.32,72.03,1.55
HIS,1542,83.46,42.48,46.15,0.97
ILE,3964,94.75,83.70,85.92,0.28
LEU,6554,86.41,83.98,94.40,0.38
LYS,3819,72.72,73.13,76.99,1.18
MET,1406,79.23,70.77,78.82,0.83
PHE,2715,90.87,82.10,85.37,0.77
PRO,3230,79.72,78.27,98.10,0.15
SER,4101,67.93, , ,0.34
THR,3784,88.82, , ,0.22
TRP,979,85.90,68.74,76.10,1.23
TYR,2346,88.49,80.18,85.02,0.97
VAL,5018,92.07, , ,0.27
\ No newline at end of file
import os
from ost.mol.alg import MolckSettings, Molck
from ost import conop, io
import traceback
with open('scwrl_testset_ids.txt') as fh:
data = fh.readlines()
testset_ids = list()
for line in data:
testset_ids += line.strip().split()
outdir = 'scwrl_testset'
if not os.path.exists(outdir):
os.makedirs(outdir)
compound_lib = conop.GetDefaultLib()
molck_settings = MolckSettings(rm_unk_atoms=True,
rm_non_std=True,
rm_hyd_atoms=True,
rm_oxt_atoms=True,
rm_zero_occ_atoms=True,
colored=False,
map_nonstd_res=True,
assign_elem=True)
for item in testset_ids:
try:
print('processing ', item)
out_path = os.path.join(outdir, item + '.pdb')
prot = io.LoadPDB(item, remote = True)
io.SavePDB(prot.Select('peptide=true'), out_path)
molcked_out_path = os.path.join(outdir, item + '-molcked.pdb')
Molck(prot, compound_lib, molck_settings)
io.SavePDB(prot.Select('peptide=true'), molcked_out_path)
except:
print('failed for', item)
traceback.print_exc()
This diff is collapsed.
This source diff could not be displayed because it is too large. You can view the blob instead.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment