blob: 2598c541ba95f3bc1941b2f925e068f2fb860399 [file] [log] [blame]
#!/usr/bin/env python3
#-------------------------------------------------------------------------------
# SPDX-FileCopyrightText: Copyright The TrustedFirmware-M Contributors
#
# SPDX-License-Identifier: BSD-3-Clause
#
#-------------------------------------------------------------------------------
import clang.cindex as cl
import struct
from collections import Counter
import os
import subprocess
import platform
import logging
logger = logging.getLogger("TF-M.{}".format(__name__))
from rich import inspect
pad = 4
format_chars = {
'uint8_t' : 'B',
'uint16_t' : 'H',
'uint32_t' : 'I',
'uint64_t' : 'Q',
'bool' : '?',
'char' : 'b',
'short' : 'H',
'int' : 'I',
'long' : 'I',
'long long' : 'Q',
'uintptr_t' : 'I',
'size_t' : 'I',
'float' : 'f',
'double' : 'd',
'long double' : 'd',
}
def get_macos_sdk_path():
"""Get the path to the macOS SDK via xcrun."""
try:
result = subprocess.run(
['xcrun', '--sdk', 'macosx', '--show-sdk-path'],
capture_output=True,
text=True,
check=True
)
return result.stdout.strip()
except subprocess.CalledProcessError as e:
raise RuntimeError(f"Failed to get macOS SDK path: {e.stderr}")
def _c_struct_or_union_from_cl_cursor(cursor, name, f):
def is_field(x):
if x.kind == cl.CursorKind.FIELD_DECL:
return True
if x.kind in [cl.CursorKind.STRUCT_DECL, cl.CursorKind.UNION_DECL]:
return True
return False
fields = [c for c in cursor.get_children() if is_field(c)]
packed = True if [c for c in cursor.get_children() if c.kind == cl.CursorKind.PACKED_ATTR] else False
c_type = cursor.spelling
fields = [_c_from_cl_cursor(f) for f in fields]
fields = [f for f in fields if f.get_size() != 0]
duplicate_types = Counter([f.c_type for f in fields if not isinstance(f, C_variable)])
duplicate_types = [k for k,v in duplicate_types.items() if v > 1]
fields = [f for f in fields if f.c_type not in duplicate_types or f.name != ""]
return f(name, c_type, fields, packed)
def _pad_lines(text, size):
lines = text.split("\n")
lines = [" " * size + l for l in lines]
return "\n".join(lines)
def _c_struct_or_union_get_value_str(self, struct_or_union):
string = "{\n"
fields_string = ""
for f in self._fields:
if f.to_bytes() != bytes(f.get_size()):
fields_string += _pad_lines(".{} = {},".format(f.name, f.get_value_str()), pad)
fields_string += "\n"
fields_string = fields_string[:-1]
if not fields_string:
return "{}"
string += fields_string
string += "\n}"
return string
def _c_struct_or_union_str(self, struct_or_union):
string = ""
string += "{} ".format(struct_or_union)
string += "__attribute__((packed)) " if self._packed else ""
string += "{} ".format(self.c_type) if self.c_type != "" else ""
string += "{\n"
fields_string = "\n".join([_pad_lines(str(f), pad) for f in self._fields])
string += fields_string
if string[-1] != "\n":
string += "\n"
string += "}"
string += " {}".format(self.name) if self.name != "" else ""
string += ";"
return string
def _c_struct_or_union_get_field_strings(self):
field_strings=[]
for f in self._fields:
name_format = "{}.".format(self.name) if self.name != "" else ""
field_strings += [name_format + s for s in f.get_field_strings()]
return field_strings
def _c_struct_or_union_field_to_bytes(field):
if isinstance(field, C_variable):
return field._get_format_string(), field.get_value()
else:
return _binary_format_string(field.get_size()), field.to_bytes()
def _c_struct_or_union_get_next_in_path(self, field_path):
field_separator = "."
if field_separator in field_path:
field_name, remainder = field_path.split(".", 1)
else:
field_name = field_path
remainder = None
try:
# Fields aren't allowed to duplicate names
field = [f for f in self._fields if f.name == field_name][0]
return field, remainder
except (KeyError, IndexError):
for f in self._fields:
try:
f.get_field(field_path)
return f, field_path
except KeyError:
continue
raise KeyError
def _c_struct_or_union_get_direct_members(self):
subfields = []
for f in self._fields:
if f.name:
continue
subfields += _c_struct_or_union_get_direct_members(f)
return [f for f in self._fields if f.name] + subfields
def _c_struct_or_union_get_field(self, field_path):
fields = [f for f in self._fields if f.name == field_path[:len(f.name)]]
for f in fields:
try:
remainder = field_path.replace(f.name + ".", "") if f.name else field_path
return f.get_field(remainder)
except (KeyError, ValueError):
continue
raise KeyError
def _binary_format_string(size):
return "{}s".format(size)
class C_enum:
def __init__(self, name, c_type, members):
self.name = name.replace("enum", "").lstrip()
self._members = members
self.dict = {m.name:m for m in self._members}
self.__dict__.update({m.name:m for m in self._members})
@staticmethod
def from_h_file(h_file, name, includes = [], defines = []):
return _c_from_h_file(h_file, name, includes, defines, C_enum, cl.CursorKind.ENUM_DECL)
@staticmethod
def from_cl_cursor(cursor, name=""):
definition = cursor.get_definition()
assert(definition.kind == cl.CursorKind.ENUM_DECL)
max_value = max([x.enum_value for x in definition.get_children()])
if (max_value <= 2^8):
c_type = 'uint8_t'
elif (max_value <= 2^16):
c_type = 'uint16_t'
else:
c_type = 'uint32_t'
members = [C_enum_member(x.spelling, c_type, x.enum_value) for x in definition.get_children()]
members = [m for m in members if m.name[0] != '_']
return C_enum(cursor.spelling, c_type, members)
def __str__(self):
pad = 4
string = ""
string += "enum "
string += "{} ".format(self.name)
string += "{\n"
fields_string = "\n".join([_pad_lines(str(f), pad) for f in self._members])
string += fields_string
if string[-1] != "\n":
string += "\n"
string += "};"
return string
class C_enum_member:
def __init__(self, name, c_type, value):
self.name = name
self.c_type = c_type
self.value = value
self._format_string = self._get_format_string()
self._size = struct.calcsize(self._get_format_string());
def _get_format_string(self):
return format_chars[self.c_type]
def get_value(self):
return self.value
def to_bytes(self):
return struct.pack("<" + self._format_string, self.get_value())
def get_size(self):
return self._size
def __str__(self):
return self.name
def __repr__(self):
return self.name
class C_array:
def __init__(self, name, c_type, members):
self.name = name
self.c_type = c_type
self._members = members
if self._members:
self._dimensions = [len(self._members)]
if isinstance(self._members[0], C_array):
self._dimensions += self._members[0]._dimensions
else:
self._dimensions = [0]
self._size = struct.calcsize(self._get_format_string());
@staticmethod
def from_h_file(h_file, name, includes = [], defines = []):
return _c_from_h_file(h_file, name, includes, defines, C_array, cl.CursorKind.TYPE_DECL)
@staticmethod
def from_cl_cursor(cursor, name="", dimensions = []):
c_type = cursor.type.spelling
c_type = c_type.replace("unsigned", "")
c_type = c_type.replace("volatile", "")
c_type = c_type.replace("const", "")
c_type = c_type.replace("static", "")
if not dimensions:
c_type, *dimensions = c_type.split("[")
dimensions = [0 if d == "]" else int(d.replace("]", "")) for d in dimensions]
c_type = c_type.strip()
assert(len(dimensions) > 0)
if (len(dimensions) > 1):
f = lambda x:C_array.from_cl_cursor(cursor, x, dimensions[1:])
elif c_type not in format_chars.keys():
f = lambda x:_c_from_cl_cursor(list(cursor.get_children())[0].get_definition(), x)
else:
f = lambda x:C_variable(x, c_type)
members = [f(name + "_{}".format(i)) for i in range(dimensions[0])]
return C_array(name, c_type, members)
def __getitem__(self, index):
return self._members[index]
# def __setitem__(self, index, value):
# return self._members[index].set_value_from_bytes(value)
def _get_format_string(self):
return "<" + "".join([_c_struct_or_union_field_to_bytes(m)[0] for m in self._members])
def get_value(self):
return self.to_bytes()
def set_value(self, value):
self.set_value_from_bytes(value)
def set_value_from_bytes(self, value):
assert(len(value) <= self._size), "{} of size {} cannot be set to value {} of size {}".format(self, self._size, value.hex(), len(value))
value_used = 0
for m in self._members:
if (value_used == len(value)):
break
m.set_value_from_bytes(value[value_used:value_used + m.get_size()])
value_used += m.get_size()
def get_field_strings(self):
return [self.name] + [f for m in self._members if not isinstance(m, C_variable) for f in m.get_field_strings()]
def get_field(self, field_path):
if field_path == self.name:
return self
field_path = field_path.replace(self.name + "_", "")
splits = [x for x in [field_path.find('.'), field_path.find("_")] if x != -1]
if not splits:
index = field_path
remainder = None
else:
split_idx = min(splits)
index = field_path[:split_idx]
remainder = field_path[split_idx + 1:]
index = int(index)
if (remainder):
return self._members[index].get_field(remainder)
else:
return self._members[index]
def to_bytes(self):
format_str = ""
values = []
for m in self._members:
field_string, field_data = _c_struct_or_union_field_to_bytes(m)
format_str += field_string
values.append(field_data)
return struct.pack(format_str, *values)
def get_size(self):
return self._size
def get_value_str(self):
string = ""
if self.to_bytes() != bytes(self.get_size()):
string += "{\n"
m_string = ""
for idx,m in enumerate(self._members):
tmp = m.get_value_str() + ", "
m_string += tmp
if m_string[-3:] == "}, " or idx % 8 == 7:
m_string = m_string[:-1] + "\n"
m_string = m_string[:-1]
string += _pad_lines(m_string, pad)
string += "\n}"
return string
else:
return "{}"
def __str__(self):
string = "{} {}".format(self.c_type, self.name)
string += "".join(["[{}]".format(a) for a in self._dimensions])
if self.to_bytes() != bytes(self.get_size()):
string += " = " + self.get_value_str()
string += ";"
return string
class C_variable:
def __init__(self, name, c_type, value = None):
self.name = name
self.c_type = c_type
self._format_string = self._get_format_string()
self._size = struct.calcsize(self._format_string)
self.value = value
@staticmethod
def from_h_file(h_file, name, includes = [], defines = []):
return _c_from_h_file(h_file, name, includes, defines, C_variable, cl.CursorKind.TYPE_DECL)
@staticmethod
def from_cl_cursor(cursor, name=""):
c_type = cursor.type.spelling
c_type = c_type.replace("unsigned", "")
c_type = c_type.replace("volatile", "")
c_type = c_type.replace("const", "")
c_type = c_type.replace("static", "")
if "[" in c_type:
return C_array.from_cl_cursor(cursor, name)
return C_variable(name, c_type)
def _get_format_string(self):
if 'enum' in self.c_type:
return format_chars['uint32_t']
return format_chars[self.c_type]
def get_value(self):
value = self.value if self.value else 0
return value
def set_value(self, value):
if isinstance(value, str):
self.value = int(value, 0)
else:
self.value = value
#Sanity check the value
self.to_bytes()
def set_value_from_bytes(self, value):
self.set_value(struct.unpack(self._format_string, value)[0])
def get_field_strings(self):
return [self.name]
def get_field(self, field_path):
if field_path != self.name:
raise KeyError
return self
def to_bytes(self):
return struct.pack("<" + self._format_string, self.get_value())
def get_size(self):
return self._size
def get_value_str(self):
value = self.value or 0
return f"0x{value:02x}"
def __str__(self):
string = "{} {}".format(self.c_type, self.name)
if self.value != None:
string += " = {}".format(self.get_value_str())
string += ";"
return string
class C_union:
def __init__(self, name="", c_type="", fields=[], packed=False):
self.name = name
self.c_type = c_type.replace("union ", "")
self._fields = fields
self._packed = packed
self._format_strings = self._get_format_strings()
self._size = max(map(struct.calcsize, self._format_strings))
self._actual_value = bytes(self._size)
self.__dict__.update({f.name:f for f in _c_struct_or_union_get_direct_members(self)})
@staticmethod
def from_h_file(h_file, name, includes = [], defines = []):
return _c_from_h_file(h_file, name, includes, defines, C_union, cl.CursorKind.UNION_DECL)
@staticmethod
def from_cl_cursor(cursor, name=""):
return _c_struct_or_union_from_cl_cursor(cursor, name, C_union)
def _ensure_consistency(self):
binaries = [f.to_bytes() for f in self._fields]
new_binaries = [b for b in binaries if b != self._actual_value[:len(b)]]
new_binaries = list(set(new_binaries))
assert(len(new_binaries) < 2)
if new_binaries:
self._actual_value = new_binaries[0] + self._actual_value[len(new_binaries[0]):]
for f in self._fields:
f.set_value_from_bytes(self._actual_value[:f._size])
def _get_format_strings(self):
format_endianness = "<" if self._packed else "@"
return ["{}{}".format(format_endianness, _binary_format_string(f._size)) for f in self._fields]
def set_value_from_bytes(self, value):
for f in self._fields:
f.set_value_from_bytes(value)
self._ensure_consistency()
def set_value(self, field_path, value):
field.set_value(self.get_field(field_path))
self._ensure_consistency()
def get_value(self, field_path):
self._ensure_consistency()
return self.get_field(field_path).value
def get_field_strings(self):
self._ensure_consistency()
return _c_struct_or_union_get_field_strings(self)
def get_field(self, field_path):
self._ensure_consistency()
return _c_struct_or_union_get_field(self, field_path)
def to_bytes(self):
self._ensure_consistency()
binaries = [f.to_bytes() for f in self._fields]
binaries = set(binaries)
return max(list(binaries))
def get_size(self):
return self._size
def get_value_str(self):
self._ensure_consistency()
return _c_struct_or_union_get_value_str(self, "union")
def __str__(self):
self._ensure_consistency()
return _c_struct_or_union_str(self, "union")
class C_struct:
def __init__(self, name="", c_type="", fields=[], packed=False):
self.name = name
self.c_type = c_type.replace("struct ", "")
self._fields = fields
self._packed = packed
self._format_string = self._get_format_string()
self._size = struct.calcsize(self._format_string)
self.__dict__.update({f.name:f for f in _c_struct_or_union_get_direct_members(self)})
@staticmethod
def from_h_file(h_file, name, includes = [], defines = []):
return _c_from_h_file(h_file, name, includes, defines, C_struct, cl.CursorKind.STRUCT_DECL)
@staticmethod
def from_cl_cursor(cursor, name=""):
return _c_struct_or_union_from_cl_cursor(cursor, name, C_struct)
def _get_format_string(self):
format_endianness = "<" if self._packed else "@"
return format_endianness + "".join([_c_struct_or_union_field_to_bytes(f)[0] for f in self._fields])
def set_value_from_bytes(self, value):
value_used = 0
for f in self._fields:
f.set_value_from_bytes(value[value_used:value_used + f.get_size()])
value_used += f.get_size()
assert(value_used == len(value))
def set_value(self, field_path, value):
self.get_field(field_path).set_value(value)
def get_value(self, field_path):
return self.get_field(field_path).value
def get_field_strings(self):
return _c_struct_or_union_get_field_strings(self)
def get_field(self, field_path):
return _c_struct_or_union_get_field(self, field_path)
def to_bytes(self):
format_endianness = "<" if self._packed else "@"
format_str = ""
values = []
for f in self._fields:
field_string, field_data = _c_struct_or_union_field_to_bytes(f)
format_str += field_string
values.append(field_data)
return struct.pack(format_str, *values)
def get_size(self):
return self._size
def get_docs_table(self):
return _c_struct_or_union_get_docs_table(self)
def get_value_str(self):
return _c_struct_or_union_get_value_str(self, "struct")
def __str__(self):
return _c_struct_or_union_str(self, "struct")
def _parse_field_dec(cursor):
return list(cursor.get_children())[0], cursor.spelling
def _parse_type_ref(cursor, name):
return cursor.get_definition(), name
def _c_from_cl_cursor(cursor, name = ""):
if cursor.kind == cl.CursorKind.STRUCT_DECL:
return C_struct.from_cl_cursor(cursor, name)
elif cursor.kind == cl.CursorKind.UNION_DECL:
return C_union.from_cl_cursor(cursor, name)
elif cursor.kind in [cl.CursorKind.TYPEDEF_DECL, cl.CursorKind.ENUM_DECL]:
return C_variable.from_cl_cursor(cursor, name)
elif cursor.kind == cl.CursorKind.FIELD_DECL:
if cursor.type.kind == cl.TypeKind.CONSTANTARRAY:
return C_variable.from_cl_cursor(cursor, cursor.spelling)
return _c_from_cl_cursor(*_parse_field_dec(cursor))
elif cursor.kind == cl.CursorKind.TYPE_REF:
return _c_from_cl_cursor(*_parse_type_ref(cursor, name))
raise NotImplementedError
def _c_from_h_file(h_file, name, includes, defines, f, kind):
name = name.replace("struct ", "")
name = name.replace("union ", "")
args = ["-I{}".format(i) for i in includes if os.path.isdir(i)]
args += ["-D{}".format(d) for d in defines]
if platform.system() == 'Darwin':
args.extend(['-isysroot', get_macos_sdk_path()])
if not os.path.isfile(h_file):
return FileNotFoundError
idx = cl.Index.create()
tu = idx.parse(h_file, args=args)
t = [cl.Cursor().from_location(tu, t.location) for t in tu.cursor.get_tokens() if t.spelling == name]
t = [x for x in t if x.kind == kind]
errors = ["{}: {} at {}:{}".format(d.category_name, d.spelling, d.location.file.name, d.location.line) for d in tu.diagnostics if d.severity > 2]
warnings = ["{}: {} at {}:{}".format(d.category_name, d.spelling, d.location.file.name, d.location.line) for d in tu.diagnostics if d.severity > 3]
for w in warnings:
print(w)
if errors:
for e in errors:
print(e)
exit(1)
if len(t) == 0:
print("Failed to find {} in {}".format(name, h_file))
exit(1)
assert(len(t) == 1)
t = t[0]
return f.from_cl_cursor(t, name)
if __name__ == '__main__':
import argparse
import c_include
parser = argparse.ArgumentParser(allow_abbrev=False)
parser.add_argument("--h_file", help="header file to parse", required=True)
parser.add_argument("--struct_name", help="struct name to evaluate", required=True)
parser.add_argument("--compile_commands_file", help="header file to parse", required=True)
parser.add_argument("--c_file_to_mirror_includes_from", help="name of the c file to take", required=True)
parser.add_argument("--log_level", help="log level", required=False, default="ERROR", choices=logging._levelToName.values())
args = parser.parse_args()
logging.getLogger("TF-M").setLevel(args.log_level)
logger.addHandler(logging.StreamHandler())
includes = c_include.get_includes(args.compile_commands_file, args.c_file_to_mirror_includes_from)
defines = c_include.get_defines(args.compile_commands_file, args.c_file_to_mirror_includes_from)
s = C_struct.from_h_file(args.h_file, args.struct_name, includes, defines)
print(s)