Update Linux to v5.4.2

Change-Id: Idf6911045d9d382da2cfe01b1edff026404ac8fd
diff --git a/net/tls/Kconfig b/net/tls/Kconfig
index 73f05ec..e4328b3 100644
--- a/net/tls/Kconfig
+++ b/net/tls/Kconfig
@@ -1,3 +1,4 @@
+# SPDX-License-Identifier: GPL-2.0-only
 #
 # TLS configuration
 #
@@ -8,6 +9,7 @@
 	select CRYPTO_AES
 	select CRYPTO_GCM
 	select STREAM_PARSER
+	select NET_SOCK_MSG
 	default n
 	---help---
 	Enable kernel support for TLS protocol. This allows symmetric
diff --git a/net/tls/Makefile b/net/tls/Makefile
index 4d6b728..ef0dc74 100644
--- a/net/tls/Makefile
+++ b/net/tls/Makefile
@@ -1,3 +1,4 @@
+# SPDX-License-Identifier: GPL-2.0-only
 #
 # Makefile for the TLS subsystem.
 #
diff --git a/net/tls/tls_device.c b/net/tls/tls_device.c
index 961b07d..683d008 100644
--- a/net/tls/tls_device.c
+++ b/net/tls/tls_device.c
@@ -52,13 +52,16 @@
 
 static void tls_device_free_ctx(struct tls_context *ctx)
 {
-	if (ctx->tx_conf == TLS_HW)
+	if (ctx->tx_conf == TLS_HW) {
 		kfree(tls_offload_ctx_tx(ctx));
+		kfree(ctx->tx.rec_seq);
+		kfree(ctx->tx.iv);
+	}
 
 	if (ctx->rx_conf == TLS_HW)
 		kfree(tls_offload_ctx_rx(ctx));
 
-	kfree(ctx);
+	tls_ctx_free(NULL, ctx);
 }
 
 static void tls_device_gc_task(struct work_struct *work)
@@ -86,22 +89,6 @@
 	}
 }
 
-static void tls_device_attach(struct tls_context *ctx, struct sock *sk,
-			      struct net_device *netdev)
-{
-	if (sk->sk_destruct != tls_device_sk_destruct) {
-		refcount_set(&ctx->refcount, 1);
-		dev_hold(netdev);
-		ctx->netdev = netdev;
-		spin_lock_irq(&tls_device_lock);
-		list_add_tail(&ctx->list, &tls_device_list);
-		spin_unlock_irq(&tls_device_lock);
-
-		ctx->sk_destruct = sk->sk_destruct;
-		sk->sk_destruct = tls_device_sk_destruct;
-	}
-}
-
 static void tls_device_queue_ctx_destruction(struct tls_context *ctx)
 {
 	unsigned long flags;
@@ -135,13 +122,10 @@
 
 static void destroy_record(struct tls_record_info *record)
 {
-	int nr_frags = record->num_frags;
-	skb_frag_t *frag;
+	int i;
 
-	while (nr_frags-- > 0) {
-		frag = &record->frags[nr_frags];
-		__skb_frag_unref(frag);
-	}
+	for (i = 0; i < record->num_frags; i++)
+		__skb_frag_unref(&record->frags[i]);
 	kfree(record);
 }
 
@@ -172,12 +156,8 @@
 
 	spin_lock_irqsave(&ctx->lock, flags);
 	info = ctx->retransmit_hint;
-	if (info && !before(acked_seq, info->end_seq)) {
+	if (info && !before(acked_seq, info->end_seq))
 		ctx->retransmit_hint = NULL;
-		list_del(&info->list);
-		destroy_record(info);
-		deleted_records++;
-	}
 
 	list_for_each_entry_safe(info, temp, &ctx->records_list, list) {
 		if (before(acked_seq, info->end_seq))
@@ -196,7 +176,7 @@
  * socket and no in-flight SKBs associated with this
  * socket, so it is safe to free all the resources.
  */
-void tls_device_sk_destruct(struct sock *sk)
+static void tls_device_sk_destruct(struct sock *sk)
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx);
@@ -214,7 +194,40 @@
 	if (refcount_dec_and_test(&tls_ctx->refcount))
 		tls_device_queue_ctx_destruction(tls_ctx);
 }
-EXPORT_SYMBOL(tls_device_sk_destruct);
+
+void tls_device_free_resources_tx(struct sock *sk)
+{
+	struct tls_context *tls_ctx = tls_get_ctx(sk);
+
+	tls_free_partial_record(sk, tls_ctx);
+}
+
+static void tls_device_resync_tx(struct sock *sk, struct tls_context *tls_ctx,
+				 u32 seq)
+{
+	struct net_device *netdev;
+	struct sk_buff *skb;
+	int err = 0;
+	u8 *rcd_sn;
+
+	skb = tcp_write_queue_tail(sk);
+	if (skb)
+		TCP_SKB_CB(skb)->eor = 1;
+
+	rcd_sn = tls_ctx->tx.rec_seq;
+
+	down_read(&device_offload_lock);
+	netdev = tls_ctx->netdev;
+	if (netdev)
+		err = netdev->tlsdev_ops->tls_dev_resync(netdev, sk, seq,
+							 rcd_sn,
+							 TLS_OFFLOAD_CTX_DIR_TX);
+	up_read(&device_offload_lock);
+	if (err)
+		return;
+
+	clear_bit_unlock(TLS_TX_SYNC_SCHED, &tls_ctx->flags);
+}
 
 static void tls_append_frag(struct tls_record_info *record,
 			    struct page_frag *pfrag,
@@ -223,14 +236,14 @@
 	skb_frag_t *frag;
 
 	frag = &record->frags[record->num_frags - 1];
-	if (frag->page.p == pfrag->page &&
-	    frag->page_offset + frag->size == pfrag->offset) {
-		frag->size += size;
+	if (skb_frag_page(frag) == pfrag->page &&
+	    skb_frag_off(frag) + skb_frag_size(frag) == pfrag->offset) {
+		skb_frag_size_add(frag, size);
 	} else {
 		++frag;
-		frag->page.p = pfrag->page;
-		frag->page_offset = pfrag->offset;
-		frag->size = size;
+		__skb_frag_set_page(frag, pfrag->page);
+		skb_frag_off_set(frag, pfrag->offset);
+		skb_frag_size_set(frag, size);
 		++record->num_frags;
 		get_page(pfrag->page);
 	}
@@ -243,41 +256,28 @@
 			   struct tls_context *ctx,
 			   struct tls_offload_context_tx *offload_ctx,
 			   struct tls_record_info *record,
-			   struct page_frag *pfrag,
-			   int flags,
-			   unsigned char record_type)
+			   int flags)
 {
+	struct tls_prot_info *prot = &ctx->prot_info;
 	struct tcp_sock *tp = tcp_sk(sk);
-	struct page_frag dummy_tag_frag;
 	skb_frag_t *frag;
 	int i;
 
-	/* fill prepend */
-	frag = &record->frags[0];
-	tls_fill_prepend(ctx,
-			 skb_frag_address(frag),
-			 record->len - ctx->tx.prepend_size,
-			 record_type);
-
-	/* HW doesn't care about the data in the tag, because it fills it. */
-	dummy_tag_frag.page = skb_frag_page(frag);
-	dummy_tag_frag.offset = 0;
-
-	tls_append_frag(record, &dummy_tag_frag, ctx->tx.tag_size);
 	record->end_seq = tp->write_seq + record->len;
-	spin_lock_irq(&offload_ctx->lock);
-	list_add_tail(&record->list, &offload_ctx->records_list);
-	spin_unlock_irq(&offload_ctx->lock);
+	list_add_tail_rcu(&record->list, &offload_ctx->records_list);
 	offload_ctx->open_record = NULL;
-	set_bit(TLS_PENDING_CLOSED_RECORD, &ctx->flags);
-	tls_advance_record_sn(sk, &ctx->tx);
+
+	if (test_bit(TLS_TX_SYNC_SCHED, &ctx->flags))
+		tls_device_resync_tx(sk, ctx, tp->write_seq);
+
+	tls_advance_record_sn(sk, prot, &ctx->tx);
 
 	for (i = 0; i < record->num_frags; i++) {
 		frag = &record->frags[i];
 		sg_unmark_end(&offload_ctx->sg_tx_data[i]);
 		sg_set_page(&offload_ctx->sg_tx_data[i], skb_frag_page(frag),
-			    frag->size, frag->page_offset);
-		sk_mem_charge(sk, frag->size);
+			    skb_frag_size(frag), skb_frag_off(frag));
+		sk_mem_charge(sk, skb_frag_size(frag));
 		get_page(skb_frag_page(frag));
 	}
 	sg_mark_end(&offload_ctx->sg_tx_data[record->num_frags - 1]);
@@ -286,6 +286,38 @@
 	return tls_push_sg(sk, ctx, offload_ctx->sg_tx_data, 0, flags);
 }
 
+static int tls_device_record_close(struct sock *sk,
+				   struct tls_context *ctx,
+				   struct tls_record_info *record,
+				   struct page_frag *pfrag,
+				   unsigned char record_type)
+{
+	struct tls_prot_info *prot = &ctx->prot_info;
+	int ret;
+
+	/* append tag
+	 * device will fill in the tag, we just need to append a placeholder
+	 * use socket memory to improve coalescing (re-using a single buffer
+	 * increases frag count)
+	 * if we can't allocate memory now, steal some back from data
+	 */
+	if (likely(skb_page_frag_refill(prot->tag_size, pfrag,
+					sk->sk_allocation))) {
+		ret = 0;
+		tls_append_frag(record, pfrag, prot->tag_size);
+	} else {
+		ret = prot->tag_size;
+		if (record->len <= prot->overhead_size)
+			return -ENOMEM;
+	}
+
+	/* fill prepend */
+	tls_fill_prepend(ctx, skb_frag_address(&record->frags[0]),
+			 record->len - prot->overhead_size,
+			 record_type, prot->version);
+	return ret;
+}
+
 static int tls_create_new_record(struct tls_offload_context_tx *offload_ctx,
 				 struct page_frag *pfrag,
 				 size_t prepend_size)
@@ -299,7 +331,7 @@
 
 	frag = &record->frags[0];
 	__skb_frag_set_page(frag, pfrag->page);
-	frag->page_offset = pfrag->offset;
+	skb_frag_off_set(frag, pfrag->offset);
 	skb_frag_size_set(frag, prepend_size);
 
 	get_page(pfrag->page);
@@ -340,16 +372,42 @@
 	return 0;
 }
 
+static int tls_device_copy_data(void *addr, size_t bytes, struct iov_iter *i)
+{
+	size_t pre_copy, nocache;
+
+	pre_copy = ~((unsigned long)addr - 1) & (SMP_CACHE_BYTES - 1);
+	if (pre_copy) {
+		pre_copy = min(pre_copy, bytes);
+		if (copy_from_iter(addr, pre_copy, i) != pre_copy)
+			return -EFAULT;
+		bytes -= pre_copy;
+		addr += pre_copy;
+	}
+
+	nocache = round_down(bytes, SMP_CACHE_BYTES);
+	if (copy_from_iter_nocache(addr, nocache, i) != nocache)
+		return -EFAULT;
+	bytes -= nocache;
+	addr += nocache;
+
+	if (bytes && copy_from_iter(addr, bytes, i) != bytes)
+		return -EFAULT;
+
+	return 0;
+}
+
 static int tls_push_data(struct sock *sk,
 			 struct iov_iter *msg_iter,
 			 size_t size, int flags,
 			 unsigned char record_type)
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_prot_info *prot = &tls_ctx->prot_info;
 	struct tls_offload_context_tx *ctx = tls_offload_ctx_tx(tls_ctx);
-	int tls_push_record_flags = flags | MSG_SENDPAGE_NOTLAST;
 	int more = flags & (MSG_SENDPAGE_NOTLAST | MSG_MORE);
 	struct tls_record_info *record = ctx->open_record;
+	int tls_push_record_flags;
 	struct page_frag *pfrag;
 	size_t orig_size = size;
 	u32 max_open_record_len;
@@ -364,10 +422,15 @@
 	if (sk->sk_err)
 		return -sk->sk_err;
 
+	flags |= MSG_SENDPAGE_DECRYPTED;
+	tls_push_record_flags = flags | MSG_SENDPAGE_NOTLAST;
+
 	timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
-	rc = tls_complete_pending_work(sk, tls_ctx, flags, &timeo);
-	if (rc < 0)
-		return rc;
+	if (tls_is_partially_sent_record(tls_ctx)) {
+		rc = tls_push_partial_record(sk, tls_ctx, flags);
+		if (rc < 0)
+			return rc;
+	}
 
 	pfrag = sk_page_frag(sk);
 
@@ -375,10 +438,10 @@
 	 * we need to leave room for an authentication tag.
 	 */
 	max_open_record_len = TLS_MAX_PAYLOAD_SIZE +
-			      tls_ctx->tx.prepend_size;
+			      prot->prepend_size;
 	do {
 		rc = tls_do_allocation(sk, ctx, pfrag,
-				       tls_ctx->tx.prepend_size);
+				       prot->prepend_size);
 		if (rc) {
 			rc = sk_stream_wait_memory(sk, &timeo);
 			if (!rc)
@@ -396,7 +459,7 @@
 				size = orig_size;
 				destroy_record(record);
 				ctx->open_record = NULL;
-			} else if (record->len > tls_ctx->tx.prepend_size) {
+			} else if (record->len > prot->prepend_size) {
 				goto last_record;
 			}
 
@@ -407,12 +470,10 @@
 		copy = min_t(size_t, size, (pfrag->size - pfrag->offset));
 		copy = min_t(size_t, copy, (max_open_record_len - record->len));
 
-		if (copy_from_iter_nocache(page_address(pfrag->page) +
-					       pfrag->offset,
-					   copy, msg_iter) != copy) {
-			rc = -EFAULT;
+		rc = tls_device_copy_data(page_address(pfrag->page) +
+					  pfrag->offset, copy, msg_iter);
+		if (rc)
 			goto handle_error;
-		}
 		tls_append_frag(record, pfrag, copy);
 
 		size -= copy;
@@ -421,7 +482,7 @@
 			tls_push_record_flags = flags;
 			if (more) {
 				tls_ctx->pending_open_record_frags =
-						record->num_frags;
+						!!record->num_frags;
 				break;
 			}
 
@@ -430,13 +491,24 @@
 
 		if (done || record->len >= max_open_record_len ||
 		    (record->num_frags >= MAX_SKB_FRAGS - 1)) {
+			rc = tls_device_record_close(sk, tls_ctx, record,
+						     pfrag, record_type);
+			if (rc) {
+				if (rc > 0) {
+					size += rc;
+				} else {
+					size = orig_size;
+					destroy_record(record);
+					ctx->open_record = NULL;
+					break;
+				}
+			}
+
 			rc = tls_push_record(sk,
 					     tls_ctx,
 					     ctx,
 					     record,
-					     pfrag,
-					     tls_push_record_flags,
-					     record_type);
+					     tls_push_record_flags);
 			if (rc < 0)
 				break;
 		}
