#!/bin/python3
from pwn import args

import socket
from zlib import crc32
from struct import pack
from base64 import b64decode, b64encode
from pwn import xor

from Crypto.Cipher import AES
from Crypto.PublicKey import RSA
from Crypto.Cipher import PKCS1_v1_5 as Cipher_PKCS1_v1_5
from Crypto.Random import get_random_bytes

# constants
serverAddressPort = ("secure-ftp.dghack.fr", 4445)
bufferSize = 2048
username = "GUEST_USER"
password = "GUEST_PASSWORD"

# padding for AES encryption and decryption
BS = 16
pad = lambda s: s + (BS - len(s) % BS) * chr(BS - len(s) % BS)
unpad = lambda s : s[0:-ord(s[-1])]

# Create a UDP socket at client side with timeout of 0.5s
UDPClientSocket = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM)
UDPClientSocket.settimeout(0.5)


# Messages types
messagesType = {"RsaKeyMessage":     78,
                "PingMessage":       10,
                "AuthMessage":       4444,
                "GetFileMessage":    666,
                "GetFilesMessage":   45,
                "ErrorMessage":      1,
                "ConnectMessage":    1921,
                "SessionKeyMessage": 1337,
                "RsaKeyReply":       98,
                "ConnectReply":      4875,
                "PingReply":         11,
                "SessionKeyReply":   1338,
                "AuthReply":         6789,
                "GetFileReply":      7331,
                "GetFilesReply":     46}

messagesKeyList = list(messagesType.keys())
messagesValueList = list(messagesType.values())


# Send message of type 'type' to the server, then receive and parse the reply from the server. Retry up to 5 times
def sendMessage(idType, data, timeout = 4):
    print(f"\n\033[32mSend message: {idType}\033[39m")
    # serialize data
    if type(data) is str or type(data) is bytes:
        # append the lenght of string at the beginning
        content = pack("!I", len(data)).lstrip(b"\x00").rjust(2, b"\x00")
        if type(data) is bytes:
            content+= data
        else:
            content+= data.encode()
    # send strings if tuple
    elif type(data) is tuple:
        content = b""
        for el in data:
            content+= pack("!I", len(el)).lstrip(b"\x00").rjust(2, b"\x00")
            if type(el) is bytes:
                content+= el
            else:
                content+= el.encode()
    else:
        content = b""


    # size of data
    size = pack("!I", len(content)).lstrip(b"\x00")

    # create header
    header = b""
    if idType in messagesType:
        # size length
        header = len(size)
        # message type
        header+= messagesType[idType] << 2
        header = pack("!I", header).lstrip(b"\x00").rjust(2, b"\x00")
    
    # compute the CRC control sum
    computed_crc = pack("!I", crc32(header + size + content))
    UDPClientSocket.sendto(header + size + content + computed_crc, serverAddressPort)
    packet = header + size + content + computed_crc
    print(f"[DEBUG] sent: ", end="")
    print(packet)

    print("\n\033[32mReceive message... \033[39m", end="")
    msg = (UDPClientSocket.recvfrom(bufferSize))[0]
    print(msg)

    # parse header, lenSize, size, content and CRC
    header = msg[0:2]
    idMessage = (int.from_bytes(header, byteorder="big") & 0b1111111111111100) >> 2
    idType = messagesKeyList[messagesValueList.index(idMessage)]
    sizeLen = int.from_bytes(header, byteorder="big") & 0b11
    size = int.from_bytes(msg[2:2 + sizeLen], byteorder="big")
    message = msg[2 + sizeLen: 2 + sizeLen + size]

    print("\033[32m" + idType + "\033[39m")
    # cut message if multiple string or array
    messageTmp = message
    computedMessage = []
    while len(messageTmp) != 0:
        tmpSize = int.from_bytes(messageTmp[:2], byteorder="big")
        computedMessage.append(messageTmp[2:tmpSize + 2])
        messageTmp = messageTmp[tmpSize + 2:]

    if idType == "RsaKeyReply":
        computedMessage[0] = xor_rsa_key(computedMessage[0])

    # compute the CRC value
    crc = msg[-4:]
    computed_crc = pack("!I", crc32(msg[:-4]))
    if computed_crc == crc:
        is_valid = "Valid"
    else:
        is_valid = "Invalid"

    # print debug informations
    #print_debug(f"[DEBUG] idMessage: {idMessage}")
    print(f"[DEBUG] type:      {idType} ({idMessage})")
    #print_debug(f"[DEBUG] SizeLen:   {sizeLen}")
    print(f"[DEBUG] Size:      {size}")
    print(f"[DEBUG] message:   {computedMessage}")
    #print_debug(int.from_bytes(message[:2], byteorder="big"))
    print(f"[DEBUG] CRC:       {is_valid}")

    return computedMessage, idType

# decrypt RSA public key using given xor key
def xor_rsa_key(encoded):
    encoded = b64decode(encoded)
    pattern = b"ThisIsNotSoSecretPleaseChangeIt"
    decoded = xor(encoded, pattern)
    return decoded

# Encrypt message to send to server
def encryptMessageB64(AES_key, enc, IV):
    cipher = AES.new(AES_key, AES.MODE_CBC, IV)
    enc = IV + cipher.encrypt(pad(enc).encode())
    return b64encode(enc)

# Decrypt message received from server
def decryptMessageB64(AES_key, dec):
    dec = b64decode(dec)
    IV = dec[:16]
    dec = dec[16:]
    cipher = AES.new(AES_key, AES.MODE_CBC, IV)
    dec = cipher.decrypt(dec)
    return dec, IV

# get session token
(sessionID, idType) = sendMessage("ConnectMessage", "CONNECT")
if idType != "ConnectReply":
    exit(1)

# First flag
flag1 = sessionID[1].decode()
print(f"\n\033[33m[+] FLAG: {flag1}\033[39m\n")

# get Public server RSA key
(servPubKey , idType) = sendMessage("RsaKeyMessage", sessionID[0])
if idType != "RsaKeyReply":
    exit(1)

# Compute new AES key
AES_key = get_random_bytes(32)

# Encrypt AES Key with Server public RSA Key
keyPub = RSA.importKey(servPubKey[0])
cipher = Cipher_PKCS1_v1_5.new(keyPub)
cipher_text = cipher.encrypt(AES_key)
cipher_text = b64encode(cipher_text)

# id session + 256 bit AES key -> encrypted with servPubKey -> base64
(msg, idType) = sendMessage("SessionKeyMessage", (sessionID[0], cipher_text))
enc_salt = msg[0]

# get salt
(salt, IV) = decryptMessageB64(AES_key, enc_salt)

# Send auth message
username_enc = encryptMessageB64(AES_key, username, IV)
password_enc = encryptMessageB64(AES_key, password, IV)
(msg, idType) = sendMessage("AuthMessage", (sessionID[0], msg[0], username_enc, password_enc))
if idType != "AuthReply":
    exit(1)

# Second flag
flag2 = msg[1].decode()
print(f"\n\033[33m[+] FLAG: {flag2}\033[39m\n")
