blob: 3fa4135cb0b4d76f2bbd31e4e5ddcec0f1bf1157 [file] [log] [blame]
Gilles Peskine5294bb32024-01-04 16:38:17 +01001#!/usr/bin/env python3
2"""Generate wrapper functions for PSA function calls.
3"""
4
5# Copyright The Mbed TLS Contributors
6# SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
7
8import argparse
Gilles Peskine4adacac2023-12-06 19:32:52 +01009import itertools
Gilles Peskine5294bb32024-01-04 16:38:17 +010010import os
Gilles Peskine4adacac2023-12-06 19:32:52 +010011from typing import Iterator, List, Optional, Tuple
Gilles Peskine5294bb32024-01-04 16:38:17 +010012
13import scripts_path #pylint: disable=unused-import
14from mbedtls_dev import build_tree
15from mbedtls_dev import c_parsing_helper
16from mbedtls_dev import c_wrapper_generator
17from mbedtls_dev import typing_util
18
19
Gilles Peskine4adacac2023-12-06 19:32:52 +010020class BufferParameter:
21 """Description of an input or output buffer parameter sequence to a PSA function."""
22 #pylint: disable=too-few-public-methods
23
24 def __init__(self, i: int, is_output: bool,
25 buffer_name: str, size_name: str) -> None:
26 """Initialize the parameter information.
27
28 i is the index of the function argument that is the pointer to the buffer.
29 The size is argument i+1. For a variable-size output, the actual length
30 goes in argument i+2.
31
32 buffer_name and size_names are the names of arguments i and i+1.
33 This class does not yet help with the output length.
34 """
35 self.index = i
36 self.buffer_name = buffer_name
37 self.size_name = size_name
38 self.is_output = is_output
39
40
Gilles Peskine5294bb32024-01-04 16:38:17 +010041class PSAWrapperGenerator(c_wrapper_generator.Base):
42 """Generate a C source file containing wrapper functions for PSA Crypto API calls."""
43
44 _CPP_GUARDS = 'defined(MBEDTLS_PSA_CRYPTO_C) && defined(MBEDTLS_TEST_HOOKS)'
45 _WRAPPER_NAME_PREFIX = 'mbedtls_test_wrap_'
46 _WRAPPER_NAME_SUFFIX = ''
47
48 def gather_data(self) -> None:
49 root_dir = build_tree.guess_mbedtls_root()
50 for header_name in ['crypto.h', 'crypto_extra.h']:
51 header_path = os.path.join(root_dir, 'include', 'psa', header_name)
52 c_parsing_helper.read_function_declarations(self.functions, header_path)
53
54 _SKIP_FUNCTIONS = frozenset([
55 'mbedtls_psa_external_get_random', # not a library function
56 'psa_get_key_domain_parameters', # client-side function
57 'psa_get_key_slot_number', # client-side function
58 'psa_key_derivation_verify_bytes', # not implemented yet
59 'psa_key_derivation_verify_key', # not implemented yet
60 'psa_set_key_domain_parameters', # client-side function
61 ])
62
63 def _skip_function(self, function: c_wrapper_generator.FunctionInfo) -> bool:
64 if function.return_type != 'psa_status_t':
65 return True
66 if function.name in self._SKIP_FUNCTIONS:
67 return True
68 return False
69
70 # PAKE stuff: not implemented yet
71 _PAKE_STUFF = frozenset([
72 'psa_crypto_driver_pake_inputs_t *',
73 'psa_pake_cipher_suite_t *',
74 ])
75
76 def _return_variable_name(self,
77 function: c_wrapper_generator.FunctionInfo) -> str:
78 """The name of the variable that will contain the return value."""
79 if function.return_type == 'psa_status_t':
80 return 'status'
81 return super()._return_variable_name(function)
82
83 _FUNCTION_GUARDS = c_wrapper_generator.Base._FUNCTION_GUARDS.copy() \
84 #pylint: disable=protected-access
85 _FUNCTION_GUARDS.update({
86 'mbedtls_psa_register_se_key': 'defined(MBEDTLS_PSA_CRYPTO_SE_C)',
87 'mbedtls_psa_inject_entropy': 'defined(MBEDTLS_PSA_INJECT_ENTROPY)',
88 'mbedtls_psa_external_get_random': 'defined(MBEDTLS_PSA_CRYPTO_EXTERNAL_RNG)',
89 'mbedtls_psa_platform_get_builtin_key': 'defined(MBEDTLS_PSA_CRYPTO_BUILTIN_KEYS)',
90 })
91
Gilles Peskine4adacac2023-12-06 19:32:52 +010092 @staticmethod
93 def _detect_buffer_parameters(arguments: List[c_parsing_helper.ArgumentInfo],
94 argument_names: List[str]) -> Iterator[BufferParameter]:
95 """Detect function arguments that are buffers (pointer, size [,length])."""
96 types = ['' if arg.suffix else arg.type for arg in arguments]
97 # pairs = list of (type_of_arg_N, type_of_arg_N+1)
98 # where each type_of_arg_X is the empty string if the type is an array
99 # or there is no argument X.
100 pairs = enumerate(itertools.zip_longest(types, types[1:], fillvalue=''))
101 for i, t01 in pairs:
102 if (t01[0] == 'const uint8_t *' or t01[0] == 'uint8_t *') and \
103 t01[1] == 'size_t':
104 yield BufferParameter(i, not t01[0].startswith('const '),
105 argument_names[i], argument_names[i+1])
106
107 @staticmethod
108 def _write_poison_buffer_parameter(out: typing_util.Writable,
109 param: BufferParameter,
110 poison: bool) -> None:
111 """Write poisoning or unpoisoning code for a buffer parameter.
112
113 Write poisoning code if poison is true, unpoisoning code otherwise.
114 """
115 out.write(' MBEDTLS_TEST_MEMORY_{}({}, {});\n'.format(
116 'POISON' if poison else 'UNPOISON',
117 param.buffer_name, param.size_name
118 ))
119
120 @staticmethod
121 def _parameter_should_be_copied(function_name: str,
122 _buffer_name: Optional[str]) -> bool:
123 """Whether the specified buffer argument to a PSA function should be copied.
124 """
125 # Proof-of-concept: just instrument one function for now
126 if function_name == 'psa_cipher_encrypt':
127 return True
128 return False
129
130 def _write_function_call(self, out: typing_util.Writable,
131 function: c_wrapper_generator.FunctionInfo,
132 argument_names: List[str]) -> None:
133 buffer_parameters = list(
134 param
135 for param in self._detect_buffer_parameters(function.arguments,
136 argument_names)
137 if self._parameter_should_be_copied(function.name,
138 function.arguments[param.index].name))
139 for param in buffer_parameters:
140 self._write_poison_buffer_parameter(out, param, True)
141 super()._write_function_call(out, function, argument_names)
142 for param in buffer_parameters:
143 self._write_poison_buffer_parameter(out, param, False)
144
Gilles Peskine5294bb32024-01-04 16:38:17 +0100145 def _write_prologue(self, out: typing_util.Writable, header: bool) -> None:
146 super()._write_prologue(out, header)
147 out.write("""
148#if {}
149
150#include <psa/crypto.h>
151
Gilles Peskine4adacac2023-12-06 19:32:52 +0100152#include <test/memory.h>
Gilles Peskine5294bb32024-01-04 16:38:17 +0100153#include <test/psa_crypto_helpers.h>
154#include <test/psa_test_wrappers.h>
155"""
156 .format(self._CPP_GUARDS))
157
158 def _write_epilogue(self, out: typing_util.Writable, header: bool) -> None:
159 out.write("""
160#endif /* {} */
161"""
162 .format(self._CPP_GUARDS))
163 super()._write_epilogue(out, header)
164
165
166class PSALoggingWrapperGenerator(PSAWrapperGenerator, c_wrapper_generator.Logging):
167 """Generate a C source file containing wrapper functions that log PSA Crypto API calls."""
168
169 def __init__(self, stream: str) -> None:
170 super().__init__()
171 self.set_stream(stream)
172
173 _PRINTF_TYPE_CAST = c_wrapper_generator.Logging._PRINTF_TYPE_CAST.copy()
174 _PRINTF_TYPE_CAST.update({
175 'mbedtls_svc_key_id_t': 'unsigned',
176 'psa_algorithm_t': 'unsigned',
177 'psa_drv_slot_number_t': 'unsigned long long',
178 'psa_key_derivation_step_t': 'int',
179 'psa_key_id_t': 'unsigned',
180 'psa_key_slot_number_t': 'unsigned long long',
181 'psa_key_lifetime_t': 'unsigned',
182 'psa_key_type_t': 'unsigned',
183 'psa_key_usage_flags_t': 'unsigned',
184 'psa_pake_role_t': 'int',
185 'psa_pake_step_t': 'int',
186 'psa_status_t': 'int',
187 })
188
189 def _printf_parameters(self, typ: str, var: str) -> Tuple[str, List[str]]:
190 if typ.startswith('const '):
191 typ = typ[6:]
192 if typ == 'uint8_t *':
193 # Skip buffers
194 return '', []
195 if typ.endswith('operation_t *'):
196 return '', []
197 if typ in self._PAKE_STUFF:
198 return '', []
199 if typ == 'psa_key_attributes_t *':
200 return (var + '={id=%u, lifetime=0x%08x, type=0x%08x, bits=%u, alg=%08x, usage=%08x}',
201 ['(unsigned) psa_get_key_{}({})'.format(field, var)
202 for field in ['id', 'lifetime', 'type', 'bits', 'algorithm', 'usage_flags']])
203 return super()._printf_parameters(typ, var)
204
205
206DEFAULT_C_OUTPUT_FILE_NAME = 'tests/src/psa_test_wrappers.c'
207DEFAULT_H_OUTPUT_FILE_NAME = 'tests/include/test/psa_test_wrappers.h'
208
209def main() -> None:
210 parser = argparse.ArgumentParser(description=globals()['__doc__'])
211 parser.add_argument('--log',
212 help='Stream to log to (default: no logging code)')
213 parser.add_argument('--output-c',
214 metavar='FILENAME',
215 default=DEFAULT_C_OUTPUT_FILE_NAME,
216 help=('Output .c file path (default: {}; skip .c output if empty)'
217 .format(DEFAULT_C_OUTPUT_FILE_NAME)))
218 parser.add_argument('--output-h',
219 metavar='FILENAME',
220 default=DEFAULT_H_OUTPUT_FILE_NAME,
221 help=('Output .h file path (default: {}; skip .h output if empty)'
222 .format(DEFAULT_H_OUTPUT_FILE_NAME)))
223 options = parser.parse_args()
224 if options.log:
225 generator = PSALoggingWrapperGenerator(options.log) #type: PSAWrapperGenerator
226 else:
227 generator = PSAWrapperGenerator()
228 generator.gather_data()
229 if options.output_h:
230 generator.write_h_file(options.output_h)
231 if options.output_c:
232 generator.write_c_file(options.output_c)
233
234if __name__ == '__main__':
235 main()