import warnings
from flask import Flask,request
import pickle
import pandas as pd
import nltk
from nltk.corpus import stopwords
from sklearn.feature_extraction.text import TfidfVectorizer
import sys
from typing import Dict, Optional
import re
import sys
from base64 import b64encode, b64decode
import email
import email.message
from Crypto.Signature import PKCS1_v1_5
from Crypto.Hash import SHA256
from Crypto.PublicKey import RSA
from Crypto.Util.asn1 import DerSequence, DerNull, DerOctetString, DerObjectId
import Crypto.Util
from Crypto.Util.number import bytes_to_long, long_to_bytes
import dns.resolver

warnings.filterwarnings("ignore")

app = Flask(__name__, template_folder='template')  # still relative to module

@app.route('/predict', methods=['POST'])
def predict():
    print("Request",request.json)
    model = pickle.load(open("./model/spam_model.pkl", "rb"))
    tfidf_model = pickle.load(open("./model/tfidf_model.pkl", "rb"))
    if request.method == "POST":
        message = request.json.get("message")
        non_bmp_map = dict.fromkeys(range(0x10000, sys.maxunicode + 1), 0xfffd)
        print(message.translate(non_bmp_map))
        message = [message]
        dataset = {'message': message}
        data = pd.DataFrame(dataset)
        data["message"] = data["message"].str.replace(
            r'^.+@[^\.].*\.[a-z]{2,}$', 'emailaddress')
        data["message"] = data["message"].str.replace(
            r'^http\://[a-zA-Z0-9\-\.]+\.[a-zA-Z]{2,3}(/\S*)?$', 'webaddress')
        data["message"] = data["message"].str.replace(r'£|\$', 'money-symbol')
        data["message"] = data["message"].str.replace(
            r'^\(?[\d]{3}\)?[\s-]?[\d]{3}[\s-]?[\d]{4}$', 'phone-number')
        data["message"] = data["message"].str.replace(r'\d+(\.\d+)?', 'number')
        data["message"] = data["message"].str.replace(r'[^\w\d\s]', ' ')
        data["message"] = data["message"].str.replace(r'\s+', ' ')
        data["message"] = data["message"].str.replace(r'^\s+|\s*?$', ' ')
        data["message"] = data["message"].str.lower()
        print("data",data.to_string())
        stop_words = set(stopwords.words('english'))
        data["message"] = data["message"].fillna("").apply(lambda x: ' '.join(
            term for term in x.split() if term not in stop_words))
        ss = nltk.SnowballStemmer("english")
        data["message"] = data["message"].fillna("").apply(lambda x: ' '.join(ss.stem(term) for term in x.split()))
        print("message formed",data.to_string())
        tfidf_vec = tfidf_model.transform(data["message"])
        tfidf_data = pd.DataFrame(tfidf_vec.toarray())
        my_prediction = model.predict(tfidf_data)
        print("my_prediction",my_prediction)
    if(my_prediction[0] == 0):
        return "ham"
    else:
        return "spam"

def get_public_key(domain: str, selector: str) -> RSA.RsaKey:
    # print("\n{}._domainkey.{}.".format(selector, domain))
    dns_response = dns.resolver.query("{}._domainkey.{}.".format(selector, domain), "TXT").response.answer[0].to_text()
    # print("\ndns_response",dns_response)
    p = re.search(r'p=([\w\d/+" ]*)', dns_response)
    # print("\np",p)
    p = p.group(1).replace(" ", "").replace("\"", "") if p else None
    pub_key = RSA.importKey(b64decode(p)) if p else None
    # print("\npub_key ->",p)
    return pub_key


def parse_dkim_header(dkim_header: str) -> Dict[str, str]:
    parameter = {}
    parts = dkim_header.split(";")
    try:
        for part in parts:
            key, value = part.split("=", 1)
            parameter[key.strip()] = re.sub(r'(\n|\t\|\r|\s)', "", value)
        return parameter
    except ValueError:
        return None


