Update Linux to v5.4.2

Change-Id: Idf6911045d9d382da2cfe01b1edff026404ac8fd
diff --git a/drivers/vhost/Kconfig b/drivers/vhost/Kconfig
index b580885..3d03ccb 100644
--- a/drivers/vhost/Kconfig
+++ b/drivers/vhost/Kconfig
@@ -1,3 +1,4 @@
+# SPDX-License-Identifier: GPL-2.0-only
 config VHOST_NET
 	tristate "Host kernel accelerator for virtio net"
 	depends on NET && EVENTFD && (TUN || !TUN) && (TAP || !TAP)
diff --git a/drivers/vhost/Kconfig.vringh b/drivers/vhost/Kconfig.vringh
index 6a4490c..c1fe36a 100644
--- a/drivers/vhost/Kconfig.vringh
+++ b/drivers/vhost/Kconfig.vringh
@@ -1,3 +1,4 @@
+# SPDX-License-Identifier: GPL-2.0-only
 config VHOST_RING
 	tristate
 	---help---
diff --git a/drivers/vhost/net.c b/drivers/vhost/net.c
index 4e656f8..1a2dd53 100644
--- a/drivers/vhost/net.c
+++ b/drivers/vhost/net.c
@@ -1,8 +1,7 @@
+// SPDX-License-Identifier: GPL-2.0-only
 /* Copyright (C) 2009 Red Hat, Inc.
  * Author: Michael S. Tsirkin <mst@redhat.com>
  *
- * This work is licensed under the terms of the GNU GPL, version 2.
- *
  * virtio-net server in host kernel.
  */
 
@@ -36,7 +35,7 @@
 
 #include "vhost.h"
 
-static int experimental_zcopytx = 1;
+static int experimental_zcopytx = 0;
 module_param(experimental_zcopytx, int, 0444);
 MODULE_PARM_DESC(experimental_zcopytx, "Enable Zero Copy TX;"
 		                       " 1 -Enable; 0 - Disable");
@@ -116,6 +115,8 @@
 	 * For RX, number of batched heads
 	 */
 	int done_idx;
+	/* Number of XDP frames batched */
+	int batched_xdp;
 	/* an array of userspace buffers info */
 	struct ubuf_info *ubuf_info;
 	/* Reference counting for outstanding ubufs.
@@ -123,6 +124,8 @@
 	struct vhost_net_ubuf_ref *ubufs;
 	struct ptr_ring *rx_ring;
 	struct vhost_net_buf rxq;
+	/* Batched XDP buffs */
+	struct xdp_buff *xdp;
 };
 
 struct vhost_net {
@@ -137,6 +140,10 @@
 	unsigned tx_zcopy_err;
 	/* Flush in progress. Protected by tx vq lock. */
 	bool tx_flush;
+	/* Private page frag */
+	struct page_frag page_frag;
+	/* Refcount bias of page frag */
+	int refcnt_bias;
 };
 
 static unsigned vhost_net_zcopy_mask __read_mostly;
@@ -338,6 +345,11 @@
 		sock_flag(sock->sk, SOCK_ZEROCOPY);
 }
 
+static bool vhost_sock_xdp(struct socket *sock)
+{
+	return sock_flag(sock->sk, SOCK_XDP);
+}
+
 /* In case of DMA done not in order in lower device driver for some reason.
  * upend_idx is used to track end of used idx, done_idx is used to track head
  * of used idx. Once lower device DMA done contiguously, we will signal KVM
@@ -444,32 +456,126 @@
 	nvq->done_idx = 0;
 }
 
-static int vhost_net_tx_get_vq_desc(struct vhost_net *net,
-				    struct vhost_net_virtqueue *nvq,
-				    unsigned int *out_num, unsigned int *in_num,
-				    bool *busyloop_intr)
+static void vhost_tx_batch(struct vhost_net *net,
+			   struct vhost_net_virtqueue *nvq,
+			   struct socket *sock,
+			   struct msghdr *msghdr)
 {
-	struct vhost_virtqueue *vq = &nvq->vq;
-	unsigned long uninitialized_var(endtime);
-	int r = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),
+	struct tun_msg_ctl ctl = {
+		.type = TUN_MSG_PTR,
+		.num = nvq->batched_xdp,
+		.ptr = nvq->xdp,
+	};
+	int err;
+
+	if (nvq->batched_xdp == 0)
+		goto signal_used;
+
+	msghdr->msg_control = &ctl;
+	err = sock->ops->sendmsg(sock, msghdr, 0);
+	if (unlikely(err < 0)) {
+		vq_err(&nvq->vq, "Fail to batch sending packets\n");
+		return;
+	}
+
+signal_used:
+	vhost_net_signal_used(nvq);
+	nvq->batched_xdp = 0;
+}
+
+static int sock_has_rx_data(struct socket *sock)
+{
+	if (unlikely(!sock))
+		return 0;
+
+	if (sock->ops->peek_len)
+		return sock->ops->peek_len(sock);
+
+	return skb_queue_empty(&sock->sk->sk_receive_queue);
+}
+
+static void vhost_net_busy_poll_try_queue(struct vhost_net *net,
+					  struct vhost_virtqueue *vq)
+{
+	if (!vhost_vq_avail_empty(&net->dev, vq)) {
+		vhost_poll_queue(&vq->poll);
+	} else if (unlikely(vhost_enable_notify(&net->dev, vq))) {
+		vhost_disable_notify(&net->dev, vq);
+		vhost_poll_queue(&vq->poll);
+	}
+}
+
+static void vhost_net_busy_poll(struct vhost_net *net,
+				struct vhost_virtqueue *rvq,
+				struct vhost_virtqueue *tvq,
+				bool *busyloop_intr,
+				bool poll_rx)
+{
+	unsigned long busyloop_timeout;
+	unsigned long endtime;
+	struct socket *sock;
+	struct vhost_virtqueue *vq = poll_rx ? tvq : rvq;
+
+	/* Try to hold the vq mutex of the paired virtqueue. We can't
+	 * use mutex_lock() here since we could not guarantee a
+	 * consistenet lock ordering.
+	 */
+	if (!mutex_trylock(&vq->mutex))
+		return;
+
+	vhost_disable_notify(&net->dev, vq);
+	sock = rvq->private_data;
+
+	busyloop_timeout = poll_rx ? rvq->busyloop_timeout:
+				     tvq->busyloop_timeout;
+
+	preempt_disable();
+	endtime = busy_clock() + busyloop_timeout;
+
+	while (vhost_can_busy_poll(endtime)) {
+		if (vhost_has_work(&net->dev)) {
+			*busyloop_intr = true;
+			break;
+		}
+
+		if ((sock_has_rx_data(sock) &&
+		     !vhost_vq_avail_empty(&net->dev, rvq)) ||
+		    !vhost_vq_avail_empty(&net->dev, tvq))
+			break;
+
+		cpu_relax();
+	}
+
+	preempt_enable();
+
+	if (poll_rx || sock_has_rx_data(sock))
+		vhost_net_busy_poll_try_queue(net, vq);
+	else if (!poll_rx) /* On tx here, sock has no rx data. */
+		vhost_enable_notify(&net->dev, rvq);
+
+	mutex_unlock(&vq->mutex);
+}
+
+static int vhost_net_tx_get_vq_desc(struct vhost_net *net,
+				    struct vhost_net_virtqueue *tnvq,
+				    unsigned int *out_num, unsigned int *in_num,
+				    struct msghdr *msghdr, bool *busyloop_intr)
+{
+	struct vhost_net_virtqueue *rnvq = &net->vqs[VHOST_NET_VQ_RX];
+	struct vhost_virtqueue *rvq = &rnvq->vq;
+	struct vhost_virtqueue *tvq = &tnvq->vq;
+
+	int r = vhost_get_vq_desc(tvq, tvq->iov, ARRAY_SIZE(tvq->iov),
 				  out_num, in_num, NULL, NULL);
 
-	if (r == vq->num && vq->busyloop_timeout) {
-		if (!vhost_sock_zcopy(vq->private_data))
-			vhost_net_signal_used(nvq);
-		preempt_disable();
-		endtime = busy_clock() + vq->busyloop_timeout;
-		while (vhost_can_busy_poll(endtime)) {
-			if (vhost_has_work(vq->dev)) {
-				*busyloop_intr = true;
-				break;
-			}
-			if (!vhost_vq_avail_empty(vq->dev, vq))
-				break;
-			cpu_relax();
-		}
-		preempt_enable();
-		r = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),
+	if (r == tvq->num && tvq->busyloop_timeout) {
+		/* Flush batched packets first */
+		if (!vhost_sock_zcopy(tvq->private_data))
+			vhost_tx_batch(net, tnvq, tvq->private_data, msghdr);
+
+		vhost_net_busy_poll(net, rvq, tvq, busyloop_intr, false);
+
+		r = vhost_get_vq_desc(tvq, tvq->iov, ARRAY_SIZE(tvq->iov),
 				      out_num, in_num, NULL, NULL);
 	}
 
@@ -497,12 +603,6 @@
 	return iov_iter_count(iter);
 }
 
-static bool vhost_exceeds_weight(int pkts, int total_len)
-{
-	return total_len >= VHOST_NET_WEIGHT ||
-	       pkts >= VHOST_NET_PKT_WEIGHT;
-}
-
 static int get_tx_bufs(struct vhost_net *net,
 		       struct vhost_net_virtqueue *nvq,
 		       struct msghdr *msg,
@@ -512,7 +612,7 @@
 	struct vhost_virtqueue *vq = &nvq->vq;
 	int ret;
 
-	ret = vhost_net_tx_get_vq_desc(net, nvq, out, in, busyloop_intr);
+	ret = vhost_net_tx_get_vq_desc(net, nvq, out, in, msg, busyloop_intr);
 
 	if (ret < 0 || ret == vq->num)
 		return ret;
@@ -540,6 +640,120 @@
 	       !vhost_vq_avail_empty(vq->dev, vq);
 }
 
