Move to a callback interface for DTLS cookies
diff --git a/include/polarssl/ssl.h b/include/polarssl/ssl.h
index 90e2596..e526098 100644
--- a/include/polarssl/ssl.h
+++ b/include/polarssl/ssl.h
@@ -878,12 +878,16 @@
 #endif
 
     /*
-     * Client id (IP/port) for DTLS hello verify
+     * Information for DTLS hello verify
      */
 #if defined(POLARSSL_SSL_DTLS_HELLO_VERIFY)
     unsigned char  *cli_id;         /*!<  transport-level ID of the client  */
     size_t          cli_id_len;     /*!<  length of cli_id                  */
-    md_context_t    hvr_hmac_ctx;   /*!<  HMAC data for HelloVerifyRequest  */
+    int (*f_cookie_write)( void *, unsigned char **, unsigned char *,
+                           const unsigned char *, size_t );
+    int (*f_cookie_check)( void *, const unsigned char *, size_t,
+                           const unsigned char *, size_t );
+    void *p_cookie;                 /*!<  context for the cookie callbacks  */
 #endif
 
     /*
@@ -1072,7 +1076,7 @@
 #if defined(POLARSSL_SSL_DTLS_HELLO_VERIFY)
 /**
  * \brief          Set client's transport-level identification info.
- *                 (Only usable on server.)
+ *                 (Server only. DTLS only.)
  *
  *                 This is usually the IP address (and port), but could be
  *                 anything identify the client depending on the underlying
@@ -1095,8 +1099,93 @@
                                  const unsigned char *info,
                                  size_t ilen );
 
-/* Temporary */
-int ssl_setup_hvr_key( ssl_context *ssl );
+/**
+ * \brief          Callback type: generate a cookie
+ *
+ * \param ctx      Context for the callback
+ * \param p        Buffer to write to,
+ *                 must be updated to point right after the cookie
+ * \param end      Pointer to one past the end of the output buffer
+ * \param info     Client ID info that was passed to
+ *                 \c ssl_set_client_transport_id()
+ * \param ilen     Length of info in bytes
+ *
+ * \return         The callback must return 0 on success,
+ *                 or a negative error code.
+ */
+typedef int ssl_cookie_write_t( void *ctx,
+                                unsigned char **p, unsigned char *end,
+                                const unsigned char *info, size_t ilen );
+
+/**
+ * \brief          Callback type: verify a cookie
+ *
+ * \param ctx      Context for the callback
+ * \param cookie   Cookie to verify
+ * \param clen     Length of cookie
+ * \param info     Client ID info that was passed to
+ *                 \c ssl_set_client_transport_id()
+ * \param ilen     Length of info in bytes
+ *
+ * \return         The callback must return 0 if cookie is valid,
+ *                 or a negative error code.
+ */
+typedef int ssl_cookie_check_t( void *ctx,
+                                const unsigned char *cookie, size_t clen,
+                                const unsigned char *info, size_t ilen );
+
+/**
+ * \brief           Register callbacks for DTLS cookies
+ *                  (Server only. DTLS only.)
+ *
+ * \param ssl               SSL context
+ * \param f_cookie_write    Cookie write callback
+ * \param f_cookie_check    Cookie check callback
+ * \param p_cookie          Context for both callbacks
+ */
+void ssl_set_dtls_cookies( ssl_context *ssl,
+                           ssl_cookie_write_t *f_cookie_write,
+                           ssl_cookie_check_t *f_cookie_check,
+                           void *p_cookie );
+
+/* Note: the next things up to endif are to be moved in a separate module */
+
+/**
+ * \brief          Default cookie generation function.
+ *                 (See description of ssl_cookie_write_t.)
+ */
+ssl_cookie_write_t ssl_cookie_write;
+
+/**
+ * \brief          Default cookie verification function.
+ *                 (See description of ssl_cookie_check_t.)
+ */
+ssl_cookie_check_t ssl_cookie_check;
+
+/**
+ * \brief          Context for the default cookie functions.
+ */
+typedef struct
+{
+    md_context_t    hmac_ctx;
+} ssl_cookie_ctx;
+
+/**
+ * \brief          Initialize cookie context
+ */
+void ssl_cookie_init( ssl_cookie_ctx *ctx );
+
+/**
+ * \brief          Setup cookie context (generate keys)
+ */
+int ssl_cookie_setup( ssl_cookie_ctx *ctx,
+                      int (*f_rng)(void *, unsigned char *, size_t),
+                      void *p_rng );
+
+/**
+ * \brief          Free cookie context
+ */
+void ssl_cookie_free( ssl_cookie_ctx *ctx );
 #endif /* POLARSSL_SSL_DTLS_HELLO_VERIFY */
 
 /**
diff --git a/library/ssl_srv.c b/library/ssl_srv.c
index 9300fa7..e22132b 100644
--- a/library/ssl_srv.c
+++ b/library/ssl_srv.c
@@ -369,6 +369,16 @@
 
     return( 0 );
 }
+
+void ssl_set_dtls_cookies( ssl_context *ssl,
+                           ssl_cookie_write_t *f_cookie_write,
+                           ssl_cookie_check_t *f_cookie_check,
+                           void *p_cookie )
+{
+    ssl->f_cookie_write = f_cookie_write;
+    ssl->f_cookie_check = f_cookie_check;
+    ssl->p_cookie       = p_cookie;
+}
 #endif /* POLARSSL_SSL_DTLS_HELLO_VERIFY */
 
 #if defined(POLARSSL_SSL_SERVER_NAME_INDICATION)
