generate all tls13 hrr test cases for compatible mode

Signed-off-by: XiaokangQian <xiaokang.qian@arm.com>
diff --git a/tests/scripts/generate_tls13_compat_tests.py b/tests/scripts/generate_tls13_compat_tests.py
index 5c429db..451187d 100755
--- a/tests/scripts/generate_tls13_compat_tests.py
+++ b/tests/scripts/generate_tls13_compat_tests.py
@@ -70,6 +70,16 @@
     'x448': 0x1e,
 }
 
+SERVER_NAMED_GROUP_IANA_VALUE = {
+    'secp256r1': 0x17,
+    'secp384r1': 0x18,
+    'secp521r1': 0x19,
+    'x448': 0x1e,
+}
+
+CLIENT_NAMED_GROUP_IANA_VALUE = {
+    'x25519': 0x1d,
+}
 
 class TLSProgram(metaclass=abc.ABCMeta):
     """
@@ -77,10 +87,11 @@
     """
     # pylint: disable=too-many-arguments
     def __init__(self, ciphersuite=None, signature_algorithm=None, named_group=None,
-                 cert_sig_alg=None, compat_mode=True):
+                 is_hrr=False, cert_sig_alg=None, compat_mode=True):
         self._ciphers = []
         self._sig_algs = []
         self._named_groups = []
+        self._is_hrr = is_hrr
         self._cert_sig_algs = []
         if ciphersuite:
             self.add_ciphersuites(ciphersuite)
@@ -306,6 +317,12 @@
 
         if self._named_groups:
             named_groups = ','.join(self._named_groups)
+            if self._is_hrr:
+                named_groups += ','
+                self_group_list = list(NAMED_GROUP_IANA_VALUE.keys())
+                self_group_list.remove(self._named_groups[0])
+                self_group = ','.join(self_group_list)
+                named_groups += (self_group)
             ret += ["curves={named_groups}".format(named_groups=named_groups)]
 
         ret = ' '.join(ret)
@@ -344,6 +361,16 @@
         check_strings.append("Verifying peer X.509 certificate... ok")
         return ['-c "{}"'.format(i) for i in check_strings]
 
+    def post_hrr_checks(self):
+        check_strings = ["server hello, chosen ciphersuite: ( {:04x} ) - {}".format(
+                             CIPHER_SUITE_IANA_VALUE[self._ciphers[0]],
+                             self.CIPHER_SUITE[self._ciphers[0]]),
+                         "Certificate Verify: Signature algorithm ( {:04x} )".format(
+                             SIG_ALG_IANA_VALUE[self._sig_algs[0]]),
+                         "<= ssl_tls13_process_server_hello ( HelloRetryRequest )",
+                         "Verifying peer X.509 certificate... ok", ]
+        return ['-c "{}"'.format(i) for i in check_strings]
+
 
 SERVER_CLASSES = {'OpenSSL': OpenSSLServ, 'GnuTLS': GnuTLSServ}
 CLIENT_CLASSES = {'mbedTLS': MbedTLSCli}
@@ -374,6 +401,23 @@
     cmd = prefix.join(cmd)
     return '\n'.join(server_object.pre_checks() + client_object.pre_checks() + [cmd])
 
