aes: xts: Add new context structure

Add a new context structure for XTS. Adjust the API for XTS to use the new
context structure, including tests suites and the benchmark program. Update
Doxgen documentation accordingly.
diff --git a/include/mbedtls/aes.h b/include/mbedtls/aes.h
index 33667d6..e166e9c 100644
--- a/include/mbedtls/aes.h
+++ b/include/mbedtls/aes.h
@@ -89,6 +89,19 @@
 }
 mbedtls_aes_context;
 
+#if defined(MBEDTLS_CIPHER_MODE_XTS)
+/**
+ * \brief The AES XTS context-type definition.
+ */
+typedef struct
+{
+    mbedtls_aes_context crypt; /*!< The AES context to use for AES block
+                                        encryption or decryption. */
+    mbedtls_aes_context tweak; /*!< The AES context used for tweak
+                                        computation. */
+} mbedtls_aes_xts_context;
+#endif /* MBEDTLS_CIPHER_MODE_XTS */
+
 #else  /* MBEDTLS_AES_ALT */
 #include "aes_alt.h"
 #endif /* MBEDTLS_AES_ALT */
@@ -110,6 +123,25 @@
  */
 void mbedtls_aes_free( mbedtls_aes_context *ctx );
 
+#if defined(MBEDTLS_CIPHER_MODE_XTS)
+/**
+ * \brief          This function initializes the specified AES XTS context.
+ *
+ *                 It must be the first API called before using
+ *                 the context.
+ *
+ * \param ctx      The AES XTS context to initialize.
+ */
+void mbedtls_aes_xts_init( mbedtls_aes_xts_context *ctx );
+
+/**
+ * \brief          This function releases and clears the specified AES XTS context.
+ *
+ * \param ctx      The AES XTS context to clear.
+ */
+void mbedtls_aes_xts_free( mbedtls_aes_xts_context *ctx );
+#endif /* MBEDTLS_CIPHER_MODE_XTS */
+
 /**
  * \brief          This function sets the encryption key.
  *
@@ -142,6 +174,44 @@
 int mbedtls_aes_setkey_dec( mbedtls_aes_context *ctx, const unsigned char *key,
                     unsigned int keybits );
 
+#if defined(MBEDTLS_CIPHER_MODE_XTS)
+/**
+ * \brief          This function prepares an XTS context for encryption and
+ *                 sets the encryption key.
+ *
+ * \param ctx      The AES XTS context to which the key should be bound.
+ * \param key      The encryption key. This is comprised of the XTS key1
+ *                 concatenated with the XTS key2.
+ * \param keybits  The size of \p key passed in bits. Valid options are:
+ *                 <ul><li>256 bits (each of key1 and key2 is a 128-bit key)</li>
+ *                 <li>512 bits (each of key1 and key2 is a 256-bit key)</li></ul>
+ *
+ * \return         \c 0 on success.
+ * \return         #MBEDTLS_ERR_AES_INVALID_KEY_LENGTH on failure.
+ */
+int mbedtls_aes_xts_setkey_enc( mbedtls_aes_xts_context *ctx,
+                                const unsigned char *key,
+                                unsigned int keybits );
+
+/**
+ * \brief          This function prepares an XTS context for decryption and
+ *                 sets the decryption key.
+ *
+ * \param ctx      The AES XTS context to which the key should be bound.
+ * \param key      The decryption key. This is comprised of the XTS key1
+ *                 concatenated with the XTS key2.
+ * \param keybits  The size of \p key passed in bits. Valid options are:
+ *                 <ul><li>256 bits (each of key1 and key2 is a 128-bit key)</li>
+ *                 <li>512 bits (each of key1 and key2 is a 256-bit key)</li></ul>
+ *
+ * \return         \c 0 on success.
+ * \return         #MBEDTLS_ERR_AES_INVALID_KEY_LENGTH on failure.
+ */
+int mbedtls_aes_xts_setkey_dec( mbedtls_aes_xts_context *ctx,
+                                const unsigned char *key,
+                                unsigned int keybits );
+#endif /* MBEDTLS_CIPHER_MODE_XTS */
+
 /**
  * \brief          This function performs an AES single-block encryption or
  *                 decryption operation.
@@ -215,30 +285,38 @@
 
 #if defined(MBEDTLS_CIPHER_MODE_XTS)
 /**
- * \brief           AES-XTS buffer encryption/decryption
- *                  Length should be greater or equal than the block size (16
- *                  bytes, 128 bits)
+ * \brief      This function performs an AES-XTS encryption or decryption
+ *             operation for an entire XTS data unit.
  *
- * Warning: The bits_length parameter must given be in bits, not bytes like the
- * other modes
+ *             AES-XTS encrypts or decrypts blocks based on their location as
+ *             defined by a data unit number. The data unit number must be
+ *             provided by \p iv.
  *
- * \param crypt_ctx AES context for encrypting data
- * \param tweak_ctx AES context for xor-ing with data
- * \param mode      MBEDTLS_AES_ENCRYPT or MBEDTLS_AES_DECRYPT
- * \param bits_length length of the input data (in bits)
- * \param iv        initialization vector
- * \param input     buffer holding the input data
- * \param output    buffer holding the output data
+ * \param ctx          The AES XTS context to use for AES XTS operations.
+ * \param mode         The AES operation: #MBEDTLS_AES_ENCRYPT or
+ *                     #MBEDTLS_AES_DECRYPT.
+ * \param bits_length  The length of a data unit in bits.
+ * \param iv           The address of the data unit encoded as an array of 16
+ *                     bytes in little-endian format. For disk encryption, this
+ *                     is typically the index of the block device sector that
+ *                     contains the data.
+ * \param input        The buffer holding the input data (which is an entire
+ *                     data unit). This function reads \p length bytes from \p
+ *                     input.
+ * \param output       The buffer holding the output data (which is an entire
+ *                     data unit). This function writes \p length bytes to \p
+ *                     output.
  *
- * \return         0 if successful, or MBEDTLS_ERR_AES_INVALID_INPUT_LENGTH
+ * \return             \c 0 on success.
+ * \return             #MBEDTLS_ERR_AES_INVALID_INPUT_LENGTH if \p length is
+ *                     smaller than an AES block in size (16 bytes).
  */
