Add new function mbedtls_asn1_write_named_bitstring()

Add a new function mbedtls_asn1_write_named_bitstring() that removes
trailing 0s at the end of DER encoded bitstrings. The function is
implemented according to Hanno Becker's suggestions.

This commit also changes the functions x509write_crt_set_ns_cert_type
and crt_set_key_usage to call the new function as the use named
bitstrings instead of the regular bitstrings.
diff --git a/include/mbedtls/asn1write.h b/include/mbedtls/asn1write.h
index 76c1780..80b31c3 100644
--- a/include/mbedtls/asn1write.h
+++ b/include/mbedtls/asn1write.h
@@ -277,6 +277,26 @@
                                   const unsigned char *buf, size_t bits );
 
 /**
+ * \brief           Write a named bitstring tag (MBEDTLS_ASN1_BIT_STRING) and
+ *                  value in ASN.1 format
+ *                  Note: function works backwards in data buffer
+ *
+ *                  As stated in RFC5280 Appending B, trailing zeroes are
+ *                  omitted when encoding named bitstrings in DER.
+ *
+ * \param p         Reference to current position pointer.
+ * \param start     Start of the buffer (for bounds-checking).
+ * \param buf       The bitstring.
+ * \param bits      The total number of bits in the bitstring.
+ *
+ * \return          The length written or a negative error code.
+ */
+int mbedtls_asn1_write_named_bitstring( unsigned char **p,
+                                        unsigned char *start,
+                                        const unsigned char *buf,
+                                        size_t bits );
+
+/**
  * \brief           Write an octet string tag (#MBEDTLS_ASN1_OCTET_STRING)
  *                  and value in ASN.1 format.
  *
diff --git a/library/asn1write.c b/library/asn1write.c
index a4d23f6..b54e26b 100644
--- a/library/asn1write.c
+++ b/library/asn1write.c
@@ -290,26 +290,75 @@
     return( mbedtls_asn1_write_tagged_string(p, start, MBEDTLS_ASN1_IA5_STRING, text, text_len) );
 }
 
+int mbedtls_asn1_write_named_bitstring( unsigned char **p,
+                                        unsigned char *start,
+                                        const unsigned char *buf,
+                                        size_t bits )
+{
+    size_t unused_bits, byte_len;
+    const unsigned char *cur_byte;
+    unsigned char cur_byte_shifted;
+    unsigned char bit;
+
+    byte_len = ( bits + 7 ) / 8;
+    unused_bits = ( byte_len * 8 ) - bits;
+
+    /*
+     * Named bitstrings require that trailing 0s are excluded in the encoding
+     * of the bitstring. Trailing 0s are considered part of the 'unused' bits
+     * when encoding this value in the first content octet
+     */
+    if( bits != 0 )
+    {
+        cur_byte = buf + byte_len - 1;
+        cur_byte_shifted = *cur_byte >> unused_bits;
+
+        for( ; ; )
+        {
+            bit = cur_byte_shifted & 0x1;
+            cur_byte_shifted >>= 1;
+
+            if( bit != 0 )
+                break;
+
+            bits--;
+            if( bits == 0 )
+                break;
+
+            if( bits % 8 == 0 )
+                cur_byte_shifted = *--cur_byte;
+        }
+    }
+
+    return( mbedtls_asn1_write_bitstring( p, start, buf, bits ) );
+}
+
 int mbedtls_asn1_write_bitstring( unsigned char **p, unsigned char *start,
                           const unsigned char *buf, size_t bits )
 {
     int ret;
-    size_t len = 0, size;
+    size_t len = 0;
+    size_t unused_bits, byte_len;
 
-    size = ( bits / 8 ) + ( ( bits % 8 ) ? 1 : 0 );
+    byte_len = ( bits + 7 ) / 8;
+    unused_bits = ( byte_len * 8 ) - bits;
 
-    // Calculate byte length
-    //
-    if( *p < start || (size_t)( *p - start ) < size + 1 )
+    if( *p < start || (size_t)( *p - start ) < byte_len + 1 )
         return( MBEDTLS_ERR_ASN1_BUF_TOO_SMALL );
 
-    len = size + 1;
-    (*p) -= size;
-    memcpy( *p, buf, size );
+    len = byte_len + 1;
 
-    // Write unused bits
-    //
-    *--(*p) = (unsigned char) (size * 8 - bits);
+    /* Write the bitstring. Ensure the unused bits are zeroed */
+    if( byte_len > 0 )
+    {
+        byte_len--;
+        *--( *p ) = buf[byte_len] & ~( ( 0x1 << unused_bits ) - 1 );
+        ( *p ) -= byte_len;
+        memcpy( *p, buf, byte_len );
+    }
+
+    /* Write unused bits */
+    *--( *p ) = (unsigned char)unused_bits;
 
     MBEDTLS_ASN1_CHK_ADD( len, mbedtls_asn1_write_len( p, start, len ) );
     MBEDTLS_ASN1_CHK_ADD( len, mbedtls_asn1_write_tag( p, start, MBEDTLS_ASN1_BIT_STRING ) );
diff --git a/library/x509write_crt.c b/library/x509write_crt.c
index b1ef216..b6cb745 100644
--- a/library/x509write_crt.c
+++ b/library/x509write_crt.c
@@ -221,23 +221,36 @@
 int mbedtls_x509write_crt_set_key_usage( mbedtls_x509write_cert *ctx,
                                          unsigned int key_usage )
 {
-    unsigned char buf[4], ku;
+    unsigned char buf[5], ku[2];
     unsigned char *c;
     int ret;
+    const unsigned int allowed_bits = MBEDTLS_X509_KU_DIGITAL_SIGNATURE |
+        MBEDTLS_X509_KU_NON_REPUDIATION   |
+        MBEDTLS_X509_KU_KEY_ENCIPHERMENT  |
+        MBEDTLS_X509_KU_DATA_ENCIPHERMENT |
+        MBEDTLS_X509_KU_KEY_AGREEMENT     |
+        MBEDTLS_X509_KU_KEY_CERT_SIGN     |
+        MBEDTLS_X509_KU_CRL_SIGN          |
+        MBEDTLS_X509_KU_ENCIPHER_ONLY     |
+        MBEDTLS_X509_KU_DECIPHER_ONLY;
 
-    /* We currently only support 7 bits, from 0x80 to 0x02 */
-    if( ( key_usage & ~0xfe ) != 0 )
+    /* Check that nothing other than the allowed flags is set */
+    if( ( key_usage & ~allowed_bits ) != 0 )
         return( MBEDTLS_ERR_X509_FEATURE_UNAVAILABLE );
 
-    c = buf + 4;
-    ku = (unsigned char) key_usage;
+    c = buf + 5;
+    ku[0] = (unsigned char)( key_usage      );
+    ku[1] = (unsigned char)( key_usage >> 8 );
+    ret = mbedtls_asn1_write_named_bitstring( &c, buf, ku, 9 );
 
-    if( ( ret = mbedtls_asn1_write_bitstring( &c, buf, &ku, 7 ) ) != 4 )
+    if( ret < 0 )
         return( ret );
+    else if( ret < 3 || ret > 5 )
+        return( MBEDTLS_ERR_X509_INVALID_FORMAT );
 
     ret = mbedtls_x509write_crt_set_extension( ctx, MBEDTLS_OID_KEY_USAGE,
                                        MBEDTLS_OID_SIZE( MBEDTLS_OID_KEY_USAGE ),
-                                       1, buf, 4 );
+                                       1, c, (size_t)ret );
     if( ret != 0 )
         return( ret );
 
@@ -253,12 +266,13 @@
 
     c = buf + 4;
 
-    if( ( ret = mbedtls_asn1_write_bitstring( &c, buf, &ns_cert_type, 8 ) ) != 4 )
+    ret = mbedtls_asn1_write_named_bitstring( &c, buf, &ns_cert_type, 8 );
+    if( ret < 3 || ret > 4 )
         return( ret );
 
     ret = mbedtls_x509write_crt_set_extension( ctx, MBEDTLS_OID_NS_CERT_TYPE,
                                        MBEDTLS_OID_SIZE( MBEDTLS_OID_NS_CERT_TYPE ),
-                                       0, buf, 4 );
+                                       0, c, (size_t)ret );
     if( ret != 0 )
         return( ret );
 
