Generated PSA wrappers: poison/unpoison buffer parameters

For now, only instrument the one function for which buffer copying has been
implemented, namely `psa_cipher_encrypt`.

Signed-off-by: Gilles Peskine <Gilles.Peskine@arm.com>
diff --git a/tests/scripts/generate_psa_wrappers.py b/tests/scripts/generate_psa_wrappers.py
index f7a193a..656fcd8 100755
--- a/tests/scripts/generate_psa_wrappers.py
+++ b/tests/scripts/generate_psa_wrappers.py
@@ -6,8 +6,9 @@
 # SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
 
 import argparse
+import itertools
 import os
-from typing import List, Tuple
+from typing import Iterator, List, Optional, Tuple
 
 import scripts_path #pylint: disable=unused-import
 from mbedtls_dev import build_tree
@@ -16,6 +17,27 @@
 from mbedtls_dev import typing_util
 
 
+class BufferParameter:
+    """Description of an input or output buffer parameter sequence to a PSA function."""
+    #pylint: disable=too-few-public-methods
+
+    def __init__(self, i: int, is_output: bool,
+                 buffer_name: str, size_name: str) -> None:
+        """Initialize the parameter information.
+
+        i is the index of the function argument that is the pointer to the buffer.
+        The size is argument i+1. For a variable-size output, the actual length
+        goes in argument i+2.
+
+        buffer_name and size_names are the names of arguments i and i+1.
+        This class does not yet help with the output length.
+        """
+        self.index = i
+        self.buffer_name = buffer_name
+        self.size_name = size_name
+        self.is_output = is_output
+
+
 class PSAWrapperGenerator(c_wrapper_generator.Base):
     """Generate a C source file containing wrapper functions for PSA Crypto API calls."""
 
@@ -75,6 +97,59 @@
         'mbedtls_psa_platform_get_builtin_key': 'defined(MBEDTLS_PSA_CRYPTO_BUILTIN_KEYS)',
     })
 
+    @staticmethod
+    def _detect_buffer_parameters(arguments: List[c_parsing_helper.ArgumentInfo],
+                                  argument_names: List[str]) -> Iterator[BufferParameter]:
+        """Detect function arguments that are buffers (pointer, size [,length])."""
+        types = ['' if arg.suffix else arg.type for arg in arguments]
+        # pairs = list of (type_of_arg_N, type_of_arg_N+1)
+        # where each type_of_arg_X is the empty string if the type is an array
+        # or there is no argument X.
+        pairs = enumerate(itertools.zip_longest(types, types[1:], fillvalue=''))
+        for i, t01 in pairs:
+            if (t01[0] == 'const uint8_t *' or t01[0] == 'uint8_t *') and \
+               t01[1] == 'size_t':
+                yield BufferParameter(i, not t01[0].startswith('const '),
+                                      argument_names[i], argument_names[i+1])
+
+    @staticmethod
+    def _write_poison_buffer_parameter(out: typing_util.Writable,
+                                       param: BufferParameter,
+                                       poison: bool) -> None:
+        """Write poisoning or unpoisoning code for a buffer parameter.
+
+        Write poisoning code if poison is true, unpoisoning code otherwise.
+        """
+        out.write('    MBEDTLS_TEST_MEMORY_{}({}, {});\n'.format(
+            'POISON' if poison else 'UNPOISON',
+            param.buffer_name, param.size_name
+        ))
+
+    @staticmethod
+    def _parameter_should_be_copied(function_name: str,
+                                    _buffer_name: Optional[str]) -> bool:
+        """Whether the specified buffer argument to a PSA function should be copied.
+        """
+        # Proof-of-concept: just instrument one function for now
+        if function_name == 'psa_cipher_encrypt':
+            return True
+        return False
+
+    def _write_function_call(self, out: typing_util.Writable,
+                             function: c_wrapper_generator.FunctionInfo,
+                             argument_names: List[str]) -> None:
+        buffer_parameters = list(
+            param
+            for param in self._detect_buffer_parameters(function.arguments,
+                                                        argument_names)
+            if self._parameter_should_be_copied(function.name,
+                                                function.arguments[param.index].name))
+        for param in buffer_parameters:
+            self._write_poison_buffer_parameter(out, param, True)
+        super()._write_function_call(out, function, argument_names)
+        for param in buffer_parameters:
+            self._write_poison_buffer_parameter(out, param, False)
+
     def _write_prologue(self, out: typing_util.Writable, header: bool) -> None:
         super()._write_prologue(out, header)
         out.write("""
@@ -82,6 +157,7 @@
 
 #include <psa/crypto.h>
 
+#include <test/memory.h>
 #include <test/psa_crypto_helpers.h>
 #include <test/psa_test_wrappers.h>
 """