# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import Queue
import struct
import sys
import threading
import time
import traceback
import uuid
import weakref

from cobs import cobs
import serial

from . import exceptions
import stm32_crc

logger = logging.getLogger(__name__)

try:
    import pyftdi.serialext
except ImportError:
    pass

DBGSERIAL_PORT_SETTINGS = dict(baudrate=230400, timeout=0.1,
                               interCharTimeout=0.01)


def get_dbgserial_tty():
    # Local import so that we only depend on this package if we're attempting
    # to autodetect the TTY. This package isn't always available (e.g., MFG),
    # so we don't want it to be required.
    try:
        import pebble_tty
        return pebble_tty.find_dbgserial_tty()
    except ImportError:
        raise exceptions.TTYAutodetectionUnavailable


def frame_splitter(istream, size=1024, timeout=1, delimiter='\0'):
    '''Returns an iterator which yields complete frames.'''
    partial = []
    start_time = time.time()
    while not istream.closed:
        data = istream.read(size)
        logger.debug('frame_splitter: received %r', data)
        while True:
            left, delim, data = data.partition(delimiter)
            if left:
                partial.append(left)
            if delim:
                frame = ''.join(partial)
                partial = []
                if frame:
                    yield frame
            if not data:
                break
        if timeout > 0 and time.time() > start_time + timeout:
            yield

def decode_frame(frame):
    '''Decodes a PULSE frame.

    Returns a tuple (protocol, payload) of the decoded frame.
    Raises FrameDecodeError if the frame is not valid.
    '''
    try:
        data = cobs.decode(frame)
    except cobs.DecodeError, e:
        raise exceptions.FrameDecodeError(e.message)
    if len(data) < 5:
        raise exceptions.FrameDecodeError('frame too short')
    fcs = struct.unpack('<I', data[-4:])[0]
    crc = stm32_crc.crc32(data[:-4])
    if fcs != crc:
        raise exceptions.FrameDecodeError('FCS 0x%.08x != CRC 0x%.08x' % (fcs, crc))
    protocol = ord(data[0])
    return (protocol, data[1:-4])

def encode_frame(protocol, payload):
    frame = struct.pack('<B', protocol)
    frame += payload
    fcs = stm32_crc.crc32(frame)
    frame += struct.pack('<I', fcs)
    return cobs.encode(frame)


