Add crypto config support for config.py

Signed-off-by: Gabor Mezei <gabor.mezei@arm.com>
diff --git a/scripts/config.py b/scripts/config.py
index 7c32db1..ad1f787 100755
--- a/scripts/config.py
+++ b/scripts/config.py
@@ -19,6 +19,8 @@
 import os
 import re
 
+from abc import ABCMeta, abstractmethod
+
 class Setting:
     """Representation of one Mbed TLS mbedtls_config.h setting.
 
@@ -31,11 +33,12 @@
     * section: the name of the section that contains this symbol.
     """
     # pylint: disable=too-few-public-methods
-    def __init__(self, active, name, value='', section=None):
+    def __init__(self, active, name, value='', section=None, configfile=None):
         self.active = active
         self.name = name
         self.value = value
         self.section = section
+        self.configfile = configfile
 
 class Config:
     """Representation of the Mbed TLS configuration.
@@ -54,7 +57,7 @@
       name to become set.
     """
 
-    def __init__(self):
+    def __init__(self, **kw):
         self.settings = {}
 
     def __contains__(self, name):
@@ -152,7 +155,7 @@
 
 def is_full_section(section):
     """Is this section affected by "config.py full" and friends?"""
-    return section.endswith('support') or section.endswith('modules')
+    return section is None or section.endswith('support') or section.endswith('modules')
 
 def realfull_adapter(_name, active, section):
     """Activate all symbols found in the global and boolean feature sections.
@@ -168,6 +171,22 @@
         return active
     return True
 
+UNSUPPORTED_FEATURE = frozenset([
+    'PSA_WANT_ALG_CBC_MAC',
+    'PSA_WANT_ALG_XTS',
+    'PSA_WANT_KEY_TYPE_RSA_KEY_PAIR_DERIVE',
+    'PSA_WANT_KEY_TYPE_DH_KEY_PAIR_DERIVE'
+])
+
+DEPRECATED_FEATURE = frozenset([
+    'PSA_WANT_KEY_TYPE_ECC_KEY_PAIR',
+    'PSA_WANT_KEY_TYPE_RSA_KEY_PAIR'
+])
+
+UNSTABLE_FEATURE = frozenset([
+    'PSA_WANT_ECC_SECP_K1_224'
+])
+
 # The goal of the full configuration is to have everything that can be tested
 # together. This includes deprecated or insecure options. It excludes:
 # * Options that require additional build dependencies or unusual hardware.
@@ -236,7 +255,8 @@
 
 def include_in_full(name):
     """Rules for symbols in the "full" configuration."""
-    if name in EXCLUDE_FROM_FULL:
+    if name in (EXCLUDE_FROM_FULL | UNSUPPORTED_FEATURE |
+                DEPRECATED_FEATURE | UNSTABLE_FEATURE):
         return False
     if name.endswith('_ALT'):
         return is_seamless_alt(name)
@@ -368,43 +388,21 @@
         return adapter(name, active, section)
     return continuation
 
-class ConfigFile(Config):
-    """Representation of the Mbed TLS configuration read for a file.
-
-    See the documentation of the `Config` class for methods to query
-    and modify the configuration.
-    """
-
-    _path_in_tree = 'include/mbedtls/mbedtls_config.h'
-    default_path = [_path_in_tree,
-                    os.path.join(os.path.dirname(__file__),
-                                 os.pardir,
-                                 _path_in_tree),
-                    os.path.join(os.path.dirname(os.path.abspath(os.path.dirname(__file__))),
-                                 _path_in_tree)]
-
-    def __init__(self, filename=None):
-        """Read the Mbed TLS configuration file."""
+class ConfigFile(metaclass=ABCMeta):
+    def __init__(self, default_path, filename=None, name=''):
         if filename is None:
-            for candidate in self.default_path:
+            for candidate in default_path:
                 if os.path.lexists(candidate):
                     filename = candidate
                     break
             else:
-                raise Exception('Mbed TLS configuration file not found',
-                                self.default_path)
-        super().__init__()
-        self.filename = filename
-        self.inclusion_guard = None
-        self.current_section = 'header'
-        with open(filename, 'r', encoding='utf-8') as file:
-            self.templates = [self._parse_line(line) for line in file]
-        self.current_section = None
+                raise Exception(name + ' configuration file not found',
+                                default_path)
 
-    def set(self, name, value=None):
-        if name not in self.settings:
-            self.templates.append((name, '', '#define ' + name + ' '))
-        super().set(name, value)
+        self.filename = filename
+        self.templates = []
+        self.current_section = None
+        self.inclusion_guard = None
 
     _define_line_regexp = (r'(?P<indentation>\s*)' +
                            r'(?P<commented_out>(//\s*)?)' +
@@ -420,39 +418,87 @@
                                                 _ifndef_line_regexp,
                                                 _section_line_regexp]))
     def _parse_line(self, line):
-        """Parse a line in mbedtls_config.h and return the corresponding template."""
+        """Parse a line in the config file and return the corresponding template."""
         line = line.rstrip('\r\n')
         m = re.match(self._config_line_regexp, line)
         if m is None:
-            return line
+            self.templates.append(line)
+            return None
         elif m.group('section'):
             self.current_section = m.group('section')
-            return line
+            self.templates.append(line)
+            return None
         elif m.group('inclusion_guard') and self.inclusion_guard is None:
             self.inclusion_guard = m.group('inclusion_guard')
-            return line
+            self.templates.append(line)
+            return None
         else:
             active = not m.group('commented_out')
             name = m.group('name')
             value = m.group('value')
             if name == self.inclusion_guard and value == '':
                 # The file double-inclusion guard is not an option.
-                return line
+                self.templates.append(line)
+                return None
             template = (name,
                         m.group('indentation'),
                         m.group('define') + name +
                         m.group('arguments') + m.group('separator'))
-            self.settings[name] = Setting(active, name, value,
-                                          self.current_section)
-            return template
+            self.templates.append(template)
 
-    def _format_template(self, name, indent, middle):
+            return (active, name, value, self.current_section)
+
+    def parse_file(self):
+        with open(self.filename, 'r', encoding='utf-8') as file:
+            for line in file:
+                setting = self._parse_line(line)
+                if setting is not None:
+                    yield setting
+        self.current_section = None
+
+    @abstractmethod
+    def _format_template(self, setting, name, indent, middle):
+        pass
+
+    def write_to_stream(self, settings, output):
+        """Write the whole configuration to output."""
+        for template in self.templates:
+            if isinstance(template, str):
+                line = template
+            else:
+                name, _, _ = template
+                line = self._format_template(settings[name], *template)
+            output.write(line + '\n')
+
+    def write(self, settings, filename=None):
+        """Write the whole configuration to the file it was read from.
+
+        If filename is specified, write to this file instead.
+        """
+        if filename is None:
+            filename = self.filename
+        with open(filename, 'w', encoding='utf-8') as output:
+            self.write_to_stream(settings, output)
+
+class MbedtlsConfigFile(ConfigFile):
+    _path_in_tree = 'include/mbedtls/mbedtls_config.h'
+    default_path = [_path_in_tree,
+                    os.path.join(os.path.dirname(__file__),
+                                 os.pardir,
+                                 _path_in_tree),
+                    os.path.join(os.path.dirname(os.path.abspath(os.path.dirname(__file__))),
+                                 _path_in_tree)]
+
+    def __init__(self, filename=None):
+        super().__init__(self.default_path, filename, 'Mbed TLS')
+        self.current_section = 'header'
+
+    def _format_template(self, setting, name, indent, middle):
         """Build a line for mbedtls_config.h for the given setting.
 
         The line has the form "<indent>#define <name> <value>"
         where <middle> is "#define <name> ".
         """
-        setting = self.settings[name]
         value = setting.value
         if value is None:
             value = ''
@@ -470,24 +516,110 @@
                         middle,
                         value]).rstrip()
 
-    def write_to_stream(self, output):
-        """Write the whole configuration to output."""
-        for template in self.templates:
-            if isinstance(template, str):
-                line = template
-            else:
-                line = self._format_template(*template)
-            output.write(line + '\n')
+class CryptoConfigFile(ConfigFile):
+    _path_in_tree = 'tf-psa-crypto/include/psa/crypto_config.h'
+    default_path = [_path_in_tree,
+                    os.path.join(os.path.dirname(__file__),
+                                 os.pardir,
+                                 _path_in_tree),
+                    os.path.join(os.path.dirname(os.path.abspath(os.path.dirname(__file__))),
+                                 _path_in_tree)]
+
+    def __init__(self, filename=None):
+        super().__init__(self.default_path, filename, 'Crypto')
+
+    def _format_template(self, setting, name, indent, middle):
+        """Build a line for crypto_config.h for the given setting.
+
+        The line has the form "<indent>#define <name> <value>"
+        where <middle> is "#define <name> ".
+        """
+        value = setting.value
+        if value is None:
+            value = '1'
+        if middle[-1] not in '\t ':
+            middle += ' '
+        return ''.join([indent,
+                        '' if setting.active else '//',
+                        middle,
+                        value]).rstrip()
+
+class MbedtlsConfig(Config):
+    """Representation of the Mbed TLS configuration read for a file.
+
+    See the documentation of the `Config` class for methods to query
+    and modify the configuration.
+    """
+    def __init__(self, mbedtls_config=None, **kw):
+        """Read the Mbed TLS configuration file."""
+        super().__init__()
+        self.mbedtls_config = MbedtlsConfigFile(mbedtls_config)
+        self.settings.update({name: Setting(active, name, value, section, self.mbedtls_config)
+                                for (active, name, value, section)
+                                in self.mbedtls_config.parse_file()})
+
+    def set(self, name, value=None):
+        if name not in self.settings:
+            self.mbedtls_config.templates.append((name, '', '#define ' + name + ' '))
+        super().set(name, value)
 
     def write(self, filename=None):
-        """Write the whole configuration to the file it was read from.
+        self.mbedtls_config.write(self.settings, filename)
 