@@ -451,8 +523,10 @@
 int tls_device_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
 {
 	unsigned char record_type = TLS_RECORD_TYPE_DATA;
+	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	int rc;
 
+	mutex_lock(&tls_ctx->tx_lock);
 	lock_sock(sk);
 
 	if (unlikely(msg->msg_controllen)) {
@@ -466,12 +540,14 @@
 
 out:
 	release_sock(sk);
+	mutex_unlock(&tls_ctx->tx_lock);
 	return rc;
 }
 
 int tls_device_sendpage(struct sock *sk, struct page *page,
 			int offset, size_t size, int flags)
 {
+	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct iov_iter	msg_iter;
 	char *kaddr = kmap(page);
 	struct kvec iov;
@@ -480,6 +556,7 @@
 	if (flags & MSG_SENDPAGE_NOTLAST)
 		flags |= MSG_MORE;
 
+	mutex_lock(&tls_ctx->tx_lock);
 	lock_sock(sk);
 
 	if (flags & MSG_OOB) {
@@ -489,13 +566,14 @@
 
 	iov.iov_base = kaddr + offset;
 	iov.iov_len = size;
-	iov_iter_kvec(&msg_iter, WRITE | ITER_KVEC, &iov, 1, size);
+	iov_iter_kvec(&msg_iter, WRITE, &iov, 1, size);
 	rc = tls_push_data(sk, &msg_iter, size,
 			   flags, TLS_RECORD_TYPE_DATA);
 	kunmap(page);
 
 out:
 	release_sock(sk);
+	mutex_unlock(&tls_ctx->tx_lock);
 	return rc;
 }
 
@@ -511,12 +589,16 @@
 		/* if retransmit_hint is irrelevant start
 		 * from the beggining of the list
 		 */
-		info = list_first_entry(&context->records_list,
-					struct tls_record_info, list);
+		info = list_first_entry_or_null(&context->records_list,
+						struct tls_record_info, list);
+		if (!info)
+			return NULL;
 		record_sn = context->unacked_record_sn;
 	}
 
-	list_for_each_entry_from(info, &context->records_list, list) {
+	/* We just need the _rcu for the READ_ONCE() */
+	rcu_read_lock();
+	list_for_each_entry_from_rcu(info, &context->records_list, list) {
 		if (before(seq, info->end_seq)) {
 			if (!context->retransmit_hint ||
 			    after(info->end_seq,
@@ -525,12 +607,15 @@
 				context->retransmit_hint = info;
 			}
 			*p_record_sn = record_sn;
-			return info;
+			goto exit_rcu_unlock;
 		}
 		record_sn++;
 	}
+	info = NULL;
 
-	return NULL;
+exit_rcu_unlock:
+	rcu_read_unlock();
+	return info;
 }
 EXPORT_SYMBOL(tls_get_record);
 
@@ -538,15 +623,45 @@
 {
 	struct iov_iter	msg_iter;
 
-	iov_iter_kvec(&msg_iter, WRITE | ITER_KVEC, NULL, 0, 0);
+	iov_iter_kvec(&msg_iter, WRITE, NULL, 0, 0);
 	return tls_push_data(sk, &msg_iter, 0, flags, TLS_RECORD_TYPE_DATA);
 }
 
-void handle_device_resync(struct sock *sk, u32 seq, u64 rcd_sn)
+void tls_device_write_space(struct sock *sk, struct tls_context *ctx)
+{
+	if (tls_is_partially_sent_record(ctx)) {
+		gfp_t sk_allocation = sk->sk_allocation;
+
+		WARN_ON_ONCE(sk->sk_write_pending);
+
+		sk->sk_allocation = GFP_ATOMIC;
+		tls_push_partial_record(sk, ctx,
+					MSG_DONTWAIT | MSG_NOSIGNAL |
+					MSG_SENDPAGE_DECRYPTED);
+		sk->sk_allocation = sk_allocation;
+	}
+}
+
+static void tls_device_resync_rx(struct tls_context *tls_ctx,
+				 struct sock *sk, u32 seq, u8 *rcd_sn)
+{
+	struct net_device *netdev;
+
+	if (WARN_ON(test_and_set_bit(TLS_RX_SYNC_RUNNING, &tls_ctx->flags)))
+		return;
+	netdev = READ_ONCE(tls_ctx->netdev);
+	if (netdev)
+		netdev->tlsdev_ops->tls_dev_resync(netdev, sk, seq, rcd_sn,
+						   TLS_OFFLOAD_CTX_DIR_RX);
+	clear_bit_unlock(TLS_RX_SYNC_RUNNING, &tls_ctx->flags);
+}
+
+void tls_device_rx_resync_new_rec(struct sock *sk, u32 rcd_len, u32 seq)
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
-	struct net_device *netdev = tls_ctx->netdev;
 	struct tls_offload_context_rx *rx_ctx;
+	u8 rcd_sn[TLS_MAX_REC_SEQ_SIZE];
+	struct tls_prot_info *prot;
 	u32 is_req_pending;
 	s64 resync_req;
 	u32 req_seq;
@@ -554,22 +669,90 @@
 	if (tls_ctx->rx_conf != TLS_HW)
 		return;
 
+	prot = &tls_ctx->prot_info;
 	rx_ctx = tls_offload_ctx_rx(tls_ctx);
-	resync_req = atomic64_read(&rx_ctx->resync_req);
-	req_seq = ntohl(resync_req >> 32) - ((u32)TLS_HEADER_SIZE - 1);
-	is_req_pending = resync_req;
+	memcpy(rcd_sn, tls_ctx->rx.rec_seq, prot->rec_seq_size);
 
-	if (unlikely(is_req_pending) && req_seq == seq &&
-	    atomic64_try_cmpxchg(&rx_ctx->resync_req, &resync_req, 0))
-		netdev->tlsdev_ops->tls_dev_resync_rx(netdev, sk,
-						      seq + TLS_HEADER_SIZE - 1,
-						      rcd_sn);
+	switch (rx_ctx->resync_type) {
+	case TLS_OFFLOAD_SYNC_TYPE_DRIVER_REQ:
+		resync_req = atomic64_read(&rx_ctx->resync_req);
+		req_seq = resync_req >> 32;
+		seq += TLS_HEADER_SIZE - 1;
+		is_req_pending = resync_req;
+
+		if (likely(!is_req_pending) || req_seq != seq ||
+		    !atomic64_try_cmpxchg(&rx_ctx->resync_req, &resync_req, 0))
+			return;
+		break;
+	case TLS_OFFLOAD_SYNC_TYPE_CORE_NEXT_HINT:
+		if (likely(!rx_ctx->resync_nh_do_now))
+			return;
+
+		/* head of next rec is already in, note that the sock_inq will
+		 * include the currently parsed message when called from parser
+		 */
+		if (tcp_inq(sk) > rcd_len)
+			return;
+
+		rx_ctx->resync_nh_do_now = 0;
+		seq += rcd_len;
+		tls_bigint_increment(rcd_sn, prot->rec_seq_size);
+		break;
+	}
+
+	tls_device_resync_rx(tls_ctx, sk, seq, rcd_sn);
+}
+
+static void tls_device_core_ctrl_rx_resync(struct tls_context *tls_ctx,
+					   struct tls_offload_context_rx *ctx,
+					   struct sock *sk, struct sk_buff *skb)
+{
+	struct strp_msg *rxm;
+
+	/* device will request resyncs by itself based on stream scan */
+	if (ctx->resync_type != TLS_OFFLOAD_SYNC_TYPE_CORE_NEXT_HINT)
+		return;
+	/* already scheduled */
+	if (ctx->resync_nh_do_now)
+		return;
+	/* seen decrypted fragments since last fully-failed record */
+	if (ctx->resync_nh_reset) {
+		ctx->resync_nh_reset = 0;
+		ctx->resync_nh.decrypted_failed = 1;
+		ctx->resync_nh.decrypted_tgt = TLS_DEVICE_RESYNC_NH_START_IVAL;
+		return;
+	}
+
+	if (++ctx->resync_nh.decrypted_failed <= ctx->resync_nh.decrypted_tgt)
+		return;
+
+	/* doing resync, bump the next target in case it fails */
+	if (ctx->resync_nh.decrypted_tgt < TLS_DEVICE_RESYNC_NH_MAX_IVAL)
+		ctx->resync_nh.decrypted_tgt *= 2;
+	else
+		ctx->resync_nh.decrypted_tgt += TLS_DEVICE_RESYNC_NH_MAX_IVAL;
+
+	rxm = strp_msg(skb);
+
+	/* head of next rec is already in, parser will sync for us */
+	if (tcp_inq(sk) > rxm->full_len) {
+		ctx->resync_nh_do_now = 1;
+	} else {
+		struct tls_prot_info *prot = &tls_ctx->prot_info;
+		u8 rcd_sn[TLS_MAX_REC_SEQ_SIZE];
+
+		memcpy(rcd_sn, tls_ctx->rx.rec_seq, prot->rec_seq_size);
+		tls_bigint_increment(rcd_sn, prot->rec_seq_size);
+
+		tls_device_resync_rx(tls_ctx, sk, tcp_sk(sk)->copied_seq,
+				     rcd_sn);
+	}
 }
 
 static int tls_device_reencrypt(struct sock *sk, struct sk_buff *skb)
 {
 	struct strp_msg *rxm = strp_msg(skb);
-	int err = 0, offset = rxm->offset, copy, nsg;
+	int err = 0, offset = rxm->offset, copy, nsg, data_len, pos;
 	struct sk_buff *skb_iter, *unused;
 	struct scatterlist sg[1];
 	char *orig_buf, *buf;
@@ -590,8 +773,10 @@
 	sg_set_buf(&sg[0], buf,
 		   rxm->full_len + TLS_HEADER_SIZE +
 		   TLS_CIPHER_AES_GCM_128_IV_SIZE);
-	skb_copy_bits(skb, offset, buf,
-		      TLS_HEADER_SIZE + TLS_CIPHER_AES_GCM_128_IV_SIZE);
+	err = skb_copy_bits(skb, offset, buf,
+			    TLS_HEADER_SIZE + TLS_CIPHER_AES_GCM_128_IV_SIZE);
+	if (err)
+		goto free_buf;
 
 	/* We are interested only in the decrypted data not the auth */
 	err = decrypt_skb(sk, skb, sg);
@@ -600,27 +785,50 @@
 	else
 		err = 0;
 
-	copy = min_t(int, skb_pagelen(skb) - offset,
-		     rxm->full_len - TLS_CIPHER_AES_GCM_128_TAG_SIZE);
+	data_len = rxm->full_len - TLS_CIPHER_AES_GCM_128_TAG_SIZE;
 
-	if (skb->decrypted)
-		skb_store_bits(skb, offset, buf, copy);
+	if (skb_pagelen(skb) > offset) {
+		copy = min_t(int, skb_pagelen(skb) - offset, data_len);
 
-	offset += copy;
-	buf += copy;
-
-	skb_walk_frags(skb, skb_iter) {
-		copy = min_t(int, skb_iter->len,
-			     rxm->full_len - offset + rxm->offset -
-			     TLS_CIPHER_AES_GCM_128_TAG_SIZE);
-
-		if (skb_iter->decrypted)
-			skb_store_bits(skb_iter, offset, buf, copy);
+		if (skb->decrypted) {
+			err = skb_store_bits(skb, offset, buf, copy);
+			if (err)
+				goto free_buf;
+		}
 
 		offset += copy;
 		buf += copy;
 	}
 
+	pos = skb_pagelen(skb);
+	skb_walk_frags(skb, skb_iter) {
+		int frag_pos;
+
+		/* Practically all frags must belong to msg if reencrypt
+		 * is needed with current strparser and coalescing logic,
+		 * but strparser may "get optimized", so let's be safe.
+		 */
+		if (pos + skb_iter->len <= offset)
+			goto done_with_frag;
+		if (pos >= data_len + rxm->offset)
+			break;
+
+		frag_pos = offset - pos;
+		copy = min_t(int, skb_iter->len - frag_pos,
+			     data_len + rxm->offset - offset);
+
+		if (skb_iter->decrypted) {
+			err = skb_store_bits(skb_iter, frag_pos, buf, copy);
+			if (err)
+				goto free_buf;
+		}
+
+		offset += copy;
+		buf += copy;
+done_with_frag:
+		pos += skb_iter->len;
+	}
+
 free_buf:
 	kfree(orig_buf);
 	return err;
@@ -634,10 +842,6 @@
 	int is_encrypted = !is_decrypted;
 	struct sk_buff *skb_iter;
 
-	/* Skip if it is already decrypted */
-	if (ctx->sw.decrypted)
-		return 0;
-
 	/* Check if all the data is decrypted already */
 	skb_walk_frags(skb, skb_iter) {
 		is_decrypted &= skb_iter->decrypted;
@@ -646,39 +850,62 @@
 
 	ctx->sw.decrypted |= is_decrypted;
 
-	/* Return immedeatly if the record is either entirely plaintext or
+	/* Return immediately if the record is either entirely plaintext or
 	 * entirely ciphertext. Otherwise handle reencrypt partially decrypted
 	 * record.
 	 */
-	return (is_encrypted || is_decrypted) ? 0 :
-		tls_device_reencrypt(sk, skb);
+	if (is_decrypted) {
+		ctx->resync_nh_reset = 1;
+		return 0;
+	}
+	if (is_encrypted) {
+		tls_device_core_ctrl_rx_resync(tls_ctx, ctx, sk, skb);
+		return 0;
+	}
+
+	ctx->resync_nh_reset = 1;
+	return tls_device_reencrypt(sk, skb);
+}
+
+static void tls_device_attach(struct tls_context *ctx, struct sock *sk,
+			      struct net_device *netdev)
+{
+	if (sk->sk_destruct != tls_device_sk_destruct) {
+		refcount_set(&ctx->refcount, 1);
+		dev_hold(netdev);
+		ctx->netdev = netdev;
+		spin_lock_irq(&tls_device_lock);
+		list_add_tail(&ctx->list, &tls_device_list);
+		spin_unlock_irq(&tls_device_lock);
+
+		ctx->sk_destruct = sk->sk_destruct;
+		sk->sk_destruct = tls_device_sk_destruct;
+	}
 }
 
 int tls_set_device_offload(struct sock *sk, struct tls_context *ctx)
 {
 	u16 nonce_size, tag_size, iv_size, rec_seq_size;
+	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_prot_info *prot = &tls_ctx->prot_info;
 	struct tls_record_info *start_marker_record;
 	struct tls_offload_context_tx *offload_ctx;
 	struct tls_crypto_info *crypto_info;
 	struct net_device *netdev;
 	char *iv, *rec_seq;
 	struct sk_buff *skb;
-	int rc = -EINVAL;
 	__be64 rcd_sn;
+	int rc;
 
 	if (!ctx)
-		goto out;
+		return -EINVAL;
 
-	if (ctx->priv_ctx_tx) {
-		rc = -EEXIST;
-		goto out;
-	}
+	if (ctx->priv_ctx_tx)
+		return -EEXIST;
 
 	start_marker_record = kmalloc(sizeof(*start_marker_record), GFP_KERNEL);
-	if (!start_marker_record) {
-		rc = -ENOMEM;
-		goto out;
-	}
+	if (!start_marker_record)
+		return -ENOMEM;
 
 	offload_ctx = kzalloc(TLS_OFFLOAD_CONTEXT_SIZE_TX, GFP_KERNEL);
 	if (!offload_ctx) {
@@ -687,6 +914,11 @@
 	}
 
 	crypto_info = &ctx->crypto_send.info;
+	if (crypto_info->version != TLS_1_2_VERSION) {
+		rc = -EOPNOTSUPP;
+		goto free_offload_ctx;
+	}
+
 	switch (crypto_info->cipher_type) {
 	case TLS_CIPHER_AES_GCM_128:
 		nonce_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
@@ -702,10 +934,18 @@
 		goto free_offload_ctx;
 	}
 
-	ctx->tx.prepend_size = TLS_HEADER_SIZE + nonce_size;
-	ctx->tx.tag_size = tag_size;
-	ctx->tx.overhead_size = ctx->tx.prepend_size + ctx->tx.tag_size;
-	ctx->tx.iv_size = iv_size;
+	/* Sanity-check the rec_seq_size for stack allocations */
+	if (rec_seq_size > TLS_MAX_REC_SEQ_SIZE) {
+		rc = -EINVAL;
+		goto free_offload_ctx;
+	}
+
+	prot->version = crypto_info->version;
+	prot->cipher_type = crypto_info->cipher_type;
+	prot->prepend_size = TLS_HEADER_SIZE + nonce_size;
+	prot->tag_size = tag_size;
+	prot->overhead_size = prot->prepend_size + prot->tag_size;
+	prot->iv_size = iv_size;
 	ctx->tx.iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
 			     GFP_KERNEL);
 	if (!ctx->tx.iv) {
@@ -715,7 +955,7 @@
 
 	memcpy(ctx->tx.iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
 
-	ctx->tx.rec_seq_size = rec_seq_size;
+	prot->rec_seq_size = rec_seq_size;
 	ctx->tx.rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL);
 	if (!ctx->tx.rec_seq) {
 		rc = -ENOMEM;
@@ -751,17 +991,11 @@
 	if (skb)
 		TCP_SKB_CB(skb)->eor = 1;
 
-	/* We support starting offload on multiple sockets
-	 * concurrently, so we only need a read lock here.
-	 * This lock must precede get_netdev_for_sock to prevent races between
-	 * NETDEV_DOWN and setsockopt.
-	 */
-	down_read(&device_offload_lock);
 	netdev = get_netdev_for_sock(sk);
 	if (!netdev) {
 		pr_err_ratelimited("%s: netdev not found\n", __func__);
 		rc = -EINVAL;
-		goto release_lock;
+		goto disable_cad;
 	}
 
 	if (!(netdev->features & NETIF_F_HW_TLS_TX)) {
@@ -772,10 +1006,15 @@
 	/* Avoid offloading if the device is down
 	 * We don't want to offload new flows after
 	 * the NETDEV_DOWN event
+	 *
+	 * device_offload_lock is taken in tls_devices's NETDEV_DOWN
+	 * handler thus protecting from the device going down before
+	 * ctx was added to tls_device_list.
 	 */
+	down_read(&device_offload_lock);
 	if (!(netdev->flags & IFF_UP)) {
 		rc = -EINVAL;
-		goto release_netdev;
+		goto release_lock;
 	}
 
 	ctx->priv_ctx_tx = offload_ctx;
@@ -783,9 +1022,10 @@
 					     &ctx->crypto_send.info,
 					     tcp_sk(sk)->write_seq);
 	if (rc)
-		goto release_netdev;
+		goto release_lock;
 
 	tls_device_attach(ctx, sk, netdev);
+	up_read(&device_offload_lock);
 
 	/* following this assignment tls_is_sk_tx_device_offloaded
 	 * will return true and the context might be accessed
@@ -793,13 +1033,14 @@
 	 */
 	smp_store_release(&sk->sk_validate_xmit_skb, tls_validate_xmit_skb);
 	dev_put(netdev);
-	up_read(&device_offload_lock);
-	goto out;
 
-release_netdev:
-	dev_put(netdev);
+	return 0;
+
 release_lock:
 	up_read(&device_offload_lock);
+release_netdev:
+	dev_put(netdev);
+disable_cad:
 	clean_acked_data_disable(inet_csk(sk));
 	crypto_free_aead(offload_ctx->aead_send);
 free_rec_seq:
@@ -811,7 +1052,6 @@
 	ctx->priv_ctx_tx = NULL;
 free_marker_record:
 	kfree(start_marker_record);
-out:
 	return rc;
 }
 
@@ -821,22 +1061,16 @@
 	struct net_device *netdev;
 	int rc = 0;
 
-	/* We support starting offload on multiple sockets
-	 * concurrently, so we only need a read lock here.
-	 * This lock must precede get_netdev_for_sock to prevent races between
-	 * NETDEV_DOWN and setsockopt.
-	 */
-	down_read(&device_offload_lock);
+	if (ctx->crypto_recv.info.version != TLS_1_2_VERSION)
+		return -EOPNOTSUPP;
+
 	netdev = get_netdev_for_sock(sk);
 	if (!netdev) {
 		pr_err_ratelimited("%s: netdev not found\n", __func__);
-		rc = -EINVAL;
-		goto release_lock;
+		return -EINVAL;
 	}
 
 	if (!(netdev->features & NETIF_F_HW_TLS_RX)) {
-		pr_err_ratelimited("%s: netdev %s with no TLS offload\n",
-				   __func__, netdev->name);
 		rc = -ENOTSUPP;
 		goto release_netdev;
 	}
@@ -844,17 +1078,23 @@
 	/* Avoid offloading if the device is down
 	 * We don't want to offload new flows after
 	 * the NETDEV_DOWN event
+	 *
+	 * device_offload_lock is taken in tls_devices's NETDEV_DOWN
+	 * handler thus protecting from the device going down before
+	 * ctx was added to tls_device_list.
 	 */
+	down_read(&device_offload_lock);
 	if (!(netdev->flags & IFF_UP)) {
 		rc = -EINVAL;
-		goto release_netdev;
+		goto release_lock;
 	}
 
 	context = kzalloc(TLS_OFFLOAD_CONTEXT_SIZE_RX, GFP_KERNEL);
 	if (!context) {
 		rc = -ENOMEM;
-		goto release_netdev;
+		goto release_lock;
 	}
+	context->resync_nh_reset = 1;
 
 	ctx->priv_ctx_rx = context;
 	rc = tls_set_sw_offload(sk, ctx, 0);
@@ -864,23 +1104,26 @@
 	rc = netdev->tlsdev_ops->tls_dev_add(netdev, sk, TLS_OFFLOAD_CTX_DIR_RX,
 					     &ctx->crypto_recv.info,
 					     tcp_sk(sk)->copied_seq);
-	if (rc) {
-		pr_err_ratelimited("%s: The netdev has refused to offload this socket\n",
-				   __func__);
+	if (rc)
 		goto free_sw_resources;
-	}
 
 	tls_device_attach(ctx, sk, netdev);
-	goto release_netdev;
+	up_read(&device_offload_lock);
+
+	dev_put(netdev);
+
+	return 0;
 
 free_sw_resources:
+	up_read(&device_offload_lock);
 	tls_sw_free_resources_rx(sk);
+	down_read(&device_offload_lock);
 release_ctx:
 	ctx->priv_ctx_rx = NULL;
-release_netdev:
-	dev_put(netdev);
 release_lock:
 	up_read(&device_offload_lock);
+release_netdev:
+	dev_put(netdev);
 	return rc;
 }
 
@@ -894,12 +1137,6 @@
 	if (!netdev)
 		goto out;
 
-	if (!(netdev->features & NETIF_F_HW_TLS_RX)) {
-		pr_err_ratelimited("%s: device is missing NETIF_F_HW_TLS_RX cap\n",
-				   __func__);
-		goto out;
-	}
-
 	netdev->tlsdev_ops->tls_dev_del(netdev, tls_ctx,
 					TLS_OFFLOAD_CTX_DIR_RX);
 
@@ -909,8 +1146,6 @@
 	}
 out:
 	up_read(&device_offload_lock);
-	kfree(tls_ctx->rx.rec_seq);
-	kfree(tls_ctx->rx.iv);
 	tls_sw_release_resources_rx(sk);
 }
 
@@ -940,7 +1175,10 @@
 		if (ctx->rx_conf == TLS_HW)
 			netdev->tlsdev_ops->tls_dev_del(netdev, ctx,
 							TLS_OFFLOAD_CTX_DIR_RX);
-		ctx->netdev = NULL;
+		WRITE_ONCE(ctx->netdev, NULL);
+		smp_mb__before_atomic(); /* pairs with test_and_set_bit() */
+		while (test_bit(TLS_RX_SYNC_RUNNING, &ctx->flags))
+			usleep_range(10, 200);
 		dev_put(netdev);
 		list_del_init(&ctx->list);
 
@@ -960,14 +1198,15 @@
 {
 	struct net_device *dev = netdev_notifier_info_to_dev(ptr);
 
-	if (!(dev->features & (NETIF_F_HW_TLS_RX | NETIF_F_HW_TLS_TX)))
+	if (!dev->tlsdev_ops &&
+	    !(dev->features & (NETIF_F_HW_TLS_RX | NETIF_F_HW_TLS_TX)))
 		return NOTIFY_DONE;
 
 	switch (event) {
 	case NETDEV_REGISTER:
 	case NETDEV_FEAT_CHANGE:
 		if ((dev->features & NETIF_F_HW_TLS_RX) &&
-		    !dev->tlsdev_ops->tls_dev_resync_rx)
+		    !dev->tlsdev_ops->tls_dev_resync)
 			return NOTIFY_BAD;
 
 		if  (dev->tlsdev_ops &&
@@ -995,4 +1234,5 @@
 {
 	unregister_netdevice_notifier(&tls_dev_notifier);
 	flush_work(&tls_device_gc_work);
+	clean_acked_data_flush();
 }
diff --git a/net/tls/tls_device_fallback.c b/net/tls/tls_device_fallback.c
index 450a6db..2889533 100644
--- a/net/tls/tls_device_fallback.c
+++ b/net/tls/tls_device_fallback.c
@@ -73,7 +73,8 @@
 	len -= TLS_CIPHER_AES_GCM_128_IV_SIZE;
 
 	tls_make_aad(aad, len - TLS_CIPHER_AES_GCM_128_TAG_SIZE,
-		     (char *)&rcd_sn, sizeof(rcd_sn), buf[0]);
+		(char *)&rcd_sn, sizeof(rcd_sn), buf[0],
+		TLS_1_2_VERSION);
 
 	memcpy(iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, buf + TLS_HEADER_SIZE,
 	       TLS_CIPHER_AES_GCM_128_IV_SIZE);
@@ -193,18 +194,30 @@
 
 static void complete_skb(struct sk_buff *nskb, struct sk_buff *skb, int headln)
 {
+	struct sock *sk = skb->sk;
+	int delta;
+
 	skb_copy_header(nskb, skb);
 
 	skb_put(nskb, skb->len);
 	memcpy(nskb->data, skb->data, headln);
-	update_chksum(nskb, headln);
 
 	nskb->destructor = skb->destructor;
-	nskb->sk = skb->sk;
+	nskb->sk = sk;
 	skb->destructor = NULL;
 	skb->sk = NULL;
-	refcount_add(nskb->truesize - skb->truesize,
-		     &nskb->sk->sk_wmem_alloc);
+
+	update_chksum(nskb, headln);
+
+	/* sock_efree means skb must gone through skb_orphan_partial() */
+	if (nskb->destructor == sock_efree)
+		return;
+
+	delta = nskb->truesize - skb->truesize;
+	if (likely(delta < 0))
+		WARN_ON_ONCE(refcount_sub_and_test(-delta, &sk->sk_wmem_alloc));
+	else if (delta)
+		refcount_add(delta, &sk->sk_wmem_alloc);
 }
 
 /* This function may be called after the user socket is already
@@ -231,7 +244,6 @@
 	record = tls_get_record(ctx, tcp_seq, rcd_sn);
 	if (!record) {
 		spin_unlock_irqrestore(&ctx->lock, flags);
-		WARN(1, "Record not found for seq %u\n", tcp_seq);
 		return -EINVAL;
 	}
 
@@ -261,7 +273,7 @@
 
 		__skb_frag_ref(frag);
 		sg_set_page(sg_in + i, skb_frag_page(frag),
-			    skb_frag_size(frag), frag->page_offset);
+			    skb_frag_size(frag), skb_frag_off(frag));
 
 		remaining -= skb_frag_size(frag);
 
@@ -400,7 +412,10 @@
 		put_page(sg_page(&sg_in[--resync_sgs]));
 	kfree(sg_in);
 free_orig:
-	kfree_skb(skb);
+	if (nskb)
+		consume_skb(skb);
+	else
+		kfree_skb(skb);
 	return nskb;
 }
 
@@ -415,6 +430,12 @@
 }
 EXPORT_SYMBOL_GPL(tls_validate_xmit_skb);
 
+struct sk_buff *tls_encrypt_skb(struct sk_buff *skb)
+{
+	return tls_sw_fallback(skb->sk, skb);
+}
+EXPORT_SYMBOL_GPL(tls_encrypt_skb);
+
 int tls_sw_fallback_init(struct sock *sk,
 			 struct tls_offload_context_tx *offload_ctx,
 			 struct tls_crypto_info *crypto_info)
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c
index 523622d..eff4442 100644
--- a/net/tls/tls_main.c
+++ b/net/tls/tls_main.c
@@ -39,6 +39,7 @@
 #include <linux/netdevice.h>
 #include <linux/sched/signal.h>
 #include <linux/inetdevice.h>
+#include <linux/inet_diag.h>
 
 #include <net/tls.h>
 
@@ -55,10 +56,14 @@
 
 static struct proto *saved_tcpv6_prot;
 static DEFINE_MUTEX(tcpv6_prot_mutex);
+static struct proto *saved_tcpv4_prot;
+static DEFINE_MUTEX(tcpv4_prot_mutex);
 static LIST_HEAD(device_list);
-static DEFINE_MUTEX(device_mutex);
+static DEFINE_SPINLOCK(device_spinlock);
 static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG];
 static struct proto_ops tls_sw_proto_ops;
+static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
+			 struct proto *base);
 
 static void update_sk_prot(struct sock *sk, struct tls_context *ctx)
 {
@@ -141,9 +146,7 @@
 		size = sg->length;
 	}
 
-	clear_bit(TLS_PENDING_CLOSED_RECORD, &ctx->flags);
 	ctx->in_tcp_sendpages = false;
-	ctx->sk_write_space(sk);
 
 	return 0;
 }
@@ -193,15 +196,12 @@
 	return rc;
 }
 
-int tls_push_pending_closed_record(struct sock *sk, struct tls_context *ctx,
-				   int flags, long *timeo)
+int tls_push_partial_record(struct sock *sk, struct tls_context *ctx,
+			    int flags)
 {
 	struct scatterlist *sg;
 	u16 offset;
 
-	if (!tls_is_partially_sent_record(ctx))
-		return ctx->push_pending_record(sk, flags);
-
 	sg = ctx->partially_sent_record;
 	offset = ctx->partially_sent_offset;
 
@@ -209,6 +209,17 @@
 	return tls_push_sg(sk, ctx, sg, offset, flags);
 }
 
+void tls_free_partial_record(struct sock *sk, struct tls_context *ctx)
+{
+	struct scatterlist *sg;
+
+	for (sg = ctx->partially_sent_record; sg; sg = sg_next(sg)) {
+		put_page(sg_page(sg));
+		sk_mem_uncharge(sk, sg->length);
+	}
+	ctx->partially_sent_record = NULL;
+}
+
 static void tls_write_space(struct sock *sk)
 {
 	struct tls_context *ctx = tls_get_ctx(sk);
@@ -222,100 +233,95 @@
 		return;
 	}
 
-	if (!sk->sk_write_pending && tls_is_pending_closed_record(ctx)) {
-		gfp_t sk_allocation = sk->sk_allocation;
-		int rc;
-		long timeo = 0;
-
-		sk->sk_allocation = GFP_ATOMIC;
-		rc = tls_push_pending_closed_record(sk, ctx,
-						    MSG_DONTWAIT |
-						    MSG_NOSIGNAL,
-						    &timeo);
-		sk->sk_allocation = sk_allocation;
-
-		if (rc < 0)
-			return;
-	}
+#ifdef CONFIG_TLS_DEVICE
+	if (ctx->tx_conf == TLS_HW)
+		tls_device_write_space(sk, ctx);
+	else
+#endif
+		tls_sw_write_space(sk, ctx);
 
 	ctx->sk_write_space(sk);
 }
 
-static void tls_ctx_free(struct tls_context *ctx)
+/**
+ * tls_ctx_free() - free TLS ULP context
+ * @sk:  socket to with @ctx is attached
+ * @ctx: TLS context structure
+ *
+ * Free TLS context. If @sk is %NULL caller guarantees that the socket
+ * to which @ctx was attached has no outstanding references.
+ */
+void tls_ctx_free(struct sock *sk, struct tls_context *ctx)
 {
 	if (!ctx)
 		return;
 
 	memzero_explicit(&ctx->crypto_send, sizeof(ctx->crypto_send));
 	memzero_explicit(&ctx->crypto_recv, sizeof(ctx->crypto_recv));
-	kfree(ctx);
+	mutex_destroy(&ctx->tx_lock);
+
+	if (sk)
+		kfree_rcu(ctx, rcu);
+	else
+		kfree(ctx);
 }
 
-static void tls_sk_proto_close(struct sock *sk, long timeout)
+static void tls_sk_proto_cleanup(struct sock *sk,
+				 struct tls_context *ctx, long timeo)
 {
-	struct tls_context *ctx = tls_get_ctx(sk);
-	long timeo = sock_sndtimeo(sk, 0);
-	void (*sk_proto_close)(struct sock *sk, long timeout);
-	bool free_ctx = false;
-
-	lock_sock(sk);
-	sk_proto_close = ctx->sk_proto_close;
-
-	if ((ctx->tx_conf == TLS_HW_RECORD && ctx->rx_conf == TLS_HW_RECORD) ||
-	    (ctx->tx_conf == TLS_BASE && ctx->rx_conf == TLS_BASE)) {
-		free_ctx = true;
-		goto skip_tx_cleanup;
-	}
-
-	if (!tls_complete_pending_work(sk, ctx, 0, &timeo))
+	if (unlikely(sk->sk_write_pending) &&
+	    !wait_on_pending_writer(sk, &timeo))
 		tls_handle_open_record(sk, 0);
 
-	if (ctx->partially_sent_record) {
-		struct scatterlist *sg = ctx->partially_sent_record;
-
-		while (1) {
-			put_page(sg_page(sg));
-			sk_mem_uncharge(sk, sg->length);
-
-			if (sg_is_last(sg))
-				break;
-			sg++;
-		}
-	}
-
 	/* We need these for tls_sw_fallback handling of other packets */
 	if (ctx->tx_conf == TLS_SW) {
 		kfree(ctx->tx.rec_seq);
 		kfree(ctx->tx.iv);
-		tls_sw_free_resources_tx(sk);
+		tls_sw_release_resources_tx(sk);
+	} else if (ctx->tx_conf == TLS_HW) {
+		tls_device_free_resources_tx(sk);
 	}
 
-	if (ctx->rx_conf == TLS_SW) {
-		kfree(ctx->rx.rec_seq);
-		kfree(ctx->rx.iv);
-		tls_sw_free_resources_rx(sk);
-	}
-
-#ifdef CONFIG_TLS_DEVICE
-	if (ctx->rx_conf == TLS_HW)
+	if (ctx->rx_conf == TLS_SW)
+		tls_sw_release_resources_rx(sk);
+	else if (ctx->rx_conf == TLS_HW)
 		tls_device_offload_cleanup_rx(sk);
+}
 
-	if (ctx->tx_conf != TLS_HW && ctx->rx_conf != TLS_HW) {
-#else
-	{
-#endif
-		tls_ctx_free(ctx);
-		ctx = NULL;
-	}
+static void tls_sk_proto_close(struct sock *sk, long timeout)
+{
+	struct inet_connection_sock *icsk = inet_csk(sk);
+	struct tls_context *ctx = tls_get_ctx(sk);
+	long timeo = sock_sndtimeo(sk, 0);
+	bool free_ctx;
 
-skip_tx_cleanup:
-	release_sock(sk);
-	sk_proto_close(sk, timeout);
-	/* free ctx for TLS_HW_RECORD, used by tcp_set_state
-	 * for sk->sk_prot->unhash [tls_hw_unhash]
-	 */
+	if (ctx->tx_conf == TLS_SW)
+		tls_sw_cancel_work_tx(ctx);
+
+	lock_sock(sk);
+	free_ctx = ctx->tx_conf != TLS_HW && ctx->rx_conf != TLS_HW;
+
+	if (ctx->tx_conf != TLS_BASE || ctx->rx_conf != TLS_BASE)
+		tls_sk_proto_cleanup(sk, ctx, timeo);
+
+	write_lock_bh(&sk->sk_callback_lock);
 	if (free_ctx)
-		tls_ctx_free(ctx);
+		rcu_assign_pointer(icsk->icsk_ulp_data, NULL);
+	sk->sk_prot = ctx->sk_proto;
+	if (sk->sk_write_space == tls_write_space)
+		sk->sk_write_space = ctx->sk_write_space;
+	write_unlock_bh(&sk->sk_callback_lock);
+	release_sock(sk);
+	if (ctx->tx_conf == TLS_SW)
+		tls_sw_free_ctx_tx(ctx);
+	if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW)
+		tls_sw_strparser_done(ctx);
+	if (ctx->rx_conf == TLS_SW)
+		tls_sw_free_ctx_rx(ctx);
+	ctx->sk_proto->close(sk, timeout);
+
+	if (free_ctx)
+		tls_ctx_free(sk, ctx);
 }
 
 static int do_tls_getsockopt_tx(struct sock *sk, char __user *optval,
@@ -378,6 +384,30 @@
 			rc = -EFAULT;
 		break;
 	}
+	case TLS_CIPHER_AES_GCM_256: {
+		struct tls12_crypto_info_aes_gcm_256 *
+		  crypto_info_aes_gcm_256 =
+		  container_of(crypto_info,
+			       struct tls12_crypto_info_aes_gcm_256,
+			       info);
+
+		if (len != sizeof(*crypto_info_aes_gcm_256)) {
+			rc = -EINVAL;
+			goto out;
+		}
+		lock_sock(sk);
+		memcpy(crypto_info_aes_gcm_256->iv,
+		       ctx->tx.iv + TLS_CIPHER_AES_GCM_256_SALT_SIZE,
+		       TLS_CIPHER_AES_GCM_256_IV_SIZE);
+		memcpy(crypto_info_aes_gcm_256->rec_seq, ctx->tx.rec_seq,
+		       TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE);
+		release_sock(sk);
+		if (copy_to_user(optval,
+				 crypto_info_aes_gcm_256,
+				 sizeof(*crypto_info_aes_gcm_256)))
+			rc = -EFAULT;
+		break;
+	}
 	default:
 		rc = -EINVAL;
 	}
@@ -408,7 +438,8 @@
 	struct tls_context *ctx = tls_get_ctx(sk);
 
 	if (level != SOL_TLS)
-		return ctx->getsockopt(sk, level, optname, optval, optlen);
+		return ctx->sk_proto->getsockopt(sk, level,
+						 optname, optval, optlen);
 
 	return do_tls_getsockopt(sk, optname, optval, optlen);
 }