class Connection(object):
    '''A socket for sending and receiving datagrams over the PULSE serial
    protocol.
    '''

    PROTOCOL_LLC = 0x01

    LLC_LINK_OPEN_REQUEST = '\x01\x03\x08\x08\x08PULSEv1\r\n'
    LLC_LINK_CLOSE_REQUEST = '\x03'
    LLC_ECHO_REQUEST = '\x05'
    LLC_CHANGE_BAUD = '\x07'

    LLC_LINK_OPENED = 0x02
    LLC_LINK_CLOSED = 0x04
    LLC_ECHO_REPLY = 0x06

    EXTENSIONS = {}

    # Maximum round-trip time
    rtt = 0.4

    def __init__(self, iostream, infinite_reconnect=False):
        self.infinite_reconnect = infinite_reconnect
        self.iostream = iostream
        self.closed = False
        try:
            self.initial_port_settings = self.iostream.getSettingsDict()
        except AttributeError:
            self.initial_port_settings = None
        self.port_settings_altered = False
        # Whether the link is open for sending.
        self._link_open = threading.Event()
        # Whether the link has been severed.
        self._link_closed = threading.Event()
        self.send_lock = threading.RLock()
        self.echoes_inflight = weakref.WeakValueDictionary()
        self.protocol_handlers = weakref.WeakValueDictionary()
        self.receive_thread = threading.Thread(target=self.run_receive_thread)
        self.receive_thread.daemon = True
        self.receive_thread.start()
        self._open_link()

        self.keepalive_thread = threading.Thread(
                target=self.run_keepalive_thread)
        self.keepalive_thread.daemon = True
        self.keepalive_thread.start()

        # Instantiate and bind all known extensions
        for name, factory in self.EXTENSIONS.iteritems():
            setattr(self, name, factory(self))

    @classmethod
    def register_extension(cls, name, factory):
        '''Register a PULSE connection extension.

        When a Connection object is instantiated, the object returned by
        factory(connection_object) is assigned to connection_object.<name>.
        '''
        try:
            getattr(cls, name)
        except AttributeError:
            cls.EXTENSIONS[name] = factory
        else:
            raise ValueError('extension name %r clashes with existing attribute'
                    % (name,))

    @classmethod
    def open_dbgserial(cls, url=None, infinite_reconnect=False):
        if url is None:
            url = get_dbgserial_tty()
        if url == "qemu":
            url = 'socket://localhost:12345'
        ser = serial.serial_for_url(url, **DBGSERIAL_PORT_SETTINGS)

        if url.startswith('socket://'):
            # Socket class for PySerial does some pointless buffering
            # setting a very small timeout effectively negates it
            ser._timeout = 0.00001

        return cls(ser, infinite_reconnect=infinite_reconnect)

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.close()

    def __del__(self):
        self.close()

    def send(self, protocol, payload):
        if self.closed:
            raise exceptions.PulseError('I/O operation on closed connection')
        frame = ''.join(('\0', encode_frame(protocol, payload), '\0'))
        logger.debug('Connection: sending %r', frame)
        with self.send_lock:
            self.iostream.write(frame)

    def run_receive_thread(self):
        logger.debug('Connection: receive thread started')
        receiver = frame_splitter(self.iostream, timeout=0)
        while True:
            try:
                protocol, payload = decode_frame(next(receiver))
            except exceptions.FrameDecodeError:
                continue
            except:
                # Probably a PySerial exception complaining about reading from a
                # closed port. Eat the exception and shut down the thread; users
                # don't need to see the stack trace.
                logger.debug('Connection: exception in receive thread:\n%s',
                             traceback.format_exc())
                break
            logger.debug('Connection:run_receive_thread: '
                    'protocol=%d payload=%r', protocol, payload)
            if protocol == self.PROTOCOL_LLC:  # LLC can't be overridden
                self.llc_handler(payload)
                continue
            try:
                handler = self.protocol_handlers[protocol]
            except KeyError:
                self.default_receiver(protocol, payload)
            else:
                handler.on_receive(payload)
        logger.debug('Connection: receive thread exiting')

    def default_receiver(self, protocol, frame):
        logger.info('Connection:default_receiver: received frame '
                'with protocol %d: %r', protocol, frame)

    def register_protocol_handler(self, protocol, handler):
        '''Register a handler for frames bearing the specified protocol number.

        handler.on_receive(payload) is called for each frame received with the
        protocol number.

        Protocol handlers can be unregistered by calling this function with a
        handler of None.
        '''
        if not handler:
            try:
                del self.protocol_handlers[protocol]
            except KeyError:
                pass
            return
        if protocol in self.protocol_handlers:
            raise exceptions.ProtocolAlreadyRegistered(
            'Protocol %d is already registered by %r' % (
                protocol, self.protocol_handlers[protocol]))
        if not hasattr(handler, 'on_receive'):
            raise ValueError('%r does not have an on_receive method')
        self.protocol_handlers[protocol] = handler

    def llc_handler(self, frame):
        opcode = ord(frame[0])
        if opcode == self.LLC_LINK_OPENED:
            # MTU and MRU are from the perspective of this side of the
            # connection
            version, mru, mtu, timeout = struct.unpack('<xBHHB', frame)
            self.version = version
            # The server reports the MTU inclusive of protocol number and FCS,
            # but we only care about the maximum payload length.
            self.mtu = mtu - 5
            self.mru = mru
            # Timeout is specified in deciseconds. Convert to seconds.
            self.timeout = timeout / 10.0
            self._link_closed.clear()
            self._link_open.set()
        elif opcode == self.LLC_LINK_CLOSED:
            logger.info('PULSE connection closed.')
            self._link_closed.set()
        elif opcode == self.LLC_ECHO_REPLY:
            self._on_echo_reply(frame[1:])
        else:
            logger.warning('Received LLC frame with unknown type %d: %r',
                           opcode, frame)

    def run_keepalive_thread(self):
        '''The keepalive thread monitors the link, reopening it if necessary.
        '''
        logger.debug('Connection: keepalive thread started')
        OPEN, TEST_LIVENESS, RECONNECT = range(3)
        state = OPEN
        next_state = state
        ping_attempts = 0
        ping_wait = self.rtt
        while True:
            # Check whether the link is being closed from our side before
            # trying to keep it alive.
            if not self._link_open.is_set():
                return

            if state == OPEN:
                time.sleep(1)
                if self._link_closed.is_set():
                    next_state = RECONNECT
                else:
                    next_state = TEST_LIVENESS
            elif state == TEST_LIVENESS:
                if ping_attempts < 3:
                    ping_attempts += 1
                    ping_wait *= 2  # Exponential backoff
                    if self.ping(ping_wait):
                        next_state = OPEN
                    else:
                        logger.info('No response to keepalive ping -- '
                                    'strike %d', ping_attempts)
                else:
                    logger.info('Connection: keepalive timed out.')
                    next_state = RECONNECT
            elif state == RECONNECT:
                # Lock out everyone from sending so that applications don't send
                # to a connection that's in an indeterminate state.
                with self.send_lock:
                    if self.port_settings_altered:
                        # Ensure that the server has timed out and reset its
                        # baud rate so we don't get into the bad situation where
                        # we try to reconnect at the default baud rate but the
                        # server is listening at a different rate, which is
                        # practically guaranteed to fail.
                        logger.info('Letting connection time out before '
                                    'attempting to reconnect.')
                        time.sleep(self.timeout + self.rtt)
                    self._link_open.clear()
                    while not self._link_open.is_set():
                        try:
                            self._open_link()
                        except exceptions.PulseError as e:
                            logger.warning('Connection: reconnect failed. %s', e)
                            if not self.infinite_reconnect:
                                break
                            logger.warning('Will try again.')
                            logger.info('Backing off for a while before retrying.')
                            time.sleep(self.timeout + self.rtt)
                        else:
                            next_state = OPEN
            else:
                assert False, 'Invalid state %d' % state

            if next_state != state:
                if next_state == TEST_LIVENESS:
                    ping_attempts = 0
                    ping_wait = self.rtt
            state = next_state

    def _open_link(self):
        self.closed = False
        if self.initial_port_settings:
            self.iostream.applySettingsDict(self.initial_port_settings)
        for attempt in xrange(5):
            logger.info('Opening link (attempt %d)...', attempt)
            self.send(self.PROTOCOL_LLC, self.LLC_LINK_OPEN_REQUEST)
            if self._link_open.wait(self.rtt):
                logger.info('Established PULSE connection!')
                logger.info('Version=%d  MTU=%d  MRU=%d  Timeout=%.1f',
                            self.version, self.mtu, self.mru, self.timeout)
                break
        else:
            self._link_closed.set()
            self.closed = True
            raise exceptions.PulseError('Could not establish connection')

    def close(self):
        self._link_open.clear()
        if not self._link_closed.is_set():
            for attempt in xrange(3):
                self.send(self.PROTOCOL_LLC, self.LLC_LINK_CLOSE_REQUEST)
                if self._link_closed.wait(self.rtt):
                    break
            else:
                logger.warning('Could not confirm link close.')
                self._link_closed.set()
        self.iostream.close()
        self.closed = True

    def ping(self, timeout=None):
        if not timeout:
            timeout = 2 * self.rtt
        nonce = uuid.uuid4().bytes
        is_received = threading.Event()
        self.echoes_inflight[nonce] = is_received
        self.send(self.PROTOCOL_LLC, self.LLC_ECHO_REQUEST + nonce)
        return is_received.wait(timeout)

    def _on_echo_reply(self, payload):
        try:
            receive_event = self.echoes_inflight[payload]
            receive_event.set()
        except KeyError:
            pass

    def change_baud_rate(self, new_baud):
        # Fail fast if the IO object doesn't support changing the baud rate
        old_baud = self.iostream.baudrate
        self.send(self.PROTOCOL_LLC,
                  self.LLC_CHANGE_BAUD + struct.pack('<I', new_baud))
        # Be extra sure that the message has been sent and it's safe to adjust
        # the baud rate on the port.
        time.sleep(0.1)
        self.iostream.baudrate = new_baud
        self.port_settings_altered = True


class ProtocolSocket(object):
    '''A socket for sending and receiving datagrams of a single protocol over a
    PULSE connection.

    It is also an example of a Connection protocol handler implementation.
    '''

    def __init__(self, connection, protocol):
        self.connection = connection
        self.protocol = protocol
        self.receive_queue = Queue.Queue()
        self.connection.register_protocol_handler(protocol, self)

    def on_receive(self, frame):
        self.receive_queue.put(frame)

    def receive(self, block=True, timeout=None):
        try:
            return self.receive_queue.get(block, timeout)
        except Queue.Empty:
            raise exceptions.ReceiveQueueEmpty

    def send(self, frame):
        self.connection.send(self.protocol, frame)

    @property
    def mtu(self):
        return self.connection.mtu


if __name__ == '__main__':
    logging.basicConfig(level=logging.INFO)
    with Connection.open_dbgserial(sys.argv[1]) as sock:
        sock.change_baud_rate(921600)
        for _ in xrange(20):
            time.sleep(0.5)
            send_time = time.time()
            if sock.ping():
                print "Ping rtt=%.2f ms" % ((time.time() - send_time) * 1000)
            else:
                print "No echo"