import unittest
import math

from ost import geom
from ost import io

def compare_atoms(a1, a2, occupancy_thresh = 0.01, bfactor_thresh = 0.01,
                  dist_thresh = 0.001):
    if abs(a1.occupancy - a2.occupancy) > occupancy_thresh:
        return False
    if abs(a1.b_factor - a2.b_factor) > bfactor_thresh:
        return False
    if geom.Distance(a1.GetPos(), a2.GetPos()) > dist_thresh:
        return False
    if a1.is_hetatom != a2.is_hetatom:
        return False
    if a1.element != a2.element:
        return False
    return True

def compare_residues(r1, r2, at_occupancy_thresh = 0.01,
                     at_bfactor_thresh = 0.01, at_dist_thresh = 0.001,
                     skip_ss = False, skip_rnums=False):
    if r1.GetName() != r2.GetName():
        return False
    if skip_rnums is False:
        if r1.GetNumber() != r2.GetNumber():
            return False
    if skip_ss is False:
        if str(r1.GetSecStructure()) != str(r2.GetSecStructure()):
            return False
    if r1.one_letter_code != r2.one_letter_code:
        return False
    if r1.chem_type != r2.chem_type:
        return False
    if r1.chem_class != r2.chem_class:
        return False
    anames1 = [a.GetName() for a in r1.atoms]
    anames2 = [a.GetName() for a in r2.atoms]
    if sorted(anames1) != sorted(anames2):
        return False
    anames = anames1
    for aname in anames:
        a1 = r1.FindAtom(aname)
        a2 = r2.FindAtom(aname)
        if not compare_atoms(a1, a2,
                             occupancy_thresh = at_occupancy_thresh,
                             bfactor_thresh = at_bfactor_thresh,
                             dist_thresh = at_dist_thresh):
            return False
    return True

def compare_chains(ch1, ch2, at_occupancy_thresh = 0.01,
                   at_bfactor_thresh = 0.01, at_dist_thresh = 0.001,
                   skip_ss=False, skip_rnums=False):
    if len(ch1.residues) != len(ch2.residues):
        return False
    for r1, r2 in zip(ch1.residues, ch2.residues):
        if not compare_residues(r1, r2,
                                at_occupancy_thresh = at_occupancy_thresh,
                                at_bfactor_thresh = at_bfactor_thresh,
                                at_dist_thresh = at_dist_thresh,
                                skip_ss = skip_ss, skip_rnums=skip_rnums):
            return False
    return True

def compare_bonds(ent1, ent2):
    bonds1 = list()
    for b in ent1.bonds:
        bond_partners = [str(b.first), str(b.second)]
        bonds1.append([min(bond_partners), max(bond_partners), b.bond_order])
    bonds2 = list()
    for b in ent2.bonds:
        bond_partners = [str(b.first), str(b.second)]
        bonds2.append([min(bond_partners), max(bond_partners), b.bond_order])
    return sorted(bonds1) == sorted(bonds2)

def compare_ent(ent1, ent2, at_occupancy_thresh = 0.01,
                at_bfactor_thresh = 0.01, at_dist_thresh = 0.001,
                skip_ss=False, skip_cnames = False, skip_bonds = False,
                skip_rnums=False, bu_idx = None):
    if bu_idx is not None:
        if ent1.GetName() + ' ' + str(bu_idx) != ent2.GetName():
            return False
    else:
        if ent1.GetName() != ent2.GetName():
            return False
    chain_names_one = [ch.GetName() for ch in ent1.chains]
    chain_names_two = [ch.GetName() for ch in ent2.chains]
    if skip_cnames:
        # only check whether we have the same number of chains
        if len(chain_names_one) != len(chain_names_two):
            return False
    else:
        if chain_names_one != chain_names_two:
            return False
    for ch1, ch2 in zip(ent1.chains, ent2.chains):
        if not compare_chains(ch1, ch2,
                              at_occupancy_thresh = at_occupancy_thresh,
                              at_bfactor_thresh = at_bfactor_thresh,
                              at_dist_thresh = at_dist_thresh,
                              skip_ss=skip_ss, skip_rnums=skip_rnums):
            return False
    if not skip_bonds:
        if not compare_bonds(ent1, ent2):
            return False
    return True

