Merge pull request #5674 from superna9999/5668-abstract-tls-mode-cleanup

Cipher cleanup: abstract TLS mode
diff --git a/library/ssl_misc.h b/library/ssl_misc.h
index 141c40a..e8acc23 100644
--- a/library/ssl_misc.h
+++ b/library/ssl_misc.h
@@ -173,6 +173,12 @@
 #define MBEDTLS_SSL_SOME_SUITES_USE_MAC
 #endif
 
+/* This macro determines whether a ciphersuite uses Encrypt-then-MAC with CBC */
+#if defined(MBEDTLS_SSL_SOME_SUITES_USE_CBC) && \
+    defined(MBEDTLS_SSL_ENCRYPT_THEN_MAC)
+#define MBEDTLS_SSL_SOME_SUITES_USE_CBC_ETM
+#endif
+
 #endif /* MBEDTLS_SSL_PROTO_TLS1_2 */
 
 #if defined(MBEDTLS_SSL_SOME_SUITES_USE_MAC)
@@ -2222,6 +2228,28 @@
 }
 #endif /* MBEDTLS_USE_PSA_CRYPTO || MBEDTLS_SSL_PROTO_TLS1_3 */
 
+/**
+ * \brief       TLS record protection modes
+ */
+typedef enum {
+    MBEDTLS_SSL_MODE_STREAM = 0,
+    MBEDTLS_SSL_MODE_CBC,
+    MBEDTLS_SSL_MODE_CBC_ETM,
+    MBEDTLS_SSL_MODE_AEAD
+} mbedtls_ssl_mode_t;
+
+mbedtls_ssl_mode_t mbedtls_ssl_get_mode_from_transform(
+        const mbedtls_ssl_transform *transform );
+
+#if defined(MBEDTLS_SSL_SOME_SUITES_USE_CBC_ETM)
+mbedtls_ssl_mode_t mbedtls_ssl_get_mode_from_ciphersuite(
+        int encrypt_then_mac,
+        const mbedtls_ssl_ciphersuite_t *suite );
+#else
+mbedtls_ssl_mode_t mbedtls_ssl_get_mode_from_ciphersuite(
+        const mbedtls_ssl_ciphersuite_t *suite );
+#endif /* MBEDTLS_SSL_SOME_SUITES_USE_CBC_ETM */
+
 #if defined(MBEDTLS_ECDH_C)
 
 int mbedtls_ssl_tls13_read_public_ecdhe_share( mbedtls_ssl_context *ssl,
diff --git a/library/ssl_msg.c b/library/ssl_msg.c
index c55c60f..083c8d2 100644
--- a/library/ssl_msg.c
+++ b/library/ssl_msg.c
@@ -523,9 +523,7 @@
                              int (*f_rng)(void *, unsigned char *, size_t),
                              void *p_rng )
 {
-#if !defined(MBEDTLS_USE_PSA_CRYPTO)
-    mbedtls_cipher_mode_t mode;
-#endif /* MBEDTLS_USE_PSA_CRYPTO */
+    mbedtls_ssl_mode_t ssl_mode;
     int auth_done = 0;
     unsigned char * data;
     unsigned char add_data[13 + 1 + MBEDTLS_SSL_CID_OUT_LEN_MAX ];
@@ -566,15 +564,13 @@
         return( MBEDTLS_ERR_SSL_INTERNAL_ERROR );
     }
 
+    ssl_mode = mbedtls_ssl_get_mode_from_transform( transform );
+
     data = rec->buf + rec->data_offset;
     post_avail = rec->buf_len - ( rec->data_len + rec->data_offset );
     MBEDTLS_SSL_DEBUG_BUF( 4, "before encrypt: output payload",
                            data, rec->data_len );
 
