#!/usr/bin/python
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from analyze_mcu_flash_find_unclaimed import claim
from analyze_mcu_flash_config import *
import argparse
import binutils
import json
import os.path
import re
import sh


ROOT_PATH_SPLIT_RE = re.compile('\.?\.?/?([^/]+)/(.*)')


def split_root_path(path):
    """ Takes a file path and returns a tuple: (root, rest_of_path)
        For example: "src/tintin/main.c" -> ("src", "tintin/main.c")
    """
    match = ROOT_PATH_SPLIT_RE.match(path)
    if match:
        groups = match.groups()
        return (groups[0], groups[1])
    else:
        return (path, None)


def tree_add_value(tree, path, value):
    """ Creates a subtree based on path in a given tree. Returns tree. """
    root, rest = split_root_path(path)
    if rest:
        # We haven't reached the leaf yet
        if root in tree:
            subtree = tree_add_value(tree[root], rest, value)
            tree[root].update(subtree)
        else:
            subtree = tree_add_value({}, rest, value)
            tree[root] = subtree
    else:
        # Leaf is reached!
        tree[root] = value
    return tree


def generate_tree(f, additional_symbols, config):
    """ Generates a tree based on the output of arm-none-eabi-nm. The tree its
        branches are the folders and files in the code base. The leaves of the
        tree are symbols and their sizes. Only symbols from .text are included.
        The tree is represented with dict() objects, where the folder, file or
        symbol name are keys. In case of a folder or file, another dict() will
        be the value. In case of a symbol, the value is an int() of its size.
    """
    symbols = binutils.nm_generator(f)
    unclaimed_regions = set([config.memory_region_to_analyze()])
    tree = {}
    total_size = 0
    for addr, section, symbol, src_path, line, size in symbols:
        if section != 't':
            # Not .text
            continue
        region = (addr, addr + size)
        if not claim(region, unclaimed_regions, symbol):
            # Region is already claimed by another symbol
            continue
        if not src_path or src_path == '?':
            src_path = '?'
            # Try to find the symbol in one of the additional symbol sets:
            for k in additional_symbols:
                if symbol in additional_symbols[k]:
                    src_path = k
                    break
                if symbol.startswith('sys_') or symbol.startswith('syscall'):
                    src_path = 'build/src/fw/syscall.auto.s'
                    break
        path = os.path.join(src_path, symbol)
        tree = tree_add_value(tree, path, size)
        total_size += size
    return (tree, total_size)


def convert_tree_to_d3(parent_name, tree):
    """ Converts a tree as generated by generate_tree() to a dict() that
        can be converted to JSON to use with the d3.js graphing library.
    """
    def convert_to_d3_node(parent_name, val):
        node = {'name': parent_name}
        val_type = type(val)
        if val_type is dict:
            node['children'] = [convert_to_d3_node(k, v)
                                for k, v in val.iteritems()]
        elif val_type is int:
            node['value'] = val
        else:
            raise Exception("Unexpected node type: %s, "
                            "parent_name=%s, val=%s" %
                            (str(val_type), parent_name, val))
        return node
    return convert_to_d3_node(parent_name, tree)


if (__name__ == '__main__'):
    parser = argparse.ArgumentParser()
    parser.add_argument('--verbose', action='store_true')
    parser.add_argument(
        '--config', default='tintin', choices=CONFIG_CLASSES.keys())
    parser.add_argument('elf_file', nargs='?')
    args = parser.parse_args()

    config_class = CONFIG_CLASSES[args.config]
    config = config_class()

    elf_file = args.elf_file
    if not elf_file:
        elf_file = config.default_elf_abs_path()

    # Create the tree:
    lib_symbols = config.lib_symbols()
    tree, total_size = generate_tree(elf_file, lib_symbols, config)

    # Unclaimed is space for which no symbols were found.
    # Run analyze_mcu_flash_find_unclaimed.py to get a dump of these regions.
    text_size = binutils.size(elf_file)[0]
    unclaimed_size = text_size - total_size
    if unclaimed_size:
        tree["Unclaimed"] = unclaimed_size

    config.apply_tree_tweaks(tree)

    # Convert to a structure that works with the d3.js graphing lib:
    d3_tree = convert_tree_to_d3('tintin', tree)

    # Dump to .json file:
    json_filename = 'analyze_mcu_flash_usage_treemap.jsonp'
    script_dir = os.path.dirname(os.path.realpath(__file__))
    json_path = os.path.join(script_dir, json_filename)

    file_out = open(json_path, 'wb')
    file_out.write("renderJson(")
    json.dump(d3_tree, file_out)
    file_out.write(");")
    file_out.close()

    # Print out some stats:
    print "Total .text bytes:         %u" % text_size
    print "Total bytes mapped:        %u" % total_size
    print "-------------------------------------"
    print "Unaccounted bytes:         %u" % (text_size - total_size)
    print ""
    print "Now go open %s.html to view treemap" % os.path.splitext(__file__)[0]