@@ -417,7 +448,9 @@
 				  unsigned int optlen, int tx)
 {
 	struct tls_crypto_info *crypto_info;
+	struct tls_crypto_info *alt_crypto_info;
 	struct tls_context *ctx = tls_get_ctx(sk);
+	size_t optsize;
 	int rc = 0;
 	int conf;
 
@@ -426,10 +459,13 @@
 		goto out;
 	}
 
-	if (tx)
+	if (tx) {
 		crypto_info = &ctx->crypto_send.info;
-	else
+		alt_crypto_info = &ctx->crypto_recv.info;
+	} else {
 		crypto_info = &ctx->crypto_recv.info;
+		alt_crypto_info = &ctx->crypto_send.info;
+	}
 
 	/* Currently we don't support set crypto info more than one time */
 	if (TLS_CRYPTO_INFO_READY(crypto_info)) {
@@ -444,57 +480,70 @@
 	}
 
 	/* check version */
-	if (crypto_info->version != TLS_1_2_VERSION) {
+	if (crypto_info->version != TLS_1_2_VERSION &&
+	    crypto_info->version != TLS_1_3_VERSION) {
 		rc = -ENOTSUPP;
 		goto err_crypto_info;
 	}
 
-	switch (crypto_info->cipher_type) {
-	case TLS_CIPHER_AES_GCM_128: {
-		if (optlen != sizeof(struct tls12_crypto_info_aes_gcm_128)) {
+	/* Ensure that TLS version and ciphers are same in both directions */
+	if (TLS_CRYPTO_INFO_READY(alt_crypto_info)) {
+		if (alt_crypto_info->version != crypto_info->version ||
+		    alt_crypto_info->cipher_type != crypto_info->cipher_type) {
 			rc = -EINVAL;
 			goto err_crypto_info;
 		}
-		rc = copy_from_user(crypto_info + 1, optval + sizeof(*crypto_info),
-				    optlen - sizeof(*crypto_info));
-		if (rc) {
-			rc = -EFAULT;
-			goto err_crypto_info;
-		}
+	}
+
+	switch (crypto_info->cipher_type) {
+	case TLS_CIPHER_AES_GCM_128:
+		optsize = sizeof(struct tls12_crypto_info_aes_gcm_128);
+		break;
+	case TLS_CIPHER_AES_GCM_256: {
+		optsize = sizeof(struct tls12_crypto_info_aes_gcm_256);
 		break;
 	}
+	case TLS_CIPHER_AES_CCM_128:
+		optsize = sizeof(struct tls12_crypto_info_aes_ccm_128);
+		break;
 	default:
 		rc = -EINVAL;
 		goto err_crypto_info;
 	}
 
+	if (optlen != optsize) {
+		rc = -EINVAL;
+		goto err_crypto_info;
+	}
+
+	rc = copy_from_user(crypto_info + 1, optval + sizeof(*crypto_info),
+			    optlen - sizeof(*crypto_info));
+	if (rc) {
+		rc = -EFAULT;
+		goto err_crypto_info;
+	}
+
 	if (tx) {
-#ifdef CONFIG_TLS_DEVICE
 		rc = tls_set_device_offload(sk, ctx);
 		conf = TLS_HW;
 		if (rc) {
-#else
-		{
-#endif
 			rc = tls_set_sw_offload(sk, ctx, 1);
+			if (rc)
+				goto err_crypto_info;
 			conf = TLS_SW;
 		}
 	} else {
-#ifdef CONFIG_TLS_DEVICE
 		rc = tls_set_device_offload_rx(sk, ctx);
 		conf = TLS_HW;
 		if (rc) {
-#else
-		{
-#endif
 			rc = tls_set_sw_offload(sk, ctx, 0);
+			if (rc)
+				goto err_crypto_info;
 			conf = TLS_SW;
 		}
+		tls_sw_strparser_arm(sk, ctx);
 	}
 
-	if (rc)
-		goto err_crypto_info;
-
 	if (tx)
 		ctx->tx_conf = conf;
 	else
@@ -540,7 +589,8 @@
 	struct tls_context *ctx = tls_get_ctx(sk);
 
 	if (level != SOL_TLS)
-		return ctx->setsockopt(sk, level, optname, optval, optlen);
+		return ctx->sk_proto->setsockopt(sk, level, optname, optval,
+						 optlen);
 
 	return do_tls_setsockopt(sk, optname, optval, optlen);
 }
@@ -550,39 +600,80 @@
 	struct inet_connection_sock *icsk = inet_csk(sk);
 	struct tls_context *ctx;
 
-	ctx = kzalloc(sizeof(*ctx), GFP_KERNEL);
+	ctx = kzalloc(sizeof(*ctx), GFP_ATOMIC);
 	if (!ctx)
 		return NULL;
 
-	icsk->icsk_ulp_data = ctx;
+	mutex_init(&ctx->tx_lock);
+	rcu_assign_pointer(icsk->icsk_ulp_data, ctx);
+	ctx->sk_proto = sk->sk_prot;
 	return ctx;
 }
 
+static void tls_build_proto(struct sock *sk)
+{
+	int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4;
+
+	/* Build IPv6 TLS whenever the address of tcpv6 _prot changes */
+	if (ip_ver == TLSV6 &&
+	    unlikely(sk->sk_prot != smp_load_acquire(&saved_tcpv6_prot))) {
+		mutex_lock(&tcpv6_prot_mutex);
+		if (likely(sk->sk_prot != saved_tcpv6_prot)) {
+			build_protos(tls_prots[TLSV6], sk->sk_prot);
+			smp_store_release(&saved_tcpv6_prot, sk->sk_prot);
+		}
+		mutex_unlock(&tcpv6_prot_mutex);
+	}
+
+	if (ip_ver == TLSV4 &&
+	    unlikely(sk->sk_prot != smp_load_acquire(&saved_tcpv4_prot))) {
+		mutex_lock(&tcpv4_prot_mutex);
+		if (likely(sk->sk_prot != saved_tcpv4_prot)) {
+			build_protos(tls_prots[TLSV4], sk->sk_prot);
+			smp_store_release(&saved_tcpv4_prot, sk->sk_prot);
+		}
+		mutex_unlock(&tcpv4_prot_mutex);
+	}
+}
+
+static void tls_hw_sk_destruct(struct sock *sk)
+{
+	struct tls_context *ctx = tls_get_ctx(sk);
+	struct inet_connection_sock *icsk = inet_csk(sk);
+
+	ctx->sk_destruct(sk);
+	/* Free ctx */
+	rcu_assign_pointer(icsk->icsk_ulp_data, NULL);
+	tls_ctx_free(sk, ctx);
+}
+
 static int tls_hw_prot(struct sock *sk)
 {
 	struct tls_context *ctx;
 	struct tls_device *dev;
 	int rc = 0;
 
-	mutex_lock(&device_mutex);
+	spin_lock_bh(&device_spinlock);
 	list_for_each_entry(dev, &device_list, dev_list) {
 		if (dev->feature && dev->feature(dev)) {
 			ctx = create_ctx(sk);
 			if (!ctx)
 				goto out;
 
-			ctx->hash = sk->sk_prot->hash;
-			ctx->unhash = sk->sk_prot->unhash;
-			ctx->sk_proto_close = sk->sk_prot->close;
+			spin_unlock_bh(&device_spinlock);
+			tls_build_proto(sk);
+			ctx->sk_destruct = sk->sk_destruct;
+			sk->sk_destruct = tls_hw_sk_destruct;
 			ctx->rx_conf = TLS_HW_RECORD;
 			ctx->tx_conf = TLS_HW_RECORD;
 			update_sk_prot(sk, ctx);
+			spin_lock_bh(&device_spinlock);
 			rc = 1;
 			break;
 		}
 	}
 out:
-	mutex_unlock(&device_mutex);
+	spin_unlock_bh(&device_spinlock);
 	return rc;
 }
 
@@ -591,13 +682,18 @@
 	struct tls_context *ctx = tls_get_ctx(sk);
 	struct tls_device *dev;
 
-	mutex_lock(&device_mutex);
+	spin_lock_bh(&device_spinlock);
 	list_for_each_entry(dev, &device_list, dev_list) {
-		if (dev->unhash)
+		if (dev->unhash) {
+			kref_get(&dev->kref);
+			spin_unlock_bh(&device_spinlock);
 			dev->unhash(dev, sk);
+			kref_put(&dev->kref, dev->release);
+			spin_lock_bh(&device_spinlock);
+		}
 	}
-	mutex_unlock(&device_mutex);
-	ctx->unhash(sk);
+	spin_unlock_bh(&device_spinlock);
+	ctx->sk_proto->unhash(sk);
 }
 
 static int tls_hw_hash(struct sock *sk)
@@ -606,13 +702,18 @@
 	struct tls_device *dev;
 	int err;
 
-	err = ctx->hash(sk);
-	mutex_lock(&device_mutex);
+	err = ctx->sk_proto->hash(sk);
+	spin_lock_bh(&device_spinlock);
 	list_for_each_entry(dev, &device_list, dev_list) {
-		if (dev->hash)
+		if (dev->hash) {
+			kref_get(&dev->kref);
+			spin_unlock_bh(&device_spinlock);
 			err |= dev->hash(dev, sk);
+			kref_put(&dev->kref, dev->release);
+			spin_lock_bh(&device_spinlock);
+		}
 	}
-	mutex_unlock(&device_mutex);
+	spin_unlock_bh(&device_spinlock);
 
 	if (err)
 		tls_hw_unhash(sk);
@@ -632,12 +733,14 @@
 	prot[TLS_SW][TLS_BASE].sendpage		= tls_sw_sendpage;
 
 	prot[TLS_BASE][TLS_SW] = prot[TLS_BASE][TLS_BASE];
-	prot[TLS_BASE][TLS_SW].recvmsg		= tls_sw_recvmsg;
-	prot[TLS_BASE][TLS_SW].close		= tls_sk_proto_close;
+	prot[TLS_BASE][TLS_SW].recvmsg		  = tls_sw_recvmsg;
+	prot[TLS_BASE][TLS_SW].stream_memory_read = tls_sw_stream_read;
+	prot[TLS_BASE][TLS_SW].close		  = tls_sk_proto_close;
 
 	prot[TLS_SW][TLS_SW] = prot[TLS_SW][TLS_BASE];
-	prot[TLS_SW][TLS_SW].recvmsg	= tls_sw_recvmsg;
-	prot[TLS_SW][TLS_SW].close	= tls_sk_proto_close;
+	prot[TLS_SW][TLS_SW].recvmsg		= tls_sw_recvmsg;
+	prot[TLS_SW][TLS_SW].stream_memory_read	= tls_sw_stream_read;
+	prot[TLS_SW][TLS_SW].close		= tls_sk_proto_close;
 
 #ifdef CONFIG_TLS_DEVICE
 	prot[TLS_HW][TLS_BASE] = prot[TLS_BASE][TLS_BASE];
@@ -658,17 +761,15 @@
 	prot[TLS_HW_RECORD][TLS_HW_RECORD] = *base;
 	prot[TLS_HW_RECORD][TLS_HW_RECORD].hash		= tls_hw_hash;
 	prot[TLS_HW_RECORD][TLS_HW_RECORD].unhash	= tls_hw_unhash;
-	prot[TLS_HW_RECORD][TLS_HW_RECORD].close	= tls_sk_proto_close;
 }
 
 static int tls_init(struct sock *sk)
 {
-	int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4;
 	struct tls_context *ctx;
 	int rc = 0;
 
 	if (tls_hw_prot(sk))
-		goto out;
+		return 0;
 
 	/* The TLS ulp is currently supported only for TCP sockets
 	 * in ESTABLISHED state.
@@ -679,69 +780,128 @@
 	if (sk->sk_state != TCP_ESTABLISHED)
 		return -ENOTSUPP;
 
+	tls_build_proto(sk);
+
 	/* allocate tls context */
+	write_lock_bh(&sk->sk_callback_lock);
 	ctx = create_ctx(sk);
 	if (!ctx) {
 		rc = -ENOMEM;
 		goto out;
 	}
-	ctx->setsockopt = sk->sk_prot->setsockopt;
-	ctx->getsockopt = sk->sk_prot->getsockopt;
-	ctx->sk_proto_close = sk->sk_prot->close;
-
-	/* Build IPv6 TLS whenever the address of tcpv6	_prot changes */
-	if (ip_ver == TLSV6 &&
-	    unlikely(sk->sk_prot != smp_load_acquire(&saved_tcpv6_prot))) {
-		mutex_lock(&tcpv6_prot_mutex);
-		if (likely(sk->sk_prot != saved_tcpv6_prot)) {
-			build_protos(tls_prots[TLSV6], sk->sk_prot);
-			smp_store_release(&saved_tcpv6_prot, sk->sk_prot);
-		}
-		mutex_unlock(&tcpv6_prot_mutex);
-	}
 
 	ctx->tx_conf = TLS_BASE;
 	ctx->rx_conf = TLS_BASE;
 	update_sk_prot(sk, ctx);
 out:
+	write_unlock_bh(&sk->sk_callback_lock);
 	return rc;
 }
 
+static void tls_update(struct sock *sk, struct proto *p)
+{
+	struct tls_context *ctx;
+
+	ctx = tls_get_ctx(sk);
+	if (likely(ctx))
+		ctx->sk_proto = p;
+	else
+		sk->sk_prot = p;
+}
+
+static int tls_get_info(const struct sock *sk, struct sk_buff *skb)
+{
+	u16 version, cipher_type;
+	struct tls_context *ctx;
+	struct nlattr *start;
+	int err;
+
+	start = nla_nest_start_noflag(skb, INET_ULP_INFO_TLS);
+	if (!start)
+		return -EMSGSIZE;
+
+	rcu_read_lock();
+	ctx = rcu_dereference(inet_csk(sk)->icsk_ulp_data);
+	if (!ctx) {
+		err = 0;
+		goto nla_failure;
+	}
+	version = ctx->prot_info.version;
+	if (version) {
+		err = nla_put_u16(skb, TLS_INFO_VERSION, version);
+		if (err)
+			goto nla_failure;
+	}
+	cipher_type = ctx->prot_info.cipher_type;
+	if (cipher_type) {
+		err = nla_put_u16(skb, TLS_INFO_CIPHER, cipher_type);
+		if (err)
+			goto nla_failure;
+	}
+	err = nla_put_u16(skb, TLS_INFO_TXCONF, tls_user_config(ctx, true));
+	if (err)
+		goto nla_failure;
+
+	err = nla_put_u16(skb, TLS_INFO_RXCONF, tls_user_config(ctx, false));
+	if (err)
+		goto nla_failure;
+
+	rcu_read_unlock();
+	nla_nest_end(skb, start);
+	return 0;
+
+nla_failure:
+	rcu_read_unlock();
+	nla_nest_cancel(skb, start);
+	return err;
+}
+
+static size_t tls_get_info_size(const struct sock *sk)
+{
+	size_t size = 0;
+
+	size += nla_total_size(0) +		/* INET_ULP_INFO_TLS */
+		nla_total_size(sizeof(u16)) +	/* TLS_INFO_VERSION */
+		nla_total_size(sizeof(u16)) +	/* TLS_INFO_CIPHER */
+		nla_total_size(sizeof(u16)) +	/* TLS_INFO_RXCONF */
+		nla_total_size(sizeof(u16)) +	/* TLS_INFO_TXCONF */
+		0;
+
+	return size;
+}
+
 void tls_register_device(struct tls_device *device)
 {
-	mutex_lock(&device_mutex);
+	spin_lock_bh(&device_spinlock);
 	list_add_tail(&device->dev_list, &device_list);
-	mutex_unlock(&device_mutex);
+	spin_unlock_bh(&device_spinlock);
 }
 EXPORT_SYMBOL(tls_register_device);
 
 void tls_unregister_device(struct tls_device *device)
 {
-	mutex_lock(&device_mutex);
+	spin_lock_bh(&device_spinlock);
 	list_del(&device->dev_list);
-	mutex_unlock(&device_mutex);
+	spin_unlock_bh(&device_spinlock);
 }
 EXPORT_SYMBOL(tls_unregister_device);
 
 static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = {
 	.name			= "tls",
-	.uid			= TCP_ULP_TLS,
-	.user_visible		= true,
 	.owner			= THIS_MODULE,
 	.init			= tls_init,
+	.update			= tls_update,
+	.get_info		= tls_get_info,
+	.get_info_size		= tls_get_info_size,
 };
 
 static int __init tls_register(void)
 {
-	build_protos(tls_prots[TLSV4], &tcp_prot);
-
 	tls_sw_proto_ops = inet_stream_ops;
-	tls_sw_proto_ops.poll = tls_sw_poll;
 	tls_sw_proto_ops.splice_read = tls_sw_splice_read;
+	tls_sw_proto_ops.sendpage_locked   = tls_sw_sendpage_locked,
 
-#ifdef CONFIG_TLS_DEVICE
 	tls_device_init();
-#endif
 	tcp_register_ulp(&tcp_tls_ulp_ops);
 
 	return 0;
@@ -750,9 +910,7 @@
 static void __exit tls_unregister(void)
 {
 	tcp_unregister_ulp(&tcp_tls_ulp_ops);
-#ifdef CONFIG_TLS_DEVICE
 	tls_device_cleanup();
-#endif
 }
 
 module_init(tls_register);
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c
index b9c6ecf..5dd0f01 100644
--- a/net/tls/tls_sw.c
+++ b/net/tls/tls_sw.c
@@ -4,6 +4,7 @@
  * Copyright (c) 2016-2017, Lance Chao <lancerchao@fb.com>. All rights reserved.
  * Copyright (c) 2016, Fridolin Pokorny <fridolin.pokorny@gmail.com>. All rights reserved.
  * Copyright (c) 2016, Nikos Mavrogiannopoulos <nmav@gnutls.org>. All rights reserved.
+ * Copyright (c) 2018, Covalent IO, Inc. http://covalent.io
  *
  * This software is available to you under a choice of one of two
  * licenses.  You may choose to be licensed under the terms of the GNU
@@ -41,237 +42,1257 @@
 #include <net/strparser.h>
 #include <net/tls.h>
 
-#define MAX_IV_SIZE	TLS_CIPHER_AES_GCM_128_IV_SIZE
+static int __skb_nsg(struct sk_buff *skb, int offset, int len,
+                     unsigned int recursion_level)
+{
+        int start = skb_headlen(skb);
+        int i, chunk = start - offset;
+        struct sk_buff *frag_iter;
+        int elt = 0;
+
+        if (unlikely(recursion_level >= 24))
+                return -EMSGSIZE;
+
+        if (chunk > 0) {
+                if (chunk > len)
+                        chunk = len;
+                elt++;
+                len -= chunk;
+                if (len == 0)
+                        return elt;
+                offset += chunk;
+        }
+
+        for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) {
+                int end;
+
+                WARN_ON(start > offset + len);
+
+                end = start + skb_frag_size(&skb_shinfo(skb)->frags[i]);
+                chunk = end - offset;
+                if (chunk > 0) {
+                        if (chunk > len)
+                                chunk = len;
+                        elt++;
+                        len -= chunk;
+                        if (len == 0)
+                                return elt;
+                        offset += chunk;
+                }
+                start = end;
+        }
+
+        if (unlikely(skb_has_frag_list(skb))) {
+                skb_walk_frags(skb, frag_iter) {
+                        int end, ret;
+
+                        WARN_ON(start > offset + len);
+
+                        end = start + frag_iter->len;
+                        chunk = end - offset;
+                        if (chunk > 0) {
+                                if (chunk > len)
+                                        chunk = len;
+                                ret = __skb_nsg(frag_iter, offset - start, chunk,
+                                                recursion_level + 1);
+                                if (unlikely(ret < 0))
+                                        return ret;
+                                elt += ret;
+                                len -= chunk;
+                                if (len == 0)
+                                        return elt;
+                                offset += chunk;
+                        }
+                        start = end;
+                }
+        }
+        BUG_ON(len);
+        return elt;
+}
+
+/* Return the number of scatterlist elements required to completely map the
+ * skb, or -EMSGSIZE if the recursion depth is exceeded.
+ */
+static int skb_nsg(struct sk_buff *skb, int offset, int len)
+{
+        return __skb_nsg(skb, offset, len, 0);
+}
+
+static int padding_length(struct tls_sw_context_rx *ctx,
+			  struct tls_prot_info *prot, struct sk_buff *skb)
+{
+	struct strp_msg *rxm = strp_msg(skb);
+	int sub = 0;
+
+	/* Determine zero-padding length */
+	if (prot->version == TLS_1_3_VERSION) {
+		char content_type = 0;
+		int err;
+		int back = 17;
+
+		while (content_type == 0) {
+			if (back > rxm->full_len - prot->prepend_size)
+				return -EBADMSG;
+			err = skb_copy_bits(skb,
+					    rxm->offset + rxm->full_len - back,
+					    &content_type, 1);
+			if (err)
+				return err;
+			if (content_type)
+				break;
+			sub++;
+			back++;
+		}
+		ctx->control = content_type;
+	}
+	return sub;
+}
+
+static void tls_decrypt_done(struct crypto_async_request *req, int err)
+{
+	struct aead_request *aead_req = (struct aead_request *)req;
+	struct scatterlist *sgout = aead_req->dst;
+	struct scatterlist *sgin = aead_req->src;
+	struct tls_sw_context_rx *ctx;
+	struct tls_context *tls_ctx;
+	struct tls_prot_info *prot;
+	struct scatterlist *sg;
+	struct sk_buff *skb;
+	unsigned int pages;
+	int pending;
+
+	skb = (struct sk_buff *)req->data;
+	tls_ctx = tls_get_ctx(skb->sk);
+	ctx = tls_sw_ctx_rx(tls_ctx);
+	prot = &tls_ctx->prot_info;
+
+	/* Propagate if there was an err */
+	if (err) {
+		ctx->async_wait.err = err;
+		tls_err_abort(skb->sk, err);
+	} else {
+		struct strp_msg *rxm = strp_msg(skb);
+		int pad;
+
+		pad = padding_length(ctx, prot, skb);
+		if (pad < 0) {
+			ctx->async_wait.err = pad;
+			tls_err_abort(skb->sk, pad);
+		} else {
+			rxm->full_len -= pad;
+			rxm->offset += prot->prepend_size;
+			rxm->full_len -= prot->overhead_size;
+		}
+	}
+
+	/* After using skb->sk to propagate sk through crypto async callback
+	 * we need to NULL it again.
+	 */
+	skb->sk = NULL;
+
+
+	/* Free the destination pages if skb was not decrypted inplace */
+	if (sgout != sgin) {
+		/* Skip the first S/G entry as it points to AAD */
+		for_each_sg(sg_next(sgout), sg, UINT_MAX, pages) {
+			if (!sg)
+				break;
+			put_page(sg_page(sg));
+		}
+	}
+
+	kfree(aead_req);
+
+	pending = atomic_dec_return(&ctx->decrypt_pending);
+
+	if (!pending && READ_ONCE(ctx->async_notify))
+		complete(&ctx->async_wait.completion);
+}
 
 static int tls_do_decryption(struct sock *sk,
+			     struct sk_buff *skb,
 			     struct scatterlist *sgin,
 			     struct scatterlist *sgout,
 			     char *iv_recv,
 			     size_t data_len,
-			     struct aead_request *aead_req)
+			     struct aead_request *aead_req,
+			     bool async)
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_prot_info *prot = &tls_ctx->prot_info;
 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
 	int ret;
 
 	aead_request_set_tfm(aead_req, ctx->aead_recv);
-	aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
+	aead_request_set_ad(aead_req, prot->aad_size);
 	aead_request_set_crypt(aead_req, sgin, sgout,
-			       data_len + tls_ctx->rx.tag_size,
+			       data_len + prot->tag_size,
 			       (u8 *)iv_recv);
-	aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG,
-				  crypto_req_done, &ctx->async_wait);
 
