Add test function for concurrently using the same persistent key

The thread functions can also be used in future tests for other key types
and other test scenarios

Signed-off-by: Ryan Everett <ryan.everett@arm.com>
diff --git a/tests/suites/test_suite_psa_crypto.function b/tests/suites/test_suite_psa_crypto.function
index 8fb7d44..6cb0744 100644
--- a/tests/suites/test_suite_psa_crypto.function
+++ b/tests/suites/test_suite_psa_crypto.function
@@ -1338,6 +1338,127 @@
 }
 
 #if defined(MBEDTLS_THREADING_PTHREAD)
+
+typedef struct same_key_context {
+    data_t *data;
+    mbedtls_svc_key_id_t key;
+    psa_key_attributes_t *attributes;
+    int type;
+    int bits;
+    /* The following two parameters are used to ensure that when multiple
+     * threads attempt to load/destroy the key, exactly one thread succeeds. */
+    int key_loaded;
+    mbedtls_threading_mutex_t MBEDTLS_PRIVATE(key_loaded_mutex);
+}
+same_key_context;
+
+/* Attempt to import the key in ctx. This handles any valid error codes
+ * and reports an error for any invalid codes. This function also insures
+ * that once imported by some thread, all threads can use the key. */
+void *thread_import_key(void *ctx)
+{
+    mbedtls_svc_key_id_t returned_key_id;
+    same_key_context *skc = (struct same_key_context *) ctx;
+    psa_key_attributes_t got_attributes = PSA_KEY_ATTRIBUTES_INIT;
+
+    /* Import the key, exactly one thread must succceed. */
+    psa_status_t status = psa_import_key(skc->attributes, skc->data->x,
+                                         skc->data->len, &returned_key_id);
+    switch (status) {
+        case PSA_SUCCESS:
+            if (mbedtls_mutex_lock(&skc->key_loaded_mutex) == 0) {
+                if (skc->key_loaded) {
+                    mbedtls_mutex_unlock(&skc->key_loaded_mutex);
+                    /* More than one thread has succeeded, report a failure. */
+                    TEST_EQUAL(skc->key_loaded, 0);
+                }
+                skc->key_loaded = 1;
+                mbedtls_mutex_unlock(&skc->key_loaded_mutex);
+            }
+            break;
+        case PSA_ERROR_INSUFFICIENT_MEMORY:
+            /* If all of the key slots are reserved when a thread
+             * locks the mutex to reserve a new slot, it will return
+             * PSA_ERROR_INSUFFICIENT_MEMORY; this is correct behaviour.
+             * There is a chance for this to occur here when the number of
+             * threads running this function is larger than the number of
+             * free key slots. Each thread reserves an empty key slot,
+             * unlocks the mutex, then relocks it to finalize key creation.
+             * It is at that point where the thread sees that the key
+             * already exists, releases the reserved slot,
+             * and returns PSA_ERROR_ALREADY_EXISTS.
+             * There is no guarantee that the key is loaded upon this return
+             * code, so we can't test the key information. Just stop this
+             * thread from executing, note that this is not an error. */
+            goto exit;
+            break;
+        case PSA_ERROR_ALREADY_EXISTS:
+            /* The key has been loaded by a different thread. */
+            break;
+        default:
+            PSA_ASSERT(status);
+    }
+    /* At this point the key must exist, test the key information. */
+    status = psa_get_key_attributes(skc->key, &got_attributes);
+    if (status == PSA_ERROR_INSUFFICIENT_MEMORY) {
+        /* This is not a test failure. The following sequence of events
+         * causes this to occur:
+         * 1: This thread successfuly imports a persistent key skc->key.
+         * 2: N threads reserve an empty key slot in psa_import_key,
+         *    where N is equal to the number of free key slots.
+         * 3: A final thread attempts to reserve an empty key slot, kicking
+         *    skc->key (which has no registered readers) out of its slot.
+         * 4: This thread calls psa_get_key_attributes(skc->key,...):
+         *    it sees that skc->key is not in a slot, attempts to load it and
+         *    finds that there are no free slots.
+         * This thread returns PSA_ERROR_INSUFFICIENT_MEMORY.
+         *
+         * The PSA spec allows this behaviour, it is an unavoidable consequence
+         * of allowing persistent keys to be kicked out of the key store while
+         * they are still valid. */
+        goto exit;
+    }
+    PSA_ASSERT(status);
+    TEST_EQUAL(psa_get_key_type(&got_attributes), skc->type);
+    TEST_EQUAL(psa_get_key_bits(&got_attributes), skc->bits);
+
+exit:
+    /* Key attributes may have been returned by psa_get_key_attributes(),
+     * reset them as required. */
+    psa_reset_key_attributes(&got_attributes);
+    return NULL;
+}
+
+void *thread_use_and_destroy_key(void *ctx)
+{
+    same_key_context *skc = (struct same_key_context *) ctx;
+
+    /* Do something with the key according
+     * to its type and permitted usage. */
+    TEST_ASSERT(mbedtls_test_psa_exercise_key(skc->key,
+                                              skc->attributes->policy.usage,
+                                              skc->attributes->policy.alg, 1));
+
+    psa_status_t status = psa_destroy_key(skc->key);
+    if (status == PSA_SUCCESS) {
+        if (mbedtls_mutex_lock(&skc->key_loaded_mutex) == 0) {
+            /* Ensure that we are the only thread to succeed. */
+            if (skc->key_loaded != 1) {
+                mbedtls_mutex_unlock(&skc->key_loaded_mutex);
+                //Will always fail
+                TEST_EQUAL(skc->key_loaded, 1);
+            }
+            skc->key_loaded = 0;
+            mbedtls_mutex_unlock(&skc->key_loaded_mutex);
+        }
+    } else {
+        TEST_EQUAL(status, PSA_ERROR_INVALID_HANDLE);
+    }
+
+exit:
+    return NULL;
+}
+
 typedef struct generate_key_context {
     psa_key_type_t type;
     psa_key_usage_t usage;
@@ -1824,6 +1945,78 @@
 }
 /* END_CASE */
 
