blob: 89c02b2e49a13a8b80cc49cd961555a13891826d [file] [log] [blame]
Raef Coles59cf5d82024-12-09 15:41:13 +00001#!/usr/bin/env python3
2#-------------------------------------------------------------------------------
3# SPDX-FileCopyrightText: Copyright The TrustedFirmware-M Contributors
4#
5# SPDX-License-Identifier: BSD-3-Clause
6#
7#-------------------------------------------------------------------------------
8
9import clang.cindex as cl
10import struct
11from collections import Counter
12import os
13
14import logging
Raef Colescfc31242025-04-04 09:38:47 +010015logger = logging.getLogger("TF-M.{}".format(__name__))
Raef Coles59cf5d82024-12-09 15:41:13 +000016
17from rich import inspect
18
19pad = 4
20
21format_chars = {
22 'uint8_t' : 'B',
23 'uint16_t' : 'H',
24 'uint32_t' : 'I',
25 'uint64_t' : 'Q',
26 'bool' : '?',
27 'char' : 'b',
28 'short' : 'H',
29 'int' : 'I',
30 'long' : 'I',
31 'long long' : 'Q',
32 'uintptr_t' : 'I',
33 'size_t' : 'I',
34 'float' : 'f',
35 'double' : 'd',
36 'long double' : 'd',
37}
38
39def _c_struct_or_union_from_cl_cursor(cursor, name, f):
40 def is_field(x):
41 if x.kind == cl.CursorKind.FIELD_DECL:
42 return True
43 if x.kind in [cl.CursorKind.STRUCT_DECL, cl.CursorKind.UNION_DECL]:
44 return True
45 return False
46
47 fields = [c for c in cursor.get_children() if is_field(c)]
48 packed = True if [c for c in cursor.get_children() if c.kind == cl.CursorKind.PACKED_ATTR] else False
49
50 c_type = cursor.spelling
51
52 fields = [_c_from_cl_cursor(f) for f in fields]
53 fields = [f for f in fields if f.get_size() != 0]
54
55 duplicate_types = Counter([f.c_type for f in fields if not isinstance(f, C_variable)])
56 duplicate_types = [k for k,v in duplicate_types.items() if v > 1]
57 fields = [f for f in fields if f.c_type not in duplicate_types or f.name != ""]
58
59 return f(name, c_type, fields, packed)
60
61def _pad_lines(text, size):
62 lines = text.split("\n")
63 lines = [" " * size + l for l in lines]
64 return "\n".join(lines)
65
66def _c_struct_or_union_get_value_str(self, struct_or_union):
67 string = "{\n"
68 fields_string = ""
69 for f in self._fields:
70 if f.to_bytes() != bytes(f.get_size()):
71 fields_string += _pad_lines(".{} = {},".format(f.name, f.get_value_str()), pad)
72 string += fields_string
73 string += "\n}"
74 return string
75
76def _c_struct_or_union_str(self, struct_or_union):
77 string = ""
78 string += "{} ".format(struct_or_union)
79 string += "__attribute__((packed)) " if self._packed else ""
80 string += "{} ".format(self.c_type) if self.c_type != "" else ""
81 string += "{\n"
82
83 fields_string = "\n".join([_pad_lines(str(f), pad) for f in self._fields])
84 string += fields_string
85
86 if string[-1] != "\n":
87 string += "\n"
88
89 string += "}"
90 string += " {}".format(self.name) if self.name != "" else ""
91 string += ";"
92 return string
93
94def _c_struct_or_union_get_field_strings(self):
95 field_strings=[]
96
97 for f in self._fields:
98 name_format = "{}.".format(self.name) if self.name != "" else ""
99 field_strings += [name_format + s for s in f.get_field_strings()]
100
101 return field_strings
102
103def _c_struct_or_union_field_to_bytes(field):
104 if isinstance(field, C_variable):
105 return field._get_format_string(), field.get_value()
106 else:
107 return _binary_format_string(field.get_size()), field.to_bytes()
108
109def _c_struct_or_union_get_next_in_path(self, field_path):
110 field_separator = "."
111
112 if field_separator in field_path:
113 field_name, remainder = field_path.split(".", 1)
114 else:
115 field_name = field_path
116 remainder = None
117
118 try:
119 # Fields aren't allowed to duplicate names
120 field = [f for f in self._fields if f.name == field_name][0]
121 return field, remainder
122 except (KeyError, IndexError):
123 for f in self._fields:
124 try:
125 f.get_field(field_path)
126 return f, field_path
127 except KeyError:
128 continue
129 raise KeyError
130
131def _c_struct_or_union_get_direct_members(self):
132 subfields = []
133 for f in self._fields:
134 if f.name:
135 continue
136 subfields += _c_struct_or_union_get_direct_members(f)
137
138 return [f for f in self._fields if f.name] + subfields
139
140def _c_struct_or_union_get_field(self, field_path):
141 fields = [f for f in self._fields if f.name == field_path[:len(f.name)]]
142 for f in fields:
143 try:
144 remainder = field_path.replace(f.name + ".", "") if f.name else field_path
145 return f.get_field(remainder)
146 except (KeyError, ValueError):
147 continue
148 raise KeyError
149
150
151def _binary_format_string(size):
152 return "{}s".format(size)
153
154class C_enum:
155 def __init__(self, name, c_type, members):
156 self.name = name.replace("enum", "").lstrip()
157 self._members = members
158 self.dict = {m.name:m for m in self._members}
159 self.__dict__.update({m.name:m for m in self._members})
160
161 @staticmethod
162 def from_h_file(h_file, name, includes = [], defines = []):
163 return _c_from_h_file(h_file, name, includes, defines, C_enum, cl.CursorKind.ENUM_DECL)
164
165 @staticmethod
166 def from_cl_cursor(cursor, name=""):
167 definition = cursor.get_definition()
168 assert(definition.kind == cl.CursorKind.ENUM_DECL)
169
170 max_value = max([x.enum_value for x in definition.get_children()])
171 if (max_value <= 2^8):
172 c_type = 'uint8_t'
173 elif (max_value <= 2^16):
174 c_type = 'uint16_t'
175 else:
176 c_type = 'uint32_t'
177
178 members = [C_enum_member(x.spelling, c_type, x.enum_value) for x in definition.get_children()]
179 members = [m for m in members if m.name[0] != '_']
180 return C_enum(cursor.spelling, c_type, members)
181
182
183 def __str__(self):
184 pad = 4
185
186 string = ""
187
188 string += "enum "
189 string += "{} ".format(self.name)
190 string += "{\n"
191
192 fields_string = "\n".join([_pad_lines(str(f), pad) for f in self._members])
193 string += fields_string
194
195 if string[-1] != "\n":
196 string += "\n"
197
198 string += "};"
199 return string
200
201class C_enum_member:
202 def __init__(self, name, c_type, value):
203 self.name = name
204 self.c_type = c_type
205 self.value = value
206 self._format_string = self._get_format_string()
207 self._size = struct.calcsize(self._get_format_string());
208
209 def _get_format_string(self):
210 return format_chars[self.c_type]
211
212 def get_value(self):
213 return self.value
214
215 def to_bytes(self):
216 return struct.pack("<" + self._format_string, self.get_value())
217
218 def get_size(self):
219 return self._size
220
221 def __str__(self):
222 return self.name
223
224 def __repr__(self):
225 return self.name
226
227class C_array:
228 def __init__(self, name, c_type, members):
229 self.name = name
230 self.c_type = c_type
231 self._members = members
232
233 if self._members:
234 self._dimensions = [len(self._members)]
235 if isinstance(self._members[0], C_array):
236 self._dimensions += self._members[0]._dimensions
237 else:
238 self._dimensions = [0]
239
240 self._size = struct.calcsize(self._get_format_string());
241
242 @staticmethod
243 def from_h_file(h_file, name, includes = [], defines = []):
244 return _c_from_h_file(h_file, name, includes, defines, C_array, cl.CursorKind.TYPE_DECL)
245
246 @staticmethod
247 def from_cl_cursor(cursor, name="", dimensions = []):
248 c_type = cursor.type.spelling
249 c_type = c_type.replace("unsigned", "")
250 c_type = c_type.replace("volatile", "")
251 c_type = c_type.replace("const", "")
252 c_type = c_type.replace("static", "")
253
254 if not dimensions:
255 c_type, *dimensions = c_type.split("[")
256 dimensions = [0 if d == "]" else int(d.replace("]", "")) for d in dimensions]
257 c_type = c_type.strip()
258
259 assert(len(dimensions) > 0)
260
261 if (len(dimensions) > 1):
262 f = lambda x:C_array.from_cl_cursor(cursor, x, dimensions[1:])
263 elif c_type not in format_chars.keys():
264 f = lambda x:_c_from_cl_cursor(list(cursor.get_children())[0].get_definition(), x)
265 else:
266 f = lambda x:C_variable(x, c_type)
267
268 members = [f(name + "_{}".format(i)) for i in range(dimensions[0])]
269
270 return C_array(name, c_type, members)
271
272 def __getitem__(self, index):
273 return self._members[index]
274
275 # def __setitem__(self, index, value):
276 # return self._members[index].set_value_from_bytes(value)
277
278 def _get_format_string(self):
279 return "<" + "".join([_c_struct_or_union_field_to_bytes(m)[0] for m in self._members])
280
281 def get_value(self):
282 return self.to_bytes()
283
284 def set_value(self, value):
285 self.set_value_from_bytes(value)
286
287 def set_value_from_bytes(self, value):
288 assert(len(value) <= self._size), "{} of size {} cannot be set to value {} of size {}".format(self, self._size, value.hex(), len(value))
289 value_used = 0
290 for m in self._members:
291 if (value_used == len(value)):
292 break
293
294 m.set_value_from_bytes(value[value_used:value_used + m.get_size()])
295 value_used += m.get_size()
296
297 def get_field_strings(self):
298 return [self.name] + [f for m in self._members if not isinstance(m, C_variable) for f in m.get_field_strings()]
299
300 def get_field(self, field_path):
301 if field_path == self.name:
302 return self
303
304 field_path = field_path.replace(self.name + "_", "")
305
306 splits = [x for x in [field_path.find('.'), field_path.find("_")] if x != -1]
307
308 if not splits:
309 index = field_path
310 remainder = None
311 else:
312 split_idx = min(splits)
313 index = field_path[:split_idx]
314 remainder = field_path[split_idx + 1:]
315
316 index = int(index)
317
318 if (remainder):
319 return self._members[index].get_field(remainder)
320 else:
321 return self._members[index]
322
323 def to_bytes(self):
324 format_str = ""
325 values = []
326
327 for m in self._members:
328 field_string, field_data = _c_struct_or_union_field_to_bytes(m)
329 format_str += field_string
330 values.append(field_data)
331
332 return struct.pack(format_str, *values)
333
334 def get_size(self):
335 return self._size
336
337 def get_value_str(self):
338 string = ""
339 if self.to_bytes() != bytes(self.get_size()):
340 string += "{\n"
341 m_string = ""
342 for m in self._members:
343 m_string += m.get_value_str() + ", "
344 string += _pad_lines(m_string, pad)
345 string += "\n}"
346 return string
347
348 def __str__(self):
349 string = "{} {}".format(self.c_type, self.name)
350 string += "".join(["[{}]".format(a) for a in self._dimensions])
351
352 if self.to_bytes() != bytes(self.get_size()):
353 string += " = " + self.get_value_str()
354
355 string += ";"
356 return string
357
358class C_variable:
359 def __init__(self, name, c_type, value = None):
360 self.name = name
361 self.c_type = c_type
362 self._format_string = self._get_format_string()
363 self._size = struct.calcsize(self._format_string)
364 self.value = value
365
366 @staticmethod
367 def from_h_file(h_file, name, includes = [], defines = []):
368 return _c_from_h_file(h_file, name, includes, defines, C_variable, cl.CursorKind.TYPE_DECL)
369
370 @staticmethod
371 def from_cl_cursor(cursor, name=""):
372 c_type = cursor.type.spelling
373 c_type = c_type.replace("unsigned", "")
374 c_type = c_type.replace("volatile", "")
375 c_type = c_type.replace("const", "")
376 c_type = c_type.replace("static", "")
377
378 if "[" in c_type:
379 return C_array.from_cl_cursor(cursor, name)
380
381 return C_variable(name, c_type)
382
383 def _get_format_string(self):
384 if 'enum' in self.c_type:
385 return format_chars['uint32_t']
386
387 return format_chars[self.c_type]
388
389 def get_value(self):
390 value = self.value if self.value else 0
391 return value
392
393 def set_value(self, value):
394 if isinstance(value, str):
395 self.value = int(value, 0)
396 else:
397 self.value = value
398
399 #Sanity check the value
400 self.to_bytes()
401
402 def set_value_from_bytes(self, value):
403 self.set_value(struct.unpack(self._format_string, value)[0])
404
405 def get_field_strings(self):
406 return [self.name]
407
408 def get_field(self, field_path):
Jackson Cooper-Driveref9f7522025-02-18 12:04:04 +0000409 if field_path != self.name:
410 raise KeyError
Raef Coles59cf5d82024-12-09 15:41:13 +0000411 return self
412
413 def to_bytes(self):
414 return struct.pack("<" + self._format_string, self.get_value())
415
416 def get_size(self):
417 return self._size
418
419 def get_value_str(self):
Raef Coles0321c362025-04-04 09:51:16 +0100420 return hex(self.value) if self.value else hex(0)
Raef Coles59cf5d82024-12-09 15:41:13 +0000421
422 def __str__(self):
423 string = "{} {}".format(self.c_type, self.name)
424
425 if self.value != None:
426 string += " = {}".format(self.get_value_str())
427
428 string += ";"
429 return string
430
431class C_union:
432 def __init__(self, name="", c_type="", fields=[], packed=False):
433 self.name = name
434 self.c_type = c_type.replace("union ", "")
435 self._fields = fields
436 self._packed = packed
437 self._format_strings = self._get_format_strings()
438 self._size = max(map(struct.calcsize, self._format_strings))
439 self._actual_value = bytes(self._size)
440 self.__dict__.update({f.name:f for f in _c_struct_or_union_get_direct_members(self)})
441
442 @staticmethod
443 def from_h_file(h_file, name, includes = [], defines = []):
444 return _c_from_h_file(h_file, name, includes, defines, C_union, cl.CursorKind.UNION_DECL)
445
446 @staticmethod
447 def from_cl_cursor(cursor, name=""):
448 return _c_struct_or_union_from_cl_cursor(cursor, name, C_union)
449
450 def _ensure_consistency(self):
451 binaries = [f.to_bytes() for f in self._fields]
452 new_binaries = [b for b in binaries if b != self._actual_value[:len(b)]]
453 new_binaries = list(set(new_binaries))
454
455 assert(len(new_binaries) < 2)
456
457 if new_binaries:
458 self._actual_value = list(new_binaries)[0]
459
460 for f in self._fields:
461 f.set_value_from_bytes(self._actual_value[:f._size])
462
463 def _get_format_strings(self):
464 format_endianness = "<" if self._packed else "@"
465 return ["{}{}".format(format_endianness, _binary_format_string(f._size)) for f in self._fields]
466
467 def set_value_from_bytes(self, value):
468 for f in self._fields:
469 f.set_value_from_bytes(value)
470 self._ensure_consistency()
471
472 def set_value(self, field_path, value):
473 field.set_value(self.get_field(field_path))
474 self._ensure_consistency()
475
476 def get_value(self, field_path):
477 self._ensure_consistency()
478 return self.get_field(field_path).value
479
480 def get_field_strings(self):
481 self._ensure_consistency()
482 return _c_struct_or_union_get_field_strings(self)
483
484 def get_field(self, field_path):
485 self._ensure_consistency()
486 return _c_struct_or_union_get_field(self, field_path)
487
488 def to_bytes(self):
489 self._ensure_consistency()
490 binaries = [f.to_bytes() for f in self._fields]
491
492 binaries = set(binaries)
493 return max(list(binaries))
494
495 def get_size(self):
496 return self._size
497
498 def get_value_str(self):
499 self._ensure_consistency()
500 return _c_struct_or_union_get_value_str(self, "union")
501
502 def __str__(self):
503 self._ensure_consistency()
504 return _c_struct_or_union_str(self, "union")
505
506class C_struct:
507 def __init__(self, name="", c_type="", fields=[], packed=False):
508 self.name = name
509 self.c_type = c_type.replace("struct ", "")
510 self._fields = fields
511 self._packed = packed
512 self._format_string = self._get_format_string()
513 self._size = struct.calcsize(self._format_string)
514 self.__dict__.update({f.name:f for f in _c_struct_or_union_get_direct_members(self)})
515
516 @staticmethod
517 def from_h_file(h_file, name, includes = [], defines = []):
518 return _c_from_h_file(h_file, name, includes, defines, C_struct, cl.CursorKind.STRUCT_DECL)
519
520 @staticmethod
521 def from_cl_cursor(cursor, name=""):
522 return _c_struct_or_union_from_cl_cursor(cursor, name, C_struct)
523
524 def _get_format_string(self):
525 format_endianness = "<" if self._packed else "@"
526 return format_endianness + "".join([_c_struct_or_union_field_to_bytes(f)[0] for f in self._fields])
527
528 def set_value_from_bytes(self, value):
529 value_used = 0
530 for f in self._fields:
531 f.set_value_from_bytes(value[value_used:value_used + f.get_size()])
532 value_used += f.get_size()
533 assert(value_used == len(value))
534
535 def set_value(self, field_path, value):
536 self.get_field(field_path).set_value(value)
537
538 def get_value(self, field_path):
539 return self.get_field(field_path).value
540
541 def get_field_strings(self):
542 return _c_struct_or_union_get_field_strings(self)
543
544 def get_field(self, field_path):
545 return _c_struct_or_union_get_field(self, field_path)
546
547 def to_bytes(self):
548 format_endianness = "<" if self._packed else "@"
549 format_str = ""
550 values = []
551
552 for f in self._fields:
553 field_string, field_data = _c_struct_or_union_field_to_bytes(f)
554 format_str += field_string
555 values.append(field_data)
556
557 return struct.pack(format_str, *values)
558
559 def get_size(self):
560 return self._size
561
562 def get_docs_table(self):
563 return _c_struct_or_union_get_docs_table(self)
564
565 def get_value_str(self):
566 return _c_struct_or_union_get_value_str(self, "struct")
567
568 def __str__(self):
569 return _c_struct_or_union_str(self, "struct")
570
571def _parse_field_dec(cursor):
572 return list(cursor.get_children())[0], cursor.spelling
573
574def _parse_type_ref(cursor, name):
575 return cursor.get_definition(), name
576
577def _c_from_cl_cursor(cursor, name = ""):
578 if cursor.kind == cl.CursorKind.STRUCT_DECL:
579 return C_struct.from_cl_cursor(cursor, name)
580
581 elif cursor.kind == cl.CursorKind.UNION_DECL:
582 return C_union.from_cl_cursor(cursor, name)
583
584 elif cursor.kind in [cl.CursorKind.TYPEDEF_DECL, cl.CursorKind.ENUM_DECL]:
585 return C_variable.from_cl_cursor(cursor, name)
586
587 elif cursor.kind == cl.CursorKind.FIELD_DECL:
588 if cursor.type.kind == cl.TypeKind.CONSTANTARRAY:
589 return C_variable.from_cl_cursor(cursor, cursor.spelling)
590
591 return _c_from_cl_cursor(*_parse_field_dec(cursor))
592
593 elif cursor.kind == cl.CursorKind.TYPE_REF:
594 return _c_from_cl_cursor(*_parse_type_ref(cursor, name))
595
596 raise NotImplementedError
597
598def _c_from_h_file(h_file, name, includes, defines, f, kind):
599 name = name.replace("struct ", "")
600 name = name.replace("union ", "")
601
602 args = ["-I{}".format(i) for i in includes if os.path.isdir(i)]
603 args += ["-D{}".format(d) for d in defines]
604
605 if not os.path.isfile(h_file):
606 return FileNotFoundError
607
608 idx = cl.Index.create()
609 tu = idx.parse(h_file, args=args)
610
611 t = [cl.Cursor().from_location(tu, t.location) for t in tu.cursor.get_tokens() if t.spelling == name]
612 t = [x for x in t if x.kind == kind]
613
614 errors = ["{}: {} at {}:{}".format(d.category_name, d.spelling, d.location.file.name, d.location.line) for d in tu.diagnostics if d.severity > 2]
615 warnings = ["{}: {} at {}:{}".format(d.category_name, d.spelling, d.location.file.name, d.location.line) for d in tu.diagnostics if d.severity > 3]
616
617 for w in warnings:
618 print(w)
619
620 if errors:
621 for e in errors:
622 print(e)
623 exit(1)
624
625 if len(t) == 0:
626 print("Failed to find {} in {}".format(name, h_file))
627 exit(1)
628
629 assert(len(t) == 1)
630 t = t[0]
631
632 return f.from_cl_cursor(t, name)
633
634
635if __name__ == '__main__':
636 import argparse
637 import c_include
638
639 parser = argparse.ArgumentParser(allow_abbrev=False)
640 parser.add_argument("--h_file", help="header file to parse", required=True)
641 parser.add_argument("--struct_name", help="struct name to evaluate", required=True)
642 parser.add_argument("--compile_commands_file", help="header file to parse", required=True)
643 parser.add_argument("--c_file_to_mirror_includes_from", help="name of the c file to take", required=True)
644 parser.add_argument("--log_level", help="log level", required=False, default="ERROR", choices=logging._levelToName.values())
645 args = parser.parse_args()
Raef Colescfc31242025-04-04 09:38:47 +0100646 logging.getLogger("TF-M").setLevel(args.log_level)
647 logger.addHandler(logging.StreamHandler())
Raef Coles59cf5d82024-12-09 15:41:13 +0000648
649 includes = c_include.get_includes(args.compile_commands_file, args.c_file_to_mirror_includes_from)
650 defines = c_include.get_defines(args.compile_commands_file, args.c_file_to_mirror_includes_from)
651
652 s = C_struct.from_h_file(args.h_file, args.struct_name, includes, defines)
653 print(s)