-	ret = crypto_wait_req(crypto_aead_decrypt(aead_req), &ctx->async_wait);
+	if (async) {
+		/* Using skb->sk to push sk through to crypto async callback
+		 * handler. This allows propagating errors up to the socket
+		 * if needed. It _must_ be cleared in the async handler
+		 * before consume_skb is called. We _know_ skb->sk is NULL
+		 * because it is a clone from strparser.
+		 */
+		skb->sk = sk;
+		aead_request_set_callback(aead_req,
+					  CRYPTO_TFM_REQ_MAY_BACKLOG,
+					  tls_decrypt_done, skb);
+		atomic_inc(&ctx->decrypt_pending);
+	} else {
+		aead_request_set_callback(aead_req,
+					  CRYPTO_TFM_REQ_MAY_BACKLOG,
+					  crypto_req_done, &ctx->async_wait);
+	}
+
+	ret = crypto_aead_decrypt(aead_req);
+	if (ret == -EINPROGRESS) {
+		if (async)
+			return ret;
+
+		ret = crypto_wait_req(ret, &ctx->async_wait);
+	}
+
+	if (async)
+		atomic_dec(&ctx->decrypt_pending);
+
 	return ret;
 }
 
-static void trim_sg(struct sock *sk, struct scatterlist *sg,
-		    int *sg_num_elem, unsigned int *sg_size, int target_size)
-{
-	int i = *sg_num_elem - 1;
-	int trim = *sg_size - target_size;
-
-	if (trim <= 0) {
-		WARN_ON(trim < 0);
-		return;
-	}
-
-	*sg_size = target_size;
-	while (trim >= sg[i].length) {
-		trim -= sg[i].length;
-		sk_mem_uncharge(sk, sg[i].length);
-		put_page(sg_page(&sg[i]));
-		i--;
-
-		if (i < 0)
-			goto out;
-	}
-
-	sg[i].length -= trim;
-	sk_mem_uncharge(sk, trim);
-
-out:
-	*sg_num_elem = i + 1;
-}
-
-static void trim_both_sgl(struct sock *sk, int target_size)
+static void tls_trim_both_msgs(struct sock *sk, int target_size)
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_prot_info *prot = &tls_ctx->prot_info;
 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
+	struct tls_rec *rec = ctx->open_rec;
 
-	trim_sg(sk, ctx->sg_plaintext_data,
-		&ctx->sg_plaintext_num_elem,
-		&ctx->sg_plaintext_size,
-		target_size);
-
+	sk_msg_trim(sk, &rec->msg_plaintext, target_size);
 	if (target_size > 0)
-		target_size += tls_ctx->tx.overhead_size;
-
-	trim_sg(sk, ctx->sg_encrypted_data,
-		&ctx->sg_encrypted_num_elem,
-		&ctx->sg_encrypted_size,
-		target_size);
+		target_size += prot->overhead_size;
+	sk_msg_trim(sk, &rec->msg_encrypted, target_size);
 }
 
-static int alloc_encrypted_sg(struct sock *sk, int len)
+static int tls_alloc_encrypted_msg(struct sock *sk, int len)
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
-	int rc = 0;
+	struct tls_rec *rec = ctx->open_rec;
+	struct sk_msg *msg_en = &rec->msg_encrypted;
 
-	rc = sk_alloc_sg(sk, len,
-			 ctx->sg_encrypted_data, 0,
-			 &ctx->sg_encrypted_num_elem,
-			 &ctx->sg_encrypted_size, 0);
-
-	if (rc == -ENOSPC)
-		ctx->sg_encrypted_num_elem = ARRAY_SIZE(ctx->sg_encrypted_data);
-
-	return rc;
+	return sk_msg_alloc(sk, msg_en, len, 0);
 }
 
-static int alloc_plaintext_sg(struct sock *sk, int len)
+static int tls_clone_plaintext_msg(struct sock *sk, int required)
+{
+	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_prot_info *prot = &tls_ctx->prot_info;
+	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
+	struct tls_rec *rec = ctx->open_rec;
+	struct sk_msg *msg_pl = &rec->msg_plaintext;
+	struct sk_msg *msg_en = &rec->msg_encrypted;
+	int skip, len;
+
+	/* We add page references worth len bytes from encrypted sg
+	 * at the end of plaintext sg. It is guaranteed that msg_en
+	 * has enough required room (ensured by caller).
+	 */
+	len = required - msg_pl->sg.size;
+
+	/* Skip initial bytes in msg_en's data to be able to use
+	 * same offset of both plain and encrypted data.
+	 */
+	skip = prot->prepend_size + msg_pl->sg.size;
+
+	return sk_msg_clone(sk, msg_pl, msg_en, skip, len);
+}
+
+static struct tls_rec *tls_get_rec(struct sock *sk)
+{
+	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_prot_info *prot = &tls_ctx->prot_info;
+	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
+	struct sk_msg *msg_pl, *msg_en;
+	struct tls_rec *rec;
+	int mem_size;
+
+	mem_size = sizeof(struct tls_rec) + crypto_aead_reqsize(ctx->aead_send);
+
+	rec = kzalloc(mem_size, sk->sk_allocation);
+	if (!rec)
+		return NULL;
+
+	msg_pl = &rec->msg_plaintext;
+	msg_en = &rec->msg_encrypted;
+
+	sk_msg_init(msg_pl);
+	sk_msg_init(msg_en);
+
+	sg_init_table(rec->sg_aead_in, 2);
+	sg_set_buf(&rec->sg_aead_in[0], rec->aad_space, prot->aad_size);
+	sg_unmark_end(&rec->sg_aead_in[1]);
+
+	sg_init_table(rec->sg_aead_out, 2);
+	sg_set_buf(&rec->sg_aead_out[0], rec->aad_space, prot->aad_size);
+	sg_unmark_end(&rec->sg_aead_out[1]);
+
+	return rec;
+}
+
+static void tls_free_rec(struct sock *sk, struct tls_rec *rec)
+{
+	sk_msg_free(sk, &rec->msg_encrypted);
+	sk_msg_free(sk, &rec->msg_plaintext);
+	kfree(rec);
+}
+
+static void tls_free_open_rec(struct sock *sk)
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
-	int rc = 0;
+	struct tls_rec *rec = ctx->open_rec;
 