+
+#if defined(MBEDTLS_THREADING_PTHREAD)
+/* BEGIN_CASE depends_on:MBEDTLS_THREADING_PTHREAD:MBEDTLS_PSA_CRYPTO_STORAGE_C */
+void concurrently_use_same_persistent_key(data_t *data,
+                                          int type_arg,
+                                          int bits_arg,
+                                          int alg_arg,
+                                          int thread_count_arg)
+{
+    size_t thread_count = (size_t) thread_count_arg;
+    mbedtls_test_thread_t *threads = NULL;
+    mbedtls_svc_key_id_t key_id = mbedtls_svc_key_id_make(1, 1);
+    same_key_context skc;
+    skc.data = data;
+    skc.key = key_id;
+    skc.type = type_arg;
+    skc.bits = bits_arg;
+    skc.key_loaded = 0;
+    mbedtls_mutex_init(&skc.key_loaded_mutex);
+    psa_key_usage_t usage = mbedtls_test_psa_usage_to_exercise(skc.type, alg_arg);
+    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
+
+    PSA_ASSERT(psa_crypto_init());
+
+    psa_set_key_id(&attributes, key_id);
+    psa_set_key_lifetime(&attributes, PSA_KEY_LIFETIME_PERSISTENT);
+    psa_set_key_usage_flags(&attributes, usage);
+    psa_set_key_algorithm(&attributes, alg_arg);
+    psa_set_key_type(&attributes, type_arg);
+    psa_set_key_bits(&attributes, bits_arg);
+    skc.attributes = &attributes;
+
+    TEST_CALLOC(threads, sizeof(mbedtls_test_thread_t) * thread_count);
+
+    /* Test that when multiple threads import the same key,
+     * exactly one thread succeeds and the rest fail with valid errors.
+     * Also test that all threads can use the key as soon as it has been
+     * imported. */
+    for (size_t i = 0; i < thread_count; i++) {
+        TEST_EQUAL(
+            mbedtls_test_thread_create(&threads[i], thread_import_key,
+                                       (void *) &skc), 0);
+    }
+
+    /* Join threads. */
+    for (size_t i = 0; i < thread_count; i++) {
+        TEST_EQUAL(mbedtls_test_thread_join(&threads[i]), 0);
+    }
+
+    /* Test that when multiple threads use and destroy a key no corruption
+     * occurs, and exactly one thread succeeds when destroying the key. */
+    for (size_t i = 0; i < thread_count; i++) {
+        TEST_EQUAL(
+            mbedtls_test_thread_create(&threads[i], thread_use_and_destroy_key,
+                                       (void *) &skc), 0);
+    }
+
+    /* Join threads. */
+    for (size_t i = 0; i < thread_count; i++) {
+        TEST_EQUAL(mbedtls_test_thread_join(&threads[i]), 0);
+    }
+    /* Ensure that one thread succeeded in destroying the key. */
+    TEST_ASSERT(!skc.key_loaded);
+exit:
+    psa_reset_key_attributes(&attributes);
+    mbedtls_mutex_free(&skc.key_loaded_mutex);
+    mbedtls_free(threads);
+    PSA_DONE();
+}
+/* END_CASE */
+#endif
+
 /* BEGIN_CASE */
 void import_and_exercise_key(data_t *data,
                              int type_arg,