Merge pull request #8130 from davidhorstmann-arm/fix-unnecessary-include-prefixes

Fix unnecessary header prefixes in tests
diff --git a/.travis.yml b/.travis.yml
index 719654c..f411ec3 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -171,7 +171,7 @@
 env:
   global:
     - SEED=1
-    - secure: "JECCru6HASpKZ0OLfHh8f/KXhKkdrCwjquZghd/qbA4ksxsWImjR7KEPERcaPndXEilzhDbKwuFvJiQX2duVgTGoq745YGhLZIjzo1i8tySkceCVd48P8WceYGz+F/bmY7r+m6fFNuxDSoGGSVeA4Lnjvmm8PFUP45YodDV9no4="
+    - secure: "GF/Fde5fkm15T/RNykrjrPV5Uh1KJ70cP308igL6Xkk3eJmqkkmWCe9JqRH12J3TeWw2fu9PYPHt6iFSg6jasgqysfUyg+W03knRT5QNn3h5eHgt36cQJiJr6t3whPrRaiM6U9omE0evm+c0cAwlkA3GGSMw8Z+na4EnKI6OFCo="
 
 install:
   - $PYTHON scripts/min_requirements.py
diff --git a/library/sha3.c b/library/sha3.c
index 4b97a85..dca5790 100644
--- a/library/sha3.c
+++ b/library/sha3.c
@@ -259,10 +259,13 @@
 int mbedtls_sha3_finish(mbedtls_sha3_context *ctx,
                         uint8_t *output, size_t olen)
 {
+    int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+
     /* Catch SHA-3 families, with fixed output length */
     if (ctx->olen > 0) {
         if (ctx->olen > olen) {
-            return MBEDTLS_ERR_SHA3_BAD_INPUT_DATA;
+            ret = MBEDTLS_ERR_SHA3_BAD_INPUT_DATA;
+            goto exit;
         }
         olen = ctx->olen;
     }
@@ -280,7 +283,11 @@
         }
     }
 
