Cleanup the code

Use conditional compilation for psa and mbedtls code (MBEDTLS_USE_PSA_CRYPTO).

Signed-off-by: Przemyslaw Stekiel <przemyslaw.stekiel@mobica.com>
diff --git a/library/ssl_misc.h b/library/ssl_misc.h
index 68cc4f0..a6439dc 100644
--- a/library/ssl_misc.h
+++ b/library/ssl_misc.h
@@ -937,14 +937,15 @@
 
 #endif /* MBEDTLS_SSL_SOME_SUITES_USE_MAC */
 
-    mbedtls_cipher_context_t cipher_ctx_enc;    /*!<  encryption context      */
-    mbedtls_cipher_context_t cipher_ctx_dec;    /*!<  decryption context      */
     int minor_ver;
 
 #if defined(MBEDTLS_USE_PSA_CRYPTO)
     mbedtls_svc_key_id_t psa_key_enc;           /*!<  psa encryption key      */
     mbedtls_svc_key_id_t psa_key_dec;           /*!<  psa decryption key      */
     psa_algorithm_t psa_alg;                    /*!<  psa algorithm           */
+#else
+    mbedtls_cipher_context_t cipher_ctx_enc;    /*!<  encryption context      */
+    mbedtls_cipher_context_t cipher_ctx_dec;    /*!<  decryption context      */
 #endif /* MBEDTLS_USE_PSA_CRYPTO */
 
 #if defined(MBEDTLS_SSL_DTLS_CONNECTION_ID)
