psa_constant_names: support key agreement algorithms
diff --git a/programs/psa/psa_constant_names.c b/programs/psa/psa_constant_names.c
index 5514100..5240b08 100644
--- a/programs/psa/psa_constant_names.c
+++ b/programs/psa/psa_constant_names.c
@@ -84,22 +84,21 @@
     append(buffer, buffer_size, required_size, ")", 1);
 }
 
-static void append_with_hash(char **buffer, size_t buffer_size,
-                             size_t *required_size,
-                             const char *string, size_t length,
-                             psa_algorithm_t hash_alg)
+typedef const char *(*psa_get_algorithm_name_func_ptr)(psa_algorithm_t alg);
+
+static void append_with_alg(char **buffer, size_t buffer_size,
+                            size_t *required_size,
+                            psa_get_algorithm_name_func_ptr get_name,
+                            psa_algorithm_t alg)
 {
-    const char *hash_name = psa_hash_algorithm_name(hash_alg);
-    append(buffer, buffer_size, required_size, string, length);
-    append(buffer, buffer_size, required_size, "(", 1);
-    if (hash_name != NULL) {
+    const char *name = get_name(alg);
+    if (name != NULL) {
         append(buffer, buffer_size, required_size,
-               hash_name, strlen(hash_name));
+               name, strlen(name));
     } else {
         append_integer(buffer, buffer_size, required_size,
-                       "0x%08lx", hash_alg);
+                       "0x%08lx", alg);
     }
-    append(buffer, buffer_size, required_size, ")", 1);
 }
 
 #include "psa_constant_names_generated.c"
diff --git a/scripts/generate_psa_constants.py b/scripts/generate_psa_constants.py
index 382fd23..dac6003 100755
--- a/scripts/generate_psa_constants.py
+++ b/scripts/generate_psa_constants.py
@@ -30,6 +30,14 @@
     }
 }
 
+static const char *psa_ka_algorithm_name(psa_algorithm_t ka_alg)
+{
+    switch (ka_alg) {
+    %(ka_algorithm_cases)s
+    default: return NULL;
+    }
+}
+
 static int psa_snprint_key_type(char *buffer, size_t buffer_size,
                                 psa_key_type_t type)
 {
@@ -47,12 +55,13 @@
     return (int) required_size;
 }
 
+#define NO_LENGTH_MODIFIER 0xfffffffflu
 static int psa_snprint_algorithm(char *buffer, size_t buffer_size,
                                  psa_algorithm_t alg)
 {
     size_t required_size = 0;
     psa_algorithm_t core_alg = alg;
-    unsigned long length_modifier = 0;
+    unsigned long length_modifier = NO_LENGTH_MODIFIER;
     if (PSA_ALG_IS_MAC(alg)) {
         core_alg = PSA_ALG_TRUNCATED_MAC(alg, 0);
         if (core_alg != alg) {
@@ -70,6 +79,15 @@
                    "PSA_ALG_AEAD_WITH_TAG_LENGTH(", 29);
             length_modifier = PSA_AEAD_TAG_LENGTH(alg);
         }
+    } else if (PSA_ALG_IS_KEY_AGREEMENT(alg) &&
+               !PSA_ALG_IS_RAW_KEY_AGREEMENT(alg)) {
+        core_alg = PSA_ALG_KEY_AGREEMENT_GET_KDF(alg);
+        append(&buffer, buffer_size, &required_size,
+               "PSA_ALG_KEY_AGREEMENT(", 22);
+        append_with_alg(&buffer, buffer_size, &required_size,
+                        psa_ka_algorithm_name,
+                        PSA_ALG_KEY_AGREEMENT_GET_BASE(alg));
+        append(&buffer, buffer_size, &required_size, ", ", 2);
     }
     switch (core_alg) {
     %(algorithm_cases)s
@@ -81,9 +99,11 @@
         break;
     }
     if (core_alg != alg) {
-        append(&buffer, buffer_size, &required_size, ", ", 2);
-        append_integer(&buffer, buffer_size, &required_size,
-                       "%%lu", length_modifier);
+        if (length_modifier != NO_LENGTH_MODIFIER) {
+            append(&buffer, buffer_size, &required_size, ", ", 2);
+            append_integer(&buffer, buffer_size, &required_size,
+                           "%%lu", length_modifier);
+        }
         append(&buffer, buffer_size, &required_size, ")", 1);
     }
     buffer[0] = 0;
