import Env
import Message
from Log import LOG
import sys
import os
import utils
from protocols import protos
import traceback

class Bufferv2():
    def __init__(self, env):
        self.env = env
        self.buffer = bytearray(b'')
        self.serv_buffer = bytearray(b'')
        self.initial_offset = -1
        self.received_data = 0
        self.protocol = None
        self.size_type = None
        self.message_size = None

    def get_headers(self, raw_data):
        val = int.from_bytes(raw_data[:2], "big")
        self.protocol = val >> 2
        self.size_type = int(val & 0b11)
        self.message_size = 0
        if self.size_type > 0:
            self.message_size = int.from_bytes(raw_data[2:2 + self.size_type], 'big')

    def packet_processing(self, validated):
        read = 0
        while len(self.buffer) >= 2:
            self.get_headers(self.buffer)
            if self.protocol == -1 or self.size_type == -1 or self.message_size == -1:
                self.buffer = bytearray(b'')
                return
            message = Message.Message(self.buffer[2 + self.size_type:2 + self.size_type + self.message_size], self.protocol, self.message_size, self.size_type, self.buffer[:self.message_size + 2 + self.size_type], self.buffer[:2 + self.size_type + self.message_size], validated)
            if message.valid == False:
                LOG.debug(f"Refused protocol: {self.protocol} ({len(self.buffer)})")
                self.received_data = 0
                self.initial_offset += len(self.buffer)
                self.buffer = bytearray(b'')
                return
            elif message.valid == True and message.missing == True:
                message.print_infos()
                return
            elif message.valid == True:
                self.received_data -= self.message_size + self.size_type + 2
                if self.received_data < 0:
                    self.received_data = 0
                try:
                    self.env.dispatch_packet(message)
                except Exception as e:
                    traceback.print_exc()
            self.buffer = self.buffer[self.message_size + self.size_type + 2:]
            self.initial_offset += self.message_size + self.size_type + 2

    def server_packet_process(self):
        self.get_headers(self.server_buffer)
        message = Message.Message(self.server_buffer[2:], self.protocol, self.message_size, self.size_type, self.server_buffer, self.server_buffer[:2 + self.size_type + self.message_size], len(self.server_buffer))
        if message.valid == True:
            self.env.dispatch_packet(message)

    def add_packet(self, packet, direction):
        if direction == "SRV":
            if self.received_data == 0:
                self.initial_offset = int(packet.tcp.ack)
            validated = int(packet.tcp.ack) - self.initial_offset
            if validated > self.received_data:
                LOG.error(f'data received {self.received_data}, expecting {validated}')
            if hasattr(packet, 'tcp') and hasattr(packet.tcp, 'payload'):
                self.server_buffer = bytearray.fromhex(packet.tcp.payload.raw_value)
                self.server_packet_process()
            self.packet_processing(int(packet.tcp.ack) - self.initial_offset)
            return

        if self.initial_offset == -1 or self.received_data == 0:
            self.initial_offset = int(packet.tcp.seq)
        current_offset = int(packet.tcp.seq) - self.initial_offset
        if current_offset < 0:
            current_offset = 0
        value = bytearray.fromhex(packet.tcp.payload.raw_value)
        if current_offset >= len(self.buffer):
            self.buffer.extend([0 for i in range((current_offset + len(value)) - len(self.buffer))])
            i = 0
            for item in value:
                self.buffer[current_offset + i] = item
                i += 1
        current_offset += 1
        self.received_data += len(value)