-        If filename is specified, write to this file instead.
-        """
-        if filename is None:
-            filename = self.filename
-        with open(filename, 'w', encoding='utf-8') as output:
-            self.write_to_stream(output)
+    def filename(self, name):
+        return self.mbedtls_config.filename
+
+class CryptoConfig(Config):
+    """Representation of the PSA crypto configuration read for a file.
+
+    See the documentation of the `Config` class for methods to query
+    and modify the configuration.
+    """
+    def __init__(self, crypto_config=None, **kw):
+        """Read the PSA crypto configuration file."""
+        super().__init__()
+        self.crypto_config = CryptoConfigFile(crypto_config)
+        self.settings.update({name: Setting(active, name, value, section, self.crypto_config)
+                                for (active, name, value, section)
+                                in self.crypto_config.parse_file()})
+
+    def set(self, name, value=None):
+        if name in UNSUPPORTED_FEATURE:
+            raise ValueError('Feature is unsupported: \'{}\''.format(name))
+        if name in UNSTABLE_FEATURE:
+            raise ValueError('Feature is unstable: \'{}\''.format(name))
+
+        if name not in self.settings:
+            self.crypto_config.templates.append((name, '', '#define ' + name + ' ' + '1'))
+        super().set(name, value)
+
+    def write(self, filename=None):
+        self.crypto_config.write(self.settings, filename)
+
+    def filename(self, name):
+        return self.crypto_config.filename
+
+class MultiConfig(MbedtlsConfig, CryptoConfig):
+
+    def __init__(self, mbedtls_config, crypto_config):
+        super().__init__(mbedtls_config=mbedtls_config, crypto_config=crypto_config)
+
+    _crypto_regexp = re.compile(r'$PSA_.*')
+    def _get_related_config(self, name):
+        if re.match(self._crypto_regexp, name):
+            return CryptoConfig
+        else:
+            return MbedtlsConfig
+
+    def set(self, name, value=None):
+        super(self._get_related_config(name), self).set(name, value)
+
+    def write(self, mbedtls_file=None, crypto_file=None):
+        self.mbedtls_config.write(self.settings, mbedtls_file)
+        self.crypto_config.write(self.settings, crypto_file)
+
+    def filename(self, name):
+        return self.settings[name].configfile
 
 if __name__ == '__main__':
     def main():
@@ -498,7 +630,11 @@
         parser.add_argument('--file', '-f',
                             help="""File to read (and modify if requested).
                             Default: {}.
-                            """.format(ConfigFile.default_path))
+                            """.format(MbedtlsConfigFile.default_path))
+        parser.add_argument('--cryptofile', '-c',
+                            help="""Crypto file to read (and modify if requested).
+                            Default: {}.
+                            """.format(CryptoConfigFile.default_path))
         parser.add_argument('--force', '-o',
                             action='store_true',
                             help="""For the set command, if SYMBOL is not
@@ -576,7 +712,7 @@
                     excluding X.509 and TLS.""")
 
         args = parser.parse_args()
-        config = ConfigFile(args.file)
+        config = MultiConfig(args.file, args.cryptofile)
         if args.command is None:
             parser.print_help()
             return 1
@@ -590,7 +726,7 @@
             if not args.force and args.symbol not in config.settings:
                 sys.stderr.write("A #define for the symbol {} "
                                  "was not found in {}\n"
-                                 .format(args.symbol, config.filename))
+                                 .format(args.symbol, config.filename(args.symbol)))
                 return 1
             config.set(args.symbol, value=args.value)
         elif args.command == 'set-all':