Use a specific function in the PSK callback
diff --git a/include/mbedtls/ssl.h b/include/mbedtls/ssl.h
index a72914f..e098bc9 100644
--- a/include/mbedtls/ssl.h
+++ b/include/mbedtls/ssl.h
@@ -672,6 +672,10 @@
 #if defined(MBEDTLS_ECDH_C) || defined(MBEDTLS_ECDSA_C)
     const mbedtls_ecp_curve_info **curves;      /*!<  Supported elliptic curves */
 #endif
+#if defined(MBEDTLS_KEY_EXCHANGE__SOME__PSK_ENABLED)
+    unsigned char *psk;                 /*!<  PSK from the callback         */
+    size_t psk_len;                     /*!<  Length of PSK from callback   */
+#endif
 #if defined(MBEDTLS_X509_CRT_PARSE_C)
     /**
      * Current key/cert or key/cert list.
@@ -1581,8 +1585,10 @@
 
 #if defined(MBEDTLS_KEY_EXCHANGE__SOME__PSK_ENABLED)
 /**
- * \brief          Set the Pre Shared Key (PSK) and the identity name connected
- *                 to it.
+ * \brief          Set the Pre Shared Key (PSK) and the expected identity name
+ *
+ * \note           This is mainly useful for clients. Servers will usually
+ *                 want to use \c mbedtls_ssl_set_psk_cb() instead.
  *
  * \param ssl      SSL context
  * \param psk      pointer to the pre-shared key
@@ -1592,11 +1598,28 @@
  *
  * \return         0 if successful or MBEDTLS_ERR_SSL_MALLOC_FAILED
  */
-int mbedtls_ssl_set_psk( mbedtls_ssl_context *ssl, const unsigned char *psk, size_t psk_len,
-                 const unsigned char *psk_identity, size_t psk_identity_len );
+int mbedtls_ssl_set_psk( mbedtls_ssl_context *ssl,
+                const unsigned char *psk, size_t psk_len,
+                const unsigned char *psk_identity, size_t psk_identity_len );
+
 
 /**
- * \brief          Set the PSK callback (server-side only) (Optional).
+ * \brief          Set the Pre Shared Key (PSK) for the current handshake
+ *
+ * \note           This should only be called inside the PSK callback,
+ *                 ie the function passed to \c mbedtls_ssl_set_psk_cb().
+ *
+ * \param ssl      SSL context
+ * \param psk      pointer to the pre-shared key
+ * \param psk_len  pre-shared key length
+ *
+ * \return         0 if successful or MBEDTLS_ERR_SSL_MALLOC_FAILED
+ */
+int mbedtls_ssl_set_hs_psk( mbedtls_ssl_context *ssl,
+                            const unsigned char *psk, size_t psk_len );
+
+/**
+ * \brief          Set the PSK callback (server-side only).
  *
  *                 If set, the PSK callback is called for each
  *                 handshake where a PSK ciphersuite was negotiated.
@@ -1607,10 +1630,14 @@
  *                 mbedtls_ssl_context *ssl, const unsigned char *psk_identity,
  *                 size_t identity_len)
  *                 If a valid PSK identity is found, the callback should use
- *                 mbedtls_ssl_set_psk() on the ssl context to set the correct PSK and
- *                 identity and return 0.
+ *                 \c mbedtls_ssl_set_hs_psk() on the ssl context to set the
+ *                 correct PSK and return 0.
  *                 Any other return value will result in a denied PSK identity.
  *
+ * \note           If you set a PSK callback using this function, then you
+ *                 don't need to set a PSK key and identity using
+ *                 \c mbedtls_ssl_set_psk().
+ *
  * \param conf     SSL configuration
  * \param f_psk    PSK identity function
  * \param p_psk    PSK identity parameter