Merge pull request #7009 from mprse/csr_write_san

Added ability to include the SubjectAltName extension to a CSR - v.2
diff --git a/ChangeLog.d/san_csr.txt b/ChangeLog.d/san_csr.txt
new file mode 100644
index 0000000..b5c6cf3
--- /dev/null
+++ b/ChangeLog.d/san_csr.txt
@@ -0,0 +1,2 @@
+Features
+   * Add support to include the SubjectAltName extension to a CSR.
diff --git a/include/mbedtls/asn1write.h b/include/mbedtls/asn1write.h
index acfc073..da73759 100644
--- a/include/mbedtls/asn1write.h
+++ b/include/mbedtls/asn1write.h
@@ -35,6 +35,15 @@
         (g) += ret;                                 \
     } while (0)
 
+#define MBEDTLS_ASN1_CHK_CLEANUP_ADD(g, f)                      \
+    do                                                  \
+    {                                                   \
+        if ((ret = (f)) < 0)                         \
+        goto cleanup;                              \
+        else                                            \
+        (g) += ret;                                 \
+    } while (0)
+
 #ifdef __cplusplus
 extern "C" {
 #endif
diff --git a/include/mbedtls/x509_csr.h b/include/mbedtls/x509_csr.h
index e376000..f3f9e13 100644
--- a/include/mbedtls/x509_csr.h
+++ b/include/mbedtls/x509_csr.h
@@ -83,6 +83,12 @@
 }
 mbedtls_x509write_csr;
 
