__author__ = 'DarkWeb'

import string
import time
import re
import hashlib
import base64
import io
import configparser
import json
import keras
import cv2
import numpy as np
from keras.preprocessing import image
from keras.applications.imagenet_utils import preprocess_input
from keras.models import Model
from datetime import datetime, timedelta
from lxml import html as lxml
from selenium.webdriver.common.by import By
from Crypto.Cipher import AES
from Crypto.Util.Padding import pad, unpad
from PIL import Image
from urllib.parse import urlsplit, urljoin


def generate_aes_key():
    config = configparser.ConfigParser()
    config.read('../../setup.ini')

    secret = config.get('Encryption', 'secret')
    secret_bytes = bytes(secret, encoding="utf-8")

    # Derive a key from the seed using PBKDF2
    key = hashlib.pbkdf2_hmac(hash_name='sha256', password=secret_bytes, salt=bytes(), iterations=1)

    # Use the first 16 bytes of the derived key as the AES key
    aes_key = key[:16]

    # print("key: ", aes_key)
    return aes_key


BLOCK_SIZE = 32
aes_key = generate_aes_key()
encryptCipher = AES.new(aes_key, AES.MODE_ECB)
decryptCipher = AES.new(aes_key, AES.MODE_ECB)

model = keras.applications.ResNet50(weights='imagenet', include_top=True)
feat_extractor = Model(inputs=model.input, outputs=model.get_layer('avg_pool').output)

sift = cv2.SIFT_create(
    nfeatures=0,                # Number of features, 0 for unlimited
    nOctaveLayers=3,            # Number of layers per octave
    contrastThreshold=0.09,     # Contrast threshold
    edgeThreshold=10,           # Edge threshold
    sigma=1.6                   # Initial Gaussian blur sigma
)


def generate_image_hash(image_string):

    image_bytes = bytes(image_string, encoding='utf-8')
    image_bytes = base64.b64decode(image_bytes)

    return hashlib.sha256(image_bytes).hexdigest()


def extract_hidden_layer_output(image_string):

    image_bytes = bytes(image_string, encoding='utf-8')
    image_bytes = base64.b64decode(image_bytes)
    im = Image.open(io.BytesIO(image_bytes)).convert('RGB')

    x = image.img_to_array(im)
    x = image.smart_resize(x, size=model.input_shape[1:3], interpolation='nearest')
    x = np.expand_dims(x, axis=0)
    x = preprocess_input(x)

    return json.dumps(feat_extractor.predict(x)[0].tolist())
    

def extract_keypoints(image_string):

    image_bytes = bytes(image_string, encoding='utf-8')
    image_bytes = base64.b64decode(image_bytes)
    image_array = np.asarray(bytearray(image_bytes), dtype=np.uint8)

    img = cv2.imdecode(image_array, cv2.IMREAD_GRAYSCALE)

    keypoints, descriptors = sift.detectAndCompute(img, None)

    if len(keypoints) == 0:
        return None, None

    return json.dumps(wrap_keypoints(keypoints)), json.dumps(descriptors.tolist())


def wrap_keypoints(keypoints):
    
    keypoints_list = []

    for i in range(len(keypoints)):
        temp = {
            'pt': keypoints[i].pt,
            'size': keypoints[i].size,
            'angle': keypoints[i].angle,
            'octave': keypoints[i].octave,
            'response': keypoints[i].response,
            'class_id': keypoints[i].class_id
        }
        keypoints_list.append(temp)
    
    return keypoints_list


def unwrap_keypoints(keypoints_list):

    keypoints = []

    for temp in keypoints_list:
        point = cv2.KeyPoint(
            x=temp['pt'][0],
            y=temp['pt'][1],
            size=temp['size'],
            angle=temp['angle'],
            octave=temp['octave'],
            response=temp['response'],
            class_id=temp['class_id']
        )
        keypoints.append(point)

    return tuple(keypoints)


def cleanText(originalText):

    safe_chars = string.ascii_letters + string.digits + " " + "_" + "/" + "&" + "$" + "#" "@" + "+" + "-" + "*" + "=" \
                     ":" + ";" + "." "," + "?" + "!" + "{" + "}" + "[" + "]" + "(" + ")" + "%" + "`" + "~" + "^" + "|" + "<" + ">"

    for index, text in enumerate(originalText):

        originalText[index] = ''.join([char if char in safe_chars else '' for char in text])

    return originalText


def cleanLink(originalLink):

    safe_chars = string.ascii_letters + string.digits

    originalLink = ''.join([char if char in safe_chars else '' for char in originalLink])

    return originalLink


