Squashed commit upgrading to mbedtls-3.4.0

Squash merging branch import/mbedtls-3.4.0

8225713449d3 ("libmbedtls: fix unrecognized compiler option")
f03730842d7b ("core: ltc: configure internal MD5")
2b0d0c50127c ("core: ltc: configure internal SHA-1 and SHA-224")
0e48a6e17630 ("libmedtls: core: update to mbedTLS 3.4.0 API")
049882b143af ("libutee: update to mbedTLS 3.4.0 API")
982307bf6169 ("core: LTC mpi_desc.c: update to mbedTLS 3.4.0 API")
33218e9eff7b ("ta: pkcs11: update to mbedTLS 3.4.0 API")
6956420cc064 ("libmbedtls: fix cipher_wrap.c for NIST AES Key Wrap mode")
ad67ef0b43fd ("libmbedtls: fix cipher_wrap.c for chacha20 and chachapoly")
7300f4d97bbf ("libmbedtls: add fault mitigation in mbedtls_rsa_rsassa_pkcs1_v15_verify()")
cec89b62a86d ("libmbedtls: add fault mitigation in mbedtls_rsa_rsassa_pss_verify_ext()")
e7e048796c44 ("libmbedtls: add SM2 curve")
096beff2cd31 ("libmbedtls: mbedtls_mpi_exp_mod(): optimize mempool usage")
7108668efd3f ("libmbedtls: mbedtls_mpi_exp_mod(): reduce stack usage")
0ba4eb8d0572 ("libmbedtls: mbedtls_mpi_exp_mod() initialize W")
3fd6ecf00382 ("libmbedtls: fix no CRT issue")
d5ea7e9e9aa7 ("libmbedtls: add interfaces in mbedtls for context memory operation")
2b0fb3f1fa3d ("libmedtls: mpi_miller_rabin: increase count limit")
2c3301ab99bb ("libmbedtls: add mbedtls_mpi_init_mempool()")
9a111f0da04b ("libmbedtls: make mbedtls_mpi_mont*() available")
804fe3a374f5 ("mbedtls: configure mbedtls to reach for config")
b28a41531427 ("mbedtls: remove default include/mbedtls/config.h")
dfafe507bbef ("Import mbedtls-3.4.0")

Signed-off-by: Jens Wiklander <jens.wiklander@linaro.org>
Acked-by: Jerome Forissier <jerome.forissier@linaro.org>
Tested-by: Jerome Forissier <jerome.forissier@linaro.org> (vexpress-qemu_armv8a)
diff --git a/lib/libmbedtls/mbedtls/library/ccm.c b/lib/libmbedtls/mbedtls/library/ccm.c
index a21a37f..36c999e 100644
--- a/lib/libmbedtls/mbedtls/library/ccm.c
+++ b/lib/libmbedtls/mbedtls/library/ccm.c
@@ -36,154 +36,137 @@
 
 #include <string.h>
 
-#if defined(MBEDTLS_SELF_TEST) && defined(MBEDTLS_AES_C)
 #if defined(MBEDTLS_PLATFORM_C)
 #include "mbedtls/platform.h"
 #else
+#if defined(MBEDTLS_SELF_TEST) && defined(MBEDTLS_AES_C)
 #include <stdio.h>
 #define mbedtls_printf printf
-#endif /* MBEDTLS_PLATFORM_C */
 #endif /* MBEDTLS_SELF_TEST && MBEDTLS_AES_C */
+#endif /* MBEDTLS_PLATFORM_C */
 
 #if !defined(MBEDTLS_CCM_ALT)
 
-#define CCM_VALIDATE_RET( cond ) \
-    MBEDTLS_INTERNAL_VALIDATE_RET( cond, MBEDTLS_ERR_CCM_BAD_INPUT )
-#define CCM_VALIDATE( cond ) \
-    MBEDTLS_INTERNAL_VALIDATE( cond )
-
-#define CCM_ENCRYPT 0
-#define CCM_DECRYPT 1
 
 /*
  * Initialize context
  */
-void mbedtls_ccm_init( mbedtls_ccm_context *ctx )
+void mbedtls_ccm_init(mbedtls_ccm_context *ctx)
 {
-    CCM_VALIDATE( ctx != NULL );
-    memset( ctx, 0, sizeof( mbedtls_ccm_context ) );
+    memset(ctx, 0, sizeof(mbedtls_ccm_context));
 }
 