+#define SKB_FRAG_PAGE_ORDER     get_order(32768)
+
+static bool vhost_net_page_frag_refill(struct vhost_net *net, unsigned int sz,
+				       struct page_frag *pfrag, gfp_t gfp)
+{
+	if (pfrag->page) {
+		if (pfrag->offset + sz <= pfrag->size)
+			return true;
+		__page_frag_cache_drain(pfrag->page, net->refcnt_bias);
+	}
+
+	pfrag->offset = 0;
+	net->refcnt_bias = 0;
+	if (SKB_FRAG_PAGE_ORDER) {
+		/* Avoid direct reclaim but allow kswapd to wake */
+		pfrag->page = alloc_pages((gfp & ~__GFP_DIRECT_RECLAIM) |
+					  __GFP_COMP | __GFP_NOWARN |
+					  __GFP_NORETRY,
+					  SKB_FRAG_PAGE_ORDER);
+		if (likely(pfrag->page)) {
+			pfrag->size = PAGE_SIZE << SKB_FRAG_PAGE_ORDER;
+			goto done;
+		}
+	}
+	pfrag->page = alloc_page(gfp);
+	if (likely(pfrag->page)) {
+		pfrag->size = PAGE_SIZE;
+		goto done;
+	}
+	return false;
+
+done:
+	net->refcnt_bias = USHRT_MAX;
+	page_ref_add(pfrag->page, USHRT_MAX - 1);
+	return true;
+}
+
+#define VHOST_NET_RX_PAD (NET_IP_ALIGN + NET_SKB_PAD)
+
+static int vhost_net_build_xdp(struct vhost_net_virtqueue *nvq,
+			       struct iov_iter *from)
+{
+	struct vhost_virtqueue *vq = &nvq->vq;
+	struct vhost_net *net = container_of(vq->dev, struct vhost_net,
+					     dev);
+	struct socket *sock = vq->private_data;
+	struct page_frag *alloc_frag = &net->page_frag;
+	struct virtio_net_hdr *gso;
+	struct xdp_buff *xdp = &nvq->xdp[nvq->batched_xdp];
+	struct tun_xdp_hdr *hdr;
+	size_t len = iov_iter_count(from);
+	int headroom = vhost_sock_xdp(sock) ? XDP_PACKET_HEADROOM : 0;
+	int buflen = SKB_DATA_ALIGN(sizeof(struct skb_shared_info));
+	int pad = SKB_DATA_ALIGN(VHOST_NET_RX_PAD + headroom + nvq->sock_hlen);
+	int sock_hlen = nvq->sock_hlen;
+	void *buf;
+	int copied;
+
+	if (unlikely(len < nvq->sock_hlen))
+		return -EFAULT;
+
+	if (SKB_DATA_ALIGN(len + pad) +
+	    SKB_DATA_ALIGN(sizeof(struct skb_shared_info)) > PAGE_SIZE)
+		return -ENOSPC;
+
+	buflen += SKB_DATA_ALIGN(len + pad);
+	alloc_frag->offset = ALIGN((u64)alloc_frag->offset, SMP_CACHE_BYTES);
+	if (unlikely(!vhost_net_page_frag_refill(net, buflen,
+						 alloc_frag, GFP_KERNEL)))
+		return -ENOMEM;
+
+	buf = (char *)page_address(alloc_frag->page) + alloc_frag->offset;
+	copied = copy_page_from_iter(alloc_frag->page,
+				     alloc_frag->offset +
+				     offsetof(struct tun_xdp_hdr, gso),
+				     sock_hlen, from);
+	if (copied != sock_hlen)
+		return -EFAULT;
+
+	hdr = buf;
+	gso = &hdr->gso;
+
+	if ((gso->flags & VIRTIO_NET_HDR_F_NEEDS_CSUM) &&
+	    vhost16_to_cpu(vq, gso->csum_start) +
+	    vhost16_to_cpu(vq, gso->csum_offset) + 2 >
+	    vhost16_to_cpu(vq, gso->hdr_len)) {
+		gso->hdr_len = cpu_to_vhost16(vq,
+			       vhost16_to_cpu(vq, gso->csum_start) +
+			       vhost16_to_cpu(vq, gso->csum_offset) + 2);
+
+		if (vhost16_to_cpu(vq, gso->hdr_len) > len)
+			return -EINVAL;
+	}
+
+	len -= sock_hlen;
+	copied = copy_page_from_iter(alloc_frag->page,
+				     alloc_frag->offset + pad,
+				     len, from);
+	if (copied != len)
+		return -EFAULT;
+
+	xdp->data_hard_start = buf;
+	xdp->data = buf + pad;
+	xdp->data_end = xdp->data + len;
+	hdr->buflen = buflen;
+
+	--net->refcnt_bias;
+	alloc_frag->offset += buflen;
+
+	++nvq->batched_xdp;
+
+	return 0;
+}
+
 static void handle_tx_copy(struct vhost_net *net, struct socket *sock)
 {
 	struct vhost_net_virtqueue *nvq = &net->vqs[VHOST_NET_VQ_TX];
@@ -556,10 +770,14 @@
 	size_t len, total_len = 0;
 	int err;
 	int sent_pkts = 0;
+	bool sock_can_batch = (sock->sk->sk_sndbuf == INT_MAX);
 
-	for (;;) {
+	do {
 		bool busyloop_intr = false;
 
+		if (nvq->done_idx == VHOST_NET_BATCH)
+			vhost_tx_batch(net, nvq, sock, &msg);
+
 		head = get_tx_bufs(net, nvq, &msg, &out, &in, &len,
 				   &busyloop_intr);
 		/* On error, stop handling until the next kick. */
@@ -577,14 +795,34 @@
 			break;
 		}
 
-		vq->heads[nvq->done_idx].id = cpu_to_vhost32(vq, head);
-		vq->heads[nvq->done_idx].len = 0;
-
 		total_len += len;
-		if (tx_can_batch(vq, total_len))
-			msg.msg_flags |= MSG_MORE;
-		else
-			msg.msg_flags &= ~MSG_MORE;
+
+		/* For simplicity, TX batching is only enabled if
+		 * sndbuf is unlimited.
+		 */
+		if (sock_can_batch) {
+			err = vhost_net_build_xdp(nvq, &msg.msg_iter);
+			if (!err) {
+				goto done;
+			} else if (unlikely(err != -ENOSPC)) {
+				vhost_tx_batch(net, nvq, sock, &msg);
+				vhost_discard_vq_desc(vq, 1);
+				vhost_net_enable_vq(net, vq);
+				break;
+			}
+
+			/* We can't build XDP buff, go for single
+			 * packet path but let's flush batched
+			 * packets.
+			 */
+			vhost_tx_batch(net, nvq, sock, &msg);
+			msg.msg_control = NULL;
+		} else {
+			if (tx_can_batch(vq, total_len))
+				msg.msg_flags |= MSG_MORE;
+			else
+				msg.msg_flags &= ~MSG_MORE;
+		}
 
 		/* TODO: Check specific error and bomb out unless ENOBUFS? */
 		err = sock->ops->sendmsg(sock, &msg, len);
@@ -596,15 +834,13 @@
 		if (err != len)
 			pr_debug("Truncated TX packet: len %d != %zd\n",
 				 err, len);
-		if (++nvq->done_idx >= VHOST_NET_BATCH)
-			vhost_net_signal_used(nvq);
-		if (vhost_exceeds_weight(++sent_pkts, total_len)) {
-			vhost_poll_queue(&vq->poll);
-			break;
-		}
-	}
+done:
+		vq->heads[nvq->done_idx].id = cpu_to_vhost32(vq, head);
+		vq->heads[nvq->done_idx].len = 0;
+		++nvq->done_idx;
+	} while (likely(!vhost_exceeds_weight(vq, ++sent_pkts, total_len)));
 
-	vhost_net_signal_used(nvq);
+	vhost_tx_batch(net, nvq, sock, &msg);
 }
 
 static void handle_tx_zerocopy(struct vhost_net *net, struct socket *sock)
@@ -620,13 +856,14 @@
 		.msg_controllen = 0,
 		.msg_flags = MSG_DONTWAIT,
 	};
+	struct tun_msg_ctl ctl;
 	size_t len, total_len = 0;
 	int err;
 	struct vhost_net_ubuf_ref *uninitialized_var(ubufs);
 	bool zcopy_used;
 	int sent_pkts = 0;
 
-	for (;;) {
+	do {
 		bool busyloop_intr;
 
 		/* Release DMAs done buffers first */
@@ -664,8 +901,10 @@
 			ubuf->ctx = nvq->ubufs;
 			ubuf->desc = nvq->upend_idx;
 			refcount_set(&ubuf->refcnt, 1);
-			msg.msg_control = ubuf;
-			msg.msg_controllen = sizeof(ubuf);
+			msg.msg_control = &ctl;
+			ctl.type = TUN_MSG_UBUF;
+			ctl.ptr = ubuf;
+			msg.msg_controllen = sizeof(ctl);
 			ubufs = nvq->ubufs;
 			atomic_inc(&ubufs->refcount);
 			nvq->upend_idx = (nvq->upend_idx + 1) % UIO_MAXIOV;
@@ -701,11 +940,7 @@
 		else
 			vhost_zerocopy_signal_used(net, vq);
 		vhost_net_tx_packet(net);
-		if (unlikely(vhost_exceeds_weight(++sent_pkts, total_len))) {
-			vhost_poll_queue(&vq->poll);
-			break;
-		}
-	}
+	} while (likely(!vhost_exceeds_weight(vq, ++sent_pkts, total_len)));
 }
 
 /* Expects to be always run from workqueue - which acts as
@@ -716,12 +951,12 @@
 	struct vhost_virtqueue *vq = &nvq->vq;
 	struct socket *sock;
 
-	mutex_lock(&vq->mutex);
+	mutex_lock_nested(&vq->mutex, VHOST_NET_VQ_TX);
 	sock = vq->private_data;
 	if (!sock)
 		goto out;
 
-	if (!vq_iotlb_prefetch(vq))
+	if (!vq_meta_prefetch(vq))
 		goto out;
 
 	vhost_disable_notify(&net->dev, vq);
@@ -757,16 +992,6 @@
 	return len;
 }
 
-static int sk_has_rx_data(struct sock *sk)
-{
-	struct socket *sock = sk->sk_socket;
-
-	if (sock->ops->peek_len)
-		return sock->ops->peek_len(sock);
-
-	return skb_queue_empty(&sk->sk_receive_queue);
-}
-
 static int vhost_net_rx_peek_head_len(struct vhost_net *net, struct sock *sk,
 				      bool *busyloop_intr)
 {
@@ -774,41 +999,13 @@
 	struct vhost_net_virtqueue *tnvq = &net->vqs[VHOST_NET_VQ_TX];
 	struct vhost_virtqueue *rvq = &rnvq->vq;
 	struct vhost_virtqueue *tvq = &tnvq->vq;
-	unsigned long uninitialized_var(endtime);
 	int len = peek_head_len(rnvq, sk);
 
-	if (!len && tvq->busyloop_timeout) {
+	if (!len && rvq->busyloop_timeout) {
 		/* Flush batched heads first */
 		vhost_net_signal_used(rnvq);
 		/* Both tx vq and rx socket were polled here */
-		mutex_lock_nested(&tvq->mutex, 1);
-		vhost_disable_notify(&net->dev, tvq);
-
-		preempt_disable();
-		endtime = busy_clock() + tvq->busyloop_timeout;
-
-		while (vhost_can_busy_poll(endtime)) {
-			if (vhost_has_work(&net->dev)) {
-				*busyloop_intr = true;
-				break;
-			}
-			if ((sk_has_rx_data(sk) &&
-			     !vhost_vq_avail_empty(&net->dev, rvq)) ||
-			    !vhost_vq_avail_empty(&net->dev, tvq))
-				break;
-			cpu_relax();
-		}
-
-		preempt_enable();
-
-		if (!vhost_vq_avail_empty(&net->dev, tvq)) {
-			vhost_poll_queue(&tvq->poll);
-		} else if (unlikely(vhost_enable_notify(&net->dev, tvq))) {
-			vhost_disable_notify(&net->dev, tvq);
-			vhost_poll_queue(&tvq->poll);
-		}
-
-		mutex_unlock(&tvq->mutex);
+		vhost_net_busy_poll(net, rvq, tvq, busyloop_intr, true);
 
 		len = peek_head_len(rnvq, sk);
 	}
@@ -923,12 +1120,12 @@
 	__virtio16 num_buffers;
 	int recv_pkts = 0;
 
-	mutex_lock_nested(&vq->mutex, 0);
+	mutex_lock_nested(&vq->mutex, VHOST_NET_VQ_RX);
 	sock = vq->private_data;
 	if (!sock)
 		goto out;
 
-	if (!vq_iotlb_prefetch(vq))
+	if (!vq_meta_prefetch(vq))
 		goto out;
 
 	vhost_disable_notify(&net->dev, vq);
@@ -941,8 +1138,11 @@
 		vq->log : NULL;
 	mergeable = vhost_has_feature(vq, VIRTIO_NET_F_MRG_RXBUF);
 
-	while ((sock_len = vhost_net_rx_peek_head_len(net, sock->sk,
-						      &busyloop_intr))) {
+	do {
+		sock_len = vhost_net_rx_peek_head_len(net, sock->sk,
+						      &busyloop_intr);
+		if (!sock_len)
+			break;
 		sock_len += sock_hlen;
 		vhost_len = sock_len + vhost_hlen;
 		headcount = get_rx_bufs(vq, vq->heads + nvq->done_idx,
@@ -1024,16 +1224,14 @@
 		if (nvq->done_idx > VHOST_NET_BATCH)
 			vhost_net_signal_used(nvq);
 		if (unlikely(vq_log))
-			vhost_log_write(vq, vq_log, log, vhost_len);
+			vhost_log_write(vq, vq_log, log, vhost_len,
+					vq->iov, in);
 		total_len += vhost_len;
-		if (unlikely(vhost_exceeds_weight(++recv_pkts, total_len))) {
-			vhost_poll_queue(&vq->poll);
-			goto out;
-		}
-	}
+	} while (likely(!vhost_exceeds_weight(vq, ++recv_pkts, total_len)));
+
 	if (unlikely(busyloop_intr))
 		vhost_poll_queue(&vq->poll);
-	else
+	else if (!sock_len)
 		vhost_net_enable_vq(net, vq);
 out:
 	vhost_net_signal_used(nvq);
@@ -1078,6 +1276,7 @@
 	struct vhost_dev *dev;
 	struct vhost_virtqueue **vqs;
 	void **queue;
+	struct xdp_buff *xdp;
 	int i;
 
 	n = kvmalloc(sizeof *n, GFP_KERNEL | __GFP_RETRY_MAYFAIL);
@@ -1098,6 +1297,15 @@
 	}
 	n->vqs[VHOST_NET_VQ_RX].rxq.queue = queue;
 