-int mbedtls_aes_crypt_xts( mbedtls_aes_context *crypt_ctx,
-                    mbedtls_aes_context *tweak_ctx,
-                    int mode,
-                    size_t bits_length,
-                    unsigned char iv[16],
-                    const unsigned char *input,
-                    unsigned char *output );
+int mbedtls_aes_crypt_xts( mbedtls_aes_xts_context *ctx,
+                           int mode,
+                           size_t bits_length,
+                           const unsigned char iv[16],
+                           const unsigned char *input,
+                           unsigned char *output );
 #endif /* MBEDTLS_CIPHER_MODE_XTS */
 
 #if defined(MBEDTLS_CIPHER_MODE_CFB)
diff --git a/library/aes.c b/library/aes.c
index 9e7b248..ed260a9 100644
--- a/library/aes.c
+++ b/library/aes.c
@@ -521,6 +521,20 @@
     mbedtls_platform_zeroize( ctx, sizeof( mbedtls_aes_context ) );
 }
 
+#if defined(MBEDTLS_CIPHER_MODE_XTS)
+void mbedtls_aes_xts_init( mbedtls_aes_xts_context *ctx )
+{
+    mbedtls_aes_init( &ctx->crypt );
+    mbedtls_aes_init( &ctx->tweak );
+}
+
+void mbedtls_aes_xts_free( mbedtls_aes_xts_context *ctx )
+{
+    mbedtls_aes_free( &ctx->crypt );
+    mbedtls_aes_free( &ctx->tweak );
+}
+#endif /* MBEDTLS_CIPHER_MODE_XTS */
+
 /*
  * AES key schedule (encryption)
  */
@@ -702,6 +716,78 @@
 
     return( ret );
 }