+typedef struct mbedtls_x509_san_list {
+    mbedtls_x509_subject_alternative_name node;
+    struct mbedtls_x509_san_list *next;
+}
+mbedtls_x509_san_list;
+
 #if defined(MBEDTLS_X509_CSR_PARSE_C)
 /**
  * \brief          Load a Certificate Signing Request (CSR) in DER format
@@ -229,6 +235,20 @@
 int mbedtls_x509write_csr_set_key_usage(mbedtls_x509write_csr *ctx, unsigned char key_usage);
 
 /**
+ * \brief           Set Subject Alternative Name
+ *
+ * \param ctx       CSR context to use
+ * \param san_list  List of SAN values
+ *
+ * \return          0 if successful, or MBEDTLS_ERR_X509_ALLOC_FAILED
+ *
+ * \note            Only "dnsName", "uniformResourceIdentifier" and "otherName",
+ *                  as defined in RFC 5280, are supported.
+ */
+int mbedtls_x509write_csr_set_subject_alternative_name(mbedtls_x509write_csr *ctx,
+                                                       const mbedtls_x509_san_list *san_list);
+
+/**
  * \brief           Set the Netscape Cert Type flags
  *                  (e.g. MBEDTLS_X509_NS_CERT_TYPE_SSL_CLIENT | MBEDTLS_X509_NS_CERT_TYPE_EMAIL)
  *
diff --git a/library/x509write_csr.c b/library/x509write_csr.c
index d8d8e99..deb6617 100644
--- a/library/x509write_csr.c
+++ b/library/x509write_csr.c
@@ -26,6 +26,7 @@
 
 #if defined(MBEDTLS_X509_CSR_WRITE_C)
 
+#include "mbedtls/x509.h"
 #include "mbedtls/x509_csr.h"
 #include "mbedtls/asn1write.h"
 #include "mbedtls/error.h"
@@ -85,6 +86,105 @@
                                       critical, val, val_len);
 }
 
+int mbedtls_x509write_csr_set_subject_alternative_name(mbedtls_x509write_csr *ctx,
+                                                       const mbedtls_x509_san_list *san_list)
+{
+    int ret = 0;
+    const mbedtls_x509_san_list *cur;
+    unsigned char *buf;
+    unsigned char *p;
+    size_t len;
+    size_t buflen = 0;
+
+    /* Determine the maximum size of the SubjectAltName list */
+    for (cur = san_list; cur != NULL; cur = cur->next) {
+        /* Calculate size of the required buffer */
+        switch (cur->node.type) {
+            case MBEDTLS_X509_SAN_DNS_NAME:
+            case MBEDTLS_X509_SAN_UNIFORM_RESOURCE_IDENTIFIER:
+            case MBEDTLS_X509_SAN_IP_ADDRESS:
+                /* length of value for each name entry,
+                 * maximum 4 bytes for the length field,
+                 * 1 byte for the tag/type.
+                 */
+                buflen += cur->node.san.unstructured_name.len + 4 + 1;
+                break;
+
+            default:
+                /* Not supported - skip. */
+                break;
+        }
+    }
+
+    /* Add the extra length field and tag */
+    buflen += 4 + 1;
+
+    /* Allocate buffer */
+    buf = mbedtls_calloc(1, buflen);
+    if (buf == NULL) {
+        return MBEDTLS_ERR_ASN1_ALLOC_FAILED;
+    }
+
+    mbedtls_platform_zeroize(buf, buflen);
+    p = buf + buflen;
+
+    /* Write ASN.1-based structure */
+    cur = san_list;
+    len = 0;
+    while (cur != NULL) {
+        switch (cur->node.type) {
+            case MBEDTLS_X509_SAN_DNS_NAME:
+            case MBEDTLS_X509_SAN_UNIFORM_RESOURCE_IDENTIFIER:
+            case MBEDTLS_X509_SAN_IP_ADDRESS:
+            {
+                const unsigned char *unstructured_name =
+                    (const unsigned char *) cur->node.san.unstructured_name.p;
+                size_t unstructured_name_len = cur->node.san.unstructured_name.len;
+
+                MBEDTLS_ASN1_CHK_CLEANUP_ADD(len,
+                                             mbedtls_asn1_write_raw_buffer(
+                                                 &p, buf,
+                                                 unstructured_name, unstructured_name_len));
+                MBEDTLS_ASN1_CHK_CLEANUP_ADD(len, mbedtls_asn1_write_len(
+                                                 &p, buf, unstructured_name_len));
+                MBEDTLS_ASN1_CHK_CLEANUP_ADD(len,
+                                             mbedtls_asn1_write_tag(
+                                                 &p, buf,
+                                                 MBEDTLS_ASN1_CONTEXT_SPECIFIC | cur->node.type));
+            }
+            break;
+            default:
+                /* Skip unsupported names. */
+                break;
+        }
+        cur = cur->next;
+    }
+
+    MBEDTLS_ASN1_CHK_CLEANUP_ADD(len, mbedtls_asn1_write_len(&p, buf, len));
+    MBEDTLS_ASN1_CHK_CLEANUP_ADD(len,
+                                 mbedtls_asn1_write_tag(&p, buf,
+                                                        MBEDTLS_ASN1_CONSTRUCTED |
+                                                        MBEDTLS_ASN1_SEQUENCE));
+
+    ret = mbedtls_x509write_csr_set_extension(
+        ctx,
+        MBEDTLS_OID_SUBJECT_ALT_NAME,
+        MBEDTLS_OID_SIZE(MBEDTLS_OID_SUBJECT_ALT_NAME),
+        0,
+        buf + buflen - len,
+        len);
+
+    /* If we exceeded the allocated buffer it means that maximum size of the SubjectAltName list
+     * was incorrectly calculated and memory is corrupted. */
+    if (p < buf) {
+        ret = MBEDTLS_ERR_ASN1_LENGTH_MISMATCH;
+    }
+
+cleanup:
+    mbedtls_free(buf);
+    return ret;
+}
+
 int mbedtls_x509write_csr_set_key_usage(mbedtls_x509write_csr *ctx, unsigned char key_usage)
 {
     unsigned char buf[4] = { 0 };
diff --git a/programs/x509/cert_req.c b/programs/x509/cert_req.c
index 8ef5932..5241438 100644
--- a/programs/x509/cert_req.c
+++ b/programs/x509/cert_req.c
@@ -63,6 +63,11 @@
     "    debug_level=%%d      default: 0 (disabled)\n"  \
     "    output_file=%%s      default: cert.req\n"      \
     "    subject_name=%%s     default: CN=Cert,O=mbed TLS,C=UK\n"   \
+    "    san=%%s              default: (none)\n"       \
+    "                        Comma-separated-list of values:\n"     \
+    "                          DNS:value\n"            \
+    "                          URI:value\n"            \
+    "                          IP:value (Only IPv4 is supported)\n"             \
     "    key_usage=%%s        default: (empty)\n"       \
     "                        Comma-separated-list of values:\n"     \
     "                          digital_signature\n"     \
@@ -96,18 +101,31 @@
  * global options
  */
 struct options {
-    const char *filename;       /* filename of the key file             */
-    const char *password;       /* password for the key file            */
-    int debug_level;            /* level of debugging                   */
-    const char *output_file;    /* where to store the constructed key file  */
-    const char *subject_name;   /* subject name for certificate request */
-    unsigned char key_usage;    /* key usage flags                      */
-    int force_key_usage;        /* Force adding the KeyUsage extension  */
-    unsigned char ns_cert_type; /* NS cert type                         */
-    int force_ns_cert_type;     /* Force adding NsCertType extension    */
-    mbedtls_md_type_t md_alg;   /* Hash algorithm used for signature.   */
+    const char *filename;             /* filename of the key file             */
+    const char *password;             /* password for the key file            */
+    int debug_level;                  /* level of debugging                   */
+    const char *output_file;          /* where to store the constructed key file  */
+    const char *subject_name;         /* subject name for certificate request   */
+    mbedtls_x509_san_list *san_list;  /* subjectAltName for certificate request */
+    unsigned char key_usage;          /* key usage flags                      */
+    int force_key_usage;              /* Force adding the KeyUsage extension  */
+    unsigned char ns_cert_type;       /* NS cert type                         */
+    int force_ns_cert_type;           /* Force adding NsCertType extension    */
+    mbedtls_md_type_t md_alg;         /* Hash algorithm used for signature.   */
 } opt;
 
+static void ip_string_to_bytes(const char *str, uint8_t *bytes, int maxBytes)
+{
+    for (int i = 0; i < maxBytes; i++) {
+        bytes[i] = (uint8_t) strtoul(str, NULL, 16);
+        str = strchr(str, '.');
+        if (str == NULL || *str == '\0') {
+            break;
+        }
+        str++;
+    }
+}
+
 int write_certificate_request(mbedtls_x509write_csr *req, const char *output_file,
                               int (*f_rng)(void *, unsigned char *, size_t),
                               void *p_rng)
@@ -145,11 +163,12 @@
     mbedtls_pk_context key;
     char buf[1024];
     int i;
-    char *p, *q, *r;
+    char *p, *q, *r, *r2;
     mbedtls_x509write_csr req;
     mbedtls_entropy_context entropy;
     mbedtls_ctr_drbg_context ctr_drbg;
     const char *pers = "csr example app";
+    mbedtls_x509_san_list *cur, *prev;
 
     /*
      * Set to sane values
@@ -175,15 +194,14 @@
     opt.ns_cert_type        = DFL_NS_CERT_TYPE;
     opt.force_ns_cert_type  = DFL_FORCE_NS_CERT_TYPE;
     opt.md_alg              = DFL_MD_ALG;
+    opt.san_list            = NULL;
 
     for (i = 1; i < argc; i++) {
-
         p = argv[i];
         if ((q = strchr(p, '=')) == NULL) {
             goto usage;
         }
         *q++ = '\0';
-
         if (strcmp(p, "filename") == 0) {
             opt.filename = q;
         } else if (strcmp(p, "password") == 0) {
@@ -197,6 +215,59 @@
             }
         } else if (strcmp(p, "subject_name") == 0) {
             opt.subject_name = q;
+        } else if (strcmp(p, "san") == 0) {
+            prev = NULL;
+
+            while (q != NULL) {
+                uint8_t ip[4] = { 0 };
+
+                if ((r = strchr(q, ';')) != NULL) {
+                    *r++ = '\0';
+                }
+
+                cur = mbedtls_calloc(1, sizeof(mbedtls_x509_san_list));
+                if (cur == NULL) {
+                    mbedtls_printf("Not enough memory for subjectAltName list\n");
+                    goto usage;
+                }
+
+                cur->next = NULL;
+
+                if ((r2 = strchr(q, ':')) != NULL) {
+                    *r2++ = '\0';
+                }
+
+                if (strcmp(q, "URI") == 0) {
+                    cur->node.type = MBEDTLS_X509_SAN_UNIFORM_RESOURCE_IDENTIFIER;
+                } else if (strcmp(q, "DNS") == 0) {
+                    cur->node.type = MBEDTLS_X509_SAN_DNS_NAME;
+                } else if (strcmp(q, "IP") == 0) {
+                    cur->node.type = MBEDTLS_X509_SAN_IP_ADDRESS;
+                    ip_string_to_bytes(r2, ip, 4);
+                } else {
+                    mbedtls_free(cur);
+                    goto usage;
+                }
+
+                if (strcmp(q, "IP") == 0) {
+                    cur->node.san.unstructured_name.p = (unsigned char *) ip;
+                    cur->node.san.unstructured_name.len = sizeof(ip);
+                } else {
+                    q = r2;
+                    cur->node.san.unstructured_name.p = (unsigned char *) q;
+                    cur->node.san.unstructured_name.len = strlen(q);
+                }
+
+                if (prev == NULL) {
+                    opt.san_list = cur;
+                } else {
+                    prev->next = cur;
+                }
+
+                prev = cur;
+                q = r;
+            }
+
         } else if (strcmp(p, "md") == 0) {
             const mbedtls_md_info_t *md_info =
                 mbedtls_md_info_from_string(q);
@@ -274,14 +345,39 @@
         }
     }
 
+    /* Set the MD algorithm to use for the signature in the CSR */
     mbedtls_x509write_csr_set_md_alg(&req, opt.md_alg);
 
+    /* Set the Key Usage Extension flags in the CSR */
     if (opt.key_usage || opt.force_key_usage == 1) {
-        mbedtls_x509write_csr_set_key_usage(&req, opt.key_usage);
+        ret = mbedtls_x509write_csr_set_key_usage(&req, opt.key_usage);
+
+        if (ret != 0) {
+            mbedtls_printf(" failed\n  !  mbedtls_x509write_csr_set_key_usage returned %d", ret);
+            goto exit;
+        }
     }
 
+    /* Set the Cert Type flags in the CSR */
     if (opt.ns_cert_type || opt.force_ns_cert_type == 1) {
-        mbedtls_x509write_csr_set_ns_cert_type(&req, opt.ns_cert_type);
+        ret = mbedtls_x509write_csr_set_ns_cert_type(&req, opt.ns_cert_type);
+
+        if (ret != 0) {
+            mbedtls_printf(" failed\n  !  mbedtls_x509write_csr_set_ns_cert_type returned %d", ret);
+            goto exit;
+        }
+    }
+
+    /* Set the SubjectAltName in the CSR */
+    if (opt.san_list != NULL) {
+        ret = mbedtls_x509write_csr_set_subject_alternative_name(&req, opt.san_list);
+
+        if (ret != 0) {
+            mbedtls_printf(
+                " failed\n  !  mbedtls_x509write_csr_set_subject_alternative_name returned %d",
+                ret);
+            goto exit;
+        }
     }
 
     /*
@@ -363,6 +459,14 @@
     mbedtls_ctr_drbg_free(&ctr_drbg);
     mbedtls_entropy_free(&entropy);
 
+    cur = opt.san_list;
+    while (cur != NULL) {
+        prev = cur;
+        cur = cur->next;
+        mbedtls_free(prev);
+    }
+
+
     mbedtls_exit(exit_code);
 }
 #endif /* MBEDTLS_X509_CSR_WRITE_C && MBEDTLS_PK_PARSE_C && MBEDTLS_FS_IO &&
diff --git a/tests/data_files/Makefile b/tests/data_files/Makefile
index e638caf..7cdbd24 100644
--- a/tests/data_files/Makefile
+++ b/tests/data_files/Makefile
@@ -1006,7 +1006,7 @@
 
 server1.req.sha256.ext: server1.key
 	# Generating this with OpenSSL as a comparison point to test we're getting the same result
-	openssl req -new -out $@ -key $< -subj '/C=NL/O=PolarSSL/CN=PolarSSL Server 1' -sha256 -addext "extendedKeyUsage=serverAuth"
+	openssl req -new -out $@ -key $< -subj '/C=NL/O=PolarSSL/CN=PolarSSL Server 1' -sha256 -addext "extendedKeyUsage=serverAuth" -addext "subjectAltName=URI:http://pki.example.com/,IP:127.1.1.0,DNS:example.com"
 all_final += server1.req.sha256.ext
 
 server1.req.sha384: server1.key
diff --git a/tests/data_files/server1.req.sha256.ext b/tests/data_files/server1.req.sha256.ext
index 3f26f09..c5ff5c5 100644
--- a/tests/data_files/server1.req.sha256.ext
+++ b/tests/data_files/server1.req.sha256.ext
@@ -1,17 +1,18 @@
 -----BEGIN CERTIFICATE REQUEST-----
-MIICpzCCAY8CAQAwPDELMAkGA1UEBhMCTkwxETAPBgNVBAoMCFBvbGFyU1NMMRow
+MIIC3jCCAcYCAQAwPDELMAkGA1UEBhMCTkwxETAPBgNVBAoMCFBvbGFyU1NMMRow
 GAYDVQQDDBFQb2xhclNTTCBTZXJ2ZXIgMTCCASIwDQYJKoZIhvcNAQEBBQADggEP
 ADCCAQoCggEBAKkCHz1AatVVU4v9Nu6CZS4VYV6Jv7joRZDb7ogWUtPxQ1BHlhJZ
 ZIdr/SvgRvlzvt3PkuGRW+1moG+JKXlFgNCDatVBQ3dfOXwJBEeCsFc5cO2j7BUZ
 HqgzCEfBBUKp/UzDtN/dBh9NEFFAZ3MTD0D4bYElXwqxU8YwfhU5rPla7n+SnqYF
 W+cTl4W1I5LZ1CQG1QkliXUH3aYajz8JGb6tZSxk65Wb3P5BXhem2mxbacwCuhQs
 FiScStzN0PdSZ3PxLaAj/X70McotcMqJCwTbLqZPcG6ezr1YieJTWZ5uWpJl4og/
-DJQZo93l6J2VE+0p26twEtxaymsXq1KCVLECAwEAAaAmMCQGCSqGSIb3DQEJDjEX
-MBUwEwYDVR0lBAwwCgYIKwYBBQUHAwEwDQYJKoZIhvcNAQELBQADggEBAHi0yEGu
-Fh5tuLiLuT95UrRnly55+lTY9xchFiKtlcoEdSheybYxqk3JHuSSqojOFKZBlRdk
-oG6Azg56/aMHPWyvtCMSRQX4b+FgjeQsm9IfhYNMquQOxyPxm62vjuU3MfZIofXH
-hKdI6Ci2CDF4Fyvw50KBWniV38eE9+kjsvDLdXD3ESZJGhjjuFl8ReUiA2wdBTcP
-XEZaXUIc6B4tUnlPeqn/2zp4GBqqWzNZx6TXBpApASGG3BEJnM52FVPC7E9p+8YZ
-qIGuiF5Cz/rYZkpwffBWIfS2zZakHLm5TB8FgZkWlyReJU9Ihk2Tl/sZ1kllFdYa
-xLPnLCL82KFL1Co=
+DJQZo93l6J2VE+0p26twEtxaymsXq1KCVLECAwEAAaBdMFsGCSqGSIb3DQEJDjFO
+MEwwEwYDVR0lBAwwCgYIKwYBBQUHAwEwNQYDVR0RBC4wLIYXaHR0cDovL3BraS5l
+eGFtcGxlLmNvbS+HBH8BAQCCC2V4YW1wbGUuY29tMA0GCSqGSIb3DQEBCwUAA4IB
+AQCGmTIXEUvTqwChkzRtxPIQDDchrMnCXgUrTSxre5nvUOpjVlcIIPGWAwxRovfe
+pW6OaGZ/3xD0dRAcOW08sTD6GRUazFrubPA1eZiNC7vYdWV59qm84N5yRR/s8Hm+
+okwI47m7W9C0pfaNXchgFUQBn16TrZxPXklbCpBJ/TFV+1ODY0sJPHYiCFpYI+Jz
+YuJmadP2BHucl8wv2RyVHywOmV1sDc74i9igVrBCAh8wu+kqImMtrnkGZDxrnj/L
+5P1eDfdqG2cN+s40RnMQMosh3UfqpNV/bTgAqBPP2uluT9L1KpWcjZeuvisOgVTq
+XwFI5s34fen2DUVw6MWNfbDK
 -----END CERTIFICATE REQUEST-----
diff --git a/tests/suites/test_suite_x509write.function b/tests/suites/test_suite_x509write.function
index cd1f203..5e8230f 100644
--- a/tests/suites/test_suite_x509write.function
+++ b/tests/suites/test_suite_x509write.function
@@ -152,6 +152,27 @@
     int der_len = -1;
     const char *subject_name = "C=NL,O=PolarSSL,CN=PolarSSL Server 1";
     mbedtls_test_rnd_pseudo_info rnd_info;
+    mbedtls_x509_san_list san_ip;
+    mbedtls_x509_san_list san_dns;
+    mbedtls_x509_san_list san_uri;
+    mbedtls_x509_san_list *san_list = NULL;
+    const char san_ip_name[] = { 0x7f, 0x01, 0x01, 0x00 }; // 127.1.1.0
+    const char *san_dns_name = "example.com";
+    const char *san_uri_name = "http://pki.example.com/";
+
+    san_uri.node.type = MBEDTLS_X509_SAN_UNIFORM_RESOURCE_IDENTIFIER;
+    san_uri.node.san.unstructured_name.p = (unsigned char *) san_uri_name;
+    san_uri.node.san.unstructured_name.len = strlen(san_uri_name);
+    san_uri.next = NULL;
+    san_ip.node.type = MBEDTLS_X509_SAN_IP_ADDRESS;
+    san_ip.node.san.unstructured_name.p = (unsigned char *) san_ip_name;
+    san_ip.node.san.unstructured_name.len = sizeof(san_ip_name);
+    san_ip.next = &san_uri;
+    san_dns.node.type = MBEDTLS_X509_SAN_DNS_NAME;
+    san_dns.node.san.unstructured_name.p = (unsigned char *) san_dns_name;
+    san_dns.node.san.unstructured_name.len = strlen(san_dns_name);
+    san_dns.next = &san_ip;
+    san_list = &san_dns;
 
     memset(&rnd_info, 0x2a, sizeof(mbedtls_test_rnd_pseudo_info));
 
@@ -175,6 +196,8 @@
     if (set_extension != 0) {
         TEST_ASSERT(csr_set_extended_key_usage(&req, MBEDTLS_OID_SERVER_AUTH,
                                                MBEDTLS_OID_SIZE(MBEDTLS_OID_SERVER_AUTH)) == 0);
+
+        TEST_ASSERT(mbedtls_x509write_csr_set_subject_alternative_name(&req, san_list) == 0);
     }
 
     ret = mbedtls_x509write_csr_pem(&req, buf, sizeof(buf),