Add engine field to context structure

For multi-part operations, we want to make the decision to use PSA or
not only once, during setup(), and remember it afterwards. This supports
the introduction, in the next few commits, of a dynamic component to
that decision: has the PSA driver sub-system been initialized yet?

Signed-off-by: Manuel Pégourié-Gonnard <manuel.pegourie-gonnard@arm.com>
diff --git a/include/mbedtls/md.h b/include/mbedtls/md.h
index bd44b64..ada7ad9 100644
--- a/include/mbedtls/md.h
+++ b/include/mbedtls/md.h
@@ -181,13 +181,28 @@
 typedef struct mbedtls_md_info_t mbedtls_md_info_t;
 
 /**
+ * Used internally to indicate whether a context uses legacy or PSA.
+ *
+ * Internal use only.
+ */
+typedef enum {
+    MBEDTLS_MD_ENGINE_LEGACY = 0,
+    MBEDTLS_MD_ENGINE_PSA,
+} mbedtls_md_engine_t;
+
+/**
  * The generic message-digest context.
  */
 typedef struct mbedtls_md_context_t {
     /** Information about the associated message digest. */
     const mbedtls_md_info_t *MBEDTLS_PRIVATE(md_info);
 
-    /** The digest-specific context. */
+#if defined(MBEDTLS_MD_SOME_PSA)
+    /** Are hash operations dispatched to PSA or legacy? */
+    mbedtls_md_engine_t MBEDTLS_PRIVATE(engine);
+#endif
+
+    /** The digest-specific context (legacy) or the PSA operation. */
     void *MBEDTLS_PRIVATE(md_ctx);
 
     /** The HMAC part of the context. */
diff --git a/library/md.c b/library/md.c
index 20bfd23..5b61b51 100644
--- a/library/md.c
+++ b/library/md.c
@@ -222,6 +222,7 @@
 
 void mbedtls_md_init(mbedtls_md_context_t *ctx)
 {
+    /* Note: this sets engine (if present) to MBEDTLS_MD_ENGINE_LEGACY */
     memset(ctx, 0, sizeof(mbedtls_md_context_t));
 }
 
@@ -233,7 +234,7 @@
 
     if (ctx->md_ctx != NULL) {
 #if defined(MBEDTLS_MD_SOME_PSA)
-        if (md_uses_psa(ctx->md_info) && ctx->md_ctx != NULL) {
+        if (ctx->engine == MBEDTLS_MD_ENGINE_PSA) {
             psa_hash_abort(ctx->md_ctx);
         } else
 #endif
@@ -299,7 +300,15 @@
     }
 
 #if defined(MBEDTLS_MD_SOME_PSA)
-    if (md_uses_psa(src->md_info)) {
+    if (src->engine != dst->engine) {
+        /* This can happen with src set to legacy because PSA wasn't ready
+         * yet, and dst to PSA because it became ready in the meantime.
+         * We currently don't support that case (we'd need to re-allocate
+         * md_ctx to the size of the appropriate MD context). */
+        return MBEDTLS_ERR_MD_FEATURE_UNAVAILABLE;
+    }
+
+    if (src->engine == MBEDTLS_MD_ENGINE_PSA) {
         psa_status_t status = psa_hash_clone(src->md_ctx, dst->md_ctx);
         return mbedtls_md_error_from_psa(status);
     }
@@ -373,6 +382,7 @@
         if (ctx->md_ctx == NULL) {
             return MBEDTLS_ERR_MD_ALLOC_FAILED;
         }
+        ctx->engine = MBEDTLS_MD_ENGINE_PSA;
     } else
 #endif
     switch (md_info->type) {
@@ -434,8 +444,8 @@
     }
 
 #if defined(MBEDTLS_MD_SOME_PSA)
-    psa_algorithm_t alg = psa_alg_of_md(ctx->md_info);
-    if (alg != PSA_ALG_NONE) {
+    if (ctx->engine == MBEDTLS_MD_ENGINE_PSA) {
+        psa_algorithm_t alg = psa_alg_of_md(ctx->md_info);
         psa_hash_abort(ctx->md_ctx);
         psa_status_t status = psa_hash_setup(ctx->md_ctx, alg);
         return mbedtls_md_error_from_psa(status);
@@ -483,7 +493,7 @@
     }
 
 #if defined(MBEDTLS_MD_SOME_PSA)
-    if (md_uses_psa(ctx->md_info)) {
+    if (ctx->engine == MBEDTLS_MD_ENGINE_PSA) {
         psa_status_t status = psa_hash_update(ctx->md_ctx, input, ilen);
         return mbedtls_md_error_from_psa(status);
     }
@@ -530,7 +540,7 @@
     }
 
 #if defined(MBEDTLS_MD_SOME_PSA)
-    if (md_uses_psa(ctx->md_info)) {
+    if (ctx->engine == MBEDTLS_MD_ENGINE_PSA) {
         size_t size = ctx->md_info->size;
         psa_status_t status = psa_hash_finish(ctx->md_ctx,
                                               output, size, &size);
@@ -580,10 +590,9 @@
     }
 
 #if defined(MBEDTLS_MD_SOME_PSA)
-    psa_algorithm_t alg = psa_alg_of_md(md_info);
-    if (alg != PSA_ALG_NONE) {
+    if (md_uses_psa(md_info)) {
         size_t size = md_info->size;
-        psa_status_t status = psa_hash_compute(alg,
+        psa_status_t status = psa_hash_compute(psa_alg_of_md(md_info),
                                                input, ilen,
                                                output, size, &size);
         return mbedtls_md_error_from_psa(status);