import random
from matplotlib import pyplot as plt
import numpy as np
import os
#os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
##os.environ["OMP_NUM_THREADS"] = "16"  # Set to the number of logical processors
#os.environ["MKL_NUM_THREADS"] = "16"
import tensorflow as tf

#tf.config.threading.set_intra_op_parallelism_threads(16)
###tf.config.threading.set_inter_op_parallelism_threads(16)

#Input neural network inside another bigger neural network

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Input, GaussianNoise
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import Sequential



class Parameters:
    k = 8
    N = 16
    G = np.array([[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0],
        [1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0],
        [1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
    
def encode(message, Parameters) -> np.array:
    '''
    Method to encode the message using the generator matrix G
    '''
    G = Parameters.G
    message = np.array(message)
    encoded_message = np.dot(message, G) % 2
    return encoded_message

def bpsk_modulation(bit_sequence: np.ndarray) -> np.ndarray:
    '''
    Method that receives a bit sequence as input and returns the bpsk modulation
    '''
    return 2*bit_sequence-1

def add_noise(signal: np.ndarray, EbOverN0: float) -> np.ndarray:
    '''
    Method that receives a signal and adds noise to it
    '''
    noise_std = np.sqrt(1 / (2 * EbOverN0))

    # Add AWGN
    noise = noise_std * np.random.randn(len(signal))
    received_signal = signal + noise

    return received_signal

def build_decoder(input_dim, num_classes, activation, loss, num_neurons) -> Model:
    inputs = Input(shape=(input_dim,))
    x = Dense(num_neurons, activation="relu")(inputs) #A single layer with X number of neurons. To be experimented with
    outputs = Dense(num_classes, activation=activation)(x)
    model = Model(inputs=inputs, outputs=outputs)
    model.compile(optimizer=Adam(learning_rate=0.01),
                  loss=loss,
                  metrics=['accuracy'])
    return model

def build_decoder_with_noise(input_dim, num_classes, activation, loss, num_neurons, noise_stddev=0.1):
    # Build the inner decoder: input -> hidden -> output
    inner_input = Input(shape=(input_dim,), name="inner_input")
    hidden = Dense(num_neurons, activation="relu", name="hidden_layer")(inner_input)
    inner_output = Dense(num_classes, activation=activation, name="inner_output")(hidden)
    inner_model = Model(inner_input, inner_output, name="inner_decoder")

    # Build the outer decoder that injects noise before passing data through the inner model
    outer_input = Input(shape=(input_dim,), name="outer_input")
    noisy = GaussianNoise(noise_stddev, name="noise_layer")(outer_input)
    output = inner_model(noisy)
    model = Model(outer_input, output, name="big_decoder")

    model.compile(optimizer=Adam(learning_rate=0.01),
                  loss=loss,
                  metrics=['accuracy'])
    # Return both the big model and the inner model for later extraction
    return model, inner_model

def generate_training_data(num_frames, EbOverN0dB):

    X_train = []
    Y_train = []
    EbOverN0Linear = 10 ** (EbOverN0dB / 10) # linear scale
    for _ in range(num_frames):
        # Generate a random message of length Parameters.k bits
        message = np.random.randint(0, 2, size=Parameters.k)
        # Encode message
        encoded_message = encode(message, Parameters)
        # Modulate it
        modulated_message = bpsk_modulation(encoded_message)
        # Scale the noise
        noise_scaling = EbOverN0Linear * (Parameters.k / Parameters.N)
        #Now add the noise
        received_signal = add_noise(modulated_message, noise_scaling)
        #Append to the training dataset
        X_train.append(received_signal)
        # Convert binary message to integer then to one-hot vector.
        int_label = int("".join(map(str, message)), 2) #Convert message to decimal
        one_hot_label = to_categorical(int_label, num_classes=2**Parameters.k) #Convert decimal to one-hot
        Y_train.append(one_hot_label)

    return np.array(X_train), np.array(Y_train)

def generate_test_data_onehot(num_frames, EbOverN0dB):
    X_test = []
    Y_test = []
    EbOverN0Linear = 10 ** (EbOverN0dB / 10)
    for frame in range(num_frames):
        message = np.random.randint(0, 2, size=Parameters.k)
        encoded_message = encode(message, Parameters)
        modulated_message = bpsk_modulation(encoded_message)
        noise_scaling = EbOverN0Linear * (Parameters.k / Parameters.N)
        received_signal = add_noise(modulated_message, noise_scaling)
        X_test.append(received_signal)
        int_label = int("".join(map(str, message)), 2)
        
        one_hot_label = to_categorical(int_label, num_classes=2**Parameters.k)
        Y_test.append(one_hot_label)
    return np.array(X_test), np.array(Y_test)


def evaluate_onehot_model(model, num_test_frames=50000, EbOverN0dB=2):
    X_test, Y_test = generate_test_data_onehot(num_test_frames, EbOverN0dB)
    predictions = model.predict(X_test)
    predicted_classes = np.argmax(predictions, axis=1)
    #input(len(predicted_classes))
    true_classes = np.argmax(Y_test, axis=1)
    #input(len(true_classes))
    
    total_bits = num_test_frames * Parameters.k
    nn_bit_errors = np.sum(np.unpackbits((predicted_classes^true_classes).astype(np.uint8)))

    #nn_bit_errors = np.sum((predicted_classes != true_classes).astype(np.uint8)), nn_bit_errors
    #print(np.sum((predicted_classes != true_classes).astype(np.uint8)), nn_bit_errors)
    nn_BER = nn_bit_errors / total_bits
    MAP_BER = get_MAP_BEP(EbOverN0dB)
    NVE = nn_BER / MAP_BER
    print(f"\nNeural Decoder Bit Error Rate (BER): {nn_BER:.6f}")
    print(f"MAP Decoder Bit Error Rate (BER):    {MAP_BER:.6f}")
    print(f"Normalized Validation Error (NVE): {NVE:.2f}")
    return nn_BER, NVE

def get_MAP_BEP(EbOverN0dB):

    BEP = {7: 6.25e-5, 6: 0.00038125, 5: 0.0018625, 4: 0.006425, 3: 0.0197625, 2: 0.0419, 1: 0.0793125, 0:0.1251, -1: 0.179025, -2: 0.2329, -3: 0.2825, -4: 0.323275, -5: 0.36095}

    return BEP[EbOverN0dB]

    #BEP = np.array([0.4575875, 0.4447875, 0.4297125, 0.4108125, 
                    #0.38385, 0.36095, 0.323275, 0.2825, 
                    #0.2329, 0.179025, 0.1251, 0.0793125, 
                    #0.0419, 0.0197625, 0.006425, 0.0018625])


def main():
    clear = lambda: os.system('cls')
    clear()

    input_dim = Parameters.N  
    num_classes = 2**Parameters.k  

    # parameters
    epochs = 100
    batch_size = 256
    num_frames = 500000
    EbOverN0dB = 1
    activation = 'softmax'
    num_neurons = 128
    loss = 'categorical_crossentropy'

    index = 0
    file_name = f"new_models/onehot_decoder_{EbOverN0dB}_{epochs}_{batch_size}_{num_frames}_{loss}_{activation}_{num_neurons}.keras"

    print(file_name)
    if os.path.exists(file_name):
        print(f"Model already exists: {file_name}")
        print()
        answer = input(f"Plot models? (Y/N): ")
        if answer.lower() == 'y':
            plot(file_name)
            return()
    
    while os.path.exists(file_name):
        index += 1
        file_name = f"new_models/onehot_decoder_{EbOverN0dB}_{epochs}_{batch_size}_{num_frames}_{loss}_{activation}_{num_neurons}_{index}.keras"

    #Create model
    model = build_decoder(input_dim, num_classes, activation, loss, num_neurons)
    model.summary()

    #Create callback for reducing learning rate
    callbacks = [
        ModelCheckpoint(file_name, monitor='loss', save_best_only=True, verbose=1),
        ReduceLROnPlateau(monitor='loss', factor=0.9, patience=4, verbose=1),
    ]

    #Generate training data
    X, Y = generate_training_data(num_frames=num_frames, EbOverN0dB=EbOverN0dB)

    #Train the model
    model.fit(X, Y, epochs=epochs, batch_size=batch_size, callbacks=callbacks, verbose=1)

    #Save the model
    model.save(file_name)

    plot(file_name)

def plot(file_name = None):
    EbOverN0dB = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7]
    if file_name is None:


        model_path = 'models/'
        models = os.listdir(model_path)
    else:
        model_path = ''
        models = [file_name]
    fig, ax = plt.subplots(1, 1, figsize=(10, 10))
    num_frames = 20000
    BEP_map = [get_MAP_BEP(EbOverN0) for EbOverN0 in EbOverN0dB]
    plt.semilogy(EbOverN0dB, BEP_map, 'o-', label='ML Simulated')
    for model_file in models:
        BEP_nn = []
        NVE = []
        path = model_path + model_file
        model = tf.keras.models.load_model(path)
        for EbOverN0 in EbOverN0dB:
            if "onehot" in model_file:
                BEP, NVE_nn = evaluate_onehot_model(model, num_test_frames=num_frames, EbOverN0dB=EbOverN0)
            BEP_nn.append(BEP)
            NVE.append(NVE_nn)
        plt.semilogy(EbOverN0dB, BEP_nn, 'o-', label=path)
        sum_NVE = sum(NVE)
        print(f"Model {model_file}")
        print(sum_NVE)
    plt.xlabel('Eb/N0 (dB)')
    plt.ylabel('BER')
    plt.legend()
    plt.grid()
    plt.show()


if __name__ == '__main__':
    main()