+def generate_compat_hrr_test(server=None, client=None, cipher=None, sig_alg=None, client_named_group=None, server_named_group=None):
+    """
+    Generate test case with `ssl-opt.sh` format.
+    """
+    name = 'TLS 1.3 {client[0]}->{server[0]}: {cipher},{named_group},{sig_alg}, force hrr'.format(
+        client=client, server=server, cipher=cipher, sig_alg=sig_alg, named_group=server_named_group)
+    server_object = SERVER_CLASSES[server](cipher, sig_alg, server_named_group)
+    client_object = CLIENT_CLASSES[client](cipher, sig_alg, client_named_group, True)
+
+    cmd = ['run_test "{}"'.format(name), '"{}"'.format(
+        server_object.cmd()), '"{}"'.format(client_object.cmd()), '0']
+    cmd += server_object.post_checks()
+    cmd += client_object.post_hrr_checks()
+    prefix = ' \\\n' + (' '*9)
+    cmd = prefix.join(cmd)
+    return '\n'.join(server_object.pre_checks() + client_object.pre_checks() + [cmd])
+
 
 SSL_OUTPUT_HEADER = '''#!/bin/sh
 
@@ -417,6 +461,9 @@
     parser.add_argument('-a', '--generate-all-tls13-compat-tests', action='store_true',
                         default=False, help='Generate all available tls13 compat tests')
 
+    parser.add_argument('-r', '--generate-hrr-tls13-compat-tests', action='store_true',
+                        default=False, help='Generate all hrr tls13 compat tests')
+
     parser.add_argument('--list-ciphers', action='store_true',
                         default=False, help='List supported ciphersuites')
 
@@ -448,6 +495,14 @@
                         default=list(NAMED_GROUP_IANA_VALUE.keys())[0],
                         help='Choose cipher suite for test')
 
+    parser.add_argument('client_named_group', choices=CLIENT_NAMED_GROUP_IANA_VALUE.keys(), nargs='?',
+                        default=list(CLIENT_NAMED_GROUP_IANA_VALUE.keys())[0],
+                        help='Choose cipher suite for test')
+
+    parser.add_argument('server_named_group', choices=SERVER_NAMED_GROUP_IANA_VALUE.keys(), nargs='?',
+                        default=list(SERVER_NAMED_GROUP_IANA_VALUE.keys())[0],
+                        help='Choose cipher suite for test')
+
     args = parser.parse_args()
 
     def get_all_test_cases():
@@ -461,6 +516,16 @@
             yield generate_compat_test(cipher=cipher, sig_alg=sig_alg, named_group=named_group,
                                        server=server, client=client)
 
+    def get_hrr_test_cases():
+        for cipher, sig_alg, client_named_group, server_named_group, server, client in \
+            itertools.product(CIPHER_SUITE_IANA_VALUE.keys(), SIG_ALG_IANA_VALUE.keys(),
+                              NAMED_GROUP_IANA_VALUE.keys(),NAMED_GROUP_IANA_VALUE.keys(), SERVER_CLASSES.keys(),
+                              CLIENT_CLASSES.keys()):
+            if client_named_group != server_named_group:
+                yield generate_compat_hrr_test(cipher=cipher, sig_alg=sig_alg,
+                                               client_named_group=client_named_group,
+                                               server_named_group=server_named_group,
+                                               server=server, client=client)
 
     if args.generate_all_tls13_compat_tests:
         if args.output:
@@ -473,6 +538,17 @@
             print('\n\n'.join(get_all_test_cases()))
         return 0
 
+    if args.generate_hrr_tls13_compat_tests:
+        if args.output:
+            with open(args.output, 'w', encoding="utf-8") as f:
+                f.write(SSL_OUTPUT_HEADER.format(
+                    filename=os.path.basename(args.output)))
+                f.write('\n\n'.join(get_hrr_test_cases()))
+                f.write('\n')
+        else:
+            print('\n'.join(get_hrr_test_cases()))
+        return 0
+
     if args.list_ciphers or args.list_sig_algs or args.list_named_groups \
             or args.list_servers or args.list_clients:
         if args.list_ciphers:
@@ -487,8 +563,15 @@
             print(*CLIENT_CLASSES.keys())
         return 0
 
-    print(generate_compat_test(server=args.server, client=args.client, sig_alg=args.sig_alg,
-                               cipher=args.cipher, named_group=args.named_group))
+    if args.generate_all_tls13_compat_tests:
+        print(generate_compat_test(server=args.server, client=args.client, sig_alg=args.sig_alg,
+                                   cipher=args.cipher, named_group=args.named_group))
+
+    if args.generate_hrr_tls13_compat_tests:
+        print(generate_compat_hrr_test(server=args.server, client=args.client,
+                                       sig_alg=args.sig_alg, cipher=args.cipher,
+                                       client_named_group=args.client_named_group,
+                                       server_named_group=args.server_named_group))
     return 0