-int mbedtls_ccm_setkey( mbedtls_ccm_context *ctx,
-                        mbedtls_cipher_id_t cipher,
-                        const unsigned char *key,
-                        unsigned int keybits )
+int mbedtls_ccm_setkey(mbedtls_ccm_context *ctx,
+                       mbedtls_cipher_id_t cipher,
+                       const unsigned char *key,
+                       unsigned int keybits)
 {
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     const mbedtls_cipher_info_t *cipher_info;
 
-    CCM_VALIDATE_RET( ctx != NULL );
-    CCM_VALIDATE_RET( key != NULL );
-
-    cipher_info = mbedtls_cipher_info_from_values( cipher, keybits,
-                                                   MBEDTLS_MODE_ECB );
-    if( cipher_info == NULL )
-        return( MBEDTLS_ERR_CCM_BAD_INPUT );
-
-    if( cipher_info->block_size != 16 )
-        return( MBEDTLS_ERR_CCM_BAD_INPUT );
-
-    mbedtls_cipher_free( &ctx->cipher_ctx );
-
-    if( ( ret = mbedtls_cipher_setup( &ctx->cipher_ctx, cipher_info ) ) != 0 )
-        return( ret );
-
-    if( ( ret = mbedtls_cipher_setkey( &ctx->cipher_ctx, key, keybits,
-                               MBEDTLS_ENCRYPT ) ) != 0 )
-    {
-        return( ret );
+    cipher_info = mbedtls_cipher_info_from_values(cipher, keybits,
+                                                  MBEDTLS_MODE_ECB);
+    if (cipher_info == NULL) {
+        return MBEDTLS_ERR_CCM_BAD_INPUT;
     }
 
-    return( 0 );
+    if (cipher_info->block_size != 16) {
+        return MBEDTLS_ERR_CCM_BAD_INPUT;
+    }
+
+    mbedtls_cipher_free(&ctx->cipher_ctx);
+
+    if ((ret = mbedtls_cipher_setup(&ctx->cipher_ctx, cipher_info)) != 0) {
+        return ret;
+    }
+
+    if ((ret = mbedtls_cipher_setkey(&ctx->cipher_ctx, key, keybits,
+                                     MBEDTLS_ENCRYPT)) != 0) {
+        return ret;
+    }
+
+    return 0;
 }
 
 /*
  * Free context
  */
-void mbedtls_ccm_free( mbedtls_ccm_context *ctx )
+void mbedtls_ccm_free(mbedtls_ccm_context *ctx)
 {
-    if( ctx == NULL )
+    if (ctx == NULL) {
         return;
-    mbedtls_cipher_free( &ctx->cipher_ctx );
-    mbedtls_platform_zeroize( ctx, sizeof( mbedtls_ccm_context ) );
+    }
+    mbedtls_cipher_free(&ctx->cipher_ctx);
+    mbedtls_platform_zeroize(ctx, sizeof(mbedtls_ccm_context));
 }
 
-/*
- * Macros for common operations.
- * Results in smaller compiled code than static inline functions.
- */
-
-/*
- * Update the CBC-MAC state in y using a block in b
- * (Always using b as the source helps the compiler optimise a bit better.)
- */
-#define UPDATE_CBC_MAC                                                      \
-    for( i = 0; i < 16; i++ )                                               \
-        y[i] ^= b[i];                                                       \
-                                                                            \
-    if( ( ret = mbedtls_cipher_update( &ctx->cipher_ctx, y, 16, y, &olen ) ) != 0 ) \
-        return( ret );
+#define CCM_STATE__CLEAR                0
+#define CCM_STATE__STARTED              (1 << 0)
+#define CCM_STATE__LENGTHS_SET          (1 << 1)
+#define CCM_STATE__AUTH_DATA_STARTED    (1 << 2)
+#define CCM_STATE__AUTH_DATA_FINISHED   (1 << 3)
+#define CCM_STATE__ERROR                (1 << 4)
 
 /*
  * Encrypt or decrypt a partial block with CTR
- * Warning: using b for temporary storage! src and dst must not be b!
- * This avoids allocating one more 16 bytes buffer while allowing src == dst.
  */
-#define CTR_CRYPT( dst, src, len  )                                            \
-    do                                                                  \
-    {                                                                   \
-        if( ( ret = mbedtls_cipher_update( &ctx->cipher_ctx, ctr,       \
-                                           16, b, &olen ) ) != 0 )      \
-        {                                                               \
-            return( ret );                                              \
-        }                                                               \
-                                                                        \
-        for( i = 0; i < (len); i++ )                                    \
-            (dst)[i] = (src)[i] ^ b[i];                                 \
-    } while( 0 )
+static int mbedtls_ccm_crypt(mbedtls_ccm_context *ctx,
+                             size_t offset, size_t use_len,
+                             const unsigned char *input,
+                             unsigned char *output)
+{
+    size_t olen = 0;
+    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+    unsigned char tmp_buf[16] = { 0 };
 
-/*
- * Authenticated encryption or decryption
- */
-static int ccm_auth_crypt( mbedtls_ccm_context *ctx, int mode, size_t length,
-                           const unsigned char *iv, size_t iv_len,
-                           const unsigned char *add, size_t add_len,
-                           const unsigned char *input, unsigned char *output,
-                           unsigned char *tag, size_t tag_len )
+    if ((ret = mbedtls_cipher_update(&ctx->cipher_ctx, ctx->ctr, 16, tmp_buf,
+                                     &olen)) != 0) {
+        ctx->state |= CCM_STATE__ERROR;
+        mbedtls_platform_zeroize(tmp_buf, sizeof(tmp_buf));
+        return ret;
+    }
+
+    mbedtls_xor(output, input, tmp_buf + offset, use_len);
+
+    mbedtls_platform_zeroize(tmp_buf, sizeof(tmp_buf));
+    return ret;
+}
+
+static void mbedtls_ccm_clear_state(mbedtls_ccm_context *ctx)
+{
+    ctx->state = CCM_STATE__CLEAR;
+    memset(ctx->y, 0, 16);
+    memset(ctx->ctr, 0, 16);
+}
+
+static int ccm_calculate_first_block_if_ready(mbedtls_ccm_context *ctx)
 {
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
     unsigned char i;
-    unsigned char q;
     size_t len_left, olen;
-    unsigned char b[16];
-    unsigned char y[16];
-    unsigned char ctr[16];
-    const unsigned char *src;
-    unsigned char *dst;
 
-    /*
-     * Check length requirements: SP800-38C A.1
-     * Additional requirement: a < 2^16 - 2^8 to simplify the code.
-     * 'length' checked later (when writing it to the first block)
-     *
-     * Also, loosen the requirements to enable support for CCM* (IEEE 802.15.4).
+    /* length calculation can be done only after both
+     * mbedtls_ccm_starts() and mbedtls_ccm_set_lengths() have been executed
      */
-    if( tag_len == 2 || tag_len > 16 || tag_len % 2 != 0 )
-        return( MBEDTLS_ERR_CCM_BAD_INPUT );
+    if (!(ctx->state & CCM_STATE__STARTED) || !(ctx->state & CCM_STATE__LENGTHS_SET)) {
+        return 0;
+    }
 
-    /* Also implies q is within bounds */
-    if( iv_len < 7 || iv_len > 13 )
-        return( MBEDTLS_ERR_CCM_BAD_INPUT );
-
-    if( add_len >= 0xFF00 )
-        return( MBEDTLS_ERR_CCM_BAD_INPUT );
-
-    q = 16 - 1 - (unsigned char) iv_len;
+    /* CCM expects non-empty tag.
+     * CCM* allows empty tag. For CCM* without tag, ignore plaintext length.
+     */
+    if (ctx->tag_len == 0) {
+        if (ctx->mode == MBEDTLS_CCM_STAR_ENCRYPT || ctx->mode == MBEDTLS_CCM_STAR_DECRYPT) {
+            ctx->plaintext_len = 0;
+        } else {
+            return MBEDTLS_ERR_CCM_BAD_INPUT;
+        }
+    }
 
     /*
-     * First block B_0:
+     * First block:
      * 0        .. 0        flags
-     * 1        .. iv_len   nonce (aka iv)
+     * 1        .. iv_len   nonce (aka iv)  - set by: mbedtls_ccm_starts()
      * iv_len+1 .. 15       length
      *
      * With flags as (bits):
@@ -192,57 +175,41 @@
      * 5 .. 3   (t - 2) / 2
      * 2 .. 0   q - 1
      */
-    b[0] = 0;
-    b[0] |= ( add_len > 0 ) << 6;
-    b[0] |= ( ( tag_len - 2 ) / 2 ) << 3;
-    b[0] |= q - 1;
+    ctx->y[0] |= (ctx->add_len > 0) << 6;
+    ctx->y[0] |= ((ctx->tag_len - 2) / 2) << 3;
+    ctx->y[0] |= ctx->q - 1;
 
-    memcpy( b + 1, iv, iv_len );
-
-    for( i = 0, len_left = length; i < q; i++, len_left >>= 8 )
-        b[15-i] = MBEDTLS_BYTE_0( len_left );
-
-    if( len_left > 0 )
-        return( MBEDTLS_ERR_CCM_BAD_INPUT );
-
-
-    /* Start CBC-MAC with first block */
-    memset( y, 0, 16 );
-    UPDATE_CBC_MAC;
-
-    /*
-     * If there is additional data, update CBC-MAC with
-     * add_len, add, 0 (padding to a block boundary)
-     */
-    if( add_len > 0 )
-    {
-        size_t use_len;
-        len_left = add_len;
-        src = add;
-
-        memset( b, 0, 16 );
-        MBEDTLS_PUT_UINT16_BE( add_len, b, 0 );
-
-        use_len = len_left < 16 - 2 ? len_left : 16 - 2;
-        memcpy( b + 2, src, use_len );
-        len_left -= use_len;
-        src += use_len;
-
-        UPDATE_CBC_MAC;
-
-        while( len_left > 0 )
-        {
-            use_len = len_left > 16 ? 16 : len_left;
-
-            memset( b, 0, 16 );
-            memcpy( b, src, use_len );
-            UPDATE_CBC_MAC;
-
-            len_left -= use_len;
-            src += use_len;
-        }
+    for (i = 0, len_left = ctx->plaintext_len; i < ctx->q; i++, len_left >>= 8) {
+        ctx->y[15-i] = MBEDTLS_BYTE_0(len_left);
     }
 
+    if (len_left > 0) {
+        ctx->state |= CCM_STATE__ERROR;
+        return MBEDTLS_ERR_CCM_BAD_INPUT;
+    }
+
+    /* Start CBC-MAC with first block*/
+    if ((ret = mbedtls_cipher_update(&ctx->cipher_ctx, ctx->y, 16, ctx->y, &olen)) != 0) {
+        ctx->state |= CCM_STATE__ERROR;
+        return ret;
+    }
+
+    return 0;
+}
+
+int mbedtls_ccm_starts(mbedtls_ccm_context *ctx,
+                       int mode,
+                       const unsigned char *iv,
+                       size_t iv_len)
+{
+    /* Also implies q is within bounds */
+    if (iv_len < 7 || iv_len > 13) {
+        return MBEDTLS_ERR_CCM_BAD_INPUT;
+    }
+
+    ctx->mode = mode;
+    ctx->q = 16 - 1 - (unsigned char) iv_len;
+
     /*
      * Prepare counter block for encryption:
      * 0        .. 0        flags
@@ -253,163 +220,377 @@
      * 7 .. 3   0
      * 2 .. 0   q - 1
      */
-    ctr[0] = q - 1;
-    memcpy( ctr + 1, iv, iv_len );
-    memset( ctr + 1 + iv_len, 0, q );
-    ctr[15] = 1;
+    memset(ctx->ctr, 0, 16);
+    ctx->ctr[0] = ctx->q - 1;
+    memcpy(ctx->ctr + 1, iv, iv_len);
+    memset(ctx->ctr + 1 + iv_len, 0, ctx->q);
+    ctx->ctr[15] = 1;
 
     /*
-     * Authenticate and {en,de}crypt the message.
-     *
-     * The only difference between encryption and decryption is
-     * the respective order of authentication and {en,de}cryption.
+     * See ccm_calculate_first_block_if_ready() for block layout description
      */
-    len_left = length;
-    src = input;
-    dst = output;
+    memcpy(ctx->y + 1, iv, iv_len);
 
-    while( len_left > 0 )
-    {
-        size_t use_len = len_left > 16 ? 16 : len_left;
+    ctx->state |= CCM_STATE__STARTED;
+    return ccm_calculate_first_block_if_ready(ctx);
+}
 
-        if( mode == CCM_ENCRYPT )
-        {
-            memset( b, 0, 16 );
-            memcpy( b, src, use_len );
-            UPDATE_CBC_MAC;
+int mbedtls_ccm_set_lengths(mbedtls_ccm_context *ctx,
+                            size_t total_ad_len,
+                            size_t plaintext_len,
+                            size_t tag_len)
+{
+    /*
+     * Check length requirements: SP800-38C A.1
+     * Additional requirement: a < 2^16 - 2^8 to simplify the code.
+     * 'length' checked later (when writing it to the first block)
+     *
+     * Also, loosen the requirements to enable support for CCM* (IEEE 802.15.4).
+     */
+    if (tag_len == 2 || tag_len > 16 || tag_len % 2 != 0) {
+        return MBEDTLS_ERR_CCM_BAD_INPUT;
+    }
+
+    if (total_ad_len >= 0xFF00) {
+        return MBEDTLS_ERR_CCM_BAD_INPUT;
+    }
+
+    ctx->plaintext_len = plaintext_len;
+    ctx->add_len = total_ad_len;
+    ctx->tag_len = tag_len;
+    ctx->processed = 0;
+
+    ctx->state |= CCM_STATE__LENGTHS_SET;
+    return ccm_calculate_first_block_if_ready(ctx);
+}
+
+int mbedtls_ccm_update_ad(mbedtls_ccm_context *ctx,
+                          const unsigned char *add,
+                          size_t add_len)
+{
+    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+    size_t olen, use_len, offset;
+
+    if (ctx->state & CCM_STATE__ERROR) {
+        return MBEDTLS_ERR_CCM_BAD_INPUT;
+    }
+
+    if (add_len > 0) {
+        if (ctx->state & CCM_STATE__AUTH_DATA_FINISHED) {
+            return MBEDTLS_ERR_CCM_BAD_INPUT;
         }
 
-        CTR_CRYPT( dst, src, use_len );
+        if (!(ctx->state & CCM_STATE__AUTH_DATA_STARTED)) {
+            if (add_len > ctx->add_len) {
+                return MBEDTLS_ERR_CCM_BAD_INPUT;
+            }
 
-        if( mode == CCM_DECRYPT )
-        {
-            memset( b, 0, 16 );
-            memcpy( b, dst, use_len );
-            UPDATE_CBC_MAC;
+            ctx->y[0] ^= (unsigned char) ((ctx->add_len >> 8) & 0xFF);
+            ctx->y[1] ^= (unsigned char) ((ctx->add_len) & 0xFF);
+
+            ctx->state |= CCM_STATE__AUTH_DATA_STARTED;
+        } else if (ctx->processed + add_len > ctx->add_len) {
+            return MBEDTLS_ERR_CCM_BAD_INPUT;
         }
 
-        dst += use_len;
-        src += use_len;
-        len_left -= use_len;
+        while (add_len > 0) {
+            offset = (ctx->processed + 2) % 16; /* account for y[0] and y[1]
+                                                 * holding total auth data length */
+            use_len = 16 - offset;
 
-        /*
-         * Increment counter.
-         * No need to check for overflow thanks to the length check above.
-         */
-        for( i = 0; i < q; i++ )
-            if( ++ctr[15-i] != 0 )
-                break;
+            if (use_len > add_len) {
+                use_len = add_len;
+            }
+
+            mbedtls_xor(ctx->y + offset, ctx->y + offset, add, use_len);
+
+            ctx->processed += use_len;
+            add_len -= use_len;
+            add += use_len;
+
+            if (use_len + offset == 16 || ctx->processed == ctx->add_len) {
+                if ((ret =
+                         mbedtls_cipher_update(&ctx->cipher_ctx, ctx->y, 16, ctx->y, &olen)) != 0) {
+                    ctx->state |= CCM_STATE__ERROR;
+                    return ret;
+                }
+            }
+        }
+
+        if (ctx->processed == ctx->add_len) {
+            ctx->state |= CCM_STATE__AUTH_DATA_FINISHED;
+            ctx->processed = 0; // prepare for mbedtls_ccm_update()
+        }
+    }
+
+    return 0;
+}
+
+int mbedtls_ccm_update(mbedtls_ccm_context *ctx,
+                       const unsigned char *input, size_t input_len,
+                       unsigned char *output, size_t output_size,
+                       size_t *output_len)
+{
+    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+    unsigned char i;
+    size_t use_len, offset, olen;
+
+    unsigned char local_output[16];
+
+    if (ctx->state & CCM_STATE__ERROR) {
+        return MBEDTLS_ERR_CCM_BAD_INPUT;
+    }
+
+    /* Check against plaintext length only if performing operation with
+     * authentication
+     */
+    if (ctx->tag_len != 0 && ctx->processed + input_len > ctx->plaintext_len) {
+        return MBEDTLS_ERR_CCM_BAD_INPUT;
+    }
+
+    if (output_size < input_len) {
+        return MBEDTLS_ERR_CCM_BAD_INPUT;
+    }
+    *output_len = input_len;
+
+    ret = 0;
+
+    while (input_len > 0) {
+        offset = ctx->processed % 16;
+
+        use_len = 16 - offset;
+
+        if (use_len > input_len) {
+            use_len = input_len;
+        }
+
+        ctx->processed += use_len;
+
+        if (ctx->mode == MBEDTLS_CCM_ENCRYPT || \
+            ctx->mode == MBEDTLS_CCM_STAR_ENCRYPT) {
+            mbedtls_xor(ctx->y + offset, ctx->y + offset, input, use_len);
+
+            if (use_len + offset == 16 || ctx->processed == ctx->plaintext_len) {
+                if ((ret =
+                         mbedtls_cipher_update(&ctx->cipher_ctx, ctx->y, 16, ctx->y, &olen)) != 0) {
+                    ctx->state |= CCM_STATE__ERROR;
+                    goto exit;
+                }
+            }
+
+            ret = mbedtls_ccm_crypt(ctx, offset, use_len, input, output);
+            if (ret != 0) {
+                goto exit;
+            }
+        }
+
+        if (ctx->mode == MBEDTLS_CCM_DECRYPT || \
+            ctx->mode == MBEDTLS_CCM_STAR_DECRYPT) {
+            /* Since output may be in shared memory, we cannot be sure that
+             * it will contain what we wrote to it. Therefore, we should avoid using
+             * it as input to any operations.
+             * Write decrypted data to local_output to avoid using output variable as
+             * input in the XOR operation for Y.
+             */
+            ret = mbedtls_ccm_crypt(ctx, offset, use_len, input, local_output);
+            if (ret != 0) {
+                goto exit;
+            }
+
+            mbedtls_xor(ctx->y + offset, ctx->y + offset, local_output, use_len);
+
+            memcpy(output, local_output, use_len);
+            mbedtls_platform_zeroize(local_output, 16);
+
+            if (use_len + offset == 16 || ctx->processed == ctx->plaintext_len) {
+                if ((ret =
+                         mbedtls_cipher_update(&ctx->cipher_ctx, ctx->y, 16, ctx->y, &olen)) != 0) {
+                    ctx->state |= CCM_STATE__ERROR;
+                    goto exit;
+                }
+            }
+        }
+
+        if (use_len + offset == 16 || ctx->processed == ctx->plaintext_len) {
+            for (i = 0; i < ctx->q; i++) {
+                if (++(ctx->ctr)[15-i] != 0) {
+                    break;
+                }
+            }
+        }
+
+        input_len -= use_len;
+        input += use_len;
+        output += use_len;
+    }
+
+exit:
+    mbedtls_platform_zeroize(local_output, 16);
+
+    return ret;
+}
+
+int mbedtls_ccm_finish(mbedtls_ccm_context *ctx,
+                       unsigned char *tag, size_t tag_len)
+{
+    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+    unsigned char i;
+
+    if (ctx->state & CCM_STATE__ERROR) {
+        return MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+    }
+
+    if (ctx->add_len > 0 && !(ctx->state & CCM_STATE__AUTH_DATA_FINISHED)) {
+        return MBEDTLS_ERR_CCM_BAD_INPUT;
+    }
+
+    if (ctx->plaintext_len > 0 && ctx->processed != ctx->plaintext_len) {
+        return MBEDTLS_ERR_CCM_BAD_INPUT;
     }
 
     /*
      * Authentication: reset counter and crypt/mask internal tag
      */
-    for( i = 0; i < q; i++ )
-        ctr[15-i] = 0;
+    for (i = 0; i < ctx->q; i++) {
+        ctx->ctr[15-i] = 0;
+    }
 
-    CTR_CRYPT( y, y, 16 );
-    memcpy( tag, y, tag_len );
+    ret = mbedtls_ccm_crypt(ctx, 0, 16, ctx->y, ctx->y);
+    if (ret != 0) {
+        return ret;
+    }
+    if (tag != NULL) {
+        memcpy(tag, ctx->y, tag_len);
+    }
+    mbedtls_ccm_clear_state(ctx);
 
-    return( 0 );
+    return 0;
+}
+
+/*
+ * Authenticated encryption or decryption
+ */
+static int ccm_auth_crypt(mbedtls_ccm_context *ctx, int mode, size_t length,
+                          const unsigned char *iv, size_t iv_len,
+                          const unsigned char *add, size_t add_len,
+                          const unsigned char *input, unsigned char *output,
+                          unsigned char *tag, size_t tag_len)
+{
+    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+    size_t olen;
+
+    if ((ret = mbedtls_ccm_starts(ctx, mode, iv, iv_len)) != 0) {
+        return ret;
+    }
+
+    if ((ret = mbedtls_ccm_set_lengths(ctx, add_len, length, tag_len)) != 0) {
+        return ret;
+    }
+
+    if ((ret = mbedtls_ccm_update_ad(ctx, add, add_len)) != 0) {
+        return ret;
+    }
+
+    if ((ret = mbedtls_ccm_update(ctx, input, length,
+                                  output, length, &olen)) != 0) {
+        return ret;
+    }
+
+    if ((ret = mbedtls_ccm_finish(ctx, tag, tag_len)) != 0) {
+        return ret;
+    }
+
+    return 0;
 }
 
 /*
  * Authenticated encryption
  */
-int mbedtls_ccm_star_encrypt_and_tag( mbedtls_ccm_context *ctx, size_t length,
-                         const unsigned char *iv, size_t iv_len,
-                         const unsigned char *add, size_t add_len,
-                         const unsigned char *input, unsigned char *output,
-                         unsigned char *tag, size_t tag_len )
+int mbedtls_ccm_star_encrypt_and_tag(mbedtls_ccm_context *ctx, size_t length,
+                                     const unsigned char *iv, size_t iv_len,
+                                     const unsigned char *add, size_t add_len,
+                                     const unsigned char *input, unsigned char *output,
+                                     unsigned char *tag, size_t tag_len)
 {
-    CCM_VALIDATE_RET( ctx != NULL );
-    CCM_VALIDATE_RET( iv != NULL );
-    CCM_VALIDATE_RET( add_len == 0 || add != NULL );
-    CCM_VALIDATE_RET( length == 0 || input != NULL );
-    CCM_VALIDATE_RET( length == 0 || output != NULL );
-    CCM_VALIDATE_RET( tag_len == 0 || tag != NULL );
-    return( ccm_auth_crypt( ctx, CCM_ENCRYPT, length, iv, iv_len,
-                            add, add_len, input, output, tag, tag_len ) );
+    return ccm_auth_crypt(ctx, MBEDTLS_CCM_STAR_ENCRYPT, length, iv, iv_len,
+                          add, add_len, input, output, tag, tag_len);
 }
 
-int mbedtls_ccm_encrypt_and_tag( mbedtls_ccm_context *ctx, size_t length,
-                         const unsigned char *iv, size_t iv_len,
-                         const unsigned char *add, size_t add_len,
-                         const unsigned char *input, unsigned char *output,
-                         unsigned char *tag, size_t tag_len )
+int mbedtls_ccm_encrypt_and_tag(mbedtls_ccm_context *ctx, size_t length,
+                                const unsigned char *iv, size_t iv_len,
+                                const unsigned char *add, size_t add_len,
+                                const unsigned char *input, unsigned char *output,
+                                unsigned char *tag, size_t tag_len)
 {
-    CCM_VALIDATE_RET( ctx != NULL );
-    CCM_VALIDATE_RET( iv != NULL );
-    CCM_VALIDATE_RET( add_len == 0 || add != NULL );
-    CCM_VALIDATE_RET( length == 0 || input != NULL );
-    CCM_VALIDATE_RET( length == 0 || output != NULL );
-    CCM_VALIDATE_RET( tag_len == 0 || tag != NULL );
-    if( tag_len == 0 )
-        return( MBEDTLS_ERR_CCM_BAD_INPUT );
-
-    return( mbedtls_ccm_star_encrypt_and_tag( ctx, length, iv, iv_len, add,
-                add_len, input, output, tag, tag_len ) );
+    return ccm_auth_crypt(ctx, MBEDTLS_CCM_ENCRYPT, length, iv, iv_len,
+                          add, add_len, input, output, tag, tag_len);
 }
 
 /*
  * Authenticated decryption
  */
-int mbedtls_ccm_star_auth_decrypt( mbedtls_ccm_context *ctx, size_t length,
-                      const unsigned char *iv, size_t iv_len,
-                      const unsigned char *add, size_t add_len,
-                      const unsigned char *input, unsigned char *output,
-                      const unsigned char *tag, size_t tag_len )
+static int mbedtls_ccm_compare_tags(const unsigned char *tag1,
+                                    const unsigned char *tag2,
+                                    size_t tag_len)
 {
-    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
-    unsigned char check_tag[16];
     unsigned char i;
     int diff;
 
-    CCM_VALIDATE_RET( ctx != NULL );
-    CCM_VALIDATE_RET( iv != NULL );
-    CCM_VALIDATE_RET( add_len == 0 || add != NULL );
-    CCM_VALIDATE_RET( length == 0 || input != NULL );
-    CCM_VALIDATE_RET( length == 0 || output != NULL );
-    CCM_VALIDATE_RET( tag_len == 0 || tag != NULL );
-
-    if( ( ret = ccm_auth_crypt( ctx, CCM_DECRYPT, length,
-                                iv, iv_len, add, add_len,
-                                input, output, check_tag, tag_len ) ) != 0 )
-    {
-        return( ret );
-    }
-
     /* Check tag in "constant-time" */
-    for( diff = 0, i = 0; i < tag_len; i++ )
-        diff |= tag[i] ^ check_tag[i];
-
-    if( diff != 0 )
-    {
-        mbedtls_platform_zeroize( output, length );
-        return( MBEDTLS_ERR_CCM_AUTH_FAILED );
+    for (diff = 0, i = 0; i < tag_len; i++) {
+        diff |= tag1[i] ^ tag2[i];
     }
 
-    return( 0 );
+    if (diff != 0) {
+        return MBEDTLS_ERR_CCM_AUTH_FAILED;
+    }
+
+    return 0;
 }
 
-int mbedtls_ccm_auth_decrypt( mbedtls_ccm_context *ctx, size_t length,
-                      const unsigned char *iv, size_t iv_len,
-                      const unsigned char *add, size_t add_len,
-                      const unsigned char *input, unsigned char *output,
-                      const unsigned char *tag, size_t tag_len )
+static int ccm_auth_decrypt(mbedtls_ccm_context *ctx, int mode, size_t length,
+                            const unsigned char *iv, size_t iv_len,
+                            const unsigned char *add, size_t add_len,
+                            const unsigned char *input, unsigned char *output,
+                            const unsigned char *tag, size_t tag_len)
 {
-    CCM_VALIDATE_RET( ctx != NULL );
-    CCM_VALIDATE_RET( iv != NULL );
-    CCM_VALIDATE_RET( add_len == 0 || add != NULL );
-    CCM_VALIDATE_RET( length == 0 || input != NULL );
-    CCM_VALIDATE_RET( length == 0 || output != NULL );
-    CCM_VALIDATE_RET( tag_len == 0 || tag != NULL );
+    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+    unsigned char check_tag[16];
 
-    if( tag_len == 0 )
-        return( MBEDTLS_ERR_CCM_BAD_INPUT );
+    if ((ret = ccm_auth_crypt(ctx, mode, length,
+                              iv, iv_len, add, add_len,
+                              input, output, check_tag, tag_len)) != 0) {
+        return ret;
+    }
 
-    return( mbedtls_ccm_star_auth_decrypt( ctx, length, iv, iv_len, add,
-                add_len, input, output, tag, tag_len ) );
+    if ((ret = mbedtls_ccm_compare_tags(tag, check_tag, tag_len)) != 0) {
+        mbedtls_platform_zeroize(output, length);
+        return ret;
+    }
+
+    return 0;
+}
+
+int mbedtls_ccm_star_auth_decrypt(mbedtls_ccm_context *ctx, size_t length,
+                                  const unsigned char *iv, size_t iv_len,
+                                  const unsigned char *add, size_t add_len,
+                                  const unsigned char *input, unsigned char *output,
+                                  const unsigned char *tag, size_t tag_len)
+{
+    return ccm_auth_decrypt(ctx, MBEDTLS_CCM_STAR_DECRYPT, length,
+                            iv, iv_len, add, add_len,
+                            input, output, tag, tag_len);
+}
+
+int mbedtls_ccm_auth_decrypt(mbedtls_ccm_context *ctx, size_t length,
+                             const unsigned char *iv, size_t iv_len,
+                             const unsigned char *add, size_t add_len,
+                             const unsigned char *input, unsigned char *output,
+                             const unsigned char *tag, size_t tag_len)
+{
+    return ccm_auth_decrypt(ctx, MBEDTLS_CCM_DECRYPT, length,
+                            iv, iv_len, add, add_len,
+                            input, output, tag, tag_len);
 }
 #endif /* !MBEDTLS_CCM_ALT */
 
@@ -446,7 +627,7 @@
     0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37,
 };
 
-static const size_t iv_len_test_data [NB_TESTS] = { 7, 8,  12 };
+static const size_t iv_len_test_data[NB_TESTS] = { 7, 8,  12 };
 static const size_t add_len_test_data[NB_TESTS] = { 8, 16, 20 };
 static const size_t msg_len_test_data[NB_TESTS] = { 4, 16, 24 };
 static const size_t tag_len_test_data[NB_TESTS] = { 4, 6,  8  };
@@ -462,7 +643,7 @@
         0x48, 0x43, 0x92, 0xfb, 0xc1, 0xb0, 0x99, 0x51 }
 };
 
-int mbedtls_ccm_self_test( int verbose )
+int mbedtls_ccm_self_test(int verbose)
 {
     mbedtls_ccm_context ctx;
     /*
@@ -475,70 +656,72 @@
     size_t i;
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
 
-    mbedtls_ccm_init( &ctx );
+    mbedtls_ccm_init(&ctx);
 
-    if( mbedtls_ccm_setkey( &ctx, MBEDTLS_CIPHER_ID_AES, key_test_data,
-                            8 * sizeof key_test_data ) != 0 )
-    {
-        if( verbose != 0 )
-            mbedtls_printf( "  CCM: setup failed" );
-
-        return( 1 );
-    }
-
-    for( i = 0; i < NB_TESTS; i++ )
-    {
-        if( verbose != 0 )
-            mbedtls_printf( "  CCM-AES #%u: ", (unsigned int) i + 1 );
-
-        memset( plaintext, 0, CCM_SELFTEST_PT_MAX_LEN );
-        memset( ciphertext, 0, CCM_SELFTEST_CT_MAX_LEN );
-        memcpy( plaintext, msg_test_data, msg_len_test_data[i] );
-
-        ret = mbedtls_ccm_encrypt_and_tag( &ctx, msg_len_test_data[i],
-                                           iv_test_data, iv_len_test_data[i],
-                                           ad_test_data, add_len_test_data[i],
-                                           plaintext, ciphertext,
-                                           ciphertext + msg_len_test_data[i],
-                                           tag_len_test_data[i] );
-
-        if( ret != 0 ||
-            memcmp( ciphertext, res_test_data[i],
-                    msg_len_test_data[i] + tag_len_test_data[i] ) != 0 )
-        {
-            if( verbose != 0 )
-                mbedtls_printf( "failed\n" );
-
-            return( 1 );
-        }
-        memset( plaintext, 0, CCM_SELFTEST_PT_MAX_LEN );
-
-        ret = mbedtls_ccm_auth_decrypt( &ctx, msg_len_test_data[i],
-                                        iv_test_data, iv_len_test_data[i],
-                                        ad_test_data, add_len_test_data[i],
-                                        ciphertext, plaintext,
-                                        ciphertext + msg_len_test_data[i],
-                                        tag_len_test_data[i] );
-
-        if( ret != 0 ||
-            memcmp( plaintext, msg_test_data, msg_len_test_data[i] ) != 0 )
-        {
-            if( verbose != 0 )
-                mbedtls_printf( "failed\n" );
-
-            return( 1 );
+    if (mbedtls_ccm_setkey(&ctx, MBEDTLS_CIPHER_ID_AES, key_test_data,
+                           8 * sizeof(key_test_data)) != 0) {
+        if (verbose != 0) {
+            mbedtls_printf("  CCM: setup failed");
         }
 
-        if( verbose != 0 )
-            mbedtls_printf( "passed\n" );
+        return 1;
     }
 
-    mbedtls_ccm_free( &ctx );
+    for (i = 0; i < NB_TESTS; i++) {
+        if (verbose != 0) {
+            mbedtls_printf("  CCM-AES #%u: ", (unsigned int) i + 1);
+        }
 
-    if( verbose != 0 )
-        mbedtls_printf( "\n" );
+        memset(plaintext, 0, CCM_SELFTEST_PT_MAX_LEN);
+        memset(ciphertext, 0, CCM_SELFTEST_CT_MAX_LEN);
+        memcpy(plaintext, msg_test_data, msg_len_test_data[i]);
 
-    return( 0 );
+        ret = mbedtls_ccm_encrypt_and_tag(&ctx, msg_len_test_data[i],
+                                          iv_test_data, iv_len_test_data[i],
+                                          ad_test_data, add_len_test_data[i],
+                                          plaintext, ciphertext,
+                                          ciphertext + msg_len_test_data[i],
+                                          tag_len_test_data[i]);
+
+        if (ret != 0 ||
+            memcmp(ciphertext, res_test_data[i],
+                   msg_len_test_data[i] + tag_len_test_data[i]) != 0) {
+            if (verbose != 0) {
+                mbedtls_printf("failed\n");
+            }
+
+            return 1;
+        }
+        memset(plaintext, 0, CCM_SELFTEST_PT_MAX_LEN);
+
+        ret = mbedtls_ccm_auth_decrypt(&ctx, msg_len_test_data[i],
+                                       iv_test_data, iv_len_test_data[i],
+                                       ad_test_data, add_len_test_data[i],
+                                       ciphertext, plaintext,
+                                       ciphertext + msg_len_test_data[i],
+                                       tag_len_test_data[i]);
+
+        if (ret != 0 ||
+            memcmp(plaintext, msg_test_data, msg_len_test_data[i]) != 0) {
+            if (verbose != 0) {
+                mbedtls_printf("failed\n");
+            }
+
+            return 1;
+        }
+
+        if (verbose != 0) {
+            mbedtls_printf("passed\n");
+        }
+    }
+
+    mbedtls_ccm_free(&ctx);
+
+    if (verbose != 0) {
+        mbedtls_printf("\n");
+    }
+
+    return 0;
 }
 
 #endif /* MBEDTLS_SELF_TEST && MBEDTLS_AES_C */