-    return 0;
+    ret = 0;
+
+exit:
+    mbedtls_sha3_free(ctx);
+    return ret;
 }
 
 /*
diff --git a/tests/compat.sh b/tests/compat.sh
index 2e03e44..b070e71 100755
--- a/tests/compat.sh
+++ b/tests/compat.sh
@@ -126,10 +126,41 @@
     printf "            \tAlso available: GnuTLS (needs v3.2.15 or higher)\n"
     printf "  -M|--memcheck\tCheck memory leaks and errors.\n"
     printf "  -v|--verbose\tSet verbose output.\n"
+    printf "     --list-test-case\tList all potential test cases (No Execution)\n"
     printf "     --outcome-file\tFile where test outcomes are written\n"
     printf "                   \t(default: \$MBEDTLS_TEST_OUTCOME_FILE, none if empty)\n"
 }
 
+# print_test_case <CLIENT> <SERVER> <STANDARD_CIPHER_SUITE>
+print_test_case() {
+    for i in $3; do
+        uniform_title $1 $2 $i
+        echo $TITLE
+    done
+}
+
+# list_test_case lists all potential test cases in compat.sh without execution
+list_test_case() {
+    reset_ciphersuites
+    for TYPE in $TYPES; do
+        add_common_ciphersuites
+        add_openssl_ciphersuites
+        add_gnutls_ciphersuites
+        add_mbedtls_ciphersuites
+    done
+
+    for VERIFY in $VERIFIES; do
+        VERIF=$(echo $VERIFY | tr '[:upper:]' '[:lower:]')
+        for MODE in $MODES; do
+            print_test_case m O "$O_CIPHERS"
+            print_test_case O m "$O_CIPHERS"
+            print_test_case m G "$G_CIPHERS"
+            print_test_case G m "$G_CIPHERS"
+            print_test_case m m "$M_CIPHERS"
+        done
+    done
+}
+
 get_options() {
     while [ $# -gt 0 ]; do
         case "$1" in
@@ -157,6 +188,12 @@
             -M|--memcheck)
                 MEMCHECK=1
                 ;;
+            # Please check scripts/check_test_cases.py correspondingly
+            # if you have to modify option, --list-test-case
+            --list-test-case)
+                list_test_case
+                exit $?
+                ;;
             --outcome-file)
                 shift; MBEDTLS_TEST_OUTCOME_FILE=$1
                 ;;
@@ -826,6 +863,14 @@
     echo "EXIT: $EXIT" >> $CLI_OUT
 }
 
+# uniform_title <CLIENT> <SERVER> <STANDARD_CIPHER_SUITE>
+# $TITLE is considered as test case description for both --list-test-case and
+# MBEDTLS_TEST_OUTCOME_FILE. This function aims to control the format of
+# each test case description.
+uniform_title() {
+    TITLE="$1->$2 $MODE,$VERIF $3"
+}
+
 # record_outcome <outcome> [<failure-reason>]
 record_outcome() {
     echo "$1"
@@ -863,8 +908,7 @@
 run_client() {
     # announce what we're going to do
     TESTS=$(( $TESTS + 1 ))
-    TITLE="${1%"${1#?}"}->${SERVER_NAME%"${SERVER_NAME#?}"}"
-    TITLE="$TITLE $MODE,$VERIF $2"
+    uniform_title "${1%"${1#?}"}" "${SERVER_NAME%"${SERVER_NAME#?}"}" $2
     DOTS72="........................................................................"
     printf "%s %.*s " "$TITLE" "$((71 - ${#TITLE}))" "$DOTS72"
 
diff --git a/tests/scripts/check_test_cases.py b/tests/scripts/check_test_cases.py
index d84ed04..1395d4d 100755
--- a/tests/scripts/check_test_cases.py
+++ b/tests/scripts/check_test_cases.py
@@ -25,6 +25,7 @@
 import glob
 import os
 import re
+import subprocess
 import sys
 
 class Results:
@@ -111,6 +112,19 @@
                 self.process_test_case(descriptions,
                                        file_name, line_number, description)
 
+    def walk_compat_sh(self, file_name):
+        """Iterate over the test cases compat.sh with a similar format."""
+        descriptions = self.new_per_file_state() # pylint: disable=assignment-from-none
+        compat_cmd = ['sh', file_name, '--list-test-case']
+        compat_output = subprocess.check_output(compat_cmd)
+        # Assume compat.sh is responsible for printing identical format of
+        # test case description between --list-test-case and its OUTCOME.CSV
+        description = compat_output.strip().split(b'\n')
+        # idx indicates the number of test case since there is no line number
+        # in `compat.sh` for each test case.
+        for idx, descrip in enumerate(description):
+            self.process_test_case(descriptions, file_name, idx, descrip)
+
     @staticmethod
     def collect_test_directories():
         """Get the relative path for the TLS and Crypto test directories."""
@@ -136,6 +150,9 @@
             for ssl_opt_file_name in glob.glob(os.path.join(directory, 'opt-testcases',
                                                             '*.sh')):
                 self.walk_ssl_opt_sh(ssl_opt_file_name)
+            compat_sh = os.path.join(directory, 'compat.sh')
+            if os.path.exists(compat_sh):
+                self.walk_compat_sh(compat_sh)
 
 class TestDescriptions(TestDescriptionExplorer):
     """Collect the available test cases."""
diff --git a/tests/suites/test_suite_shax.function b/tests/suites/test_suite_shax.function
index 7dd9166..629e281 100644
--- a/tests/suites/test_suite_shax.function
+++ b/tests/suites/test_suite_shax.function
@@ -176,9 +176,12 @@
     TEST_EQUAL(mbedtls_sha3_starts(&ctx, MBEDTLS_SHA3_NONE), MBEDTLS_ERR_SHA3_BAD_INPUT_DATA);
 
     TEST_EQUAL(mbedtls_sha3_starts(&ctx, MBEDTLS_SHA3_256), 0);
-
     TEST_EQUAL(mbedtls_sha3_finish(&ctx, output, 0), MBEDTLS_ERR_SHA3_BAD_INPUT_DATA);
+
+    TEST_EQUAL(mbedtls_sha3_starts(&ctx, MBEDTLS_SHA3_256), 0);
     TEST_EQUAL(mbedtls_sha3_finish(&ctx, output, 31), MBEDTLS_ERR_SHA3_BAD_INPUT_DATA);
+
+    TEST_EQUAL(mbedtls_sha3_starts(&ctx, MBEDTLS_SHA3_256), 0);
     TEST_EQUAL(mbedtls_sha3_finish(&ctx, output, 32), 0);
 
 exit: