# Copyright (c) 2013-2020, SIB - Swiss Institute of Bioinformatics and
#                          Biozentrum - University of Basel
# 
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
#   http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""
Unit tests for modelling.
"""
import unittest
from promod3 import modelling
from ost import conop, seq, io, mol, geom

class ModellingTests(unittest.TestCase):

    #######################################################################
    # HELPERs
    #######################################################################
    def checkModel(self, mhandle):
        '''
        Check if model residues are properly set (peptide_linking, protein,
        torsion).
        '''
        # check residue properties
        for r in mhandle.model.residues:
            # check types
            self.assertTrue(r.peptide_linking)
            self.assertTrue(r.is_protein)
            # check links
            if r.next.IsValid() and (r.next.number - r.number) == 1:
                self.assertTrue(mol.InSequence(r, r.next))
            # check torsions
            if r.prev.IsValid() and mol.InSequence(r.prev, r):
                self.assertTrue(r.phi_torsion.IsValid())
                self.assertTrue(r.omega_torsion.IsValid())
            if r.next.IsValid() and mol.InSequence(r, r.next):
                self.assertTrue(r.psi_torsion.IsValid())
    #######################################################################

    def testRaiseNoAttachedView(self):
        # test that BuildRawModel throws exception when no view is attached
        aln = seq.CreateAlignment(seq.CreateSequence('A', 'acdef'),
                                  seq.CreateSequence('B', 'ac-ef'))
        self.assertRaises(RuntimeError, modelling.BuildRawModel, aln)

    def testRaiseInvalidChainNames(self):
        tpl = io.LoadPDB('data/gly.pdb')
        aln = io.LoadAlignment('data/seq.fasta')
        aln.AttachView(1, tpl.CreateFullView())

        # perform stuff that doesn't raise, i.e. do it correctly
        result = modelling.BuildRawModel(aln)
        result = modelling.BuildRawModel(aln, chain_names='cheese')
        self.assertEqual(result.model.chains[0].GetName(), 'cheese')
        aln_lst  = seq.AlignmentList()
        aln_lst.append(aln)
        aln_lst.append(aln)
        result = modelling.BuildRawModel(aln_lst)
        result = modelling.BuildRawModel(aln_lst, ['cheese', 'steak'])
        self.assertEqual(result.model.chains[0].GetName(), 'cheese')
        self.assertEqual(result.model.chains[1].GetName(), 'steak')
        result = modelling.BuildRawModel(aln_lst, 'ch')
        self.assertEqual(result.model.chains[0].GetName(), 'c')
        self.assertEqual(result.model.chains[1].GetName(), 'h')
        result = modelling.BuildRawModel(aln_lst, 'cheese')
        self.assertEqual(result.model.chains[0].GetName(), 'c')
        self.assertEqual(result.model.chains[1].GetName(), 'h')

        # we only accept a string as chain_names if aln is AlignmentHandle
        self.assertRaises(RuntimeError, modelling.BuildRawModel, aln, 1)
        self.assertRaises(RuntimeError, modelling.BuildRawModel, aln, ['A'])

        # we only accept a list or str as chain_names if aln is AlignmentList
        self.assertRaises(RuntimeError, modelling.BuildRawModel, aln_lst, 1)
        # size also matters...
        self.assertRaises(RuntimeError, modelling.BuildRawModel, aln_lst, ['A'])
        self.assertRaises(RuntimeError, modelling.BuildRawModel, aln_lst, 'A')

        # increase size of aln_list => at some point we should run out of 
        # default chain_names
        aln_lst = 100*[aln]
        self.assertRaises(RuntimeError, modelling.BuildRawModel, aln_lst)

    def testModeledSequence(self):
        # test if the model has the sequence we want.
        tpl = io.LoadPDB('data/gly.pdb')
        aln = io.LoadAlignment('data/seq.fasta')
        aln.AttachView(1, tpl.CreateFullView())
        result = modelling.BuildRawModel(aln)
        seq1 = seq.SequenceFromChain('MODEL', result.model.chains[0])
        self.assertEqual(len(result.gaps), 0)
        self.assertEqual(seq1.string, aln.sequences[0].string)

    def testDeletion(self):
        # test if the result contains a "deletion" gap at the right spot.
        tpl = io.LoadPDB('data/gly.pdb')
        aln = io.LoadAlignment('data/del.fasta')
        aln.AttachView(1, tpl.CreateFullView())
        # disable aln_preprocessing to keep the deletion in the example aln
        result = modelling.BuildRawModel(aln, aln_preprocessing=False)
        residues = result.model.residues
        self.assertEqual(len(result.gaps), 1)
        self.assertEqual(result.gaps[0].before, residues[2])
        self.assertEqual(result.gaps[0].after, residues[3])
        self.assertEqual(result.gaps[0].seq, '')

    def testInsertion(self):
        # test if the result contains an "insertion" gap at the right spot.
        tpl = io.LoadPDB('data/gly.pdb')
        aln = io.LoadAlignment('data/ins.fasta')
        aln.AttachView(1, tpl.CreateFullView())
        result = modelling.BuildRawModel(aln)
        residues = result.model.residues
        self.assertEqual(len(result.gaps), 1)
        self.assertEqual(result.gaps[0].before, residues[1])
        self.assertEqual(result.gaps[0].after, residues[2])
        self.assertEqual(result.gaps[0].seq, 'AV')

    def testNonFeasiblePeptideBond(self):
        # test if the result contains a gap when two residues cannot be 
        # connected with a feasible peptide bond
        tpl = io.LoadPDB('data/gly_shifted_nter.pdb')
        aln = io.LoadAlignment('data/seq.fasta')
        aln.AttachView(1, tpl.CreateFullView())
        result = modelling.BuildRawModel(aln)
        residues = result.model.residues
        self.assertEqual(len(result.gaps), 1)
        self.assertEqual(result.gaps[0].before, residues[1])
        self.assertEqual(result.gaps[0].after, residues[2])
        self.assertEqual(result.gaps[0].seq, '')

    def testTer(self):
        # test if the result contains two terminal gaps, one at the beginning,
        # one at the end
        tpl = io.LoadPDB('data/gly.pdb')
        aln = io.LoadAlignment('data/ter.fasta')
        aln.AttachView(1, tpl.CreateFullView())
        result = modelling.BuildRawModel(aln)
        residues = result.model.residues
        self.assertEqual(len(result.gaps), 2)
        self.assertEqual(result.gaps[0].before, mol.ResidueHandle())
        self.assertEqual(result.gaps[0].after, residues[0])
        self.assertEqual(result.gaps[0].seq, 'G')
        self.assertEqual(result.gaps[1].before, residues[-1])
        self.assertEqual(result.gaps[1].after, mol.ResidueHandle())
        self.assertEqual(result.gaps[1].seq, 'G')

    def testTerTrg(self):
        # test if raw model ignores terminal gaps in target sequence
        tpl = io.LoadPDB('data/gly.pdb')
        aln = seq.CreateAlignment(
            seq.CreateSequence('trg', '--GG--'),
            seq.CreateSequence('tpl', 'GGGGGG'))
        aln.AttachView(1, tpl.CreateFullView())
        result = modelling.BuildRawModel(aln)
        self.assertEqual(len(result.gaps), 0)

    def testModified(self):
        # test if we correctly strip off modifications
        tpl = io.LoadPDB('data/sep.pdb')
        aln = io.LoadAlignment('data/sep.fasta')
        aln.AttachView(1, tpl.CreateFullView())
        result = modelling.BuildRawModel(aln)
        residues = result.model.residues
        self.assertEqual(len(residues), 1)
        self.assertEqual(len(residues[0].atoms), 6)
        self.assertTrue(residues[0].FindAtom("N"))
        self.assertTrue(residues[0].FindAtom("CA"))
        self.assertTrue(residues[0].FindAtom("C"))
        self.assertTrue(residues[0].FindAtom("O"))
        self.assertTrue(residues[0].FindAtom("CB"))
        self.assertTrue(residues[0].FindAtom("OG"))

    def testModifiedMismatch(self):
        # test if we allow OLC mismatch for modified AA
        tpl = io.LoadPDB('data/sep.pdb')
        aln = seq.CreateAlignment(
            seq.CreateSequence('trg', 'S'),
            seq.CreateSequence('tpl', 'X'))
        aln.AttachView(1, tpl.CreateFullView())
        result = modelling.BuildRawModel(aln)
        residues = result.model.residues
        # same as before as OLC of SEP is 'S' (i.e. matches)
        self.assertEqual(len(residues), 1)
        self.assertEqual(len(residues[0].atoms), 6)
        self.assertTrue(residues[0].FindAtom("N"))
        self.assertTrue(residues[0].FindAtom("CA"))
        self.assertTrue(residues[0].FindAtom("C"))
        self.assertTrue(residues[0].FindAtom("O"))
        self.assertTrue(residues[0].FindAtom("CB"))
        self.assertTrue(residues[0].FindAtom("OG"))
        # NOTE: relevant seq-vs-str mismatch tested in testOffset
        # See OST's nonstandard.cc for additional tests of handling modified
        # residues. Code duplication will be removed in SCHWED-3569.

    def testInsertCBeta(self):
        # test if the dst residues contain cbeta, unless they are glycines
        tpl = io.LoadPDB('data/cbeta.pdb')
        aln = io.LoadAlignment('data/cbeta.fasta')
        aln.AttachView(1, tpl.CreateFullView())
        result = modelling.BuildRawModel(aln)
        residues = result.model.residues
        self.assertFalse(residues[0].FindAtom("CB").IsValid())
        self.assertFalse(residues[1].FindAtom("CB").IsValid())
        self.assertTrue(residues[2].FindAtom("CB").IsValid())
        self.assertTrue(residues[3].FindAtom("CB").IsValid())

    def testOffset(self):
        # test if we can construct a raw model with an offset
        tpl = io.LoadPDB('data/2jlp-1.pdb')
        aln = io.LoadAlignment('data/2jlp-1.fasta')
        aln.AttachView(1, tpl.Select("cname=A").CreateFullView())
        # fail w/o offset
        with self.assertRaises(RuntimeError):
            mhandle = modelling.BuildRawModel(aln)
        # try with hard-coded offset
        aln.SetSequenceOffset(1, 55)
        mhandle = modelling.BuildRawModel(aln)
        self.assertEqual(len(mhandle.gaps), 4)

    def testOnTop(self):
        # do we clean up atoms that are on top of each other?
        # ResNum 8 has CA-pos == C-pos
        tpl = io.LoadPDB('data/gly_on_top.pdb')
        aln = io.LoadAlignment('data/seq.fasta')
        aln.AttachView(1, tpl.CreateFullView())
        result = modelling.BuildRawModel(aln)
        self.assertEqual(len(result.gaps), 1)
        self.assertEqual(result.gaps[0].before.number.num, 7)
        self.assertEqual(result.gaps[0].after.number.num, 9)

    #######################################################################

    def testModellingHandleOperations(self):
        # handle with gap
        tpl = io.LoadPDB('data/1crn_cut.pdb')
        aln = io.LoadAlignment('data/1crn.fasta')
        aln.AttachView(1, tpl.CreateFullView())
        mhandle = modelling.BuildRawModel(aln)

        # check ModellingHandle
        self.assertEqual(len(mhandle.gaps), 1)
        self.assertEqual(str(mhandle.gaps[0]), 'A.ALA24-(ICATYT)-A.GLY31')
        self.checkModel(mhandle)

        # check copy
        mhandle_copy = mhandle.Copy()
        self.assertEqual([str(a) for a in mhandle_copy.model.atoms],
                         [str(a) for a in mhandle.model.atoms])
        self.assertEqual([str(g) for g in mhandle_copy.gaps],
                         [str(g) for g in mhandle.gaps])
        self.assertEqual([str(s) for s in mhandle_copy.seqres],
                         [str(s) for s in mhandle.seqres])
        # rest of the fields not checked...yet

        # handle without gap
        tpl_closed = io.LoadPDB('data/1crn_build.pdb')
        seqres = str(aln.sequences[0])
        aln_closed = seq.CreateAlignment(seq.CreateSequence('trg', seqres),
                                         seq.CreateSequence('tpl', seqres))
        aln_closed.AttachView(1, tpl_closed.CreateFullView())
        mhandle_closed = modelling.BuildRawModel(aln_closed)

        # merge with no overlap (no change)
        mhandle_copy = mhandle.Copy()
        modelling.MergeMHandle(mhandle_closed, mhandle_copy, 0, 0, 10, 19,
                               geom.Mat4())
        self.assertEqual([str(a) for a in mhandle_copy.model.atoms],
                         [str(a) for a in mhandle.model.atoms])
        self.assertEqual([str(g) for g in mhandle_copy.gaps],
                         [str(g) for g in mhandle.gaps])
        self.assertEqual([str(s) for s in mhandle_copy.seqres],
                         [str(s) for s in mhandle.seqres])
        self.checkModel(mhandle_copy)

        # merge with overlap on N-side
        mhandle_copy = mhandle.Copy()
        modelling.MergeMHandle(mhandle_closed, mhandle_copy, 0, 0, 20, 29,
                               geom.Mat4())
        self.assertEqual(len(mhandle_copy.gaps), 1)
        self.assertEqual(str(mhandle_copy.gaps[0]), 'A.TYR29-(T)-A.GLY31')
        self.checkModel(mhandle_copy)

        # merge with overlap on C-side
        mhandle_copy = mhandle.Copy()
        modelling.MergeMHandle(mhandle_closed, mhandle_copy, 0, 0, 27, 35,
                               geom.Mat4())
        self.assertEqual(len(mhandle_copy.gaps), 1)
        self.assertEqual(str(mhandle_copy.gaps[0]), 'A.ALA24-(IC)-A.ALA27')
        self.checkModel(mhandle_copy)

        # merge with full overlap
        mhandle_copy = mhandle.Copy()
        modelling.MergeMHandle(mhandle_closed, mhandle_copy, 0, 0, 20, 35,
                               geom.Mat4())
        self.assertEqual(len(mhandle_copy.gaps), 0)
        self.checkModel(mhandle_copy)

    #######################################################################

    def testMergeGaps(self):
        tpl = io.LoadPDB('data/1mcg.pdb')
        aln = seq.CreateAlignment(seq.CreateSequence('trg', 'DDFAGTHN'),
                                  seq.CreateSequence('tpl', 'N-N-A-LF'))
        aln.AttachView(1, tpl.CreateFullView())
        mhandle = modelling.BuildRawModel(aln)
        # check
        seqres = ''.join([r.one_letter_code for r in mhandle.model.residues])
        self.assertEqual(seqres, 'DFGHN')
        self.assertEqual(len(mhandle.gaps), 3)
        modelling.MergeGaps(mhandle, 0)
        seqres = ''.join([r.one_letter_code for r in mhandle.model.residues])
        self.assertEqual(seqres, 'DGHN')
        self.assertEqual(len(mhandle.gaps), 2)
        self.assertEqual(str(mhandle.gaps[0]), 'A.ASP1-(DFA)-A.GLY5')
        modelling.MergeGaps(mhandle, 0)
        seqres = ''.join([r.one_letter_code for r in mhandle.model.residues])
        self.assertEqual(seqres, 'DHN')
        self.assertEqual(len(mhandle.gaps), 1)
        self.assertEqual(str(mhandle.gaps[0]), 'A.ASP1-(DFAGT)-A.HIS7')
        # last gap: should throw exception
        with self.assertRaises(RuntimeError):
            modelling.MergeGaps(mhandle, 0)

    def testMergeGapsOligo(self):
        # check that merge gaps fails if on diff. chains
        tpl = io.LoadPDB('data/5d52-1.pdb')
        aln_A = seq.CreateAlignment(
            seq.CreateSequence('trg', 'GIVEQAAACCTSICSLYQLENYCN'),
            seq.CreateSequence('tpl', 'GIVEQ---CCTSICSLYQLENYCN'))
        aln_B = seq.CreateAlignment(
            seq.CreateSequence('trg', 'FVNQHLCG---LEALTLVCGERGFFYTPKA'),
            seq.CreateSequence('tpl', 'FVNQHLCGSHLVEALYLVCGERGFFYTPKA'))
        aln_A.AttachView(1, tpl.Select("cname=A").CreateFullView())
        aln_B.AttachView(1, tpl.Select("cname=B").CreateFullView())
        alns = seq.AlignmentList()
        alns.append(aln_A)
        alns.append(aln_B)
        mhandle = modelling.BuildRawModel(alns)
        # check
        self.assertEqual(len(mhandle.gaps), 2)
        with self.assertRaises(RuntimeError):
            modelling.MergeGaps(mhandle, 0)

    def testCountEnclosedGaps(self):
        tpl = io.LoadPDB('data/2dbs.pdb')
        aln = seq.CreateAlignment(
            seq.CreateSequence('trg', 'GATLNGFTVPAGNTLVLN---PD--KG---ATVTMAGA'),
            seq.CreateSequence('tpl', '--NGG--TL--LI--PNGTYHFLGIQMKSNVHIRVE--'))
        aln.AttachView(1, tpl.CreateFullView())
        mhandle = modelling.BuildRawModel(aln)
        # check
        self.assertEqual(len(mhandle.gaps), 8)
        for g in mhandle.gaps:
          self.assertEqual(modelling.CountEnclosedGaps(mhandle, g), 1)
          self.assertEqual(modelling.CountEnclosedInsertions(mhandle, g),
                           1 if (g.length > 0) else 0)

        mych = mhandle.model.chains[0]
        # none
        mygap = modelling.StructuralGap(mych.FindResidue(15),
                                        mych.FindResidue(17), "L")
        self.assertEqual(modelling.CountEnclosedGaps(mhandle, mygap), 0)
        self.assertEqual(modelling.CountEnclosedInsertions(mhandle, mygap), 0)
        # extended singles
        mygap = modelling.StructuralGap(mych.FindResidue(3),
                                        mych.FindResidue(9), "LNGFT")
        self.assertEqual(modelling.CountEnclosedGaps(mhandle, mygap), 1)
        self.assertEqual(modelling.CountEnclosedInsertions(mhandle, mygap), 1)
        mygap = modelling.StructuralGap(mych.FindResidue(20),
                                        mych.FindResidue(22), "K")
        self.assertEqual(modelling.CountEnclosedGaps(mhandle, mygap), 1)
        self.assertEqual(modelling.CountEnclosedInsertions(mhandle, mygap), 0)
        # doubles
        mygap = modelling.StructuralGap(mych.FindResidue(3),
                                        mych.FindResidue(12), "LNGFTVPA")
        self.assertEqual(modelling.CountEnclosedGaps(mhandle, mygap), 2)
        self.assertEqual(modelling.CountEnclosedInsertions(mhandle, mygap), 2)
        mygap = modelling.StructuralGap(mych.FindResidue(13),
                                        mych.FindResidue(19), "TLVLN")
        self.assertEqual(modelling.CountEnclosedGaps(mhandle, mygap), 2)
        self.assertEqual(modelling.CountEnclosedInsertions(mhandle, mygap), 1)
        mygap = modelling.StructuralGap(mych.FindResidue(20),
                                        mych.FindResidue(25), "KGAT")
        self.assertEqual(modelling.CountEnclosedGaps(mhandle, mygap), 2)
        self.assertEqual(modelling.CountEnclosedInsertions(mhandle, mygap), 0)

    #######################################################################

    def testClearGapsExceptions(self):
        # check that clear gaps throws exceptions when used wrongly
        tpl = io.LoadPDB('data/1mcg.pdb')
        aln = seq.CreateAlignment(seq.CreateSequence('trg', 'DDFAGTHN'),
                                  seq.CreateSequence('tpl', 'N-N-A-LF'))
        aln.AttachView(1, tpl.CreateFullView())
        mhandle = modelling.BuildRawModel(aln)
        # get gaps before we remove residues in MergeGaps
        mych = mhandle.model.chains[0]
        mygap = modelling.StructuralGap(mych.FindResidue(0),
                                        mych.FindResidue(3), "DD")
        mygap2 = modelling.StructuralGap(mych.FindResidue(3),
                                         mych.FindResidue(9), "AGTHN")
        mygap3 = modelling.StructuralGap(mych.FindResidue(3),
                                         mych.FindResidue(5), "A")
        # check
        modelling.MergeGaps(mhandle, 0)
        modelling.MergeGaps(mhandle, 0)
        self.assertEqual(len(mhandle.gaps), 1)
        with self.assertRaises(RuntimeError):
            modelling.ClearGaps(mhandle, mygap)
        with self.assertRaises(RuntimeError):
            modelling.ClearGaps(mhandle, mygap2)
        with self.assertRaises(RuntimeError):
            modelling.ClearGaps(mhandle, mygap3)
        self.assertEqual(len(mhandle.gaps), 1)

    def testClearGaps(self):
        tpl = io.LoadPDB('data/2dbs.pdb')
        aln = seq.CreateAlignment(
            seq.CreateSequence('trg', 'TLNGFTVPAGNTLVLN---PDKG--ATVTM-A'),
            seq.CreateSequence('tpl', 'N-GG-TLLI--PNGTYHFLGIQMKSNVHIRVE'))
        aln.AttachView(1, tpl.CreateFullView())
        # disable aln_preprocessing to also keep the small deletion in the end
        mhandle = modelling.BuildRawModel(aln, aln_preprocessing=False)
        # check
        self.assertEqual(len(mhandle.gaps), 6)
        self.assertEqual(modelling.ClearGaps(mhandle, mhandle.gaps[1]), 1)
        self.assertEqual(len(mhandle.gaps), 5)
        self.assertEqual(modelling.ClearGaps(mhandle, mhandle.gaps[2]), 2)
        self.assertEqual(len(mhandle.gaps), 4)
        # special gaps
        mych = mhandle.model.chains[0]
        mygap = modelling.StructuralGap(mych.FindResidue(8),
                                        mych.FindResidue(13), "AGNT")
        self.assertEqual(modelling.ClearGaps(mhandle, mygap), 1)
        self.assertEqual(len(mhandle.gaps), 3)
        mygap = modelling.StructuralGap(mych.FindResidue(19),
                                        mych.FindResidue(22), "GA")
        self.assertEqual(modelling.ClearGaps(mhandle, mygap), 1)
        self.assertEqual(len(mhandle.gaps), 2)
        mygap = modelling.StructuralGap(mych.FindResidue(25),
                                        mych.FindResidue(27), "A")
        self.assertEqual(modelling.ClearGaps(mhandle, mygap), -1)
        self.assertEqual(len(mhandle.gaps), 1)
        mygap = modelling.StructuralGap(mych.FindResidue(0),
                                        mych.FindResidue(3), "TL")
        self.assertEqual(modelling.ClearGaps(mhandle, mygap), -1)
        self.assertEqual(len(mhandle.gaps), 0)
        self.assertEqual(modelling.ClearGaps(mhandle, mygap), -1)

    def testClearMultipleGaps(self):
        # check that we can clear multiple gaps at once
        tpl = io.LoadPDB('data/2dbs.pdb')
        aln = seq.CreateAlignment(
            seq.CreateSequence('trg', 'TLNGFTVPAGNTLVLN---PD--KG---ATVTMA'),
            seq.CreateSequence('tpl', 'NGG--TL--LI--PNGTYHFLGIQMKSNVHIRVE'))
        aln.AttachView(1, tpl.CreateFullView())
        mhandle = modelling.BuildRawModel(aln)
        # check
        self.assertEqual(len(mhandle.gaps), 6)
        mych = mhandle.model.chains[0]
        mygap = modelling.StructuralGap(mych.FindResidue(1),
                                        mych.FindResidue(10), "LNGFTVPA")
        self.assertEqual(modelling.ClearGaps(mhandle, mygap), 0)
        self.assertEqual(len(mhandle.gaps), 4)
        mygap = modelling.StructuralGap(mych.FindResidue(11),
                                        mych.FindResidue(17), "TLVLN")
        self.assertEqual(modelling.ClearGaps(mhandle, mygap), 0)
        self.assertEqual(len(mhandle.gaps), 2)
        mygap = modelling.StructuralGap(mych.FindResidue(18),
                                        mych.FindResidue(23), "KGAT")
        self.assertEqual(modelling.ClearGaps(mhandle, mygap), -1)
        self.assertEqual(len(mhandle.gaps), 0)

    def testClearGapsTermnini(self):
        # check that we can clear terminal gaps
        tpl = io.LoadPDB('data/1mcg.pdb')
        aln = seq.CreateAlignment(seq.CreateSequence('trg', 'DDFAGTHN'),
                                  seq.CreateSequence('tpl', '--NNALF-'))
        aln.AttachView(1, tpl.CreateFullView())
        mhandle = modelling.BuildRawModel(aln)
        # check
        self.assertEqual(len(mhandle.gaps), 2)
        self.assertEqual(modelling.ClearGaps(mhandle, mhandle.gaps[0]), 0)
        self.assertEqual(len(mhandle.gaps), 1)
        self.assertEqual(modelling.ClearGaps(mhandle, mhandle.gaps[0]), -1)
        self.assertEqual(len(mhandle.gaps), 0)

    def testClearGapsOligo(self):
        # check that we can clear gaps in oligomers
        tpl = io.LoadPDB('data/5d52-1.pdb')
        aln_A = seq.CreateAlignment(
            seq.CreateSequence('trg', 'GIVEQPAYFWPPPAHCCTSICSLYQLENYCN'),
            seq.CreateSequence('tpl', 'GIVEQ----------CCTSICSLYQLENYCN'))
        aln_B = seq.CreateAlignment(
            seq.CreateSequence('trg', 'FVNQHLCGSGHLVEALYLVCGERGFFYTPKA'),
            seq.CreateSequence('tpl', 'FVNQHLCGS-HLVEALYLVCGERGFFYTPKA'))
        aln_A.AttachView(1, tpl.Select("cname=A").CreateFullView())
        aln_B.AttachView(1, tpl.Select("cname=B").CreateFullView())
        alns = seq.AlignmentList()
        alns.append(aln_A)
        alns.append(aln_B)
        mhandle = modelling.BuildRawModel(alns)
        # check
        self.assertEqual(len(mhandle.gaps), 2)
        self.assertEqual(modelling.ClearGaps(mhandle, mhandle.gaps[0]), 0)
        self.assertEqual(len(mhandle.gaps), 1)
        self.assertEqual(modelling.ClearGaps(mhandle, mhandle.gaps[0]), -1)
        self.assertEqual(len(mhandle.gaps), 0)

    #######################################################################

    def testRemoveTerminalGaps(self):
        # check that we can remove both terminal gaps
        tpl = io.LoadPDB('data/1mcg.pdb')
        aln = seq.CreateAlignment(seq.CreateSequence('trg', 'DDFAGTHN'),
                                  seq.CreateSequence('tpl', '--NNALF-'))
        aln.AttachView(1, tpl.CreateFullView())
        mhandle = modelling.BuildRawModel(aln)
        # check
        self.assertEqual(len(mhandle.gaps), 2)
        self.assertEqual(modelling.RemoveTerminalGaps(mhandle), 2)
        self.assertEqual(len(mhandle.gaps), 0)

    def testRemoveTerminalGapsN(self):
        # check that we can remove N-terminal gaps
        tpl = io.LoadPDB('data/1mcg.pdb')
        aln = seq.CreateAlignment(seq.CreateSequence('trg', 'DDFAGTH'),
                                  seq.CreateSequence('tpl', '--NNALF'))
        aln.AttachView(1, tpl.CreateFullView())
        mhandle = modelling.BuildRawModel(aln)
        # check
        self.assertEqual(len(mhandle.gaps), 1)
        self.assertEqual(modelling.RemoveTerminalGaps(mhandle), 1)
        self.assertEqual(len(mhandle.gaps), 0)

    def testRemoveTerminalGapsC(self):
        # check that we can remove C-terminal gaps
        tpl = io.LoadPDB('data/1mcg.pdb')
        aln = seq.CreateAlignment(seq.CreateSequence('trg', 'FAGTHN'),
                                  seq.CreateSequence('tpl', 'NNALF-'))
        aln.AttachView(1, tpl.CreateFullView())
        mhandle = modelling.BuildRawModel(aln)
        # check
        self.assertEqual(len(mhandle.gaps), 1)
        self.assertEqual(modelling.RemoveTerminalGaps(mhandle), 1)
        self.assertEqual(len(mhandle.gaps), 0)
        

if __name__ == "__main__":
    from ost import testutils
    testutils.RunTests()