import os
from ost import io
from ost import seq
from promod3 import loop

# load the sequence of all chains that are part of the loop modelling benchmark
target_sequences = list()
with open('fread_data.txt', 'r') as fh:
    fread_data = fh.readlines()
for item in fread_data:
    pdb_id = item.split()[0]
    chain_id = item.split()[1]
    seq_path = os.path.join('sequences', pdb_id + '_' + chain_id + '.fasta')
    target_sequences.append(io.LoadSequence(seq_path))

# iterate over entries in the default StructureDB and keep the nonredundant ones
structure_db = loop.LoadStructureDB()
indices_to_keep = list()
for idx in range(structure_db.GetNumCoords()):
    print('processing ',idx)
    coord_info = structure_db.GetCoordInfo(idx)
    frag_info = loop.FragmentInfo(idx, 0, coord_info.size)
    s = seq.CreateSequence('db_seq', structure_db.GetSequence(frag_info))
    skip = False
    for ts in target_sequences:
        aln = seq.alg.SemiGlobalAlign(ts, s, seq.alg.BLOSUM62)[0]
        s_id = seq.alg.SequenceIdentity(aln)
        if s_id > 90.0:
            # High sequence identity can be caused through a supper crappy 
            # alignment where we mostly have gaps. Lets estimate the fraction 
            # of aligned residues for the shorter sequence
            s0 = str(aln.GetSequence(0))
            s1 = str(aln.GetSequence(1))
            n_aligned = sum([int(a!='-' and b!='-') for a,b in zip(s0, s1)])
            frac = float(n_aligned) / min(sum([int(c!='-') for c in s0]),
                                          sum([int(c!='-') for c in s1]))
            # set arbitrary threshold, most residues are expected to be aligned
            if frac > 0.8:
              skip = True
              break
    if skip:
        print('skip entry with pdb id', coord_info.id)
    else:
        indices_to_keep.append(idx)

# generate the StructureDB only containing the nonredundant entries as well as
# the accoring FragDB
non_redundant_db = structure_db.GetSubDB(indices_to_keep)
non_redundant_db.PrintStatistics()
non_redundant_db.SavePortable('nonredundant_structure_db_portable.dat')

max_pairwise_rmsd = 1.0
dist_bin_size = 1.0
angle_bin_size = 20
fragdb = loop.FragDB(dist_bin_size, angle_bin_size)
for i in range(3,15):
    print('start to add fragments of length ', i)
    fragdb.AddFragments(i, max_pairwise_rmsd, non_redundant_db)
fragdb.PrintStatistics()
fragdb.SavePortable('nonredundant_frag_db_portable.dat')