#!/bin/python3
from pwn import args

import socket
from zlib import crc32
from struct import pack
from base64 import b64decode, b64encode
from pwn import xor

from Crypto.Cipher import AES
from Crypto.PublicKey import RSA
from Crypto.Cipher import PKCS1_v1_5 as Cipher_PKCS1_v1_5
from Crypto.Random import get_random_bytes

# constants
serverAddressPort = ("secure-ftp.dghack.fr", 4445)
bufferSize = 2048
username = "GUEST_USER"
password = "GUEST_PASSWORD"

# padding for AES encryption and decryption
BS = 16
pad = lambda s: s + (BS - len(s) % BS) * chr(BS - len(s) % BS)
unpad = lambda s : s[0:-ord(s[-1])]

# Create a UDP socket at client side with timeout of 0.5s
UDPClientSocket = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM)
UDPClientSocket.settimeout(0.5)

# Messages types
messagesType = {"RsaKeyMessage":     78,
                "PingMessage":       10,
                "AuthMessage":       4444,
                "GetFileMessage":    666,
                "GetFilesMessage":   45,
                "ErrorMessage":      1,
                "ConnectMessage":    1921,
                "SessionKeyMessage": 1337,
                "RsaKeyReply":       98,
                "ConnectReply":      4875,
                "PingReply":         11,
                "SessionKeyReply":   1338,
                "AuthReply":         6789,
                "GetFileReply":      7331,
                "GetFilesReply":     46}

messagesKeyList = list(messagesType.keys())
messagesValueList = list(messagesType.values())

# use -DEBUG for verbose output
def print_debug(msg, end = "\n"):
    if args.VERBOSE:
        print(msg, end=end)

# Send message of type 'type' to the server, then receive and parse the reply from the server. Retry up to 5 times
def sendMessage(idType, data, timeout = 4):
    try:
        print_debug(f"\n\033[32mSend message: {idType}\033[39m")
        # serialize data
        if type(data) is str or type(data) is bytes:
            # append the lenght of string at the beginning
            content = pack("!I", len(data)).lstrip(b"\x00").rjust(2, b"\x00")
            if type(data) is bytes:
                content+= data
            else:
                content+= data.encode()
        # send array if list, never used
        elif type(data) is list:
            content = b""
            totallen = -1
            for el in data:
                if type(el) is bytes:
                    content+= el + b"\x00"
                else:
                    content+= el.encode() + b"\x00"
                totallen+= len(el) + 1
            content = content[:-1]
            content = pack("!I", totallen).lstrip(b"\x00").rjust(2, b"\x00") + content
        # send strings if tuple
        elif type(data) is tuple:
            content = b""
            for el in data:
                content+= pack("!I", len(el)).lstrip(b"\x00").rjust(2, b"\x00")
                if type(el) is bytes:
                    content+= el
                else:
                    content+= el.encode()
        else:
            content = b""


        # size of data
        size = pack("!I", len(content)).lstrip(b"\x00")

        # create header
        header = b""
        if idType in messagesType:
            # size length
            header = len(size)
            # message type
            header+= messagesType[idType] << 2
            header = pack("!I", header).lstrip(b"\x00").rjust(2, b"\x00")
        
        # compute the CRC control sum
        computed_crc = pack("!I", crc32(header + size + content))
        UDPClientSocket.sendto(header + size + content + computed_crc, serverAddressPort)
        packet = header + size + content + computed_crc
        print_debug(f"[DEBUG] sent: ", end="")
        print_debug(packet)


        # RECEIVING
        print("\n\033[32mReceive message... \033[39m", end="")
        msg = (UDPClientSocket.recvfrom(bufferSize))[0]
        print_debug(msg)

        # parse header, lenSize, size, content and CRC
        header = msg[0:2]
        idMessage = (int.from_bytes(header, byteorder="big") & 0b1111111111111100) >> 2
        idType = messagesKeyList[messagesValueList.index(idMessage)]
        sizeLen = int.from_bytes(header, byteorder="big") & 0b11
        size = int.from_bytes(msg[2:2 + sizeLen], byteorder="big")
        message = msg[2 + sizeLen: 2 + sizeLen + size]

        print("\033[32m" + idType + "\033[39m")
        # cut message if multiple string or array
        messageTmp = message
        computedMessage = []
        while len(messageTmp) != 0:
            tmpSize = int.from_bytes(messageTmp[:2], byteorder="big")
            computedMessage.append(messageTmp[2:tmpSize + 2])
            messageTmp = messageTmp[tmpSize + 2:]

        if idType == "RsaKeyReply":
            computedMessage[0] = xor_rsa_key(computedMessage[0])

        # compute the CRC value
        crc = msg[-4:]
        computed_crc = pack("!I", crc32(msg[:-4]))
        if computed_crc == crc:
            is_valid = "Valid"
        else:
            is_valid = "Invalid"

        # print debug informations
        #print_debug(f"[DEBUG] idMessage: {idMessage}")
        print_debug(f"[DEBUG] type:      {idType} ({idMessage})")
        #print_debug(f"[DEBUG] SizeLen:   {sizeLen}")
        print_debug(f"[DEBUG] Size:      {size}")
        print_debug(f"[DEBUG] message:   {computedMessage}")
        #print_debug(int.from_bytes(message[:2], byteorder="big"))
        print_debug(f"[DEBUG] CRC:       {is_valid}")

        return computedMessage, idType
    
    except socket.timeout:
        if timeout == 0:
            print("\033[31mRetry limit exceeded, exiting.\033[39m")
            exit(1)
        else:
            print("\033[31mError while receiving, retrying.\033[39m")
            return sendMessage(idType, data, timeout - 1)

