Enable support for user/peer for JPAKE

This is only partial support. Only 'client' and 'server' values are accepted for peer and user.
Remove support for role.

Signed-off-by: Przemek Stekiel <przemyslaw.stekiel@mobica.com>
diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index 5724025..31df082 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -7323,7 +7323,6 @@
     size_t user_id_len)
 {
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
-    (void) user_id;
 
     if (operation->stage != PSA_PAKE_OPERATION_STAGE_COLLECT_INPUTS) {
         status = PSA_ERROR_BAD_STATE;
@@ -7335,7 +7334,28 @@
         goto exit;
     }
 
-    return PSA_ERROR_NOT_SUPPORTED;
+    if (operation->data.inputs.peer_len != 0) {
+        status = PSA_ERROR_BAD_STATE;
+        goto exit;
+    }
+
+    /* Allow only "client" or "server" values. */
+    if (memcmp(peer_id, PSA_JPAKE_SERVER_ID, peer_id_len) != 0 &&
+        memcmp(peer_id, PSA_JPAKE_CLIENT_ID, peer_id_len) != 0) {
+        status = PSA_ERROR_NOT_SUPPORTED;
+        goto exit;
+    }
+
+    operation->data.inputs.peer = mbedtls_calloc(1, peer_id_len);
+    if (operation->data.inputs.peer == NULL) {
+        status = PSA_ERROR_INSUFFICIENT_MEMORY;
+        goto exit;
+    }
+
+    memcpy(operation->data.inputs.peer, peer_id, peer_id_len);
+    operation->data.inputs.peer_len = peer_id_len;
+
+    return PSA_SUCCESS;
 exit:
     psa_pake_abort(operation);
     return status;
@@ -7347,7 +7367,6 @@
     size_t peer_id_len)
 {
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
-    (void) peer_id;
 
     if (operation->stage != PSA_PAKE_OPERATION_STAGE_COLLECT_INPUTS) {
         status = PSA_ERROR_BAD_STATE;
@@ -7359,7 +7378,28 @@
         goto exit;
     }
 
-    return PSA_ERROR_NOT_SUPPORTED;
+    if (operation->data.inputs.user_len != 0) {
+        status = PSA_ERROR_BAD_STATE;
+        goto exit;
+    }
+
+    /* Allow only "client" or "server" values. */
+    if (memcmp(user_id, PSA_JPAKE_SERVER_ID, user_id_len) != 0 &&
+        memcmp(user_id, PSA_JPAKE_CLIENT_ID, user_id_len) != 0) {
+        status = PSA_ERROR_NOT_SUPPORTED;
+        goto exit;
+    }
+
+    operation->data.inputs.user = mbedtls_calloc(1, user_id_len);
+    if (operation->data.inputs.user == NULL) {
+        status = PSA_ERROR_INSUFFICIENT_MEMORY;
+        goto exit;
+    }
+
+    memcpy(operation->data.inputs.user, user_id, user_id_len);
+    operation->data.inputs.user_len = user_id_len;
+
+    return PSA_SUCCESS;
 exit:
     psa_pake_abort(operation);
     return status;
@@ -7372,7 +7412,7 @@
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
 
     if (operation->stage != PSA_PAKE_OPERATION_STAGE_COLLECT_INPUTS) {
-        status =  PSA_ERROR_BAD_STATE;
+        status = PSA_ERROR_BAD_STATE;
         goto exit;
     }
 
@@ -7385,9 +7425,7 @@
         goto exit;
     }
 
-    operation->data.inputs.role = role;
-
-    return PSA_SUCCESS;
+    status = PSA_ERROR_NOT_SUPPORTED;
 exit:
     psa_pake_abort(operation);
     return status;
@@ -7458,14 +7496,25 @@
     psa_crypto_driver_pake_inputs_t inputs = operation->data.inputs;
 
     if (inputs.password_len == 0 ||
-        inputs.role == PSA_PAKE_ROLE_NONE) {
+        inputs.user_len == 0 ||
+        inputs.peer_len == 0) {
         return PSA_ERROR_BAD_STATE;
     }
 
-    if (operation->alg == PSA_ALG_JPAKE &&
-        inputs.role != PSA_PAKE_ROLE_CLIENT &&
-        inputs.role != PSA_PAKE_ROLE_SERVER) {
-        return PSA_ERROR_NOT_SUPPORTED;
+    if (operation->alg == PSA_ALG_JPAKE) {
+        if (memcmp(inputs.user, PSA_JPAKE_CLIENT_ID, inputs.user_len) == 0 &&
+            memcmp(inputs.peer, PSA_JPAKE_SERVER_ID, inputs.peer_len) == 0) {
+            inputs.role = PSA_PAKE_ROLE_CLIENT;
+        } else
+        if (memcmp(inputs.user, PSA_JPAKE_SERVER_ID, inputs.user_len) == 0 &&
+            memcmp(inputs.peer, PSA_JPAKE_CLIENT_ID, inputs.peer_len) == 0) {
+            inputs.role = PSA_PAKE_ROLE_SERVER;
+        }
+
+        if (inputs.role != PSA_PAKE_ROLE_CLIENT &&
+            inputs.role != PSA_PAKE_ROLE_SERVER) {
+            return PSA_ERROR_NOT_SUPPORTED;
+        }
     }
 
     /* Clear driver context */
@@ -7477,6 +7526,10 @@
     mbedtls_platform_zeroize(inputs.password, inputs.password_len);
     mbedtls_free(inputs.password);
 
+    /* User and peer are translated to role. */
+    mbedtls_free(inputs.user);
+    mbedtls_free(inputs.peer);
+
     if (status == PSA_SUCCESS) {
 #if defined(PSA_WANT_ALG_JPAKE)
         if (operation->alg == PSA_ALG_JPAKE) {
@@ -7885,13 +7938,19 @@
         status = psa_driver_wrapper_pake_abort(operation);
     }
 
-    if (operation->stage == PSA_PAKE_OPERATION_STAGE_COLLECT_INPUTS &&
-        operation->data.inputs.password != NULL) {
-        mbedtls_platform_zeroize(operation->data.inputs.password,
+    if (operation->stage == PSA_PAKE_OPERATION_STAGE_COLLECT_INPUTS) {
+        if (operation->data.inputs.password != NULL) {
+            mbedtls_platform_zeroize(operation->data.inputs.password,
                                  operation->data.inputs.password_len);
-        mbedtls_free(operation->data.inputs.password);
+            mbedtls_free(operation->data.inputs.password);
+        }
+        if (operation->data.inputs.user != NULL) {
+            mbedtls_free(operation->data.inputs.user);
+        }
+        if (operation->data.inputs.peer != NULL) {
+            mbedtls_free(operation->data.inputs.peer);
+        }
     }
-
     memset(operation, 0, sizeof(psa_pake_operation_t));
 
     return status;