Fix optional parameter handling in crypto service

The psa_crypto api mandates that the optional 'salt' parameter passed
to psa_asymmetric_encrypt/decrypt should be NULL if there is no
salt.  This fix modifies packed-c encoding to omit the salt TLV
if not needed.  Deserialization logic is also made more tolerant
to a missing or zero length salt parameter.  A new test case
is added to cover an asymmetric encrypt/decrypt with a salt.

Signed-off-by: Julian Hall <julian.hall@arm.com>
Change-Id: I08d23ea11a8e75bd880367dcb380805ce921033f
diff --git a/components/service/crypto/client/cpp/packed-c/packedc_crypto_client.cpp b/components/service/crypto/client/cpp/packed-c/packedc_crypto_client.cpp
index 0c570c3..3685219 100644
--- a/components/service/crypto/client/cpp/packed-c/packedc_crypto_client.cpp
+++ b/components/service/crypto/client/cpp/packed-c/packedc_crypto_client.cpp
@@ -441,22 +441,26 @@
     psa_status_t psa_status = PSA_ERROR_GENERIC_ERROR;
     struct ts_crypto_asymmetric_encrypt_in req_msg;
     size_t req_fixed_len = sizeof(ts_crypto_asymmetric_encrypt_in);
-    size_t req_len = req_fixed_len + tlv_required_space(input_length) + tlv_required_space(salt_length);
+    size_t req_len = req_fixed_len;
 
     *output_length = 0;  /* For failure case */
 
     req_msg.id = id;
     req_msg.alg = alg;
 
+    /* Mandatory parameter */
     struct tlv_record plaintext_record;
     plaintext_record.tag = TS_CRYPTO_ASYMMETRIC_ENCRYPT_IN_TAG_PLAINTEXT;
     plaintext_record.length = input_length;
     plaintext_record.value = input;
+    req_len += tlv_required_space(plaintext_record.length);
 
+    /* Optional parameter */
     struct tlv_record salt_record;
     salt_record.tag = TS_CRYPTO_ASYMMETRIC_ENCRYPT_IN_TAG_SALT;
-    salt_record.length = salt_length;
+    salt_record.length = (salt) ? salt_length : 0;
     salt_record.value = salt;
+    if (salt) req_len += tlv_required_space(salt_record.length);
 
     rpc_call_handle call_handle;
     uint8_t *req_buf;
@@ -474,7 +478,7 @@
 
         tlv_iterator_begin(&req_iter, &req_buf[req_fixed_len], req_len - req_fixed_len);
         tlv_encode(&req_iter, &plaintext_record);
-        tlv_encode(&req_iter, &salt_record);
+        if (salt) tlv_encode(&req_iter, &salt_record);
 
         m_err_rpc_status = rpc_caller_invoke(m_caller, call_handle,
                     TS_CRYPTO_OPCODE_ASYMMETRIC_ENCRYPT, &opstatus, &resp_buf, &resp_len);
@@ -522,22 +526,26 @@
     psa_status_t psa_status = PSA_ERROR_GENERIC_ERROR;
     struct ts_crypto_asymmetric_decrypt_in req_msg;
     size_t req_fixed_len = sizeof(ts_crypto_asymmetric_decrypt_in);
-    size_t req_len = req_fixed_len + tlv_required_space(input_length) + tlv_required_space(salt_length);
+    size_t req_len = req_fixed_len;
 
     *output_length = 0;  /* For failure case */
 
     req_msg.id = id;
     req_msg.alg = alg;
 
+    /* Mandatory parameter */
     struct tlv_record ciphertext_record;
     ciphertext_record.tag = TS_CRYPTO_ASYMMETRIC_DECRYPT_IN_TAG_CIPHERTEXT;
     ciphertext_record.length = input_length;
     ciphertext_record.value = input;
+    req_len += tlv_required_space(ciphertext_record.length);
 
+    /* Optional parameter */
     struct tlv_record salt_record;
     salt_record.tag = TS_CRYPTO_ASYMMETRIC_DECRYPT_IN_TAG_SALT;
-    salt_record.length = salt_length;
+    salt_record.length = (salt) ? salt_length : 0;
     salt_record.value = salt;
+    if (salt) req_len += tlv_required_space(salt_record.length);
 
     rpc_call_handle call_handle;
     uint8_t *req_buf;