+	xdp = kmalloc_array(VHOST_NET_BATCH, sizeof(*xdp), GFP_KERNEL);
+	if (!xdp) {
+		kfree(vqs);
+		kvfree(n);
+		kfree(queue);
+		return -ENOMEM;
+	}
+	n->vqs[VHOST_NET_VQ_TX].xdp = xdp;
+
 	dev = &n->dev;
 	vqs[VHOST_NET_VQ_TX] = &n->vqs[VHOST_NET_VQ_TX].vq;
 	vqs[VHOST_NET_VQ_RX] = &n->vqs[VHOST_NET_VQ_RX].vq;
@@ -1108,17 +1316,22 @@
 		n->vqs[i].ubuf_info = NULL;
 		n->vqs[i].upend_idx = 0;
 		n->vqs[i].done_idx = 0;
+		n->vqs[i].batched_xdp = 0;
 		n->vqs[i].vhost_hlen = 0;
 		n->vqs[i].sock_hlen = 0;
 		n->vqs[i].rx_ring = NULL;
 		vhost_net_buf_init(&n->vqs[i].rxq);
 	}
-	vhost_dev_init(dev, vqs, VHOST_NET_VQ_MAX);
+	vhost_dev_init(dev, vqs, VHOST_NET_VQ_MAX,
+		       UIO_MAXIOV + VHOST_NET_BATCH,
+		       VHOST_NET_PKT_WEIGHT, VHOST_NET_WEIGHT);
 
 	vhost_poll_init(n->poll + VHOST_NET_VQ_TX, handle_tx_net, EPOLLOUT, dev);
 	vhost_poll_init(n->poll + VHOST_NET_VQ_RX, handle_rx_net, EPOLLIN, dev);
 
 	f->private_data = n;
+	n->page_frag.page = NULL;
+	n->refcnt_bias = 0;
 
 	return 0;
 }
@@ -1186,12 +1399,15 @@
 	if (rx_sock)
 		sockfd_put(rx_sock);
 	/* Make sure no callbacks are outstanding */
-	synchronize_rcu_bh();
+	synchronize_rcu();
 	/* We do an extra flush before freeing memory,
 	 * since jobs can re-queue themselves. */
 	vhost_net_flush(n);
 	kfree(n->vqs[VHOST_NET_VQ_RX].rxq.queue);
+	kfree(n->vqs[VHOST_NET_VQ_TX].xdp);
 	kfree(n->dev.vqs);
+	if (n->page_frag.page)
+		__page_frag_cache_drain(n->page_frag.page, n->refcnt_bias);
 	kvfree(n);
 	return 0;
 }
diff --git a/drivers/vhost/scsi.c b/drivers/vhost/scsi.c
index e7e3ae1..a9caf1b 100644
--- a/drivers/vhost/scsi.c
+++ b/drivers/vhost/scsi.c
@@ -57,6 +57,12 @@
 #define VHOST_SCSI_PREALLOC_UPAGES 2048
 #define VHOST_SCSI_PREALLOC_PROT_SGLS 2048
 
+/* Max number of requests before requeueing the job.
+ * Using this limit prevents one virtqueue from starving others with
+ * request.
+ */
+#define VHOST_SCSI_WEIGHT 256
+
 struct vhost_scsi_inflight {
 	/* Wait for the flush operation to finish */
 	struct completion comp;
@@ -203,6 +209,19 @@
 	int vs_events_nr; /* num of pending events, protected by vq->mutex */
 };
 
+/*
+ * Context for processing request and control queue operations.
+ */
+struct vhost_scsi_ctx {
+	int head;
+	unsigned int out, in;
+	size_t req_size, rsp_size;
+	size_t out_size, in_size;
+	u8 *target, *lunp;
+	void *req;
+	struct iov_iter out_iter;
+};
+
 static struct workqueue_struct *vhost_scsi_workqueue;
 
 /* Global spinlock to protect vhost_scsi TPG list for vhost IOCTL access */
@@ -272,11 +291,6 @@
 	return 0;
 }
 
