Add tls13 session save/load

Signed-off-by: Jerry Yu <jerry.h.yu@arm.com>
diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index 6d91011..d055794 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -440,12 +440,12 @@
 static void ssl_calc_finished_tls_sha384( mbedtls_ssl_context *, unsigned char *, int );
 #endif /* MBEDTLS_SHA384_C */
 
-static size_t ssl_session_save_tls12( const mbedtls_ssl_session *session,
+static size_t ssl_tls12_session_save( const mbedtls_ssl_session *session,
                                       unsigned char *buf,
                                       size_t buf_len );
 
 MBEDTLS_CHECK_RETURN_CRITICAL
-static int ssl_session_load_tls12( mbedtls_ssl_session *session,
+static int ssl_tls12_session_load( mbedtls_ssl_session *session,
                                    const unsigned char *buf,
                                    size_t len );
 #endif /* MBEDTLS_SSL_PROTO_TLS1_2 */
@@ -1888,7 +1888,170 @@
 #if defined(MBEDTLS_USE_PSA_CRYPTO) || defined(MBEDTLS_SSL_PROTO_TLS1_3)
 
 #if defined(MBEDTLS_SSL_PROTO_TLS1_3)
-static size_t ssl_session_save_tls13( const mbedtls_ssl_session *session,
+/* Serialization of TLS 1.3 sessions:
+ *
+ *     struct {
+ *       uint64 ticket_received;
+ *       uint32 ticket_lifetime;
+ *       opaque ticket<0..2^16>;
+ *     } ClientOnlyData;
+ *
+ *     struct {
+ *       uint8 endpoint;
+ *       uint8 ciphersuite[2];
+ *       uint32 ticket_age_add;
+ *       uint8 ticket_flags;
+ *       opaque resumption_key<0..255>;
+ *       select ( endpoint ) {
+ *            case client: ClientOnlyData;
+ *            case server: uint64 start_time;
+ *        };
+ *     } serialized_session_tls13;
+ *
+ */
+#if defined(MBEDTLS_SSL_SESSION_TICKETS)
+static size_t ssl_tls13_session_save( const mbedtls_ssl_session *session,
+                                      unsigned char *buf,
+                                      size_t buf_len )
+{
+    unsigned char *p = buf;
+    size_t needed =   1                 /* endpoint */
+                    + 2                 /* ciphersuite */
+                    + 4                 /* ticket_age_add */
+                    + 2                 /* key_len */
+                    + session->key_len; /* key */
+
+#if defined(MBEDTLS_HAVE_TIME)
+    needed += 8; /* start_time or ticket_received */
+#endif
+
+#if defined(MBEDTLS_SSL_CLI_C)
+    if( session->endpoint == MBEDTLS_SSL_IS_CLIENT )
+    {
+        needed +=   4                       /* ticket_lifetime */
+                  + 2                       /* ticket_len */
+                  + session->ticket_len;    /* ticket */
+    }
+#endif /* MBEDTLS_SSL_CLI_C */
+
+    if( needed > buf_len )
+        return( needed );
+
+    p[0] = session->endpoint;
+    MBEDTLS_PUT_UINT16_BE( session->ciphersuite, p, 1 );
+    MBEDTLS_PUT_UINT32_BE( session->ticket_age_add, p, 3 );
+    p[7] = session->ticket_flags;
+
+    /* save resumption_key */
+    p[8] = session->key_len;
+    p += 9;
+    memcpy( p, session->key, session->key_len );
+    p += session->key_len;
+
+#if defined(MBEDTLS_HAVE_TIME) && defined(MBEDTLS_SSL_SRV_C)
+    if( session->endpoint == MBEDTLS_SSL_IS_SERVER )
+    {
+        MBEDTLS_PUT_UINT64_BE( (uint64_t) session->start, p, 0 );
+        p += 8;
+    }
+#endif /* MBEDTLS_HAVE_TIME */
+
+#if defined(MBEDTLS_SSL_CLI_C)
+    if( session->endpoint == MBEDTLS_SSL_IS_CLIENT )
+    {
+#if defined(MBEDTLS_HAVE_TIME)
+        MBEDTLS_PUT_UINT64_BE( (uint64_t) session->ticket_received, p, 0 );
+        p += 8;
+#endif
+        MBEDTLS_PUT_UINT32_BE( session->ticket_lifetime, p, 0 );
+        p += 4;
+
+        MBEDTLS_PUT_UINT16_BE( session->ticket_len, p, 0 );
+        p += 2;
+        if( session->ticket_len > 0 )
+        {
+            memcpy( p, session->ticket, session->ticket_len );
+            p += session->ticket_len;
+        }
+    }
+#endif /* MBEDTLS_SSL_CLI_C */
+    return( needed );
+}
+
+MBEDTLS_CHECK_RETURN_CRITICAL
+static int ssl_tls13_session_load( mbedtls_ssl_session *session,
+                                   const unsigned char *buf,
+                                   size_t len )
+{
+    const unsigned char *p = buf;
+    const unsigned char *end = buf + len;
+
+    if( end - p < 9 )
+        return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
+    session->endpoint = p[0];
+    session->ciphersuite = MBEDTLS_GET_UINT16_BE( p, 1 );
+    session->ticket_age_add = MBEDTLS_GET_UINT32_BE( p, 3 );
+    session->ticket_flags = p[7];
+
+    /* load resumption_key */
+    session->key_len = p[8];
+    p += 9;
+
+    if( end - p < session->key_len )
+        return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
+
+    if( sizeof( session->key ) < session->key_len)
+        return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
+    memcpy( session->key, p, session->key_len );
+    p += session->key_len;
+
+#if defined(MBEDTLS_HAVE_TIME) && defined(MBEDTLS_SSL_SRV_C)
+    if( session->endpoint == MBEDTLS_SSL_IS_SERVER )
+    {
+        if( end - p < 8 )
+            return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
+        session->start = MBEDTLS_GET_UINT64_BE( p, 0 );
+        p += 8;
+    }
+#endif /* MBEDTLS_HAVE_TIME */
+
+#if defined(MBEDTLS_SSL_CLI_C)
+    if( session->endpoint == MBEDTLS_SSL_IS_CLIENT )
+    {
+#if defined(MBEDTLS_HAVE_TIME)
+        if( end - p < 8 )
+            return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
+        session->ticket_received = MBEDTLS_GET_UINT64_BE( p, 0 );
+        p += 8;
+#endif
+        if( end - p < 4 )
+            return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
+        session->ticket_lifetime = MBEDTLS_GET_UINT32_BE( p, 0 );
+        p += 4;
+
+        if( end - p <  2 )
+            return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
+        session->ticket_len = MBEDTLS_GET_UINT16_BE( p, 0 );
+        p += 2;
+
+        if( end - p < ( long int )session->ticket_len )
+            return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
+        if( session->ticket_len > 0 )
+        {
+            session->ticket = mbedtls_calloc( 1, session->ticket_len );
+            if( session->ticket == NULL )
+                return( MBEDTLS_ERR_SSL_ALLOC_FAILED );
+            memcpy( session->ticket, p, session->ticket_len );
+            p += session->ticket_len;
+        }
+    }
+#endif /* MBEDTLS_SSL_CLI_C */
+
+    return( 0 );
+
+}
+#else /* MBEDTLS_SSL_SESSION_TICKETS */
+static size_t ssl_tls13_session_save( const mbedtls_ssl_session *session,
                                       unsigned char *buf,
                                       size_t buf_len )
 {
@@ -1897,6 +2060,17 @@
     ((void) buf_len);
     return( 0 );
 }
+
+static int ssl_tls13_session_load( const mbedtls_ssl_session *session,
+                                   unsigned char *buf,
+                                   size_t buf_len )
+{
+    ((void) session);
+    ((void) buf);
+    ((void) buf_len);
+    return( MBEDTLS_ERR_SSL_FEATURE_UNAVAILABLE );
+}
+#endif /* !MBEDTLS_SSL_SESSION_TICKETS */
 #endif /* MBEDTLS_SSL_PROTO_TLS1_3 */
 
 psa_status_t mbedtls_ssl_cipher_to_psa( mbedtls_cipher_type_t mbedtls_cipher_type,
@@ -2827,12 +3001,14 @@
     size_t used = 0;
     size_t remaining_len;
 
+    if( session == NULL )
+        return( MBEDTLS_ERR_SSL_INTERNAL_ERROR );
+
     if( !omit_header )
     {
         /*
          * Add Mbed TLS version identifier
          */
-
         used += sizeof( ssl_serialized_session_header );
 
         if( used <= buf_len )
@@ -2858,18 +3034,14 @@
     {
 #if defined(MBEDTLS_SSL_PROTO_TLS1_2)
     case MBEDTLS_SSL_VERSION_TLS1_2:
-    {
-        used += ssl_session_save_tls12( session, p, remaining_len );
+        used += ssl_tls12_session_save( session, p, remaining_len );
         break;
-    }
 #endif /* MBEDTLS_SSL_PROTO_TLS1_2 */
 
 #if defined(MBEDTLS_SSL_PROTO_TLS1_3)
     case MBEDTLS_SSL_VERSION_TLS1_3:
-    {
-        used += ssl_session_save_tls13( session, p, remaining_len );
+        used += ssl_tls13_session_save( session, p, remaining_len );
         break;
-    }
 #endif /* MBEDTLS_SSL_PROTO_TLS1_3 */
 
     default:
@@ -2908,6 +3080,11 @@
 {
     const unsigned char *p = buf;
     const unsigned char * const end = buf + len;
+    size_t remaining_len;
+
+
+    if( session == NULL )
+        return( MBEDTLS_ERR_SSL_INTERNAL_ERROR );
 
     if( !omit_header )
     {
@@ -2934,16 +3111,19 @@
     session->tls_version = 0x0300 | *p++;
 
     /* Dispatch according to TLS version. */
+    remaining_len = ( end - p );
     switch( session->tls_version )
     {
 #if defined(MBEDTLS_SSL_PROTO_TLS1_2)
     case MBEDTLS_SSL_VERSION_TLS1_2:
-    {
-        size_t remaining_len = ( end - p );
-        return( ssl_session_load_tls12( session, p, remaining_len ) );
-    }
+        return( ssl_tls12_session_load( session, p, remaining_len ) );
 #endif /* MBEDTLS_SSL_PROTO_TLS1_2 */
 
+#if defined(MBEDTLS_SSL_PROTO_TLS1_3)
+    case MBEDTLS_SSL_VERSION_TLS1_3:
+        return( ssl_tls13_session_load( session, p, remaining_len ) );
+#endif /* MBEDTLS_SSL_PROTO_TLS1_3 */
+
     default:
         return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
     }
@@ -7826,7 +8006,7 @@
  * } serialized_session_tls12;
  *
  */
-static size_t ssl_session_save_tls12( const mbedtls_ssl_session *session,
+static size_t ssl_tls12_session_save( const mbedtls_ssl_session *session,
                                       unsigned char *buf,
                                       size_t buf_len )
 {
@@ -7978,7 +8158,7 @@
 }
 
 MBEDTLS_CHECK_RETURN_CRITICAL
-static int ssl_session_load_tls12( mbedtls_ssl_session *session,
+static int ssl_tls12_session_load( mbedtls_ssl_session *session,
                                    const unsigned char *buf,
                                    size_t len )
 {