Use a specific function in the PSK callback
diff --git a/include/mbedtls/ssl.h b/include/mbedtls/ssl.h
index a72914f..e098bc9 100644
--- a/include/mbedtls/ssl.h
+++ b/include/mbedtls/ssl.h
@@ -672,6 +672,10 @@
 #if defined(MBEDTLS_ECDH_C) || defined(MBEDTLS_ECDSA_C)
     const mbedtls_ecp_curve_info **curves;      /*!<  Supported elliptic curves */
 #endif
+#if defined(MBEDTLS_KEY_EXCHANGE__SOME__PSK_ENABLED)
+    unsigned char *psk;                 /*!<  PSK from the callback         */
+    size_t psk_len;                     /*!<  Length of PSK from callback   */
+#endif
 #if defined(MBEDTLS_X509_CRT_PARSE_C)
     /**
      * Current key/cert or key/cert list.
@@ -1581,8 +1585,10 @@
 
 #if defined(MBEDTLS_KEY_EXCHANGE__SOME__PSK_ENABLED)
 /**
- * \brief          Set the Pre Shared Key (PSK) and the identity name connected
- *                 to it.
+ * \brief          Set the Pre Shared Key (PSK) and the expected identity name
+ *
+ * \note           This is mainly useful for clients. Servers will usually
+ *                 want to use \c mbedtls_ssl_set_psk_cb() instead.
  *
  * \param ssl      SSL context
  * \param psk      pointer to the pre-shared key
@@ -1592,11 +1598,28 @@
  *
  * \return         0 if successful or MBEDTLS_ERR_SSL_MALLOC_FAILED
  */
-int mbedtls_ssl_set_psk( mbedtls_ssl_context *ssl, const unsigned char *psk, size_t psk_len,
-                 const unsigned char *psk_identity, size_t psk_identity_len );
+int mbedtls_ssl_set_psk( mbedtls_ssl_context *ssl,
+                const unsigned char *psk, size_t psk_len,
+                const unsigned char *psk_identity, size_t psk_identity_len );
+
 
 /**
- * \brief          Set the PSK callback (server-side only) (Optional).
+ * \brief          Set the Pre Shared Key (PSK) for the current handshake
+ *
+ * \note           This should only be called inside the PSK callback,
+ *                 ie the function passed to \c mbedtls_ssl_set_psk_cb().
+ *
+ * \param ssl      SSL context
+ * \param psk      pointer to the pre-shared key
+ * \param psk_len  pre-shared key length
+ *
+ * \return         0 if successful or MBEDTLS_ERR_SSL_MALLOC_FAILED
+ */
+int mbedtls_ssl_set_hs_psk( mbedtls_ssl_context *ssl,
+                            const unsigned char *psk, size_t psk_len );
+
+/**
+ * \brief          Set the PSK callback (server-side only).
  *
  *                 If set, the PSK callback is called for each
  *                 handshake where a PSK ciphersuite was negotiated.
@@ -1607,10 +1630,14 @@
  *                 mbedtls_ssl_context *ssl, const unsigned char *psk_identity,
  *                 size_t identity_len)
  *                 If a valid PSK identity is found, the callback should use
- *                 mbedtls_ssl_set_psk() on the ssl context to set the correct PSK and
- *                 identity and return 0.
+ *                 \c mbedtls_ssl_set_hs_psk() on the ssl context to set the
+ *                 correct PSK and return 0.
  *                 Any other return value will result in a denied PSK identity.
  *
+ * \note           If you set a PSK callback using this function, then you
+ *                 don't need to set a PSK key and identity using
+ *                 \c mbedtls_ssl_set_psk().
+ *
  * \param conf     SSL configuration
  * \param f_psk    PSK identity function
  * \param p_psk    PSK identity parameter
diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index c599761..d3ec5dc 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -1066,6 +1066,15 @@
 {
     unsigned char *p = ssl->handshake->premaster;
     unsigned char *end = p + sizeof( ssl->handshake->premaster );
+    const unsigned char *psk = ssl->conf->psk;
+    size_t psk_len = ssl->conf->psk_len;
+
+    /* If the psk callback was called, use its result */
+    if( ssl->handshake->psk != NULL )
+    {
+        psk = ssl->handshake->psk;
+        psk_len = ssl->handshake->psk_len;
+    }
 
     /*
      * PMS = struct {
@@ -1077,12 +1086,12 @@
 #if defined(MBEDTLS_KEY_EXCHANGE_PSK_ENABLED)
     if( key_ex == MBEDTLS_KEY_EXCHANGE_PSK )
     {
-        if( end - p < 2 + (int) ssl->conf->psk_len )
+        if( end - p < 2 + (int) psk_len )
             return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
 
-        *(p++) = (unsigned char)( ssl->conf->psk_len >> 8 );
-        *(p++) = (unsigned char)( ssl->conf->psk_len      );
-        p += ssl->conf->psk_len;
+        *(p++) = (unsigned char)( psk_len >> 8 );
+        *(p++) = (unsigned char)( psk_len      );
+        p += psk_len;
     }
     else
 #endif /* MBEDTLS_KEY_EXCHANGE_PSK_ENABLED */
