import time
from model import *

#We redefine some methods
def generate_training_data(num_frames, EbOverN0dB):

    X_train = []
    Y_train = []
    EbOverN0Linear = 10 ** (EbOverN0dB / 10) # linear scale
    for _ in range(num_frames):
        # Generate a random message of length Parameters.k bits
        message = np.random.randint(0, 2, size=Parameters.k)
        # Encode message
        encoded_message = encode(message, Parameters)
        # Modulate it
        modulated_message = bpsk_modulation(encoded_message)
        # Scale the noise
        noise_scaling = EbOverN0Linear * (Parameters.k / Parameters.N)
        #Now add the noise
        #received_signal = add_noise(modulated_message, noise_scaling) We dont add noise here
        received_signal = modulated_message
        #Append to the training dataset
        X_train.append(received_signal)
        # Convert binary message to integer then to one-hot vector.
        int_label = int("".join(map(str, message)), 2) #Convert message to decimal
        one_hot_label = to_categorical(int_label, num_classes=2**Parameters.k) #Convert decimal to one-hot
        Y_train.append(one_hot_label)

    return np.array(X_train), np.array(Y_train)

def main():

    clear()

    input_dim = Parameters.N  
    num_classes = 2**Parameters.k  


    ####### Parameters NOT to change
    loss = 'categorical_crossentropy'
    activation = 'softmax'
    batch_size = 256
    #######


    # parameters
    epochs = 1000
    num_frames = 20000
    EbOverN0dB = 1
    num_neurons = 64
    

    file_name = f"onehot_decoder_{EbOverN0dB}_{epochs}_{batch_size}_{num_frames}_{num_neurons}.keras"

    # Create model

    #Calculate noise std with Eb/N0
    EbOverN0Linear = 10 ** (EbOverN0dB / 10)
    # Scale the noise
    noise_scaling = EbOverN0Linear * (Parameters.k / Parameters.N)
    noise_std = 1 / np.sqrt(2 * noise_scaling)


    model, _ = build_decoder_with_noise(input_dim = input_dim, 
                                     num_classes = num_classes, 
                                     activation = activation, 
                                     loss = loss, 
                                     num_neurons = num_neurons, 
                                     noise_stddev = noise_std)
    model.summary()
    #input()

    #Create callback for reducing learning rate
    callbacks = [
        ModelCheckpoint('temp/' + file_name, monitor='loss', save_best_only=True, verbose=1),
        ReduceLROnPlateau(monitor='loss', factor=0.9, patience=10, verbose=1),
        EarlyStopping(monitor='loss', patience=100, verbose=1)
    ]

    #Generate training data
    X, Y = generate_training_data(num_frames=num_frames, EbOverN0dB=EbOverN0dB)

    #Train the model
    model.fit(X, Y, epochs=epochs, batch_size=batch_size, callbacks=callbacks, verbose=1)

    #Save the model with noise layer
    #model.save('noise_models/' + file_name)

    #Save inner model
    for layer in model.layers:
        if 'decoder' in layer.name:
            
            #Compile
            layer.compile(optimizer='adam', loss=loss, metrics=['accuracy'])
            #Change name to file_name
            layer._name = file_name
            #Save
            layer.save('models/' + file_name)
            #Also save weights
            layer.save_weights('models/' + file_name + '.weights.h5')
            break

    plot('models/' + file_name)

def clear():
    os.system('cls')

def load_model(file_name):
    return tf.keras.models.load_model(file_name)

def compare_models():
    #Load models
    models = []
    modelsFiles = os.listdir('models')

    print('Choose a model to evaluate:')
    for i, modelFile in enumerate(modelsFiles):
        print(f"{i}: {modelFile}")

    modelIndex = int(input())
    modelFile = modelsFiles[modelIndex]
    model = load_model('models/' + modelFile)
    models.append(model)

    

    #Generate test data
    #EbOverN0 range is from -5 to 6dB every 1dB
    EbOverN0dBs = np.arange(-5, 6, 1)

    map_integral = get_MAP_integral()

    dict_of_results = {}
    dict_of_results_NVE = {}
    for model in models:
        clear()
        print(f"Evaluating model {model.name}")
        time.sleep(1)
        dict_of_results[model.name] = []
        BEps = []
        NVEs = []

        for EbOverN0dB in EbOverN0dBs:  
            BEP, NVE = evaluate_onehot_model(model, num_test_frames=50000, EbOverN0dB=EbOverN0dB)
            BEps.append(BEP)
            NVEs.append(NVE)

        #Do integral of BEP from -5 to 6
        integral = np.trapz(BEps, EbOverN0dBs)
        integral_NVE = np.trapz(NVEs, EbOverN0dBs)
        dict_of_results_NVE[model.name] = integral_NVE
        dict_of_results[model.name] = integral - map_integral

    print(dict_of_results)
    print(dict_of_results_NVE)

def get_MAP_BEP(EbOverN0dB):

    BEP = {7: 6.25e-5, 6: 0.00038125, 5: 0.0018625, 4: 0.006425, 3: 0.0197625, 2: 0.0419, 1: 0.0793125, 0:0.1251, -1: 0.179025, -2: 0.2329, -3: 0.2825, -4: 0.323275, -5: 0.36095}

    return BEP[EbOverN0dB]

def get_MAP_integral():
    
    EbOverN0Range = np.arange(-5, 6, 1)
    BEPs = []
    for EbOverN0dB in EbOverN0Range:
        BEPs.append(get_MAP_BEP(EbOverN0dB))

    integral = np.trapz(BEPs, EbOverN0Range)
    return integral


if __name__ == '__main__':

    #print(get_MAP_integral())
    

    #If you want to plot all models, uncomment and comment main() :))

    #main()
    #plot()
    compare_models()