Gilles Peskine | 5294bb3 | 2024-01-04 16:38:17 +0100 | [diff] [blame] | 1 | #!/usr/bin/env python3 |
| 2 | """Generate wrapper functions for PSA function calls. |
| 3 | """ |
| 4 | |
| 5 | # Copyright The Mbed TLS Contributors |
| 6 | # SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later |
| 7 | |
Gilles Peskine | a1871f3 | 2024-01-04 17:28:59 +0100 | [diff] [blame] | 8 | ### WARNING: the code in this file has not been extensively reviewed yet. |
| 9 | ### We do not think it is harmful, but it may be below our normal standards |
| 10 | ### for robustness and maintainability. |
| 11 | |
Gilles Peskine | 5294bb3 | 2024-01-04 16:38:17 +0100 | [diff] [blame] | 12 | import argparse |
Gilles Peskine | 4adacac | 2023-12-06 19:32:52 +0100 | [diff] [blame] | 13 | import itertools |
Gilles Peskine | 5294bb3 | 2024-01-04 16:38:17 +0100 | [diff] [blame] | 14 | import os |
Gilles Peskine | 4adacac | 2023-12-06 19:32:52 +0100 | [diff] [blame] | 15 | from typing import Iterator, List, Optional, Tuple |
Gilles Peskine | 5294bb3 | 2024-01-04 16:38:17 +0100 | [diff] [blame] | 16 | |
| 17 | import scripts_path #pylint: disable=unused-import |
| 18 | from mbedtls_dev import build_tree |
| 19 | from mbedtls_dev import c_parsing_helper |
| 20 | from mbedtls_dev import c_wrapper_generator |
| 21 | from mbedtls_dev import typing_util |
| 22 | |
| 23 | |
Gilles Peskine | 4adacac | 2023-12-06 19:32:52 +0100 | [diff] [blame] | 24 | class BufferParameter: |
| 25 | """Description of an input or output buffer parameter sequence to a PSA function.""" |
| 26 | #pylint: disable=too-few-public-methods |
| 27 | |
| 28 | def __init__(self, i: int, is_output: bool, |
| 29 | buffer_name: str, size_name: str) -> None: |
| 30 | """Initialize the parameter information. |
| 31 | |
| 32 | i is the index of the function argument that is the pointer to the buffer. |
| 33 | The size is argument i+1. For a variable-size output, the actual length |
| 34 | goes in argument i+2. |
| 35 | |
| 36 | buffer_name and size_names are the names of arguments i and i+1. |
| 37 | This class does not yet help with the output length. |
| 38 | """ |
| 39 | self.index = i |
| 40 | self.buffer_name = buffer_name |
| 41 | self.size_name = size_name |
| 42 | self.is_output = is_output |
| 43 | |
| 44 | |
Gilles Peskine | 5294bb3 | 2024-01-04 16:38:17 +0100 | [diff] [blame] | 45 | class PSAWrapperGenerator(c_wrapper_generator.Base): |
| 46 | """Generate a C source file containing wrapper functions for PSA Crypto API calls.""" |
| 47 | |
| 48 | _CPP_GUARDS = 'defined(MBEDTLS_PSA_CRYPTO_C) && defined(MBEDTLS_TEST_HOOKS)' |
| 49 | _WRAPPER_NAME_PREFIX = 'mbedtls_test_wrap_' |
| 50 | _WRAPPER_NAME_SUFFIX = '' |
| 51 | |
| 52 | def gather_data(self) -> None: |
| 53 | root_dir = build_tree.guess_mbedtls_root() |
| 54 | for header_name in ['crypto.h', 'crypto_extra.h']: |
| 55 | header_path = os.path.join(root_dir, 'include', 'psa', header_name) |
| 56 | c_parsing_helper.read_function_declarations(self.functions, header_path) |
| 57 | |
| 58 | _SKIP_FUNCTIONS = frozenset([ |
| 59 | 'mbedtls_psa_external_get_random', # not a library function |
| 60 | 'psa_get_key_domain_parameters', # client-side function |
| 61 | 'psa_get_key_slot_number', # client-side function |
| 62 | 'psa_key_derivation_verify_bytes', # not implemented yet |
| 63 | 'psa_key_derivation_verify_key', # not implemented yet |
| 64 | 'psa_set_key_domain_parameters', # client-side function |
| 65 | ]) |
| 66 | |
| 67 | def _skip_function(self, function: c_wrapper_generator.FunctionInfo) -> bool: |
| 68 | if function.return_type != 'psa_status_t': |
| 69 | return True |
| 70 | if function.name in self._SKIP_FUNCTIONS: |
| 71 | return True |
| 72 | return False |
| 73 | |
| 74 | # PAKE stuff: not implemented yet |
| 75 | _PAKE_STUFF = frozenset([ |
| 76 | 'psa_crypto_driver_pake_inputs_t *', |
| 77 | 'psa_pake_cipher_suite_t *', |
| 78 | ]) |
| 79 | |
| 80 | def _return_variable_name(self, |
| 81 | function: c_wrapper_generator.FunctionInfo) -> str: |
| 82 | """The name of the variable that will contain the return value.""" |
| 83 | if function.return_type == 'psa_status_t': |
| 84 | return 'status' |
| 85 | return super()._return_variable_name(function) |
| 86 | |
| 87 | _FUNCTION_GUARDS = c_wrapper_generator.Base._FUNCTION_GUARDS.copy() \ |
| 88 | #pylint: disable=protected-access |
| 89 | _FUNCTION_GUARDS.update({ |
| 90 | 'mbedtls_psa_register_se_key': 'defined(MBEDTLS_PSA_CRYPTO_SE_C)', |
| 91 | 'mbedtls_psa_inject_entropy': 'defined(MBEDTLS_PSA_INJECT_ENTROPY)', |
| 92 | 'mbedtls_psa_external_get_random': 'defined(MBEDTLS_PSA_CRYPTO_EXTERNAL_RNG)', |
| 93 | 'mbedtls_psa_platform_get_builtin_key': 'defined(MBEDTLS_PSA_CRYPTO_BUILTIN_KEYS)', |
| 94 | }) |
| 95 | |
Gilles Peskine | 4adacac | 2023-12-06 19:32:52 +0100 | [diff] [blame] | 96 | @staticmethod |
| 97 | def _detect_buffer_parameters(arguments: List[c_parsing_helper.ArgumentInfo], |
| 98 | argument_names: List[str]) -> Iterator[BufferParameter]: |
| 99 | """Detect function arguments that are buffers (pointer, size [,length]).""" |
| 100 | types = ['' if arg.suffix else arg.type for arg in arguments] |
| 101 | # pairs = list of (type_of_arg_N, type_of_arg_N+1) |
| 102 | # where each type_of_arg_X is the empty string if the type is an array |
| 103 | # or there is no argument X. |
| 104 | pairs = enumerate(itertools.zip_longest(types, types[1:], fillvalue='')) |
| 105 | for i, t01 in pairs: |
| 106 | if (t01[0] == 'const uint8_t *' or t01[0] == 'uint8_t *') and \ |
| 107 | t01[1] == 'size_t': |
| 108 | yield BufferParameter(i, not t01[0].startswith('const '), |
| 109 | argument_names[i], argument_names[i+1]) |
| 110 | |
| 111 | @staticmethod |
| 112 | def _write_poison_buffer_parameter(out: typing_util.Writable, |
| 113 | param: BufferParameter, |
| 114 | poison: bool) -> None: |
| 115 | """Write poisoning or unpoisoning code for a buffer parameter. |
| 116 | |
| 117 | Write poisoning code if poison is true, unpoisoning code otherwise. |
| 118 | """ |
| 119 | out.write(' MBEDTLS_TEST_MEMORY_{}({}, {});\n'.format( |
| 120 | 'POISON' if poison else 'UNPOISON', |
| 121 | param.buffer_name, param.size_name |
| 122 | )) |
| 123 | |
Gilles Peskine | 88385c2 | 2024-01-04 20:33:29 +0100 | [diff] [blame^] | 124 | def _write_poison_buffer_parameters(self, out: typing_util.Writable, |
| 125 | buffer_parameters: List[BufferParameter], |
| 126 | poison: bool) -> None: |
| 127 | """Write poisoning or unpoisoning code for the buffer parameters. |
| 128 | |
| 129 | Write poisoning code if poison is true, unpoisoning code otherwise. |
| 130 | """ |
| 131 | if not buffer_parameters: |
| 132 | return |
| 133 | out.write('#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS)\n') |
| 134 | for param in buffer_parameters: |
| 135 | self._write_poison_buffer_parameter(out, param, poison) |
| 136 | out.write('#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */\n') |
| 137 | |
Gilles Peskine | 4adacac | 2023-12-06 19:32:52 +0100 | [diff] [blame] | 138 | @staticmethod |
| 139 | def _parameter_should_be_copied(function_name: str, |
| 140 | _buffer_name: Optional[str]) -> bool: |
| 141 | """Whether the specified buffer argument to a PSA function should be copied. |
| 142 | """ |
| 143 | # Proof-of-concept: just instrument one function for now |
| 144 | if function_name == 'psa_cipher_encrypt': |
| 145 | return True |
| 146 | return False |
| 147 | |
| 148 | def _write_function_call(self, out: typing_util.Writable, |
| 149 | function: c_wrapper_generator.FunctionInfo, |
| 150 | argument_names: List[str]) -> None: |
| 151 | buffer_parameters = list( |
| 152 | param |
| 153 | for param in self._detect_buffer_parameters(function.arguments, |
| 154 | argument_names) |
| 155 | if self._parameter_should_be_copied(function.name, |
| 156 | function.arguments[param.index].name)) |
Gilles Peskine | 88385c2 | 2024-01-04 20:33:29 +0100 | [diff] [blame^] | 157 | self._write_poison_buffer_parameters(out, buffer_parameters, True) |
Gilles Peskine | 4adacac | 2023-12-06 19:32:52 +0100 | [diff] [blame] | 158 | super()._write_function_call(out, function, argument_names) |
Gilles Peskine | 88385c2 | 2024-01-04 20:33:29 +0100 | [diff] [blame^] | 159 | self._write_poison_buffer_parameters(out, buffer_parameters, False) |
Gilles Peskine | 4adacac | 2023-12-06 19:32:52 +0100 | [diff] [blame] | 160 | |
Gilles Peskine | 5294bb3 | 2024-01-04 16:38:17 +0100 | [diff] [blame] | 161 | def _write_prologue(self, out: typing_util.Writable, header: bool) -> None: |
| 162 | super()._write_prologue(out, header) |
| 163 | out.write(""" |
| 164 | #if {} |
| 165 | |
| 166 | #include <psa/crypto.h> |
| 167 | |
Gilles Peskine | 4adacac | 2023-12-06 19:32:52 +0100 | [diff] [blame] | 168 | #include <test/memory.h> |
Gilles Peskine | 5294bb3 | 2024-01-04 16:38:17 +0100 | [diff] [blame] | 169 | #include <test/psa_crypto_helpers.h> |
| 170 | #include <test/psa_test_wrappers.h> |
| 171 | """ |
| 172 | .format(self._CPP_GUARDS)) |
| 173 | |
| 174 | def _write_epilogue(self, out: typing_util.Writable, header: bool) -> None: |
| 175 | out.write(""" |
| 176 | #endif /* {} */ |
| 177 | """ |
| 178 | .format(self._CPP_GUARDS)) |
| 179 | super()._write_epilogue(out, header) |
| 180 | |
| 181 | |
| 182 | class PSALoggingWrapperGenerator(PSAWrapperGenerator, c_wrapper_generator.Logging): |
| 183 | """Generate a C source file containing wrapper functions that log PSA Crypto API calls.""" |
| 184 | |
| 185 | def __init__(self, stream: str) -> None: |
| 186 | super().__init__() |
| 187 | self.set_stream(stream) |
| 188 | |
| 189 | _PRINTF_TYPE_CAST = c_wrapper_generator.Logging._PRINTF_TYPE_CAST.copy() |
| 190 | _PRINTF_TYPE_CAST.update({ |
| 191 | 'mbedtls_svc_key_id_t': 'unsigned', |
| 192 | 'psa_algorithm_t': 'unsigned', |
| 193 | 'psa_drv_slot_number_t': 'unsigned long long', |
| 194 | 'psa_key_derivation_step_t': 'int', |
| 195 | 'psa_key_id_t': 'unsigned', |
| 196 | 'psa_key_slot_number_t': 'unsigned long long', |
| 197 | 'psa_key_lifetime_t': 'unsigned', |
| 198 | 'psa_key_type_t': 'unsigned', |
| 199 | 'psa_key_usage_flags_t': 'unsigned', |
| 200 | 'psa_pake_role_t': 'int', |
| 201 | 'psa_pake_step_t': 'int', |
| 202 | 'psa_status_t': 'int', |
| 203 | }) |
| 204 | |
| 205 | def _printf_parameters(self, typ: str, var: str) -> Tuple[str, List[str]]: |
| 206 | if typ.startswith('const '): |
| 207 | typ = typ[6:] |
| 208 | if typ == 'uint8_t *': |
| 209 | # Skip buffers |
| 210 | return '', [] |
| 211 | if typ.endswith('operation_t *'): |
| 212 | return '', [] |
| 213 | if typ in self._PAKE_STUFF: |
| 214 | return '', [] |
| 215 | if typ == 'psa_key_attributes_t *': |
| 216 | return (var + '={id=%u, lifetime=0x%08x, type=0x%08x, bits=%u, alg=%08x, usage=%08x}', |
| 217 | ['(unsigned) psa_get_key_{}({})'.format(field, var) |
| 218 | for field in ['id', 'lifetime', 'type', 'bits', 'algorithm', 'usage_flags']]) |
| 219 | return super()._printf_parameters(typ, var) |
| 220 | |
| 221 | |
| 222 | DEFAULT_C_OUTPUT_FILE_NAME = 'tests/src/psa_test_wrappers.c' |
| 223 | DEFAULT_H_OUTPUT_FILE_NAME = 'tests/include/test/psa_test_wrappers.h' |
| 224 | |
| 225 | def main() -> None: |
| 226 | parser = argparse.ArgumentParser(description=globals()['__doc__']) |
| 227 | parser.add_argument('--log', |
| 228 | help='Stream to log to (default: no logging code)') |
| 229 | parser.add_argument('--output-c', |
| 230 | metavar='FILENAME', |
| 231 | default=DEFAULT_C_OUTPUT_FILE_NAME, |
| 232 | help=('Output .c file path (default: {}; skip .c output if empty)' |
| 233 | .format(DEFAULT_C_OUTPUT_FILE_NAME))) |
| 234 | parser.add_argument('--output-h', |
| 235 | metavar='FILENAME', |
| 236 | default=DEFAULT_H_OUTPUT_FILE_NAME, |
| 237 | help=('Output .h file path (default: {}; skip .h output if empty)' |
| 238 | .format(DEFAULT_H_OUTPUT_FILE_NAME))) |
| 239 | options = parser.parse_args() |
| 240 | if options.log: |
| 241 | generator = PSALoggingWrapperGenerator(options.log) #type: PSAWrapperGenerator |
| 242 | else: |
| 243 | generator = PSAWrapperGenerator() |
| 244 | generator.gather_data() |
| 245 | if options.output_h: |
| 246 | generator.write_h_file(options.output_h) |
| 247 | if options.output_c: |
| 248 | generator.write_c_file(options.output_c) |
| 249 | |
| 250 | if __name__ == '__main__': |
| 251 | main() |