class TestOMF(unittest.TestCase):

    def setUp(self):
        ent, seqres, info = io.LoadMMCIF("testfiles/mmcif/3T6C.cif.gz", 
                                         seqres=True,
                                         info=True)
        self.ent = ent
        self.seqres = seqres
        self.info = info
        self.ent.SetName("This is a name 123")

    def test_AU(self):
        omf = io.OMF.FromEntity(self.ent)
        omf_bytes = omf.ToBytes()
        loaded_omf = io.OMF.FromBytes(omf_bytes)
        loaded_ent = loaded_omf.GetAU()
        self.assertTrue(compare_ent(self.ent, loaded_ent))

    def test_default_peplib(self):
        omf = io.OMF.FromEntity(self.ent)
        omf_bytes = omf.ToBytes()
        omf_def_pep = io.OMF.FromEntity(self.ent, options = io.OMFOption.DEFAULT_PEPLIB)
        omf_def_pep_bytes = omf_def_pep.ToBytes()
        loaded_omf_def_pep = io.OMF.FromBytes(omf_def_pep_bytes)
        loaded_ent = loaded_omf_def_pep.GetAU()

        self.assertTrue(len(omf_def_pep_bytes) < len(omf_bytes))
        self.assertTrue(compare_ent(self.ent, loaded_ent))

    def test_avg_bfactors(self):
        omf = io.OMF.FromEntity(self.ent)
        omf_bytes = omf.ToBytes()
        omf_avg_bfac = io.OMF.FromEntity(self.ent, options = io.OMFOption.AVG_BFACTORS)
        omf_avg_bfac_bytes = omf_avg_bfac.ToBytes()
        loaded_omf_avg_bfac = io.OMF.FromBytes(omf_avg_bfac_bytes)
        loaded_ent = loaded_omf_avg_bfac.GetAU()

        self.assertTrue(len(omf_avg_bfac_bytes) < len(omf_bytes))
        self.assertFalse(compare_ent(self.ent, loaded_ent))
        # just give a huge slack for bfactors and check averaging manually
        self.assertTrue(compare_ent(self.ent, loaded_ent,
                                    at_bfactor_thresh=1000))

        self.assertEqual(len(self.ent.residues), len(loaded_ent.residues))
        for r_ref, r in zip(self.ent.residues, loaded_ent.residues):
            exp_bfac = sum([a.b_factor for a in r_ref.atoms])
            exp_bfac /= r_ref.atom_count
            for a in r.atoms:
                self.assertTrue(abs(a.b_factor - exp_bfac) < 0.008)

    def test_round_bfactors(self):
        omf = io.OMF.FromEntity(self.ent)
        omf_bytes = omf.ToBytes()
        omf_round_bfac = io.OMF.FromEntity(self.ent, options = io.OMFOption.ROUND_BFACTORS)
        omf_round_bfac_bytes = omf_round_bfac.ToBytes()
        loaded_omf_round_bfac = io.OMF.FromBytes(omf_round_bfac_bytes)
        loaded_ent = loaded_omf_round_bfac.GetAU()

        self.assertTrue(len(omf_round_bfac_bytes) < len(omf_bytes))
        self.assertFalse(compare_ent(self.ent, loaded_ent))
        self.assertTrue(compare_ent(self.ent, loaded_ent,
                                    at_bfactor_thresh=0.5))

    def test_skip_ss(self):
        omf = io.OMF.FromEntity(self.ent)
        omf_bytes = omf.ToBytes()
        omf_skip_ss = io.OMF.FromEntity(self.ent, options = io.OMFOption.SKIP_SS)
        omf_skip_ss_bytes = omf_skip_ss.ToBytes()
        loaded_omf_skip_ss = io.OMF.FromBytes(omf_skip_ss_bytes)
        loaded_ent = loaded_omf_skip_ss.GetAU()

        self.assertTrue(len(omf_skip_ss_bytes) < len(omf_bytes))
        self.assertFalse(compare_ent(self.ent, loaded_ent))
        self.assertTrue(compare_ent(self.ent, loaded_ent, skip_ss=True))

    def test_infer_pep_bonds(self):
        omf = io.OMF.FromEntity(self.ent)
        omf_bytes = omf.ToBytes()
        omf_infer_pep_bonds = io.OMF.FromEntity(self.ent,
                                                options = io.OMFOption.INFER_PEP_BONDS)
        omf_infer_pep_bonds_bytes = omf_infer_pep_bonds.ToBytes()
        loaded_omf_infer_pep_bonds = io.OMF.FromBytes(omf_infer_pep_bonds_bytes)
        loaded_ent = loaded_omf_infer_pep_bonds.GetAU()

        self.assertTrue(len(omf_infer_pep_bonds_bytes) < len(omf_bytes))
        self.assertTrue(compare_ent(self.ent, loaded_ent))

    def test_lower_precition(self):
        omf = io.OMF.FromEntity(self.ent, max_error=0.5)
        omf_bytes = omf.ToBytes()
        loaded_omf = io.OMF.FromBytes(omf_bytes)
        loaded_ent = loaded_omf.GetAU()
        self.assertFalse(compare_ent(self.ent, loaded_ent))
        self.assertTrue(compare_ent(self.ent, loaded_ent, at_dist_thresh=0.5))


if __name__== '__main__':
    from ost import testutils
    if testutils.DefaultCompoundLibIsSet():
        testutils.RunTests()
    else:
        print('No compound library available. Ignoring test_stereochemistry.py tests.')