-	rc = sk_alloc_sg(sk, len, ctx->sg_plaintext_data, 0,
-			 &ctx->sg_plaintext_num_elem, &ctx->sg_plaintext_size,
-			 tls_ctx->pending_open_record_frags);
-
-	if (rc == -ENOSPC)
-		ctx->sg_plaintext_num_elem = ARRAY_SIZE(ctx->sg_plaintext_data);
-
-	return rc;
-}
-
-static void free_sg(struct sock *sk, struct scatterlist *sg,
-		    int *sg_num_elem, unsigned int *sg_size)
-{
-	int i, n = *sg_num_elem;
-
-	for (i = 0; i < n; ++i) {
-		sk_mem_uncharge(sk, sg[i].length);
-		put_page(sg_page(&sg[i]));
+	if (rec) {
+		tls_free_rec(sk, rec);
+		ctx->open_rec = NULL;
 	}
-	*sg_num_elem = 0;
-	*sg_size = 0;
 }
 
-static void tls_free_both_sg(struct sock *sk)
+int tls_tx_records(struct sock *sk, int flags)
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
+	struct tls_rec *rec, *tmp;
+	struct sk_msg *msg_en;
+	int tx_flags, rc = 0;
 
-	free_sg(sk, ctx->sg_encrypted_data, &ctx->sg_encrypted_num_elem,
-		&ctx->sg_encrypted_size);
+	if (tls_is_partially_sent_record(tls_ctx)) {
+		rec = list_first_entry(&ctx->tx_list,
+				       struct tls_rec, list);
 
-	free_sg(sk, ctx->sg_plaintext_data, &ctx->sg_plaintext_num_elem,
-		&ctx->sg_plaintext_size);
+		if (flags == -1)
+			tx_flags = rec->tx_flags;
+		else
+			tx_flags = flags;
+
+		rc = tls_push_partial_record(sk, tls_ctx, tx_flags);
+		if (rc)
+			goto tx_err;
+
+		/* Full record has been transmitted.
+		 * Remove the head of tx_list
+		 */
+		list_del(&rec->list);
+		sk_msg_free(sk, &rec->msg_plaintext);
+		kfree(rec);
+	}
+
+	/* Tx all ready records */
+	list_for_each_entry_safe(rec, tmp, &ctx->tx_list, list) {
+		if (READ_ONCE(rec->tx_ready)) {
+			if (flags == -1)
+				tx_flags = rec->tx_flags;
+			else
+				tx_flags = flags;
+
+			msg_en = &rec->msg_encrypted;
+			rc = tls_push_sg(sk, tls_ctx,
+					 &msg_en->sg.data[msg_en->sg.curr],
+					 0, tx_flags);
+			if (rc)
+				goto tx_err;
+
+			list_del(&rec->list);
+			sk_msg_free(sk, &rec->msg_plaintext);
+			kfree(rec);
+		} else {
+			break;
+		}
+	}
+
+tx_err:
+	if (rc < 0 && rc != -EAGAIN)
+		tls_err_abort(sk, EBADMSG);
+
+	return rc;
 }
 
-static int tls_do_encryption(struct tls_context *tls_ctx,
+static void tls_encrypt_done(struct crypto_async_request *req, int err)
+{
+	struct aead_request *aead_req = (struct aead_request *)req;
+	struct sock *sk = req->data;
+	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_prot_info *prot = &tls_ctx->prot_info;
+	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
+	struct scatterlist *sge;
+	struct sk_msg *msg_en;
+	struct tls_rec *rec;
+	bool ready = false;
+	int pending;
+
+	rec = container_of(aead_req, struct tls_rec, aead_req);
+	msg_en = &rec->msg_encrypted;
+
+	sge = sk_msg_elem(msg_en, msg_en->sg.curr);
+	sge->offset -= prot->prepend_size;
+	sge->length += prot->prepend_size;
+
+	/* Check if error is previously set on socket */
+	if (err || sk->sk_err) {
+		rec = NULL;
+
+		/* If err is already set on socket, return the same code */
+		if (sk->sk_err) {
+			ctx->async_wait.err = sk->sk_err;
+		} else {
+			ctx->async_wait.err = err;
+			tls_err_abort(sk, err);
+		}
+	}
+
+	if (rec) {
+		struct tls_rec *first_rec;
+
+		/* Mark the record as ready for transmission */
+		smp_store_mb(rec->tx_ready, true);
+
+		/* If received record is at head of tx_list, schedule tx */
+		first_rec = list_first_entry(&ctx->tx_list,
+					     struct tls_rec, list);
+		if (rec == first_rec)
+			ready = true;
+	}
+
+	pending = atomic_dec_return(&ctx->encrypt_pending);
+
+	if (!pending && READ_ONCE(ctx->async_notify))
+		complete(&ctx->async_wait.completion);
+
+	if (!ready)
+		return;
+
+	/* Schedule the transmission */
+	if (!test_and_set_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask))
+		schedule_delayed_work(&ctx->tx_work.work, 1);
+}
+
+static int tls_do_encryption(struct sock *sk,
+			     struct tls_context *tls_ctx,
 			     struct tls_sw_context_tx *ctx,
 			     struct aead_request *aead_req,
-			     size_t data_len)
+			     size_t data_len, u32 start)
 {
-	int rc;
+	struct tls_prot_info *prot = &tls_ctx->prot_info;
+	struct tls_rec *rec = ctx->open_rec;
+	struct sk_msg *msg_en = &rec->msg_encrypted;
+	struct scatterlist *sge = sk_msg_elem(msg_en, start);
+	int rc, iv_offset = 0;
 
-	ctx->sg_encrypted_data[0].offset += tls_ctx->tx.prepend_size;
-	ctx->sg_encrypted_data[0].length -= tls_ctx->tx.prepend_size;
+	/* For CCM based ciphers, first byte of IV is a constant */
+	if (prot->cipher_type == TLS_CIPHER_AES_CCM_128) {
+		rec->iv_data[0] = TLS_AES_CCM_IV_B0_BYTE;
+		iv_offset = 1;
+	}
+
+	memcpy(&rec->iv_data[iv_offset], tls_ctx->tx.iv,
+	       prot->iv_size + prot->salt_size);
+
+	xor_iv_with_seq(prot->version, rec->iv_data, tls_ctx->tx.rec_seq);
+
+	sge->offset += prot->prepend_size;
+	sge->length -= prot->prepend_size;
+
+	msg_en->sg.curr = start;
 
 	aead_request_set_tfm(aead_req, ctx->aead_send);
-	aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
-	aead_request_set_crypt(aead_req, ctx->sg_aead_in, ctx->sg_aead_out,
-			       data_len, tls_ctx->tx.iv);
+	aead_request_set_ad(aead_req, prot->aad_size);
+	aead_request_set_crypt(aead_req, rec->sg_aead_in,
+			       rec->sg_aead_out,
+			       data_len, rec->iv_data);
 
 	aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG,
-				  crypto_req_done, &ctx->async_wait);
+				  tls_encrypt_done, sk);
 
-	rc = crypto_wait_req(crypto_aead_encrypt(aead_req), &ctx->async_wait);
+	/* Add the record in tx_list */
+	list_add_tail((struct list_head *)&rec->list, &ctx->tx_list);
+	atomic_inc(&ctx->encrypt_pending);
 
-	ctx->sg_encrypted_data[0].offset -= tls_ctx->tx.prepend_size;
-	ctx->sg_encrypted_data[0].length += tls_ctx->tx.prepend_size;
+	rc = crypto_aead_encrypt(aead_req);
+	if (!rc || rc != -EINPROGRESS) {
+		atomic_dec(&ctx->encrypt_pending);
+		sge->offset -= prot->prepend_size;
+		sge->length += prot->prepend_size;
+	}
 
+	if (!rc) {
+		WRITE_ONCE(rec->tx_ready, true);
+	} else if (rc != -EINPROGRESS) {
+		list_del(&rec->list);
+		return rc;
+	}
+
+	/* Unhook the record from context if encryption is not failure */
+	ctx->open_rec = NULL;
+	tls_advance_record_sn(sk, prot, &tls_ctx->tx);
 	return rc;
 }
 
+static int tls_split_open_record(struct sock *sk, struct tls_rec *from,
+				 struct tls_rec **to, struct sk_msg *msg_opl,
+				 struct sk_msg *msg_oen, u32 split_point,
+				 u32 tx_overhead_size, u32 *orig_end)
+{
+	u32 i, j, bytes = 0, apply = msg_opl->apply_bytes;
+	struct scatterlist *sge, *osge, *nsge;
+	u32 orig_size = msg_opl->sg.size;
+	struct scatterlist tmp = { };
+	struct sk_msg *msg_npl;
+	struct tls_rec *new;
+	int ret;
+
+	new = tls_get_rec(sk);
+	if (!new)
+		return -ENOMEM;
+	ret = sk_msg_alloc(sk, &new->msg_encrypted, msg_opl->sg.size +
+			   tx_overhead_size, 0);
+	if (ret < 0) {
+		tls_free_rec(sk, new);
+		return ret;
+	}
+
+	*orig_end = msg_opl->sg.end;
+	i = msg_opl->sg.start;
+	sge = sk_msg_elem(msg_opl, i);
+	while (apply && sge->length) {
+		if (sge->length > apply) {
+			u32 len = sge->length - apply;
+
+			get_page(sg_page(sge));
+			sg_set_page(&tmp, sg_page(sge), len,
+				    sge->offset + apply);
+			sge->length = apply;
+			bytes += apply;
+			apply = 0;
+		} else {
+			apply -= sge->length;
+			bytes += sge->length;
+		}
+
+		sk_msg_iter_var_next(i);
+		if (i == msg_opl->sg.end)
+			break;
+		sge = sk_msg_elem(msg_opl, i);
+	}
+
+	msg_opl->sg.end = i;
+	msg_opl->sg.curr = i;
+	msg_opl->sg.copybreak = 0;
+	msg_opl->apply_bytes = 0;
+	msg_opl->sg.size = bytes;
+
+	msg_npl = &new->msg_plaintext;
+	msg_npl->apply_bytes = apply;
+	msg_npl->sg.size = orig_size - bytes;
+
+	j = msg_npl->sg.start;
+	nsge = sk_msg_elem(msg_npl, j);
+	if (tmp.length) {
+		memcpy(nsge, &tmp, sizeof(*nsge));
+		sk_msg_iter_var_next(j);
+		nsge = sk_msg_elem(msg_npl, j);
+	}
+
+	osge = sk_msg_elem(msg_opl, i);
+	while (osge->length) {
+		memcpy(nsge, osge, sizeof(*nsge));
+		sg_unmark_end(nsge);
+		sk_msg_iter_var_next(i);
+		sk_msg_iter_var_next(j);
+		if (i == *orig_end)
+			break;
+		osge = sk_msg_elem(msg_opl, i);
+		nsge = sk_msg_elem(msg_npl, j);
+	}
+
+	msg_npl->sg.end = j;
+	msg_npl->sg.curr = j;
+	msg_npl->sg.copybreak = 0;
+
+	*to = new;
+	return 0;
+}
+
+static void tls_merge_open_record(struct sock *sk, struct tls_rec *to,
+				  struct tls_rec *from, u32 orig_end)
+{
+	struct sk_msg *msg_npl = &from->msg_plaintext;
+	struct sk_msg *msg_opl = &to->msg_plaintext;
+	struct scatterlist *osge, *nsge;
+	u32 i, j;
+
+	i = msg_opl->sg.end;
+	sk_msg_iter_var_prev(i);
+	j = msg_npl->sg.start;
+
+	osge = sk_msg_elem(msg_opl, i);
+	nsge = sk_msg_elem(msg_npl, j);
+
+	if (sg_page(osge) == sg_page(nsge) &&
+	    osge->offset + osge->length == nsge->offset) {
+		osge->length += nsge->length;
+		put_page(sg_page(nsge));
+	}
+
+	msg_opl->sg.end = orig_end;
+	msg_opl->sg.curr = orig_end;
+	msg_opl->sg.copybreak = 0;
+	msg_opl->apply_bytes = msg_opl->sg.size + msg_npl->sg.size;
+	msg_opl->sg.size += msg_npl->sg.size;
+
+	sk_msg_free(sk, &to->msg_encrypted);
+	sk_msg_xfer_full(&to->msg_encrypted, &from->msg_encrypted);
+
+	kfree(from);
+}
+
 static int tls_push_record(struct sock *sk, int flags,
 			   unsigned char record_type)
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_prot_info *prot = &tls_ctx->prot_info;
 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
+	struct tls_rec *rec = ctx->open_rec, *tmp = NULL;
+	u32 i, split_point, uninitialized_var(orig_end);
+	struct sk_msg *msg_pl, *msg_en;
 	struct aead_request *req;
+	bool split;
 	int rc;
 
-	req = aead_request_alloc(ctx->aead_send, sk->sk_allocation);
-	if (!req)
-		return -ENOMEM;
+	if (!rec)
+		return 0;
 
-	sg_mark_end(ctx->sg_plaintext_data + ctx->sg_plaintext_num_elem - 1);
-	sg_mark_end(ctx->sg_encrypted_data + ctx->sg_encrypted_num_elem - 1);
+	msg_pl = &rec->msg_plaintext;
+	msg_en = &rec->msg_encrypted;
 
-	tls_make_aad(ctx->aad_space, ctx->sg_plaintext_size,
-		     tls_ctx->tx.rec_seq, tls_ctx->tx.rec_seq_size,
-		     record_type);
-
-	tls_fill_prepend(tls_ctx,
-			 page_address(sg_page(&ctx->sg_encrypted_data[0])) +
-			 ctx->sg_encrypted_data[0].offset,
-			 ctx->sg_plaintext_size, record_type);
-
-	tls_ctx->pending_open_record_frags = 0;
-	set_bit(TLS_PENDING_CLOSED_RECORD, &tls_ctx->flags);
-
-	rc = tls_do_encryption(tls_ctx, ctx, req, ctx->sg_plaintext_size);
-	if (rc < 0) {
-		/* If we are called from write_space and
-		 * we fail, we need to set this SOCK_NOSPACE
-		 * to trigger another write_space in the future.
-		 */
-		set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
-		goto out_req;
+	split_point = msg_pl->apply_bytes;
+	split = split_point && split_point < msg_pl->sg.size;
+	if (split) {
+		rc = tls_split_open_record(sk, rec, &tmp, msg_pl, msg_en,
+					   split_point, prot->overhead_size,
+					   &orig_end);
+		if (rc < 0)
+			return rc;
+		sk_msg_trim(sk, msg_en, msg_pl->sg.size +
+			    prot->overhead_size);
 	}
 
-	free_sg(sk, ctx->sg_plaintext_data, &ctx->sg_plaintext_num_elem,
-		&ctx->sg_plaintext_size);
+	rec->tx_flags = flags;
+	req = &rec->aead_req;
 
-	ctx->sg_encrypted_num_elem = 0;
-	ctx->sg_encrypted_size = 0;
+	i = msg_pl->sg.end;
+	sk_msg_iter_var_prev(i);
 
-	/* Only pass through MSG_DONTWAIT and MSG_NOSIGNAL flags */
-	rc = tls_push_sg(sk, tls_ctx, ctx->sg_encrypted_data, 0, flags);
-	if (rc < 0 && rc != -EAGAIN)
-		tls_err_abort(sk, EBADMSG);
+	rec->content_type = record_type;
+	if (prot->version == TLS_1_3_VERSION) {
+		/* Add content type to end of message.  No padding added */
+		sg_set_buf(&rec->sg_content_type, &rec->content_type, 1);
+		sg_mark_end(&rec->sg_content_type);
+		sg_chain(msg_pl->sg.data, msg_pl->sg.end + 1,
+			 &rec->sg_content_type);
+	} else {
+		sg_mark_end(sk_msg_elem(msg_pl, i));
+	}
 
-	tls_advance_record_sn(sk, &tls_ctx->tx);
-out_req:
-	aead_request_free(req);
-	return rc;
+	i = msg_pl->sg.start;
+	sg_chain(rec->sg_aead_in, 2, &msg_pl->sg.data[i]);
+
+	i = msg_en->sg.end;
+	sk_msg_iter_var_prev(i);
+	sg_mark_end(sk_msg_elem(msg_en, i));
+
+	i = msg_en->sg.start;
+	sg_chain(rec->sg_aead_out, 2, &msg_en->sg.data[i]);
+
+	tls_make_aad(rec->aad_space, msg_pl->sg.size + prot->tail_size,
+		     tls_ctx->tx.rec_seq, prot->rec_seq_size,
+		     record_type, prot->version);
+
+	tls_fill_prepend(tls_ctx,
+			 page_address(sg_page(&msg_en->sg.data[i])) +
+			 msg_en->sg.data[i].offset,
+			 msg_pl->sg.size + prot->tail_size,
+			 record_type, prot->version);
+
+	tls_ctx->pending_open_record_frags = false;
+
+	rc = tls_do_encryption(sk, tls_ctx, ctx, req,
+			       msg_pl->sg.size + prot->tail_size, i);
+	if (rc < 0) {
+		if (rc != -EINPROGRESS) {
+			tls_err_abort(sk, EBADMSG);
+			if (split) {
+				tls_ctx->pending_open_record_frags = true;
+				tls_merge_open_record(sk, rec, tmp, orig_end);
+			}
+		}
+		ctx->async_capable = 1;
+		return rc;
+	} else if (split) {
+		msg_pl = &tmp->msg_plaintext;
+		msg_en = &tmp->msg_encrypted;
+		sk_msg_trim(sk, msg_en, msg_pl->sg.size + prot->overhead_size);
+		tls_ctx->pending_open_record_frags = true;
+		ctx->open_rec = tmp;
+	}
+
+	return tls_tx_records(sk, flags);
+}
+
+static int bpf_exec_tx_verdict(struct sk_msg *msg, struct sock *sk,
+			       bool full_record, u8 record_type,
+			       size_t *copied, int flags)
+{
+	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
+	struct sk_msg msg_redir = { };
+	struct sk_psock *psock;
+	struct sock *sk_redir;
+	struct tls_rec *rec;
+	bool enospc, policy;
+	int err = 0, send;
+	u32 delta = 0;
+
+	policy = !(flags & MSG_SENDPAGE_NOPOLICY);
+	psock = sk_psock_get(sk);
+	if (!psock || !policy) {
+		err = tls_push_record(sk, flags, record_type);
+		if (err) {
+			*copied -= sk_msg_free(sk, msg);
+			tls_free_open_rec(sk);
+		}
+		return err;
+	}
+more_data:
+	enospc = sk_msg_full(msg);
+	if (psock->eval == __SK_NONE) {
+		delta = msg->sg.size;
+		psock->eval = sk_psock_msg_verdict(sk, psock, msg);
+		if (delta < msg->sg.size)
+			delta -= msg->sg.size;
+		else
+			delta = 0;
+	}
+	if (msg->cork_bytes && msg->cork_bytes > msg->sg.size &&
+	    !enospc && !full_record) {
+		err = -ENOSPC;
+		goto out_err;
+	}
+	msg->cork_bytes = 0;
+	send = msg->sg.size;
+	if (msg->apply_bytes && msg->apply_bytes < send)
+		send = msg->apply_bytes;
+
+	switch (psock->eval) {
+	case __SK_PASS:
+		err = tls_push_record(sk, flags, record_type);
+		if (err < 0) {
+			*copied -= sk_msg_free(sk, msg);
+			tls_free_open_rec(sk);
+			goto out_err;
+		}
+		break;
+	case __SK_REDIRECT:
+		sk_redir = psock->sk_redir;
+		memcpy(&msg_redir, msg, sizeof(*msg));
+		if (msg->apply_bytes < send)
+			msg->apply_bytes = 0;
+		else
+			msg->apply_bytes -= send;
+		sk_msg_return_zero(sk, msg, send);
+		msg->sg.size -= send;
+		release_sock(sk);
+		err = tcp_bpf_sendmsg_redir(sk_redir, &msg_redir, send, flags);
+		lock_sock(sk);
+		if (err < 0) {
+			*copied -= sk_msg_free_nocharge(sk, &msg_redir);
+			msg->sg.size = 0;
+		}
+		if (msg->sg.size == 0)
+			tls_free_open_rec(sk);
+		break;
+	case __SK_DROP:
+	default:
+		sk_msg_free_partial(sk, msg, send);
+		if (msg->apply_bytes < send)
+			msg->apply_bytes = 0;
+		else
+			msg->apply_bytes -= send;
+		if (msg->sg.size == 0)
+			tls_free_open_rec(sk);
+		*copied -= (send + delta);
+		err = -EACCES;
+	}
+
+	if (likely(!err)) {
+		bool reset_eval = !ctx->open_rec;
+
+		rec = ctx->open_rec;
+		if (rec) {
+			msg = &rec->msg_plaintext;
+			if (!msg->apply_bytes)
+				reset_eval = true;
+		}
+		if (reset_eval) {
+			psock->eval = __SK_NONE;
+			if (psock->sk_redir) {
+				sock_put(psock->sk_redir);
+				psock->sk_redir = NULL;
+			}
+		}
+		if (rec)
+			goto more_data;
+	}
+ out_err:
+	sk_psock_put(sk, psock);
+	return err;
 }
 
 static int tls_sw_push_pending_record(struct sock *sk, int flags)
 {
-	return tls_push_record(sk, flags, TLS_RECORD_TYPE_DATA);
+	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
+	struct tls_rec *rec = ctx->open_rec;
+	struct sk_msg *msg_pl;
+	size_t copied;
+
+	if (!rec)
+		return 0;
+
+	msg_pl = &rec->msg_plaintext;
+	copied = msg_pl->sg.size;
+	if (!copied)
+		return 0;
+
+	return bpf_exec_tx_verdict(msg_pl, sk, true, TLS_RECORD_TYPE_DATA,
+				   &copied, flags);
 }
 