@@ -1159,22 +1169,31 @@
 #error "DTLS hello verify needs SHA-1 or SHA-2"
 #endif
 
-/*
- * Generate server key for HelloVerifyRequest
- */
-int ssl_setup_hvr_key( ssl_context *ssl )
+void ssl_cookie_init( ssl_cookie_ctx *ctx )
+{
+    md_init( &ctx->hmac_ctx );
+}
+
+void ssl_cookie_free( ssl_cookie_ctx *ctx )
+{
+    md_free( &ctx->hmac_ctx );
+}
+
+int ssl_cookie_setup( ssl_cookie_ctx *ctx,
+                      int (*f_rng)(void *, unsigned char *, size_t),
+                      void *p_rng )
 {
     int ret;
     unsigned char key[HVR_MD_LEN];
 
-    if( ( ret = ssl->f_rng( ssl->p_rng, key, sizeof( key ) ) ) != 0 )
+    if( ( ret = f_rng( p_rng, key, sizeof( key ) ) ) != 0 )
         return( ret );
 
-    ret = md_init_ctx( &ssl->hvr_hmac_ctx, md_info_from_type( HVR_MD ) );
+    ret = md_init_ctx( &ctx->hmac_ctx, md_info_from_type( HVR_MD ) );
     if( ret != 0 )
         return( ret );
 
-    ret = md_hmac_starts( &ssl->hvr_hmac_ctx, key, sizeof( key ) );
+    ret = md_hmac_starts( &ctx->hmac_ctx, key, sizeof( key ) );
     if( ret != 0 )
         return( ret );
 
@@ -1186,9 +1205,9 @@
 /*
  * Generate cookie for DTLS ClientHello verification
  */
-static int ssl_cookie_write( void *ctx,
-                             unsigned char **p, unsigned char *end,
-                             const unsigned char *cli_id, size_t cli_id_len )
+int ssl_cookie_write( void *ctx,
+                      unsigned char **p, unsigned char *end,
+                      const unsigned char *cli_id, size_t cli_id_len )
 {
     int ret;
     unsigned char hmac_out[HVR_MD_LEN];
@@ -1213,9 +1232,9 @@
 /*
  * Check a cookie
  */
-static int ssl_cookie_check( void *ctx,
-                             const unsigned char *cookie, size_t cookie_len,
-                             const unsigned char *cli_id, size_t cli_id_len )
+int ssl_cookie_check( void *ctx,
+                      const unsigned char *cookie, size_t cookie_len,
+                      const unsigned char *cli_id, size_t cli_id_len )
 {
     unsigned char ref_cookie[HVR_MD_USE];
     unsigned char *p = ref_cookie;
@@ -1531,9 +1550,9 @@
                        buf + cookie_offset + 1, cookie_len );
 
 #if defined(POLARSSL_SSL_DTLS_HELLO_VERIFY)
-        if( ssl_cookie_check( &ssl->hvr_hmac_ctx,
-                              buf + cookie_offset + 1, cookie_len,
-                              ssl->cli_id, ssl->cli_id_len ) != 0 )
+        if( ssl->f_cookie_check( ssl->p_cookie,
+                                 buf + cookie_offset + 1, cookie_len,
+                                 ssl->cli_id, ssl->cli_id_len ) != 0 )
         {
             SSL_DEBUG_MSG( 2, ( "client hello, cookie verification failed" ) );
             ssl->handshake->verify_cookie_len = 1;
@@ -2075,11 +2094,11 @@
     /* Skip length byte until we know the length */
     cookie_len_byte = p++;
 
-    if( ( ret = ssl_cookie_write( &ssl->hvr_hmac_ctx,
-                                  &p, ssl->out_buf + SSL_BUFFER_LEN,
-                                  ssl->cli_id, ssl->cli_id_len ) ) != 0 )
+    if( ( ret = ssl->f_cookie_write( ssl->p_cookie,
+                                     &p, ssl->out_buf + SSL_BUFFER_LEN,
+                                     ssl->cli_id, ssl->cli_id_len ) ) != 0 )
     {
-        SSL_DEBUG_RET( 1, "ssl_cookie_generate", ret );
+        SSL_DEBUG_RET( 1, "f_cookie_write", ret );
         return( ret );
     }
 
diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index 4a111c6..288ef69 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -5041,7 +5041,6 @@
 
 #if defined(POLARSSL_SSL_DTLS_HELLO_VERIFY)
     polarssl_free( ssl->cli_id );
-    md_free( &ssl->hvr_hmac_ctx );
 #endif
 
     SSL_DEBUG_MSG( 2, ( "<= free" ) );
diff --git a/programs/ssl/ssl_server2.c b/programs/ssl/ssl_server2.c
index df88502..67661ea 100644
--- a/programs/ssl/ssl_server2.c
+++ b/programs/ssl/ssl_server2.c
@@ -601,6 +601,9 @@
 #endif
     const char *pers = "ssl_server2";
     unsigned char client_ip[16] = { 0 };
+#if defined(POLARSSL_SSL_DTLS_HELLO_VERIFY)
+    ssl_cookie_ctx cookie_ctx;
+#endif
 
     entropy_context entropy;
     ctr_drbg_context ctr_drbg;
@@ -658,6 +661,9 @@
 #if defined(POLARSSL_SSL_ALPN)
     memset( (void *) alpn_list, 0, sizeof( alpn_list ) );
 #endif
+#if defined(POLARSSL_SSL_DTLS_HELLO_VERIFY)
+    ssl_cookie_init( &cookie_ctx );
+#endif
 
 #if !defined(_WIN32)
     /* Abort cleanly on SIGTERM */
@@ -1345,11 +1351,17 @@
 #endif
 
 #if defined(POLARSSL_SSL_DTLS_HELLO_VERIFY)
-    if( opt.transport == SSL_TRANSPORT_DATAGRAM &&
-        ( ret = ssl_setup_hvr_key( &ssl ) ) != 0 )
+    if( opt.transport == SSL_TRANSPORT_DATAGRAM )
     {
-        printf( " failed\n  ! ssl_setup_hvr_key returned %d\n\n", ret );
-        goto exit;
+        if( ( ret = ssl_cookie_setup( &cookie_ctx,
+                                      ctr_drbg_random, &ctr_drbg ) ) != 0 )
+        {
+            printf( " failed\n  ! ssl_setup_hvr_key returned %d\n\n", ret );
+            goto exit;
+        }
+
+        ssl_set_dtls_cookies( &ssl, ssl_cookie_write, ssl_cookie_check,
+                                   &cookie_ctx );
     }
 #endif
 
@@ -1844,6 +1856,9 @@
 #if defined(POLARSSL_SSL_CACHE_C)
     ssl_cache_free( &cache );
 #endif
+#if defined(POLARSSL_SSL_DTLS_HELLO_VERIFY)
+    ssl_cookie_free( &cookie_ctx );
+#endif
 
 #if defined(POLARSSL_MEMORY_BUFFER_ALLOC_C)
 #if defined(POLARSSL_MEMORY_DEBUG)