"""I provide a source to read AMMOS data from a socket."""

import socket

from ammosreader.AmmosSource import AmmosSource
from ammosreader import logger


class AmmosSocketSource(AmmosSource):
    """I implement a descendent of AmmosSource that reads from a socket."""

    def read_bytes(self, bytes_to_read):
        """
        I read bytes_to_read bytes from my socket. Incomplete reads get discarded.

        :param bytes_to_read: the number of bytes to read
        :type bytes_to_read: int

        :return: the complete byte string or None if not all the bytes are read
        :rtype: bytes
        """

        assert bytes_to_read > 0
        byte_array = []
        logger.info("Start reading bytes from socket")
        try:
            while bytes_to_read > len(byte_array):
                logger.info("Remaining Bytes: %s", bytes_to_read - len(byte_array))
                self.source.settimeout(self.timeout)
                new_bytes = self.source.recv(bytes_to_read - len(byte_array), socket.MSG_WAITALL)

                if not new_bytes:
                    raise TimeoutError("Socket timed out while reading data")
                logger.info("Got %s bytes of %s remaining", len(new_bytes), bytes_to_read - len(byte_array))
                byte_array.append(new_bytes)
        except TimeoutError:
            logger.info("Timeout error while reading from socket")
            return None
        logger.info("All bytes read")
        return b''.join(byte_array)
