blob: cce75842914393890572869137024baa5a8f4d79 [file] [log] [blame]
Gilles Peskine8519dc92024-01-04 16:38:17 +01001#!/usr/bin/env python3
2"""Generate wrapper functions for PSA function calls.
3"""
4
5# Copyright The Mbed TLS Contributors
6# SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
7
Gilles Peskine7c7b7d52024-01-04 17:28:59 +01008### WARNING: the code in this file has not been extensively reviewed yet.
9### We do not think it is harmful, but it may be below our normal standards
10### for robustness and maintainability.
11
Gilles Peskine8519dc92024-01-04 16:38:17 +010012import argparse
Gilles Peskinec8b22d02023-12-06 19:32:52 +010013import itertools
Gilles Peskine8519dc92024-01-04 16:38:17 +010014import os
Gilles Peskinec8b22d02023-12-06 19:32:52 +010015from typing import Iterator, List, Optional, Tuple
Gilles Peskine8519dc92024-01-04 16:38:17 +010016
17import scripts_path #pylint: disable=unused-import
18from mbedtls_dev import build_tree
19from mbedtls_dev import c_parsing_helper
20from mbedtls_dev import c_wrapper_generator
21from mbedtls_dev import typing_util
22
23
Gilles Peskinec8b22d02023-12-06 19:32:52 +010024class BufferParameter:
25 """Description of an input or output buffer parameter sequence to a PSA function."""
26 #pylint: disable=too-few-public-methods
27
28 def __init__(self, i: int, is_output: bool,
29 buffer_name: str, size_name: str) -> None:
30 """Initialize the parameter information.
31
32 i is the index of the function argument that is the pointer to the buffer.
33 The size is argument i+1. For a variable-size output, the actual length
34 goes in argument i+2.
35
36 buffer_name and size_names are the names of arguments i and i+1.
37 This class does not yet help with the output length.
38 """
39 self.index = i
40 self.buffer_name = buffer_name
41 self.size_name = size_name
42 self.is_output = is_output
43
44
Gilles Peskine8519dc92024-01-04 16:38:17 +010045class PSAWrapperGenerator(c_wrapper_generator.Base):
46 """Generate a C source file containing wrapper functions for PSA Crypto API calls."""
47
48 _CPP_GUARDS = 'defined(MBEDTLS_PSA_CRYPTO_C) && defined(MBEDTLS_TEST_HOOKS)'
49 _WRAPPER_NAME_PREFIX = 'mbedtls_test_wrap_'
50 _WRAPPER_NAME_SUFFIX = ''
51
52 def gather_data(self) -> None:
53 root_dir = build_tree.guess_mbedtls_root()
54 for header_name in ['crypto.h', 'crypto_extra.h']:
55 header_path = os.path.join(root_dir, 'include', 'psa', header_name)
56 c_parsing_helper.read_function_declarations(self.functions, header_path)
57
58 _SKIP_FUNCTIONS = frozenset([
59 'mbedtls_psa_external_get_random', # not a library function
Gilles Peskine17a14f12024-01-04 16:41:30 +010060 'psa_aead_abort', # not implemented yet
61 'psa_aead_decrypt_setup', # not implemented yet
62 'psa_aead_encrypt_setup', # not implemented yet
63 'psa_aead_finish', # not implemented yet
64 'psa_aead_generate_nonce', # not implemented yet
65 'psa_aead_set_lengths', # not implemented yet
66 'psa_aead_set_nonce', # not implemented yet
67 'psa_aead_update', # not implemented yet
68 'psa_aead_update_ad', # not implemented yet
69 'psa_aead_verify', # not implemented yet
Gilles Peskine8519dc92024-01-04 16:38:17 +010070 'psa_get_key_domain_parameters', # client-side function
71 'psa_get_key_slot_number', # client-side function
Gilles Peskine8519dc92024-01-04 16:38:17 +010072 'psa_set_key_domain_parameters', # client-side function
73 ])
74
75 def _skip_function(self, function: c_wrapper_generator.FunctionInfo) -> bool:
76 if function.return_type != 'psa_status_t':
77 return True
78 if function.name in self._SKIP_FUNCTIONS:
79 return True
80 return False
81
82 # PAKE stuff: not implemented yet
83 _PAKE_STUFF = frozenset([
84 'psa_crypto_driver_pake_inputs_t *',
85 'psa_pake_cipher_suite_t *',
86 ])
87
88 def _return_variable_name(self,
89 function: c_wrapper_generator.FunctionInfo) -> str:
90 """The name of the variable that will contain the return value."""
91 if function.return_type == 'psa_status_t':
92 return 'status'
93 return super()._return_variable_name(function)
94
95 _FUNCTION_GUARDS = c_wrapper_generator.Base._FUNCTION_GUARDS.copy() \
96 #pylint: disable=protected-access
97 _FUNCTION_GUARDS.update({
98 'mbedtls_psa_register_se_key': 'defined(MBEDTLS_PSA_CRYPTO_SE_C)',
99 'mbedtls_psa_inject_entropy': 'defined(MBEDTLS_PSA_INJECT_ENTROPY)',
100 'mbedtls_psa_external_get_random': 'defined(MBEDTLS_PSA_CRYPTO_EXTERNAL_RNG)',
101 'mbedtls_psa_platform_get_builtin_key': 'defined(MBEDTLS_PSA_CRYPTO_BUILTIN_KEYS)',
102 })
103
Gilles Peskinec8b22d02023-12-06 19:32:52 +0100104 @staticmethod
105 def _detect_buffer_parameters(arguments: List[c_parsing_helper.ArgumentInfo],
106 argument_names: List[str]) -> Iterator[BufferParameter]:
107 """Detect function arguments that are buffers (pointer, size [,length])."""
108 types = ['' if arg.suffix else arg.type for arg in arguments]
109 # pairs = list of (type_of_arg_N, type_of_arg_N+1)
110 # where each type_of_arg_X is the empty string if the type is an array
111 # or there is no argument X.
112 pairs = enumerate(itertools.zip_longest(types, types[1:], fillvalue=''))
113 for i, t01 in pairs:
114 if (t01[0] == 'const uint8_t *' or t01[0] == 'uint8_t *') and \
115 t01[1] == 'size_t':
116 yield BufferParameter(i, not t01[0].startswith('const '),
117 argument_names[i], argument_names[i+1])
118
119 @staticmethod
120 def _write_poison_buffer_parameter(out: typing_util.Writable,
121 param: BufferParameter,
122 poison: bool) -> None:
123 """Write poisoning or unpoisoning code for a buffer parameter.
124
125 Write poisoning code if poison is true, unpoisoning code otherwise.
126 """
127 out.write(' MBEDTLS_TEST_MEMORY_{}({}, {});\n'.format(
128 'POISON' if poison else 'UNPOISON',
129 param.buffer_name, param.size_name
130 ))
131
Gilles Peskineb3d457c2024-01-04 20:33:29 +0100132 def _write_poison_buffer_parameters(self, out: typing_util.Writable,
133 buffer_parameters: List[BufferParameter],
134 poison: bool) -> None:
135 """Write poisoning or unpoisoning code for the buffer parameters.
136
137 Write poisoning code if poison is true, unpoisoning code otherwise.
138 """
139 if not buffer_parameters:
140 return
141 out.write('#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)\n')
142 for param in buffer_parameters:
143 self._write_poison_buffer_parameter(out, param, poison)
144 out.write('#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */\n')
145
Gilles Peskinec8b22d02023-12-06 19:32:52 +0100146 @staticmethod
147 def _parameter_should_be_copied(function_name: str,
148 _buffer_name: Optional[str]) -> bool:
149 """Whether the specified buffer argument to a PSA function should be copied.
150 """
151 # Proof-of-concept: just instrument one function for now
152 if function_name == 'psa_cipher_encrypt':
153 return True
154 return False
155
156 def _write_function_call(self, out: typing_util.Writable,
157 function: c_wrapper_generator.FunctionInfo,
158 argument_names: List[str]) -> None:
159 buffer_parameters = list(
160 param
161 for param in self._detect_buffer_parameters(function.arguments,
162 argument_names)
163 if self._parameter_should_be_copied(function.name,
164 function.arguments[param.index].name))
Gilles Peskineb3d457c2024-01-04 20:33:29 +0100165 self._write_poison_buffer_parameters(out, buffer_parameters, True)
Gilles Peskinec8b22d02023-12-06 19:32:52 +0100166 super()._write_function_call(out, function, argument_names)
Gilles Peskineb3d457c2024-01-04 20:33:29 +0100167 self._write_poison_buffer_parameters(out, buffer_parameters, False)
Gilles Peskinec8b22d02023-12-06 19:32:52 +0100168
Gilles Peskine8519dc92024-01-04 16:38:17 +0100169 def _write_prologue(self, out: typing_util.Writable, header: bool) -> None:
170 super()._write_prologue(out, header)
171 out.write("""
172#if {}
173
174#include <psa/crypto.h>
175
Gilles Peskinec8b22d02023-12-06 19:32:52 +0100176#include <test/memory.h>
Gilles Peskine8519dc92024-01-04 16:38:17 +0100177#include <test/psa_crypto_helpers.h>
178#include <test/psa_test_wrappers.h>
179"""
180 .format(self._CPP_GUARDS))
181
182 def _write_epilogue(self, out: typing_util.Writable, header: bool) -> None:
183 out.write("""
184#endif /* {} */
185"""
186 .format(self._CPP_GUARDS))
187 super()._write_epilogue(out, header)
188
189
190class PSALoggingWrapperGenerator(PSAWrapperGenerator, c_wrapper_generator.Logging):
191 """Generate a C source file containing wrapper functions that log PSA Crypto API calls."""
192
193 def __init__(self, stream: str) -> None:
194 super().__init__()
195 self.set_stream(stream)
196
197 _PRINTF_TYPE_CAST = c_wrapper_generator.Logging._PRINTF_TYPE_CAST.copy()
198 _PRINTF_TYPE_CAST.update({
199 'mbedtls_svc_key_id_t': 'unsigned',
200 'psa_algorithm_t': 'unsigned',
201 'psa_drv_slot_number_t': 'unsigned long long',
202 'psa_key_derivation_step_t': 'int',
203 'psa_key_id_t': 'unsigned',
204 'psa_key_slot_number_t': 'unsigned long long',
205 'psa_key_lifetime_t': 'unsigned',
206 'psa_key_type_t': 'unsigned',
207 'psa_key_usage_flags_t': 'unsigned',
208 'psa_pake_role_t': 'int',
209 'psa_pake_step_t': 'int',
210 'psa_status_t': 'int',
211 })
212
213 def _printf_parameters(self, typ: str, var: str) -> Tuple[str, List[str]]:
214 if typ.startswith('const '):
215 typ = typ[6:]
216 if typ == 'uint8_t *':
217 # Skip buffers
218 return '', []
219 if typ.endswith('operation_t *'):
220 return '', []
221 if typ in self._PAKE_STUFF:
222 return '', []
223 if typ == 'psa_key_attributes_t *':
224 return (var + '={id=%u, lifetime=0x%08x, type=0x%08x, bits=%u, alg=%08x, usage=%08x}',
225 ['(unsigned) psa_get_key_{}({})'.format(field, var)
226 for field in ['id', 'lifetime', 'type', 'bits', 'algorithm', 'usage_flags']])
227 return super()._printf_parameters(typ, var)
228
229
230DEFAULT_C_OUTPUT_FILE_NAME = 'tests/src/psa_test_wrappers.c'
231DEFAULT_H_OUTPUT_FILE_NAME = 'tests/include/test/psa_test_wrappers.h'
232
233def main() -> None:
234 parser = argparse.ArgumentParser(description=globals()['__doc__'])
235 parser.add_argument('--log',
236 help='Stream to log to (default: no logging code)')
237 parser.add_argument('--output-c',
238 metavar='FILENAME',
239 default=DEFAULT_C_OUTPUT_FILE_NAME,
240 help=('Output .c file path (default: {}; skip .c output if empty)'
241 .format(DEFAULT_C_OUTPUT_FILE_NAME)))
242 parser.add_argument('--output-h',
243 metavar='FILENAME',
244 default=DEFAULT_H_OUTPUT_FILE_NAME,
245 help=('Output .h file path (default: {}; skip .h output if empty)'
246 .format(DEFAULT_H_OUTPUT_FILE_NAME)))
247 options = parser.parse_args()
248 if options.log:
249 generator = PSALoggingWrapperGenerator(options.log) #type: PSAWrapperGenerator
250 else:
251 generator = PSAWrapperGenerator()
252 generator.gather_data()
253 if options.output_h:
254 generator.write_h_file(options.output_h)
255 if options.output_c:
256 generator.write_c_file(options.output_c)
257
258if __name__ == '__main__':
259 main()