"""I read a Ammos datastream from a socket."""

import socket
from collections import deque
import numpy as np
from ammosreader.AmmosAudioDataHeader import AmmosAudioDataHeader
from ammosreader.AmmosExtendedAudioDataHeader import AmmosExtendedAudioDataHeader
from ammosreader.AmmosGlobalFrameHeader import AmmosGlobalFrameHeader


class AmmosAudioSocketReader:
    def __init__(self, socket:socket.socket):
        """
        Initializes the AmmosAudioSocketReader

        Args:
            socket (socket.socket): socket to read from
        """

        #buffer for reading socket bytewise und check for the magic word
        self.__magic_word_buffer = deque(maxlen=4)

        #input socket to read from
        self.__socket = socket

    def __read_next_audio_data_body(self, sample_count:int, channel_count:int, sample_size:int) -> bytearray:
        """
        reads the next audio data body

        Args:
            sample_count (int): amount of samples
            channel_count (int): amount of channels
            sample_size (int): size of a sample in bytes

        Returns:
            bytearray: contains the audio data
        """

        total = sample_count*channel_count*sample_size
        byte_array = []

        while len(b''.join(byte_array)) < total:
            byte_array.append(self.__socket.recv(total - len(b''.join(byte_array))))

        if len(b''.join(byte_array)) != total:
            print("Can not read all", total, "bytes of data body")
            return None
        return b''.join(byte_array)

    def __audio_data_body_to_numpy(self, audio_data_body:bytearray) -> np.ndarray:
        """
        converts the audio data body to a numpy array

        Args:
            audio_data_body (bytearray): audio data from audio data body

        Returns:
            np.ndarray: audio data as numpy array
        """

        return np.frombuffer(audio_data_body, dtype=np.int16)

    def read_next_frame(self) -> tuple[bytearray, int]:
        """
        reads the next ammos audio frame

        Returns:
            tuple[bytearray, int]: contains the audio data and the sample rate
        """
        #read loop
        byte = self.__socket.recv(1)

        while byte:
            #
            self.__magic_word_buffer.append(byte)
            byte_array = b''.join(self.__magic_word_buffer)

            if byte_array.hex() == '726574fb':
                #print(byte_array.hex())

                ammos_global_header_buffer = list(self.__magic_word_buffer)
                while len(b''.join(ammos_global_header_buffer)) < 24:
                    ammos_global_header_buffer.append(self.__socket.recv(24 - len(b''.join(ammos_global_header_buffer))))
                    
                ammos_global_header = AmmosGlobalFrameHeader.from_bytes(b''.join(ammos_global_header_buffer))
                print(ammos_global_header)

                if ammos_global_header.data_header_length == 44 and ammos_global_header.frame_type == 256:
                    byte_array_header = []
                    while len(b''.join(byte_array_header)) < 44:
                        byte_array_header.append(self.__socket.recv(44 - len(b''.join(byte_array_header))))

                    ammos_extended_audio_data_header = AmmosExtendedAudioDataHeader.from_bytes(b''.join(byte_array_header))
                    print(ammos_extended_audio_data_header.sample_count, ammos_extended_audio_data_header.channel_count, ammos_extended_audio_data_header.sample_size)
                    audio_body = self.__read_next_audio_data_body(ammos_extended_audio_data_header.sample_count, 
                                                                  ammos_extended_audio_data_header.channel_count, 
                                                                  ammos_extended_audio_data_header.sample_size)

                    audio_array = self.__audio_data_body_to_numpy(audio_body)
                    print(len(audio_array), len(audio_array)/ammos_extended_audio_data_header.sample_rate)

                    return [audio_array, ammos_extended_audio_data_header.sample_rate]

                elif ammos_global_header.data_header_length == 36 and ammos_global_header.frame_type == 256:
                    byte_array_header = []
                    while len(b''.join(byte_array_header)) < 36:
                        byte_array_header.append(self.__socket.recv(36 - len(b''.join(byte_array_header))))

                    ammos_audio_data_header = AmmosAudioDataHeader.from_bytes(b''.join(byte_array_header))
                    print(ammos_audio_data_header.sample_count, ammos_audio_data_header.channel_count, ammos_audio_data_header.sample_size)
                    audio_body = self.__read_next_audio_data_body(ammos_audio_data_header.sample_count, 
                                                                  ammos_audio_data_header.channel_count, 
                                                                  ammos_audio_data_header.sample_size)

                    audio_array = self.__audio_data_body_to_numpy(audio_body)
                    print(len(audio_array), len(audio_array)/ammos_audio_data_header.sample_rate)

                    return [audio_array, ammos_audio_data_header.sample_rate]

            byte = self.__socket.recv(1)

        return None
