import itertools, pandas as pd, numpy as np, matplotlib.pyplot as plt, seaborn as sns, openbabel, openbabel.pybel, scipy, scipy.stats, openbabel.pybel

from modules import *

rule smiles_dude:
    input:
        smi_actives = 'resources/DHODH_DUD-E/pyrd/actives_final.ism',
        smi_decoys = 'resources/DHODH_DUD-E/pyrd/decoys_final.ism',
    output:
        smi_actives = 'results/smiles/dude_actives.smi',
        smi_decoys = 'results/smiles/dude_decoys.smi',
    params:
        random_source = 'resources/DHODH_DUD-E/pyrd.tar.gz',
    shell: """
        cp {input.smi_actives} {output.smi_actives}
        shuf -n 645 --random-source={params.random_source} {input.smi_decoys} -o {output.smi_decoys}
    """

localrules: rdconf

rule rdconf:
    input:
        smi = 'results/smiles/{compounds_id}.smi',
    output:
        sdf = 'results/rdconf/{compounds_id}.sdf',
    shell: """
        python software/rdkit-scripts/rdconf.py --maxconfs 1 {input.smi} {output.sdf}
    """

# https://doi.org/10.1038/s41598-021-91069-7
# https://onlinelibrary.wiley.com/doi/full/10.1002/jhet.3644
# https://www.nature.com/articles/s41598-022-23006-1
# https://www.sciencedirect.com/science/article/pii/S0022286020310176
# https://www.mdpi.com/1420-3049/27/12/3660
pdb_lig = {
    '1d3g': 'BRE',
    '1d3h': 'A26',
    '4igh': '1EA',
    '6gk0': 'F1W',
    '6j3c': 'B5X',
    '5zf4': '9BL',
    '3g0u': 'MDY',
    '4oqv': '2V6',
    '3u2o': '03U',
    '6oc0': 'M4J',
}

localrules: pdb

rule pdb:
    output:
        'results/pdb/{struct_id}.pdb',
        temp('results/pdb/{struct_id}_fetch.pdb'),
        temp('results/pdb/{struct_id}_pdbfixer.pdb'),
        temp('results/pdb/{struct_id}_pdbfixer_aligned.pdb'),
    params:
        delres = 'SO4,GOL,CL,ACY,ACT,LDA,ZWI,DDQ,NA,DOR',
        lig = lambda wc: pdb_lig[wc.struct_id],
        ref = 'resources/alphafill/AF-Q02127-F1-model_v2-FMN.pdb'
    shell: """
        pdb_fetch -biounit {wildcards.struct_id} | pdb_delresname -HOH > results/pdb/{wildcards.struct_id}_fetch.pdb
        pdbfixer results/pdb/{wildcards.struct_id}_fetch.pdb --add-residues --add-atoms=heavy --output=results/pdb/{wildcards.struct_id}_pdbfixer.pdb
        cd results/pdb
        prody align ../../{params.ref} {wildcards.struct_id}_pdbfixer.pdb
        cd ../..
        pdb_delresname -{params.delres},{params.lig},ORO results/pdb/{wildcards.struct_id}_pdbfixer_aligned.pdb > results/pdb/{wildcards.struct_id}.pdb
        pdb_selresname -{params.lig} results/pdb/{wildcards.struct_id}_pdbfixer_aligned.pdb > results/pdb/{wildcards.struct_id}_lig.pdb
    """

rule reduce:
    input:
        pdb = 'results/pdb/{struct_id}.pdb',
    output:
        pdb = 'results/pdb.reduce/{struct_id}.pdb',
        fix = 'results/pdb.reduce/{struct_id}.fix.pdb',
        pdbqt = 'results/pdb.reduce/{struct_id}.pdbqt',
    shell: """
        #export REDUCE_HET_DICT=/cluster/project/beltrao/jjaenes/software/miniconda3/envs/adfr-suite/bin/reduce_wwPDB_het_dict.txt
        #conda run -p /cluster/project/beltrao/jjaenes/software/miniconda3/envs/adfr-suite reduce {input.pdb} > {output.pdb}
        export REDUCE_HET_DICT=software/reduce/reduce_wwPDB_het_dict.txt
        software/reduce/reduce_src/reduce -FLIP {input.pdb} > {output.pdb} || true
        pdbfixer {output.pdb} --output={output.fix}
        obabel {output.fix} -xr --partialcharge gasteiger -O{output.pdbqt}
    """

