Add knowledge of algorithms

Determine the category of operations supported by an algorithm based
on its name.

Signed-off-by: Gilles Peskine <Gilles.Peskine@arm.com>
diff --git a/scripts/mbedtls_dev/crypto_knowledge.py b/scripts/mbedtls_dev/crypto_knowledge.py
index 80ad4b2..05db887 100644
--- a/scripts/mbedtls_dev/crypto_knowledge.py
+++ b/scripts/mbedtls_dev/crypto_knowledge.py
@@ -18,11 +18,21 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import enum
 import re
 from typing import Dict, Iterable, Optional, Pattern, Tuple
 
 from mbedtls_dev.asymmetric_key_data import ASYMMETRIC_KEY_DATA
 
+
+BLOCK_MAC_MODES = frozenset(['CBC_MAC', 'CMAC'])
+BLOCK_CIPHER_MODES = frozenset([
+    'CTR', 'CFB', 'OFB', 'XTS', 'CCM_STAR_NO_TAG',
+    'ECB_NO_PADDING', 'CBC_NO_PADDING', 'CBC_PKCS7',
+])
+BLOCK_AEAD_MODES = frozenset(['CCM', 'GCM'])
+
+
 class KeyType:
     """Knowledge about a PSA key type."""
 
@@ -151,3 +161,144 @@
         """
         # This is just temporaly solution for the implicit usage flags.
         return re.match(self.KEY_TYPE_FOR_SIGNATURE[usage], self.name) is not None
+
+
+class AlgorithmCategory(enum.Enum):
+    """PSA algorithm categories."""
+    # The numbers are aligned with the category bits in numerical values of
+    # algorithms.
+    HASH = 2
+    MAC = 3
+    CIPHER = 4
+    AEAD = 5
+    SIGN = 6
+    ASYMMETRIC_ENCRYPTION = 7
+    KEY_DERIVATION = 8
+    KEY_AGREEMENT = 9
+    PAKE = 10
+
+    def requires_key(self) -> bool:
+        return self not in {self.HASH, self.KEY_DERIVATION}
+
+
+class AlgorithmNotRecognized(Exception):
+    def __init__(self, expr: str) -> None:
+        super().__init__('Algorithm not recognized: ' + expr)
+        self.expr = expr
+
+
+class Algorithm:
+    """Knowledge about a PSA algorithm."""
+
+    @staticmethod
+    def determine_base(expr: str) -> str:
+        """Return an expression for the "base" of the algorithm.
+
+        This strips off variants of algorithms such as MAC truncation.
+
+        This function does not attempt to detect invalid inputs.
+        """
+        m = re.match(r'PSA_ALG_(?:'
+                     r'(?:TRUNCATED|AT_LEAST_THIS_LENGTH)_MAC|'
+                     r'AEAD_WITH_(?:SHORTENED|AT_LEAST_THIS_LENGTH)_TAG'
+                     r')\((.*),[^,]+\)\Z', expr)
+        if m:
+            expr = m.group(1)
+        return expr
+
+    @staticmethod
+    def determine_head(expr: str) -> str:
+        """Return the head of an algorithm expression.
+
+        The head is the first (outermost) constructor, without its PSA_ALG_
+        prefix, and with some normalization of similar algorithms.
+        """
+        m = re.match(r'PSA_ALG_(?:DETERMINISTIC_)?(\w+)', expr)
+        if not m:
+            raise AlgorithmNotRecognized(expr)
+        head = m.group(1)
+        if head == 'KEY_AGREEMENT':
+            m = re.match(r'PSA_ALG_KEY_AGREEMENT\s*\(\s*PSA_ALG_(\w+)', expr)
+            if not m:
+                raise AlgorithmNotRecognized(expr)
+            head = m.group(1)
+        head = re.sub(r'_ANY\Z', r'', head)
+        if re.match(r'ED[0-9]+PH\Z', head):
+            head = 'EDDSA_PREHASH'
+        return head
+
+    CATEGORY_FROM_HEAD = {
+        'SHA': AlgorithmCategory.HASH,
+        'SHAKE256_512': AlgorithmCategory.HASH,
+        'MD': AlgorithmCategory.HASH,
+        'RIPEMD': AlgorithmCategory.HASH,
+        'ANY_HASH': AlgorithmCategory.HASH,
+        'HMAC': AlgorithmCategory.MAC,
+        'STREAM_CIPHER': AlgorithmCategory.CIPHER,
+        'CHACHA20_POLY1305': AlgorithmCategory.AEAD,
+        'DSA': AlgorithmCategory.SIGN,
+        'ECDSA': AlgorithmCategory.SIGN,
+        'EDDSA': AlgorithmCategory.SIGN,
+        'PURE_EDDSA': AlgorithmCategory.SIGN,
+        'RSA_PSS': AlgorithmCategory.SIGN,
+        'RSA_PKCS1V15_SIGN': AlgorithmCategory.SIGN,
+        'RSA_PKCS1V15_CRYPT': AlgorithmCategory.ASYMMETRIC_ENCRYPTION,
+        'RSA_OAEP': AlgorithmCategory.ASYMMETRIC_ENCRYPTION,
+        'HKDF': AlgorithmCategory.KEY_DERIVATION,
+        'TLS12_PRF': AlgorithmCategory.KEY_DERIVATION,
+        'TLS12_PSK_TO_MS': AlgorithmCategory.KEY_DERIVATION,
+        'PBKDF': AlgorithmCategory.KEY_DERIVATION,
+        'ECDH': AlgorithmCategory.KEY_AGREEMENT,
+        'FFDH': AlgorithmCategory.KEY_AGREEMENT,
+        # KEY_AGREEMENT(...) is a key derivation with a key agreement component
+        'KEY_AGREEMENT': AlgorithmCategory.KEY_DERIVATION,
+        'JPAKE': AlgorithmCategory.PAKE,
+    }
+    for x in BLOCK_MAC_MODES:
+        CATEGORY_FROM_HEAD[x] = AlgorithmCategory.MAC
+    for x in BLOCK_CIPHER_MODES:
+        CATEGORY_FROM_HEAD[x] = AlgorithmCategory.CIPHER
+    for x in BLOCK_AEAD_MODES:
+        CATEGORY_FROM_HEAD[x] = AlgorithmCategory.AEAD
+
+    def determine_category(self, expr: str, head: str) -> AlgorithmCategory:
+        """Return the category of the given algorithm expression.
+
+        This function does not attempt to detect invalid inputs.
+        """
+        prefix = head
+        while prefix:
+            if prefix in self.CATEGORY_FROM_HEAD:
+                return self.CATEGORY_FROM_HEAD[prefix]
+            if re.match(r'.*[0-9]\Z', prefix):
+                prefix = re.sub(r'_*[0-9]+\Z', r'', prefix)
+            else:
+                prefix = re.sub(r'_*[^_]*\Z', r'', prefix)
+        raise AlgorithmNotRecognized(expr)
+
+    @staticmethod
+    def determine_wildcard(expr) -> bool:
+        """Whether the given algorithm expression is a wildcard.
+
+        This function does not attempt to detect invalid inputs.
+        """
+        if re.search(r'\bPSA_ALG_ANY_HASH\b', expr):
+            return True
+        if re.search(r'_AT_LEAST_', expr):
+            return True
+        return False
+
+    def __init__(self, expr: str) -> None:
+        """Analyze an algorithm value.
+
+        The algorithm must be expressed as a C expression containing only
+        calls to PSA algorithm constructor macros and numeric literals.
+
+        This class is only programmed to handle valid expressions. Invalid
+        expressions may result in exceptions or in nonsensical results.
+        """
+        self.expression = re.sub(r'\s+', r'', expr)
+        self.base_expression = self.determine_base(self.expression)
+        self.head = self.determine_head(self.base_expression)
+        self.category = self.determine_category(self.base_expression, self.head)
+        self.is_wildcard = self.determine_wildcard(self.expression)