Tools: Add python modules

Add generic python modules to TF-M to support build and image packaging
tools.

* arg_utils provides helpers to deal with argparse arguments
* c_include gets the list of include paths for a file from the
  `compile_commands.json` build database
* c_macro includes a python implementation of (most of) the C
  preprocessor.
* c_struct is a libclang-based evaluator of C datastructures (Including
  enums) which can be used to generate python representations of nested
  C datastructures which rely on complex macro configuration.
* crypto_conversion_utils provides helpers to convert various types of
  crypto keys to different and convert string representations of
  algorithms and hash functions to their python objects
* encrypt_data provides functions to encrypt bytes() objects
* file_loader provides automatic handler functions for various filetypes
  based on their extensions, primarily useful for loading crypto keys
* key_derivation provides a python implementation of HKDF and a
  SP200-108 CMAC KDF, both matching the TF-M/MbedTLS/CC3XX implementation
* sign_data provides functions to perform symmetric and asymmetric
  signatures of bytes() objects
* sign_then_encrypt_data provides combined signing and encryption,
  either via symmetric AEAD modes or a combination of the sign_data and
  encrypt_data modules
* struct_pack provides helper functions for packing bytes objects
  together.

Change-Id: I858dd8ef69c9069ec0a44e4ad3f9a1d70cc5d4da
Signed-off-by: Raef Coles <raef.coles@arm.com>
diff --git a/tools/modules/c_struct.py b/tools/modules/c_struct.py
new file mode 100644
index 0000000..4ae865f
--- /dev/null
+++ b/tools/modules/c_struct.py
@@ -0,0 +1,650 @@
+#!/usr/bin/env python3
+#-------------------------------------------------------------------------------
+# SPDX-FileCopyrightText: Copyright The TrustedFirmware-M Contributors
+#
+# SPDX-License-Identifier: BSD-3-Clause
+#
+#-------------------------------------------------------------------------------
+
+import clang.cindex as cl
+import struct
+from collections import Counter
+import os
+
+import logging
+logger = logging.getLogger("TF-M")
+
+from rich import inspect
+
+pad = 4
+
+format_chars = {
+    'uint8_t'     : 'B',
+    'uint16_t'    : 'H',
+    'uint32_t'    : 'I',
+    'uint64_t'    : 'Q',
+    'bool'        : '?',
+    'char'        : 'b',
+    'short'       : 'H',
+    'int'         : 'I',
+    'long'        : 'I',
+    'long long'   : 'Q',
+    'uintptr_t'   : 'I',
+    'size_t'      : 'I',
+    'float'       : 'f',
+    'double'      : 'd',
+    'long double' : 'd',
+}
+
+def _c_struct_or_union_from_cl_cursor(cursor, name, f):
+        def is_field(x):
+            if x.kind == cl.CursorKind.FIELD_DECL:
+                return True
+            if x.kind in [cl.CursorKind.STRUCT_DECL, cl.CursorKind.UNION_DECL]:
+                return True
+            return False
+
+        fields = [c for c in cursor.get_children() if is_field(c)]
+        packed  = True if [c for c in cursor.get_children() if c.kind == cl.CursorKind.PACKED_ATTR] else False
+
+        c_type = cursor.spelling
+
+        fields = [_c_from_cl_cursor(f) for f in fields]
+        fields = [f for f in fields if f.get_size() != 0]
+
+        duplicate_types = Counter([f.c_type for f in fields if not isinstance(f, C_variable)])
+        duplicate_types = [k for k,v in duplicate_types.items() if v > 1]
+        fields = [f for f in fields if f.c_type not in duplicate_types or f.name != ""]
+
+        return f(name, c_type, fields, packed)
+
+def _pad_lines(text, size):
+    lines = text.split("\n")
+    lines = [" " * size + l for l in lines]
+    return "\n".join(lines)
+
+def _c_struct_or_union_get_value_str(self, struct_or_union):
+    string = "{\n"
+    fields_string = ""
+    for f in self._fields:
+        if f.to_bytes() != bytes(f.get_size()):
+            fields_string += _pad_lines(".{} = {},".format(f.name, f.get_value_str()), pad)
+    string += fields_string
+    string += "\n}"
+    return string
+
+def _c_struct_or_union_str(self, struct_or_union):
+        string = ""
+        string += "{} ".format(struct_or_union)
+        string += "__attribute__((packed)) " if self._packed else ""
+        string += "{} ".format(self.c_type) if self.c_type != "" else ""
+        string += "{\n"
+
+        fields_string = "\n".join([_pad_lines(str(f), pad) for f in self._fields])
+        string += fields_string
+
+        if string[-1] != "\n":
+            string += "\n"
+
+        string += "}"
+        string += " {}".format(self.name) if self.name != "" else ""
+        string += ";"
+        return string
+
+def _c_struct_or_union_get_field_strings(self):
+    field_strings=[]
+
+    for f in self._fields:
+        name_format = "{}.".format(self.name) if self.name != "" else ""
+        field_strings += [name_format + s for s in f.get_field_strings()]
+
+    return field_strings
+
+def _c_struct_or_union_field_to_bytes(field):
+    if isinstance(field, C_variable):
+        return field._get_format_string(), field.get_value()
+    else:
+        return _binary_format_string(field.get_size()), field.to_bytes()
+
+def _c_struct_or_union_get_next_in_path(self, field_path):
+    field_separator = "."
+
+    if field_separator in field_path:
+        field_name, remainder = field_path.split(".", 1)
+    else:
+        field_name = field_path
+        remainder = None
+
+    try:
+        # Fields aren't allowed to duplicate names
+        field = [f for f in self._fields if f.name == field_name][0]
+        return field, remainder
+    except (KeyError, IndexError):
+        for f in self._fields:
+            try:
+                f.get_field(field_path)
+                return f, field_path
+            except KeyError:
+                continue
+    raise KeyError
+
+def _c_struct_or_union_get_direct_members(self):
+    subfields = []
+    for f in self._fields:
+        if f.name:
+            continue
+        subfields += _c_struct_or_union_get_direct_members(f)
+
+    return [f for f in self._fields if f.name] + subfields
+
+def _c_struct_or_union_get_field(self, field_path):
+    fields = [f for f in self._fields if f.name == field_path[:len(f.name)]]
+    for f in fields:
+        try:
+            remainder = field_path.replace(f.name + ".", "") if f.name else field_path
+            return f.get_field(remainder)
+        except (KeyError, ValueError):
+            continue
+    raise KeyError
+
+
+def _binary_format_string(size):
+    return "{}s".format(size)
+
+class C_enum:
+    def __init__(self, name, c_type, members):
+        self.name = name.replace("enum", "").lstrip()
+        self._members = members
+        self.dict = {m.name:m for m in self._members}
+        self.__dict__.update({m.name:m for m in self._members})
+
+    @staticmethod
+    def from_h_file(h_file, name, includes = [], defines = []):
+        return _c_from_h_file(h_file, name, includes, defines, C_enum, cl.CursorKind.ENUM_DECL)
+
+    @staticmethod
+    def from_cl_cursor(cursor, name=""):
+        definition = cursor.get_definition()
+        assert(definition.kind == cl.CursorKind.ENUM_DECL)
+
+        max_value = max([x.enum_value for x in definition.get_children()])
+        if (max_value <= 2^8):
+            c_type = 'uint8_t'
+        elif (max_value <= 2^16):
+            c_type = 'uint16_t'
+        else:
+            c_type = 'uint32_t'
+
+        members = [C_enum_member(x.spelling, c_type, x.enum_value) for x in definition.get_children()]
+        members = [m for m in members if m.name[0] != '_']
+        return C_enum(cursor.spelling, c_type, members)
+
+
+    def __str__(self):
+        pad = 4
+
+        string = ""
+
+        string += "enum "
+        string += "{} ".format(self.name)
+        string += "{\n"
+
+        fields_string = "\n".join([_pad_lines(str(f), pad) for f in self._members])
+        string += fields_string
+
+        if string[-1] != "\n":
+            string += "\n"
+
+        string += "};"
+        return string
+
+class C_enum_member:
+    def __init__(self, name, c_type, value):
+        self.name = name
+        self.c_type = c_type
+        self.value = value
+        self._format_string = self._get_format_string()
+        self._size = struct.calcsize(self._get_format_string());
+
+    def _get_format_string(self):
+        return format_chars[self.c_type]
+
+    def get_value(self):
+        return self.value
+
+    def to_bytes(self):
+        return struct.pack("<" + self._format_string, self.get_value())
+
+    def get_size(self):
+        return self._size
+
+    def __str__(self):
+        return self.name
+
+    def __repr__(self):
+        return self.name
+
+class C_array:
+    def __init__(self, name, c_type, members):
+        self.name = name
+        self.c_type = c_type
+        self._members = members
+
+        if self._members:
+            self._dimensions = [len(self._members)]
+            if isinstance(self._members[0], C_array):
+                self._dimensions += self._members[0]._dimensions
+        else:
+            self._dimensions = [0]
+
+        self._size = struct.calcsize(self._get_format_string());
+
+    @staticmethod
+    def from_h_file(h_file, name, includes = [], defines = []):
+        return _c_from_h_file(h_file, name, includes, defines, C_array, cl.CursorKind.TYPE_DECL)
+
+    @staticmethod
+    def from_cl_cursor(cursor, name="", dimensions = []):
+        c_type = cursor.type.spelling
+        c_type = c_type.replace("unsigned", "")
+        c_type = c_type.replace("volatile", "")
+        c_type = c_type.replace("const", "")
+        c_type = c_type.replace("static", "")
+
+        if not dimensions:
+            c_type, *dimensions = c_type.split("[")
+            dimensions = [0 if d == "]" else int(d.replace("]", "")) for d in dimensions]
+        c_type = c_type.strip()
+
+        assert(len(dimensions) > 0)
+
+        if (len(dimensions) > 1):
+            f = lambda x:C_array.from_cl_cursor(cursor, x, dimensions[1:])
+        elif c_type not in format_chars.keys():
+            f = lambda x:_c_from_cl_cursor(list(cursor.get_children())[0].get_definition(), x)
+        else:
+            f = lambda x:C_variable(x, c_type)
+
+        members = [f(name + "_{}".format(i)) for i in range(dimensions[0])]
+
+        return C_array(name, c_type, members)
+
+    def __getitem__(self, index):
+        return self._members[index]
+
+    # def __setitem__(self, index, value):
+    #     return self._members[index].set_value_from_bytes(value)
+
+    def _get_format_string(self):
+        return "<" + "".join([_c_struct_or_union_field_to_bytes(m)[0] for m in self._members])
+
+    def get_value(self):
+        return self.to_bytes()
+
+    def set_value(self, value):
+        self.set_value_from_bytes(value)
+
+    def set_value_from_bytes(self, value):
+        assert(len(value) <= self._size), "{} of size {} cannot be set to value {} of size {}".format(self, self._size, value.hex(), len(value))
+        value_used = 0
+        for m in self._members:
+            if (value_used == len(value)):
+                break
+
+            m.set_value_from_bytes(value[value_used:value_used + m.get_size()])
+            value_used += m.get_size()
+
+    def get_field_strings(self):
+        return [self.name] + [f for m in self._members if not isinstance(m, C_variable) for f in m.get_field_strings()]
+
+    def get_field(self, field_path):
+        if field_path == self.name:
+            return self
+
+        field_path = field_path.replace(self.name + "_", "")
+
+        splits = [x for x in [field_path.find('.'), field_path.find("_")] if x != -1]
+
+        if not splits:
+            index = field_path
+            remainder = None
+        else:
+            split_idx = min(splits)
+            index = field_path[:split_idx]
+            remainder = field_path[split_idx + 1:]
+
+        index = int(index)
+
+        if (remainder):
+            return self._members[index].get_field(remainder)
+        else:
+            return self._members[index]
+
+    def to_bytes(self):
+        format_str = ""
+        values = []
+
+        for m in self._members:
+            field_string, field_data = _c_struct_or_union_field_to_bytes(m)
+            format_str += field_string
+            values.append(field_data)
+
+        return struct.pack(format_str, *values)
+
+    def get_size(self):
+        return self._size
+
+    def get_value_str(self):
+        string = ""
+        if self.to_bytes() != bytes(self.get_size()):
+            string += "{\n"
+            m_string = ""
+            for m in self._members:
+                m_string += m.get_value_str() + ", "
+            string += _pad_lines(m_string, pad)
+            string += "\n}"
+        return string
+
+    def __str__(self):
+        string = "{} {}".format(self.c_type, self.name)
+        string += "".join(["[{}]".format(a) for a in self._dimensions])
+
+        if self.to_bytes() != bytes(self.get_size()):
+            string += " = " + self.get_value_str()
+
+        string += ";"
+        return string
+
+class C_variable:
+    def __init__(self, name, c_type, value = None):
+        self.name = name
+        self.c_type = c_type
+        self._format_string = self._get_format_string()
+        self._size = struct.calcsize(self._format_string)
+        self.value = value
+
+    @staticmethod
+    def from_h_file(h_file, name, includes = [], defines = []):
+        return _c_from_h_file(h_file, name, includes, defines, C_variable, cl.CursorKind.TYPE_DECL)
+
+    @staticmethod
+    def from_cl_cursor(cursor, name=""):
+        c_type = cursor.type.spelling
+        c_type = c_type.replace("unsigned", "")
+        c_type = c_type.replace("volatile", "")
+        c_type = c_type.replace("const", "")
+        c_type = c_type.replace("static", "")
+
+        if "[" in c_type:
+            return C_array.from_cl_cursor(cursor, name)
+
+        return C_variable(name, c_type)
+
+    def _get_format_string(self):
+        if 'enum' in self.c_type:
+            return format_chars['uint32_t']
+
+        return format_chars[self.c_type]
+
+    def get_value(self):
+        value = self.value if self.value else 0
+        return value
+
+    def set_value(self, value):
+        if isinstance(value, str):
+            self.value = int(value, 0)
+        else:
+            self.value = value
+
+        #Sanity check the value
+        self.to_bytes()
+
+    def set_value_from_bytes(self, value):
+        self.set_value(struct.unpack(self._format_string, value)[0])
+
+    def get_field_strings(self):
+        return [self.name]
+
+    def get_field(self, field_path):
+        return self
+
+    def to_bytes(self):
+        return struct.pack("<" + self._format_string, self.get_value())
+
+    def get_size(self):
+        return self._size
+
+    def get_value_str(self):
+        return hex(self.value)
+
+    def __str__(self):
+        string = "{} {}".format(self.c_type, self.name)
+
+        if self.value != None:
+            string += " = {}".format(self.get_value_str())
+
+        string += ";"
+        return string
+
+class C_union:
+    def __init__(self, name="", c_type="", fields=[], packed=False):
+        self.name = name
+        self.c_type = c_type.replace("union ", "")
+        self._fields = fields
+        self._packed = packed
+        self._format_strings = self._get_format_strings()
+        self._size = max(map(struct.calcsize, self._format_strings))
+        self._actual_value = bytes(self._size)
+        self.__dict__.update({f.name:f for f in _c_struct_or_union_get_direct_members(self)})
+
+    @staticmethod
+    def from_h_file(h_file, name, includes = [], defines = []):
+        return _c_from_h_file(h_file, name, includes, defines, C_union, cl.CursorKind.UNION_DECL)
+
+    @staticmethod
+    def from_cl_cursor(cursor, name=""):
+        return _c_struct_or_union_from_cl_cursor(cursor, name, C_union)
+
+    def _ensure_consistency(self):
+        binaries = [f.to_bytes() for f in self._fields]
+        new_binaries = [b for b in binaries if b != self._actual_value[:len(b)]]
+        new_binaries = list(set(new_binaries))
+
+        assert(len(new_binaries) < 2)
+
+        if new_binaries:
+            self._actual_value = list(new_binaries)[0]
+
+        for f in self._fields:
+            f.set_value_from_bytes(self._actual_value[:f._size])
+
+    def _get_format_strings(self):
+        format_endianness = "<" if self._packed else "@"
+        return ["{}{}".format(format_endianness, _binary_format_string(f._size)) for f in self._fields]
+
+    def set_value_from_bytes(self, value):
+        for f in self._fields:
+            f.set_value_from_bytes(value)
+        self._ensure_consistency()
+
+    def set_value(self, field_path, value):
+        field.set_value(self.get_field(field_path))
+        self._ensure_consistency()
+
+    def get_value(self, field_path):
+        self._ensure_consistency()
+        return self.get_field(field_path).value
+
+    def get_field_strings(self):
+        self._ensure_consistency()
+        return _c_struct_or_union_get_field_strings(self)
+
+    def get_field(self, field_path):
+        self._ensure_consistency()
+        return _c_struct_or_union_get_field(self, field_path)
+
+    def to_bytes(self):
+        self._ensure_consistency()
+        binaries = [f.to_bytes() for f in self._fields]
+
+        binaries = set(binaries)
+        return max(list(binaries))
+
+    def get_size(self):
+        return self._size
+
+    def get_value_str(self):
+        self._ensure_consistency()
+        return _c_struct_or_union_get_value_str(self, "union")
+
+    def __str__(self):
+        self._ensure_consistency()
+        return _c_struct_or_union_str(self, "union")
+
+class C_struct:
+    def __init__(self, name="", c_type="", fields=[], packed=False):
+        self.name = name
+        self.c_type = c_type.replace("struct ", "")
+        self._fields = fields
+        self._packed = packed
+        self._format_string = self._get_format_string()
+        self._size = struct.calcsize(self._format_string)
+        self.__dict__.update({f.name:f for f in _c_struct_or_union_get_direct_members(self)})
+
+    @staticmethod
+    def from_h_file(h_file, name, includes = [], defines = []):
+        return _c_from_h_file(h_file, name, includes, defines, C_struct, cl.CursorKind.STRUCT_DECL)
+
+    @staticmethod
+    def from_cl_cursor(cursor, name=""):
+        return _c_struct_or_union_from_cl_cursor(cursor, name, C_struct)
+
+    def _get_format_string(self):
+        format_endianness = "<" if self._packed else "@"
+        return format_endianness + "".join([_c_struct_or_union_field_to_bytes(f)[0] for f in self._fields])
+
+    def set_value_from_bytes(self, value):
+        value_used = 0
+        for f in self._fields:
+            f.set_value_from_bytes(value[value_used:value_used + f.get_size()])
+            value_used += f.get_size()
+        assert(value_used == len(value))
+
+    def set_value(self, field_path, value):
+        self.get_field(field_path).set_value(value)
+
+    def get_value(self, field_path):
+        return self.get_field(field_path).value
+
+    def get_field_strings(self):
+        return _c_struct_or_union_get_field_strings(self)
+
+    def get_field(self, field_path):
+        return _c_struct_or_union_get_field(self, field_path)
+
+    def to_bytes(self):
+        format_endianness = "<" if self._packed else "@"
+        format_str = ""
+        values = []
+
+        for f in self._fields:
+            field_string, field_data = _c_struct_or_union_field_to_bytes(f)
+            format_str += field_string
+            values.append(field_data)
+
+        return struct.pack(format_str, *values)
+
+    def get_size(self):
+        return self._size
+
+    def get_docs_table(self):
+        return _c_struct_or_union_get_docs_table(self)
+
+    def get_value_str(self):
+        return _c_struct_or_union_get_value_str(self, "struct")
+
+    def __str__(self):
+        return _c_struct_or_union_str(self, "struct")
+
+def _parse_field_dec(cursor):
+    return list(cursor.get_children())[0], cursor.spelling
+
+def _parse_type_ref(cursor, name):
+    return cursor.get_definition(), name
+
+def _c_from_cl_cursor(cursor, name = ""):
+    if cursor.kind == cl.CursorKind.STRUCT_DECL:
+        return C_struct.from_cl_cursor(cursor, name)
+
+    elif cursor.kind == cl.CursorKind.UNION_DECL:
+        return C_union.from_cl_cursor(cursor, name)
+
+    elif cursor.kind in [cl.CursorKind.TYPEDEF_DECL, cl.CursorKind.ENUM_DECL]:
+        return C_variable.from_cl_cursor(cursor, name)
+
+    elif cursor.kind == cl.CursorKind.FIELD_DECL:
+        if cursor.type.kind == cl.TypeKind.CONSTANTARRAY:
+            return C_variable.from_cl_cursor(cursor, cursor.spelling)
+
+        return _c_from_cl_cursor(*_parse_field_dec(cursor))
+
+    elif cursor.kind == cl.CursorKind.TYPE_REF:
+        return _c_from_cl_cursor(*_parse_type_ref(cursor, name))
+
+    raise NotImplementedError
+
+def _c_from_h_file(h_file, name, includes, defines, f, kind):
+    name = name.replace("struct ", "")
+    name = name.replace("union ", "")
+
+    args = ["-I{}".format(i) for i in includes if os.path.isdir(i)]
+    args += ["-D{}".format(d) for d in defines]
+
+    if not os.path.isfile(h_file):
+        return FileNotFoundError
+
+    idx = cl.Index.create()
+    tu = idx.parse(h_file, args=args)
+
+    t = [cl.Cursor().from_location(tu, t.location) for t in tu.cursor.get_tokens() if t.spelling == name]
+    t = [x for x in t if x.kind == kind]
+
+    errors = ["{}: {} at {}:{}".format(d.category_name, d.spelling, d.location.file.name, d.location.line) for d in tu.diagnostics if d.severity > 2]
+    warnings = ["{}: {} at {}:{}".format(d.category_name, d.spelling, d.location.file.name, d.location.line) for d in tu.diagnostics if d.severity > 3]
+
+    for w in warnings:
+        print(w)
+
+    if errors:
+        for e in errors:
+            print(e)
+        exit(1)
+
+    if len(t) == 0:
+        print("Failed to find {} in {}".format(name, h_file))
+        exit(1)
+
+    assert(len(t) == 1)
+    t = t[0]
+
+    return f.from_cl_cursor(t, name)
+
+
+if __name__ == '__main__':
+    import argparse
+    import c_include
+
+    parser = argparse.ArgumentParser(allow_abbrev=False)
+    parser.add_argument("--h_file", help="header file to parse", required=True)
+    parser.add_argument("--struct_name", help="struct name to evaluate", required=True)
+    parser.add_argument("--compile_commands_file", help="header file to parse", required=True)
+    parser.add_argument("--c_file_to_mirror_includes_from", help="name of the c file to take", required=True)
+    parser.add_argument("--log_level", help="log level", required=False, default="ERROR", choices=logging._levelToName.values())
+    args = parser.parse_args()
+    logger.setLevel(args.log_level)
+
+    includes = c_include.get_includes(args.compile_commands_file, args.c_file_to_mirror_includes_from)
+    defines = c_include.get_defines(args.compile_commands_file, args.c_file_to_mirror_includes_from)
+
+    s = C_struct.from_h_file(args.h_file, args.struct_name, includes, defines)
+    print(s)