def hash_headers(mail: email.message.Message, header_to_hash: str) -> SHA256.SHA256Hash:
    header_to_hash_list = header_to_hash.split(":")
    headers = ""
    # print("\nheader_to_hash_list",header_to_hash_list)
    try:
        for header in header_to_hash_list:
            # print("\nmail[header]",mail[header])
            if header in header_to_hash_list:
                # print('header.lower() + ":" + mail[header].strip() + "\r\n"',header.lower() + ":" + mail[header].strip() + "\r\n")
                headers += header.lower() + ":" + mail[header].strip() + "\r\n"
                # header_to_hash_list.remove(header) # strip duplicate header like the from

        # print("\nheader_to_hash_list",header_to_hash_list)
        dkim_header = mail.get("DKIM-Signature")
        dkim_header = re.sub(r'(\n|\r)', "", dkim_header)
        dkim_header = re.sub(r'\s+', " ", dkim_header)
        # print("\ndkim_header ->",dkim_header)
        headers += "dkim-signature:{}\r\n".format(dkim_header)
        headers = re.sub(r'b=[\w0-9\s/+=]+', "b=", headers) #replace b=... with be=
        # print("\nheaders\n",headers)
        headers = SHA256.new(headers.encode())
    except Exception as e:
        print("\nException",e)
    return headers


def pkcs1_v1_5_encode(msg_hash: SHA256.SHA256Hash, emLen: int) -> bytes:
    digestAlgo = DerSequence([ DerObjectId(msg_hash.oid).encode() ])
    #if with_hash_parameters:
    if True:
        digestAlgo.append(DerNull().encode())

    digest      = DerOctetString(msg_hash.digest())
    digestInfo  = DerSequence([
                    digestAlgo.encode(),
                      digest.encode()
                    ]).encode()

    # We need at least 11 bytes for the remaining data: 3 fixed bytes and
    # at least 8 bytes of padding).
    if emLen<len(digestInfo)+11:
          raise TypeError("Selected hash algorith has a too long digest (%d bytes)." % len(digest))
    PS = b'\xFF' * (emLen - len(digestInfo) - 3)
    return b'\x00\x01' + PS + b'\x00' + digestInfo


def verify_signature(hashed_header: SHA256.SHA256Hash, signature: bytes, public_key: RSA.RsaKey) -> bool:
    # print("\npublic key ->",public_key.n)
    modBits = Crypto.Util.number.size(public_key.n)
    emLen = modBits // 8

    signature_long = bytes_to_long(signature)
    expected_message_int = pow(signature_long, public_key.e, public_key.n)
    expected_message = long_to_bytes(expected_message_int, emLen)

    padded_hash = pkcs1_v1_5_encode(hashed_header, emLen)

    # print("\n",padded_hash,"\n\n",expected_message)
    return padded_hash == expected_message

@app.route('/verify-dkim', methods=['POST'])
def VerifyEmailDkim():
    print("Inside VerifyEmailDkim")
    mail = email.message_from_bytes(request.json.get("mail").encode('ascii'))
    dkim_header = mail.get("DKIM-Signature")
    # print("\nDKIM-Signature ->",dkim_header)
    
    dkim_parameter = parse_dkim_header(dkim_header)
    if dkim_parameter:    
        public_key = get_public_key(dkim_parameter['d'], dkim_parameter['s'])
        if public_key:
            hashed_header = hash_headers(mail, dkim_parameter['h'])

            # print("\nmail",mail)
            signature = b64decode(dkim_parameter['b'])
            
            if verify_signature(hashed_header, signature, public_key):
                print("\nsignature is valid")
                return "True"
            else:
                print("\nsignature is NOOOOT valid")
                return "False"
            print("done")
            return "False"
        else:
            print("\npublic key not found")
            return "False"
    else:
        print("\ndkim_parameter not missing")
        return "False"
if __name__ == '__main__':
    app.run(host='0.0.0.0',port=5000)
