import sys
import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
from scipy import stats
from sklearn import preprocessing
import random
import joblib


seed=1337
random.seed(seed)
np.random.seed(seed)
tf.random.set_seed(seed)

'''
auxiliary method converts input sequences to float vectors
'''
def convert_seq2vec(in_string):
	out_vec = []
	for i in range(0,len(in_string)):
		out_vec.append(nt_dict[in_string[i]])
	return keras.utils.to_categorical(out_vec,num_classes=max_features,dtype='float32')


'''
reads and prepares the data

transforms the output data with sklearn tools by standard-scaling
input data by MinMax-scaling
'''
def prepare_data(test_split, input_file, prim_input_col, output_col, metadata_cols, qc_col=None, ordering_asc=False, target_logarithmic=False):
	#load raw data from csv
	if '.tsv' in input_file:
		sep='\t'
	else:
		sep=','

	raw_data = pd.read_csv(input_file, low_memory=False, sep=sep)
	raw_data.reset_index(inplace=True, drop=True)

	#only take short enough input data
	raw_data = raw_data.loc[(raw_data[prim_input_col].str.len() <= maxlen_prim)]

	#take log
	if target_logarithmic == True:
		raw_data[output_col] = np.log(raw_data[output_col].loc[raw_data[output_col] > 0.0])

	if qc_col == None:
		raw_data = raw_data.sample(frac=1).reset_index(drop=True)
	else:
		raw_data = raw_data.sort_values(by=[qc_col], ascending=ordering_asc)

	nbr_total = raw_data.shape[0]
	nbr_test = int(test_split*nbr_total)
	nbr_train = nbr_total-nbr_test

	raw_test = raw_data.iloc[:nbr_test]
	raw_train = raw_data.iloc[nbr_test:]

	#random reshuffeling of training data
	raw_train = raw_train.sample(frac=1).reset_index(drop=True)

	#set scalers for output and metadata
	output_scaler = preprocessing.StandardScaler()
	metadata_scalers = []
	for md in metadata_cols:
		metadata_scalers.append(preprocessing.MinMaxScaler(feature_range=(0,1)))
	
	#scale data
	target_train_sc = output_scaler.fit_transform(raw_train[[output_col]].values.reshape(-1,1))
	joblib.dump(output_scaler,scaler_dir+'output_scaler.gz')

	meta_train = pd.DataFrame(columns=metadata_cols)
	meta_test = pd.DataFrame(columns=metadata_cols)
	for i, mdsc in enumerate(metadata_scalers):
		meta_train[metadata_cols[i]] = pd.DataFrame(mdsc.fit_transform(raw_train[metadata_cols[i]].values.reshape(-1,1)))
		meta_test[metadata_cols[i]] = pd.DataFrame(mdsc.transform(raw_test[metadata_cols[i]].values.reshape(-1,1)))
		joblib.dump(mdsc,scaler_dir+metadata_cols[i]+'_scaler.gz')
	
	#ensure correct (unscaled) test data set
	target_test_usc = raw_test[[output_col]].values
	
	#one-hot encoding of sequence data
	tmp = raw_train[[prim_input_col]].values
	tmp_train = []
	for i in range(0,nbr_train):
		tmp_train.append(convert_seq2vec(tmp[i][0]))
	prim_seq_train = keras.utils.pad_sequences(tmp_train,maxlen=maxlen_prim,dtype='float32',padding='pre',value=-1.0)
	
	tmp = raw_test[[prim_input_col]].values
	tmp_test = []
	for i in range(0,nbr_test):
		tmp_test.append(convert_seq2vec(tmp[i][0]))
	prim_seq_test = keras.utils.pad_sequences(tmp_test,maxlen=maxlen_prim,dtype='float32',padding='pre',value=-1.0)

	
	return prim_seq_train, prim_seq_test, meta_train, meta_test, target_train_sc, target_test_usc, output_scaler, metadata_scalers, raw_test


