# importing libraries 
from tensorflow.keras.preprocessing.image import ImageDataGenerator 
from tensorflow.keras.models import Sequential 
from tensorflow.keras.layers import Conv2D, MaxPooling2D 
from tensorflow.keras.layers import Activation, Dropout, Flatten, Dense 
from tensorflow.keras import backend as K 

import tensorflow as tf
print( tf.config.list_physical_devices() )

import matplotlib.pyplot as plt
import numpy as np

import pickle

img_width, img_height = 256, 256

def NNmodel():

    if K.image_data_format() == 'channels_first': 
        input_shape = (3, img_width, img_height) 
    else: 
        input_shape = (img_width, img_height, 3)

    model = Sequential() 
    model.add(Conv2D(8, (2, 2), input_shape = input_shape)) 
    model.add(Activation('relu')) 
    model.add(MaxPooling2D(pool_size =(2, 2))) 
    
    model.add(Conv2D(16, (2, 2))) 
    model.add(Activation('relu')) 
    model.add(MaxPooling2D(pool_size =(2, 2))) 
    
    model.add(Conv2D(32, (2, 2))) 
    model.add(Activation('relu')) 
    model.add(MaxPooling2D(pool_size =(2, 2))) 
    
    model.add(Flatten()) 
    model.add(Dense(64)) 
    model.add(Activation('relu')) 
    model.add(Dropout(0.5)) 
    model.add(Dense(1)) 
    model.add(Activation('sigmoid')) 

    return model

def main():
  
    train_data_dir = '/user/msullivan/TauCP/tau-cp-dev/images/'
    nb_train_samples = 2000
    nb_validation_samples = 1000
    epochs = 200
    batch_size = 32

    model = NNmodel()
    model.compile(loss ='binary_crossentropy', 
                        optimizer ='rmsprop', 
                    metrics =['accuracy', 'mse']) 

    print( model.summary() )

    train_datagen = ImageDataGenerator(
                rescale = 1. / 255, 
                shear_range = 0.1, 
                zoom_range = 0.1,
                validation_split = 0.2 ) 
            #horizontal_flip = True) 

    train_generator = train_datagen.flow_from_directory(train_data_dir, 
        target_size =(img_width, img_height), 
        batch_size = batch_size, class_mode ='binary', subset='training') 

    evaluateMean=False
    if evaluateMean:

        print('Will try and evaluate average image!')

        index = 0
        class0 = []
        class1 = []

        x, y = train_generator.next()
        while len(class0) < 800:
            image, label = x[index], y[index]
            if label == 0:
                class0.append(image)
            elif label == 1:
                class1.append(image)  
            index += 1
            if index == batch_size:
                x, y = train_generator.next()
                index = 0

        avgimage0 = np.mean(class0, axis=0)
        plt.figure(0)
        plt.imshow(avgimage0)
        plt.savefig('mean0.png',dpi=200)
        avgimage1 = np.mean(class1, axis=0)
        plt.figure(1)
        plt.imshow(avgimage1)
        plt.savefig('mean1.png',dpi=200)

        plt.figure(2)
        plt.imshow(avgimage1/avgimage0)
        plt.savefig('ratio.png',dpi=200)
        bkgsub = np.absolute(avgimage0 - avgimage1)

        plt.figure(3)
        plt.imshow(bkgsub)  
        plt.savefig('bkgsub.png',dpi=200)

    validation_generator = train_datagen.flow_from_directory(train_data_dir, 
        target_size =(img_width, img_height), 
        batch_size = batch_size, class_mode ='binary', subset='validation') 

    #x,y = train_generator.next()
    #for i in range(0, batch_size):
    #    image = x[i]
    #    correctedimage = 0.8*np.absolute(x[i]/avgimage0) if y[i] == 0 else 0.8*np.absolute(x[i]/avgimage1)
    #    label = y[i]

        #plt.figure(i)
    #    plt.subplot(111)
    #    np.squeeze(image)
    #    plt.imshow(image)
    #    plt.subplot(121)
    #    np.squeeze(correctedimage)
    #    plt.imshow(correctedimage)
        #plt2 = plt.subplots(120)
        #plt.xlabel(label)
        #plt.show()

    history = model.fit(train_generator, 
        steps_per_epoch = nb_train_samples // batch_size, 
        epochs = epochs, validation_data=validation_generator) 

    with open('history.pickle', 'wb') as handle:
        pickle.dump(history.history, handle, protocol=pickle.HIGHEST_PROTOCOL)

    import sys
    sys.exit()

    pred= model.predict_generator(validation_generator, 20)
    predicted_class_indices=np.argmax(pred,axis=1)
    print(predicted_class_indices)
    labels = (validation_generator.class_indices)
    #labels2 = dict((v,k) for k,v in labels.items())
    predictions = [labels[k] for k in predicted_class_indices]
    #print(predicted_class_indices)
    print (labels)
    #print (predictions)

    plt.hist(pred,bins=20)
    plt.savefig('preds.png', dpi=200)

    # x,y = validation_generator.next()
    # for i in range(batch_size):
    #     image, label = x[i], y[i]
    #     plt.figure(i)
    #     plt.imshow(image)
    #     prediction =   model.predict(np.expand_dims(image, axis=0))
    #     plt.xlabel('{} vs {}'.format(label, prediction))

if __name__ == '__main__':
    main()