-static char *vhost_scsi_get_fabric_name(void)
-{
-	return "vhost";
-}
-
 static char *vhost_scsi_get_fabric_wwn(struct se_portal_group *se_tpg)
 {
 	struct vhost_scsi_tpg *tpg = container_of(se_tpg,
@@ -338,11 +352,6 @@
 	return 0;
 }
 
-static int vhost_scsi_write_pending_status(struct se_cmd *se_cmd)
-{
-	return 0;
-}
-
 static void vhost_scsi_set_default_node_attrs(struct se_node_acl *nacl)
 {
 	return;
@@ -800,24 +809,120 @@
 		pr_err("Faulted on virtio_scsi_cmd_resp\n");
 }
 
+static int
+vhost_scsi_get_desc(struct vhost_scsi *vs, struct vhost_virtqueue *vq,
+		    struct vhost_scsi_ctx *vc)
+{
+	int ret = -ENXIO;
+
+	vc->head = vhost_get_vq_desc(vq, vq->iov,
+				     ARRAY_SIZE(vq->iov), &vc->out, &vc->in,
+				     NULL, NULL);
+
+	pr_debug("vhost_get_vq_desc: head: %d, out: %u in: %u\n",
+		 vc->head, vc->out, vc->in);
+
+	/* On error, stop handling until the next kick. */
+	if (unlikely(vc->head < 0))
+		goto done;
+
+	/* Nothing new?  Wait for eventfd to tell us they refilled. */
+	if (vc->head == vq->num) {
+		if (unlikely(vhost_enable_notify(&vs->dev, vq))) {
+			vhost_disable_notify(&vs->dev, vq);
+			ret = -EAGAIN;
+		}
+		goto done;
+	}
+
+	/*
+	 * Get the size of request and response buffers.
+	 * FIXME: Not correct for BIDI operation
+	 */
+	vc->out_size = iov_length(vq->iov, vc->out);
+	vc->in_size = iov_length(&vq->iov[vc->out], vc->in);
+
+	/*
+	 * Copy over the virtio-scsi request header, which for a
+	 * ANY_LAYOUT enabled guest may span multiple iovecs, or a
+	 * single iovec may contain both the header + outgoing
+	 * WRITE payloads.
+	 *
+	 * copy_from_iter() will advance out_iter, so that it will
+	 * point at the start of the outgoing WRITE payload, if
+	 * DMA_TO_DEVICE is set.
+	 */
+	iov_iter_init(&vc->out_iter, WRITE, vq->iov, vc->out, vc->out_size);
+	ret = 0;
+
+done:
+	return ret;
+}
+
+static int
+vhost_scsi_chk_size(struct vhost_virtqueue *vq, struct vhost_scsi_ctx *vc)
+{
+	if (unlikely(vc->in_size < vc->rsp_size)) {
+		vq_err(vq,
+		       "Response buf too small, need min %zu bytes got %zu",
+		       vc->rsp_size, vc->in_size);
+		return -EINVAL;
+	} else if (unlikely(vc->out_size < vc->req_size)) {
+		vq_err(vq,
+		       "Request buf too small, need min %zu bytes got %zu",
+		       vc->req_size, vc->out_size);
+		return -EIO;
+	}
+
+	return 0;
+}
+
+static int
+vhost_scsi_get_req(struct vhost_virtqueue *vq, struct vhost_scsi_ctx *vc,
+		   struct vhost_scsi_tpg **tpgp)
+{
+	int ret = -EIO;
+
+	if (unlikely(!copy_from_iter_full(vc->req, vc->req_size,
+					  &vc->out_iter))) {
+		vq_err(vq, "Faulted on copy_from_iter_full\n");
+	} else if (unlikely(*vc->lunp != 1)) {
+		/* virtio-scsi spec requires byte 0 of the lun to be 1 */
+		vq_err(vq, "Illegal virtio-scsi lun: %u\n", *vc->lunp);
+	} else {
+		struct vhost_scsi_tpg **vs_tpg, *tpg;
+
+		vs_tpg = vq->private_data;	/* validated at handler entry */
+
+		tpg = READ_ONCE(vs_tpg[*vc->target]);
+		if (unlikely(!tpg)) {
+			vq_err(vq, "Target 0x%x does not exist\n", *vc->target);
+		} else {
+			if (tpgp)
+				*tpgp = tpg;
+			ret = 0;
+		}
+	}
+
+	return ret;
+}
+
 static void
 vhost_scsi_handle_vq(struct vhost_scsi *vs, struct vhost_virtqueue *vq)
 {
 	struct vhost_scsi_tpg **vs_tpg, *tpg;
 	struct virtio_scsi_cmd_req v_req;
 	struct virtio_scsi_cmd_req_pi v_req_pi;
+	struct vhost_scsi_ctx vc;
 	struct vhost_scsi_cmd *cmd;
-	struct iov_iter out_iter, in_iter, prot_iter, data_iter;
+	struct iov_iter in_iter, prot_iter, data_iter;
 	u64 tag;
 	u32 exp_data_len, data_direction;
-	unsigned int out = 0, in = 0;
-	int head, ret, prot_bytes;
-	size_t req_size, rsp_size = sizeof(struct virtio_scsi_cmd_resp);
-	size_t out_size, in_size;
+	int ret, prot_bytes, c = 0;
 	u16 lun;
-	u8 *target, *lunp, task_attr;
+	u8 task_attr;
 	bool t10_pi = vhost_has_feature(vq, VIRTIO_SCSI_F_T10_PI);
-	void *req, *cdb;
+	void *cdb;
 
 	mutex_lock(&vq->mutex);
 	/*
@@ -828,85 +933,47 @@
 	if (!vs_tpg)
 		goto out;
 
+	memset(&vc, 0, sizeof(vc));
+	vc.rsp_size = sizeof(struct virtio_scsi_cmd_resp);
+
 	vhost_disable_notify(&vs->dev, vq);
 
-	for (;;) {
-		head = vhost_get_vq_desc(vq, vq->iov,
-					 ARRAY_SIZE(vq->iov), &out, &in,
-					 NULL, NULL);
-		pr_debug("vhost_get_vq_desc: head: %d, out: %u in: %u\n",
-			 head, out, in);
-		/* On error, stop handling until the next kick. */
-		if (unlikely(head < 0))
-			break;
-		/* Nothing new?  Wait for eventfd to tell us they refilled. */
-		if (head == vq->num) {
-			if (unlikely(vhost_enable_notify(&vs->dev, vq))) {
-				vhost_disable_notify(&vs->dev, vq);
-				continue;
-			}
-			break;
-		}
-		/*
-		 * Check for a sane response buffer so we can report early
-		 * errors back to the guest.
-		 */
-		if (unlikely(vq->iov[out].iov_len < rsp_size)) {
-			vq_err(vq, "Expecting at least virtio_scsi_cmd_resp"
-				" size, got %zu bytes\n", vq->iov[out].iov_len);
-			break;
-		}
+	do {
+		ret = vhost_scsi_get_desc(vs, vq, &vc);
+		if (ret)
+			goto err;
+
 		/*
 		 * Setup pointers and values based upon different virtio-scsi
 		 * request header if T10_PI is enabled in KVM guest.
 		 */
 		if (t10_pi) {
-			req = &v_req_pi;
-			req_size = sizeof(v_req_pi);
-			lunp = &v_req_pi.lun[0];
-			target = &v_req_pi.lun[1];
+			vc.req = &v_req_pi;
+			vc.req_size = sizeof(v_req_pi);
+			vc.lunp = &v_req_pi.lun[0];
+			vc.target = &v_req_pi.lun[1];
 		} else {
-			req = &v_req;
-			req_size = sizeof(v_req);
-			lunp = &v_req.lun[0];
-			target = &v_req.lun[1];
+			vc.req = &v_req;
+			vc.req_size = sizeof(v_req);
+			vc.lunp = &v_req.lun[0];
+			vc.target = &v_req.lun[1];
 		}
-		/*
-		 * FIXME: Not correct for BIDI operation
-		 */
-		out_size = iov_length(vq->iov, out);
-		in_size = iov_length(&vq->iov[out], in);
 
 		/*
-		 * Copy over the virtio-scsi request header, which for a
-		 * ANY_LAYOUT enabled guest may span multiple iovecs, or a
-		 * single iovec may contain both the header + outgoing
-		 * WRITE payloads.
-		 *
-		 * copy_from_iter() will advance out_iter, so that it will
-		 * point at the start of the outgoing WRITE payload, if
-		 * DMA_TO_DEVICE is set.
+		 * Validate the size of request and response buffers.
+		 * Check for a sane response buffer so we can report
+		 * early errors back to the guest.
 		 */
-		iov_iter_init(&out_iter, WRITE, vq->iov, out, out_size);
+		ret = vhost_scsi_chk_size(vq, &vc);
+		if (ret)
+			goto err;
 
-		if (unlikely(!copy_from_iter_full(req, req_size, &out_iter))) {
-			vq_err(vq, "Faulted on copy_from_iter\n");
-			vhost_scsi_send_bad_target(vs, vq, head, out);
-			continue;
-		}
-		/* virtio-scsi spec requires byte 0 of the lun to be 1 */
-		if (unlikely(*lunp != 1)) {
-			vq_err(vq, "Illegal virtio-scsi lun: %u\n", *lunp);
-			vhost_scsi_send_bad_target(vs, vq, head, out);
-			continue;
-		}
+		ret = vhost_scsi_get_req(vq, &vc, &tpg);
+		if (ret)
+			goto err;
 
-		tpg = READ_ONCE(vs_tpg[*target]);
-		if (unlikely(!tpg)) {
-			/* Target does not exist, fail the request */
-			vhost_scsi_send_bad_target(vs, vq, head, out);
-			continue;
-		}
+		ret = -EIO;	/* bad target on any error from here on */
+
 		/*
 		 * Determine data_direction by calculating the total outgoing
 		 * iovec sizes + incoming iovec sizes vs. virtio-scsi request +
@@ -924,17 +991,17 @@
 		 */
 		prot_bytes = 0;
 
-		if (out_size > req_size) {
+		if (vc.out_size > vc.req_size) {
 			data_direction = DMA_TO_DEVICE;
-			exp_data_len = out_size - req_size;
-			data_iter = out_iter;
-		} else if (in_size > rsp_size) {
+			exp_data_len = vc.out_size - vc.req_size;
+			data_iter = vc.out_iter;
+		} else if (vc.in_size > vc.rsp_size) {
 			data_direction = DMA_FROM_DEVICE;
-			exp_data_len = in_size - rsp_size;
+			exp_data_len = vc.in_size - vc.rsp_size;
 
-			iov_iter_init(&in_iter, READ, &vq->iov[out], in,
-				      rsp_size + exp_data_len);
-			iov_iter_advance(&in_iter, rsp_size);
+			iov_iter_init(&in_iter, READ, &vq->iov[vc.out], vc.in,
+				      vc.rsp_size + exp_data_len);
+			iov_iter_advance(&in_iter, vc.rsp_size);
 			data_iter = in_iter;
 		} else {
 			data_direction = DMA_NONE;
@@ -950,16 +1017,14 @@
 				if (data_direction != DMA_TO_DEVICE) {
 					vq_err(vq, "Received non zero pi_bytesout,"
 						" but wrong data_direction\n");
-					vhost_scsi_send_bad_target(vs, vq, head, out);
-					continue;
+					goto err;
 				}
 				prot_bytes = vhost32_to_cpu(vq, v_req_pi.pi_bytesout);
 			} else if (v_req_pi.pi_bytesin) {
 				if (data_direction != DMA_FROM_DEVICE) {
 					vq_err(vq, "Received non zero pi_bytesin,"
 						" but wrong data_direction\n");
-					vhost_scsi_send_bad_target(vs, vq, head, out);
-					continue;
+					goto err;
 				}
 				prot_bytes = vhost32_to_cpu(vq, v_req_pi.pi_bytesin);
 			}
@@ -998,8 +1063,7 @@
 			vq_err(vq, "Received SCSI CDB with command_size: %d that"
 				" exceeds SCSI_MAX_VARLEN_CDB_SIZE: %d\n",
 				scsi_command_size(cdb), VHOST_SCSI_MAX_CDB_SIZE);
-			vhost_scsi_send_bad_target(vs, vq, head, out);
-			continue;
+				goto err;
 		}
 		cmd = vhost_scsi_get_tag(vq, tpg, cdb, tag, lun, task_attr,
 					 exp_data_len + prot_bytes,
@@ -1007,13 +1071,12 @@
 		if (IS_ERR(cmd)) {
 			vq_err(vq, "vhost_scsi_get_tag failed %ld\n",
 			       PTR_ERR(cmd));
-			vhost_scsi_send_bad_target(vs, vq, head, out);
-			continue;
+			goto err;
 		}
 		cmd->tvc_vhost = vs;
 		cmd->tvc_vq = vq;
-		cmd->tvc_resp_iov = vq->iov[out];
-		cmd->tvc_in_iovs = in;
+		cmd->tvc_resp_iov = vq->iov[vc.out];
+		cmd->tvc_in_iovs = vc.in;
 
 		pr_debug("vhost_scsi got command opcode: %#02x, lun: %d\n",
 			 cmd->tvc_cdb[0], cmd->tvc_lun);
@@ -1021,14 +1084,12 @@
 			 " %d\n", cmd, exp_data_len, prot_bytes, data_direction);
 
 		if (data_direction != DMA_NONE) {
-			ret = vhost_scsi_mapal(cmd,
-					       prot_bytes, &prot_iter,
-					       exp_data_len, &data_iter);
-			if (unlikely(ret)) {
+			if (unlikely(vhost_scsi_mapal(cmd, prot_bytes,
+						      &prot_iter, exp_data_len,
+						      &data_iter))) {
 				vq_err(vq, "Failed to map iov to sgl\n");
 				vhost_scsi_release_cmd(&cmd->tvc_se_cmd);
-				vhost_scsi_send_bad_target(vs, vq, head, out);
-				continue;
+				goto err;
 			}
 		}
 		/*
@@ -1036,7 +1097,7 @@
 		 * complete the virtio-scsi request in TCM callback context via
 		 * vhost_scsi_queue_data_in() and vhost_scsi_queue_status()
 		 */
-		cmd->tvc_vq_desc = head;
+		cmd->tvc_vq_desc = vc.head;
 		/*
 		 * Dispatch cmd descriptor for cmwq execution in process
 		 * context provided by vhost_scsi_workqueue.  This also ensures
@@ -1045,14 +1106,183 @@
 		 */
 		INIT_WORK(&cmd->work, vhost_scsi_submission_work);
 		queue_work(vhost_scsi_workqueue, &cmd->work);
-	}
+		ret = 0;
+err:
+		/*
+		 * ENXIO:  No more requests, or read error, wait for next kick
+		 * EINVAL: Invalid response buffer, drop the request
+		 * EIO:    Respond with bad target
+		 * EAGAIN: Pending request
+		 */
+		if (ret == -ENXIO)
+			break;
+		else if (ret == -EIO)
+			vhost_scsi_send_bad_target(vs, vq, vc.head, vc.out);
+	} while (likely(!vhost_exceeds_weight(vq, ++c, 0)));
+out:
+	mutex_unlock(&vq->mutex);
+}
+
+static void
+vhost_scsi_send_tmf_reject(struct vhost_scsi *vs,
+			   struct vhost_virtqueue *vq,
+			   struct vhost_scsi_ctx *vc)
+{
+	struct virtio_scsi_ctrl_tmf_resp rsp;
+	struct iov_iter iov_iter;
+	int ret;
+
+	pr_debug("%s\n", __func__);
+	memset(&rsp, 0, sizeof(rsp));
+	rsp.response = VIRTIO_SCSI_S_FUNCTION_REJECTED;
+
+	iov_iter_init(&iov_iter, READ, &vq->iov[vc->out], vc->in, sizeof(rsp));
+
+	ret = copy_to_iter(&rsp, sizeof(rsp), &iov_iter);
+	if (likely(ret == sizeof(rsp)))
+		vhost_add_used_and_signal(&vs->dev, vq, vc->head, 0);
+	else
+		pr_err("Faulted on virtio_scsi_ctrl_tmf_resp\n");
+}
+
+static void
+vhost_scsi_send_an_resp(struct vhost_scsi *vs,
+			struct vhost_virtqueue *vq,
+			struct vhost_scsi_ctx *vc)
+{
+	struct virtio_scsi_ctrl_an_resp rsp;
+	struct iov_iter iov_iter;
+	int ret;
+
+	pr_debug("%s\n", __func__);
+	memset(&rsp, 0, sizeof(rsp));	/* event_actual = 0 */
+	rsp.response = VIRTIO_SCSI_S_OK;
+
+	iov_iter_init(&iov_iter, READ, &vq->iov[vc->out], vc->in, sizeof(rsp));
+
+	ret = copy_to_iter(&rsp, sizeof(rsp), &iov_iter);
+	if (likely(ret == sizeof(rsp)))
+		vhost_add_used_and_signal(&vs->dev, vq, vc->head, 0);
+	else
+		pr_err("Faulted on virtio_scsi_ctrl_an_resp\n");
+}
+
+static void
+vhost_scsi_ctl_handle_vq(struct vhost_scsi *vs, struct vhost_virtqueue *vq)
+{
+	union {
+		__virtio32 type;
+		struct virtio_scsi_ctrl_an_req an;
+		struct virtio_scsi_ctrl_tmf_req tmf;
+	} v_req;
+	struct vhost_scsi_ctx vc;
+	size_t typ_size;
+	int ret, c = 0;
+
+	mutex_lock(&vq->mutex);
+	/*
+	 * We can handle the vq only after the endpoint is setup by calling the
+	 * VHOST_SCSI_SET_ENDPOINT ioctl.
+	 */
+	if (!vq->private_data)
+		goto out;
+
+	memset(&vc, 0, sizeof(vc));
+
+	vhost_disable_notify(&vs->dev, vq);
+
+	do {
+		ret = vhost_scsi_get_desc(vs, vq, &vc);
+		if (ret)
+			goto err;
+
+		/*
+		 * Get the request type first in order to setup
+		 * other parameters dependent on the type.
+		 */
+		vc.req = &v_req.type;
+		typ_size = sizeof(v_req.type);
+
+		if (unlikely(!copy_from_iter_full(vc.req, typ_size,
+						  &vc.out_iter))) {
+			vq_err(vq, "Faulted on copy_from_iter tmf type\n");
+			/*
+			 * The size of the response buffer depends on the
+			 * request type and must be validated against it.
+			 * Since the request type is not known, don't send
+			 * a response.
+			 */
+			continue;
+		}
+
+		switch (v_req.type) {
+		case VIRTIO_SCSI_T_TMF:
+			vc.req = &v_req.tmf;
+			vc.req_size = sizeof(struct virtio_scsi_ctrl_tmf_req);
+			vc.rsp_size = sizeof(struct virtio_scsi_ctrl_tmf_resp);
+			vc.lunp = &v_req.tmf.lun[0];
+			vc.target = &v_req.tmf.lun[1];
+			break;
+		case VIRTIO_SCSI_T_AN_QUERY:
+		case VIRTIO_SCSI_T_AN_SUBSCRIBE:
+			vc.req = &v_req.an;
+			vc.req_size = sizeof(struct virtio_scsi_ctrl_an_req);
+			vc.rsp_size = sizeof(struct virtio_scsi_ctrl_an_resp);
+			vc.lunp = &v_req.an.lun[0];
+			vc.target = NULL;
+			break;
+		default:
+			vq_err(vq, "Unknown control request %d", v_req.type);
+			continue;
+		}
+
+		/*
+		 * Validate the size of request and response buffers.
+		 * Check for a sane response buffer so we can report
+		 * early errors back to the guest.
+		 */
+		ret = vhost_scsi_chk_size(vq, &vc);
+		if (ret)
+			goto err;
+
+		/*
+		 * Get the rest of the request now that its size is known.
+		 */
+		vc.req += typ_size;
+		vc.req_size -= typ_size;
+
+		ret = vhost_scsi_get_req(vq, &vc, NULL);
+		if (ret)
+			goto err;
+
+		if (v_req.type == VIRTIO_SCSI_T_TMF)
+			vhost_scsi_send_tmf_reject(vs, vq, &vc);
+		else
+			vhost_scsi_send_an_resp(vs, vq, &vc);
+err:
+		/*
+		 * ENXIO:  No more requests, or read error, wait for next kick
+		 * EINVAL: Invalid response buffer, drop the request
+		 * EIO:    Respond with bad target
+		 * EAGAIN: Pending request
+		 */
+		if (ret == -ENXIO)
+			break;
+		else if (ret == -EIO)
+			vhost_scsi_send_bad_target(vs, vq, vc.head, vc.out);
+	} while (likely(!vhost_exceeds_weight(vq, ++c, 0)));
 out:
 	mutex_unlock(&vq->mutex);
 }
 
 static void vhost_scsi_ctl_handle_kick(struct vhost_work *work)
 {
+	struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
+						poll.work);
+	struct vhost_scsi *vs = container_of(vq->dev, struct vhost_scsi, dev);
+
 	pr_debug("%s: The handling func for control queue.\n", __func__);
+	vhost_scsi_ctl_handle_vq(vs, vq);
 }
 
 static void