diff --git a/library/ssl_msg.c b/library/ssl_msg.c
index c9f75de..2353c5e 100644
--- a/library/ssl_msg.c
+++ b/library/ssl_msg.c
@@ -522,7 +522,9 @@
                              int (*f_rng)(void *, unsigned char *, size_t),
                              void *p_rng )
 {
+#if !defined(MBEDTLS_USE_PSA_CRYPTO)
     mbedtls_cipher_mode_t mode;
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
     int auth_done = 0;
     unsigned char * data;
     unsigned char add_data[13 + 1 + MBEDTLS_SSL_CID_OUT_LEN_MAX ];
@@ -568,7 +570,9 @@
     MBEDTLS_SSL_DEBUG_BUF( 4, "before encrypt: output payload",
                            data, rec->data_len );
 
+#if !defined(MBEDTLS_USE_PSA_CRYPTO)
     mode = mbedtls_cipher_get_cipher_mode( &transform->cipher_ctx_enc );
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
 
     if( rec->data_len > MBEDTLS_SSL_OUT_CONTENT_LEN )
     {
@@ -649,8 +653,13 @@
      * Add MAC before if needed
      */
 #if defined(MBEDTLS_SSL_SOME_SUITES_USE_MAC)
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+    if ( transform->psa_alg == MBEDTLS_SSL_NULL_CIPHER  ||
+            ( transform->psa_alg == PSA_ALG_CBC_NO_PADDING
+#else
     if( mode == MBEDTLS_MODE_STREAM ||
         ( mode == MBEDTLS_MODE_CBC
+#endif
 #if defined(MBEDTLS_SSL_ENCRYPT_THEN_MAC)
           && transform->encrypt_then_mac == MBEDTLS_SSL_ETM_DISABLED
 #endif
@@ -707,7 +716,11 @@
      * Encrypt
      */
 #if defined(MBEDTLS_SSL_SOME_SUITES_USE_STREAM)
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+    if ( transform->psa_alg == MBEDTLS_SSL_NULL_CIPHER )
+#else
     if( mode == MBEDTLS_MODE_STREAM )
+#endif
     {
         size_t olen;
 #if defined(MBEDTLS_USE_PSA_CRYPTO)
@@ -779,9 +792,18 @@
 #if defined(MBEDTLS_GCM_C) || \
     defined(MBEDTLS_CCM_C) || \
     defined(MBEDTLS_CHACHAPOLY_C)
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+    if ( transform->psa_alg == PSA_ALG_GCM ||
+         /* PSA_ALG_IS_AEAD( transform->psa_alg ) corresponds to
+            psa_alg == PSA_ALG_CCM || psa_alg == PSA_ALG_AEAD_WITH_SHORTENED_TAG( PSA_ALG_CCM, 8 )
+            in tls context (TLS only uses the default taglen or 8) */
+         PSA_ALG_IS_AEAD( transform->psa_alg ) ||
+         transform->psa_alg == PSA_ALG_CHACHA20_POLY1305 )
+#else
     if( mode == MBEDTLS_MODE_GCM ||
         mode == MBEDTLS_MODE_CCM ||
         mode == MBEDTLS_MODE_CHACHAPOLY )
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
     {
         unsigned char iv[12];
         unsigned char *dynamic_iv;
@@ -897,7 +919,11 @@
     else
 #endif /* MBEDTLS_GCM_C || MBEDTLS_CCM_C || MBEDTLS_CHACHAPOLY_C */
 #if defined(MBEDTLS_SSL_SOME_SUITES_USE_CBC)
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+    if ( transform->psa_alg == PSA_ALG_CBC_NO_PADDING )
+#else
     if( mode == MBEDTLS_MODE_CBC )
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
     {
         int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
         size_t padlen, i;
@@ -1092,7 +1118,9 @@
                              mbedtls_record *rec )
 {
     size_t olen;
+#if !defined(MBEDTLS_USE_PSA_CRYPTO)
     mbedtls_cipher_mode_t mode;
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
     int ret, auth_done = 0;
 #if defined(MBEDTLS_SSL_SOME_SUITES_USE_MAC)
     size_t padlen = 0, correct = 1;
@@ -1117,7 +1145,9 @@
     }
 
     data = rec->buf + rec->data_offset;
+#if !defined(MBEDTLS_USE_PSA_CRYPTO)
     mode = mbedtls_cipher_get_cipher_mode( &transform->cipher_ctx_dec );
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
 
 #if defined(MBEDTLS_SSL_DTLS_CONNECTION_ID)
     /*
@@ -1131,7 +1161,11 @@
 #endif /* MBEDTLS_SSL_DTLS_CONNECTION_ID */
 
 #if defined(MBEDTLS_SSL_SOME_SUITES_USE_STREAM)
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+    if ( transform->psa_alg == MBEDTLS_SSL_NULL_CIPHER )
+#else
     if( mode == MBEDTLS_MODE_STREAM )
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
     {
         padlen = 0;
 #if defined(MBEDTLS_USE_PSA_CRYPTO)
@@ -1198,9 +1232,18 @@
 #if defined(MBEDTLS_GCM_C) || \
     defined(MBEDTLS_CCM_C) || \
     defined(MBEDTLS_CHACHAPOLY_C)
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+    if ( transform->psa_alg == PSA_ALG_GCM ||
+         /* PSA_ALG_IS_AEAD( transform->psa_alg ) corresponds to
+            psa_alg == PSA_ALG_CCM || psa_alg == PSA_ALG_AEAD_WITH_SHORTENED_TAG( PSA_ALG_CCM, 8 )
+            in tls context (TLS only uses the default taglen or 8) */
+         PSA_ALG_IS_AEAD( transform->psa_alg ) ||
+         transform->psa_alg == PSA_ALG_CHACHA20_POLY1305 )
+#else
     if( mode == MBEDTLS_MODE_GCM ||
         mode == MBEDTLS_MODE_CCM ||
         mode == MBEDTLS_MODE_CHACHAPOLY )
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
     {
         unsigned char iv[12];
         unsigned char *dynamic_iv;
@@ -1322,7 +1365,11 @@
     else
 #endif /* MBEDTLS_GCM_C || MBEDTLS_CCM_C */
 #if defined(MBEDTLS_SSL_SOME_SUITES_USE_CBC)
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+    if ( transform->psa_alg == PSA_ALG_CBC_NO_PADDING )
+#else
     if( mode == MBEDTLS_MODE_CBC )
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
     {
         size_t minlen = 0;
 #if defined(MBEDTLS_USE_PSA_CRYPTO)
@@ -5047,12 +5094,62 @@
     size_t transform_expansion = 0;
     const mbedtls_ssl_transform *transform = ssl->transform_out;
     unsigned block_size;
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+    psa_key_attributes_t attr = PSA_KEY_ATTRIBUTES_INIT;
+    psa_key_type_t key_type;
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
 
     size_t out_hdr_len = mbedtls_ssl_out_hdr_len( ssl );
 
     if( transform == NULL )
         return( (int) out_hdr_len );
 
+
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+    switch( transform->psa_alg )
+    {
+        case PSA_ALG_GCM:
+        case PSA_ALG_CHACHA20_POLY1305:
+        case MBEDTLS_SSL_NULL_CIPHER:
+            transform_expansion = transform->minlen;
+            break;
+
+        case PSA_ALG_CBC_NO_PADDING:
+            (void) psa_get_key_attributes( transform->psa_key_enc, &attr );
+            key_type = psa_get_key_type( &attr );
+
+            block_size = PSA_BLOCK_CIPHER_BLOCK_LENGTH( key_type );
+
+            /* Expansion due to the addition of the MAC. */
+            transform_expansion += transform->maclen;
+
+            /* Expansion due to the addition of CBC padding;
+             * Theoretically up to 256 bytes, but we never use
+             * more than the block size of the underlying cipher. */
+            transform_expansion += block_size;
+
+            /* For TLS 1.2 or higher, an explicit IV is added
+             * after the record header. */
+#if defined(MBEDTLS_SSL_PROTO_TLS1_2)
+            transform_expansion += block_size;
+#endif /* MBEDTLS_SSL_PROTO_TLS1_2 */
+            break;
+
+        default:
+         /* Handle CCM case in default:
+            PSA_ALG_IS_AEAD( transform->psa_alg ) corresponds to
+            psa_alg == PSA_ALG_CCM || psa_alg == PSA_ALG_AEAD_WITH_SHORTENED_TAG( PSA_ALG_CCM, 8 )
+            in tls context (TLS only uses the default taglen or 8) */
+            if ( PSA_ALG_IS_AEAD( transform->psa_alg ) )
+            {
+                transform_expansion = transform->minlen;
+                break;
+            }
+
+            MBEDTLS_SSL_DEBUG_MSG( 1, ( "should never happen" ) );
+            return( MBEDTLS_ERR_SSL_INTERNAL_ERROR );
+    }
+#else
     switch( mbedtls_cipher_get_cipher_mode( &transform->cipher_ctx_enc ) )
     {
         case MBEDTLS_MODE_GCM:
@@ -5087,6 +5184,7 @@
             MBEDTLS_SSL_DEBUG_MSG( 1, ( "should never happen" ) );
             return( MBEDTLS_ERR_SSL_INTERNAL_ERROR );
     }
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
 
 #if defined(MBEDTLS_SSL_DTLS_CONNECTION_ID)
     if( transform->out_cid_len != 0 )
@@ -5591,13 +5689,13 @@
     if( transform == NULL )
         return;
 
-    mbedtls_cipher_free( &transform->cipher_ctx_enc );
-    mbedtls_cipher_free( &transform->cipher_ctx_dec );
-
 #if defined(MBEDTLS_USE_PSA_CRYPTO)
     psa_destroy_key( transform->psa_key_enc );
     psa_destroy_key( transform->psa_key_dec );
-#endif
+#else
+    mbedtls_cipher_free( &transform->cipher_ctx_enc );
+    mbedtls_cipher_free( &transform->cipher_ctx_dec );
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
 
 #if defined(MBEDTLS_SSL_SOME_SUITES_USE_MAC)
     mbedtls_md_free( &transform->md_ctx_enc );
diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index 6191d63..4266af4 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -705,9 +705,6 @@
                                    const mbedtls_ssl_context *ssl )
 {
     int ret = 0;
-#if defined(MBEDTLS_USE_PSA_CRYPTO)
-    int psa_fallthrough;
-#endif /* MBEDTLS_USE_PSA_CRYPTO */
     unsigned char keyblk[256];
     unsigned char *key1;
     unsigned char *key2;
@@ -1012,80 +1009,6 @@
     }
 
 #if defined(MBEDTLS_USE_PSA_CRYPTO)
-    ret = mbedtls_cipher_setup_psa( &transform->cipher_ctx_enc,
-                                    cipher_info, transform->taglen );
-    if( ret != 0 && ret != MBEDTLS_ERR_CIPHER_FEATURE_UNAVAILABLE )
-    {
-        MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_setup_psa", ret );
-        goto end;
-    }
-
-    if( ret == 0 )
-    {
-        MBEDTLS_SSL_DEBUG_MSG( 3, ( "Successfully setup PSA-based encryption cipher context" ) );
-        psa_fallthrough = 0;
-    }
-    else
-    {
-        MBEDTLS_SSL_DEBUG_MSG( 1, ( "Failed to setup PSA-based cipher context for record encryption - fall through to default setup." ) );
-        psa_fallthrough = 1;
-    }
-
-    if( psa_fallthrough == 1 )
-#endif /* MBEDTLS_USE_PSA_CRYPTO */
-    if( ( ret = mbedtls_cipher_setup( &transform->cipher_ctx_enc,
-                                 cipher_info ) ) != 0 )
-    {
-        MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_setup", ret );
-        goto end;
-    }
-
-#if defined(MBEDTLS_USE_PSA_CRYPTO)
-    ret = mbedtls_cipher_setup_psa( &transform->cipher_ctx_dec,
-                                    cipher_info, transform->taglen );
-    if( ret != 0 && ret != MBEDTLS_ERR_CIPHER_FEATURE_UNAVAILABLE )
-    {
-        MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_setup_psa", ret );
-        goto end;
-    }
-
-    if( ret == 0 )
-    {
-        MBEDTLS_SSL_DEBUG_MSG( 3, ( "Successfully setup PSA-based decryption cipher context" ) );
-        psa_fallthrough = 0;
-    }
-    else
-    {
-        MBEDTLS_SSL_DEBUG_MSG( 1, ( "Failed to setup PSA-based cipher context for record decryption - fall through to default setup." ) );
-        psa_fallthrough = 1;
-    }
-
-    if( psa_fallthrough == 1 )
-#endif /* MBEDTLS_USE_PSA_CRYPTO */
-    if( ( ret = mbedtls_cipher_setup( &transform->cipher_ctx_dec,
-                                 cipher_info ) ) != 0 )
-    {
-        MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_setup", ret );
-        goto end;
-    }
-
-    if( ( ret = mbedtls_cipher_setkey( &transform->cipher_ctx_enc, key1,
-                               (int) mbedtls_cipher_info_get_key_bitlen( cipher_info ),
-                               MBEDTLS_ENCRYPT ) ) != 0 )
-    {
-        MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_setkey", ret );
-        goto end;
-    }
-
-    if( ( ret = mbedtls_cipher_setkey( &transform->cipher_ctx_dec, key2,
-                               (int) mbedtls_cipher_info_get_key_bitlen( cipher_info ),
-                               MBEDTLS_DECRYPT ) ) != 0 )
-    {
-        MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_setkey", ret );
-        goto end;
-    }
-
-#if defined(MBEDTLS_USE_PSA_CRYPTO)
     if( ( status = mbedtls_cipher_to_psa( cipher_info->type,
                                  transform->taglen,
                                  &alg,
@@ -1099,6 +1022,7 @@
 
     psa_set_key_usage_flags( &attributes, PSA_KEY_USAGE_ENCRYPT );
     psa_set_key_algorithm( &attributes, alg );
+    psa_set_key_type( &attributes, key_type );
 
     transform->psa_alg = alg;
 
@@ -1123,7 +1047,36 @@
         MBEDTLS_SSL_DEBUG_RET( 1, "psa_import_key", ret );
         goto end;
     }
-#endif /* MBEDTLS_USE_PSA_CRYPTO */
+#else
+    if( ( ret = mbedtls_cipher_setup( &transform->cipher_ctx_enc,
+                                 cipher_info ) ) != 0 )
+    {
+        MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_setup", ret );
+        goto end;
+    }
+
+    if( ( ret = mbedtls_cipher_setup( &transform->cipher_ctx_dec,
+                                 cipher_info ) ) != 0 )
+    {
+        MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_setup", ret );
+        goto end;
+    }
+
+    if( ( ret = mbedtls_cipher_setkey( &transform->cipher_ctx_enc, key1,
+                               (int) mbedtls_cipher_info_get_key_bitlen( cipher_info ),
+                               MBEDTLS_ENCRYPT ) ) != 0 )
+    {
+        MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_setkey", ret );
+        goto end;
+    }
+
+    if( ( ret = mbedtls_cipher_setkey( &transform->cipher_ctx_dec, key2,
+                               (int) mbedtls_cipher_info_get_key_bitlen( cipher_info ),
+                               MBEDTLS_DECRYPT ) ) != 0 )
+    {
+        MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_setkey", ret );
+        goto end;
+    }
 
 #if defined(MBEDTLS_CIPHER_MODE_CBC)
     if( mbedtls_cipher_info_get_mode( cipher_info ) == MBEDTLS_MODE_CBC )
@@ -1143,7 +1096,7 @@
         }
     }
 #endif /* MBEDTLS_CIPHER_MODE_CBC */
-
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
 
 end:
     mbedtls_platform_zeroize( keyblk, sizeof( keyblk ) );
@@ -3070,12 +3023,12 @@
 {
     memset( transform, 0, sizeof(mbedtls_ssl_transform) );
 
-    mbedtls_cipher_init( &transform->cipher_ctx_enc );
-    mbedtls_cipher_init( &transform->cipher_ctx_dec );
-
 #if defined(MBEDTLS_USE_PSA_CRYPTO)
     transform->psa_key_enc = MBEDTLS_SVC_KEY_ID_INIT;
     transform->psa_key_dec = MBEDTLS_SVC_KEY_ID_INIT;
+#else
+    mbedtls_cipher_init( &transform->cipher_ctx_enc );
+    mbedtls_cipher_init( &transform->cipher_ctx_dec );
 #endif
 
 #if defined(MBEDTLS_SSL_SOME_SUITES_USE_MAC)
diff --git a/library/ssl_tls13_keys.c b/library/ssl_tls13_keys.c
index 0aade35..a3c1fe5 100644
--- a/library/ssl_tls13_keys.c
+++ b/library/ssl_tls13_keys.c
@@ -801,7 +801,9 @@
                                           mbedtls_ssl_key_set const *traffic_keys,
                                           mbedtls_ssl_context *ssl /* DEBUG ONLY */ )
 {
+#if !defined(MBEDTLS_USE_PSA_CRYPTO)
     int ret;
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
     mbedtls_cipher_info_t const *cipher_info;
     const mbedtls_ssl_ciphersuite_t *ciphersuite_info;
     unsigned char const *key_enc;
@@ -838,10 +840,10 @@
         return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
     }
 
+#if !defined(MBEDTLS_USE_PSA_CRYPTO)
     /*
      * Setup cipher contexts in target transform
      */
-
     if( ( ret = mbedtls_cipher_setup( &transform->cipher_ctx_enc,
                                       cipher_info ) ) != 0 )
     {
@@ -855,6 +857,7 @@
         MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_setup", ret );
         return( ret );
     }
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
 
 #if defined(MBEDTLS_SSL_SRV_C)
     if( endpoint == MBEDTLS_SSL_IS_SERVER )
@@ -884,6 +887,7 @@
     memcpy( transform->iv_enc, iv_enc, traffic_keys->iv_len );
     memcpy( transform->iv_dec, iv_dec, traffic_keys->iv_len );
 
+#if !defined(MBEDTLS_USE_PSA_CRYPTO)
     if( ( ret = mbedtls_cipher_setkey( &transform->cipher_ctx_enc,
                                        key_enc, cipher_info->key_bitlen,
                                        MBEDTLS_ENCRYPT ) ) != 0 )
@@ -899,6 +903,7 @@
         MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_cipher_setkey", ret );
         return( ret );
     }
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
 
     /*
      * Setup other fields in SSL transform
@@ -922,6 +927,9 @@
         transform->taglen + MBEDTLS_SSL_CID_TLS1_3_PADDING_GRANULARITY;
 
 #if defined(MBEDTLS_USE_PSA_CRYPTO)
+    /*
+     * Setup psa keys and alg
+     */
     if( ( status = mbedtls_cipher_to_psa( cipher_info->type,
                                  transform->taglen,
                                  &alg,
@@ -934,6 +942,7 @@
 
     psa_set_key_usage_flags( &attributes, PSA_KEY_USAGE_ENCRYPT );
     psa_set_key_algorithm( &attributes, alg );
+    psa_set_key_type( &attributes, key_type );
 
     transform->psa_alg = alg;