rule smina:
    input:
        sdf = 'results/rdconf/{compounds_id}.sdf',
        pdb = 'results/pdb/{struct_id}.pdb',
    output:
        sdf = 'results/pdb.smina/{struct_id}+{compounds_id}.sdf',
    threads: 64
    shell: """
        software/bin/smina --cpu {threads}\
            --ligand {input.sdf}\
            --receptor {input.pdb}\
            --autobox_ligand {input.pdb}\
            --out {output.sdf}\
            --exhaustiveness 64\
            --seed 4
    """

rule gnina:
    input:
        sdf = 'results/rdconf/{compounds_id}.sdf',
        pdb = 'results/pdb/{struct_id}.pdb',
    output:
        sdf = 'results/pdb.gnina/{struct_id}+{compounds_id}.sdf',
    threads: 64
    shell: """
        module load gcc/9.3.0; software/bin/gnina --cpu {threads}\
            --ligand {input.sdf}\
            --receptor {input.pdb}\
            --autobox_ligand {input.pdb}\
            --out {output.sdf}\
            --exhaustiveness 64\
            --seed 4 --no_gpu --cnn crossdock_default2018
    """

localrules: lig_decoys_plot

rule lig_decoys_plot:
    input:
        sdf = [ 'results/pdb.smina/{struct_id}+%s.sdf' % (compounds_id,) for compounds_id in read_dude()['compounds_id'] ],
    output:
        png = 'results/pdb.smina.plot_lig_decoys/{struct_id}.png',
        tsv = 'results/pdb.smina.plot_lig_decoys/{struct_id}.tsv',
    run:
        df_ = read_dude()
        df_['sdf_out'] = list(input.sdf)

        df_scores_ = pd.concat([ read_affinities_dude(r.sdf_out, r.compounds_id) for i, r in df_.iterrows() ], axis=0)
        fig, ax = plt.subplots(1, 2, sharey=True, figsize=(10,8))
        order_ = df_scores_.query('~is_decoy').sort_values('decoyNormalisedAffinity', ascending=False)['label'].to_list()
        sns.violinplot(data=df_scores_, x='minimizedAffinity', y='label', hue='is_decoy', ax=ax[0], split=True, cut=0, inner='stick', order=order_)#, color='tab:blue')
        sns.barplot(data=df_scores_.query('~is_decoy'), x='decoyNormalisedAffinity', y='label', ax=ax[1], color='tab:blue', order=order_)
        #ax[1].axvline(1.25, color='tab:red')
        fig.savefig(output.png, bbox_inches='tight')
        df_scores_.to_csv(output.tsv, sep='\t', header=True, index=False)

rule all:
    """
        profile_euler/run_local --dry-run
        profile_euler/run_sbatch --dry-run
    """
    default_target: True
    input:
        #[f'results/pdb.smina/{struct_id}+dude_actives.sdf' for struct_id in pdb_lig.keys() ],
        #[f'results/pdb.smina/{struct_id}+dude_decoys.sdf' for struct_id in pdb_lig.keys() ],
        #[f'results/pdb.smina/{struct_id}+DHODH_ChEMBL.sdf'  for struct_id in pdb_lig.keys() ],
        #[f'results/pdb.smina/{struct_id}+DHODH_inhibitors.sdf'  for struct_id in pdb_lig.keys() ],
        #[f'results/pdb.smina/{struct_id}+DHODH_inhibitors_pubchem.sdf'  for struct_id in pdb_lig.keys() ],
        #[f'results/pdb.gnina/{struct_id}+dude_actives.sdf' for struct_id in pdb_lig.keys() ],
        #[f'results/pdb.gnina/{struct_id}+dude_decoys.sdf' for struct_id in pdb_lig.keys() ],
        #[f'results/pdb.gnina/{struct_id}+DHODH_ChEMBL.sdf'  for struct_id in pdb_lig.keys() ],
        #[f'results/pdb.gnina/{struct_id}+DHODH_inhibitors.sdf' for struct_id in pdb_lig.keys() ],
        #[f'results/pdb.gnina/{struct_id}+DHODH_inhibitors_pubchem.sdf' for struct_id in pdb_lig.keys() ],
        #expand('results/pdb.smina/{struct_id}+{compounds_id}.sdf', struct_id=pdb_lig.keys(), compounds_id=read_dude()['compounds_id']),
        #expand('results/pdb.gnina/{struct_id}+{compounds_id}.sdf', struct_id=pdb_lig.keys(), compounds_id=read_dude()['compounds_id']),
        expand('results/pdb.smina.plot_lig_decoys/{struct_id}.png', struct_id=pdb_lig.keys())