@@ -1211,7 +1441,7 @@
 			se_tpg = &tpg->se_tpg;
 			ret = target_depend_item(&se_tpg->tpg_group.cg_item);
 			if (ret) {
-				pr_warn("configfs_depend_item() failed: %d\n", ret);
+				pr_warn("target_depend_item() failed: %d\n", ret);
 				kfree(vs_tpg);
 				mutex_unlock(&tpg->tv_tpg_mutex);
 				goto out;
@@ -1219,7 +1449,6 @@
 			tpg->tv_tpg_vhost_count++;
 			tpg->vhost_scsi = vs;
 			vs_tpg[tpg->tport_tpgt] = tpg;
-			smp_mb__after_atomic();
 			match = true;
 		}
 		mutex_unlock(&tpg->tv_tpg_mutex);
@@ -1398,7 +1627,8 @@
 		vqs[i] = &vs->vqs[i].vq;
 		vs->vqs[i].vq.handle_kick = vhost_scsi_handle_kick;
 	}
-	vhost_dev_init(&vs->dev, vqs, VHOST_SCSI_MAX_VQ);
+	vhost_dev_init(&vs->dev, vqs, VHOST_SCSI_MAX_VQ, UIO_MAXIOV,
+		       VHOST_SCSI_WEIGHT, 0);
 
 	vhost_scsi_init_inflight(vs, NULL);
 
@@ -2059,8 +2289,7 @@
 
 static const struct target_core_fabric_ops vhost_scsi_ops = {
 	.module				= THIS_MODULE,
-	.name				= "vhost",
-	.get_fabric_name		= vhost_scsi_get_fabric_name,
+	.fabric_name			= "vhost",
 	.tpg_get_wwn			= vhost_scsi_get_fabric_wwn,
 	.tpg_get_tag			= vhost_scsi_get_tpgt,
 	.tpg_check_demo_mode		= vhost_scsi_check_true,
@@ -2074,7 +2303,6 @@
 	.sess_get_index			= vhost_scsi_sess_get_index,
 	.sess_get_initiator_sid		= NULL,
 	.write_pending			= vhost_scsi_write_pending,
-	.write_pending_status		= vhost_scsi_write_pending_status,
 	.set_default_node_attributes	= vhost_scsi_set_default_node_attrs,
 	.get_cmd_state			= vhost_scsi_get_cmd_state,
 	.queue_data_in			= vhost_scsi_queue_data_in,
diff --git a/drivers/vhost/test.c b/drivers/vhost/test.c
index 4058985..0563080 100644
--- a/drivers/vhost/test.c
+++ b/drivers/vhost/test.c
@@ -1,8 +1,7 @@
+// SPDX-License-Identifier: GPL-2.0-only
 /* Copyright (C) 2009 Red Hat, Inc.
  * Author: Michael S. Tsirkin <mst@redhat.com>
  *
- * This work is licensed under the terms of the GNU GPL, version 2.
- *
  * test virtio server in host kernel.
  */
 
@@ -23,6 +22,12 @@
  * Using this limit prevents one virtqueue from starving others. */
 #define VHOST_TEST_WEIGHT 0x80000
 
+/* Max number of packets transferred before requeueing the job.
+ * Using this limit prevents one virtqueue from starving others with
+ * pkts.
+ */
+#define VHOST_TEST_PKT_WEIGHT 256
+
 enum {
 	VHOST_TEST_VQ = 0,
 	VHOST_TEST_VQ_MAX = 1,
@@ -81,10 +86,8 @@
 		}
 		vhost_add_used_and_signal(&n->dev, vq, head, 0);
 		total_len += len;
-		if (unlikely(total_len >= VHOST_TEST_WEIGHT)) {
-			vhost_poll_queue(&vq->poll);
+		if (unlikely(vhost_exceeds_weight(vq, 0, total_len)))
 			break;
-		}
 	}
 
 	mutex_unlock(&vq->mutex);
@@ -116,7 +119,8 @@
 	dev = &n->dev;
 	vqs[VHOST_TEST_VQ] = &n->vqs[VHOST_TEST_VQ];
 	n->vqs[VHOST_TEST_VQ].handle_kick = handle_vq_kick;
-	vhost_dev_init(dev, vqs, VHOST_TEST_VQ_MAX);
+	vhost_dev_init(dev, vqs, VHOST_TEST_VQ_MAX, UIO_MAXIOV,
+		       VHOST_TEST_PKT_WEIGHT, VHOST_TEST_WEIGHT);
 
 	f->private_data = n;
 
@@ -157,6 +161,7 @@
 
 	vhost_test_stop(n, &private);
 	vhost_test_flush(n);
+	vhost_dev_stop(&n->dev);
 	vhost_dev_cleanup(&n->dev);
 	/* We do an extra flush before freeing memory,
 	 * since jobs can re-queue themselves. */
@@ -233,6 +238,7 @@
 	}
 	vhost_test_stop(n, &priv);
 	vhost_test_flush(n);
+	vhost_dev_stop(&n->dev);
 	vhost_dev_reset_owner(&n->dev, umem);
 done:
 	mutex_unlock(&n->dev.mutex);
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c
index eb95daa..36ca2cf 100644
--- a/drivers/vhost/vhost.c
+++ b/drivers/vhost/vhost.c
@@ -1,3 +1,4 @@
+// SPDX-License-Identifier: GPL-2.0-only
 /* Copyright (C) 2009 Red Hat, Inc.
  * Copyright (C) 2006 Rusty Russell IBM Corporation
  *
@@ -6,8 +7,6 @@
  * Inspiration, some code, and most witty comments come from
  * Documentation/virtual/lguest/lguest.c, by Rusty Russell
  *
- * This work is licensed under the terms of the GNU GPL, version 2.
- *
  * Generic code for virtio server in host kernel.
  */
 
@@ -204,7 +203,6 @@
 int vhost_poll_start(struct vhost_poll *poll, struct file *file)
 {
 	__poll_t mask;
-	int ret = 0;
 
 	if (poll->wqh)
 		return 0;
@@ -214,10 +212,10 @@
 		vhost_poll_wakeup(&poll->wait, 0, 0, poll_to_key(mask));
 	if (mask & EPOLLERR) {
 		vhost_poll_stop(poll);
-		ret = -EINVAL;
+		return -EINVAL;
 	}
 
-	return ret;
+	return 0;
 }
 EXPORT_SYMBOL_GPL(vhost_poll_start);
 
@@ -390,9 +388,9 @@
 		vq->indirect = kmalloc_array(UIO_MAXIOV,
 					     sizeof(*vq->indirect),
 					     GFP_KERNEL);
-		vq->log = kmalloc_array(UIO_MAXIOV, sizeof(*vq->log),
+		vq->log = kmalloc_array(dev->iov_limit, sizeof(*vq->log),
 					GFP_KERNEL);
-		vq->heads = kmalloc_array(UIO_MAXIOV, sizeof(*vq->heads),
+		vq->heads = kmalloc_array(dev->iov_limit, sizeof(*vq->heads),
 					  GFP_KERNEL);
 		if (!vq->indirect || !vq->log || !vq->heads)
 			goto err_nomem;
@@ -413,8 +411,50 @@
 		vhost_vq_free_iovecs(dev->vqs[i]);
 }
 
+bool vhost_exceeds_weight(struct vhost_virtqueue *vq,
+			  int pkts, int total_len)
+{
+	struct vhost_dev *dev = vq->dev;
+
+	if ((dev->byte_weight && total_len >= dev->byte_weight) ||
+	    pkts >= dev->weight) {
+		vhost_poll_queue(&vq->poll);
+		return true;
+	}
+
+	return false;
+}
+EXPORT_SYMBOL_GPL(vhost_exceeds_weight);
+
+static size_t vhost_get_avail_size(struct vhost_virtqueue *vq,
+				   unsigned int num)
+{
+	size_t event __maybe_unused =
+	       vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
+
+	return sizeof(*vq->avail) +
+	       sizeof(*vq->avail->ring) * num + event;
+}
+
+static size_t vhost_get_used_size(struct vhost_virtqueue *vq,
+				  unsigned int num)
+{
+	size_t event __maybe_unused =
+	       vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
+
+	return sizeof(*vq->used) +
+	       sizeof(*vq->used->ring) * num + event;
+}
+
+static size_t vhost_get_desc_size(struct vhost_virtqueue *vq,
+				  unsigned int num)
+{
+	return sizeof(*vq->desc) * num;
+}
+
 void vhost_dev_init(struct vhost_dev *dev,
-		    struct vhost_virtqueue **vqs, int nvqs)
+		    struct vhost_virtqueue **vqs, int nvqs,
+		    int iov_limit, int weight, int byte_weight)
 {
 	struct vhost_virtqueue *vq;
 	int i;
@@ -427,6 +467,9 @@
 	dev->iotlb = NULL;
 	dev->mm = NULL;
 	dev->worker = NULL;
+	dev->iov_limit = iov_limit;
+	dev->weight = weight;
+	dev->byte_weight = byte_weight;
 	init_llist_head(&dev->work_list);
 	init_waitqueue_head(&dev->wait);
 	INIT_LIST_HEAD(&dev->read_list);
@@ -655,7 +698,7 @@
 	    a + (unsigned long)log_base > ULONG_MAX)
 		return false;
 
-	return access_ok(VERIFY_WRITE, log_base + a,
+	return access_ok(log_base + a,
 			 (sz + VHOST_PAGE_SIZE * 8 - 1) / VHOST_PAGE_SIZE / 8);
 }
 
@@ -681,7 +724,7 @@
 			return false;
 
 