@@ -1149,13 +1158,13 @@
     }
 
     /* opaque psk<0..2^16-1>; */
-    if( end - p < 2 + (int) ssl->conf->psk_len )
+    if( end - p < 2 + (int) psk_len )
             return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
 
-    *(p++) = (unsigned char)( ssl->conf->psk_len >> 8 );
-    *(p++) = (unsigned char)( ssl->conf->psk_len      );
-    memcpy( p, ssl->conf->psk, ssl->conf->psk_len );
-    p += ssl->conf->psk_len;
+    *(p++) = (unsigned char)( psk_len >> 8 );
+    *(p++) = (unsigned char)( psk_len      );
+    memcpy( p, psk, psk_len );
+    p += psk_len;
 
     ssl->handshake->pmslen = p - ssl->handshake->premaster;
 
@@ -5353,8 +5362,9 @@
 #endif /* MBEDTLS_X509_CRT_PARSE_C */
 
 #if defined(MBEDTLS_KEY_EXCHANGE__SOME__PSK_ENABLED)
-int mbedtls_ssl_set_psk( mbedtls_ssl_context *ssl, const unsigned char *psk, size_t psk_len,
-                 const unsigned char *psk_identity, size_t psk_identity_len )
+int mbedtls_ssl_set_psk( mbedtls_ssl_context *ssl,
+                const unsigned char *psk, size_t psk_len,
+                const unsigned char *psk_identity, size_t psk_identity_len )
 {
     if( psk == NULL || psk_identity == NULL )
         return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
@@ -5385,6 +5395,31 @@
     return( 0 );
 }
 
+int mbedtls_ssl_set_hs_psk( mbedtls_ssl_context *ssl,
+                            const unsigned char *psk, size_t psk_len )
+{
+    if( psk == NULL || ssl->handshake == NULL )
+        return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
+
+    if( psk_len > MBEDTLS_PSK_MAX_LEN )
+        return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
+
+    if( ssl->handshake->psk != NULL )
+        mbedtls_free( ssl->conf->psk );
+
+    if( ( ssl->handshake->psk = mbedtls_malloc( psk_len ) ) == NULL )
+    {
+        mbedtls_free( ssl->handshake->psk );
+        ssl->handshake->psk = NULL;
+        return( MBEDTLS_ERR_SSL_MALLOC_FAILED );
+    }
+
+    ssl->handshake->psk_len = psk_len;
+    memcpy( ssl->handshake->psk, psk, ssl->handshake->psk_len );
+
+    return( 0 );
+}
+
 void mbedtls_ssl_set_psk_cb( mbedtls_ssl_config *conf,
                      int (*f_psk)(void *, mbedtls_ssl_context *, const unsigned char *,
                      size_t),
@@ -6441,6 +6476,14 @@
     mbedtls_free( (void *) handshake->curves );
 #endif
 
+#if defined(MBEDTLS_KEY_EXCHANGE__SOME__PSK_ENABLED)
+    if( handshake->psk != NULL )
+    {
+        mbedtls_zeroize( handshake->psk, handshake->psk_len );
+        mbedtls_free( handshake->psk );
+    }
+#endif
+
 #if defined(MBEDTLS_X509_CRT_PARSE_C) && \
     defined(MBEDTLS_SSL_SERVER_NAME_INDICATION)
     /*
diff --git a/programs/ssl/ssl_server2.c b/programs/ssl/ssl_server2.c
index e3e680c..12614c1 100644
--- a/programs/ssl/ssl_server2.c
+++ b/programs/ssl/ssl_server2.c
@@ -678,8 +678,7 @@
         if( name_len == strlen( cur->name ) &&
             memcmp( name, cur->name, name_len ) == 0 )
         {
-            return( mbedtls_ssl_set_psk( ssl, cur->key, cur->key_len,
-                                 name, name_len ) );
+            return( mbedtls_ssl_set_hs_psk( ssl, cur->key, cur->key_len ) );
         }
 
         cur = cur->next;