Generated PSA wrappers: poison/unpoison buffer parameters
For now, only instrument the one function for which buffer copying has been
implemented, namely `psa_cipher_encrypt`.
Signed-off-by: Gilles Peskine <Gilles.Peskine@arm.com>
diff --git a/tests/scripts/generate_psa_wrappers.py b/tests/scripts/generate_psa_wrappers.py
index f7a193a..656fcd8 100755
--- a/tests/scripts/generate_psa_wrappers.py
+++ b/tests/scripts/generate_psa_wrappers.py
@@ -6,8 +6,9 @@
# SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
import argparse
+import itertools
import os
-from typing import List, Tuple
+from typing import Iterator, List, Optional, Tuple
import scripts_path #pylint: disable=unused-import
from mbedtls_dev import build_tree
@@ -16,6 +17,27 @@
from mbedtls_dev import typing_util
+class BufferParameter:
+ """Description of an input or output buffer parameter sequence to a PSA function."""
+ #pylint: disable=too-few-public-methods
+
+ def __init__(self, i: int, is_output: bool,
+ buffer_name: str, size_name: str) -> None:
+ """Initialize the parameter information.
+
+ i is the index of the function argument that is the pointer to the buffer.
+ The size is argument i+1. For a variable-size output, the actual length
+ goes in argument i+2.
+
+ buffer_name and size_names are the names of arguments i and i+1.
+ This class does not yet help with the output length.
+ """
+ self.index = i
+ self.buffer_name = buffer_name
+ self.size_name = size_name
+ self.is_output = is_output
+
+
class PSAWrapperGenerator(c_wrapper_generator.Base):
"""Generate a C source file containing wrapper functions for PSA Crypto API calls."""
@@ -75,6 +97,59 @@
'mbedtls_psa_platform_get_builtin_key': 'defined(MBEDTLS_PSA_CRYPTO_BUILTIN_KEYS)',
})
+ @staticmethod
+ def _detect_buffer_parameters(arguments: List[c_parsing_helper.ArgumentInfo],
+ argument_names: List[str]) -> Iterator[BufferParameter]:
+ """Detect function arguments that are buffers (pointer, size [,length])."""
+ types = ['' if arg.suffix else arg.type for arg in arguments]
+ # pairs = list of (type_of_arg_N, type_of_arg_N+1)
+ # where each type_of_arg_X is the empty string if the type is an array
+ # or there is no argument X.
+ pairs = enumerate(itertools.zip_longest(types, types[1:], fillvalue=''))
+ for i, t01 in pairs:
+ if (t01[0] == 'const uint8_t *' or t01[0] == 'uint8_t *') and \
+ t01[1] == 'size_t':
+ yield BufferParameter(i, not t01[0].startswith('const '),
+ argument_names[i], argument_names[i+1])
+
+ @staticmethod
+ def _write_poison_buffer_parameter(out: typing_util.Writable,
+ param: BufferParameter,
+ poison: bool) -> None:
+ """Write poisoning or unpoisoning code for a buffer parameter.
+
+ Write poisoning code if poison is true, unpoisoning code otherwise.
+ """
+ out.write(' MBEDTLS_TEST_MEMORY_{}({}, {});\n'.format(
+ 'POISON' if poison else 'UNPOISON',
+ param.buffer_name, param.size_name
+ ))
+
+ @staticmethod
+ def _parameter_should_be_copied(function_name: str,
+ _buffer_name: Optional[str]) -> bool:
+ """Whether the specified buffer argument to a PSA function should be copied.
+ """
+ # Proof-of-concept: just instrument one function for now
+ if function_name == 'psa_cipher_encrypt':
+ return True
+ return False
+
+ def _write_function_call(self, out: typing_util.Writable,
+ function: c_wrapper_generator.FunctionInfo,
+ argument_names: List[str]) -> None:
+ buffer_parameters = list(
+ param
+ for param in self._detect_buffer_parameters(function.arguments,
+ argument_names)
+ if self._parameter_should_be_copied(function.name,
+ function.arguments[param.index].name))
+ for param in buffer_parameters:
+ self._write_poison_buffer_parameter(out, param, True)
+ super()._write_function_call(out, function, argument_names)
+ for param in buffer_parameters:
+ self._write_poison_buffer_parameter(out, param, False)
+
def _write_prologue(self, out: typing_util.Writable, header: bool) -> None:
super()._write_prologue(out, header)
out.write("""
@@ -82,6 +157,7 @@
#include <psa/crypto.h>
+#include <test/memory.h>
#include <test/psa_crypto_helpers.h>
#include <test/psa_test_wrappers.h>
"""