-		if (!access_ok(VERIFY_WRITE, (void __user *)a,
+		if (!access_ok((void __user *)a,
 				    node->size))
 			return false;
 		else if (log_all && !log_access_ok(log_base,
@@ -868,6 +911,34 @@
 	ret; \
 })
 
+static inline int vhost_put_avail_event(struct vhost_virtqueue *vq)
+{
+	return vhost_put_user(vq, cpu_to_vhost16(vq, vq->avail_idx),
+			      vhost_avail_event(vq));
+}
+
+static inline int vhost_put_used(struct vhost_virtqueue *vq,
+				 struct vring_used_elem *head, int idx,
+				 int count)
+{
+	return vhost_copy_to_user(vq, vq->used->ring + idx, head,
+				  count * sizeof(*head));
+}
+
+static inline int vhost_put_used_flags(struct vhost_virtqueue *vq)
+
+{
+	return vhost_put_user(vq, cpu_to_vhost16(vq, vq->used_flags),
+			      &vq->used->flags);
+}
+
+static inline int vhost_put_used_idx(struct vhost_virtqueue *vq)
+
+{
+	return vhost_put_user(vq, cpu_to_vhost16(vq, vq->last_used_idx),
+			      &vq->used->idx);
+}
+
 #define vhost_get_user(vq, x, ptr, type)		\
 ({ \
 	int ret; \
@@ -906,12 +977,53 @@
 		mutex_unlock(&d->vqs[i]->mutex);
 }
 
+static inline int vhost_get_avail_idx(struct vhost_virtqueue *vq,
+				      __virtio16 *idx)
+{
+	return vhost_get_avail(vq, *idx, &vq->avail->idx);
+}
+
+static inline int vhost_get_avail_head(struct vhost_virtqueue *vq,
+				       __virtio16 *head, int idx)
+{
+	return vhost_get_avail(vq, *head,
+			       &vq->avail->ring[idx & (vq->num - 1)]);
+}
+
+static inline int vhost_get_avail_flags(struct vhost_virtqueue *vq,
+					__virtio16 *flags)
+{
+	return vhost_get_avail(vq, *flags, &vq->avail->flags);
+}
+
+static inline int vhost_get_used_event(struct vhost_virtqueue *vq,
+				       __virtio16 *event)
+{
+	return vhost_get_avail(vq, *event, vhost_used_event(vq));
+}
+
+static inline int vhost_get_used_idx(struct vhost_virtqueue *vq,
+				     __virtio16 *idx)
+{
+	return vhost_get_used(vq, *idx, &vq->used->idx);
+}
+
+static inline int vhost_get_desc(struct vhost_virtqueue *vq,
+				 struct vring_desc *desc, int idx)
+{
+	return vhost_copy_from_user(vq, desc, vq->desc + idx, sizeof(*desc));
+}
+
 static int vhost_new_umem_range(struct vhost_umem *umem,
 				u64 start, u64 size, u64 end,
 				u64 userspace_addr, int perm)
 {
-	struct vhost_umem_node *tmp, *node = kmalloc(sizeof(*node), GFP_ATOMIC);
+	struct vhost_umem_node *tmp, *node;
 
+	if (!size)
+		return -EFAULT;
+
+	node = kmalloc(sizeof(*node), GFP_ATOMIC);
 	if (!node)
 		return -ENOMEM;
 
@@ -973,10 +1085,10 @@
 		return false;
 
 	if ((access & VHOST_ACCESS_RO) &&
-	    !access_ok(VERIFY_READ, (void __user *)a, size))
+	    !access_ok((void __user *)a, size))
 		return false;
 	if ((access & VHOST_ACCESS_WO) &&
-	    !access_ok(VERIFY_WRITE, (void __user *)a, size))
+	    !access_ok((void __user *)a, size))
 		return false;
 	return true;
 }
@@ -1034,8 +1146,10 @@
 	int type, ret;
 
 	ret = copy_from_iter(&type, sizeof(type), from);
-	if (ret != sizeof(type))
+	if (ret != sizeof(type)) {
+		ret = -EINVAL;
 		goto done;
+	}
 
 	switch (type) {
 	case VHOST_IOTLB_MSG:
@@ -1054,8 +1168,10 @@
 
 	iov_iter_advance(from, offset);
 	ret = copy_from_iter(&msg, sizeof(msg), from);
-	if (ret != sizeof(msg))
+	if (ret != sizeof(msg)) {
+		ret = -EINVAL;
 		goto done;
+	}
 	if (vhost_process_iotlb_msg(dev, &msg)) {
 		ret = -EFAULT;
 		goto done;
@@ -1183,13 +1299,9 @@
 			 struct vring_used __user *used)
 
 {
-	size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
-
-	return access_ok(VERIFY_READ, desc, num * sizeof *desc) &&
-	       access_ok(VERIFY_READ, avail,
-			 sizeof *avail + num * sizeof *avail->ring + s) &&
-	       access_ok(VERIFY_WRITE, used,
-			sizeof *used + num * sizeof *used->ring + s);
+	return access_ok(desc, vhost_get_desc_size(vq, num)) &&
+	       access_ok(avail, vhost_get_avail_size(vq, num)) &&
+	       access_ok(used, vhost_get_used_size(vq, num));
 }
 
 static void vhost_vq_meta_update(struct vhost_virtqueue *vq,
@@ -1239,26 +1351,22 @@
 	return true;
 }
 
-int vq_iotlb_prefetch(struct vhost_virtqueue *vq)
+int vq_meta_prefetch(struct vhost_virtqueue *vq)
 {
-	size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
 	unsigned int num = vq->num;
 
 	if (!vq->iotlb)
 		return 1;
 
 	return iotlb_access_ok(vq, VHOST_ACCESS_RO, (u64)(uintptr_t)vq->desc,
-			       num * sizeof(*vq->desc), VHOST_ADDR_DESC) &&
+			       vhost_get_desc_size(vq, num), VHOST_ADDR_DESC) &&
 	       iotlb_access_ok(vq, VHOST_ACCESS_RO, (u64)(uintptr_t)vq->avail,
-			       sizeof *vq->avail +
-			       num * sizeof(*vq->avail->ring) + s,
+			       vhost_get_avail_size(vq, num),
 			       VHOST_ADDR_AVAIL) &&
 	       iotlb_access_ok(vq, VHOST_ACCESS_WO, (u64)(uintptr_t)vq->used,
-			       sizeof *vq->used +
-			       num * sizeof(*vq->used->ring) + s,
-			       VHOST_ADDR_USED);
+			       vhost_get_used_size(vq, num), VHOST_ADDR_USED);
 }
-EXPORT_SYMBOL_GPL(vq_iotlb_prefetch);
+EXPORT_SYMBOL_GPL(vq_meta_prefetch);
 
 /* Can we log writes? */
 /* Caller should have device mutex but not vq mutex */
@@ -1273,13 +1381,10 @@
 static bool vq_log_access_ok(struct vhost_virtqueue *vq,
 			     void __user *log_base)
 {
-	size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
-
 	return vq_memory_access_ok(log_base, vq->umem,
 				   vhost_has_feature(vq, VHOST_F_LOG_ALL)) &&
 		(!vq->log_used || log_access_ok(log_base, vq->log_addr,
-					sizeof *vq->used +
-					vq->num * sizeof *vq->used->ring + s));
+				  vhost_get_used_size(vq, vq->num)));
 }
 
 /* Can we start vq? */
@@ -1379,6 +1484,104 @@
 	return -EFAULT;
 }
 