+
+#if defined(MBEDTLS_CIPHER_MODE_XTS)
+static int mbedtls_aes_xts_decode_keys( const unsigned char *key,
+                                        unsigned int keybits,
+                                        const unsigned char **key1,
+                                        unsigned int *key1bits,
+                                        const unsigned char **key2,
+                                        unsigned int *key2bits )
+{
+    const unsigned int half_keybits = keybits / 2;
+    const unsigned int half_keybytes = half_keybits / 8;
+
+    switch( keybits )
+    {
+        case 256: break;
+        case 512: break;
+        default : return( MBEDTLS_ERR_AES_INVALID_KEY_LENGTH );
+    }
+
+    *key1bits = half_keybits;
+    *key2bits = half_keybits;
+    *key1 = &key[0];
+    *key2 = &key[half_keybytes];
+
+    return 0;
+}
+
+int mbedtls_aes_xts_setkey_enc( mbedtls_aes_xts_context *ctx,
+                                const unsigned char *key,
+                                unsigned int keybits)
+{
+    int ret;
+    const unsigned char *key1, *key2;
+    unsigned int key1bits, key2bits;
+
+    ret = mbedtls_aes_xts_decode_keys( key, keybits, &key1, &key1bits,
+                                       &key2, &key2bits );
+    if( ret != 0 )
+        return( ret );
+
+    /* Set the tweak key. Always set tweak key for the encryption mode. */
+    ret = mbedtls_aes_setkey_enc( &ctx->tweak, key2, key2bits );
+    if( ret != 0 )
+        return( ret );
+
+    /* Set crypt key for encryption. */
+    return mbedtls_aes_setkey_enc( &ctx->crypt, key1, key1bits );
+}
+
+int mbedtls_aes_xts_setkey_dec( mbedtls_aes_xts_context *ctx,
+                                const unsigned char *key,
+                                unsigned int keybits)
+{
+    int ret;
+    const unsigned char *key1, *key2;
+    unsigned int key1bits, key2bits;
+
+    ret = mbedtls_aes_xts_decode_keys( key, keybits, &key1, &key1bits,
+                                       &key2, &key2bits );
+    if( ret != 0 )
+        return( ret );
+
+    /* Set the tweak key. Always set tweak key for encryption. */
+    ret = mbedtls_aes_setkey_enc( &ctx->tweak, key2, key2bits );
+    if( ret != 0 )
+        return( ret );
+
+    /* Set crypt key for decryption. */
+    return mbedtls_aes_setkey_dec( &ctx->crypt, key1, key1bits );
+}
+#endif /* MBEDTLS_CIPHER_MODE_XTS */
+
 #endif /* !MBEDTLS_AES_SETKEY_DEC_ALT */
 
 #define AES_FROUND(X0,X1,X2,X3,Y0,Y1,Y2,Y3)         \
@@ -1042,13 +1128,12 @@
 /*
  * AES-XTS buffer encryption/decryption
  */
