blob: a019fce814845fbd3e74cfec40799c5ebcfa0e92 [file] [log] [blame]
Raef Coles59cf5d82024-12-09 15:41:13 +00001#!/usr/bin/env python3
2#-------------------------------------------------------------------------------
3# SPDX-FileCopyrightText: Copyright The TrustedFirmware-M Contributors
4#
5# SPDX-License-Identifier: BSD-3-Clause
6#
7#-------------------------------------------------------------------------------
8
9import clang.cindex as cl
10import struct
11from collections import Counter
12import os
13
14import logging
Raef Colescfc31242025-04-04 09:38:47 +010015logger = logging.getLogger("TF-M.{}".format(__name__))
Raef Coles59cf5d82024-12-09 15:41:13 +000016
17from rich import inspect
18
19pad = 4
20
21format_chars = {
22 'uint8_t' : 'B',
23 'uint16_t' : 'H',
24 'uint32_t' : 'I',
25 'uint64_t' : 'Q',
26 'bool' : '?',
27 'char' : 'b',
28 'short' : 'H',
29 'int' : 'I',
30 'long' : 'I',
31 'long long' : 'Q',
32 'uintptr_t' : 'I',
33 'size_t' : 'I',
34 'float' : 'f',
35 'double' : 'd',
36 'long double' : 'd',
37}
38
39def _c_struct_or_union_from_cl_cursor(cursor, name, f):
40 def is_field(x):
41 if x.kind == cl.CursorKind.FIELD_DECL:
42 return True
43 if x.kind in [cl.CursorKind.STRUCT_DECL, cl.CursorKind.UNION_DECL]:
44 return True
45 return False
46
47 fields = [c for c in cursor.get_children() if is_field(c)]
48 packed = True if [c for c in cursor.get_children() if c.kind == cl.CursorKind.PACKED_ATTR] else False
49
50 c_type = cursor.spelling
51
52 fields = [_c_from_cl_cursor(f) for f in fields]
53 fields = [f for f in fields if f.get_size() != 0]
54
55 duplicate_types = Counter([f.c_type for f in fields if not isinstance(f, C_variable)])
56 duplicate_types = [k for k,v in duplicate_types.items() if v > 1]
57 fields = [f for f in fields if f.c_type not in duplicate_types or f.name != ""]
58
59 return f(name, c_type, fields, packed)
60
61def _pad_lines(text, size):
62 lines = text.split("\n")
63 lines = [" " * size + l for l in lines]
64 return "\n".join(lines)
65
66def _c_struct_or_union_get_value_str(self, struct_or_union):
67 string = "{\n"
68 fields_string = ""
69 for f in self._fields:
70 if f.to_bytes() != bytes(f.get_size()):
71 fields_string += _pad_lines(".{} = {},".format(f.name, f.get_value_str()), pad)
Raef Coles2e6539f2025-04-04 10:29:37 +010072 fields_string += "\n"
73 fields_string = fields_string[:-1]
74 if not fields_string:
75 return "{}"
Raef Coles59cf5d82024-12-09 15:41:13 +000076 string += fields_string
77 string += "\n}"
78 return string
79
80def _c_struct_or_union_str(self, struct_or_union):
81 string = ""
82 string += "{} ".format(struct_or_union)
83 string += "__attribute__((packed)) " if self._packed else ""
84 string += "{} ".format(self.c_type) if self.c_type != "" else ""
85 string += "{\n"
86
87 fields_string = "\n".join([_pad_lines(str(f), pad) for f in self._fields])
88 string += fields_string
89
90 if string[-1] != "\n":
91 string += "\n"
92
93 string += "}"
94 string += " {}".format(self.name) if self.name != "" else ""
95 string += ";"
96 return string
97
98def _c_struct_or_union_get_field_strings(self):
99 field_strings=[]
100
101 for f in self._fields:
102 name_format = "{}.".format(self.name) if self.name != "" else ""
103 field_strings += [name_format + s for s in f.get_field_strings()]
104
105 return field_strings
106
107def _c_struct_or_union_field_to_bytes(field):
108 if isinstance(field, C_variable):
109 return field._get_format_string(), field.get_value()
110 else:
111 return _binary_format_string(field.get_size()), field.to_bytes()
112
113def _c_struct_or_union_get_next_in_path(self, field_path):
114 field_separator = "."
115
116 if field_separator in field_path:
117 field_name, remainder = field_path.split(".", 1)
118 else:
119 field_name = field_path
120 remainder = None
121
122 try:
123 # Fields aren't allowed to duplicate names
124 field = [f for f in self._fields if f.name == field_name][0]
125 return field, remainder
126 except (KeyError, IndexError):
127 for f in self._fields:
128 try:
129 f.get_field(field_path)
130 return f, field_path
131 except KeyError:
132 continue
133 raise KeyError
134
135def _c_struct_or_union_get_direct_members(self):
136 subfields = []
137 for f in self._fields:
138 if f.name:
139 continue
140 subfields += _c_struct_or_union_get_direct_members(f)
141
142 return [f for f in self._fields if f.name] + subfields
143
144def _c_struct_or_union_get_field(self, field_path):
145 fields = [f for f in self._fields if f.name == field_path[:len(f.name)]]
146 for f in fields:
147 try:
148 remainder = field_path.replace(f.name + ".", "") if f.name else field_path
149 return f.get_field(remainder)
150 except (KeyError, ValueError):
151 continue
152 raise KeyError
153
154
155def _binary_format_string(size):
156 return "{}s".format(size)
157
158class C_enum:
159 def __init__(self, name, c_type, members):
160 self.name = name.replace("enum", "").lstrip()
161 self._members = members
162 self.dict = {m.name:m for m in self._members}
163 self.__dict__.update({m.name:m for m in self._members})
164
165 @staticmethod
166 def from_h_file(h_file, name, includes = [], defines = []):
167 return _c_from_h_file(h_file, name, includes, defines, C_enum, cl.CursorKind.ENUM_DECL)
168
169 @staticmethod
170 def from_cl_cursor(cursor, name=""):
171 definition = cursor.get_definition()
172 assert(definition.kind == cl.CursorKind.ENUM_DECL)
173
174 max_value = max([x.enum_value for x in definition.get_children()])
175 if (max_value <= 2^8):
176 c_type = 'uint8_t'
177 elif (max_value <= 2^16):
178 c_type = 'uint16_t'
179 else:
180 c_type = 'uint32_t'
181
182 members = [C_enum_member(x.spelling, c_type, x.enum_value) for x in definition.get_children()]
183 members = [m for m in members if m.name[0] != '_']
184 return C_enum(cursor.spelling, c_type, members)
185
186
187 def __str__(self):
188 pad = 4
189
190 string = ""
191
192 string += "enum "
193 string += "{} ".format(self.name)
194 string += "{\n"
195
196 fields_string = "\n".join([_pad_lines(str(f), pad) for f in self._members])
197 string += fields_string
198
199 if string[-1] != "\n":
200 string += "\n"
201
202 string += "};"
203 return string
204
205class C_enum_member:
206 def __init__(self, name, c_type, value):
207 self.name = name
208 self.c_type = c_type
209 self.value = value
210 self._format_string = self._get_format_string()
211 self._size = struct.calcsize(self._get_format_string());
212
213 def _get_format_string(self):
214 return format_chars[self.c_type]
215
216 def get_value(self):
217 return self.value
218
219 def to_bytes(self):
220 return struct.pack("<" + self._format_string, self.get_value())
221
222 def get_size(self):
223 return self._size
224
225 def __str__(self):
226 return self.name
227
228 def __repr__(self):
229 return self.name
230
231class C_array:
232 def __init__(self, name, c_type, members):
233 self.name = name
234 self.c_type = c_type
235 self._members = members
236
237 if self._members:
238 self._dimensions = [len(self._members)]
239 if isinstance(self._members[0], C_array):
240 self._dimensions += self._members[0]._dimensions
241 else:
242 self._dimensions = [0]
243
244 self._size = struct.calcsize(self._get_format_string());
245
246 @staticmethod
247 def from_h_file(h_file, name, includes = [], defines = []):
248 return _c_from_h_file(h_file, name, includes, defines, C_array, cl.CursorKind.TYPE_DECL)
249
250 @staticmethod
251 def from_cl_cursor(cursor, name="", dimensions = []):
252 c_type = cursor.type.spelling
253 c_type = c_type.replace("unsigned", "")
254 c_type = c_type.replace("volatile", "")
255 c_type = c_type.replace("const", "")
256 c_type = c_type.replace("static", "")
257
258 if not dimensions:
259 c_type, *dimensions = c_type.split("[")
260 dimensions = [0 if d == "]" else int(d.replace("]", "")) for d in dimensions]
261 c_type = c_type.strip()
262
263 assert(len(dimensions) > 0)
264
265 if (len(dimensions) > 1):
266 f = lambda x:C_array.from_cl_cursor(cursor, x, dimensions[1:])
267 elif c_type not in format_chars.keys():
268 f = lambda x:_c_from_cl_cursor(list(cursor.get_children())[0].get_definition(), x)
269 else:
270 f = lambda x:C_variable(x, c_type)
271
272 members = [f(name + "_{}".format(i)) for i in range(dimensions[0])]
273
274 return C_array(name, c_type, members)
275
276 def __getitem__(self, index):
277 return self._members[index]
278
279 # def __setitem__(self, index, value):
280 # return self._members[index].set_value_from_bytes(value)
281
282 def _get_format_string(self):
283 return "<" + "".join([_c_struct_or_union_field_to_bytes(m)[0] for m in self._members])
284
285 def get_value(self):
286 return self.to_bytes()
287
288 def set_value(self, value):
289 self.set_value_from_bytes(value)
290
291 def set_value_from_bytes(self, value):
292 assert(len(value) <= self._size), "{} of size {} cannot be set to value {} of size {}".format(self, self._size, value.hex(), len(value))
293 value_used = 0
294 for m in self._members:
295 if (value_used == len(value)):
296 break
297
298 m.set_value_from_bytes(value[value_used:value_used + m.get_size()])
299 value_used += m.get_size()
300
301 def get_field_strings(self):
302 return [self.name] + [f for m in self._members if not isinstance(m, C_variable) for f in m.get_field_strings()]
303
304 def get_field(self, field_path):
305 if field_path == self.name:
306 return self
307
308 field_path = field_path.replace(self.name + "_", "")
309
310 splits = [x for x in [field_path.find('.'), field_path.find("_")] if x != -1]
311
312 if not splits:
313 index = field_path
314 remainder = None
315 else:
316 split_idx = min(splits)
317 index = field_path[:split_idx]
318 remainder = field_path[split_idx + 1:]
319
320 index = int(index)
321
322 if (remainder):
323 return self._members[index].get_field(remainder)
324 else:
325 return self._members[index]
326
327 def to_bytes(self):
328 format_str = ""
329 values = []
330
331 for m in self._members:
332 field_string, field_data = _c_struct_or_union_field_to_bytes(m)
333 format_str += field_string
334 values.append(field_data)
335
336 return struct.pack(format_str, *values)
337
338 def get_size(self):
339 return self._size
340
341 def get_value_str(self):
342 string = ""
343 if self.to_bytes() != bytes(self.get_size()):
344 string += "{\n"
345 m_string = ""
Raef Coles2e6539f2025-04-04 10:29:37 +0100346 for idx,m in enumerate(self._members):
347 tmp = m.get_value_str() + ", "
348 m_string += tmp
349 if m_string[-3:] == "}, " or idx % 8 == 7:
350 m_string = m_string[:-1] + "\n"
351 m_string = m_string[:-1]
Raef Coles59cf5d82024-12-09 15:41:13 +0000352 string += _pad_lines(m_string, pad)
353 string += "\n}"
Raef Coles2e6539f2025-04-04 10:29:37 +0100354 return string
355 else:
356 return "{}"
Raef Coles59cf5d82024-12-09 15:41:13 +0000357
358 def __str__(self):
359 string = "{} {}".format(self.c_type, self.name)
360 string += "".join(["[{}]".format(a) for a in self._dimensions])
361
362 if self.to_bytes() != bytes(self.get_size()):
363 string += " = " + self.get_value_str()
364
365 string += ";"
366 return string
367
368class C_variable:
369 def __init__(self, name, c_type, value = None):
370 self.name = name
371 self.c_type = c_type
372 self._format_string = self._get_format_string()
373 self._size = struct.calcsize(self._format_string)
374 self.value = value
375
376 @staticmethod
377 def from_h_file(h_file, name, includes = [], defines = []):
378 return _c_from_h_file(h_file, name, includes, defines, C_variable, cl.CursorKind.TYPE_DECL)
379
380 @staticmethod
381 def from_cl_cursor(cursor, name=""):
382 c_type = cursor.type.spelling
383 c_type = c_type.replace("unsigned", "")
384 c_type = c_type.replace("volatile", "")
385 c_type = c_type.replace("const", "")
386 c_type = c_type.replace("static", "")
387
388 if "[" in c_type:
389 return C_array.from_cl_cursor(cursor, name)
390
391 return C_variable(name, c_type)
392
393 def _get_format_string(self):
394 if 'enum' in self.c_type:
395 return format_chars['uint32_t']
396
397 return format_chars[self.c_type]
398
399 def get_value(self):
400 value = self.value if self.value else 0
401 return value
402
403 def set_value(self, value):
404 if isinstance(value, str):
405 self.value = int(value, 0)
406 else:
407 self.value = value
408
409 #Sanity check the value
410 self.to_bytes()
411
412 def set_value_from_bytes(self, value):
413 self.set_value(struct.unpack(self._format_string, value)[0])
414
415 def get_field_strings(self):
416 return [self.name]
417
418 def get_field(self, field_path):
Jackson Cooper-Driveref9f7522025-02-18 12:04:04 +0000419 if field_path != self.name:
420 raise KeyError
Raef Coles59cf5d82024-12-09 15:41:13 +0000421 return self
422
423 def to_bytes(self):
424 return struct.pack("<" + self._format_string, self.get_value())
425
426 def get_size(self):
427 return self._size
428
429 def get_value_str(self):
Raef Coles2e6539f2025-04-04 10:29:37 +0100430 value = self.value or 0
431 return f"0x{value:02x}"
Raef Coles59cf5d82024-12-09 15:41:13 +0000432
433 def __str__(self):
434 string = "{} {}".format(self.c_type, self.name)
435
436 if self.value != None:
437 string += " = {}".format(self.get_value_str())
438
439 string += ";"
440 return string
441
442class C_union:
443 def __init__(self, name="", c_type="", fields=[], packed=False):
444 self.name = name
445 self.c_type = c_type.replace("union ", "")
446 self._fields = fields
447 self._packed = packed
448 self._format_strings = self._get_format_strings()
449 self._size = max(map(struct.calcsize, self._format_strings))
450 self._actual_value = bytes(self._size)
451 self.__dict__.update({f.name:f for f in _c_struct_or_union_get_direct_members(self)})
452
453 @staticmethod
454 def from_h_file(h_file, name, includes = [], defines = []):
455 return _c_from_h_file(h_file, name, includes, defines, C_union, cl.CursorKind.UNION_DECL)
456
457 @staticmethod
458 def from_cl_cursor(cursor, name=""):
459 return _c_struct_or_union_from_cl_cursor(cursor, name, C_union)
460
461 def _ensure_consistency(self):
462 binaries = [f.to_bytes() for f in self._fields]
463 new_binaries = [b for b in binaries if b != self._actual_value[:len(b)]]
464 new_binaries = list(set(new_binaries))
465
466 assert(len(new_binaries) < 2)
467
468 if new_binaries:
Jackson Cooper-Driver12c80b82025-04-29 13:55:09 +0100469 self._actual_value = new_binaries[0] + self._actual_value[len(new_binaries[0]):]
Raef Coles59cf5d82024-12-09 15:41:13 +0000470
471 for f in self._fields:
472 f.set_value_from_bytes(self._actual_value[:f._size])
473
474 def _get_format_strings(self):
475 format_endianness = "<" if self._packed else "@"
476 return ["{}{}".format(format_endianness, _binary_format_string(f._size)) for f in self._fields]
477
478 def set_value_from_bytes(self, value):
479 for f in self._fields:
480 f.set_value_from_bytes(value)
481 self._ensure_consistency()
482
483 def set_value(self, field_path, value):
484 field.set_value(self.get_field(field_path))
485 self._ensure_consistency()
486
487 def get_value(self, field_path):
488 self._ensure_consistency()
489 return self.get_field(field_path).value
490
491 def get_field_strings(self):
492 self._ensure_consistency()
493 return _c_struct_or_union_get_field_strings(self)
494
495 def get_field(self, field_path):
496 self._ensure_consistency()
497 return _c_struct_or_union_get_field(self, field_path)
498
499 def to_bytes(self):
500 self._ensure_consistency()
501 binaries = [f.to_bytes() for f in self._fields]
502
503 binaries = set(binaries)
504 return max(list(binaries))
505
506 def get_size(self):
507 return self._size
508
509 def get_value_str(self):
510 self._ensure_consistency()
511 return _c_struct_or_union_get_value_str(self, "union")
512
513 def __str__(self):
514 self._ensure_consistency()
515 return _c_struct_or_union_str(self, "union")
516
517class C_struct:
518 def __init__(self, name="", c_type="", fields=[], packed=False):
519 self.name = name
520 self.c_type = c_type.replace("struct ", "")
521 self._fields = fields
522 self._packed = packed
523 self._format_string = self._get_format_string()
524 self._size = struct.calcsize(self._format_string)
525 self.__dict__.update({f.name:f for f in _c_struct_or_union_get_direct_members(self)})
526
527 @staticmethod
528 def from_h_file(h_file, name, includes = [], defines = []):
529 return _c_from_h_file(h_file, name, includes, defines, C_struct, cl.CursorKind.STRUCT_DECL)
530
531 @staticmethod
532 def from_cl_cursor(cursor, name=""):
533 return _c_struct_or_union_from_cl_cursor(cursor, name, C_struct)
534
535 def _get_format_string(self):
536 format_endianness = "<" if self._packed else "@"
537 return format_endianness + "".join([_c_struct_or_union_field_to_bytes(f)[0] for f in self._fields])
538
539 def set_value_from_bytes(self, value):
540 value_used = 0
541 for f in self._fields:
542 f.set_value_from_bytes(value[value_used:value_used + f.get_size()])
543 value_used += f.get_size()
544 assert(value_used == len(value))
545
546 def set_value(self, field_path, value):
547 self.get_field(field_path).set_value(value)
548
549 def get_value(self, field_path):
550 return self.get_field(field_path).value
551
552 def get_field_strings(self):
553 return _c_struct_or_union_get_field_strings(self)
554
555 def get_field(self, field_path):
556 return _c_struct_or_union_get_field(self, field_path)
557
558 def to_bytes(self):
559 format_endianness = "<" if self._packed else "@"
560 format_str = ""
561 values = []
562
563 for f in self._fields:
564 field_string, field_data = _c_struct_or_union_field_to_bytes(f)
565 format_str += field_string
566 values.append(field_data)
567
568 return struct.pack(format_str, *values)
569
570 def get_size(self):
571 return self._size
572
573 def get_docs_table(self):
574 return _c_struct_or_union_get_docs_table(self)
575
576 def get_value_str(self):
577 return _c_struct_or_union_get_value_str(self, "struct")
578
579 def __str__(self):
580 return _c_struct_or_union_str(self, "struct")
581
582def _parse_field_dec(cursor):
583 return list(cursor.get_children())[0], cursor.spelling
584
585def _parse_type_ref(cursor, name):
586 return cursor.get_definition(), name
587
588def _c_from_cl_cursor(cursor, name = ""):
589 if cursor.kind == cl.CursorKind.STRUCT_DECL:
590 return C_struct.from_cl_cursor(cursor, name)
591
592 elif cursor.kind == cl.CursorKind.UNION_DECL:
593 return C_union.from_cl_cursor(cursor, name)
594
595 elif cursor.kind in [cl.CursorKind.TYPEDEF_DECL, cl.CursorKind.ENUM_DECL]:
596 return C_variable.from_cl_cursor(cursor, name)
597
598 elif cursor.kind == cl.CursorKind.FIELD_DECL:
599 if cursor.type.kind == cl.TypeKind.CONSTANTARRAY:
600 return C_variable.from_cl_cursor(cursor, cursor.spelling)
601
602 return _c_from_cl_cursor(*_parse_field_dec(cursor))
603
604 elif cursor.kind == cl.CursorKind.TYPE_REF:
605 return _c_from_cl_cursor(*_parse_type_ref(cursor, name))
606
607 raise NotImplementedError
608
609def _c_from_h_file(h_file, name, includes, defines, f, kind):
610 name = name.replace("struct ", "")
611 name = name.replace("union ", "")
612
613 args = ["-I{}".format(i) for i in includes if os.path.isdir(i)]
614 args += ["-D{}".format(d) for d in defines]
615
616 if not os.path.isfile(h_file):
617 return FileNotFoundError
618
619 idx = cl.Index.create()
620 tu = idx.parse(h_file, args=args)
621
622 t = [cl.Cursor().from_location(tu, t.location) for t in tu.cursor.get_tokens() if t.spelling == name]
623 t = [x for x in t if x.kind == kind]
624
625 errors = ["{}: {} at {}:{}".format(d.category_name, d.spelling, d.location.file.name, d.location.line) for d in tu.diagnostics if d.severity > 2]
626 warnings = ["{}: {} at {}:{}".format(d.category_name, d.spelling, d.location.file.name, d.location.line) for d in tu.diagnostics if d.severity > 3]
627
628 for w in warnings:
629 print(w)
630
631 if errors:
632 for e in errors:
633 print(e)
634 exit(1)
635
636 if len(t) == 0:
637 print("Failed to find {} in {}".format(name, h_file))
638 exit(1)
639
640 assert(len(t) == 1)
641 t = t[0]
642
643 return f.from_cl_cursor(t, name)
644
645
646if __name__ == '__main__':
647 import argparse
648 import c_include
649
650 parser = argparse.ArgumentParser(allow_abbrev=False)
651 parser.add_argument("--h_file", help="header file to parse", required=True)
652 parser.add_argument("--struct_name", help="struct name to evaluate", required=True)
653 parser.add_argument("--compile_commands_file", help="header file to parse", required=True)
654 parser.add_argument("--c_file_to_mirror_includes_from", help="name of the c file to take", required=True)
655 parser.add_argument("--log_level", help="log level", required=False, default="ERROR", choices=logging._levelToName.values())
656 args = parser.parse_args()
Raef Colescfc31242025-04-04 09:38:47 +0100657 logging.getLogger("TF-M").setLevel(args.log_level)
658 logger.addHandler(logging.StreamHandler())
Raef Coles59cf5d82024-12-09 15:41:13 +0000659
660 includes = c_include.get_includes(args.compile_commands_file, args.c_file_to_mirror_includes_from)
661 defines = c_include.get_defines(args.compile_commands_file, args.c_file_to_mirror_includes_from)
662
663 s = C_struct.from_h_file(args.h_file, args.struct_name, includes, defines)
664 print(s)