-static int zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
-			      int length, int *pages_used,
-			      unsigned int *size_used,
-			      struct scatterlist *to, int to_max_pages,
-			      bool charge)
+int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
 {
-	struct page *pages[MAX_SKB_FRAGS];
+	long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
+	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_prot_info *prot = &tls_ctx->prot_info;
+	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
+	bool async_capable = ctx->async_capable;
+	unsigned char record_type = TLS_RECORD_TYPE_DATA;
+	bool is_kvec = iov_iter_is_kvec(&msg->msg_iter);
+	bool eor = !(msg->msg_flags & MSG_MORE);
+	size_t try_to_copy, copied = 0;
+	struct sk_msg *msg_pl, *msg_en;
+	struct tls_rec *rec;
+	int required_size;
+	int num_async = 0;
+	bool full_record;
+	int record_room;
+	int num_zc = 0;
+	int orig_size;
+	int ret = 0;
 
-	size_t offset;
-	ssize_t copied, use;
-	int i = 0;
+	if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL))
+		return -ENOTSUPP;
+
+	mutex_lock(&tls_ctx->tx_lock);
+	lock_sock(sk);
+
+	if (unlikely(msg->msg_controllen)) {
+		ret = tls_proccess_cmsg(sk, msg, &record_type);
+		if (ret) {
+			if (ret == -EINPROGRESS)
+				num_async++;
+			else if (ret != -EAGAIN)
+				goto send_end;
+		}
+	}
+
+	while (msg_data_left(msg)) {
+		if (sk->sk_err) {
+			ret = -sk->sk_err;
+			goto send_end;
+		}
+
+		if (ctx->open_rec)
+			rec = ctx->open_rec;
+		else
+			rec = ctx->open_rec = tls_get_rec(sk);
+		if (!rec) {
+			ret = -ENOMEM;
+			goto send_end;
+		}
+
+		msg_pl = &rec->msg_plaintext;
+		msg_en = &rec->msg_encrypted;
+
+		orig_size = msg_pl->sg.size;
+		full_record = false;
+		try_to_copy = msg_data_left(msg);
+		record_room = TLS_MAX_PAYLOAD_SIZE - msg_pl->sg.size;
+		if (try_to_copy >= record_room) {
+			try_to_copy = record_room;
+			full_record = true;
+		}
+
+		required_size = msg_pl->sg.size + try_to_copy +
+				prot->overhead_size;
+
+		if (!sk_stream_memory_free(sk))
+			goto wait_for_sndbuf;
+
+alloc_encrypted:
+		ret = tls_alloc_encrypted_msg(sk, required_size);
+		if (ret) {
+			if (ret != -ENOSPC)
+				goto wait_for_memory;
+
+			/* Adjust try_to_copy according to the amount that was
+			 * actually allocated. The difference is due
+			 * to max sg elements limit
+			 */
+			try_to_copy -= required_size - msg_en->sg.size;
+			full_record = true;
+		}
+
+		if (!is_kvec && (full_record || eor) && !async_capable) {
+			u32 first = msg_pl->sg.end;
+
+			ret = sk_msg_zerocopy_from_iter(sk, &msg->msg_iter,
+							msg_pl, try_to_copy);
+			if (ret)
+				goto fallback_to_reg_send;
+
+			num_zc++;
+			copied += try_to_copy;
+
+			sk_msg_sg_copy_set(msg_pl, first);
+			ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
+						  record_type, &copied,
+						  msg->msg_flags);
+			if (ret) {
+				if (ret == -EINPROGRESS)
+					num_async++;
+				else if (ret == -ENOMEM)
+					goto wait_for_memory;
+				else if (ctx->open_rec && ret == -ENOSPC)
+					goto rollback_iter;
+				else if (ret != -EAGAIN)
+					goto send_end;
+			}
+			continue;
+rollback_iter:
+			copied -= try_to_copy;
+			sk_msg_sg_copy_clear(msg_pl, first);
+			iov_iter_revert(&msg->msg_iter,
+					msg_pl->sg.size - orig_size);
+fallback_to_reg_send:
+			sk_msg_trim(sk, msg_pl, orig_size);
+		}
+
+		required_size = msg_pl->sg.size + try_to_copy;
+
+		ret = tls_clone_plaintext_msg(sk, required_size);
+		if (ret) {
+			if (ret != -ENOSPC)
+				goto send_end;
+
+			/* Adjust try_to_copy according to the amount that was
+			 * actually allocated. The difference is due
+			 * to max sg elements limit
+			 */
+			try_to_copy -= required_size - msg_pl->sg.size;
+			full_record = true;
+			sk_msg_trim(sk, msg_en,
+				    msg_pl->sg.size + prot->overhead_size);
+		}
+
+		if (try_to_copy) {
+			ret = sk_msg_memcopy_from_iter(sk, &msg->msg_iter,
+						       msg_pl, try_to_copy);
+			if (ret < 0)
+				goto trim_sgl;
+		}
+
+		/* Open records defined only if successfully copied, otherwise
+		 * we would trim the sg but not reset the open record frags.
+		 */
+		tls_ctx->pending_open_record_frags = true;
+		copied += try_to_copy;
+		if (full_record || eor) {
+			ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
+						  record_type, &copied,
+						  msg->msg_flags);
+			if (ret) {
+				if (ret == -EINPROGRESS)
+					num_async++;
+				else if (ret == -ENOMEM)
+					goto wait_for_memory;
+				else if (ret != -EAGAIN) {
+					if (ret == -ENOSPC)
+						ret = 0;
+					goto send_end;
+				}
+			}
+		}
+
+		continue;
+
+wait_for_sndbuf:
+		set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
+wait_for_memory:
+		ret = sk_stream_wait_memory(sk, &timeo);
+		if (ret) {
+trim_sgl:
+			if (ctx->open_rec)
+				tls_trim_both_msgs(sk, orig_size);
+			goto send_end;
+		}
+
+		if (ctx->open_rec && msg_en->sg.size < required_size)
+			goto alloc_encrypted;
+	}
+
+	if (!num_async) {
+		goto send_end;
+	} else if (num_zc) {
+		/* Wait for pending encryptions to get completed */
+		smp_store_mb(ctx->async_notify, true);
+
+		if (atomic_read(&ctx->encrypt_pending))
+			crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
+		else
+			reinit_completion(&ctx->async_wait.completion);
+
+		WRITE_ONCE(ctx->async_notify, false);
+
+		if (ctx->async_wait.err) {
+			ret = ctx->async_wait.err;
+			copied = 0;
+		}
+	}
+
+	/* Transmit if any encryptions have completed */
+	if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) {
+		cancel_delayed_work(&ctx->tx_work.work);
+		tls_tx_records(sk, msg->msg_flags);
+	}
+
+send_end:
+	ret = sk_stream_error(sk, msg->msg_flags, ret);
+
+	release_sock(sk);
+	mutex_unlock(&tls_ctx->tx_lock);
+	return copied ? copied : ret;
+}
+
+static int tls_sw_do_sendpage(struct sock *sk, struct page *page,
+			      int offset, size_t size, int flags)
+{
+	long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
+	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
+	struct tls_prot_info *prot = &tls_ctx->prot_info;
+	unsigned char record_type = TLS_RECORD_TYPE_DATA;
+	struct sk_msg *msg_pl;
+	struct tls_rec *rec;
+	int num_async = 0;
+	size_t copied = 0;
+	bool full_record;
+	int record_room;
+	int ret = 0;
+	bool eor;
+
+	eor = !(flags & (MSG_MORE | MSG_SENDPAGE_NOTLAST));
+	sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk);
+
+	/* Call the sk_stream functions to manage the sndbuf mem. */
+	while (size > 0) {
+		size_t copy, required_size;
+
+		if (sk->sk_err) {
+			ret = -sk->sk_err;
+			goto sendpage_end;
+		}
+
+		if (ctx->open_rec)
+			rec = ctx->open_rec;
+		else
+			rec = ctx->open_rec = tls_get_rec(sk);
+		if (!rec) {
+			ret = -ENOMEM;
+			goto sendpage_end;
+		}
+
+		msg_pl = &rec->msg_plaintext;
+
+		full_record = false;
+		record_room = TLS_MAX_PAYLOAD_SIZE - msg_pl->sg.size;
+		copy = size;
+		if (copy >= record_room) {
+			copy = record_room;
+			full_record = true;
+		}
+
+		required_size = msg_pl->sg.size + copy + prot->overhead_size;
+
+		if (!sk_stream_memory_free(sk))
+			goto wait_for_sndbuf;
+alloc_payload:
+		ret = tls_alloc_encrypted_msg(sk, required_size);
+		if (ret) {
+			if (ret != -ENOSPC)
+				goto wait_for_memory;
+
+			/* Adjust copy according to the amount that was
+			 * actually allocated. The difference is due
+			 * to max sg elements limit
+			 */
+			copy -= required_size - msg_pl->sg.size;
+			full_record = true;
+		}
+
+		sk_msg_page_add(msg_pl, page, copy, offset);
+		sk_mem_charge(sk, copy);
+
+		offset += copy;
+		size -= copy;
+		copied += copy;
+
+		tls_ctx->pending_open_record_frags = true;
+		if (full_record || eor || sk_msg_full(msg_pl)) {
+			ret = bpf_exec_tx_verdict(msg_pl, sk, full_record,
+						  record_type, &copied, flags);
+			if (ret) {
+				if (ret == -EINPROGRESS)
+					num_async++;
+				else if (ret == -ENOMEM)
+					goto wait_for_memory;
+				else if (ret != -EAGAIN) {
+					if (ret == -ENOSPC)
+						ret = 0;
+					goto sendpage_end;
+				}
+			}
+		}
+		continue;
+wait_for_sndbuf:
+		set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
+wait_for_memory:
+		ret = sk_stream_wait_memory(sk, &timeo);
+		if (ret) {
+			if (ctx->open_rec)
+				tls_trim_both_msgs(sk, msg_pl->sg.size);
+			goto sendpage_end;
+		}
+
+		if (ctx->open_rec)
+			goto alloc_payload;
+	}
+
+	if (num_async) {
+		/* Transmit if any encryptions have completed */
+		if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) {
+			cancel_delayed_work(&ctx->tx_work.work);
+			tls_tx_records(sk, flags);
+		}
+	}
+sendpage_end:
+	ret = sk_stream_error(sk, flags, ret);
+	return copied ? copied : ret;
+}
+
+int tls_sw_sendpage_locked(struct sock *sk, struct page *page,
+			   int offset, size_t size, int flags)
+{
+	if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
+		      MSG_SENDPAGE_NOTLAST | MSG_SENDPAGE_NOPOLICY |
+		      MSG_NO_SHARED_FRAGS))
+		return -ENOTSUPP;
+
+	return tls_sw_do_sendpage(sk, page, offset, size, flags);
+}
+
+int tls_sw_sendpage(struct sock *sk, struct page *page,
+		    int offset, size_t size, int flags)
+{
+	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	int ret;
+
+	if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
+		      MSG_SENDPAGE_NOTLAST | MSG_SENDPAGE_NOPOLICY))
+		return -ENOTSUPP;
+
+	mutex_lock(&tls_ctx->tx_lock);
+	lock_sock(sk);
+	ret = tls_sw_do_sendpage(sk, page, offset, size, flags);
+	release_sock(sk);
+	mutex_unlock(&tls_ctx->tx_lock);
+	return ret;
+}
+
+static struct sk_buff *tls_wait_data(struct sock *sk, struct sk_psock *psock,
+				     int flags, long timeo, int *err)
+{
+	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
+	struct sk_buff *skb;
+	DEFINE_WAIT_FUNC(wait, woken_wake_function);
+
+	while (!(skb = ctx->recv_pkt) && sk_psock_queue_empty(psock)) {
+		if (sk->sk_err) {
+			*err = sock_error(sk);
+			return NULL;
+		}
+
+		if (sk->sk_shutdown & RCV_SHUTDOWN)
+			return NULL;
+
+		if (sock_flag(sk, SOCK_DONE))
+			return NULL;
+
+		if ((flags & MSG_DONTWAIT) || !timeo) {
+			*err = -EAGAIN;
+			return NULL;
+		}
+
+		add_wait_queue(sk_sleep(sk), &wait);
+		sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
+		sk_wait_event(sk, &timeo,
+			      ctx->recv_pkt != skb ||
+			      !sk_psock_queue_empty(psock),
+			      &wait);
+		sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
+		remove_wait_queue(sk_sleep(sk), &wait);
+
+		/* Handle signals */
+		if (signal_pending(current)) {
+			*err = sock_intr_errno(timeo);
+			return NULL;
+		}
+	}
+
+	return skb;
+}
+
+static int tls_setup_from_iter(struct sock *sk, struct iov_iter *from,
+			       int length, int *pages_used,
+			       unsigned int *size_used,
+			       struct scatterlist *to,
+			       int to_max_pages)
+{
+	int rc = 0, i = 0, num_elem = *pages_used, maxpages;
+	struct page *pages[MAX_SKB_FRAGS];
 	unsigned int size = *size_used;
-	int num_elem = *pages_used;
-	int rc = 0;
-	int maxpages;
+	ssize_t copied, use;
+	size_t offset;
 
 	while (length > 0) {
 		i = 0;
@@ -298,17 +1319,15 @@
 			sg_set_page(&to[num_elem],
 				    pages[i], use, offset);
 			sg_unmark_end(&to[num_elem]);
-			if (charge)
-				sk_mem_charge(sk, use);
+			/* We do not uncharge memory from this API */
 
 			offset = 0;
 			copied -= use;
 
-			++i;
-			++num_elem;
+			i++;
+			num_elem++;
 		}
 	}
-
 	/* Mark the end in the last sg entry if newly added */
 	if (num_elem > *pages_used)
 		sg_mark_end(&to[num_elem - 1]);
@@ -321,340 +1340,6 @@
 	return rc;
 }
 
-static int memcopy_from_iter(struct sock *sk, struct iov_iter *from,
-			     int bytes)
-{
-	struct tls_context *tls_ctx = tls_get_ctx(sk);
-	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
-	struct scatterlist *sg = ctx->sg_plaintext_data;
-	int copy, i, rc = 0;
-
-	for (i = tls_ctx->pending_open_record_frags;
-	     i < ctx->sg_plaintext_num_elem; ++i) {
-		copy = sg[i].length;
-		if (copy_from_iter(
-				page_address(sg_page(&sg[i])) + sg[i].offset,
-				copy, from) != copy) {
-			rc = -EFAULT;
-			goto out;
-		}
-		bytes -= copy;
-
-		++tls_ctx->pending_open_record_frags;
-
-		if (!bytes)
-			break;
-	}
-
-out:
-	return rc;
-}
-
-int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
-{
-	struct tls_context *tls_ctx = tls_get_ctx(sk);
-	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
-	int ret = 0;
-	int required_size;
-	long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
-	bool eor = !(msg->msg_flags & MSG_MORE);
-	size_t try_to_copy, copied = 0;
-	unsigned char record_type = TLS_RECORD_TYPE_DATA;
-	int record_room;
-	bool full_record;
-	int orig_size;
-	bool is_kvec = msg->msg_iter.type & ITER_KVEC;
-
-	if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL))
-		return -ENOTSUPP;
-
-	lock_sock(sk);
-
-	if (tls_complete_pending_work(sk, tls_ctx, msg->msg_flags, &timeo))
-		goto send_end;
-
-	if (unlikely(msg->msg_controllen)) {
-		ret = tls_proccess_cmsg(sk, msg, &record_type);
-		if (ret)
-			goto send_end;
-	}
-
-	while (msg_data_left(msg)) {
-		if (sk->sk_err) {
-			ret = -sk->sk_err;
-			goto send_end;
-		}
-
-		orig_size = ctx->sg_plaintext_size;
-		full_record = false;
-		try_to_copy = msg_data_left(msg);
-		record_room = TLS_MAX_PAYLOAD_SIZE - ctx->sg_plaintext_size;
-		if (try_to_copy >= record_room) {
-			try_to_copy = record_room;
-			full_record = true;
-		}
-
-		required_size = ctx->sg_plaintext_size + try_to_copy +
-				tls_ctx->tx.overhead_size;
-
-		if (!sk_stream_memory_free(sk))
-			goto wait_for_sndbuf;
-alloc_encrypted:
-		ret = alloc_encrypted_sg(sk, required_size);
-		if (ret) {
-			if (ret != -ENOSPC)
-				goto wait_for_memory;
-
-			/* Adjust try_to_copy according to the amount that was
-			 * actually allocated. The difference is due
-			 * to max sg elements limit
-			 */
-			try_to_copy -= required_size - ctx->sg_encrypted_size;
-			full_record = true;
-		}
-		if (!is_kvec && (full_record || eor)) {
-			ret = zerocopy_from_iter(sk, &msg->msg_iter,
-				try_to_copy, &ctx->sg_plaintext_num_elem,
-				&ctx->sg_plaintext_size,
-				ctx->sg_plaintext_data,
-				ARRAY_SIZE(ctx->sg_plaintext_data),
-				true);
-			if (ret)
-				goto fallback_to_reg_send;
-
-			copied += try_to_copy;
-			ret = tls_push_record(sk, msg->msg_flags, record_type);
-			if (ret)
-				goto send_end;
-			continue;
-
-fallback_to_reg_send:
-			trim_sg(sk, ctx->sg_plaintext_data,
-				&ctx->sg_plaintext_num_elem,
-				&ctx->sg_plaintext_size,
-				orig_size);
-		}
-
-		required_size = ctx->sg_plaintext_size + try_to_copy;
-alloc_plaintext:
-		ret = alloc_plaintext_sg(sk, required_size);
-		if (ret) {
-			if (ret != -ENOSPC)
-				goto wait_for_memory;
-
-			/* Adjust try_to_copy according to the amount that was
-			 * actually allocated. The difference is due
-			 * to max sg elements limit
-			 */
-			try_to_copy -= required_size - ctx->sg_plaintext_size;
-			full_record = true;
-
-			trim_sg(sk, ctx->sg_encrypted_data,
-				&ctx->sg_encrypted_num_elem,
-				&ctx->sg_encrypted_size,
-				ctx->sg_plaintext_size +
-				tls_ctx->tx.overhead_size);
-		}
-
-		ret = memcopy_from_iter(sk, &msg->msg_iter, try_to_copy);
-		if (ret)
-			goto trim_sgl;
-
-		copied += try_to_copy;
-		if (full_record || eor) {
-push_record:
-			ret = tls_push_record(sk, msg->msg_flags, record_type);
-			if (ret) {
-				if (ret == -ENOMEM)
-					goto wait_for_memory;
-
-				goto send_end;
-			}
-		}
-
-		continue;
-
-wait_for_sndbuf:
-		set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
-wait_for_memory:
-		ret = sk_stream_wait_memory(sk, &timeo);
-		if (ret) {
-trim_sgl:
-			trim_both_sgl(sk, orig_size);
-			goto send_end;
-		}
-
-		if (tls_is_pending_closed_record(tls_ctx))
-			goto push_record;
-
-		if (ctx->sg_encrypted_size < required_size)
-			goto alloc_encrypted;
-
-		goto alloc_plaintext;
-	}
-
-send_end:
-	ret = sk_stream_error(sk, msg->msg_flags, ret);
-
-	release_sock(sk);
-	return copied ? copied : ret;
-}
-
-int tls_sw_sendpage(struct sock *sk, struct page *page,
-		    int offset, size_t size, int flags)
-{
-	struct tls_context *tls_ctx = tls_get_ctx(sk);
-	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
-	int ret = 0;
-	long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
-	bool eor;
-	size_t orig_size = size;
-	unsigned char record_type = TLS_RECORD_TYPE_DATA;
-	struct scatterlist *sg;
-	bool full_record;
-	int record_room;
-
-	if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
-		      MSG_SENDPAGE_NOTLAST))
-		return -ENOTSUPP;
-
-	/* No MSG_EOR from splice, only look at MSG_MORE */
-	eor = !(flags & (MSG_MORE | MSG_SENDPAGE_NOTLAST));
-
-	lock_sock(sk);
-
-	sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk);
-
-	if (tls_complete_pending_work(sk, tls_ctx, flags, &timeo))
-		goto sendpage_end;
-
-	/* Call the sk_stream functions to manage the sndbuf mem. */
-	while (size > 0) {
-		size_t copy, required_size;
-
-		if (sk->sk_err) {
-			ret = -sk->sk_err;
-			goto sendpage_end;
-		}
-
-		full_record = false;
-		record_room = TLS_MAX_PAYLOAD_SIZE - ctx->sg_plaintext_size;
-		copy = size;
-		if (copy >= record_room) {
-			copy = record_room;
-			full_record = true;
-		}
-		required_size = ctx->sg_plaintext_size + copy +
-			      tls_ctx->tx.overhead_size;
-
-		if (!sk_stream_memory_free(sk))
-			goto wait_for_sndbuf;
-alloc_payload:
-		ret = alloc_encrypted_sg(sk, required_size);
-		if (ret) {
-			if (ret != -ENOSPC)
-				goto wait_for_memory;
-
-			/* Adjust copy according to the amount that was
-			 * actually allocated. The difference is due
-			 * to max sg elements limit
-			 */
-			copy -= required_size - ctx->sg_plaintext_size;
-			full_record = true;
-		}
-
-		get_page(page);
-		sg = ctx->sg_plaintext_data + ctx->sg_plaintext_num_elem;
-		sg_set_page(sg, page, copy, offset);
-		sg_unmark_end(sg);
-
-		ctx->sg_plaintext_num_elem++;
-
-		sk_mem_charge(sk, copy);
-		offset += copy;
-		size -= copy;
-		ctx->sg_plaintext_size += copy;
-		tls_ctx->pending_open_record_frags = ctx->sg_plaintext_num_elem;
-
-		if (full_record || eor ||
-		    ctx->sg_plaintext_num_elem ==
-		    ARRAY_SIZE(ctx->sg_plaintext_data)) {
-push_record:
-			ret = tls_push_record(sk, flags, record_type);
-			if (ret) {
-				if (ret == -ENOMEM)
-					goto wait_for_memory;
-
-				goto sendpage_end;
-			}
-		}
-		continue;
-wait_for_sndbuf:
-		set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
-wait_for_memory:
-		ret = sk_stream_wait_memory(sk, &timeo);
-		if (ret) {
-			trim_both_sgl(sk, ctx->sg_plaintext_size);
-			goto sendpage_end;
-		}
-
-		if (tls_is_pending_closed_record(tls_ctx))
-			goto push_record;
-
-		goto alloc_payload;
-	}
-
-sendpage_end:
-	if (orig_size > size)
-		ret = orig_size - size;
-	else
-		ret = sk_stream_error(sk, flags, ret);
-
-	release_sock(sk);
-	return ret;
-}
-
-static struct sk_buff *tls_wait_data(struct sock *sk, int flags,
-				     long timeo, int *err)
-{
-	struct tls_context *tls_ctx = tls_get_ctx(sk);
-	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
-	struct sk_buff *skb;
-	DEFINE_WAIT_FUNC(wait, woken_wake_function);
-
-	while (!(skb = ctx->recv_pkt)) {
-		if (sk->sk_err) {
-			*err = sock_error(sk);
-			return NULL;
-		}
-
-		if (sk->sk_shutdown & RCV_SHUTDOWN)
-			return NULL;
-
-		if (sock_flag(sk, SOCK_DONE))
-			return NULL;
-
-		if ((flags & MSG_DONTWAIT) || !timeo) {
-			*err = -EAGAIN;
-			return NULL;
-		}
-
-		add_wait_queue(sk_sleep(sk), &wait);
-		sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
-		sk_wait_event(sk, &timeo, ctx->recv_pkt != skb, &wait);
-		sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
-		remove_wait_queue(sk_sleep(sk), &wait);
-
-		/* Handle signals */
-		if (signal_pending(current)) {
-			*err = sock_intr_errno(timeo);
-			return NULL;
-		}
-	}
-
-	return skb;
-}
-
 /* This function decrypts the input skb into either out_iov or in out_sg
  * or in skb buffers itself. The input parameter 'zc' indicates if
  * zero-copy mode needs to be tried or not. With zero-copy mode, either
@@ -666,10 +1351,11 @@
 static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
 			    struct iov_iter *out_iov,
 			    struct scatterlist *out_sg,
-			    int *chunk, bool *zc)
+			    int *chunk, bool *zc, bool async)
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
+	struct tls_prot_info *prot = &tls_ctx->prot_info;
 	struct strp_msg *rxm = strp_msg(skb);
 	int n_sgin, n_sgout, nsg, mem_size, aead_size, err, pages = 0;
 	struct aead_request *aead_req;
@@ -677,19 +1363,23 @@
 	u8 *aad, *iv, *mem = NULL;
 	struct scatterlist *sgin = NULL;
 	struct scatterlist *sgout = NULL;
-	const int data_len = rxm->full_len - tls_ctx->rx.overhead_size;
+	const int data_len = rxm->full_len - prot->overhead_size +
+			     prot->tail_size;
+	int iv_offset = 0;
 
 	if (*zc && (out_iov || out_sg)) {
 		if (out_iov)
 			n_sgout = iov_iter_npages(out_iov, INT_MAX) + 1;
 		else
 			n_sgout = sg_nents(out_sg);
+		n_sgin = skb_nsg(skb, rxm->offset + prot->prepend_size,
+				 rxm->full_len - prot->prepend_size);
 	} else {
 		n_sgout = 0;
 		*zc = false;
+		n_sgin = skb_cow_data(skb, 0, &unused);
 	}
 
-	n_sgin = skb_cow_data(skb, 0, &unused);
 	if (n_sgin < 1)
 		return -EBADMSG;
 
@@ -700,7 +1390,7 @@
 
 	aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv);
 	mem_size = aead_size + (nsg * sizeof(struct scatterlist));
-	mem_size = mem_size + TLS_AAD_SPACE_SIZE;
+	mem_size = mem_size + prot->aad_size;
 	mem_size = mem_size + crypto_aead_ivsize(ctx->aead_recv);
 
 	/* Allocate a single block of memory which contains
@@ -716,29 +1406,42 @@
 	sgin = (struct scatterlist *)(mem + aead_size);
 	sgout = sgin + n_sgin;
 	aad = (u8 *)(sgout + n_sgout);
-	iv = aad + TLS_AAD_SPACE_SIZE;
+	iv = aad + prot->aad_size;
+
+	/* For CCM based ciphers, first byte of nonce+iv is always '2' */
+	if (prot->cipher_type == TLS_CIPHER_AES_CCM_128) {
+		iv[0] = 2;
+		iv_offset = 1;
+	}
 
 	/* Prepare IV */
 	err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE,