diff --git a/library/x509write_csr.c b/library/x509write_csr.c
index 66cee56..8b475be 100644
--- a/library/x509write_csr.c
+++ b/library/x509write_csr.c
@@ -89,12 +89,13 @@
 
     c = buf + 4;
 
-    if( ( ret = mbedtls_asn1_write_bitstring( &c, buf, &key_usage, 7 ) ) != 4 )
+    ret = mbedtls_asn1_write_named_bitstring( &c, buf, &key_usage, 8 );
+    if( ret < 3 || ret > 4 )
         return( ret );
 
     ret = mbedtls_x509write_csr_set_extension( ctx, MBEDTLS_OID_KEY_USAGE,
                                        MBEDTLS_OID_SIZE( MBEDTLS_OID_KEY_USAGE ),
-                                       buf, 4 );
+                                       c, (size_t)ret );
     if( ret != 0 )
         return( ret );
 
@@ -110,12 +111,13 @@
 
     c = buf + 4;
 
-    if( ( ret = mbedtls_asn1_write_bitstring( &c, buf, &ns_cert_type, 8 ) ) != 4 )
+    ret = mbedtls_asn1_write_named_bitstring( &c, buf, &ns_cert_type, 8 );
+    if( ret < 3 || ret > 4 )
         return( ret );
 
     ret = mbedtls_x509write_csr_set_extension( ctx, MBEDTLS_OID_NS_CERT_TYPE,
                                        MBEDTLS_OID_SIZE( MBEDTLS_OID_NS_CERT_TYPE ),
-                                       buf, 4 );
+                                       c, (size_t)ret );
     if( ret != 0 )
         return( ret );