'''
begin script
'''
max_features = 4  #4 nucleic bases
maxlen_prim = 100  #maximum length of UTRs
prim_input_col='utr'
metadata_cols = ['UTR_length','normalized_5p_folding_energy','GC_content','number_outframe_uAUGs','number_inframe_uAUGs']
output_col='rl'

data_path = '../HEK293_training_data/opt100_nonseq_feat.tsv'
scaler_dir = '../HEK293_training_data/scalers/'
integrated_model_path = 'TranslateLSTM_opt100.h5'


#nucleotide dictionary
nt_dict = {'A' : 0, 'C' : 1, 'G' : 2, 'T' : 3, 'a' : 0, 'c' : 1, 'g' : 2, 't' : 3}


#load and scale data
prim_seq_train, prim_seq_test, meta_train, meta_test, target_train_sc, target_test_usc, output_scaler, metadata_scaler, raw_test = prepare_data(0.048, data_path, prim_input_col=prim_input_col, output_col=output_col, metadata_cols=metadata_cols, qc_col='total_reads', ordering_asc=False, target_logarithmic=False)


#load pretrained models, concatenate them, and do finetuning
if len(sys.argv) < 2 or sys.argv[1] == 'train':
	#define inputs
	prim_seq_inputs = keras.Input(shape=(None,max_features),dtype='float32')
	meta_inputs = keras.Input(shape=(len(metadata_cols)),dtype='float32')

	x = layers.Masking(mask_value=-1.,input_shape=(None,None,max_features))(prim_seq_inputs)
	x = layers.Bidirectional(layers.LSTM(64,return_sequences=True))(x)
	x = layers.Bidirectional(layers.LSTM(64))(x)
	x = layers.Dropout(0.2)(x)

	z = layers.Dropout(0.0)(meta_inputs)

	a = layers.Concatenate(axis=1)([x,z])
	outputs = layers.Dense(1,activation='linear')(a)

	#only train last layer
	integrated_model = keras.Model(inputs=[prim_seq_inputs,meta_inputs], outputs=[outputs])
	integrated_model.compile(optimizer='adam', loss='mean_squared_error', metrics=[])
	integrated_model.summary()

	callbacks = [keras.callbacks.EarlyStopping(patience=10,restore_best_weights=True)]

	integrated_model.fit([prim_seq_train,meta_train], target_train_sc,batch_size=16,epochs=200,validation_split=0.2,callbacks=callbacks,)

	integrated_model.save(integrated_model_path)


#make scatterplot of test data
if len(sys.argv) < 2 or sys.argv[1] == 'predict':
	#load model if code used in testing configuration
	if len(sys.argv)>=2 and sys.argv[1] == 'predict':
		integrated_model = keras.models.load_model(integrated_model_path)

	pred = output_scaler.inverse_transform(integrated_model.predict([prim_seq_test,meta_test]))[:,0]
	control = target_test_usc[:,0]

	#obtain m (slope) and b(intercept) of linear regression line
	m, b = np.polyfit(control, pred, 1)
	rho_p, pval = stats.pearsonr(control, pred)
	rho_s, pval = stats.spearmanr(control, pred)

	#without color coding
	plt.scatter(control, pred, s=2)
	plt.xlabel('True Values')
	plt.ylabel('Predictions')

	#use red as color for regression line
	plt.plot(control,m*control+b,c='black')
	xmin, xmax, ymin, ymax = plt.axis()
	plt.text(xmin+1.0, ymax-0.3, '$R_{Pearson}$=%.3f, $R_{Spearman}$=%.3f' % (rho_p,rho_s),fontsize = 12,color='black')

	plt.savefig('scatterplot_TranslateLSTM_opt100.pdf')
	plt.close()

	raw_test['predicted_'+output_col] = pred
	raw_test.to_csv("predictions_test_TranslateLSTM_opt100_"+output_col+".tsv",sep="\t",index=False)