# decrypt RSA public key using given xor key
def xor_rsa_key(encoded):
    encoded = b64decode(encoded)
    pattern = b"ThisIsNotSoSecretPleaseChangeIt"
    decoded = xor(encoded, pattern)
    return decoded

# Encrypt message to send to server
def encryptMessageB64(AES_key, enc, IV):
    cipher = AES.new(AES_key, AES.MODE_CBC, IV)
    enc = IV + cipher.encrypt(pad(enc).encode())
    return b64encode(enc)

# Decrypt message received from server
def decryptMessageB64(AES_key, dec):
    dec = b64decode(dec)
    IV = dec[:16]
    dec = dec[16:]
    cipher = AES.new(AES_key, AES.MODE_CBC, IV)
    dec = cipher.decrypt(dec)
    return dec, IV


"""
Main Flow
"""
# get session token
(sessionID, idType) = sendMessage("ConnectMessage", "CONNECT")
if idType != "ConnectReply":
    exit(1)

# First flag
flag1 = sessionID[1].decode()
print(f"\n\033[33m[+] FLAG: {flag1}\033[39m\n")

# get Public server RSA key
(servPubKey , idType) = sendMessage("RsaKeyMessage", sessionID[0])
if idType != "RsaKeyReply":
    exit(1)

# Compute new AES key
AES_key = get_random_bytes(32)

# Encrypt AES Key with Server public RSA Key
keyPub = RSA.importKey(servPubKey[0])
cipher = Cipher_PKCS1_v1_5.new(keyPub)
cipher_text = cipher.encrypt(AES_key)
cipher_text = b64encode(cipher_text)

# id session + 256 bit AES key -> encrypted with servPubKey -> base64
(msg, idType) = sendMessage("SessionKeyMessage", (sessionID[0], cipher_text))
enc_salt = msg[0]

# get salt
(salt, IV) = decryptMessageB64(AES_key, enc_salt)

# Send auth message
username_enc = encryptMessageB64(AES_key, username, IV)
password_enc = encryptMessageB64(AES_key, password, IV)
(msg, idType) = sendMessage("AuthMessage", (sessionID[0], msg[0], username_enc, password_enc))
if idType != "AuthReply":
    exit(1)

# Second flag
flag2 = msg[1].decode()
print(f"\n\033[33m[+] FLAG: {flag2}\033[39m\n")

# get files in /opt/
path = "/opt/"
path_enc = encryptMessageB64(AES_key, path, IV)
(path_enc, __) = sendMessage("GetFilesMessage", (sessionID[0], path_enc))

# decode path
(path, IV) = decryptMessageB64(AES_key, path_enc[0])
# Print folder content
print(unpad(path.decode()).rstrip("\x00"))

# get files in /opt/dga2021/
path = "/opt/dga2021/"
path_enc = encryptMessageB64(AES_key, path, IV)
(path_enc, __) = sendMessage("GetFilesMessage", (sessionID[0], path_enc))

# decode path
(path, IV) = decryptMessageB64(AES_key, path_enc[0])
# Print folder content
print(unpad(path.decode()).rstrip("\x00"))

path = "/opt/dga2021/flag"
path_enc = encryptMessageB64(AES_key, path, IV)
# Read file /opt/dga2021/flag
(file_content_enc, __) = sendMessage("GetFileMessage", (sessionID[0], path_enc))

# decode file content
(file_content, IV) = decryptMessageB64(AES_key, file_content_enc[0])

# Third flag
flag3 = file_content.rstrip(b"\x08").rstrip(b"\x07").rstrip(b"\x06").rstrip(b"\x05").decode()

# Win!!
print(f"\n\033[33m[+] FLAG 1: {flag1}\033[39m")
print(f"\n\033[33m[+] FLAG 2: {flag2}\033[39m")
print(f"\n\033[33m[+] FLAG 3: {flag3}\033[39m")
