blob: f7a193adade29978af5e4b349669bf4c2386bd9c [file] [log] [blame]
Gilles Peskine8519dc92024-01-04 16:38:17 +01001#!/usr/bin/env python3
2"""Generate wrapper functions for PSA function calls.
3"""
4
5# Copyright The Mbed TLS Contributors
6# SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
7
8import argparse
9import os
10from typing import List, Tuple
11
12import scripts_path #pylint: disable=unused-import
13from mbedtls_dev import build_tree
14from mbedtls_dev import c_parsing_helper
15from mbedtls_dev import c_wrapper_generator
16from mbedtls_dev import typing_util
17
18
19class PSAWrapperGenerator(c_wrapper_generator.Base):
20 """Generate a C source file containing wrapper functions for PSA Crypto API calls."""
21
22 _CPP_GUARDS = 'defined(MBEDTLS_PSA_CRYPTO_C) && defined(MBEDTLS_TEST_HOOKS)'
23 _WRAPPER_NAME_PREFIX = 'mbedtls_test_wrap_'
24 _WRAPPER_NAME_SUFFIX = ''
25
26 def gather_data(self) -> None:
27 root_dir = build_tree.guess_mbedtls_root()
28 for header_name in ['crypto.h', 'crypto_extra.h']:
29 header_path = os.path.join(root_dir, 'include', 'psa', header_name)
30 c_parsing_helper.read_function_declarations(self.functions, header_path)
31
32 _SKIP_FUNCTIONS = frozenset([
33 'mbedtls_psa_external_get_random', # not a library function
Gilles Peskine17a14f12024-01-04 16:41:30 +010034 'psa_aead_abort', # not implemented yet
35 'psa_aead_decrypt_setup', # not implemented yet
36 'psa_aead_encrypt_setup', # not implemented yet
37 'psa_aead_finish', # not implemented yet
38 'psa_aead_generate_nonce', # not implemented yet
39 'psa_aead_set_lengths', # not implemented yet
40 'psa_aead_set_nonce', # not implemented yet
41 'psa_aead_update', # not implemented yet
42 'psa_aead_update_ad', # not implemented yet
43 'psa_aead_verify', # not implemented yet
Gilles Peskine8519dc92024-01-04 16:38:17 +010044 'psa_get_key_domain_parameters', # client-side function
45 'psa_get_key_slot_number', # client-side function
Gilles Peskine8519dc92024-01-04 16:38:17 +010046 'psa_set_key_domain_parameters', # client-side function
47 ])
48
49 def _skip_function(self, function: c_wrapper_generator.FunctionInfo) -> bool:
50 if function.return_type != 'psa_status_t':
51 return True
52 if function.name in self._SKIP_FUNCTIONS:
53 return True
54 return False
55
56 # PAKE stuff: not implemented yet
57 _PAKE_STUFF = frozenset([
58 'psa_crypto_driver_pake_inputs_t *',
59 'psa_pake_cipher_suite_t *',
60 ])
61
62 def _return_variable_name(self,
63 function: c_wrapper_generator.FunctionInfo) -> str:
64 """The name of the variable that will contain the return value."""
65 if function.return_type == 'psa_status_t':
66 return 'status'
67 return super()._return_variable_name(function)
68
69 _FUNCTION_GUARDS = c_wrapper_generator.Base._FUNCTION_GUARDS.copy() \
70 #pylint: disable=protected-access
71 _FUNCTION_GUARDS.update({
72 'mbedtls_psa_register_se_key': 'defined(MBEDTLS_PSA_CRYPTO_SE_C)',
73 'mbedtls_psa_inject_entropy': 'defined(MBEDTLS_PSA_INJECT_ENTROPY)',
74 'mbedtls_psa_external_get_random': 'defined(MBEDTLS_PSA_CRYPTO_EXTERNAL_RNG)',
75 'mbedtls_psa_platform_get_builtin_key': 'defined(MBEDTLS_PSA_CRYPTO_BUILTIN_KEYS)',
76 })
77
78 def _write_prologue(self, out: typing_util.Writable, header: bool) -> None:
79 super()._write_prologue(out, header)
80 out.write("""
81#if {}
82
83#include <psa/crypto.h>
84
85#include <test/psa_crypto_helpers.h>
86#include <test/psa_test_wrappers.h>
87"""
88 .format(self._CPP_GUARDS))
89
90 def _write_epilogue(self, out: typing_util.Writable, header: bool) -> None:
91 out.write("""
92#endif /* {} */
93"""
94 .format(self._CPP_GUARDS))
95 super()._write_epilogue(out, header)
96
97
98class PSALoggingWrapperGenerator(PSAWrapperGenerator, c_wrapper_generator.Logging):
99 """Generate a C source file containing wrapper functions that log PSA Crypto API calls."""
100
101 def __init__(self, stream: str) -> None:
102 super().__init__()
103 self.set_stream(stream)
104
105 _PRINTF_TYPE_CAST = c_wrapper_generator.Logging._PRINTF_TYPE_CAST.copy()
106 _PRINTF_TYPE_CAST.update({
107 'mbedtls_svc_key_id_t': 'unsigned',
108 'psa_algorithm_t': 'unsigned',
109 'psa_drv_slot_number_t': 'unsigned long long',
110 'psa_key_derivation_step_t': 'int',
111 'psa_key_id_t': 'unsigned',
112 'psa_key_slot_number_t': 'unsigned long long',
113 'psa_key_lifetime_t': 'unsigned',
114 'psa_key_type_t': 'unsigned',
115 'psa_key_usage_flags_t': 'unsigned',
116 'psa_pake_role_t': 'int',
117 'psa_pake_step_t': 'int',
118 'psa_status_t': 'int',
119 })
120
121 def _printf_parameters(self, typ: str, var: str) -> Tuple[str, List[str]]:
122 if typ.startswith('const '):
123 typ = typ[6:]
124 if typ == 'uint8_t *':
125 # Skip buffers
126 return '', []
127 if typ.endswith('operation_t *'):
128 return '', []
129 if typ in self._PAKE_STUFF:
130 return '', []
131 if typ == 'psa_key_attributes_t *':
132 return (var + '={id=%u, lifetime=0x%08x, type=0x%08x, bits=%u, alg=%08x, usage=%08x}',
133 ['(unsigned) psa_get_key_{}({})'.format(field, var)
134 for field in ['id', 'lifetime', 'type', 'bits', 'algorithm', 'usage_flags']])
135 return super()._printf_parameters(typ, var)
136
137
138DEFAULT_C_OUTPUT_FILE_NAME = 'tests/src/psa_test_wrappers.c'
139DEFAULT_H_OUTPUT_FILE_NAME = 'tests/include/test/psa_test_wrappers.h'
140
141def main() -> None:
142 parser = argparse.ArgumentParser(description=globals()['__doc__'])
143 parser.add_argument('--log',
144 help='Stream to log to (default: no logging code)')
145 parser.add_argument('--output-c',
146 metavar='FILENAME',
147 default=DEFAULT_C_OUTPUT_FILE_NAME,
148 help=('Output .c file path (default: {}; skip .c output if empty)'
149 .format(DEFAULT_C_OUTPUT_FILE_NAME)))
150 parser.add_argument('--output-h',
151 metavar='FILENAME',
152 default=DEFAULT_H_OUTPUT_FILE_NAME,
153 help=('Output .h file path (default: {}; skip .h output if empty)'
154 .format(DEFAULT_H_OUTPUT_FILE_NAME)))
155 options = parser.parse_args()
156 if options.log:
157 generator = PSALoggingWrapperGenerator(options.log) #type: PSAWrapperGenerator
158 else:
159 generator = PSAWrapperGenerator()
160 generator.gather_data()
161 if options.output_h:
162 generator.write_h_file(options.output_h)
163 if options.output_c:
164 generator.write_c_file(options.output_c)
165
166if __name__ == '__main__':
167 main()