-#if !defined(MBEDTLS_USE_PSA_CRYPTO)
-    mode = mbedtls_cipher_get_cipher_mode( &transform->cipher_ctx_enc );
-#endif /* MBEDTLS_USE_PSA_CRYPTO */
-
     if( rec->data_len > MBEDTLS_SSL_OUT_CONTENT_LEN )
     {
         MBEDTLS_SSL_DEBUG_MSG( 1, ( "Record content %" MBEDTLS_PRINTF_SIZET
@@ -654,17 +650,8 @@
      * Add MAC before if needed
      */
 #if defined(MBEDTLS_SSL_SOME_SUITES_USE_MAC)
-#if defined(MBEDTLS_USE_PSA_CRYPTO)
-    if ( transform->psa_alg == MBEDTLS_SSL_NULL_CIPHER  ||
-            ( transform->psa_alg == PSA_ALG_CBC_NO_PADDING
-#else
-    if( mode == MBEDTLS_MODE_STREAM ||
-        ( mode == MBEDTLS_MODE_CBC
-#endif
-#if defined(MBEDTLS_SSL_ENCRYPT_THEN_MAC)
-          && transform->encrypt_then_mac == MBEDTLS_SSL_ETM_DISABLED
-#endif
-        ) )
+    if( ssl_mode == MBEDTLS_SSL_MODE_STREAM ||
+        ssl_mode == MBEDTLS_SSL_MODE_CBC )
     {
         if( post_avail < transform->maclen )
         {
@@ -748,11 +735,7 @@
      * Encrypt
      */
 #if defined(MBEDTLS_SSL_SOME_SUITES_USE_STREAM)
-#if defined(MBEDTLS_USE_PSA_CRYPTO)
-    if ( transform->psa_alg == MBEDTLS_SSL_NULL_CIPHER )
-#else
-    if( mode == MBEDTLS_MODE_STREAM )
-#endif
+    if( ssl_mode == MBEDTLS_SSL_MODE_STREAM )
     {
         MBEDTLS_SSL_DEBUG_MSG( 3, ( "before encrypt: msglen = %" MBEDTLS_PRINTF_SIZET ", "
                                     "including %d bytes of padding",
@@ -767,13 +750,7 @@
 #if defined(MBEDTLS_GCM_C) || \
     defined(MBEDTLS_CCM_C) || \
     defined(MBEDTLS_CHACHAPOLY_C)
-#if defined(MBEDTLS_USE_PSA_CRYPTO)
-    if ( PSA_ALG_IS_AEAD( transform->psa_alg ) )
-#else
-    if( mode == MBEDTLS_MODE_GCM ||
-        mode == MBEDTLS_MODE_CCM ||
-        mode == MBEDTLS_MODE_CHACHAPOLY )
-#endif /* MBEDTLS_USE_PSA_CRYPTO */
+    if( ssl_mode == MBEDTLS_SSL_MODE_AEAD )
     {
         unsigned char iv[12];
         unsigned char *dynamic_iv;
@@ -891,11 +868,8 @@
     else
 #endif /* MBEDTLS_GCM_C || MBEDTLS_CCM_C || MBEDTLS_CHACHAPOLY_C */
 #if defined(MBEDTLS_SSL_SOME_SUITES_USE_CBC)
-#if defined(MBEDTLS_USE_PSA_CRYPTO)
-    if ( transform->psa_alg == PSA_ALG_CBC_NO_PADDING )
-#else
-    if( mode == MBEDTLS_MODE_CBC )
-#endif /* MBEDTLS_USE_PSA_CRYPTO */
+    if( ssl_mode == MBEDTLS_SSL_MODE_CBC ||
+        ssl_mode == MBEDTLS_SSL_MODE_CBC_ETM )
     {
         int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
         size_t padlen, i;
@@ -1139,14 +1113,9 @@
                              mbedtls_record *rec )
 {
     size_t olen;
-#if defined(MBEDTLS_USE_PSA_CRYPTO)
+    mbedtls_ssl_mode_t ssl_mode;
     int ret;
 
-#else
-    mbedtls_cipher_mode_t mode;
-    int ret;
-#endif /* MBEDTLS_USE_PSA_CRYPTO */
-
     int auth_done = 0;
 #if defined(MBEDTLS_SSL_SOME_SUITES_USE_MAC)
     size_t padlen = 0, correct = 1;
@@ -1171,9 +1140,7 @@
     }
 
     data = rec->buf + rec->data_offset;
-#if !defined(MBEDTLS_USE_PSA_CRYPTO)
-    mode = mbedtls_cipher_get_cipher_mode( &transform->cipher_ctx_dec );
-#endif /* MBEDTLS_USE_PSA_CRYPTO */
+    ssl_mode = mbedtls_ssl_get_mode_from_transform( transform );
 
 #if defined(MBEDTLS_SSL_DTLS_CONNECTION_ID)
     /*
@@ -1187,11 +1154,7 @@
 #endif /* MBEDTLS_SSL_DTLS_CONNECTION_ID */
 
 #if defined(MBEDTLS_SSL_SOME_SUITES_USE_STREAM)
-#if defined(MBEDTLS_USE_PSA_CRYPTO)
-    if ( transform->psa_alg == MBEDTLS_SSL_NULL_CIPHER )
-#else
-    if( mode == MBEDTLS_MODE_STREAM )
-#endif /* MBEDTLS_USE_PSA_CRYPTO */
+    if( ssl_mode == MBEDTLS_SSL_MODE_STREAM )
     {
         /* The only supported stream cipher is "NULL",
          * so there's nothing to do here.*/
@@ -1201,13 +1164,7 @@
 #if defined(MBEDTLS_GCM_C) || \
     defined(MBEDTLS_CCM_C) || \
     defined(MBEDTLS_CHACHAPOLY_C)
-#if defined(MBEDTLS_USE_PSA_CRYPTO)
-    if ( PSA_ALG_IS_AEAD( transform->psa_alg ) )
-#else
-    if( mode == MBEDTLS_MODE_GCM ||
-        mode == MBEDTLS_MODE_CCM ||
-        mode == MBEDTLS_MODE_CHACHAPOLY )
-#endif /* MBEDTLS_USE_PSA_CRYPTO */
+    if( ssl_mode == MBEDTLS_SSL_MODE_AEAD )
     {
         unsigned char iv[12];
         unsigned char *dynamic_iv;
@@ -1333,11 +1290,8 @@
     else
 #endif /* MBEDTLS_GCM_C || MBEDTLS_CCM_C */
 #if defined(MBEDTLS_SSL_SOME_SUITES_USE_CBC)
-#if defined(MBEDTLS_USE_PSA_CRYPTO)
-    if ( transform->psa_alg == PSA_ALG_CBC_NO_PADDING )
-#else
-    if( mode == MBEDTLS_MODE_CBC )
-#endif /* MBEDTLS_USE_PSA_CRYPTO */
+    if( ssl_mode == MBEDTLS_SSL_MODE_CBC ||
+        ssl_mode == MBEDTLS_SSL_MODE_CBC_ETM )
     {
         size_t minlen = 0;
 #if defined(MBEDTLS_USE_PSA_CRYPTO)
@@ -1391,7 +1345,7 @@
          * Authenticate before decrypt if enabled
          */
 #if defined(MBEDTLS_SSL_ENCRYPT_THEN_MAC)
-        if( transform->encrypt_then_mac == MBEDTLS_SSL_ETM_ENABLED )
+        if( ssl_mode == MBEDTLS_SSL_MODE_CBC_ETM )
         {
 #if defined(MBEDTLS_USE_PSA_CRYPTO)
             psa_mac_operation_t operation = PSA_MAC_OPERATION_INIT;
diff --git a/library/ssl_ticket.c b/library/ssl_ticket.c
index 7f65849..1c05001 100644
--- a/library/ssl_ticket.c
+++ b/library/ssl_ticket.c
@@ -216,19 +216,23 @@
     uint32_t lifetime )
 {
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
-    const mbedtls_cipher_info_t *cipher_info;
+    size_t key_bits;
 
 #if defined(MBEDTLS_USE_PSA_CRYPTO)
     psa_algorithm_t alg;
     psa_key_type_t key_type;
-    size_t key_bits;
-#endif
+#else
+    const mbedtls_cipher_info_t *cipher_info;
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
 
-    ctx->f_rng = f_rng;
-    ctx->p_rng = p_rng;
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+    if( mbedtls_ssl_cipher_to_psa( cipher, TICKET_AUTH_TAG_BYTES,
+                                   &alg, &key_type, &key_bits ) != PSA_SUCCESS )
+        return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
 
-    ctx->ticket_lifetime = lifetime;
-
+    if( PSA_ALG_IS_AEAD( alg ) == 0 )
+        return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
+#else
     cipher_info = mbedtls_cipher_info_from_type( cipher );
 
     if( mbedtls_cipher_info_get_mode( cipher_info ) != MBEDTLS_MODE_GCM &&
@@ -238,14 +242,18 @@
         return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
     }
 
-    if( mbedtls_cipher_info_get_key_bitlen( cipher_info ) > 8 * MAX_KEY_BYTES )
+    key_bits = mbedtls_cipher_info_get_key_bitlen( cipher_info );
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
+
+    if( key_bits > 8 * MAX_KEY_BYTES )
         return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
 
+    ctx->f_rng = f_rng;
+    ctx->p_rng = p_rng;
+
+    ctx->ticket_lifetime = lifetime;
+
 #if defined(MBEDTLS_USE_PSA_CRYPTO)
-    if( mbedtls_ssl_cipher_to_psa( cipher_info->type, TICKET_AUTH_TAG_BYTES,
-                                   &alg, &key_type, &key_bits ) != PSA_SUCCESS )
-        return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
-
     ctx->keys[0].alg = alg;
     ctx->keys[0].key_type = key_type;
     ctx->keys[0].key_bits = key_bits;
diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index 7804cb7..1702885 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -383,11 +383,9 @@
 static int ssl_tls12_populate_transform( mbedtls_ssl_transform *transform,
                                    int ciphersuite,
                                    const unsigned char master[48],
-#if defined(MBEDTLS_SSL_SOME_SUITES_USE_MAC) && \
-    defined(MBEDTLS_SSL_ENCRYPT_THEN_MAC)
+#if defined(MBEDTLS_SSL_SOME_SUITES_USE_CBC_ETM)
                                    int encrypt_then_mac,
-#endif /* MBEDTLS_SSL_ENCRYPT_THEN_MAC &&
-          MBEDTLS_SSL_SOME_SUITES_USE_MAC */
+#endif /* MBEDTLS_SSL_SOME_SUITES_USE_CBC_ETM */
                                    ssl_tls_prf_t tls_prf,
                                    const unsigned char randbytes[64],
                                    mbedtls_ssl_protocol_version tls_version,
@@ -1716,6 +1714,109 @@
 #endif /* MBEDTLS_KEY_EXCHANGE_SOME_PSK_ENABLED */
 
 #if defined(MBEDTLS_USE_PSA_CRYPTO)
+static mbedtls_ssl_mode_t mbedtls_ssl_get_base_mode(
+        psa_algorithm_t alg )
+{
+#if defined(MBEDTLS_SSL_SOME_SUITES_USE_MAC)
+    if( alg == PSA_ALG_CBC_NO_PADDING )
+        return( MBEDTLS_SSL_MODE_CBC );
+#endif /* MBEDTLS_SSL_SOME_SUITES_USE_MAC */
+    if( PSA_ALG_IS_AEAD( alg ) )
+        return( MBEDTLS_SSL_MODE_AEAD );
+    return( MBEDTLS_SSL_MODE_STREAM );
+}
+
+#else /* MBEDTLS_USE_PSA_CRYPTO */
+
+static mbedtls_ssl_mode_t mbedtls_ssl_get_base_mode(
+        mbedtls_cipher_mode_t mode )
+{
+#if defined(MBEDTLS_SSL_SOME_SUITES_USE_MAC)
+    if( mode == MBEDTLS_MODE_CBC )
+        return( MBEDTLS_SSL_MODE_CBC );
+#endif /* MBEDTLS_SSL_SOME_SUITES_USE_MAC */
+
+#if defined(MBEDTLS_GCM_C) || \
+    defined(MBEDTLS_CCM_C) || \
+    defined(MBEDTLS_CHACHAPOLY_C)
+    if( mode == MBEDTLS_MODE_GCM ||
+        mode == MBEDTLS_MODE_CCM ||
+        mode == MBEDTLS_MODE_CHACHAPOLY )
+        return( MBEDTLS_SSL_MODE_AEAD );
+#endif /* MBEDTLS_GCM_C || MBEDTLS_CCM_C || MBEDTLS_CHACHAPOLY_C */
+
+    return( MBEDTLS_SSL_MODE_STREAM );
+}
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
+
+static mbedtls_ssl_mode_t mbedtls_ssl_get_actual_mode(
+    mbedtls_ssl_mode_t base_mode,
+    int encrypt_then_mac )
+{
+#if defined(MBEDTLS_SSL_SOME_SUITES_USE_CBC_ETM)
+    if( encrypt_then_mac == MBEDTLS_SSL_ETM_ENABLED &&
+        base_mode == MBEDTLS_SSL_MODE_CBC )
+    {
+        return( MBEDTLS_SSL_MODE_CBC_ETM );
+    }
+#else
+    (void) encrypt_then_mac;
+#endif
+    return( base_mode );
+}
+
+mbedtls_ssl_mode_t mbedtls_ssl_get_mode_from_transform(
+                    const mbedtls_ssl_transform *transform )
+{
+    mbedtls_ssl_mode_t base_mode = mbedtls_ssl_get_base_mode(
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+            transform->psa_alg
+#else
+            mbedtls_cipher_get_cipher_mode( &transform->cipher_ctx_enc )
+#endif
+        );
+
+    int encrypt_then_mac = 0;
+#if defined(MBEDTLS_SSL_SOME_SUITES_USE_CBC_ETM)
+    encrypt_then_mac = transform->encrypt_then_mac;
+#endif
+    return( mbedtls_ssl_get_actual_mode( base_mode, encrypt_then_mac ) );
+}
+
+mbedtls_ssl_mode_t mbedtls_ssl_get_mode_from_ciphersuite(
+#if defined(MBEDTLS_SSL_SOME_SUITES_USE_CBC_ETM)
+        int encrypt_then_mac,
+#endif /* MBEDTLS_SSL_SOME_SUITES_USE_CBC_ETM */
+        const mbedtls_ssl_ciphersuite_t *suite )
+{
+    mbedtls_ssl_mode_t base_mode = MBEDTLS_SSL_MODE_STREAM;
+
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+    psa_status_t status;
+    psa_algorithm_t alg;
+    psa_key_type_t type;
+    size_t size;
+    status = mbedtls_ssl_cipher_to_psa( suite->cipher, 0, &alg, &type, &size );
+    if( status == PSA_SUCCESS )
+        base_mode = mbedtls_ssl_get_base_mode( alg );
+#else
+    const mbedtls_cipher_info_t *cipher =
+            mbedtls_cipher_info_from_type( suite->cipher );
+    if( cipher != NULL )
+    {
+        base_mode =
+            mbedtls_ssl_get_base_mode(
+                mbedtls_cipher_info_get_mode( cipher ) );
+    }
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
+
+#if !defined(MBEDTLS_SSL_SOME_SUITES_USE_CBC_ETM)
+    int encrypt_then_mac = 0;
+#endif
+    return( mbedtls_ssl_get_actual_mode( base_mode, encrypt_then_mac ) );
+}
+
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
 psa_status_t mbedtls_ssl_cipher_to_psa( mbedtls_cipher_type_t mbedtls_cipher_type,
                                     size_t taglen,
                                     psa_algorithm_t *alg,
@@ -3615,11 +3716,9 @@
     ret = ssl_tls12_populate_transform( ssl->transform,
                   ssl->session->ciphersuite,
                   ssl->session->master,
-#if defined(MBEDTLS_SSL_SOME_SUITES_USE_MAC) && \
-    defined(MBEDTLS_SSL_ENCRYPT_THEN_MAC)
+#if defined(MBEDTLS_SSL_SOME_SUITES_USE_CBC_ETM)
                   ssl->session->encrypt_then_mac,
-#endif /* MBEDTLS_SSL_ENCRYPT_THEN_MAC &&
-          MBEDTLS_SSL_SOME_SUITES_USE_MAC */
+#endif /* MBEDTLS_SSL_SOME_SUITES_USE_CBC_ETM */
                   ssl_tls12prf_from_cs( ssl->session->ciphersuite ),
                   p, /* currently pointing to randbytes */
                   MBEDTLS_SSL_VERSION_TLS1_2, /* (D)TLS 1.2 is forced */
@@ -5193,11 +5292,9 @@
     ret = ssl_tls12_populate_transform( ssl->transform_negotiate,
                                         ssl->session_negotiate->ciphersuite,
                                         ssl->session_negotiate->master,
-#if defined(MBEDTLS_SSL_SOME_SUITES_USE_MAC) && \
-    defined(MBEDTLS_SSL_ENCRYPT_THEN_MAC)
+#if defined(MBEDTLS_SSL_SOME_SUITES_USE_CBC_ETM)
                                         ssl->session_negotiate->encrypt_then_mac,
-#endif /* MBEDTLS_SSL_ENCRYPT_THEN_MAC &&
-          MBEDTLS_SSL_SOME_SUITES_USE_MAC */
+#endif /* MBEDTLS_SSL_SOME_SUITES_USE_CBC_ETM */
                                         ssl->handshake->tls_prf,
                                         ssl->handshake->randbytes,
                                         ssl->tls_version,
@@ -6789,11 +6886,9 @@
 static int ssl_tls12_populate_transform( mbedtls_ssl_transform *transform,
                                    int ciphersuite,
                                    const unsigned char master[48],
-#if defined(MBEDTLS_SSL_SOME_SUITES_USE_MAC) && \
-    defined(MBEDTLS_SSL_ENCRYPT_THEN_MAC)
+#if defined(MBEDTLS_SSL_SOME_SUITES_USE_CBC_ETM)
                                    int encrypt_then_mac,
-#endif /* MBEDTLS_SSL_ENCRYPT_THEN_MAC &&
-          MBEDTLS_SSL_SOME_SUITES_USE_MAC */
+#endif /* MBEDTLS_SSL_SOME_SUITES_USE_CBC_ETM */
                                    ssl_tls_prf_t tls_prf,
                                    const unsigned char randbytes[64],
                                    mbedtls_ssl_protocol_version tls_version,
@@ -6810,8 +6905,9 @@
     size_t iv_copy_len;
     size_t keylen;
     const mbedtls_ssl_ciphersuite_t *ciphersuite_info;
-    const mbedtls_cipher_info_t *cipher_info;
+    mbedtls_ssl_mode_t ssl_mode;
 #if !defined(MBEDTLS_USE_PSA_CRYPTO)
+    const mbedtls_cipher_info_t *cipher_info;
     const mbedtls_md_info_t *md_info;
 #endif /* !MBEDTLS_USE_PSA_CRYPTO */
 
@@ -6836,10 +6932,9 @@
     /*
      * Some data just needs copying into the structure
      */
-#if defined(MBEDTLS_SSL_ENCRYPT_THEN_MAC) && \
-    defined(MBEDTLS_SSL_SOME_SUITES_USE_MAC)
+#if defined(MBEDTLS_SSL_SOME_SUITES_USE_CBC_ETM)
     transform->encrypt_then_mac = encrypt_then_mac;
-#endif
+#endif /* MBEDTLS_SSL_SOME_SUITES_USE_CBC_ETM */
     transform->tls_version = tls_version;
 
 #if defined(MBEDTLS_SSL_CONTEXT_SERIALIZATION)
@@ -6866,6 +6961,28 @@
         return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
     }
 
+    ssl_mode = mbedtls_ssl_get_mode_from_ciphersuite(
+#if defined(MBEDTLS_SSL_SOME_SUITES_USE_CBC_ETM)
+                                        encrypt_then_mac,
+#endif /* MBEDTLS_SSL_SOME_SUITES_USE_CBC_ETM */
+                                        ciphersuite_info );
+
+    if( ssl_mode == MBEDTLS_SSL_MODE_AEAD )
+        transform->taglen =
+            ciphersuite_info->flags & MBEDTLS_CIPHERSUITE_SHORT_TAG ? 8 : 16;
+
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+    if( ( status = mbedtls_ssl_cipher_to_psa( ciphersuite_info->cipher,
+                                 transform->taglen,
+                                 &alg,
+                                 &key_type,
+                                 &key_bits ) ) != PSA_SUCCESS )
+    {
+        ret = psa_ssl_status_to_mbedtls( status );
+        MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_cipher_to_psa", ret );
+        goto end;
+    }
+#else
     cipher_info = mbedtls_cipher_info_from_type( ciphersuite_info->cipher );
     if( cipher_info == NULL )
     {
@@ -6873,6 +6990,7 @@
                                     ciphersuite_info->cipher ) );
         return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
     }
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
 
 #if defined(MBEDTLS_USE_PSA_CRYPTO)
     mac_alg = mbedtls_psa_translate_md( ciphersuite_info->mac );
@@ -6932,21 +7050,21 @@
      * Determine the appropriate key, IV and MAC length.
      */
 
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+    keylen = PSA_BITS_TO_BYTES(key_bits);
+#else
     keylen = mbedtls_cipher_info_get_key_bitlen( cipher_info ) / 8;
+#endif
 
 #if defined(MBEDTLS_GCM_C) ||                           \
     defined(MBEDTLS_CCM_C) ||                           \
     defined(MBEDTLS_CHACHAPOLY_C)
-    if( mbedtls_cipher_info_get_mode( cipher_info ) == MBEDTLS_MODE_GCM ||
-        mbedtls_cipher_info_get_mode( cipher_info ) == MBEDTLS_MODE_CCM ||
-        mbedtls_cipher_info_get_mode( cipher_info ) == MBEDTLS_MODE_CHACHAPOLY )
+    if( ssl_mode == MBEDTLS_SSL_MODE_AEAD )
     {
         size_t explicit_ivlen;
 
         transform->maclen = 0;
         mac_key_len = 0;
-        transform->taglen =
-            ciphersuite_info->flags & MBEDTLS_CIPHERSUITE_SHORT_TAG ? 8 : 16;
 
         /* All modes haves 96-bit IVs, but the length of the static parts vary
          * with mode and version:
@@ -6957,7 +7075,11 @@
          *   sequence number).
          */
         transform->ivlen = 12;
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+        if( key_type == PSA_KEY_TYPE_CHACHA20 )
+#else
         if( mbedtls_cipher_info_get_mode( cipher_info ) == MBEDTLS_MODE_CHACHAPOLY )
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
             transform->fixed_ivlen = 12;
         else
             transform->fixed_ivlen = 4;
@@ -6969,10 +7091,17 @@
     else
 #endif /* MBEDTLS_GCM_C || MBEDTLS_CCM_C || MBEDTLS_CHACHAPOLY_C */
 #if defined(MBEDTLS_SSL_SOME_SUITES_USE_MAC)
-    if( mbedtls_cipher_info_get_mode( cipher_info ) == MBEDTLS_MODE_STREAM ||
-        mbedtls_cipher_info_get_mode( cipher_info ) == MBEDTLS_MODE_CBC )
+    if( ssl_mode == MBEDTLS_SSL_MODE_STREAM ||
+        ssl_mode == MBEDTLS_SSL_MODE_CBC ||
+        ssl_mode == MBEDTLS_SSL_MODE_CBC_ETM )
     {
 #if defined(MBEDTLS_USE_PSA_CRYPTO)
+        size_t block_size = PSA_BLOCK_CIPHER_BLOCK_LENGTH( key_type );
+#else
+        size_t block_size = cipher_info->block_size;
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
+
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
         /* Get MAC length */
         mac_key_len = PSA_HASH_LENGTH(mac_alg);
 #else
@@ -6990,10 +7119,14 @@
         transform->maclen = mac_key_len;
 
         /* IV length */
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+        transform->ivlen = PSA_CIPHER_IV_LENGTH( key_type, alg );
+#else
         transform->ivlen = cipher_info->iv_size;
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
 
         /* Minimum length */
-        if( mbedtls_cipher_info_get_mode( cipher_info ) == MBEDTLS_MODE_STREAM )
+        if( ssl_mode == MBEDTLS_SSL_MODE_STREAM )
             transform->minlen = transform->maclen;
         else
         {
@@ -7004,17 +7137,17 @@
              * 2. IV
              */
 #if defined(MBEDTLS_SSL_ENCRYPT_THEN_MAC)
-            if( encrypt_then_mac == MBEDTLS_SSL_ETM_ENABLED )
+            if( ssl_mode == MBEDTLS_SSL_MODE_CBC_ETM )
             {
                 transform->minlen = transform->maclen
-                                  + cipher_info->block_size;
+                                  + block_size;
             }
             else
 #endif
             {
                 transform->minlen = transform->maclen
-                                  + cipher_info->block_size
-                                  - transform->maclen % cipher_info->block_size;
+                                  + block_size
+                                  - transform->maclen % block_size;
             }
 
             if( tls_version == MBEDTLS_SSL_VERSION_TLS1_2 )
@@ -7096,17 +7229,6 @@
     }
 
 #if defined(MBEDTLS_USE_PSA_CRYPTO)
-    if( ( status = mbedtls_ssl_cipher_to_psa( cipher_info->type,
-                                 transform->taglen,
-                                 &alg,
-                                 &key_type,
-                                 &key_bits ) ) != PSA_SUCCESS )
-    {
-        ret = psa_ssl_status_to_mbedtls( status );
-        MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_cipher_to_psa", ret );
-        goto end;
-    }
-
     transform->psa_alg = alg;
 
     if ( alg != MBEDTLS_SSL_NULL_CIPHER )
diff --git a/library/ssl_tls12_server.c b/library/ssl_tls12_server.c
index 6fd916f..8866d4f 100644
--- a/library/ssl_tls12_server.c
+++ b/library/ssl_tls12_server.c
@@ -1973,20 +1973,13 @@
 }
 #endif /* MBEDTLS_SSL_DTLS_CONNECTION_ID */
 
-#if defined(MBEDTLS_SSL_ENCRYPT_THEN_MAC)
+#if defined(MBEDTLS_SSL_SOME_SUITES_USE_CBC_ETM)
 static void ssl_write_encrypt_then_mac_ext( mbedtls_ssl_context *ssl,
                                             unsigned char *buf,
                                             size_t *olen )
 {
     unsigned char *p = buf;
     const mbedtls_ssl_ciphersuite_t *suite = NULL;
-#if defined(MBEDTLS_USE_PSA_CRYPTO)
-    psa_key_type_t key_type;
-    psa_algorithm_t alg;
-    size_t key_bits;
-#else
-    const mbedtls_cipher_info_t *cipher = NULL;
-#endif /* MBEDTLS_USE_PSA_CRYPTO */
 
     /*
      * RFC 7366: "If a server receives an encrypt-then-MAC request extension
@@ -1994,18 +1987,19 @@
      * with Associated Data (AEAD) ciphersuite, it MUST NOT send an
      * encrypt-then-MAC response extension back to the client."
      */
-    if( ( suite = mbedtls_ssl_ciphersuite_from_id(
-                    ssl->session_negotiate->ciphersuite ) ) == NULL ||
-#if defined(MBEDTLS_USE_PSA_CRYPTO)
-        ( mbedtls_ssl_cipher_to_psa( suite->cipher, 0, &alg,
-                            &key_type, &key_bits ) != PSA_SUCCESS ) ||
-        alg != PSA_ALG_CBC_NO_PADDING )
-#else
-        ( cipher = mbedtls_cipher_info_from_type( suite->cipher ) ) == NULL ||
-        cipher->mode != MBEDTLS_MODE_CBC )
-#endif /* MBEDTLS_USE_PSA_CRYPTO */
-    {
+    suite = mbedtls_ssl_ciphersuite_from_id(
+                    ssl->session_negotiate->ciphersuite );
+    if( suite == NULL )
         ssl->session_negotiate->encrypt_then_mac = MBEDTLS_SSL_ETM_DISABLED;
+    else
+    {
+        mbedtls_ssl_mode_t ssl_mode =
+            mbedtls_ssl_get_mode_from_ciphersuite(
+                ssl->session_negotiate->encrypt_then_mac,
+                suite );
+
+        if( ssl_mode != MBEDTLS_SSL_MODE_CBC_ETM )
+            ssl->session_negotiate->encrypt_then_mac = MBEDTLS_SSL_ETM_DISABLED;
     }
 
     if( ssl->session_negotiate->encrypt_then_mac == MBEDTLS_SSL_ETM_DISABLED )
@@ -2024,7 +2018,7 @@
 
     *olen = 4;
 }
-#endif /* MBEDTLS_SSL_ENCRYPT_THEN_MAC */
+#endif /* MBEDTLS_SSL_SOME_SUITES_USE_CBC_ETM */
 
 #if defined(MBEDTLS_SSL_EXTENDED_MASTER_SECRET)
 static void ssl_write_extended_ms_ext( mbedtls_ssl_context *ssl,
@@ -2593,7 +2587,7 @@
     ext_len += olen;
 #endif
 
-#if defined(MBEDTLS_SSL_ENCRYPT_THEN_MAC)
+#if defined(MBEDTLS_SSL_SOME_SUITES_USE_CBC_ETM)
     ssl_write_encrypt_then_mac_ext( ssl, p + 2 + ext_len, &olen );
     ext_len += olen;
 #endif