"""I read a Ammos datastream from a socket."""

import select
import socket
from collections import deque
import numpy as np
from ammosreader.AmmosAudioDataHeader import AmmosAudioDataHeader
from ammosreader.AmmosExtendedAudioDataHeader import AmmosExtendedAudioDataHeader
from ammosreader.AmmosGlobalFrameHeader import AmmosGlobalFrameHeader
from ammosreader import logger

class AmmosAudioSocketReader:
    def __init__(self, socket:socket.socket):
        """
        Initializes the AmmosAudioSocketReader

        Args:
            socket (socket.socket): socket to read from
        """

        #buffer for reading socket bytewise und check for the magic word
        self.__magic_word_buffer = deque(maxlen=4)

        #input socket to read from
        self.__socket = socket

    def __get_next_data(self, byte_count: int) -> bytearray:
        """
        Gets the next bytes from the socket, for example headers and body data.

        Args:
            byte_count (int): number of bytes to read

        Raises:
            TimeoutError: Raises TimeoutError if the socket does not serve data anymore

        Returns:
            bytearray: data from socket as bytearray
        """

        byte_array = []
        
        while len(b''.join(byte_array)) < byte_count:
            logger.info(f"Remaining Bytes: {byte_count - len(b''.join(byte_array))}")
            self.__socket.settimeout(5)
            new_bytes = self.__socket.recv(byte_count - len(b''.join(byte_array)), socket.MSG_WAITALL)

            if not new_bytes:
                raise TimeoutError("Socket timed out while reading data")

            logger.info(f"Got {len(new_bytes)} bytes of {byte_count - len(b''.join(byte_array))} ramining")

            byte_array.append(new_bytes)

        return b''.join(byte_array)

    def __audio_data_body_to_numpy(self, audio_data_body:bytearray) -> np.ndarray:
        """
        converts the audio data body to a numpy array

        Args:
            audio_data_body (bytearray): audio data from audio data body

        Returns:
            np.ndarray: audio data as numpy array
        """

        return np.frombuffer(audio_data_body, dtype=np.int16)

    def read_next_frame(self) -> tuple[np.ndarray, int]:
        """Reads the next ammos audio frame from socket

        Raises:
            TimeoutError: Raisees TimeoutError if the socket does not serve data anymore

        Returns:
            tuple[np.ndarray, int]: Contains the audio data and the sample rate
        """

        # get first byte of the day
        self.__socket.settimeout(5)

        new_byte = self.__socket.recv(1, socket.MSG_WAITALL)
        # raise Exception if socket does not return anything
        if len(new_byte) < 1:
            raise TimeoutError      

        #read loop
        while new_byte:
            #
            self.__magic_word_buffer.append(new_byte)
            byte_array = b''.join(self.__magic_word_buffer)

            if byte_array.hex() == '726574fb':
                #print(byte_array.hex())

                ammos_global_header_buffer = list(self.__magic_word_buffer)
                ammos_global_header_buffer.append(self.__get_next_data(20))
                #while len(b''.join(ammos_global_header_buffer)) < 24:
                #    ammos_global_header_buffer.append(self.__socket.recv(24 - len(b''.join(ammos_global_header_buffer))))
                    
                ammos_global_header = AmmosGlobalFrameHeader.from_bytes(b''.join(ammos_global_header_buffer))
                logger.info(ammos_global_header)

                if ammos_global_header.data_header_length == 44 and ammos_global_header.frame_type == 256:
                    byte_array_header = self.__get_next_data(44)
                    #while len(b''.join(byte_array_header)) < 44:
                    #    byte_array_header.append(self.__socket.recv(44 - len(b''.join(byte_array_header))))

                    ammos_extended_audio_data_header = AmmosExtendedAudioDataHeader.from_bytes(byte_array_header)
                    print(str(ammos_extended_audio_data_header.number_of_samples), str(ammos_extended_audio_data_header.number_of_channels), str(ammos_extended_audio_data_header.sample_size))
                    audio_body = self.__get_next_data(ammos_extended_audio_data_header.number_of_samples* 
                                                      ammos_extended_audio_data_header.number_of_channels* 
                                                      ammos_extended_audio_data_header.sample_size)

                    audio_array = self.__audio_data_body_to_numpy(audio_body)
                    print(str(len(audio_array)), str(len(audio_array)/ammos_extended_audio_data_header.sample_rate))

                    return [audio_array, ammos_extended_audio_data_header.sample_rate]

                elif ammos_global_header.data_header_length == 36 and ammos_global_header.frame_type == 256:
                    byte_array_header = self.__get_next_data(36)
                    #while len(b''.join(byte_array_header)) < 36:
                    #    byte_array_header.append(self.__socket.recv(36 - len(b''.join(byte_array_header))))

                    ammos_audio_data_header = AmmosAudioDataHeader.from_bytes(byte_array_header)
                    print(str(ammos_audio_data_header.number_of_samples), str(ammos_audio_data_header.number_of_channels), str(ammos_audio_data_header.sample_size))
                    audio_body = self.__get_next_data(ammos_extended_audio_data_header.number_of_samples* 
                                                      ammos_extended_audio_data_header.number_of_channels* 
                                                      ammos_extended_audio_data_header.sample_size)

                    audio_array = self.__audio_data_body_to_numpy(audio_body)
                    logger.info(str(len(audio_array)), str(len(audio_array)/ammos_audio_data_header.sample_rate))

                    return [audio_array, ammos_audio_data_header.sample_rate]

            # get the next byte
            self.__socket.settimeout(5)

            new_byte = self.__socket.recv(1, socket.MSG_WAITALL)
            # raise Exception if socket does not return anything
            if len(new_byte) < 1:
                raise TimeoutError   

        return None
