# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import struct import time from binascii import crc32 from random import randint from hdlc import HDLCDecoder, hdlc_encode_data from serial_port_wrapper import SerialPortWrapper CRC_RESIDUE = crc32('\0\0\0\0') READ_TIMEOUT = 1 ACCESSORY_CONSOLE_BAUD_RATE = 115200 ACCESSORY_IMAGING_BAUD_RATE = 921600 class AccessoryImagingError(Exception): pass class AccessoryImaging(object): class Frame(object): MAX_DATA_LENGTH = 1024 FLAG_IS_SERVER = (1 << 0) FLAG_VERSION = (0b111 << 1) OPCODE_PING = 0x01 OPCODE_DISCONNECT = 0x02 OPCODE_RESET = 0x03 OPCODE_FLASH_GEOMETRY = 0x11 OPCODE_FLASH_ERASE = 0x12 OPCODE_FLASH_WRITE = 0x13 OPCODE_FLASH_CRC = 0x14 OPCODE_FLASH_FINALIZE = 0x15 OPCODE_FLASH_READ = 0x16 REGION_PRF = 0x01 REGION_RESOURCES = 0x02 REGION_FW_SCRATCH = 0x03 REGION_PFS = 0x04 REGION_COREDUMP = 0x05 FLASH_READ_FLAG_ALL_SAME = (1 << 0) def __init__(self, raw_data): self._data = raw_data def __repr__(self): if self.is_valid(): return '<{}@{:#x}: opcode={}>' \ .format(self.__class__.__name__, id(self), self.get_opcode()) else: return '<{}@{:#x}: INVALID>' \ .format(self.__class__.__name__, id(self)) def is_valid(self): # minimum packet size is 6 (2 bytes of header and 4 bytes of checksum) return self._data and len(self._data) >= 6 and crc32(self._data) == CRC_RESIDUE def flag_is_server(self): return bool(ord(self._data[0]) & self.FLAG_IS_SERVER) def flag_version(self): return (ord(self._data[0]) & self.FLAG_VERSION) >> 1 def get_opcode(self): return ord(self._data[1]) def get_payload(self): return self._data[2:-4] class FlashBlock(object): def __init__(self, addr, data): self._addr = addr self._data = data self._crc = crc32(self._data) & 0xFFFFFFFF self._validated = False def get_write_payload(self): return struct.pack('<I', self._addr) + self._data def get_crc_payload(self): return struct.pack('<II', self._addr, len(self._data)) def validate(self, raw_response): addr, length, crc = struct.unpack('<III', raw_response) # check if this response completely includes this block if addr <= self._addr and (addr + length) >= self._addr + len(self._data): self._validated = (crc == self._crc) def is_validated(self): return self._validated def __repr__(self): return '<{}@{:#x}: addr={:#x}, length={}>' \ .format(self.__class__.__name__, id(self), self._addr, len(self._data)) def __init__(self, tty): self._serial = SerialPortWrapper(tty, None, ACCESSORY_CONSOLE_BAUD_RATE) self._hdlc_decoder = HDLCDecoder() self._server_version = 0 def _send_frame(self, opcode, payload): data = struct.pack('<BB', 0, opcode) data += payload data += struct.pack('<I', crc32(data) & 0xFFFFFFFF) self._serial.write_fast(hdlc_encode_data(data)) def _read_frame(self): start_time = time.time() while True: # process any queued frames for frame_data in iter(self._hdlc_decoder.get_frame, None): frame = self.Frame(frame_data) if frame.is_valid() and frame.flag_is_server(): self._server_version = frame.flag_version() return frame if (time.time() - start_time) > READ_TIMEOUT: return None self._hdlc_decoder.write(self._serial.read(0.001)) def _command_and_response(self, opcode, payload=''): retries = 5 while True: self._send_frame(opcode, payload) frame = self._read_frame() if frame: if frame.get_opcode() != opcode: raise AccessoryImagingError('ERROR: Got unexpected response ({:#x}, {})' .format(opcode, frame)) break elif --retries == 0: raise AccessoryImagingError('ERROR: Watch did not respond to request ({:#x})' .format(opcode)) return frame.get_payload() def _get_prompt(self): timeout = time.time() + 5 while True: # we could be in stop mode, so send a few self._serial.write('\x03') self._serial.write('\x03') self._serial.write('\x03') read_data = self._serial.read() if read_data and read_data[-1] == '>': break time.sleep(0.5) if time.time() > timeout: raise AccessoryImagingError('ERROR: Timed-out connecting to the watch!') def start(self): self._serial.s.baudrate = ACCESSORY_CONSOLE_BAUD_RATE self._get_prompt() self._serial.write_fast('accessory imaging start\r\n') self._serial.read() self._serial.s.baudrate = ACCESSORY_IMAGING_BAUD_RATE if self._server_version >= 1: self.Frame.MAX_DATA_LENGTH = 2048 def ping(self): payload = ''.join(chr(randint(0, 255)) for _ in range(10)) if self._command_and_response(self.Frame.OPCODE_PING, payload) != payload: raise AccessoryImagingError('ERROR: Invalid ping payload in response!') def disconnect(self): self._command_and_response(self.Frame.OPCODE_DISCONNECT) self._serial.s.baudrate = ACCESSORY_CONSOLE_BAUD_RATE def reset(self): self._command_and_response(self.Frame.OPCODE_RESET) def flash_geometry(self, region): if region == self.Frame.REGION_PFS or region == self.Frame.REGION_COREDUMP: # These regions require >= v1 if self._server_version < 1: raise AccessoryImagingError('ERROR: Server does not support this region') payload = struct.pack('<B', region) response = self._command_and_response(self.Frame.OPCODE_FLASH_GEOMETRY, payload) response_region, addr, length = struct.unpack('<BII', response) if response_region != region or length == 0: raise AccessoryImagingError('ERROR: Did not get region information ({:#x})' .format(region)) return addr, length def flash_erase(self, addr, length): payload = struct.pack('<II', addr, length) while True: response = self._command_and_response(self.Frame.OPCODE_FLASH_ERASE, payload) response_addr, response_length, response_complete = struct.unpack('<IIB', response) if response_addr != addr or response_length != length: raise AccessoryImagingError('ERROR: Got invalid response (expected ' '[{:#x},{:#x}], got [{:#x},{:#x}])' .format(addr, length, response_addr, response_length)) elif response_complete != 0: break time.sleep(0.5) time.sleep(1) def flash_write(self, block): self._send_frame(self.Frame.OPCODE_FLASH_WRITE, block.get_write_payload()) def flash_crc(self, blocks): payload = ''.join(x.get_crc_payload() for x in blocks) response = self._command_and_response(self.Frame.OPCODE_FLASH_CRC, payload) response_fmt = '<III' entry_size = struct.calcsize(response_fmt) num_entries = len(response) // entry_size if len(response) % entry_size != 0: raise AccessoryImagingError('ERROR: Invalid response length ({})'.format(len(response))) elif num_entries != len(blocks): raise AccessoryImagingError('ERROR: Invalid number of response entries ({})' .format(num_entries)) responses = [response[i:i+entry_size] for i in xrange(0, len(response), entry_size)] assert len(responses) == len(blocks) return responses def flash_finalize(self, region): payload = struct.pack('<B', region) response = self._command_and_response(self.Frame.OPCODE_FLASH_FINALIZE, payload) response_region = struct.unpack('<B', response)[0] if response_region != region: raise AccessoryImagingError('ERROR: Did not get correct region ({:#x})'.format(region)) def flash_read(self, region, progress): if progress: print('Connecting...') self.start() self.ping() # flash reading was added in v1 if self._server_version < 1: raise AccessoryImagingError('ERROR: Server does not support reading from flash') addr, length = self.flash_geometry(region) if progress: print('Reading...') read_bytes = [] last_percent = 0 for offset in xrange(0, length, self.Frame.MAX_DATA_LENGTH): chunk_length = min(self.Frame.MAX_DATA_LENGTH, length - offset) data = struct.pack('<II', offset + addr, chunk_length) response = self._command_and_response(self.Frame.OPCODE_FLASH_READ, payload=data) # the first byte of the response is the flags (0th bit: repeat the single data byte) if bool(ord(response[0]) & self.Frame.FLASH_READ_FLAG_ALL_SAME): if len(response) != 2: raise AccessoryImagingError('ERROR: Invalid flash read response') read_bytes.extend(response[1] * chunk_length) else: read_bytes.extend(response[1:]) if progress: # don't spam the progress (only every 5%) percent = (offset * 100) // length if percent >= last_percent + 5: print('{}% of the data read'.format(percent)) last_percent = percent self.flash_finalize(region) self.disconnect() if progress: print('Done!') return read_bytes def flash_image(self, image, region, progress): if progress: print('Connecting...') self.start() self.ping() addr, length = self.flash_geometry(region) if len(image) > length: raise AccessoryImagingError('ERROR: Image is too big! (size={}, region_length={})' .format(len(image), length)) if progress: print('Erasing...') self.flash_erase(addr, length) total_blocks = [] # the block size should be as big as possible, but we need to leave 4 bytes for the address block_size = self.Frame.MAX_DATA_LENGTH - 4 for offset in xrange(0, len(image), block_size): total_blocks.append(self.FlashBlock(addr + offset, image[offset:offset+block_size])) if progress: print('Writing...') num_total = len(total_blocks) num_errors = 0 pending_blocks = [x for x in total_blocks if not x.is_validated()] while len(pending_blocks) > 0: # We will split up the outstanding blocks into packets which should be as big as # possible, but are limited by the fact that the flash CRC response is 12 bytes per # block. packet_size = self.Frame.MAX_DATA_LENGTH // 12 packets = [] for i in xrange(0, len(pending_blocks), packet_size): packets += [pending_blocks[i:i+packet_size]] for packet in packets: # write each of the blocks for block in packet: self.flash_write(block) # CRC each of the blocks crc_results = self.flash_crc(packet) for block, result in zip(packet, crc_results): block.validate(result) # update the pending blocks pending_blocks = [x for x in total_blocks if not x.is_validated()] if progress: percent = ((num_total - len(pending_blocks)) * 100) // num_total num_errors += len([x for x in packet if not x.is_validated()]) print('{}% of blocks written ({} errors)'.format(percent, num_errors)) self.flash_finalize(region) if region == self.Frame.REGION_FW_SCRATCH: self.reset() else: self.disconnect() if progress: print('Done!')