-			    iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
-			    tls_ctx->rx.iv_size);
+			    iv + iv_offset + prot->salt_size,
+			    prot->iv_size);
 	if (err < 0) {
 		kfree(mem);
 		return err;
 	}
-	memcpy(iv, tls_ctx->rx.iv, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
+	if (prot->version == TLS_1_3_VERSION)
+		memcpy(iv + iv_offset, tls_ctx->rx.iv,
+		       crypto_aead_ivsize(ctx->aead_recv));
+	else
+		memcpy(iv + iv_offset, tls_ctx->rx.iv, prot->salt_size);
+
+	xor_iv_with_seq(prot->version, iv, tls_ctx->rx.rec_seq);
 
 	/* Prepare AAD */
-	tls_make_aad(aad, rxm->full_len - tls_ctx->rx.overhead_size,
-		     tls_ctx->rx.rec_seq, tls_ctx->rx.rec_seq_size,
-		     ctx->control);
+	tls_make_aad(aad, rxm->full_len - prot->overhead_size +
+		     prot->tail_size,
+		     tls_ctx->rx.rec_seq, prot->rec_seq_size,
+		     ctx->control, prot->version);
 
 	/* Prepare sgin */
 	sg_init_table(sgin, n_sgin);
-	sg_set_buf(&sgin[0], aad, TLS_AAD_SPACE_SIZE);
+	sg_set_buf(&sgin[0], aad, prot->aad_size);
 	err = skb_to_sgvec(skb, &sgin[1],
-			   rxm->offset + tls_ctx->rx.prepend_size,
-			   rxm->full_len - tls_ctx->rx.prepend_size);
+			   rxm->offset + prot->prepend_size,
+			   rxm->full_len - prot->prepend_size);
 	if (err < 0) {
 		kfree(mem);
 		return err;
@@ -747,12 +1450,12 @@
 	if (n_sgout) {
 		if (out_iov) {
 			sg_init_table(sgout, n_sgout);
-			sg_set_buf(&sgout[0], aad, TLS_AAD_SPACE_SIZE);
+			sg_set_buf(&sgout[0], aad, prot->aad_size);
 
 			*chunk = 0;
-			err = zerocopy_from_iter(sk, out_iov, data_len, &pages,
-						 chunk, &sgout[1],
-						 (n_sgout - 1), false);
+			err = tls_setup_from_iter(sk, out_iov, data_len,
+						  &pages, chunk, &sgout[1],
+						  (n_sgout - 1));
 			if (err < 0)
 				goto fallback_to_reg_recv;
 		} else if (out_sg) {
@@ -764,12 +1467,15 @@
 fallback_to_reg_recv:
 		sgout = sgin;
 		pages = 0;
-		*chunk = 0;
+		*chunk = data_len;
 		*zc = false;
 	}
 
 	/* Prepare and submit AEAD request */
-	err = tls_do_decryption(sk, sgin, sgout, iv, data_len, aead_req);
+	err = tls_do_decryption(sk, skb, sgin, sgout, iv,
+				data_len, aead_req, async);
+	if (err == -EINPROGRESS)
+		return err;
 
 	/* Release the pages in case iov was mapped to pages */
 	for (; pages > 0; pages--)
@@ -780,32 +1486,51 @@
 }
 
 static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
-			      struct iov_iter *dest, int *chunk, bool *zc)
+			      struct iov_iter *dest, int *chunk, bool *zc,
+			      bool async)
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
+	struct tls_prot_info *prot = &tls_ctx->prot_info;
 	struct strp_msg *rxm = strp_msg(skb);
-	int err = 0;
+	int pad, err = 0;
 
-#ifdef CONFIG_TLS_DEVICE
-	err = tls_device_decrypted(sk, skb);
-	if (err < 0)
-		return err;
-#endif
 	if (!ctx->decrypted) {
-		err = decrypt_internal(sk, skb, dest, NULL, chunk, zc);
-		if (err < 0)
-			return err;
+		if (tls_ctx->rx_conf == TLS_HW) {
+			err = tls_device_decrypted(sk, skb);
+			if (err < 0)
+				return err;
+		}
+
+		/* Still not decrypted after tls_device */
+		if (!ctx->decrypted) {
+			err = decrypt_internal(sk, skb, dest, NULL, chunk, zc,
+					       async);
+			if (err < 0) {
+				if (err == -EINPROGRESS)
+					tls_advance_record_sn(sk, prot,
+							      &tls_ctx->rx);
+
+				return err;
+			}
+		} else {
+			*zc = false;
+		}
+
+		pad = padding_length(ctx, prot, skb);
+		if (pad < 0)
+			return pad;
+
+		rxm->full_len -= pad;
+		rxm->offset += prot->prepend_size;
+		rxm->full_len -= prot->overhead_size;
+		tls_advance_record_sn(sk, prot, &tls_ctx->rx);
+		ctx->decrypted = true;
+		ctx->saved_data_ready(sk);
 	} else {
 		*zc = false;
 	}
 
-	rxm->offset += tls_ctx->rx.prepend_size;
-	rxm->full_len -= tls_ctx->rx.overhead_size;
-	tls_advance_record_sn(sk, &tls_ctx->rx);
-	ctx->decrypted = true;
-	ctx->saved_data_ready(sk);
-
 	return err;
 }
 
@@ -815,7 +1540,7 @@
 	bool zc = true;
 	int chunk;
 
-	return decrypt_internal(sk, skb, NULL, sgout, &chunk, &zc);
+	return decrypt_internal(sk, skb, NULL, sgout, &chunk, &zc, false);
 }
 
 static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb,
@@ -823,23 +1548,134 @@
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
-	struct strp_msg *rxm = strp_msg(skb);
 
-	if (len < rxm->full_len) {
-		rxm->offset += len;
-		rxm->full_len -= len;
+	if (skb) {
+		struct strp_msg *rxm = strp_msg(skb);
 
-		return false;
+		if (len < rxm->full_len) {
+			rxm->offset += len;
+			rxm->full_len -= len;
+			return false;
+		}
+		consume_skb(skb);
 	}
 
 	/* Finished with message */
 	ctx->recv_pkt = NULL;
-	kfree_skb(skb);
 	__strp_unpause(&ctx->strp);
 
 	return true;
 }
 
+/* This function traverses the rx_list in tls receive context to copies the
+ * decrypted records into the buffer provided by caller zero copy is not
+ * true. Further, the records are removed from the rx_list if it is not a peek
+ * case and the record has been consumed completely.
+ */
+static int process_rx_list(struct tls_sw_context_rx *ctx,
+			   struct msghdr *msg,
+			   u8 *control,
+			   bool *cmsg,
+			   size_t skip,
+			   size_t len,
+			   bool zc,
+			   bool is_peek)
+{
+	struct sk_buff *skb = skb_peek(&ctx->rx_list);
+	u8 ctrl = *control;
+	u8 msgc = *cmsg;
+	struct tls_msg *tlm;
+	ssize_t copied = 0;
+
+	/* Set the record type in 'control' if caller didn't pass it */
+	if (!ctrl && skb) {
+		tlm = tls_msg(skb);
+		ctrl = tlm->control;
+	}
+
+	while (skip && skb) {
+		struct strp_msg *rxm = strp_msg(skb);
+		tlm = tls_msg(skb);
+
+		/* Cannot process a record of different type */
+		if (ctrl != tlm->control)
+			return 0;
+
+		if (skip < rxm->full_len)
+			break;
+
+		skip = skip - rxm->full_len;
+		skb = skb_peek_next(skb, &ctx->rx_list);
+	}
+
+	while (len && skb) {
+		struct sk_buff *next_skb;
+		struct strp_msg *rxm = strp_msg(skb);
+		int chunk = min_t(unsigned int, rxm->full_len - skip, len);
+
+		tlm = tls_msg(skb);
+
+		/* Cannot process a record of different type */
+		if (ctrl != tlm->control)
+			return 0;
+
+		/* Set record type if not already done. For a non-data record,
+		 * do not proceed if record type could not be copied.
+		 */
+		if (!msgc) {
+			int cerr = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE,
+					    sizeof(ctrl), &ctrl);
+			msgc = true;
+			if (ctrl != TLS_RECORD_TYPE_DATA) {
+				if (cerr || msg->msg_flags & MSG_CTRUNC)
+					return -EIO;
+
+				*cmsg = msgc;
+			}
+		}
+
+		if (!zc || (rxm->full_len - skip) > len) {
+			int err = skb_copy_datagram_msg(skb, rxm->offset + skip,
+						    msg, chunk);
+			if (err < 0)
+				return err;
+		}
+
+		len = len - chunk;
+		copied = copied + chunk;
+
+		/* Consume the data from record if it is non-peek case*/
+		if (!is_peek) {
+			rxm->offset = rxm->offset + chunk;
+			rxm->full_len = rxm->full_len - chunk;
+
+			/* Return if there is unconsumed data in the record */
+			if (rxm->full_len - skip)
+				break;
+		}
+
+		/* The remaining skip-bytes must lie in 1st record in rx_list.
+		 * So from the 2nd record, 'skip' should be 0.
+		 */
+		skip = 0;
+
+		if (msg)
+			msg->msg_flags |= MSG_EOR;
+
+		next_skb = skb_peek_next(skb, &ctx->rx_list);
+
+		if (!is_peek) {
+			skb_unlink(skb, &ctx->rx_list);
+			consume_skb(skb);
+		}
+
+		skb = next_skb;
+	}
+
+	*control = ctrl;
+	return copied;
+}
+
 int tls_sw_recvmsg(struct sock *sk,
 		   struct msghdr *msg,
 		   size_t len,
@@ -849,104 +1685,217 @@
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
-	unsigned char control;
+	struct tls_prot_info *prot = &tls_ctx->prot_info;
+	struct sk_psock *psock;
+	unsigned char control = 0;
+	ssize_t decrypted = 0;
 	struct strp_msg *rxm;
+	struct tls_msg *tlm;
 	struct sk_buff *skb;
 	ssize_t copied = 0;
 	bool cmsg = false;
 	int target, err = 0;
 	long timeo;
-	bool is_kvec = msg->msg_iter.type & ITER_KVEC;
+	bool is_kvec = iov_iter_is_kvec(&msg->msg_iter);
+	bool is_peek = flags & MSG_PEEK;
+	int num_async = 0;
 
 	flags |= nonblock;
 
 	if (unlikely(flags & MSG_ERRQUEUE))
 		return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR);
 
+	psock = sk_psock_get(sk);
 	lock_sock(sk);
 
-	target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
-	timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
-	do {
-		bool zc = false;
-		int chunk = 0;
+	/* Process pending decrypted records. It must be non-zero-copy */
+	err = process_rx_list(ctx, msg, &control, &cmsg, 0, len, false,
+			      is_peek);
+	if (err < 0) {
+		tls_err_abort(sk, err);
+		goto end;
+	} else {
+		copied = err;
+	}
 
-		skb = tls_wait_data(sk, flags, timeo, &err);
-		if (!skb)
+	if (len <= copied)
+		goto recv_end;
+
+	target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
+	len = len - copied;
+	timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
+
+	while (len && (decrypted + copied < target || ctx->recv_pkt)) {
+		bool retain_skb = false;
+		bool zc = false;
+		int to_decrypt;
+		int chunk = 0;
+		bool async_capable;
+		bool async = false;
+
+		skb = tls_wait_data(sk, psock, flags, timeo, &err);
+		if (!skb) {
+			if (psock) {
+				int ret = __tcp_bpf_recvmsg(sk, psock,
+							    msg, len, flags);
+
+				if (ret > 0) {
+					decrypted += ret;
+					len -= ret;
+					continue;
+				}
+			}
 			goto recv_end;
+		} else {
+			tlm = tls_msg(skb);
+			if (prot->version == TLS_1_3_VERSION)
+				tlm->control = 0;
+			else
+				tlm->control = ctx->control;
+		}
 
 		rxm = strp_msg(skb);
+
+		to_decrypt = rxm->full_len - prot->overhead_size;
+
+		if (to_decrypt <= len && !is_kvec && !is_peek &&
+		    ctx->control == TLS_RECORD_TYPE_DATA &&
+		    prot->version != TLS_1_3_VERSION)
+			zc = true;
+
+		/* Do not use async mode if record is non-data */
+		if (ctx->control == TLS_RECORD_TYPE_DATA)
+			async_capable = ctx->async_capable;
+		else
+			async_capable = false;
+
+		err = decrypt_skb_update(sk, skb, &msg->msg_iter,
+					 &chunk, &zc, async_capable);
+		if (err < 0 && err != -EINPROGRESS) {
+			tls_err_abort(sk, EBADMSG);
+			goto recv_end;
+		}
+
+		if (err == -EINPROGRESS) {
+			async = true;
+			num_async++;
+		} else if (prot->version == TLS_1_3_VERSION) {
+			tlm->control = ctx->control;
+		}
+
+		/* If the type of records being processed is not known yet,
+		 * set it to record type just dequeued. If it is already known,
+		 * but does not match the record type just dequeued, go to end.
+		 * We always get record type here since for tls1.2, record type
+		 * is known just after record is dequeued from stream parser.
+		 * For tls1.3, we disable async.
+		 */
+
+		if (!control)
+			control = tlm->control;
+		else if (control != tlm->control)
+			goto recv_end;
+
 		if (!cmsg) {
 			int cerr;
 
 			cerr = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE,
-					sizeof(ctx->control), &ctx->control);
+					sizeof(control), &control);
 			cmsg = true;
-			control = ctx->control;
-			if (ctx->control != TLS_RECORD_TYPE_DATA) {
+			if (control != TLS_RECORD_TYPE_DATA) {
 				if (cerr || msg->msg_flags & MSG_CTRUNC) {
 					err = -EIO;
 					goto recv_end;
 				}
 			}
-		} else if (control != ctx->control) {
-			goto recv_end;
 		}
 
-		if (!ctx->decrypted) {
-			int to_copy = rxm->full_len - tls_ctx->rx.overhead_size;
-
-			if (!is_kvec && to_copy <= len &&
-			    likely(!(flags & MSG_PEEK)))
-				zc = true;
-
-			err = decrypt_skb_update(sk, skb, &msg->msg_iter,
-						 &chunk, &zc);
-			if (err < 0) {
-				tls_err_abort(sk, EBADMSG);
-				goto recv_end;
-			}
-			ctx->decrypted = true;
-		}
+		if (async)
+			goto pick_next_record;
 
 		if (!zc) {
-			chunk = min_t(unsigned int, rxm->full_len, len);
-			err = skb_copy_datagram_msg(skb, rxm->offset, msg,
-						    chunk);
+			if (rxm->full_len > len) {
+				retain_skb = true;
+				chunk = len;
+			} else {
+				chunk = rxm->full_len;
+			}
+
+			err = skb_copy_datagram_msg(skb, rxm->offset,
+						    msg, chunk);
 			if (err < 0)
 				goto recv_end;
-		}
 
-		copied += chunk;
-		len -= chunk;
-		if (likely(!(flags & MSG_PEEK))) {
-			u8 control = ctx->control;
-
-			if (tls_sw_advance_skb(sk, skb, chunk)) {
-				/* Return full control message to
-				 * userspace before trying to parse
-				 * another message type
-				 */
-				msg->msg_flags |= MSG_EOR;
-				if (control != TLS_RECORD_TYPE_DATA)
-					goto recv_end;
+			if (!is_peek) {
+				rxm->offset = rxm->offset + chunk;
+				rxm->full_len = rxm->full_len - chunk;
 			}
-		} else {
-			/* MSG_PEEK right now cannot look beyond current skb
-			 * from strparser, meaning we cannot advance skb here
-			 * and thus unpause strparser since we'd loose original
-			 * one.
-			 */
-			break;
 		}
 
-		/* If we have a new message from strparser, continue now. */
-		if (copied >= target && !ctx->recv_pkt)
+pick_next_record:
+		if (chunk > len)
+			chunk = len;
+
+		decrypted += chunk;
+		len -= chunk;
+
+		/* For async or peek case, queue the current skb */
+		if (async || is_peek || retain_skb) {
+			skb_queue_tail(&ctx->rx_list, skb);
+			skb = NULL;
+		}
+
+		if (tls_sw_advance_skb(sk, skb, chunk)) {
+			/* Return full control message to
+			 * userspace before trying to parse
+			 * another message type
+			 */
+			msg->msg_flags |= MSG_EOR;
+			if (ctx->control != TLS_RECORD_TYPE_DATA)
+				goto recv_end;
+		} else {
 			break;
-	} while (len);
+		}
+	}
 
 recv_end:
+	if (num_async) {
+		/* Wait for all previously submitted records to be decrypted */
+		smp_store_mb(ctx->async_notify, true);
+		if (atomic_read(&ctx->decrypt_pending)) {
+			err = crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
+			if (err) {
+				/* one of async decrypt failed */
+				tls_err_abort(sk, err);
+				copied = 0;
+				decrypted = 0;
+				goto end;
+			}
+		} else {
+			reinit_completion(&ctx->async_wait.completion);
+		}
+		WRITE_ONCE(ctx->async_notify, false);
+
+		/* Drain records from the rx_list & copy if required */
+		if (is_peek || is_kvec)
+			err = process_rx_list(ctx, msg, &control, &cmsg, copied,
+					      decrypted, false, is_peek);
+		else
+			err = process_rx_list(ctx, msg, &control, &cmsg, 0,
+					      decrypted, true, is_peek);
+		if (err < 0) {
+			tls_err_abort(sk, err);
+			copied = 0;
+			goto end;
+		}
+	}
+
+	copied += decrypted;
+
+end:
 	release_sock(sk);
+	if (psock)
+		sk_psock_put(sk, psock);
 	return copied ? : err;
 }
 
@@ -969,18 +1918,18 @@
 
 	timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
 
-	skb = tls_wait_data(sk, flags, timeo, &err);
+	skb = tls_wait_data(sk, NULL, flags, timeo, &err);
 	if (!skb)
 		goto splice_read_end;
 
-	/* splice does not support reading control messages */
-	if (ctx->control != TLS_RECORD_TYPE_DATA) {
-		err = -ENOTSUPP;
-		goto splice_read_end;
-	}
-
 	if (!ctx->decrypted) {
-		err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc);
+		err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc, false);
+
+		/* splice does not support reading control messages */
+		if (ctx->control != TLS_RECORD_TYPE_DATA) {
+			err = -ENOTSUPP;
+			goto splice_read_end;
+		}
 
 		if (err < 0) {
 			tls_err_abort(sk, EBADMSG);
@@ -1003,29 +1952,28 @@
 	return copied ? : err;
 }
 
-unsigned int tls_sw_poll(struct file *file, struct socket *sock,
-			 struct poll_table_struct *wait)
+bool tls_sw_stream_read(const struct sock *sk)
 {
-	unsigned int ret;
-	struct sock *sk = sock->sk;
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
+	bool ingress_empty = true;
+	struct sk_psock *psock;
 
-	/* Grab POLLOUT and POLLHUP from the underlying socket */
-	ret = ctx->sk_poll(file, sock, wait);
+	rcu_read_lock();
+	psock = sk_psock(sk);
+	if (psock)
+		ingress_empty = list_empty(&psock->ingress_msg);
+	rcu_read_unlock();
 
-	/* Clear POLLIN bits, and set based on recv_pkt */
-	ret &= ~(POLLIN | POLLRDNORM);
-	if (ctx->recv_pkt)
-		ret |= POLLIN | POLLRDNORM;
-
-	return ret;
+	return !ingress_empty || ctx->recv_pkt ||
+		!skb_queue_empty(&ctx->rx_list);
 }
 
 static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
 {
 	struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
+	struct tls_prot_info *prot = &tls_ctx->prot_info;
 	char header[TLS_HEADER_SIZE + MAX_IV_SIZE];
 	struct strp_msg *rxm = strp_msg(skb);
 	size_t cipher_overhead;
@@ -1033,17 +1981,17 @@
 	int ret;
 
 	/* Verify that we have a full TLS header, or wait for more data */
-	if (rxm->offset + tls_ctx->rx.prepend_size > skb->len)
+	if (rxm->offset + prot->prepend_size > skb->len)
 		return 0;
 
 	/* Sanity-check size of on-stack buffer. */
-	if (WARN_ON(tls_ctx->rx.prepend_size > sizeof(header))) {
+	if (WARN_ON(prot->prepend_size > sizeof(header))) {
 		ret = -EINVAL;
 		goto read_failure;
 	}
 
 	/* Linearize header to local buffer */
-	ret = skb_copy_bits(skb, rxm->offset, header, tls_ctx->rx.prepend_size);
+	ret = skb_copy_bits(skb, rxm->offset, header, prot->prepend_size);
 
 	if (ret < 0)
 		goto read_failure;
@@ -1052,9 +2000,12 @@
 
 	data_len = ((header[4] & 0xFF) | (header[3] << 8));
 
-	cipher_overhead = tls_ctx->rx.tag_size + tls_ctx->rx.iv_size;
+	cipher_overhead = prot->tag_size;
+	if (prot->version != TLS_1_3_VERSION)
+		cipher_overhead += prot->iv_size;
 
-	if (data_len > TLS_MAX_PAYLOAD_SIZE + cipher_overhead) {
+	if (data_len > TLS_MAX_PAYLOAD_SIZE + cipher_overhead +
+	    prot->tail_size) {
 		ret = -EMSGSIZE;
 		goto read_failure;
 	}
@@ -1063,16 +2014,15 @@
 		goto read_failure;
 	}
 
-	if (header[1] != TLS_VERSION_MINOR(tls_ctx->crypto_recv.info.version) ||
-	    header[2] != TLS_VERSION_MAJOR(tls_ctx->crypto_recv.info.version)) {
+	/* Note that both TLS1.3 and TLS1.2 use TLS_1_2 version here */
+	if (header[1] != TLS_1_2_VERSION_MINOR ||
+	    header[2] != TLS_1_2_VERSION_MAJOR) {
 		ret = -EINVAL;
 		goto read_failure;
 	}
 
-#ifdef CONFIG_TLS_DEVICE
-	handle_device_resync(strp->sk, TCP_SKB_CB(skb)->seq + rxm->offset,
-			     *(u64*)tls_ctx->rx.rec_seq);
-#endif
+	tls_device_rx_resync_new_rec(strp->sk, data_len + TLS_HEADER_SIZE,
+				     TCP_SKB_CB(skb)->seq + rxm->offset);
 	return data_len + TLS_HEADER_SIZE;
 
 read_failure:
@@ -1098,17 +2048,65 @@
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
+	struct sk_psock *psock;
 
 	strp_data_ready(&ctx->strp);
+
+	psock = sk_psock_get(sk);
+	if (psock && !list_empty(&psock->ingress_msg)) {
+		ctx->saved_data_ready(sk);
+		sk_psock_put(sk, psock);
+	}
 }
 
-void tls_sw_free_resources_tx(struct sock *sk)
+void tls_sw_cancel_work_tx(struct tls_context *tls_ctx)
+{
+	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
+
+	set_bit(BIT_TX_CLOSING, &ctx->tx_bitmask);
+	set_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask);
+	cancel_delayed_work_sync(&ctx->tx_work.work);
+}
+
+void tls_sw_release_resources_tx(struct sock *sk)
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
+	struct tls_rec *rec, *tmp;
+
+	/* Wait for any pending async encryptions to complete */
+	smp_store_mb(ctx->async_notify, true);
+	if (atomic_read(&ctx->encrypt_pending))
+		crypto_wait_req(-EINPROGRESS, &ctx->async_wait);
+
+	tls_tx_records(sk, -1);
+
+	/* Free up un-sent records in tx_list. First, free
+	 * the partially sent record if any at head of tx_list.
+	 */
+	if (tls_ctx->partially_sent_record) {
+		tls_free_partial_record(sk, tls_ctx);
+		rec = list_first_entry(&ctx->tx_list,
+				       struct tls_rec, list);
+		list_del(&rec->list);
+		sk_msg_free(sk, &rec->msg_plaintext);
+		kfree(rec);
+	}
+
+	list_for_each_entry_safe(rec, tmp, &ctx->tx_list, list) {
+		list_del(&rec->list);
+		sk_msg_free(sk, &rec->msg_encrypted);
+		sk_msg_free(sk, &rec->msg_plaintext);
+		kfree(rec);
+	}
 
 	crypto_free_aead(ctx->aead_send);
-	tls_free_both_sg(sk);
+	tls_free_open_rec(sk);
+}
+
+void tls_sw_free_ctx_tx(struct tls_context *tls_ctx)
+{
+	struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
 
 	kfree(ctx);
 }
@@ -1118,41 +2116,114 @@
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
 	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
 
+	kfree(tls_ctx->rx.rec_seq);
+	kfree(tls_ctx->rx.iv);
+
 	if (ctx->aead_recv) {
 		kfree_skb(ctx->recv_pkt);
 		ctx->recv_pkt = NULL;
+		skb_queue_purge(&ctx->rx_list);
 		crypto_free_aead(ctx->aead_recv);
 		strp_stop(&ctx->strp);
-		write_lock_bh(&sk->sk_callback_lock);
-		sk->sk_data_ready = ctx->saved_data_ready;
-		write_unlock_bh(&sk->sk_callback_lock);
-		release_sock(sk);
-		strp_done(&ctx->strp);
-		lock_sock(sk);
+		/* If tls_sw_strparser_arm() was not called (cleanup paths)
+		 * we still want to strp_stop(), but sk->sk_data_ready was
+		 * never swapped.
+		 */
+		if (ctx->saved_data_ready) {
+			write_lock_bh(&sk->sk_callback_lock);
+			sk->sk_data_ready = ctx->saved_data_ready;
+			write_unlock_bh(&sk->sk_callback_lock);
+		}
 	}
 }
 
+void tls_sw_strparser_done(struct tls_context *tls_ctx)
+{
+	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
+
+	strp_done(&ctx->strp);
+}
+
+void tls_sw_free_ctx_rx(struct tls_context *tls_ctx)
+{
+	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
+
+	kfree(ctx);
+}
+
 void tls_sw_free_resources_rx(struct sock *sk)
 {
 	struct tls_context *tls_ctx = tls_get_ctx(sk);
-	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
 
 	tls_sw_release_resources_rx(sk);
+	tls_sw_free_ctx_rx(tls_ctx);
+}
 
-	kfree(ctx);
+/* The work handler to transmitt the encrypted records in tx_list */
+static void tx_work_handler(struct work_struct *work)
+{
+	struct delayed_work *delayed_work = to_delayed_work(work);
+	struct tx_work *tx_work = container_of(delayed_work,
+					       struct tx_work, work);
+	struct sock *sk = tx_work->sk;
+	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_sw_context_tx *ctx;
+
+	if (unlikely(!tls_ctx))
+		return;
+
+	ctx = tls_sw_ctx_tx(tls_ctx);
+	if (test_bit(BIT_TX_CLOSING, &ctx->tx_bitmask))
+		return;
+
+	if (!test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask))
+		return;
+	mutex_lock(&tls_ctx->tx_lock);
+	lock_sock(sk);
+	tls_tx_records(sk, -1);
+	release_sock(sk);
+	mutex_unlock(&tls_ctx->tx_lock);
+}
+
+void tls_sw_write_space(struct sock *sk, struct tls_context *ctx)
+{
+	struct tls_sw_context_tx *tx_ctx = tls_sw_ctx_tx(ctx);
+
+	/* Schedule the transmission if tx list is ready */
+	if (is_tx_ready(tx_ctx) &&
+	    !test_and_set_bit(BIT_TX_SCHEDULED, &tx_ctx->tx_bitmask))
+		schedule_delayed_work(&tx_ctx->tx_work.work, 0);
+}
+
+void tls_sw_strparser_arm(struct sock *sk, struct tls_context *tls_ctx)
+{
+	struct tls_sw_context_rx *rx_ctx = tls_sw_ctx_rx(tls_ctx);
+
+	write_lock_bh(&sk->sk_callback_lock);
+	rx_ctx->saved_data_ready = sk->sk_data_ready;
+	sk->sk_data_ready = tls_data_ready;
+	write_unlock_bh(&sk->sk_callback_lock);
+
+	strp_check_rcv(&rx_ctx->strp);
 }
 
 int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
 {
+	struct tls_context *tls_ctx = tls_get_ctx(sk);
+	struct tls_prot_info *prot = &tls_ctx->prot_info;
 	struct tls_crypto_info *crypto_info;
 	struct tls12_crypto_info_aes_gcm_128 *gcm_128_info;
+	struct tls12_crypto_info_aes_gcm_256 *gcm_256_info;
+	struct tls12_crypto_info_aes_ccm_128 *ccm_128_info;
 	struct tls_sw_context_tx *sw_ctx_tx = NULL;
 	struct tls_sw_context_rx *sw_ctx_rx = NULL;
 	struct cipher_context *cctx;
 	struct crypto_aead **aead;
 	struct strp_callbacks cb;
-	u16 nonce_size, tag_size, iv_size, rec_seq_size;
-	char *iv, *rec_seq;
+	u16 nonce_size, tag_size, iv_size, rec_seq_size, salt_size;
+	struct crypto_tfm *tfm;
+	char *iv, *rec_seq, *key, *salt, *cipher_name;
+	size_t keysize;
 	int rc = 0;
 
 	if (!ctx) {
@@ -1191,10 +2262,14 @@
 		crypto_info = &ctx->crypto_send.info;
 		cctx = &ctx->tx;
 		aead = &sw_ctx_tx->aead_send;
+		INIT_LIST_HEAD(&sw_ctx_tx->tx_list);
+		INIT_DELAYED_WORK(&sw_ctx_tx->tx_work.work, tx_work_handler);
+		sw_ctx_tx->tx_work.sk = sk;
 	} else {
 		crypto_init_wait(&sw_ctx_rx->async_wait);
 		crypto_info = &ctx->crypto_recv.info;
 		cctx = &ctx->rx;
+		skb_queue_head_init(&sw_ctx_rx->rx_list);
 		aead = &sw_ctx_rx->aead_recv;
 	}
 
@@ -1209,6 +2284,45 @@
 		 ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->rec_seq;
 		gcm_128_info =
 			(struct tls12_crypto_info_aes_gcm_128 *)crypto_info;
+		keysize = TLS_CIPHER_AES_GCM_128_KEY_SIZE;
+		key = gcm_128_info->key;
+		salt = gcm_128_info->salt;
+		salt_size = TLS_CIPHER_AES_GCM_128_SALT_SIZE;
+		cipher_name = "gcm(aes)";
+		break;
+	}
+	case TLS_CIPHER_AES_GCM_256: {
+		nonce_size = TLS_CIPHER_AES_GCM_256_IV_SIZE;
+		tag_size = TLS_CIPHER_AES_GCM_256_TAG_SIZE;
+		iv_size = TLS_CIPHER_AES_GCM_256_IV_SIZE;
+		iv = ((struct tls12_crypto_info_aes_gcm_256 *)crypto_info)->iv;
+		rec_seq_size = TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE;
+		rec_seq =
+		 ((struct tls12_crypto_info_aes_gcm_256 *)crypto_info)->rec_seq;
+		gcm_256_info =
+			(struct tls12_crypto_info_aes_gcm_256 *)crypto_info;
+		keysize = TLS_CIPHER_AES_GCM_256_KEY_SIZE;
+		key = gcm_256_info->key;
+		salt = gcm_256_info->salt;
+		salt_size = TLS_CIPHER_AES_GCM_256_SALT_SIZE;
+		cipher_name = "gcm(aes)";
+		break;
+	}
+	case TLS_CIPHER_AES_CCM_128: {
+		nonce_size = TLS_CIPHER_AES_CCM_128_IV_SIZE;
+		tag_size = TLS_CIPHER_AES_CCM_128_TAG_SIZE;
+		iv_size = TLS_CIPHER_AES_CCM_128_IV_SIZE;
+		iv = ((struct tls12_crypto_info_aes_ccm_128 *)crypto_info)->iv;
+		rec_seq_size = TLS_CIPHER_AES_CCM_128_REC_SEQ_SIZE;
+		rec_seq =
+		((struct tls12_crypto_info_aes_ccm_128 *)crypto_info)->rec_seq;
+		ccm_128_info =
+		(struct tls12_crypto_info_aes_ccm_128 *)crypto_info;
+		keysize = TLS_CIPHER_AES_CCM_128_KEY_SIZE;
+		key = ccm_128_info->key;
+		salt = ccm_128_info->salt;
+		salt_size = TLS_CIPHER_AES_CCM_128_SALT_SIZE;
+		cipher_name = "ccm(aes)";
 		break;
 	}
 	default:
@@ -1216,53 +2330,47 @@
 		goto free_priv;
 	}
 
-	/* Sanity-check the IV size for stack allocations. */
-	if (iv_size > MAX_IV_SIZE || nonce_size > MAX_IV_SIZE) {
+	/* Sanity-check the sizes for stack allocations. */
+	if (iv_size > MAX_IV_SIZE || nonce_size > MAX_IV_SIZE ||
+	    rec_seq_size > TLS_MAX_REC_SEQ_SIZE) {
 		rc = -EINVAL;
 		goto free_priv;
 	}
 
-	cctx->prepend_size = TLS_HEADER_SIZE + nonce_size;
-	cctx->tag_size = tag_size;
-	cctx->overhead_size = cctx->prepend_size + cctx->tag_size;
-	cctx->iv_size = iv_size;
-	cctx->iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
-			   GFP_KERNEL);
+	if (crypto_info->version == TLS_1_3_VERSION) {
+		nonce_size = 0;
+		prot->aad_size = TLS_HEADER_SIZE;
+		prot->tail_size = 1;
+	} else {
+		prot->aad_size = TLS_AAD_SPACE_SIZE;
+		prot->tail_size = 0;
+	}
+
+	prot->version = crypto_info->version;
+	prot->cipher_type = crypto_info->cipher_type;
+	prot->prepend_size = TLS_HEADER_SIZE + nonce_size;
+	prot->tag_size = tag_size;
+	prot->overhead_size = prot->prepend_size +
+			      prot->tag_size + prot->tail_size;
+	prot->iv_size = iv_size;
+	prot->salt_size = salt_size;
+	cctx->iv = kmalloc(iv_size + salt_size, GFP_KERNEL);
 	if (!cctx->iv) {
 		rc = -ENOMEM;
 		goto free_priv;
 	}
-	memcpy(cctx->iv, gcm_128_info->salt, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
-	memcpy(cctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
-	cctx->rec_seq_size = rec_seq_size;
+	/* Note: 128 & 256 bit salt are the same size */
+	prot->rec_seq_size = rec_seq_size;
+	memcpy(cctx->iv, salt, salt_size);
+	memcpy(cctx->iv + salt_size, iv, iv_size);
 	cctx->rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL);
 	if (!cctx->rec_seq) {
 		rc = -ENOMEM;
 		goto free_iv;
 	}
 
-	if (sw_ctx_tx) {
-		sg_init_table(sw_ctx_tx->sg_encrypted_data,
-			      ARRAY_SIZE(sw_ctx_tx->sg_encrypted_data));
-		sg_init_table(sw_ctx_tx->sg_plaintext_data,
-			      ARRAY_SIZE(sw_ctx_tx->sg_plaintext_data));
-
-		sg_init_table(sw_ctx_tx->sg_aead_in, 2);
-		sg_set_buf(&sw_ctx_tx->sg_aead_in[0], sw_ctx_tx->aad_space,
-			   sizeof(sw_ctx_tx->aad_space));
-		sg_unmark_end(&sw_ctx_tx->sg_aead_in[1]);
-		sg_chain(sw_ctx_tx->sg_aead_in, 2,
-			 sw_ctx_tx->sg_plaintext_data);
-		sg_init_table(sw_ctx_tx->sg_aead_out, 2);
-		sg_set_buf(&sw_ctx_tx->sg_aead_out[0], sw_ctx_tx->aad_space,
-			   sizeof(sw_ctx_tx->aad_space));
-		sg_unmark_end(&sw_ctx_tx->sg_aead_out[1]);
-		sg_chain(sw_ctx_tx->sg_aead_out, 2,
-			 sw_ctx_tx->sg_encrypted_data);
-	}
-
 	if (!*aead) {
-		*aead = crypto_alloc_aead("gcm(aes)", 0, 0);
+		*aead = crypto_alloc_aead(cipher_name, 0, 0);
 		if (IS_ERR(*aead)) {
 			rc = PTR_ERR(*aead);
 			*aead = NULL;
@@ -1272,31 +2380,30 @@
 
 	ctx->push_pending_record = tls_sw_push_pending_record;
 
-	rc = crypto_aead_setkey(*aead, gcm_128_info->key,
-				TLS_CIPHER_AES_GCM_128_KEY_SIZE);
+	rc = crypto_aead_setkey(*aead, key, keysize);
+
 	if (rc)
 		goto free_aead;
 
-	rc = crypto_aead_setauthsize(*aead, cctx->tag_size);
+	rc = crypto_aead_setauthsize(*aead, prot->tag_size);
 	if (rc)
 		goto free_aead;
 
 	if (sw_ctx_rx) {
+		tfm = crypto_aead_tfm(sw_ctx_rx->aead_recv);
+
+		if (crypto_info->version == TLS_1_3_VERSION)
+			sw_ctx_rx->async_capable = false;
+		else
+			sw_ctx_rx->async_capable =
+				tfm->__crt_alg->cra_flags & CRYPTO_ALG_ASYNC;
+
 		/* Set up strparser */
 		memset(&cb, 0, sizeof(cb));
 		cb.rcv_msg = tls_queue;
 		cb.parse_msg = tls_read_size;
 
 		strp_init(&sw_ctx_rx->strp, sk, &cb);
-
-		write_lock_bh(&sk->sk_callback_lock);
-		sw_ctx_rx->saved_data_ready = sk->sk_data_ready;
-		sk->sk_data_ready = tls_data_ready;
-		write_unlock_bh(&sk->sk_callback_lock);
-
-		sw_ctx_rx->sk_poll = sk->sk_socket->ops->poll;
-
-		strp_check_rcv(&sw_ctx_rx->strp);
 	}
 
 	goto out;