@@ -126,9 +146,12 @@
         } else '''
 
 algorithm_from_hash_template = '''if (%(tester)s(core_alg)) {
-            append_with_hash(&buffer, buffer_size, &required_size,
-                             "%(builder)s", %(builder_length)s,
-                             PSA_ALG_GET_HASH(core_alg));
+            append(&buffer, buffer_size, &required_size,
+                   "%(builder)s(", %(builder_length)s + 1);
+            append_with_alg(&buffer, buffer_size, &required_size,
+                            psa_hash_algorithm_name,
+                            PSA_ALG_GET_HASH(core_alg));
+            append(&buffer, buffer_size, &required_size, ")", 1);
         } else '''
 
 bit_test_template = '''\
@@ -149,6 +172,7 @@
         self.ecc_curves = set()
         self.algorithms = set()
         self.hash_algorithms = set()
+        self.ka_algorithms = set()
         self.algorithms_from_hash = {}
         self.key_usages = set()
 
@@ -193,6 +217,9 @@
             # Ad hoc detection of hash algorithms
             if re.search(r'0x010000[0-9A-Fa-f]{2}', definition):
                 self.hash_algorithms.add(name)
+            # Ad hoc detection of key agreement algorithms
+            if re.search(r'0x30[0-9A-Fa-f]{2}0000', definition):
+                self.ka_algorithms.add(name)
         elif name.startswith('PSA_ALG_') and parameter == 'hash_alg':
             if name in ['PSA_ALG_DSA', 'PSA_ALG_ECDSA']:
                 # A naming irregularity
@@ -256,6 +283,10 @@
         return '\n    '.join(map(self.make_return_case,
                                  sorted(self.hash_algorithms)))
 
+    def make_ka_algorithm_cases(self):
+        return '\n    '.join(map(self.make_return_case,
+                                 sorted(self.ka_algorithms)))
+
     def make_algorithm_cases(self):
         return '\n    '.join(map(self.make_append_case,
                                  sorted(self.algorithms)))
@@ -281,6 +312,7 @@
         data['key_type_cases'] = self.make_key_type_cases()
         data['key_type_code'] = self.make_key_type_code()
         data['hash_algorithm_cases'] = self.make_hash_algorithm_cases()
+        data['ka_algorithm_cases'] = self.make_ka_algorithm_cases()
         data['algorithm_cases'] = self.make_algorithm_cases()
         data['algorithm_code'] = self.make_algorithm_code()
         data['key_usage_code'] = self.make_key_usage_code()
diff --git a/tests/scripts/test_psa_constant_names.py b/tests/scripts/test_psa_constant_names.py
index 5e128eb..421cf4e 100755
--- a/tests/scripts/test_psa_constant_names.py
+++ b/tests/scripts/test_psa_constant_names.py
@@ -63,7 +63,8 @@
         # Hard-coded value for unknown algorithms
         self.hash_algorithms = set(['0x010000fe'])
         self.mac_algorithms = set(['0x02ff00ff'])
-        self.kdf_algorithms = set(['0x300000ff', '0x310000ff'])
+        self.ka_algorithms = set(['0x30fc0000'])
+        self.kdf_algorithms = set(['0x200000ff'])
         # For AEAD algorithms, the only variability is over the tag length,
         # and this only applies to known algorithms, so don't test an
         # unknown algorithm.
@@ -89,6 +90,7 @@
 Call this after parsing all the inputs.'''
         self.arguments_for['hash_alg'] = sorted(self.hash_algorithms)
         self.arguments_for['mac_alg'] = sorted(self.mac_algorithms)
+        self.arguments_for['ka_alg'] = sorted(self.ka_algorithms)
         self.arguments_for['kdf_alg'] = sorted(self.kdf_algorithms)
         self.arguments_for['aead_alg'] = sorted(self.aead_algorithms)
         self.arguments_for['curve'] = sorted(self.ecc_curves)