# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging import Queue import struct import sys import threading import time import traceback import uuid import weakref from cobs import cobs import serial from . import exceptions import stm32_crc logger = logging.getLogger(__name__) try: import pyftdi.serialext except ImportError: pass DBGSERIAL_PORT_SETTINGS = dict(baudrate=230400, timeout=0.1, interCharTimeout=0.01) def get_dbgserial_tty(): # Local import so that we only depend on this package if we're attempting # to autodetect the TTY. This package isn't always available (e.g., MFG), # so we don't want it to be required. try: import pebble_tty return pebble_tty.find_dbgserial_tty() except ImportError: raise exceptions.TTYAutodetectionUnavailable def frame_splitter(istream, size=1024, timeout=1, delimiter='\0'): '''Returns an iterator which yields complete frames.''' partial = [] start_time = time.time() while not istream.closed: data = istream.read(size) logger.debug('frame_splitter: received %r', data) while True: left, delim, data = data.partition(delimiter) if left: partial.append(left) if delim: frame = ''.join(partial) partial = [] if frame: yield frame if not data: break if timeout > 0 and time.time() > start_time + timeout: yield def decode_frame(frame): '''Decodes a PULSE frame. Returns a tuple (protocol, payload) of the decoded frame. Raises FrameDecodeError if the frame is not valid. ''' try: data = cobs.decode(frame) except cobs.DecodeError, e: raise exceptions.FrameDecodeError(e.message) if len(data) < 5: raise exceptions.FrameDecodeError('frame too short') fcs = struct.unpack('<I', data[-4:])[0] crc = stm32_crc.crc32(data[:-4]) if fcs != crc: raise exceptions.FrameDecodeError('FCS 0x%.08x != CRC 0x%.08x' % (fcs, crc)) protocol = ord(data[0]) return (protocol, data[1:-4]) def encode_frame(protocol, payload): frame = struct.pack('<B', protocol) frame += payload fcs = stm32_crc.crc32(frame) frame += struct.pack('<I', fcs) return cobs.encode(frame) class Connection(object): '''A socket for sending and receiving datagrams over the PULSE serial protocol. ''' PROTOCOL_LLC = 0x01 LLC_LINK_OPEN_REQUEST = '\x01\x03\x08\x08\x08PULSEv1\r\n' LLC_LINK_CLOSE_REQUEST = '\x03' LLC_ECHO_REQUEST = '\x05' LLC_CHANGE_BAUD = '\x07' LLC_LINK_OPENED = 0x02 LLC_LINK_CLOSED = 0x04 LLC_ECHO_REPLY = 0x06 EXTENSIONS = {} # Maximum round-trip time rtt = 0.4 def __init__(self, iostream, infinite_reconnect=False): self.infinite_reconnect = infinite_reconnect self.iostream = iostream self.closed = False try: self.initial_port_settings = self.iostream.getSettingsDict() except AttributeError: self.initial_port_settings = None self.port_settings_altered = False # Whether the link is open for sending. self._link_open = threading.Event() # Whether the link has been severed. self._link_closed = threading.Event() self.send_lock = threading.RLock() self.echoes_inflight = weakref.WeakValueDictionary() self.protocol_handlers = weakref.WeakValueDictionary() self.receive_thread = threading.Thread(target=self.run_receive_thread) self.receive_thread.daemon = True self.receive_thread.start() self._open_link() self.keepalive_thread = threading.Thread( target=self.run_keepalive_thread) self.keepalive_thread.daemon = True self.keepalive_thread.start() # Instantiate and bind all known extensions for name, factory in self.EXTENSIONS.iteritems(): setattr(self, name, factory(self)) @classmethod def register_extension(cls, name, factory): '''Register a PULSE connection extension. When a Connection object is instantiated, the object returned by factory(connection_object) is assigned to connection_object.<name>. ''' try: getattr(cls, name) except AttributeError: cls.EXTENSIONS[name] = factory else: raise ValueError('extension name %r clashes with existing attribute' % (name,)) @classmethod def open_dbgserial(cls, url=None, infinite_reconnect=False): if url is None: url = get_dbgserial_tty() if url == "qemu": url = 'socket://localhost:12345' ser = serial.serial_for_url(url, **DBGSERIAL_PORT_SETTINGS) if url.startswith('socket://'): # Socket class for PySerial does some pointless buffering # setting a very small timeout effectively negates it ser._timeout = 0.00001 return cls(ser, infinite_reconnect=infinite_reconnect) def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): self.close() def __del__(self): self.close() def send(self, protocol, payload): if self.closed: raise exceptions.PulseError('I/O operation on closed connection') frame = ''.join(('\0', encode_frame(protocol, payload), '\0')) logger.debug('Connection: sending %r', frame) with self.send_lock: self.iostream.write(frame) def run_receive_thread(self): logger.debug('Connection: receive thread started') receiver = frame_splitter(self.iostream, timeout=0) while True: try: protocol, payload = decode_frame(next(receiver)) except exceptions.FrameDecodeError: continue except: # Probably a PySerial exception complaining about reading from a # closed port. Eat the exception and shut down the thread; users # don't need to see the stack trace. logger.debug('Connection: exception in receive thread:\n%s', traceback.format_exc()) break logger.debug('Connection:run_receive_thread: ' 'protocol=%d payload=%r', protocol, payload) if protocol == self.PROTOCOL_LLC: # LLC can't be overridden self.llc_handler(payload) continue try: handler = self.protocol_handlers[protocol] except KeyError: self.default_receiver(protocol, payload) else: handler.on_receive(payload) logger.debug('Connection: receive thread exiting') def default_receiver(self, protocol, frame): logger.info('Connection:default_receiver: received frame ' 'with protocol %d: %r', protocol, frame) def register_protocol_handler(self, protocol, handler): '''Register a handler for frames bearing the specified protocol number. handler.on_receive(payload) is called for each frame received with the protocol number. Protocol handlers can be unregistered by calling this function with a handler of None. ''' if not handler: try: del self.protocol_handlers[protocol] except KeyError: pass return if protocol in self.protocol_handlers: raise exceptions.ProtocolAlreadyRegistered( 'Protocol %d is already registered by %r' % ( protocol, self.protocol_handlers[protocol])) if not hasattr(handler, 'on_receive'): raise ValueError('%r does not have an on_receive method') self.protocol_handlers[protocol] = handler def llc_handler(self, frame): opcode = ord(frame[0]) if opcode == self.LLC_LINK_OPENED: # MTU and MRU are from the perspective of this side of the # connection version, mru, mtu, timeout = struct.unpack('<xBHHB', frame) self.version = version # The server reports the MTU inclusive of protocol number and FCS, # but we only care about the maximum payload length. self.mtu = mtu - 5 self.mru = mru # Timeout is specified in deciseconds. Convert to seconds. self.timeout = timeout / 10.0 self._link_closed.clear() self._link_open.set() elif opcode == self.LLC_LINK_CLOSED: logger.info('PULSE connection closed.') self._link_closed.set() elif opcode == self.LLC_ECHO_REPLY: self._on_echo_reply(frame[1:]) else: logger.warning('Received LLC frame with unknown type %d: %r', opcode, frame) def run_keepalive_thread(self): '''The keepalive thread monitors the link, reopening it if necessary. ''' logger.debug('Connection: keepalive thread started') OPEN, TEST_LIVENESS, RECONNECT = range(3) state = OPEN next_state = state ping_attempts = 0 ping_wait = self.rtt while True: # Check whether the link is being closed from our side before # trying to keep it alive. if not self._link_open.is_set(): return if state == OPEN: time.sleep(1) if self._link_closed.is_set(): next_state = RECONNECT else: next_state = TEST_LIVENESS elif state == TEST_LIVENESS: if ping_attempts < 3: ping_attempts += 1 ping_wait *= 2 # Exponential backoff if self.ping(ping_wait): next_state = OPEN else: logger.info('No response to keepalive ping -- ' 'strike %d', ping_attempts) else: logger.info('Connection: keepalive timed out.') next_state = RECONNECT elif state == RECONNECT: # Lock out everyone from sending so that applications don't send # to a connection that's in an indeterminate state. with self.send_lock: if self.port_settings_altered: # Ensure that the server has timed out and reset its # baud rate so we don't get into the bad situation where # we try to reconnect at the default baud rate but the # server is listening at a different rate, which is # practically guaranteed to fail. logger.info('Letting connection time out before ' 'attempting to reconnect.') time.sleep(self.timeout + self.rtt) self._link_open.clear() while not self._link_open.is_set(): try: self._open_link() except exceptions.PulseError as e: logger.warning('Connection: reconnect failed. %s', e) if not self.infinite_reconnect: break logger.warning('Will try again.') logger.info('Backing off for a while before retrying.') time.sleep(self.timeout + self.rtt) else: next_state = OPEN else: assert False, 'Invalid state %d' % state if next_state != state: if next_state == TEST_LIVENESS: ping_attempts = 0 ping_wait = self.rtt state = next_state def _open_link(self): self.closed = False if self.initial_port_settings: self.iostream.applySettingsDict(self.initial_port_settings) for attempt in xrange(5): logger.info('Opening link (attempt %d)...', attempt) self.send(self.PROTOCOL_LLC, self.LLC_LINK_OPEN_REQUEST) if self._link_open.wait(self.rtt): logger.info('Established PULSE connection!') logger.info('Version=%d MTU=%d MRU=%d Timeout=%.1f', self.version, self.mtu, self.mru, self.timeout) break else: self._link_closed.set() self.closed = True raise exceptions.PulseError('Could not establish connection') def close(self): self._link_open.clear() if not self._link_closed.is_set(): for attempt in xrange(3): self.send(self.PROTOCOL_LLC, self.LLC_LINK_CLOSE_REQUEST) if self._link_closed.wait(self.rtt): break else: logger.warning('Could not confirm link close.') self._link_closed.set() self.iostream.close() self.closed = True def ping(self, timeout=None): if not timeout: timeout = 2 * self.rtt nonce = uuid.uuid4().bytes is_received = threading.Event() self.echoes_inflight[nonce] = is_received self.send(self.PROTOCOL_LLC, self.LLC_ECHO_REQUEST + nonce) return is_received.wait(timeout) def _on_echo_reply(self, payload): try: receive_event = self.echoes_inflight[payload] receive_event.set() except KeyError: pass def change_baud_rate(self, new_baud): # Fail fast if the IO object doesn't support changing the baud rate old_baud = self.iostream.baudrate self.send(self.PROTOCOL_LLC, self.LLC_CHANGE_BAUD + struct.pack('<I', new_baud)) # Be extra sure that the message has been sent and it's safe to adjust # the baud rate on the port. time.sleep(0.1) self.iostream.baudrate = new_baud self.port_settings_altered = True class ProtocolSocket(object): '''A socket for sending and receiving datagrams of a single protocol over a PULSE connection. It is also an example of a Connection protocol handler implementation. ''' def __init__(self, connection, protocol): self.connection = connection self.protocol = protocol self.receive_queue = Queue.Queue() self.connection.register_protocol_handler(protocol, self) def on_receive(self, frame): self.receive_queue.put(frame) def receive(self, block=True, timeout=None): try: return self.receive_queue.get(block, timeout) except Queue.Empty: raise exceptions.ReceiveQueueEmpty def send(self, frame): self.connection.send(self.protocol, frame) @property def mtu(self): return self.connection.mtu if __name__ == '__main__': logging.basicConfig(level=logging.INFO) with Connection.open_dbgserial(sys.argv[1]) as sock: sock.change_baud_rate(921600) for _ in xrange(20): time.sleep(0.5) send_time = time.time() if sock.ping(): print "Ping rtt=%.2f ms" % ((time.time() - send_time) * 1000) else: print "No echo"