+static long vhost_vring_set_num(struct vhost_dev *d,
+				struct vhost_virtqueue *vq,
+				void __user *argp)
+{
+	struct vhost_vring_state s;
+
+	/* Resizing ring with an active backend?
+	 * You don't want to do that. */
+	if (vq->private_data)
+		return -EBUSY;
+
+	if (copy_from_user(&s, argp, sizeof s))
+		return -EFAULT;
+
+	if (!s.num || s.num > 0xffff || (s.num & (s.num - 1)))
+		return -EINVAL;
+	vq->num = s.num;
+
+	return 0;
+}
+
+static long vhost_vring_set_addr(struct vhost_dev *d,
+				 struct vhost_virtqueue *vq,
+				 void __user *argp)
+{
+	struct vhost_vring_addr a;
+
+	if (copy_from_user(&a, argp, sizeof a))
+		return -EFAULT;
+	if (a.flags & ~(0x1 << VHOST_VRING_F_LOG))
+		return -EOPNOTSUPP;
+
+	/* For 32bit, verify that the top 32bits of the user
+	   data are set to zero. */
+	if ((u64)(unsigned long)a.desc_user_addr != a.desc_user_addr ||
+	    (u64)(unsigned long)a.used_user_addr != a.used_user_addr ||
+	    (u64)(unsigned long)a.avail_user_addr != a.avail_user_addr)
+		return -EFAULT;
+
+	/* Make sure it's safe to cast pointers to vring types. */
+	BUILD_BUG_ON(__alignof__ *vq->avail > VRING_AVAIL_ALIGN_SIZE);
+	BUILD_BUG_ON(__alignof__ *vq->used > VRING_USED_ALIGN_SIZE);
+	if ((a.avail_user_addr & (VRING_AVAIL_ALIGN_SIZE - 1)) ||
+	    (a.used_user_addr & (VRING_USED_ALIGN_SIZE - 1)) ||
+	    (a.log_guest_addr & (VRING_USED_ALIGN_SIZE - 1)))
+		return -EINVAL;
+
+	/* We only verify access here if backend is configured.
+	 * If it is not, we don't as size might not have been setup.
+	 * We will verify when backend is configured. */
+	if (vq->private_data) {
+		if (!vq_access_ok(vq, vq->num,
+			(void __user *)(unsigned long)a.desc_user_addr,
+			(void __user *)(unsigned long)a.avail_user_addr,
+			(void __user *)(unsigned long)a.used_user_addr))
+			return -EINVAL;
+
+		/* Also validate log access for used ring if enabled. */
+		if ((a.flags & (0x1 << VHOST_VRING_F_LOG)) &&
+			!log_access_ok(vq->log_base, a.log_guest_addr,
+				sizeof *vq->used +
+				vq->num * sizeof *vq->used->ring))
+			return -EINVAL;
+	}
+
+	vq->log_used = !!(a.flags & (0x1 << VHOST_VRING_F_LOG));
+	vq->desc = (void __user *)(unsigned long)a.desc_user_addr;
+	vq->avail = (void __user *)(unsigned long)a.avail_user_addr;
+	vq->log_addr = a.log_guest_addr;
+	vq->used = (void __user *)(unsigned long)a.used_user_addr;
+
+	return 0;
+}
+
+static long vhost_vring_set_num_addr(struct vhost_dev *d,
+				     struct vhost_virtqueue *vq,
+				     unsigned int ioctl,
+				     void __user *argp)
+{
+	long r;
+
+	mutex_lock(&vq->mutex);
+
+	switch (ioctl) {
+	case VHOST_SET_VRING_NUM:
+		r = vhost_vring_set_num(d, vq, argp);
+		break;
+	case VHOST_SET_VRING_ADDR:
+		r = vhost_vring_set_addr(d, vq, argp);
+		break;
+	default:
+		BUG();
+	}
+
+	mutex_unlock(&vq->mutex);
+
+	return r;
+}
 long vhost_vring_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp)
 {
 	struct file *eventfp, *filep = NULL;
@@ -1388,7 +1591,6 @@
 	struct vhost_virtqueue *vq;
 	struct vhost_vring_state s;
 	struct vhost_vring_file f;
-	struct vhost_vring_addr a;
 	u32 idx;
 	long r;
 
@@ -1401,26 +1603,14 @@
 	idx = array_index_nospec(idx, d->nvqs);
 	vq = d->vqs[idx];
 
+	if (ioctl == VHOST_SET_VRING_NUM ||
+	    ioctl == VHOST_SET_VRING_ADDR) {
+		return vhost_vring_set_num_addr(d, vq, ioctl, argp);
+	}
+
 	mutex_lock(&vq->mutex);
 
 	switch (ioctl) {
-	case VHOST_SET_VRING_NUM:
-		/* Resizing ring with an active backend?
-		 * You don't want to do that. */
-		if (vq->private_data) {
-			r = -EBUSY;
-			break;
-		}
-		if (copy_from_user(&s, argp, sizeof s)) {
-			r = -EFAULT;
-			break;
-		}
-		if (!s.num || s.num > 0xffff || (s.num & (s.num - 1))) {
-			r = -EINVAL;
-			break;
-		}
-		vq->num = s.num;
-		break;
 	case VHOST_SET_VRING_BASE:
 		/* Moving base with an active backend?
 		 * You don't want to do that. */
@@ -1446,62 +1636,6 @@
 		if (copy_to_user(argp, &s, sizeof s))
 			r = -EFAULT;
 		break;
-	case VHOST_SET_VRING_ADDR:
-		if (copy_from_user(&a, argp, sizeof a)) {
-			r = -EFAULT;
-			break;
-		}
-		if (a.flags & ~(0x1 << VHOST_VRING_F_LOG)) {
-			r = -EOPNOTSUPP;
-			break;
-		}
-		/* For 32bit, verify that the top 32bits of the user
-		   data are set to zero. */
-		if ((u64)(unsigned long)a.desc_user_addr != a.desc_user_addr ||
-		    (u64)(unsigned long)a.used_user_addr != a.used_user_addr ||
-		    (u64)(unsigned long)a.avail_user_addr != a.avail_user_addr) {
-			r = -EFAULT;
-			break;
-		}
-
-		/* Make sure it's safe to cast pointers to vring types. */
-		BUILD_BUG_ON(__alignof__ *vq->avail > VRING_AVAIL_ALIGN_SIZE);
-		BUILD_BUG_ON(__alignof__ *vq->used > VRING_USED_ALIGN_SIZE);
-		if ((a.avail_user_addr & (VRING_AVAIL_ALIGN_SIZE - 1)) ||
-		    (a.used_user_addr & (VRING_USED_ALIGN_SIZE - 1)) ||
-		    (a.log_guest_addr & (VRING_USED_ALIGN_SIZE - 1))) {
-			r = -EINVAL;
-			break;
-		}
-
-		/* We only verify access here if backend is configured.
-		 * If it is not, we don't as size might not have been setup.
-		 * We will verify when backend is configured. */
-		if (vq->private_data) {
-			if (!vq_access_ok(vq, vq->num,
-				(void __user *)(unsigned long)a.desc_user_addr,
-				(void __user *)(unsigned long)a.avail_user_addr,
-				(void __user *)(unsigned long)a.used_user_addr)) {
-				r = -EINVAL;
-				break;
-			}
-
-			/* Also validate log access for used ring if enabled. */
-			if ((a.flags & (0x1 << VHOST_VRING_F_LOG)) &&
-			    !log_access_ok(vq->log_base, a.log_guest_addr,
-					   sizeof *vq->used +
-					   vq->num * sizeof *vq->used->ring)) {
-				r = -EINVAL;
-				break;
-			}
-		}
-
-		vq->log_used = !!(a.flags & (0x1 << VHOST_VRING_F_LOG));
-		vq->desc = (void __user *)(unsigned long)a.desc_user_addr;
-		vq->avail = (void __user *)(unsigned long)a.avail_user_addr;
-		vq->log_addr = a.log_guest_addr;
-		vq->used = (void __user *)(unsigned long)a.used_user_addr;
-		break;
 	case VHOST_SET_VRING_KICK:
 		if (copy_from_user(&f, argp, sizeof f)) {
 			r = -EFAULT;
@@ -1685,7 +1819,7 @@
 
 /* TODO: This is really inefficient.  We need something like get_user()
  * (instruction directly accesses the data, with an exception table entry
- * returning -EFAULT). See Documentation/x86/exception-tables.txt.
+ * returning -EFAULT). See Documentation/x86/exception-tables.rst.
  */
 static int set_bit_to_user(int nr, void __user *addr)
 {
@@ -1695,7 +1829,7 @@
 	int bit = nr + (log % PAGE_SIZE) * 8;
 	int r;
 
-	r = get_user_pages_fast(log, 1, 1, &page);
+	r = get_user_pages_fast(log, 1, FOLL_WRITE, &page);
 	if (r < 0)
 		return r;
 	BUG_ON(r != 1);
@@ -1733,13 +1867,87 @@
 	return r;
 }
 
+static int log_write_hva(struct vhost_virtqueue *vq, u64 hva, u64 len)
+{
+	struct vhost_umem *umem = vq->umem;
+	struct vhost_umem_node *u;
+	u64 start, end, l, min;
+	int r;
+	bool hit = false;
+
+	while (len) {
+		min = len;
+		/* More than one GPAs can be mapped into a single HVA. So
+		 * iterate all possible umems here to be safe.
+		 */
+		list_for_each_entry(u, &umem->umem_list, link) {
+			if (u->userspace_addr > hva - 1 + len ||
+			    u->userspace_addr - 1 + u->size < hva)
+				continue;
+			start = max(u->userspace_addr, hva);
+			end = min(u->userspace_addr - 1 + u->size,
+				  hva - 1 + len);
+			l = end - start + 1;
+			r = log_write(vq->log_base,
+				      u->start + start - u->userspace_addr,
+				      l);
+			if (r < 0)
+				return r;
+			hit = true;
+			min = min(l, min);
+		}
+
+		if (!hit)
+			return -EFAULT;
+
+		len -= min;
+		hva += min;
+	}
+
+	return 0;
+}
+
+static int log_used(struct vhost_virtqueue *vq, u64 used_offset, u64 len)
+{
+	struct iovec iov[64];
+	int i, ret;
+
+	if (!vq->iotlb)
+		return log_write(vq->log_base, vq->log_addr + used_offset, len);
+
+	ret = translate_desc(vq, (uintptr_t)vq->used + used_offset,
+			     len, iov, 64, VHOST_ACCESS_WO);
+	if (ret < 0)
+		return ret;
+
+	for (i = 0; i < ret; i++) {
+		ret = log_write_hva(vq,	(uintptr_t)iov[i].iov_base,
+				    iov[i].iov_len);
+		if (ret)
+			return ret;
+	}
+
+	return 0;
+}
+
 int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log,
-		    unsigned int log_num, u64 len)
+		    unsigned int log_num, u64 len, struct iovec *iov, int count)
 {
 	int i, r;
 
 	/* Make sure data written is seen before log. */
 	smp_wmb();
+
+	if (vq->iotlb) {
+		for (i = 0; i < count; i++) {
+			r = log_write_hva(vq, (uintptr_t)iov[i].iov_base,
+					  iov[i].iov_len);
+			if (r < 0)
+				return r;
+		}
+		return 0;
+	}
+
 	for (i = 0; i < log_num; ++i) {
 		u64 l = min(log[i].len, len);
 		r = log_write(vq->log_base, log[i].addr, l);
@@ -1761,17 +1969,15 @@
 static int vhost_update_used_flags(struct vhost_virtqueue *vq)
 {
 	void __user *used;
-	if (vhost_put_user(vq, cpu_to_vhost16(vq, vq->used_flags),
-			   &vq->used->flags) < 0)
+	if (vhost_put_used_flags(vq))
 		return -EFAULT;
 	if (unlikely(vq->log_used)) {
 		/* Make sure the flag is seen before log. */
 		smp_wmb();
 		/* Log used flag write. */
 		used = &vq->used->flags;
-		log_write(vq->log_base, vq->log_addr +
-			  (used - (void __user *)vq->used),
-			  sizeof vq->used->flags);
+		log_used(vq, (used - (void __user *)vq->used),
+			 sizeof vq->used->flags);
 		if (vq->log_ctx)
 			eventfd_signal(vq->log_ctx, 1);
 	}
@@ -1780,8 +1986,7 @@
 
 static int vhost_update_avail_event(struct vhost_virtqueue *vq, u16 avail_event)
 {
-	if (vhost_put_user(vq, cpu_to_vhost16(vq, vq->avail_idx),
-			   vhost_avail_event(vq)))
+	if (vhost_put_avail_event(vq))
 		return -EFAULT;
 	if (unlikely(vq->log_used)) {
 		void __user *used;
@@ -1789,9 +1994,8 @@
 		smp_wmb();
 		/* Log avail event write */
 		used = vhost_avail_event(vq);
-		log_write(vq->log_base, vq->log_addr +
-			  (used - (void __user *)vq->used),
-			  sizeof *vhost_avail_event(vq));
+		log_used(vq, (used - (void __user *)vq->used),
+			 sizeof *vhost_avail_event(vq));
 		if (vq->log_ctx)
 			eventfd_signal(vq->log_ctx, 1);
 	}
@@ -1814,11 +2018,11 @@
 		goto err;
 	vq->signalled_used_valid = false;
 	if (!vq->iotlb &&
-	    !access_ok(VERIFY_READ, &vq->used->idx, sizeof vq->used->idx)) {
+	    !access_ok(&vq->used->idx, sizeof vq->used->idx)) {
 		r = -EFAULT;
 		goto err;
 	}
-	r = vhost_get_used(vq, last_used_idx, &vq->used->idx);
+	r = vhost_get_used_idx(vq, &last_used_idx);
 	if (r) {
 		vq_err(vq, "Can't access used idx at %p\n",
 		       &vq->used->idx);
@@ -1974,7 +2178,7 @@
 		/* If this is an input descriptor, increment that count. */
 		if (access == VHOST_ACCESS_WO) {
 			*in_num += ret;
-			if (unlikely(log)) {
+			if (unlikely(log && ret)) {
 				log[*log_num].addr = vhost64_to_cpu(vq, desc.addr);
 				log[*log_num].len = vhost32_to_cpu(vq, desc.len);
 				++*log_num;
@@ -2017,7 +2221,7 @@
 	last_avail_idx = vq->last_avail_idx;
 
 	if (vq->avail_idx == vq->last_avail_idx) {
-		if (unlikely(vhost_get_avail(vq, avail_idx, &vq->avail->idx))) {
+		if (unlikely(vhost_get_avail_idx(vq, &avail_idx))) {
 			vq_err(vq, "Failed to access avail idx at %p\n",
 				&vq->avail->idx);
 			return -EFAULT;
@@ -2044,8 +2248,7 @@
 
 	/* Grab the next descriptor number they're advertising, and increment
 	 * the index we've seen. */
-	if (unlikely(vhost_get_avail(vq, ring_head,
-		     &vq->avail->ring[last_avail_idx & (vq->num - 1)]))) {
+	if (unlikely(vhost_get_avail_head(vq, &ring_head, last_avail_idx))) {
 		vq_err(vq, "Failed to read head: idx %d address %p\n",
 		       last_avail_idx,
 		       &vq->avail->ring[last_avail_idx % vq->num]);
@@ -2080,8 +2283,7 @@
 			       i, vq->num, head);
 			return -EINVAL;
 		}
-		ret = vhost_copy_from_user(vq, &desc, vq->desc + i,
-					   sizeof desc);
+		ret = vhost_get_desc(vq, &desc, i);
 		if (unlikely(ret)) {
 			vq_err(vq, "Failed to get descriptor: idx %d addr %p\n",
 			       i, vq->desc + i);
@@ -2117,7 +2319,7 @@
 			/* If this is an input descriptor,
 			 * increment that count. */
 			*in_num += ret;
-			if (unlikely(log)) {
+			if (unlikely(log && ret)) {
 				log[*log_num].addr = vhost64_to_cpu(vq, desc.addr);
 				log[*log_num].len = vhost32_to_cpu(vq, desc.len);
 				++*log_num;
@@ -2174,16 +2376,7 @@
 
 	start = vq->last_used_idx & (vq->num - 1);
 	used = vq->used->ring + start;
-	if (count == 1) {
-		if (vhost_put_user(vq, heads[0].id, &used->id)) {
-			vq_err(vq, "Failed to write used id");
-			return -EFAULT;
-		}
-		if (vhost_put_user(vq, heads[0].len, &used->len)) {
-			vq_err(vq, "Failed to write used len");
-			return -EFAULT;
-		}
-	} else if (vhost_copy_to_user(vq, used, heads, count * sizeof *used)) {
+	if (vhost_put_used(vq, heads, start, count)) {
 		vq_err(vq, "Failed to write used");
 		return -EFAULT;
 	}
@@ -2191,10 +2384,8 @@
 		/* Make sure data is seen before log. */
 		smp_wmb();
 		/* Log used ring entry write. */
-		log_write(vq->log_base,
-			  vq->log_addr +
-			   ((void __user *)used - (void __user *)vq->used),
-			  count * sizeof *used);
+		log_used(vq, ((void __user *)used - (void __user *)vq->used),
+			 count * sizeof *used);
 	}
 	old = vq->last_used_idx;
 	new = (vq->last_used_idx += count);
@@ -2227,16 +2418,16 @@
 
 	/* Make sure buffer is written before we update index. */
 	smp_wmb();
-	if (vhost_put_user(vq, cpu_to_vhost16(vq, vq->last_used_idx),
-			   &vq->used->idx)) {
+	if (vhost_put_used_idx(vq)) {
 		vq_err(vq, "Failed to increment used idx");
 		return -EFAULT;
 	}
 	if (unlikely(vq->log_used)) {
+		/* Make sure used idx is seen before log. */
+		smp_wmb();
 		/* Log used index update. */
-		log_write(vq->log_base,
-			  vq->log_addr + offsetof(struct vring_used, idx),
-			  sizeof vq->used->idx);
+		log_used(vq, offsetof(struct vring_used, idx),
+			 sizeof vq->used->idx);
 		if (vq->log_ctx)
 			eventfd_signal(vq->log_ctx, 1);
 	}
@@ -2260,7 +2451,7 @@
 
 	if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) {
 		__virtio16 flags;
-		if (vhost_get_avail(vq, flags, &vq->avail->flags)) {
+		if (vhost_get_avail_flags(vq, &flags)) {
 			vq_err(vq, "Failed to get flags");
 			return true;
 		}
@@ -2274,7 +2465,7 @@
 	if (unlikely(!v))
 		return true;
 
-	if (vhost_get_avail(vq, event, vhost_used_event(vq))) {
+	if (vhost_get_used_event(vq, &event)) {
 		vq_err(vq, "Failed to get used event idx");
 		return true;
 	}
@@ -2319,7 +2510,7 @@
 	if (vq->avail_idx != vq->last_avail_idx)
 		return false;
 
-	r = vhost_get_avail(vq, avail_idx, &vq->avail->idx);
+	r = vhost_get_avail_idx(vq, &avail_idx);
 	if (unlikely(r))
 		return false;
 	vq->avail_idx = vhost16_to_cpu(vq, avail_idx);
@@ -2355,7 +2546,7 @@
 	/* They could have slipped one in as we were doing that: make
 	 * sure it's written, then check again. */
 	smp_mb();
-	r = vhost_get_avail(vq, avail_idx, &vq->avail->idx);
+	r = vhost_get_avail_idx(vq, &avail_idx);
 	if (r) {
 		vq_err(vq, "Failed to check avail idx at %p: %d\n",
 		       &vq->avail->idx, r);
diff --git a/drivers/vhost/vhost.h b/drivers/vhost/vhost.h
index 466ef75..e9ed272 100644
--- a/drivers/vhost/vhost.h
+++ b/drivers/vhost/vhost.h
@@ -170,9 +170,14 @@
 	struct list_head read_list;
 	struct list_head pending_list;
 	wait_queue_head_t wait;
+	int iov_limit;
+	int weight;
+	int byte_weight;
 };
 
-void vhost_dev_init(struct vhost_dev *, struct vhost_virtqueue **vqs, int nvqs);
+bool vhost_exceeds_weight(struct vhost_virtqueue *vq, int pkts, int total_len);
+void vhost_dev_init(struct vhost_dev *, struct vhost_virtqueue **vqs,
+		    int nvqs, int iov_limit, int weight, int byte_weight);
 long vhost_dev_set_owner(struct vhost_dev *dev);
 bool vhost_dev_has_owner(struct vhost_dev *dev);
 long vhost_dev_check_owner(struct vhost_dev *);
@@ -205,8 +210,9 @@
 bool vhost_enable_notify(struct vhost_dev *, struct vhost_virtqueue *);
 
 int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log,
-		    unsigned int log_num, u64 len);
-int vq_iotlb_prefetch(struct vhost_virtqueue *vq);
+		    unsigned int log_num, u64 len,
+		    struct iovec *iov, int count);
+int vq_meta_prefetch(struct vhost_virtqueue *vq);
 
 struct vhost_msg_node *vhost_new_msg(struct vhost_virtqueue *vq, int type);
 void vhost_enqueue_msg(struct vhost_dev *dev,
diff --git a/drivers/vhost/vringh.c b/drivers/vhost/vringh.c
index a94d700..a0a2d74 100644
--- a/drivers/vhost/vringh.c
+++ b/drivers/vhost/vringh.c
@@ -1,3 +1,4 @@
+// SPDX-License-Identifier: GPL-2.0-only
 /*
  * Helpers for the host side of a virtio ring.
  *
@@ -851,6 +852,12 @@
 	return 0;
 }
 
+static inline int kern_xfer(void *dst, void *src, size_t len)
+{
+	memcpy(dst, src, len);
+	return 0;
+}
+
 /**
  * vringh_init_kern - initialize a vringh for a kernelspace vring.
  * @vrh: the vringh to initialize.
@@ -957,7 +964,7 @@
 ssize_t vringh_iov_push_kern(struct vringh_kiov *wiov,
 			     const void *src, size_t len)
 {
-	return vringh_iov_xfer(wiov, (void *)src, len, xfer_kern);
+	return vringh_iov_xfer(wiov, (void *)src, len, kern_xfer);
 }
 EXPORT_SYMBOL(vringh_iov_push_kern);
 
diff --git a/drivers/vhost/vsock.c b/drivers/vhost/vsock.c
index 98ed5be..9f57736 100644
--- a/drivers/vhost/vsock.c
+++ b/drivers/vhost/vsock.c
@@ -1,11 +1,10 @@
+// SPDX-License-Identifier: GPL-2.0-only
 /*
  * vhost transport for vsock
  *
  * Copyright (C) 2013-2015 Red Hat, Inc.
  * Author: Asias He <asias@redhat.com>
  *         Stefan Hajnoczi <stefanha@redhat.com>
- *
- * This work is licensed under the terms of the GNU GPL, version 2.
  */
 #include <linux/miscdevice.h>
 #include <linux/atomic.h>
@@ -21,20 +20,28 @@
 #include "vhost.h"
 
 #define VHOST_VSOCK_DEFAULT_HOST_CID	2
+/* Max number of bytes transferred before requeueing the job.
+ * Using this limit prevents one virtqueue from starving others. */
+#define VHOST_VSOCK_WEIGHT 0x80000
+/* Max number of packets transferred before requeueing the job.
+ * Using this limit prevents one virtqueue from starving others with
+ * small pkts.
+ */
+#define VHOST_VSOCK_PKT_WEIGHT 256
 
 enum {
 	VHOST_VSOCK_FEATURES = VHOST_FEATURES,
 };
 
 /* Used to track all the vhost_vsock instances on the system. */
-static DEFINE_SPINLOCK(vhost_vsock_lock);
+static DEFINE_MUTEX(vhost_vsock_mutex);
 static DEFINE_READ_MOSTLY_HASHTABLE(vhost_vsock_hash, 8);
 
 struct vhost_vsock {
 	struct vhost_dev dev;
 	struct vhost_virtqueue vqs[2];
 
-	/* Link to global vhost_vsock_hash, writes use vhost_vsock_lock */
+	/* Link to global vhost_vsock_hash, writes use vhost_vsock_mutex */
 	struct hlist_node hash;
 
 	struct vhost_work send_pkt_work;
@@ -51,7 +58,7 @@
 	return VHOST_VSOCK_DEFAULT_HOST_CID;
 }
 
-/* Callers that dereference the return value must hold vhost_vsock_lock or the
+/* Callers that dereference the return value must hold vhost_vsock_mutex or the
  * RCU read lock.
  */
 static struct vhost_vsock *vhost_vsock_get(u32 guest_cid)
@@ -78,6 +85,7 @@
 			    struct vhost_virtqueue *vq)
 {
 	struct vhost_virtqueue *tx_vq = &vsock->vqs[VSOCK_VQ_TX];
+	int pkts = 0, total_len = 0;
 	bool added = false;
 	bool restart_tx = false;
 
@@ -89,12 +97,12 @@
 	/* Avoid further vmexits, we're already processing the virtqueue */
 	vhost_disable_notify(&vsock->dev, vq);
 
-	for (;;) {
+	do {
 		struct virtio_vsock_pkt *pkt;
 		struct iov_iter iov_iter;
 		unsigned out, in;
 		size_t nbytes;
-		size_t len;
+		size_t iov_len, payload_len;
 		int head;
 
 		spin_lock_bh(&vsock->send_pkt_list_lock);
@@ -139,8 +147,24 @@
 			break;
 		}
 
-		len = iov_length(&vq->iov[out], in);
-		iov_iter_init(&iov_iter, READ, &vq->iov[out], in, len);
+		iov_len = iov_length(&vq->iov[out], in);
+		if (iov_len < sizeof(pkt->hdr)) {
+			virtio_transport_free_pkt(pkt);
+			vq_err(vq, "Buffer len [%zu] too small\n", iov_len);
+			break;
+		}
+
+		iov_iter_init(&iov_iter, READ, &vq->iov[out], in, iov_len);
+		payload_len = pkt->len - pkt->off;
+
+		/* If the packet is greater than the space available in the
+		 * buffer, we split it using multiple buffers.
+		 */
+		if (payload_len > iov_len - sizeof(pkt->hdr))
+			payload_len = iov_len - sizeof(pkt->hdr);
+
+		/* Set the correct length in the header */
+		pkt->hdr.len = cpu_to_le32(payload_len);
 
 		nbytes = copy_to_iter(&pkt->hdr, sizeof(pkt->hdr), &iov_iter);
 		if (nbytes != sizeof(pkt->hdr)) {
@@ -149,33 +173,48 @@
 			break;
 		}
 
-		nbytes = copy_to_iter(pkt->buf, pkt->len, &iov_iter);
-		if (nbytes != pkt->len) {
+		nbytes = copy_to_iter(pkt->buf + pkt->off, payload_len,
+				      &iov_iter);
+		if (nbytes != payload_len) {
 			virtio_transport_free_pkt(pkt);
 			vq_err(vq, "Faulted on copying pkt buf\n");
 			break;
 		}
 
-		vhost_add_used(vq, head, sizeof(pkt->hdr) + pkt->len);
+		vhost_add_used(vq, head, sizeof(pkt->hdr) + payload_len);
 		added = true;
 
-		if (pkt->reply) {
-			int val;
-
-			val = atomic_dec_return(&vsock->queued_replies);
-
-			/* Do we have resources to resume tx processing? */
-			if (val + 1 == tx_vq->num)
-				restart_tx = true;
-		}
-
 		/* Deliver to monitoring devices all correctly transmitted
 		 * packets.
 		 */
 		virtio_transport_deliver_tap_pkt(pkt);
 
-		virtio_transport_free_pkt(pkt);
-	}
+		pkt->off += payload_len;
+		total_len += payload_len;
+
+		/* If we didn't send all the payload we can requeue the packet
+		 * to send it with the next available buffer.
+		 */
+		if (pkt->off < pkt->len) {
+			spin_lock_bh(&vsock->send_pkt_list_lock);
+			list_add(&pkt->list, &vsock->send_pkt_list);
+			spin_unlock_bh(&vsock->send_pkt_list_lock);
+		} else {
+			if (pkt->reply) {
+				int val;
+
+				val = atomic_dec_return(&vsock->queued_replies);
+
+				/* Do we have resources to resume tx
+				 * processing?
+				 */
+				if (val + 1 == tx_vq->num)
+					restart_tx = true;
+			}
+
+			virtio_transport_free_pkt(pkt);
+		}
+	} while(likely(!vhost_exceeds_weight(vq, ++pkts, total_len)));
 	if (added)
 		vhost_signal(&vsock->dev, vq);
 
@@ -320,6 +359,8 @@
 		return NULL;
 	}
 
+	pkt->buf_len = pkt->len;
+
 	nbytes = copy_from_iter(pkt->buf, pkt->len, &iov_iter);
 	if (nbytes != pkt->len) {
 		vq_err(vq, "Expected %u byte payload, got %zu bytes\n",
@@ -350,7 +391,7 @@
 	struct vhost_vsock *vsock = container_of(vq->dev, struct vhost_vsock,
 						 dev);
 	struct virtio_vsock_pkt *pkt;
-	int head;
+	int head, pkts = 0, total_len = 0;
 	unsigned int out, in;
 	bool added = false;
 
@@ -360,7 +401,7 @@
 		goto out;
 
 	vhost_disable_notify(&vsock->dev, vq);
-	for (;;) {
+	do {
 		u32 len;
 
 		if (!vhost_vsock_more_replies(vsock)) {
@@ -401,9 +442,11 @@
 		else
 			virtio_transport_free_pkt(pkt);
 
-		vhost_add_used(vq, head, sizeof(pkt->hdr) + len);
+		len += sizeof(pkt->hdr);
+		vhost_add_used(vq, head, len);
+		total_len += len;
 		added = true;
-	}
+	} while(likely(!vhost_exceeds_weight(vq, ++pkts, total_len)));
 
 no_more_replies:
 	if (added)
@@ -531,7 +574,9 @@
 	vsock->vqs[VSOCK_VQ_TX].handle_kick = vhost_vsock_handle_tx_kick;
 	vsock->vqs[VSOCK_VQ_RX].handle_kick = vhost_vsock_handle_rx_kick;
 
-	vhost_dev_init(&vsock->dev, vqs, ARRAY_SIZE(vsock->vqs));
+	vhost_dev_init(&vsock->dev, vqs, ARRAY_SIZE(vsock->vqs),
+		       UIO_MAXIOV, VHOST_VSOCK_PKT_WEIGHT,
+		       VHOST_VSOCK_WEIGHT);
 
 	file->private_data = vsock;
 	spin_lock_init(&vsock->send_pkt_list_lock);
@@ -584,10 +629,10 @@
 {
 	struct vhost_vsock *vsock = file->private_data;
 
-	spin_lock_bh(&vhost_vsock_lock);
+	mutex_lock(&vhost_vsock_mutex);
 	if (vsock->guest_cid)
 		hash_del_rcu(&vsock->hash);
-	spin_unlock_bh(&vhost_vsock_lock);
+	mutex_unlock(&vhost_vsock_mutex);
 
 	/* Wait for other CPUs to finish using vsock */
 	synchronize_rcu();
@@ -631,10 +676,10 @@
 		return -EINVAL;
 
 	/* Refuse if CID is already in use */
-	spin_lock_bh(&vhost_vsock_lock);
+	mutex_lock(&vhost_vsock_mutex);
 	other = vhost_vsock_get(guest_cid);
 	if (other && other != vsock) {
-		spin_unlock_bh(&vhost_vsock_lock);
+		mutex_unlock(&vhost_vsock_mutex);
 		return -EADDRINUSE;
 	}
 
@@ -642,8 +687,8 @@
 		hash_del_rcu(&vsock->hash);
 
 	vsock->guest_cid = guest_cid;
-	hash_add_rcu(vhost_vsock_hash, &vsock->hash, guest_cid);
-	spin_unlock_bh(&vhost_vsock_lock);
+	hash_add_rcu(vhost_vsock_hash, &vsock->hash, vsock->guest_cid);
+	mutex_unlock(&vhost_vsock_mutex);
 
 	return 0;
 }