# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import collections import struct import time from . import exceptions from . import socket class EraseCommand(object): command_type = 1 command_struct = struct.Struct('<BII') response_type = 128 response_struct = struct.Struct('<xII?') Response = collections.namedtuple( 'EraseResponse', 'address length complete') def __init__(self, address, length): self.address = address self.length = length @property def packet(self): return self.command_struct.pack( self.command_type, self.address, self.length) def parse_response(self, response): if ord(response[0]) != self.response_type: raise exceptions.ResponseParseError( 'Unexpected response: %r' % response) unpacked = self.Response._make(self.response_struct.unpack(response)) if unpacked.address != self.address or unpacked.length != self.length: raise exceptions.ResponseParseError( 'Response does not match command: ' 'address=%#.08x length=%d (expected %#.08x, %d)' % ( unpacked.address, unpacked.length, self.address, self.length)) return unpacked class WriteCommand(object): command_type = 2 command_struct = struct.Struct('<BI') header_len = command_struct.size def __init__(self, address, data): self.address = address self.data = data @property def packet(self): header = self.command_struct.pack(self.command_type, self.address) return header + self.data class WriteResponse(object): response_type = 129 response_struct = struct.Struct('<xII?') Response = collections.namedtuple( 'WriteResponse', 'address length complete') @classmethod def parse(cls, response): if ord(response[0]) != cls.response_type: raise exceptions.ResponseParseError( 'Unexpected response: %r' % response) return cls.Response._make(cls.response_struct.unpack(response)) class CrcCommand(object): command_type = 3 command_struct = struct.Struct('<BII') response_type = 130 response_struct = struct.Struct('<xIII') Response = collections.namedtuple('CrcResponse', 'address length crc') def __init__(self, address, length): self.address = address self.length = length @property def packet(self): return self.command_struct.pack(self.command_type, self.address, self.length) def parse_response(self, response): if ord(response[0]) != self.response_type: raise exceptions.ResponseParseError( 'Unexpected response: %r' % response) unpacked = self.Response._make(self.response_struct.unpack(response)) if unpacked.address != self.address or unpacked.length != self.length: raise exceptions.ResponseParseError( 'Response does not match command: ' 'address=%#.08x length=%d (expected %#.08x, %d)' % ( unpacked.address, unpacked.length, self.address, self.length)) return unpacked class QueryFlashRegionCommand(object): command_type = 4 command_struct = struct.Struct('<BB') REGION_PRF = 1 REGION_SYSTEM_RESOURCES = 2 response_type = 131 response_struct = struct.Struct('<xBII') Response = collections.namedtuple( 'FlashRegionGeometry', 'region address length') def __init__(self, region): self.region = region @property def packet(self): return self.command_struct.pack(self.command_type, self.region) def parse_response(self, response): if ord(response[0]) != self.response_type: raise exceptions.ResponseParseError( 'Unexpected response: %r' % response) unpacked = self.Response._make(self.response_struct.unpack(response)) if unpacked.address == 0 and unpacked.length == 0: raise exceptions.RegionDoesNotExist(self.region) return unpacked class FinalizeFlashRegionCommand(object): command_type = 5 command_struct = struct.Struct('<BB') response_type = 132 response_struct = struct.Struct('<xB') def __init__(self, region): self.region = region @property def packet(self): return self.command_struct.pack(self.command_type, self.region) def parse_response(self, response): if ord(response[0]) != self.response_type: raise exceptions.ResponseParseError( 'Unexpected response: %r' % response) region, = self.response_struct.unpack(response) if region != self.region: raise exceptions.ResponseParseError( 'Response does not match command: ' 'response is for region %d (expected %d)' % ( region, self.region)) class FlashImagingProtocol(object): PROTOCOL_NUMBER = 0x02 RESP_BAD_CMD = 192 RESP_INTERNAL_ERROR = 193 REGION_PRF = QueryFlashRegionCommand.REGION_PRF REGION_SYSTEM_RESOURCES = QueryFlashRegionCommand.REGION_SYSTEM_RESOURCES def __init__(self, connection): self.socket = socket.ProtocolSocket(connection, self.PROTOCOL_NUMBER) def erase(self, address, length): cmd = EraseCommand(address, length) ack_received = False retries = 0 while retries < 10: if not ack_received: self.socket.send(cmd.packet) try: packet = self.socket.receive(timeout=5 if ack_received else 1.5) response = cmd.parse_response(packet) ack_received = True if response.complete: return except exceptions.ReceiveQueueEmpty: ack_received = False retries += 1 continue raise exceptions.CommandTimedOut def write(self, address, data, max_retries=5, max_in_flight=5, progress_cb=None): mtu = self.socket.mtu - WriteCommand.header_len assert(mtu > 0) unsent = collections.OrderedDict() for offset in xrange(0, len(data), mtu): segment = data[offset:offset+mtu] assert(len(segment)) seg_address = address + offset unsent[seg_address] = WriteCommand(seg_address, segment) in_flight = collections.OrderedDict() retries = 0 while unsent or in_flight: try: while True: # Process ACKs (if any) ack = WriteResponse.parse( self.socket.receive(block=False)) try: cmd, _, _ = in_flight[ack.address] except KeyError: raise exceptions.WriteError( 'Received ACK for an unknown segment: ' '%#.08x' % ack.address) if len(cmd.data) != ack.length: raise exceptions.WriteError( 'ACK length %d != data length %d' % ( ack.length, len(cmd.data))) assert(ack.complete) del in_flight[ack.address] if progress_cb: progress_cb(True) except exceptions.ReceiveQueueEmpty: pass # Retry any in_flight writes where the ACK has timed out to_retry = [] timeout_time = time.time() - 0.5 for seg_address, (_, send_time, _) in in_flight.iteritems(): if send_time > timeout_time: # in_flight is an OrderedDict so iteration is in # chronological order. break to_retry.append(seg_address) retries += len(to_retry) for seg_address in to_retry: cmd, send_time, retry_count = in_flight[seg_address] del in_flight[seg_address] if retry_count >= max_retries: raise exceptions.WriteError( 'Segment %#.08x exceeded the max retry count (%d)' % ( seg_address, max_retries)) retry_count += 1 self.socket.send(cmd.packet) in_flight[seg_address] = (cmd, time.time(), retry_count) if progress_cb: progress_cb(False) # Send out fresh segments try: while len(in_flight) < max_in_flight: seg_address, cmd = unsent.popitem(last=False) self.socket.send(cmd.packet) in_flight[cmd.address] = (cmd, time.time(), 0) except KeyError: pass # Give other threads a chance to run time.sleep(0) return retries def _command_and_response(self, cmd, timeout=0.5): for attempt in xrange(5): self.socket.send(cmd.packet) try: packet = self.socket.receive(timeout=timeout) return cmd.parse_response(packet) except exceptions.ReceiveQueueEmpty: pass raise exceptions.CommandTimedOut def crc(self, address, length): cmd = CrcCommand(address, length) return self._command_and_response(cmd, timeout=1).crc def query_region_geometry(self, region): cmd = QueryFlashRegionCommand(region) return self._command_and_response(cmd) def finalize_region(self, region): cmd = FinalizeFlashRegionCommand(region) return self._command_and_response(cmd)