@@ -555,7 +563,7 @@
 
         tlv_iterator_begin(&req_iter, &req_buf[req_fixed_len], req_len - req_fixed_len);
         tlv_encode(&req_iter, &ciphertext_record);
-        tlv_encode(&req_iter, &salt_record);
+        if (salt) tlv_encode(&req_iter, &salt_record);
 
         m_err_rpc_status = rpc_caller_invoke(m_caller, call_handle,
                     TS_CRYPTO_OPCODE_ASYMMETRIC_DECRYPT, &opstatus, &resp_buf, &resp_len);
diff --git a/components/service/crypto/provider/mbedcrypto/crypto_provider.c b/components/service/crypto/provider/mbedcrypto/crypto_provider.c
index 389e8bc..1b2fffd 100644
--- a/components/service/crypto/provider/mbedcrypto/crypto_provider.c
+++ b/components/service/crypto/provider/mbedcrypto/crypto_provider.c
@@ -421,9 +421,12 @@
 
                     if (plaintext_buffer) {
 
+                        /* Salt is an optional parameter */
+                        uint8_t *salt = (salt_len) ? salt_buffer : NULL;
+
                         psa_status = psa_asymmetric_decrypt(id, alg,
                                     ciphertext_buffer, ciphertext_len,
-                                    salt_buffer, salt_len,
+                                    salt, salt_len,
                                     plaintext_buffer, max_decrypt_size, &plaintext_len);
 
                         if (psa_status == PSA_SUCCESS) {
@@ -500,9 +503,12 @@
 
                     if (ciphertext_buffer) {
 
+                        /* Salt is an optional parameter */
+                        uint8_t *salt = (salt_len) ? salt_buffer : NULL;
+
                         psa_status = psa_asymmetric_encrypt(id, alg,
                                     plaintext_buffer, plaintext_len,
-                                    salt_buffer, salt_len,
+                                    salt, salt_len,
                                     ciphertext_buffer, max_encrypt_size, &ciphertext_len);
 
                         if (psa_status == PSA_SUCCESS) {
diff --git a/components/service/crypto/provider/serializer/protobuf/pb_crypto_provider_serializer.c b/components/service/crypto/provider/serializer/protobuf/pb_crypto_provider_serializer.c
index 2e2edf8..dd9e0c5 100644
--- a/components/service/crypto/provider/serializer/protobuf/pb_crypto_provider_serializer.c
+++ b/components/service/crypto/provider/serializer/protobuf/pb_crypto_provider_serializer.c
@@ -339,6 +339,7 @@
 
     pb_bytes_array_t *salt_buffer = pb_malloc_byte_array(*salt_len);
     recv_msg.salt = pb_in_byte_array(salt_buffer);
+    *salt_len = 0;  /* Default for optional parameter */
 
     pb_istream_t istream = pb_istream_from_buffer((const uint8_t*)req_buf->data, req_buf->data_len);
 
@@ -403,6 +404,7 @@
 
     pb_bytes_array_t *salt_buffer = pb_malloc_byte_array(*salt_len);
     recv_msg.salt = pb_in_byte_array(salt_buffer);
+    *salt_len = 0;  /* Default for optional parameter */
 
     pb_istream_t istream = pb_istream_from_buffer((const uint8_t*)req_buf->data, req_buf->data_len);
 
diff --git a/components/service/crypto/test/service/crypto_service_scenarios.cpp b/components/service/crypto/test/service/crypto_service_scenarios.cpp
index bbd417b..e0fa43e 100644
--- a/components/service/crypto/test/service/crypto_service_scenarios.cpp
+++ b/components/service/crypto/test/service/crypto_service_scenarios.cpp
@@ -261,6 +261,57 @@
     CHECK_EQUAL(PSA_SUCCESS, status);
 }
 
+void crypto_service_scenarios::asymEncryptDecryptWithSalt()
+{
+    psa_status_t status;
+    psa_key_attributes_t attributes = PSA_KEY_ATTRIBUTES_INIT;
+    psa_key_handle_t key_handle;
+
+    psa_set_key_id(&attributes, 15);
+    psa_set_key_usage_flags(&attributes, PSA_KEY_USAGE_ENCRYPT | PSA_KEY_USAGE_DECRYPT);
+    psa_set_key_algorithm(&attributes,  PSA_ALG_RSA_OAEP(PSA_ALG_SHA_256));
+    psa_set_key_type(&attributes, PSA_KEY_TYPE_RSA_KEY_PAIR);
+    psa_set_key_bits(&attributes, 1024);
+
+    /* Generate a key */
+    status = m_crypto_client->generate_key(&attributes, &key_handle);
+    CHECK_EQUAL(PSA_SUCCESS, status);
+
+    psa_reset_key_attributes(&attributes);
+
+    /* Encrypt a message */
+    uint8_t message[] = {'q','u','i','c','k','b','r','o','w','n','f','o','x'};
+    uint8_t ciphertext[128];
+    size_t ciphertext_len = 0;
+
+    /* With salt */
+    uint8_t salt[] = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16};
+
+    status = m_crypto_client->asymmetric_encrypt(key_handle, PSA_ALG_RSA_OAEP(PSA_ALG_SHA_256),
+                            message, sizeof(message),
+                            salt, sizeof(salt),
+                            ciphertext, sizeof(ciphertext), &ciphertext_len);
+    CHECK_EQUAL(PSA_SUCCESS, status);
+
+    /* Decrypt it */
+    uint8_t plaintext[256];
+    size_t plaintext_len = 0;
+
+    status = m_crypto_client->asymmetric_decrypt(key_handle, PSA_ALG_RSA_OAEP(PSA_ALG_SHA_256),
+                            ciphertext, ciphertext_len,
+                            salt, sizeof(salt),
+                            plaintext, sizeof(plaintext), &plaintext_len);
+    CHECK_EQUAL(PSA_SUCCESS, status);
+
+    /* Expect the encrypted/decrypted message to match theh original */
+    CHECK_EQUAL(sizeof(message), plaintext_len);
+    MEMCMP_EQUAL(message, plaintext, plaintext_len);
+
+    /* Remove the key */
+    status = m_crypto_client->destroy_key(key_handle);
+    CHECK_EQUAL(PSA_SUCCESS, status);
+}
+
 void crypto_service_scenarios::generateRandomNumbers()
 {
     psa_status_t status;
diff --git a/components/service/crypto/test/service/crypto_service_scenarios.h b/components/service/crypto/test/service/crypto_service_scenarios.h
index 2c7d188..0e996aa 100644
--- a/components/service/crypto/test/service/crypto_service_scenarios.h
+++ b/components/service/crypto/test/service/crypto_service_scenarios.h
@@ -19,6 +19,7 @@
 
     void generateRandomNumbers();
     void asymEncryptDecrypt();
+    void asymEncryptDecryptWithSalt();
     void signAndVerifyHash();
     void exportAndImportKeyPair();
     void exportPublicKey();
diff --git a/components/service/crypto/test/service/packed-c/crypto_service_packedc_tests.cpp b/components/service/crypto/test/service/packed-c/crypto_service_packedc_tests.cpp
index a6cbe31..5a620b4 100644
--- a/components/service/crypto/test/service/packed-c/crypto_service_packedc_tests.cpp
+++ b/components/service/crypto/test/service/packed-c/crypto_service_packedc_tests.cpp
@@ -82,6 +82,11 @@
     m_scenarios->asymEncryptDecrypt();
 }
 
+TEST(CryptoServicePackedcTests, asymEncryptDecryptWithSalt)
+{
+    m_scenarios->asymEncryptDecryptWithSalt();
+}
+
 TEST(CryptoServicePackedcTests, generateRandomNumbers)
 {
     m_scenarios->generateRandomNumbers();
diff --git a/components/service/crypto/test/service/protobuf/crypto_service_protobuf_tests.cpp b/components/service/crypto/test/service/protobuf/crypto_service_protobuf_tests.cpp
index 3d728e2..9483f25 100644
--- a/components/service/crypto/test/service/protobuf/crypto_service_protobuf_tests.cpp
+++ b/components/service/crypto/test/service/protobuf/crypto_service_protobuf_tests.cpp
@@ -82,6 +82,11 @@
     m_scenarios->asymEncryptDecrypt();
 }
 
+TEST(CryptoServiceProtobufTests, asymEncryptDecryptWithSalt)
+{
+    m_scenarios->asymEncryptDecryptWithSalt();
+}
+
 TEST(CryptoServiceProtobufTests, generateRandomNumbers)
 {
     m_scenarios->generateRandomNumbers();