Split tag handling out of cipher_finish()
diff --git a/include/polarssl/cipher.h b/include/polarssl/cipher.h
index 7dea1e2..dc5a41c 100644
--- a/include/polarssl/cipher.h
+++ b/include/polarssl/cipher.h
@@ -519,11 +519,6 @@
* \param ctx Generic cipher context
* \param output buffer to write data to. Needs block_size available.
* \param olen length of the data written to the output buffer.
- * \param tag Ignore by non-AEAD ciphers. For AEAD ciphers:
- * - on encryption: buffer to write the tag;
- * - on decryption: tag to verify.
- * May be NULL if tag_len is zero.
- * \param tag_len Length of the tag to write/check for AEAD ciphers.
*
* \returns 0 on success, POLARSSL_ERR_CIPHER_BAD_INPUT_DATA if
* parameter verification fails,
@@ -533,8 +528,34 @@
* while decrypting or a cipher specific error code.
*/
int cipher_finish( cipher_context_t *ctx,
- unsigned char *output, size_t *olen,
- unsigned char *tag, size_t tag_len );
+ unsigned char *output, size_t *olen );
+
+/**
+ * \brief Write tag for AEAD ciphers.
+ * No effect for other ciphers.
+ * Must be called after cipher_finish().
+ *
+ * \param tag buffer to write the tag
+ * \param tag_len Length of the tag to write
+ *
+ * \return 0 on success, or a specific error code.
+ */
+int cipher_write_tag( cipher_context_t *ctx,
+ unsigned char *tag, size_t tag_len );
+
+/**
+ * \brief Check tag for AEAD ciphers.
+ * No effect for other ciphers.
+ * Calling time depends on the cipher:
+ * for GCM, must be called after cipher_finish().
+ *
+ * \param tag Buffer holding the tag
+ * \param tag_len Length of the tag to check
+ *
+ * \return 0 on success, or a specific error code.
+ */
+int cipher_check_tag( cipher_context_t *ctx,
+ const unsigned char *tag, size_t tag_len );
/**
* \brief Checkup routine
diff --git a/library/cipher.c b/library/cipher.c
index f8e2841..a90e2dc 100644
--- a/library/cipher.c
+++ b/library/cipher.c
@@ -777,8 +777,7 @@
}
int cipher_finish( cipher_context_t *ctx,
- unsigned char *output, size_t *olen,
- unsigned char *tag, size_t tag_len )
+ unsigned char *output, size_t *olen )
{
int ret = 0;
@@ -797,10 +796,6 @@
#if defined(POLARSSL_GCM_C)
if( POLARSSL_MODE_GCM == ctx->cipher_info->mode )
{
- unsigned char check_tag[16];
- size_t i;
- int diff;
-
if( 0 != ( ret = gcm_update( ctx->cipher_ctx,
ctx->unprocessed_len, ctx->unprocessed_data,
output ) ) )
@@ -810,29 +805,8 @@
*olen += ctx->unprocessed_len;
- if( 0 != ( ret = gcm_finish( ctx->cipher_ctx, check_tag, tag_len ) ) )
- return( ret );
-
- /* On encryption, write the tag */
- if( POLARSSL_ENCRYPT == ctx->operation )
- {
- if( tag_len != 0 )
- memcpy( tag, check_tag, tag_len );
- return( 0 );
- }
-
- /* On decryption, check the tag (in "constant-time") */
- for( diff = 0, i = 0; i < tag_len; i++ )
- diff |= tag[i] ^ check_tag[i];
-
- if( diff != 0 )
- return( POLARSSL_ERR_GCM_AUTH_FAILED );
-
return( 0 );
}
-#else
- ((void) tag);
- ((void) tag_len);
#endif
if( POLARSSL_MODE_CBC == ctx->cipher_info->mode )
@@ -930,6 +904,51 @@
return 0;
}
+int cipher_write_tag( cipher_context_t *ctx,
+ unsigned char *tag, size_t tag_len )
+{
+ if( NULL == ctx || NULL == ctx->cipher_info || NULL == tag )
+ return POLARSSL_ERR_CIPHER_BAD_INPUT_DATA;
+
+ if( POLARSSL_MODE_GCM != ctx->cipher_info->mode )
+ return( 0 );
+
+ if( POLARSSL_ENCRYPT != ctx->operation )
+ return POLARSSL_ERR_CIPHER_BAD_INPUT_DATA;
+
+ return gcm_finish( ctx->cipher_ctx, tag, tag_len );
+}
+
+int cipher_check_tag( cipher_context_t *ctx,
+ const unsigned char *tag, size_t tag_len )
+{
+ int ret;
+ unsigned char check_tag[16];
+ size_t i;
+ int diff;
+
+ if( NULL == ctx || NULL == ctx->cipher_info )
+ return POLARSSL_ERR_CIPHER_BAD_INPUT_DATA;
+
+ if( POLARSSL_MODE_GCM != ctx->cipher_info->mode )
+ return( 0 );
+
+ if( POLARSSL_DECRYPT != ctx->operation || tag_len > sizeof( check_tag ) )
+ return POLARSSL_ERR_CIPHER_BAD_INPUT_DATA;
+
+ if( 0 != ( ret = gcm_finish( ctx->cipher_ctx, check_tag, tag_len ) ) )
+ return( ret );
+
+ /* On decryption, check the tag (in "constant-time") */
+ for( diff = 0, i = 0; i < tag_len; i++ )
+ diff |= tag[i] ^ check_tag[i];
+
+ if( diff != 0 )
+ return( POLARSSL_ERR_GCM_AUTH_FAILED );
+
+ return( 0 );
+}
+
#if defined(POLARSSL_SELF_TEST)
#include <stdio.h>
diff --git a/library/pkcs12.c b/library/pkcs12.c
index 98ebd88..335af7e 100644
--- a/library/pkcs12.c
+++ b/library/pkcs12.c
@@ -196,11 +196,8 @@
goto exit;
}
- if( ( ret = cipher_finish( &cipher_ctx, output + olen, &olen, NULL, 0 ) )
- != 0 )
- {
+ if( ( ret = cipher_finish( &cipher_ctx, output + olen, &olen ) ) != 0 )
ret = POLARSSL_ERR_PKCS12_PASSWORD_MISMATCH;
- }
exit:
cipher_free_ctx( &cipher_ctx );
diff --git a/library/pkcs5.c b/library/pkcs5.c
index a27d4fb..0b9830d 100644
--- a/library/pkcs5.c
+++ b/library/pkcs5.c
@@ -199,11 +199,8 @@
goto exit;
}
- if( ( ret = cipher_finish( &cipher_ctx, output + olen, &olen, NULL, 0 ) )
- != 0 )
- {
+ if( ( ret = cipher_finish( &cipher_ctx, output + olen, &olen ) ) != 0 )
ret = POLARSSL_ERR_PKCS5_PASSWORD_MISMATCH;
- }
exit:
md_free_ctx( &md_ctx );
diff --git a/programs/aes/crypt_and_hash.c b/programs/aes/crypt_and_hash.c
index c46713d..6caaad8 100644
--- a/programs/aes/crypt_and_hash.c
+++ b/programs/aes/crypt_and_hash.c
@@ -343,7 +343,7 @@
}
}
- if( cipher_finish( &cipher_ctx, output, &olen, NULL, 0 ) != 0 )
+ if( cipher_finish( &cipher_ctx, output, &olen ) != 0 )
{
fprintf( stderr, "cipher_finish() returned error\n" );
goto exit;
@@ -461,7 +461,7 @@
/*
* Write the final block of data
*/
- cipher_finish( &cipher_ctx, output, &olen, NULL, 0 );
+ cipher_finish( &cipher_ctx, output, &olen );
if( fwrite( output, 1, olen, fout ) != olen )
{
diff --git a/tests/suites/test_suite_cipher.function b/tests/suites/test_suite_cipher.function
index aa82daa..5d32bc3 100644
--- a/tests/suites/test_suite_cipher.function
+++ b/tests/suites/test_suite_cipher.function
@@ -76,10 +76,11 @@
total_len < length &&
total_len + cipher_get_block_size( &ctx_enc ) > length ) );
- TEST_ASSERT( 0 == cipher_finish( &ctx_enc, encbuf + outlen, &outlen,
- tag, 16 ) );
+ TEST_ASSERT( 0 == cipher_finish( &ctx_enc, encbuf + outlen, &outlen ) );
total_len += outlen;
+ TEST_ASSERT( 0 == cipher_write_tag( &ctx_enc, tag, 16 ) );
+
TEST_ASSERT( total_len == length ||
( total_len % cipher_get_block_size( &ctx_enc ) == 0 &&
total_len > length &&
@@ -94,10 +95,11 @@
total_len < length &&
total_len + cipher_get_block_size( &ctx_dec ) >= length ) );
- TEST_ASSERT( 0 == cipher_finish( &ctx_dec, decbuf + outlen, &outlen,
- tag, 16 ) );
+ TEST_ASSERT( 0 == cipher_finish( &ctx_dec, decbuf + outlen, &outlen ) );
total_len += outlen;
+ TEST_ASSERT( 0 == cipher_check_tag( &ctx_dec, tag, 16 ) );
+
TEST_ASSERT( total_len == length );
TEST_ASSERT( 0 == memcmp(inbuf, decbuf, length) );
@@ -145,7 +147,7 @@
/* encode length number of bytes from inbuf */
TEST_ASSERT( 0 == cipher_update( &ctx, inbuf, length, encbuf, &outlen ) );
- TEST_ASSERT( ret == cipher_finish( &ctx, encbuf + outlen, &outlen, NULL, 0 ) );
+ TEST_ASSERT( ret == cipher_finish( &ctx, encbuf + outlen, &outlen ) );
/* done */
TEST_ASSERT( 0 == cipher_free_ctx( &ctx ) );
@@ -192,7 +194,7 @@
TEST_ASSERT( 0 == cipher_update( &ctx_dec, encbuf, 0, decbuf, &outlen ) );
TEST_ASSERT( 0 == outlen );
TEST_ASSERT( POLARSSL_ERR_CIPHER_FULL_BLOCK_EXPECTED == cipher_finish(
- &ctx_dec, decbuf + outlen, &outlen, NULL, 0 ) );
+ &ctx_dec, decbuf + outlen, &outlen ) );
TEST_ASSERT( 0 == outlen );
TEST_ASSERT( 0 == cipher_free_ctx( &ctx_dec ) );
@@ -259,8 +261,7 @@
totaloutlen < length &&
totaloutlen + cipher_get_block_size( &ctx_enc ) > length ) );
- TEST_ASSERT( 0 == cipher_finish( &ctx_enc, encbuf + totaloutlen, &outlen,
- NULL, 0 ) );
+ TEST_ASSERT( 0 == cipher_finish( &ctx_enc, encbuf + totaloutlen, &outlen ) );
totaloutlen += outlen;
TEST_ASSERT( totaloutlen == length ||
( totaloutlen % cipher_get_block_size( &ctx_enc ) == 0 &&
@@ -276,8 +277,7 @@
totaloutlen < length &&
totaloutlen + cipher_get_block_size( &ctx_dec ) >= length ) );
- TEST_ASSERT( 0 == cipher_finish( &ctx_dec, decbuf + outlen, &outlen,
- NULL, 0 ) );
+ TEST_ASSERT( 0 == cipher_finish( &ctx_dec, decbuf + outlen, &outlen ) );
totaloutlen += outlen;
TEST_ASSERT( totaloutlen == length );