-int mbedtls_aes_crypt_xts( mbedtls_aes_context *crypt_ctx,
-                    mbedtls_aes_context *tweak_ctx,
-                    int mode,
-                    size_t bits_length,
-                    unsigned char iv[16],
-                    const unsigned char *input,
-                    unsigned char *output )
+int mbedtls_aes_crypt_xts( mbedtls_aes_xts_context *ctx,
+                           int mode,
+                           size_t bits_length,
+                           const unsigned char iv[16],
+                           const unsigned char *input,
+                           unsigned char *output )
 {
     union xts_buf128 {
         uint8_t  u8[16];
@@ -1075,7 +1160,7 @@
         return( MBEDTLS_ERR_AES_INVALID_INPUT_LENGTH );
 
 
-    mbedtls_aes_crypt_ecb( tweak_ctx, MBEDTLS_AES_ENCRYPT, iv, t_buf.u8 );
+    mbedtls_aes_crypt_ecb( &ctx->tweak, MBEDTLS_AES_ENCRYPT, iv, t_buf.u8 );
 
     if( mode == MBEDTLS_AES_DECRYPT && remn )
     {
@@ -1096,7 +1181,7 @@
         scratch.u64[1] = (uint64_t)( inbuf->u64[1] ^ t_buf.u64[1] );
 
         /* CC <- E(Key2,PP) */
-        mbedtls_aes_crypt_ecb( crypt_ctx, mode, scratch.u8, outbuf->u8 );
+        mbedtls_aes_crypt_ecb( &ctx->crypt, mode, scratch.u8, outbuf->u8 );
 
         /* C <- T xor CC */
         outbuf->u64[0] = (uint64_t)( outbuf->u64[0] ^ t_buf.u64[0] );
@@ -1127,7 +1212,7 @@
             scratch.u64[1] = (uint64_t)( cts_scratch.u64[1] ^ t_buf.u64[1] );
 
             /* CC <- E(Key2,PP) */
-            mbedtls_aes_crypt_ecb( crypt_ctx, mode, scratch.u8, scratch.u8 );
+            mbedtls_aes_crypt_ecb( &ctx->crypt, mode, scratch.u8, scratch.u8 );
 
             /* C <- T xor CC */
             outbuf[nblk - 1].u64[0] = (uint64_t)( scratch.u64[0] ^ t_buf.u64[0] );
@@ -1148,7 +1233,7 @@
             scratch.u64[1] = (uint64_t)( inbuf[nblk - 1].u64[1] ^ t_buf.u64[1] );
 
             /* CC <- E(Key2,PP) */
-            mbedtls_aes_crypt_ecb( crypt_ctx, mode, scratch.u8, scratch.u8 );
+            mbedtls_aes_crypt_ecb( &ctx->crypt, mode, scratch.u8, scratch.u8 );
 
             /* C <- T xor CC */
             cts_scratch.u64[0] = (uint64_t)( scratch.u64[0] ^ t_buf.u64[0] );
@@ -1165,7 +1250,7 @@
             scratch.u64[1] = (uint64_t)( inbuf[nblk - 1].u64[1] ^ cts_t_buf.u64[1] );
 
             /* CC <- E(Key2,PP) */
-            mbedtls_aes_crypt_ecb( crypt_ctx, mode, scratch.u8, scratch.u8 );
+            mbedtls_aes_crypt_ecb( &ctx->crypt, mode, scratch.u8, scratch.u8 );
 
             /* C <- T xor CC */
             outbuf[nblk - 1].u64[0] = (uint64_t)( scratch.u64[0] ^ cts_t_buf.u64[0] );
diff --git a/programs/test/benchmark.c b/programs/test/benchmark.c
index 47d36ff..ef83dc1 100644
--- a/programs/test/benchmark.c
+++ b/programs/test/benchmark.c
@@ -432,23 +432,23 @@
     if( todo.aes_xts )
     {
         int keysize;
-        mbedtls_aes_context crypt_ctx, tweak_ctx;
-        mbedtls_aes_init( &crypt_ctx );
-        mbedtls_aes_init( &tweak_ctx );
-        for( keysize = 128; keysize <= 256; keysize += 64 )
+        mbedtls_aes_xts_context ctx;
+
+        mbedtls_aes_xts_init( &ctx );
+        for( keysize = 128; keysize <= 256; keysize += 128 )
         {
             mbedtls_snprintf( title, sizeof( title ), "AES-XTS-%d", keysize );
 
             memset( buf, 0, sizeof( buf ) );
             memset( tmp, 0, sizeof( tmp ) );
-            mbedtls_aes_setkey_enc( &crypt_ctx, tmp, keysize );
-            mbedtls_aes_setkey_enc( &tweak_ctx, tmp, keysize );
+            mbedtls_aes_xts_setkey_enc( &ctx, tmp, keysize * 2 );
 
             TIME_AND_TSC( title,
-                mbedtls_aes_crypt_xts( &crypt_ctx, &tweak_ctx, MBEDTLS_AES_ENCRYPT, BUFSIZE * 8, tmp, buf, buf ) );
+                    mbedtls_aes_crypt_xts( &ctx, MBEDTLS_AES_ENCRYPT, BUFSIZE,
+                                           tmp, buf, buf ) );
+
+            mbedtls_aes_xts_free( &ctx );
         }
-        mbedtls_aes_free( &crypt_ctx );
-        mbedtls_aes_free( &tweak_ctx );
     }
 #endif
 #if defined(MBEDTLS_GCM_C)
diff --git a/tests/suites/test_suite_aes.function b/tests/suites/test_suite_aes.function
index 91f5fa2..e998795 100644
--- a/tests/suites/test_suite_aes.function
+++ b/tests/suites/test_suite_aes.function
@@ -161,20 +161,18 @@
     unsigned char src_str[100] = { 0, };
     unsigned char dst_str[100] = { 0, };
     unsigned char output[100]  = { 0, };
-    mbedtls_aes_context crypt_ctx, tweak_ctx;
+    mbedtls_aes_xts_context ctx;
     int key_len, data_len;
 
-    mbedtls_aes_init( &crypt_ctx );
-    mbedtls_aes_init( &tweak_ctx );
+    mbedtls_aes_xts_init( &ctx );
 
     key_len = unhexify( key_str, hex_key_string );
     unhexify( iv_str, hex_iv_string );
     data_len = unhexify( src_str, hex_src_string );
 
-    mbedtls_aes_setkey_enc( &crypt_ctx, key_str,               ( key_len * 8 ) / 2 );
-    mbedtls_aes_setkey_enc( &tweak_ctx, key_str + key_len / 2, ( key_len * 8 ) / 2 );
+    mbedtls_aes_xts_setkey_enc( &ctx, key_str, key_len * 8 );
 
-    TEST_ASSERT( mbedtls_aes_crypt_xts( &crypt_ctx, &tweak_ctx, MBEDTLS_AES_ENCRYPT, data_unit_len, iv_str, src_str, output ) == xts_result );
+    TEST_ASSERT( mbedtls_aes_crypt_xts( &ctx, MBEDTLS_AES_ENCRYPT, data_unit_len, iv_str, src_str, output ) == xts_result );
     if( xts_result == 0 )
     {
         hexify( dst_str, output, data_len );
@@ -183,8 +181,7 @@
     }
 
 exit:
-    mbedtls_aes_free( &crypt_ctx );
-    mbedtls_aes_free( &tweak_ctx );
+    mbedtls_aes_xts_free( &ctx );
 }
 /* END_CASE */
 
@@ -198,20 +195,18 @@
     unsigned char src_str[100] = { 0, };
     unsigned char dst_str[100] = { 0, };
     unsigned char output[100]  = { 0, };
-    mbedtls_aes_context crypt_ctx, tweak_ctx;
+    mbedtls_aes_xts_context ctx;
     int key_len, data_len;
 
-    mbedtls_aes_init( &crypt_ctx );
-    mbedtls_aes_init( &tweak_ctx );
+    mbedtls_aes_xts_init( &ctx );
 
     key_len = unhexify( key_str, hex_key_string );
     unhexify( iv_str, hex_iv_string );
     data_len = unhexify( src_str, hex_src_string );
 
-    mbedtls_aes_setkey_dec( &crypt_ctx, key_str,               ( key_len * 8 ) / 2 );
-    mbedtls_aes_setkey_enc( &tweak_ctx, key_str + key_len / 2, ( key_len * 8 ) / 2 );
+    mbedtls_aes_xts_setkey_dec( &ctx, key_str, key_len * 8 );
 
-	TEST_ASSERT( mbedtls_aes_crypt_xts( &crypt_ctx, &tweak_ctx, MBEDTLS_AES_DECRYPT, data_unit_len, iv_str, src_str, output ) == xts_result );
+    TEST_ASSERT( mbedtls_aes_crypt_xts( &ctx, MBEDTLS_AES_DECRYPT, data_unit_len, iv_str, src_str, output ) == xts_result );
     if( xts_result == 0 )
     {
         hexify( dst_str, output, data_len );
@@ -220,8 +215,7 @@
     }
 
 exit:
-    mbedtls_aes_free( &crypt_ctx );
-    mbedtls_aes_free( &tweak_ctx );
+    mbedtls_aes_xts_free( &ctx );
 }
 /* END_CASE */