blob: 4e8ee6c8075836ae1a2fc823c4dd42950d0c125d [file] [log] [blame]
Gilles Peskine5294bb32024-01-04 16:38:17 +01001#!/usr/bin/env python3
2"""Generate wrapper functions for PSA function calls.
3"""
4
5# Copyright The Mbed TLS Contributors
6# SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
7
Gilles Peskinea1871f32024-01-04 17:28:59 +01008### WARNING: the code in this file has not been extensively reviewed yet.
9### We do not think it is harmful, but it may be below our normal standards
10### for robustness and maintainability.
11
Gilles Peskine5294bb32024-01-04 16:38:17 +010012import argparse
Gilles Peskine4adacac2023-12-06 19:32:52 +010013import itertools
Gilles Peskine5294bb32024-01-04 16:38:17 +010014import os
Gilles Peskine4adacac2023-12-06 19:32:52 +010015from typing import Iterator, List, Optional, Tuple
Gilles Peskine5294bb32024-01-04 16:38:17 +010016
17import scripts_path #pylint: disable=unused-import
18from mbedtls_dev import build_tree
19from mbedtls_dev import c_parsing_helper
20from mbedtls_dev import c_wrapper_generator
21from mbedtls_dev import typing_util
22
23
Gilles Peskine4adacac2023-12-06 19:32:52 +010024class BufferParameter:
25 """Description of an input or output buffer parameter sequence to a PSA function."""
26 #pylint: disable=too-few-public-methods
27
28 def __init__(self, i: int, is_output: bool,
29 buffer_name: str, size_name: str) -> None:
30 """Initialize the parameter information.
31
32 i is the index of the function argument that is the pointer to the buffer.
33 The size is argument i+1. For a variable-size output, the actual length
34 goes in argument i+2.
35
36 buffer_name and size_names are the names of arguments i and i+1.
37 This class does not yet help with the output length.
38 """
39 self.index = i
40 self.buffer_name = buffer_name
41 self.size_name = size_name
42 self.is_output = is_output
43
44
Gilles Peskine5294bb32024-01-04 16:38:17 +010045class PSAWrapperGenerator(c_wrapper_generator.Base):
46 """Generate a C source file containing wrapper functions for PSA Crypto API calls."""
47
Gilles Peskine4411c9c2024-01-04 20:51:38 +010048 _CPP_GUARDS = ('defined(MBEDTLS_PSA_CRYPTO_C) && ' +
49 'defined(MBEDTLS_TEST_HOOKS) && \\\n ' +
50 '!defined(RECORD_PSA_STATUS_COVERAGE_LOG)')
Gilles Peskine5294bb32024-01-04 16:38:17 +010051 _WRAPPER_NAME_PREFIX = 'mbedtls_test_wrap_'
52 _WRAPPER_NAME_SUFFIX = ''
53
54 def gather_data(self) -> None:
55 root_dir = build_tree.guess_mbedtls_root()
56 for header_name in ['crypto.h', 'crypto_extra.h']:
57 header_path = os.path.join(root_dir, 'include', 'psa', header_name)
58 c_parsing_helper.read_function_declarations(self.functions, header_path)
59
60 _SKIP_FUNCTIONS = frozenset([
61 'mbedtls_psa_external_get_random', # not a library function
62 'psa_get_key_domain_parameters', # client-side function
63 'psa_get_key_slot_number', # client-side function
64 'psa_key_derivation_verify_bytes', # not implemented yet
65 'psa_key_derivation_verify_key', # not implemented yet
66 'psa_set_key_domain_parameters', # client-side function
67 ])
68
69 def _skip_function(self, function: c_wrapper_generator.FunctionInfo) -> bool:
70 if function.return_type != 'psa_status_t':
71 return True
72 if function.name in self._SKIP_FUNCTIONS:
73 return True
74 return False
75
76 # PAKE stuff: not implemented yet
77 _PAKE_STUFF = frozenset([
78 'psa_crypto_driver_pake_inputs_t *',
79 'psa_pake_cipher_suite_t *',
80 ])
81
82 def _return_variable_name(self,
83 function: c_wrapper_generator.FunctionInfo) -> str:
84 """The name of the variable that will contain the return value."""
85 if function.return_type == 'psa_status_t':
86 return 'status'
87 return super()._return_variable_name(function)
88
89 _FUNCTION_GUARDS = c_wrapper_generator.Base._FUNCTION_GUARDS.copy() \
90 #pylint: disable=protected-access
91 _FUNCTION_GUARDS.update({
92 'mbedtls_psa_register_se_key': 'defined(MBEDTLS_PSA_CRYPTO_SE_C)',
93 'mbedtls_psa_inject_entropy': 'defined(MBEDTLS_PSA_INJECT_ENTROPY)',
94 'mbedtls_psa_external_get_random': 'defined(MBEDTLS_PSA_CRYPTO_EXTERNAL_RNG)',
95 'mbedtls_psa_platform_get_builtin_key': 'defined(MBEDTLS_PSA_CRYPTO_BUILTIN_KEYS)',
96 })
97
Gilles Peskine4adacac2023-12-06 19:32:52 +010098 @staticmethod
99 def _detect_buffer_parameters(arguments: List[c_parsing_helper.ArgumentInfo],
100 argument_names: List[str]) -> Iterator[BufferParameter]:
101 """Detect function arguments that are buffers (pointer, size [,length])."""
102 types = ['' if arg.suffix else arg.type for arg in arguments]
103 # pairs = list of (type_of_arg_N, type_of_arg_N+1)
104 # where each type_of_arg_X is the empty string if the type is an array
105 # or there is no argument X.
106 pairs = enumerate(itertools.zip_longest(types, types[1:], fillvalue=''))
107 for i, t01 in pairs:
108 if (t01[0] == 'const uint8_t *' or t01[0] == 'uint8_t *') and \
109 t01[1] == 'size_t':
110 yield BufferParameter(i, not t01[0].startswith('const '),
111 argument_names[i], argument_names[i+1])
112
113 @staticmethod
114 def _write_poison_buffer_parameter(out: typing_util.Writable,
115 param: BufferParameter,
116 poison: bool) -> None:
117 """Write poisoning or unpoisoning code for a buffer parameter.
118
119 Write poisoning code if poison is true, unpoisoning code otherwise.
120 """
121 out.write(' MBEDTLS_TEST_MEMORY_{}({}, {});\n'.format(
122 'POISON' if poison else 'UNPOISON',
123 param.buffer_name, param.size_name
124 ))
125
Gilles Peskine88385c22024-01-04 20:33:29 +0100126 def _write_poison_buffer_parameters(self, out: typing_util.Writable,
127 buffer_parameters: List[BufferParameter],
128 poison: bool) -> None:
129 """Write poisoning or unpoisoning code for the buffer parameters.
130
131 Write poisoning code if poison is true, unpoisoning code otherwise.
132 """
133 if not buffer_parameters:
134 return
135 out.write('#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)\n')
136 for param in buffer_parameters:
137 self._write_poison_buffer_parameter(out, param, poison)
138 out.write('#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */\n')
139
Gilles Peskine4adacac2023-12-06 19:32:52 +0100140 @staticmethod
141 def _parameter_should_be_copied(function_name: str,
142 _buffer_name: Optional[str]) -> bool:
143 """Whether the specified buffer argument to a PSA function should be copied.
144 """
Thomas Daubney4a46d732024-02-26 13:49:26 +0000145 #pylint: disable=too-many-return-statements
David Horstmann5d64c6a2024-03-11 13:58:07 +0000146 if function_name in ('psa_sign_hash_start',
147 'psa_sign_hash_complete',
148 'psa_verify_hash_start'):
149 return True
David Horstmann6076fe42024-01-23 15:28:51 +0000150 if function_name.startswith('psa_pake'):
151 return True
David Horstmann86e6fe02024-01-22 14:36:01 +0000152 if function_name.startswith('psa_aead'):
153 return True
Gabor Mezeib8f97a12024-01-24 16:58:40 +0100154 if function_name in {'psa_cipher_encrypt', 'psa_cipher_decrypt',
Gabor Mezeib74ac662024-02-01 10:39:56 +0100155 'psa_cipher_update', 'psa_cipher_finish',
156 'psa_cipher_generate_iv', 'psa_cipher_set_iv'}:
Ryan Everett84a666d2024-01-25 12:00:02 +0000157 return True
Ryan Everetteb8c6652024-02-07 17:25:39 +0000158 if function_name in ('psa_key_derivation_output_bytes',
159 'psa_key_derivation_input_bytes'):
Gilles Peskine4adacac2023-12-06 19:32:52 +0100160 return True
Ryan Everett4c74c4f2024-01-25 14:36:09 +0000161 if function_name in ('psa_import_key',
162 'psa_export_key',
163 'psa_export_public_key'):
Gilles Peskine4adacac2023-12-06 19:32:52 +0100164 return True
Thomas Daubneyf430f472024-01-30 12:25:35 +0000165 if function_name in ('psa_sign_message',
166 'psa_verify_message',
167 'psa_sign_hash',
168 'psa_verify_hash'):
169 return True
Thomas Daubney45c85862024-01-25 16:48:09 +0000170 if function_name in ('psa_hash_update',
171 'psa_hash_finish',
172 'psa_hash_verify',
173 'psa_hash_compute',
174 'psa_hash_compare'):
175 return True
Thomas Daubneyfe2bda32024-02-15 13:35:06 +0000176 if function_name in ('psa_key_derivation_key_agreement',
177 'psa_raw_key_agreement'):
178 return True
David Horstmann075c5fb2024-02-06 15:44:08 +0000179 if function_name == 'psa_generate_random':
Gilles Peskine4adacac2023-12-06 19:32:52 +0100180 return True
Thomas Daubneya1cf1012024-01-30 11:18:54 +0000181 if function_name in ('psa_mac_update',
182 'psa_mac_sign_finish',
183 'psa_mac_verify_finish',
184 'psa_mac_compute',
185 'psa_mac_verify'):
186 return True
Thomas Daubney27b48a32024-01-30 14:04:47 +0000187 if function_name in ('psa_asymmetric_encrypt',
Thomas Daubney54e6b412024-01-31 16:56:17 +0000188 'psa_asymmetric_decrypt'):
Gilles Peskine4adacac2023-12-06 19:32:52 +0100189 return True
190 return False
191
192 def _write_function_call(self, out: typing_util.Writable,
193 function: c_wrapper_generator.FunctionInfo,
194 argument_names: List[str]) -> None:
195 buffer_parameters = list(
196 param
197 for param in self._detect_buffer_parameters(function.arguments,
198 argument_names)
199 if self._parameter_should_be_copied(function.name,
200 function.arguments[param.index].name))
Gilles Peskine88385c22024-01-04 20:33:29 +0100201 self._write_poison_buffer_parameters(out, buffer_parameters, True)
Gilles Peskine4adacac2023-12-06 19:32:52 +0100202 super()._write_function_call(out, function, argument_names)
Gilles Peskine88385c22024-01-04 20:33:29 +0100203 self._write_poison_buffer_parameters(out, buffer_parameters, False)
Gilles Peskine4adacac2023-12-06 19:32:52 +0100204
Gilles Peskine5294bb32024-01-04 16:38:17 +0100205 def _write_prologue(self, out: typing_util.Writable, header: bool) -> None:
206 super()._write_prologue(out, header)
207 out.write("""
208#if {}
209
210#include <psa/crypto.h>
211
Gilles Peskine4adacac2023-12-06 19:32:52 +0100212#include <test/memory.h>
Gilles Peskine5294bb32024-01-04 16:38:17 +0100213#include <test/psa_crypto_helpers.h>
214#include <test/psa_test_wrappers.h>
215"""
216 .format(self._CPP_GUARDS))
217
218 def _write_epilogue(self, out: typing_util.Writable, header: bool) -> None:
219 out.write("""
220#endif /* {} */
221"""
222 .format(self._CPP_GUARDS))
223 super()._write_epilogue(out, header)
224
225
226class PSALoggingWrapperGenerator(PSAWrapperGenerator, c_wrapper_generator.Logging):
227 """Generate a C source file containing wrapper functions that log PSA Crypto API calls."""
228
229 def __init__(self, stream: str) -> None:
230 super().__init__()
231 self.set_stream(stream)
232
233 _PRINTF_TYPE_CAST = c_wrapper_generator.Logging._PRINTF_TYPE_CAST.copy()
234 _PRINTF_TYPE_CAST.update({
235 'mbedtls_svc_key_id_t': 'unsigned',
236 'psa_algorithm_t': 'unsigned',
237 'psa_drv_slot_number_t': 'unsigned long long',
238 'psa_key_derivation_step_t': 'int',
239 'psa_key_id_t': 'unsigned',
240 'psa_key_slot_number_t': 'unsigned long long',
241 'psa_key_lifetime_t': 'unsigned',
242 'psa_key_type_t': 'unsigned',
243 'psa_key_usage_flags_t': 'unsigned',
244 'psa_pake_role_t': 'int',
245 'psa_pake_step_t': 'int',
246 'psa_status_t': 'int',
247 })
248
249 def _printf_parameters(self, typ: str, var: str) -> Tuple[str, List[str]]:
250 if typ.startswith('const '):
251 typ = typ[6:]
252 if typ == 'uint8_t *':
253 # Skip buffers
254 return '', []
255 if typ.endswith('operation_t *'):
256 return '', []
257 if typ in self._PAKE_STUFF:
258 return '', []
259 if typ == 'psa_key_attributes_t *':
260 return (var + '={id=%u, lifetime=0x%08x, type=0x%08x, bits=%u, alg=%08x, usage=%08x}',
261 ['(unsigned) psa_get_key_{}({})'.format(field, var)
262 for field in ['id', 'lifetime', 'type', 'bits', 'algorithm', 'usage_flags']])
263 return super()._printf_parameters(typ, var)
264
265
266DEFAULT_C_OUTPUT_FILE_NAME = 'tests/src/psa_test_wrappers.c'
267DEFAULT_H_OUTPUT_FILE_NAME = 'tests/include/test/psa_test_wrappers.h'
268
269def main() -> None:
270 parser = argparse.ArgumentParser(description=globals()['__doc__'])
271 parser.add_argument('--log',
272 help='Stream to log to (default: no logging code)')
273 parser.add_argument('--output-c',
274 metavar='FILENAME',
275 default=DEFAULT_C_OUTPUT_FILE_NAME,
276 help=('Output .c file path (default: {}; skip .c output if empty)'
277 .format(DEFAULT_C_OUTPUT_FILE_NAME)))
278 parser.add_argument('--output-h',
279 metavar='FILENAME',
280 default=DEFAULT_H_OUTPUT_FILE_NAME,
281 help=('Output .h file path (default: {}; skip .h output if empty)'
282 .format(DEFAULT_H_OUTPUT_FILE_NAME)))
283 options = parser.parse_args()
284 if options.log:
285 generator = PSALoggingWrapperGenerator(options.log) #type: PSAWrapperGenerator
286 else:
287 generator = PSAWrapperGenerator()
288 generator.gather_data()
289 if options.output_h:
290 generator.write_h_file(options.output_h)
291 if options.output_c:
292 generator.write_c_file(options.output_c)
293
294if __name__ == '__main__':
295 main()