def organizeTopics(forum, nm, board, author, topic, views, posts, href, addDate, image_author):

    rw = []

    current_time = datetime.now()
    day = current_time.strftime("%m/%d/%Y")
    ahora = current_time.strftime("%I:%M:%S")

    for n in range(nm):

        lne = forum                                                       # 0
        lne += ","
        lne += board                                                      # 1 board_topic
        lne += ","
        lne += author[n]                                                  # 2
        lne += ","
        lne += topic[n]                                                   # 3 topic_title
        lne += ","
        lne += "-1" if len(views) == 0 else views[n]                      # 4 views_topic
        lne += ","
        lne += "-1" if len(posts) == 0 else posts[n]                      # 5 posts_topic
        lne += ","
        lne += "-1" if len(href) == 0 else href[n]                        # 6 href_topic
        lne += ","
        lne += "-1" if len(addDate) == 0 else str(addDate[n])             # 7 dateadded_topic
        lne += ","
        lne += day + " " + ahora                                          # 8 dateinserted_topic
        lne += ","
        lne += "-1" if len(image_author) == 0 else str(image_author[n])   # 9 image_user
        lne += ","
        lne += "-1"                                                       # 10 name_user
        lne += ","
        lne += "-1"                                                       # 11 status_user
        lne += ","
        lne += "-1"                                                       # 12 reputation_user
        lne += ","
        lne += "-1"                                                       # 13 interest_user
        lne += ","
        lne += "-1"                                                       # 14 signature_user
        lne += ","
        lne += "-1"                                                       # 15 content_post
        lne += ","
        lne += "-1"                                                       # 16 feedback_post
        lne += ","
        lne += "-1"                                                       # 17 dateadded_post
        lne += ","
        lne += "-1"                                                       # 18 image_post
        lne += ","
        lne += "-1"                                                       # 19 classification_post

        rw.append(lne)

    return rw


def cleanString(originalString):
    updated_string = originalString.replace(",", "")    #replace all commas
    updated_string = updated_string.replace("\n", "")   #replace all newlines
    updated_string = updated_string.replace("\t", "")   #replace all tabs
    updated_string = updated_string.replace("\r", "")   #replace all carriage returns
    updated_string = updated_string.replace("'", "^")   #replace all semicolons
    updated_string = updated_string.replace(u"ยป", '')   #replace all arrows
    updated_string = updated_string.replace("!", "")    #replace all exclamation points
    updated_string = updated_string.replace(";", "")    #replace all exclamations

    return updated_string


def cleanNumbers(inputString):

    reg_ex = re.compile(r'[^\d.]+')
    updated_string = reg_ex.sub('', inputString)

    return updated_string


def aes_encryption(data_bytes):

    encrypted_bytes = encryptCipher.encrypt(pad(data_bytes, BLOCK_SIZE))

    return encrypted_bytes


def aes_decryption(data_bytes):

    decrypted_bytes = decryptCipher.decrypt(data_bytes)

    return unpad(decrypted_bytes, BLOCK_SIZE)


def encrypt_encode_image_to_base64(driver, xpath):

    try:

        img_element = driver.find_element(by=By.XPATH, value=xpath)
        image_data = img_element.screenshot_as_png

        encrypted_image = aes_encryption(image_data)
        base64_image = base64.b64encode(encrypted_image)
        enc_image_string = base64_image.decode('utf-8')

        return enc_image_string

    except Exception as e:
        print(e)
        pass

    return None


def decode_decrypt_image_in_base64(image_string):

    try:

        image_bytes = bytes(image_string, encoding='utf-8')
        encrypted_bytes = base64.b64decode(image_bytes)
        decrypted_image = aes_decryption(encrypted_bytes)
        base64_image = base64.b64encode(decrypted_image)
        dec_image_string = base64_image.decode('utf-8')

        return dec_image_string

    except Exception as e:
        print(e)
        pass

    return None


def replace_image_sources(driver, html_content):

    tree = lxml.fromstring(html_content)

    for picture_tag in tree.findall('.//picture'):
        for source_tag in picture_tag.findall('.//source'):
            picture_tag.remove(source_tag)

    for img_tag in tree.findall('.//img'):

        img_xpath = tree.getroottree().getpath(img_tag)

        string_image = encrypt_encode_image_to_base64(driver, img_xpath)

        if string_image:
            img_tag.set('src', f'data:image/png;base64,{string_image}')
        else:
            img_tag.getparent().remove(img_tag)

    modified_html = lxml.tostring(tree, encoding='utf-8').decode('utf-8')

    return modified_html


def cleanHTML(driver, html):

    clean_html = replace_image_sources(driver, html)
    # decode_decrypt_image_in_base64(clean_html)

    formats = [
        "jpg", "jpeg", "jfif", "pjpeg", "pjp",
        "png", "apng", "svg", "bmp", "gif",
        "avif", "webp", "ico", "cur", "tiff"
    ]

    # remove images
    clean_html = re.sub(r"<svg[\s\S]*?svg>", "", clean_html)
    for fmat in formats:
        clean_html = re.sub(r"<object.*" + fmat + "[\s\S]*?object>", "", clean_html)
    clean_html = re.sub(r"<canvas[\s\S]*?canvas>", "", clean_html)

    # remove JavaScript
    clean_html = re.sub(r"<script[\s\S]*?script>", "", clean_html)
    clean_html = re.sub(r"<iframe[\s\S]*?iframe>", "", clean_html)
    clean_html = re.sub(r"<object.*javascript[\s\S]*?object>", "", clean_html)
    clean_html = re.sub(r"<aplet.*mayscript[\s\S]*?aplet>", "", clean_html)
    clean_html = re.sub(r"<embed.*scriptable[\s\S]*?embed>", "", clean_html)

    # image and JavaScript
    clean_html = re.sub(r"<div[^>]*style=\"[^\"]*background-image[\s\S]*?div>", "", clean_html)

    return clean_html


def get_relative_url(target_url):
    # Use a dummy base URL to handle both absolute and relative URLs
    base_url = "http://dummybaseurl.com/"
    absolute_url = urljoin(base_url, target_url)

    # Parse the absolute URL
    parsed_absolute_url = urlsplit(absolute_url)

    # Extract the path and query from the absolute URL as the relative URL
    return parsed_absolute_url.path + '?' + parsed_absolute_url.query \
        if parsed_absolute_url.query else parsed_absolute_url.path