blob: 4ae865fc6a7786d4710611c9f5e883b1e0edf07c [file] [log] [blame]
Raef Coles59cf5d82024-12-09 15:41:13 +00001#!/usr/bin/env python3
2#-------------------------------------------------------------------------------
3# SPDX-FileCopyrightText: Copyright The TrustedFirmware-M Contributors
4#
5# SPDX-License-Identifier: BSD-3-Clause
6#
7#-------------------------------------------------------------------------------
8
9import clang.cindex as cl
10import struct
11from collections import Counter
12import os
13
14import logging
15logger = logging.getLogger("TF-M")
16
17from rich import inspect
18
19pad = 4
20
21format_chars = {
22 'uint8_t' : 'B',
23 'uint16_t' : 'H',
24 'uint32_t' : 'I',
25 'uint64_t' : 'Q',
26 'bool' : '?',
27 'char' : 'b',
28 'short' : 'H',
29 'int' : 'I',
30 'long' : 'I',
31 'long long' : 'Q',
32 'uintptr_t' : 'I',
33 'size_t' : 'I',
34 'float' : 'f',
35 'double' : 'd',
36 'long double' : 'd',
37}
38
39def _c_struct_or_union_from_cl_cursor(cursor, name, f):
40 def is_field(x):
41 if x.kind == cl.CursorKind.FIELD_DECL:
42 return True
43 if x.kind in [cl.CursorKind.STRUCT_DECL, cl.CursorKind.UNION_DECL]:
44 return True
45 return False
46
47 fields = [c for c in cursor.get_children() if is_field(c)]
48 packed = True if [c for c in cursor.get_children() if c.kind == cl.CursorKind.PACKED_ATTR] else False
49
50 c_type = cursor.spelling
51
52 fields = [_c_from_cl_cursor(f) for f in fields]
53 fields = [f for f in fields if f.get_size() != 0]
54
55 duplicate_types = Counter([f.c_type for f in fields if not isinstance(f, C_variable)])
56 duplicate_types = [k for k,v in duplicate_types.items() if v > 1]
57 fields = [f for f in fields if f.c_type not in duplicate_types or f.name != ""]
58
59 return f(name, c_type, fields, packed)
60
61def _pad_lines(text, size):
62 lines = text.split("\n")
63 lines = [" " * size + l for l in lines]
64 return "\n".join(lines)
65
66def _c_struct_or_union_get_value_str(self, struct_or_union):
67 string = "{\n"
68 fields_string = ""
69 for f in self._fields:
70 if f.to_bytes() != bytes(f.get_size()):
71 fields_string += _pad_lines(".{} = {},".format(f.name, f.get_value_str()), pad)
72 string += fields_string
73 string += "\n}"
74 return string
75
76def _c_struct_or_union_str(self, struct_or_union):
77 string = ""
78 string += "{} ".format(struct_or_union)
79 string += "__attribute__((packed)) " if self._packed else ""
80 string += "{} ".format(self.c_type) if self.c_type != "" else ""
81 string += "{\n"
82
83 fields_string = "\n".join([_pad_lines(str(f), pad) for f in self._fields])
84 string += fields_string
85
86 if string[-1] != "\n":
87 string += "\n"
88
89 string += "}"
90 string += " {}".format(self.name) if self.name != "" else ""
91 string += ";"
92 return string
93
94def _c_struct_or_union_get_field_strings(self):
95 field_strings=[]
96
97 for f in self._fields:
98 name_format = "{}.".format(self.name) if self.name != "" else ""
99 field_strings += [name_format + s for s in f.get_field_strings()]
100
101 return field_strings
102
103def _c_struct_or_union_field_to_bytes(field):
104 if isinstance(field, C_variable):
105 return field._get_format_string(), field.get_value()
106 else:
107 return _binary_format_string(field.get_size()), field.to_bytes()
108
109def _c_struct_or_union_get_next_in_path(self, field_path):
110 field_separator = "."
111
112 if field_separator in field_path:
113 field_name, remainder = field_path.split(".", 1)
114 else:
115 field_name = field_path
116 remainder = None
117
118 try:
119 # Fields aren't allowed to duplicate names
120 field = [f for f in self._fields if f.name == field_name][0]
121 return field, remainder
122 except (KeyError, IndexError):
123 for f in self._fields:
124 try:
125 f.get_field(field_path)
126 return f, field_path
127 except KeyError:
128 continue
129 raise KeyError
130
131def _c_struct_or_union_get_direct_members(self):
132 subfields = []
133 for f in self._fields:
134 if f.name:
135 continue
136 subfields += _c_struct_or_union_get_direct_members(f)
137
138 return [f for f in self._fields if f.name] + subfields
139
140def _c_struct_or_union_get_field(self, field_path):
141 fields = [f for f in self._fields if f.name == field_path[:len(f.name)]]
142 for f in fields:
143 try:
144 remainder = field_path.replace(f.name + ".", "") if f.name else field_path
145 return f.get_field(remainder)
146 except (KeyError, ValueError):
147 continue
148 raise KeyError
149
150
151def _binary_format_string(size):
152 return "{}s".format(size)
153
154class C_enum:
155 def __init__(self, name, c_type, members):
156 self.name = name.replace("enum", "").lstrip()
157 self._members = members
158 self.dict = {m.name:m for m in self._members}
159 self.__dict__.update({m.name:m for m in self._members})
160
161 @staticmethod
162 def from_h_file(h_file, name, includes = [], defines = []):
163 return _c_from_h_file(h_file, name, includes, defines, C_enum, cl.CursorKind.ENUM_DECL)
164
165 @staticmethod
166 def from_cl_cursor(cursor, name=""):
167 definition = cursor.get_definition()
168 assert(definition.kind == cl.CursorKind.ENUM_DECL)
169
170 max_value = max([x.enum_value for x in definition.get_children()])
171 if (max_value <= 2^8):
172 c_type = 'uint8_t'
173 elif (max_value <= 2^16):
174 c_type = 'uint16_t'
175 else:
176 c_type = 'uint32_t'
177
178 members = [C_enum_member(x.spelling, c_type, x.enum_value) for x in definition.get_children()]
179 members = [m for m in members if m.name[0] != '_']
180 return C_enum(cursor.spelling, c_type, members)
181
182
183 def __str__(self):
184 pad = 4
185
186 string = ""
187
188 string += "enum "
189 string += "{} ".format(self.name)
190 string += "{\n"
191
192 fields_string = "\n".join([_pad_lines(str(f), pad) for f in self._members])
193 string += fields_string
194
195 if string[-1] != "\n":
196 string += "\n"
197
198 string += "};"
199 return string
200
201class C_enum_member:
202 def __init__(self, name, c_type, value):
203 self.name = name
204 self.c_type = c_type
205 self.value = value
206 self._format_string = self._get_format_string()
207 self._size = struct.calcsize(self._get_format_string());
208
209 def _get_format_string(self):
210 return format_chars[self.c_type]
211
212 def get_value(self):
213 return self.value
214
215 def to_bytes(self):
216 return struct.pack("<" + self._format_string, self.get_value())
217
218 def get_size(self):
219 return self._size
220
221 def __str__(self):
222 return self.name
223
224 def __repr__(self):
225 return self.name
226
227class C_array:
228 def __init__(self, name, c_type, members):
229 self.name = name
230 self.c_type = c_type
231 self._members = members
232
233 if self._members:
234 self._dimensions = [len(self._members)]
235 if isinstance(self._members[0], C_array):
236 self._dimensions += self._members[0]._dimensions
237 else:
238 self._dimensions = [0]
239
240 self._size = struct.calcsize(self._get_format_string());
241
242 @staticmethod
243 def from_h_file(h_file, name, includes = [], defines = []):
244 return _c_from_h_file(h_file, name, includes, defines, C_array, cl.CursorKind.TYPE_DECL)
245
246 @staticmethod
247 def from_cl_cursor(cursor, name="", dimensions = []):
248 c_type = cursor.type.spelling
249 c_type = c_type.replace("unsigned", "")
250 c_type = c_type.replace("volatile", "")
251 c_type = c_type.replace("const", "")
252 c_type = c_type.replace("static", "")
253
254 if not dimensions:
255 c_type, *dimensions = c_type.split("[")
256 dimensions = [0 if d == "]" else int(d.replace("]", "")) for d in dimensions]
257 c_type = c_type.strip()
258
259 assert(len(dimensions) > 0)
260
261 if (len(dimensions) > 1):
262 f = lambda x:C_array.from_cl_cursor(cursor, x, dimensions[1:])
263 elif c_type not in format_chars.keys():
264 f = lambda x:_c_from_cl_cursor(list(cursor.get_children())[0].get_definition(), x)
265 else:
266 f = lambda x:C_variable(x, c_type)
267
268 members = [f(name + "_{}".format(i)) for i in range(dimensions[0])]
269
270 return C_array(name, c_type, members)
271
272 def __getitem__(self, index):
273 return self._members[index]
274
275 # def __setitem__(self, index, value):
276 # return self._members[index].set_value_from_bytes(value)
277
278 def _get_format_string(self):
279 return "<" + "".join([_c_struct_or_union_field_to_bytes(m)[0] for m in self._members])
280
281 def get_value(self):
282 return self.to_bytes()
283
284 def set_value(self, value):
285 self.set_value_from_bytes(value)
286
287 def set_value_from_bytes(self, value):
288 assert(len(value) <= self._size), "{} of size {} cannot be set to value {} of size {}".format(self, self._size, value.hex(), len(value))
289 value_used = 0
290 for m in self._members:
291 if (value_used == len(value)):
292 break
293
294 m.set_value_from_bytes(value[value_used:value_used + m.get_size()])
295 value_used += m.get_size()
296
297 def get_field_strings(self):
298 return [self.name] + [f for m in self._members if not isinstance(m, C_variable) for f in m.get_field_strings()]
299
300 def get_field(self, field_path):
301 if field_path == self.name:
302 return self
303
304 field_path = field_path.replace(self.name + "_", "")
305
306 splits = [x for x in [field_path.find('.'), field_path.find("_")] if x != -1]
307
308 if not splits:
309 index = field_path
310 remainder = None
311 else:
312 split_idx = min(splits)
313 index = field_path[:split_idx]
314 remainder = field_path[split_idx + 1:]
315
316 index = int(index)
317
318 if (remainder):
319 return self._members[index].get_field(remainder)
320 else:
321 return self._members[index]
322
323 def to_bytes(self):
324 format_str = ""
325 values = []
326
327 for m in self._members:
328 field_string, field_data = _c_struct_or_union_field_to_bytes(m)
329 format_str += field_string
330 values.append(field_data)
331
332 return struct.pack(format_str, *values)
333
334 def get_size(self):
335 return self._size
336
337 def get_value_str(self):
338 string = ""
339 if self.to_bytes() != bytes(self.get_size()):
340 string += "{\n"
341 m_string = ""
342 for m in self._members:
343 m_string += m.get_value_str() + ", "
344 string += _pad_lines(m_string, pad)
345 string += "\n}"
346 return string
347
348 def __str__(self):
349 string = "{} {}".format(self.c_type, self.name)
350 string += "".join(["[{}]".format(a) for a in self._dimensions])
351
352 if self.to_bytes() != bytes(self.get_size()):
353 string += " = " + self.get_value_str()
354
355 string += ";"
356 return string
357
358class C_variable:
359 def __init__(self, name, c_type, value = None):
360 self.name = name
361 self.c_type = c_type
362 self._format_string = self._get_format_string()
363 self._size = struct.calcsize(self._format_string)
364 self.value = value
365
366 @staticmethod
367 def from_h_file(h_file, name, includes = [], defines = []):
368 return _c_from_h_file(h_file, name, includes, defines, C_variable, cl.CursorKind.TYPE_DECL)
369
370 @staticmethod
371 def from_cl_cursor(cursor, name=""):
372 c_type = cursor.type.spelling
373 c_type = c_type.replace("unsigned", "")
374 c_type = c_type.replace("volatile", "")
375 c_type = c_type.replace("const", "")
376 c_type = c_type.replace("static", "")
377
378 if "[" in c_type:
379 return C_array.from_cl_cursor(cursor, name)
380
381 return C_variable(name, c_type)
382
383 def _get_format_string(self):
384 if 'enum' in self.c_type:
385 return format_chars['uint32_t']
386
387 return format_chars[self.c_type]
388
389 def get_value(self):
390 value = self.value if self.value else 0
391 return value
392
393 def set_value(self, value):
394 if isinstance(value, str):
395 self.value = int(value, 0)
396 else:
397 self.value = value
398
399 #Sanity check the value
400 self.to_bytes()
401
402 def set_value_from_bytes(self, value):
403 self.set_value(struct.unpack(self._format_string, value)[0])
404
405 def get_field_strings(self):
406 return [self.name]
407
408 def get_field(self, field_path):
409 return self
410
411 def to_bytes(self):
412 return struct.pack("<" + self._format_string, self.get_value())
413
414 def get_size(self):
415 return self._size
416
417 def get_value_str(self):
418 return hex(self.value)
419
420 def __str__(self):
421 string = "{} {}".format(self.c_type, self.name)
422
423 if self.value != None:
424 string += " = {}".format(self.get_value_str())
425
426 string += ";"
427 return string
428
429class C_union:
430 def __init__(self, name="", c_type="", fields=[], packed=False):
431 self.name = name
432 self.c_type = c_type.replace("union ", "")
433 self._fields = fields
434 self._packed = packed
435 self._format_strings = self._get_format_strings()
436 self._size = max(map(struct.calcsize, self._format_strings))
437 self._actual_value = bytes(self._size)
438 self.__dict__.update({f.name:f for f in _c_struct_or_union_get_direct_members(self)})
439
440 @staticmethod
441 def from_h_file(h_file, name, includes = [], defines = []):
442 return _c_from_h_file(h_file, name, includes, defines, C_union, cl.CursorKind.UNION_DECL)
443
444 @staticmethod
445 def from_cl_cursor(cursor, name=""):
446 return _c_struct_or_union_from_cl_cursor(cursor, name, C_union)
447
448 def _ensure_consistency(self):
449 binaries = [f.to_bytes() for f in self._fields]
450 new_binaries = [b for b in binaries if b != self._actual_value[:len(b)]]
451 new_binaries = list(set(new_binaries))
452
453 assert(len(new_binaries) < 2)
454
455 if new_binaries:
456 self._actual_value = list(new_binaries)[0]
457
458 for f in self._fields:
459 f.set_value_from_bytes(self._actual_value[:f._size])
460
461 def _get_format_strings(self):
462 format_endianness = "<" if self._packed else "@"
463 return ["{}{}".format(format_endianness, _binary_format_string(f._size)) for f in self._fields]
464
465 def set_value_from_bytes(self, value):
466 for f in self._fields:
467 f.set_value_from_bytes(value)
468 self._ensure_consistency()
469
470 def set_value(self, field_path, value):
471 field.set_value(self.get_field(field_path))
472 self._ensure_consistency()
473
474 def get_value(self, field_path):
475 self._ensure_consistency()
476 return self.get_field(field_path).value
477
478 def get_field_strings(self):
479 self._ensure_consistency()
480 return _c_struct_or_union_get_field_strings(self)
481
482 def get_field(self, field_path):
483 self._ensure_consistency()
484 return _c_struct_or_union_get_field(self, field_path)
485
486 def to_bytes(self):
487 self._ensure_consistency()
488 binaries = [f.to_bytes() for f in self._fields]
489
490 binaries = set(binaries)
491 return max(list(binaries))
492
493 def get_size(self):
494 return self._size
495
496 def get_value_str(self):
497 self._ensure_consistency()
498 return _c_struct_or_union_get_value_str(self, "union")
499
500 def __str__(self):
501 self._ensure_consistency()
502 return _c_struct_or_union_str(self, "union")
503
504class C_struct:
505 def __init__(self, name="", c_type="", fields=[], packed=False):
506 self.name = name
507 self.c_type = c_type.replace("struct ", "")
508 self._fields = fields
509 self._packed = packed
510 self._format_string = self._get_format_string()
511 self._size = struct.calcsize(self._format_string)
512 self.__dict__.update({f.name:f for f in _c_struct_or_union_get_direct_members(self)})
513
514 @staticmethod
515 def from_h_file(h_file, name, includes = [], defines = []):
516 return _c_from_h_file(h_file, name, includes, defines, C_struct, cl.CursorKind.STRUCT_DECL)
517
518 @staticmethod
519 def from_cl_cursor(cursor, name=""):
520 return _c_struct_or_union_from_cl_cursor(cursor, name, C_struct)
521
522 def _get_format_string(self):
523 format_endianness = "<" if self._packed else "@"
524 return format_endianness + "".join([_c_struct_or_union_field_to_bytes(f)[0] for f in self._fields])
525
526 def set_value_from_bytes(self, value):
527 value_used = 0
528 for f in self._fields:
529 f.set_value_from_bytes(value[value_used:value_used + f.get_size()])
530 value_used += f.get_size()
531 assert(value_used == len(value))
532
533 def set_value(self, field_path, value):
534 self.get_field(field_path).set_value(value)
535
536 def get_value(self, field_path):
537 return self.get_field(field_path).value
538
539 def get_field_strings(self):
540 return _c_struct_or_union_get_field_strings(self)
541
542 def get_field(self, field_path):
543 return _c_struct_or_union_get_field(self, field_path)
544
545 def to_bytes(self):
546 format_endianness = "<" if self._packed else "@"
547 format_str = ""
548 values = []
549
550 for f in self._fields:
551 field_string, field_data = _c_struct_or_union_field_to_bytes(f)
552 format_str += field_string
553 values.append(field_data)
554
555 return struct.pack(format_str, *values)
556
557 def get_size(self):
558 return self._size
559
560 def get_docs_table(self):
561 return _c_struct_or_union_get_docs_table(self)
562
563 def get_value_str(self):
564 return _c_struct_or_union_get_value_str(self, "struct")
565
566 def __str__(self):
567 return _c_struct_or_union_str(self, "struct")
568
569def _parse_field_dec(cursor):
570 return list(cursor.get_children())[0], cursor.spelling
571
572def _parse_type_ref(cursor, name):
573 return cursor.get_definition(), name
574
575def _c_from_cl_cursor(cursor, name = ""):
576 if cursor.kind == cl.CursorKind.STRUCT_DECL:
577 return C_struct.from_cl_cursor(cursor, name)
578
579 elif cursor.kind == cl.CursorKind.UNION_DECL:
580 return C_union.from_cl_cursor(cursor, name)
581
582 elif cursor.kind in [cl.CursorKind.TYPEDEF_DECL, cl.CursorKind.ENUM_DECL]:
583 return C_variable.from_cl_cursor(cursor, name)
584
585 elif cursor.kind == cl.CursorKind.FIELD_DECL:
586 if cursor.type.kind == cl.TypeKind.CONSTANTARRAY:
587 return C_variable.from_cl_cursor(cursor, cursor.spelling)
588
589 return _c_from_cl_cursor(*_parse_field_dec(cursor))
590
591 elif cursor.kind == cl.CursorKind.TYPE_REF:
592 return _c_from_cl_cursor(*_parse_type_ref(cursor, name))
593
594 raise NotImplementedError
595
596def _c_from_h_file(h_file, name, includes, defines, f, kind):
597 name = name.replace("struct ", "")
598 name = name.replace("union ", "")
599
600 args = ["-I{}".format(i) for i in includes if os.path.isdir(i)]
601 args += ["-D{}".format(d) for d in defines]
602
603 if not os.path.isfile(h_file):
604 return FileNotFoundError
605
606 idx = cl.Index.create()
607 tu = idx.parse(h_file, args=args)
608
609 t = [cl.Cursor().from_location(tu, t.location) for t in tu.cursor.get_tokens() if t.spelling == name]
610 t = [x for x in t if x.kind == kind]
611
612 errors = ["{}: {} at {}:{}".format(d.category_name, d.spelling, d.location.file.name, d.location.line) for d in tu.diagnostics if d.severity > 2]
613 warnings = ["{}: {} at {}:{}".format(d.category_name, d.spelling, d.location.file.name, d.location.line) for d in tu.diagnostics if d.severity > 3]
614
615 for w in warnings:
616 print(w)
617
618 if errors:
619 for e in errors:
620 print(e)
621 exit(1)
622
623 if len(t) == 0:
624 print("Failed to find {} in {}".format(name, h_file))
625 exit(1)
626
627 assert(len(t) == 1)
628 t = t[0]
629
630 return f.from_cl_cursor(t, name)
631
632
633if __name__ == '__main__':
634 import argparse
635 import c_include
636
637 parser = argparse.ArgumentParser(allow_abbrev=False)
638 parser.add_argument("--h_file", help="header file to parse", required=True)
639 parser.add_argument("--struct_name", help="struct name to evaluate", required=True)
640 parser.add_argument("--compile_commands_file", help="header file to parse", required=True)
641 parser.add_argument("--c_file_to_mirror_includes_from", help="name of the c file to take", required=True)
642 parser.add_argument("--log_level", help="log level", required=False, default="ERROR", choices=logging._levelToName.values())
643 args = parser.parse_args()
644 logger.setLevel(args.log_level)
645
646 includes = c_include.get_includes(args.compile_commands_file, args.c_file_to_mirror_includes_from)
647 defines = c_include.get_defines(args.compile_commands_file, args.c_file_to_mirror_includes_from)
648
649 s = C_struct.from_h_file(args.h_file, args.struct_name, includes, defines)
650 print(s)