Update Linux to v5.4.2
Change-Id: Idf6911045d9d382da2cfe01b1edff026404ac8fd
diff --git a/drivers/vhost/Kconfig b/drivers/vhost/Kconfig
index b580885..3d03ccb 100644
--- a/drivers/vhost/Kconfig
+++ b/drivers/vhost/Kconfig
@@ -1,3 +1,4 @@
+# SPDX-License-Identifier: GPL-2.0-only
config VHOST_NET
tristate "Host kernel accelerator for virtio net"
depends on NET && EVENTFD && (TUN || !TUN) && (TAP || !TAP)
diff --git a/drivers/vhost/Kconfig.vringh b/drivers/vhost/Kconfig.vringh
index 6a4490c..c1fe36a 100644
--- a/drivers/vhost/Kconfig.vringh
+++ b/drivers/vhost/Kconfig.vringh
@@ -1,3 +1,4 @@
+# SPDX-License-Identifier: GPL-2.0-only
config VHOST_RING
tristate
---help---
diff --git a/drivers/vhost/net.c b/drivers/vhost/net.c
index 4e656f8..1a2dd53 100644
--- a/drivers/vhost/net.c
+++ b/drivers/vhost/net.c
@@ -1,8 +1,7 @@
+// SPDX-License-Identifier: GPL-2.0-only
/* Copyright (C) 2009 Red Hat, Inc.
* Author: Michael S. Tsirkin <mst@redhat.com>
*
- * This work is licensed under the terms of the GNU GPL, version 2.
- *
* virtio-net server in host kernel.
*/
@@ -36,7 +35,7 @@
#include "vhost.h"
-static int experimental_zcopytx = 1;
+static int experimental_zcopytx = 0;
module_param(experimental_zcopytx, int, 0444);
MODULE_PARM_DESC(experimental_zcopytx, "Enable Zero Copy TX;"
" 1 -Enable; 0 - Disable");
@@ -116,6 +115,8 @@
* For RX, number of batched heads
*/
int done_idx;
+ /* Number of XDP frames batched */
+ int batched_xdp;
/* an array of userspace buffers info */
struct ubuf_info *ubuf_info;
/* Reference counting for outstanding ubufs.
@@ -123,6 +124,8 @@
struct vhost_net_ubuf_ref *ubufs;
struct ptr_ring *rx_ring;
struct vhost_net_buf rxq;
+ /* Batched XDP buffs */
+ struct xdp_buff *xdp;
};
struct vhost_net {
@@ -137,6 +140,10 @@
unsigned tx_zcopy_err;
/* Flush in progress. Protected by tx vq lock. */
bool tx_flush;
+ /* Private page frag */
+ struct page_frag page_frag;
+ /* Refcount bias of page frag */
+ int refcnt_bias;
};
static unsigned vhost_net_zcopy_mask __read_mostly;
@@ -338,6 +345,11 @@
sock_flag(sock->sk, SOCK_ZEROCOPY);
}
+static bool vhost_sock_xdp(struct socket *sock)
+{
+ return sock_flag(sock->sk, SOCK_XDP);
+}
+
/* In case of DMA done not in order in lower device driver for some reason.
* upend_idx is used to track end of used idx, done_idx is used to track head
* of used idx. Once lower device DMA done contiguously, we will signal KVM
@@ -444,32 +456,126 @@
nvq->done_idx = 0;
}
-static int vhost_net_tx_get_vq_desc(struct vhost_net *net,
- struct vhost_net_virtqueue *nvq,
- unsigned int *out_num, unsigned int *in_num,
- bool *busyloop_intr)
+static void vhost_tx_batch(struct vhost_net *net,
+ struct vhost_net_virtqueue *nvq,
+ struct socket *sock,
+ struct msghdr *msghdr)
{
- struct vhost_virtqueue *vq = &nvq->vq;
- unsigned long uninitialized_var(endtime);
- int r = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),
+ struct tun_msg_ctl ctl = {
+ .type = TUN_MSG_PTR,
+ .num = nvq->batched_xdp,
+ .ptr = nvq->xdp,
+ };
+ int err;
+
+ if (nvq->batched_xdp == 0)
+ goto signal_used;
+
+ msghdr->msg_control = &ctl;
+ err = sock->ops->sendmsg(sock, msghdr, 0);
+ if (unlikely(err < 0)) {
+ vq_err(&nvq->vq, "Fail to batch sending packets\n");
+ return;
+ }
+
+signal_used:
+ vhost_net_signal_used(nvq);
+ nvq->batched_xdp = 0;
+}
+
+static int sock_has_rx_data(struct socket *sock)
+{
+ if (unlikely(!sock))
+ return 0;
+
+ if (sock->ops->peek_len)
+ return sock->ops->peek_len(sock);
+
+ return skb_queue_empty(&sock->sk->sk_receive_queue);
+}
+
+static void vhost_net_busy_poll_try_queue(struct vhost_net *net,
+ struct vhost_virtqueue *vq)
+{
+ if (!vhost_vq_avail_empty(&net->dev, vq)) {
+ vhost_poll_queue(&vq->poll);
+ } else if (unlikely(vhost_enable_notify(&net->dev, vq))) {
+ vhost_disable_notify(&net->dev, vq);
+ vhost_poll_queue(&vq->poll);
+ }
+}
+
+static void vhost_net_busy_poll(struct vhost_net *net,
+ struct vhost_virtqueue *rvq,
+ struct vhost_virtqueue *tvq,
+ bool *busyloop_intr,
+ bool poll_rx)
+{
+ unsigned long busyloop_timeout;
+ unsigned long endtime;
+ struct socket *sock;
+ struct vhost_virtqueue *vq = poll_rx ? tvq : rvq;
+
+ /* Try to hold the vq mutex of the paired virtqueue. We can't
+ * use mutex_lock() here since we could not guarantee a
+ * consistenet lock ordering.
+ */
+ if (!mutex_trylock(&vq->mutex))
+ return;
+
+ vhost_disable_notify(&net->dev, vq);
+ sock = rvq->private_data;
+
+ busyloop_timeout = poll_rx ? rvq->busyloop_timeout:
+ tvq->busyloop_timeout;
+
+ preempt_disable();
+ endtime = busy_clock() + busyloop_timeout;
+
+ while (vhost_can_busy_poll(endtime)) {
+ if (vhost_has_work(&net->dev)) {
+ *busyloop_intr = true;
+ break;
+ }
+
+ if ((sock_has_rx_data(sock) &&
+ !vhost_vq_avail_empty(&net->dev, rvq)) ||
+ !vhost_vq_avail_empty(&net->dev, tvq))
+ break;
+
+ cpu_relax();
+ }
+
+ preempt_enable();
+
+ if (poll_rx || sock_has_rx_data(sock))
+ vhost_net_busy_poll_try_queue(net, vq);
+ else if (!poll_rx) /* On tx here, sock has no rx data. */
+ vhost_enable_notify(&net->dev, rvq);
+
+ mutex_unlock(&vq->mutex);
+}
+
+static int vhost_net_tx_get_vq_desc(struct vhost_net *net,
+ struct vhost_net_virtqueue *tnvq,
+ unsigned int *out_num, unsigned int *in_num,
+ struct msghdr *msghdr, bool *busyloop_intr)
+{
+ struct vhost_net_virtqueue *rnvq = &net->vqs[VHOST_NET_VQ_RX];
+ struct vhost_virtqueue *rvq = &rnvq->vq;
+ struct vhost_virtqueue *tvq = &tnvq->vq;
+
+ int r = vhost_get_vq_desc(tvq, tvq->iov, ARRAY_SIZE(tvq->iov),
out_num, in_num, NULL, NULL);
- if (r == vq->num && vq->busyloop_timeout) {
- if (!vhost_sock_zcopy(vq->private_data))
- vhost_net_signal_used(nvq);
- preempt_disable();
- endtime = busy_clock() + vq->busyloop_timeout;
- while (vhost_can_busy_poll(endtime)) {
- if (vhost_has_work(vq->dev)) {
- *busyloop_intr = true;
- break;
- }
- if (!vhost_vq_avail_empty(vq->dev, vq))
- break;
- cpu_relax();
- }
- preempt_enable();
- r = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov),
+ if (r == tvq->num && tvq->busyloop_timeout) {
+ /* Flush batched packets first */
+ if (!vhost_sock_zcopy(tvq->private_data))
+ vhost_tx_batch(net, tnvq, tvq->private_data, msghdr);
+
+ vhost_net_busy_poll(net, rvq, tvq, busyloop_intr, false);
+
+ r = vhost_get_vq_desc(tvq, tvq->iov, ARRAY_SIZE(tvq->iov),
out_num, in_num, NULL, NULL);
}
@@ -497,12 +603,6 @@
return iov_iter_count(iter);
}
-static bool vhost_exceeds_weight(int pkts, int total_len)
-{
- return total_len >= VHOST_NET_WEIGHT ||
- pkts >= VHOST_NET_PKT_WEIGHT;
-}
-
static int get_tx_bufs(struct vhost_net *net,
struct vhost_net_virtqueue *nvq,
struct msghdr *msg,
@@ -512,7 +612,7 @@
struct vhost_virtqueue *vq = &nvq->vq;
int ret;
- ret = vhost_net_tx_get_vq_desc(net, nvq, out, in, busyloop_intr);
+ ret = vhost_net_tx_get_vq_desc(net, nvq, out, in, msg, busyloop_intr);
if (ret < 0 || ret == vq->num)
return ret;
@@ -540,6 +640,120 @@
!vhost_vq_avail_empty(vq->dev, vq);
}
+#define SKB_FRAG_PAGE_ORDER get_order(32768)
+
+static bool vhost_net_page_frag_refill(struct vhost_net *net, unsigned int sz,
+ struct page_frag *pfrag, gfp_t gfp)
+{
+ if (pfrag->page) {
+ if (pfrag->offset + sz <= pfrag->size)
+ return true;
+ __page_frag_cache_drain(pfrag->page, net->refcnt_bias);
+ }
+
+ pfrag->offset = 0;
+ net->refcnt_bias = 0;
+ if (SKB_FRAG_PAGE_ORDER) {
+ /* Avoid direct reclaim but allow kswapd to wake */
+ pfrag->page = alloc_pages((gfp & ~__GFP_DIRECT_RECLAIM) |
+ __GFP_COMP | __GFP_NOWARN |
+ __GFP_NORETRY,
+ SKB_FRAG_PAGE_ORDER);
+ if (likely(pfrag->page)) {
+ pfrag->size = PAGE_SIZE << SKB_FRAG_PAGE_ORDER;
+ goto done;
+ }
+ }
+ pfrag->page = alloc_page(gfp);
+ if (likely(pfrag->page)) {
+ pfrag->size = PAGE_SIZE;
+ goto done;
+ }
+ return false;
+
+done:
+ net->refcnt_bias = USHRT_MAX;
+ page_ref_add(pfrag->page, USHRT_MAX - 1);
+ return true;
+}
+
+#define VHOST_NET_RX_PAD (NET_IP_ALIGN + NET_SKB_PAD)
+
+static int vhost_net_build_xdp(struct vhost_net_virtqueue *nvq,
+ struct iov_iter *from)
+{
+ struct vhost_virtqueue *vq = &nvq->vq;
+ struct vhost_net *net = container_of(vq->dev, struct vhost_net,
+ dev);
+ struct socket *sock = vq->private_data;
+ struct page_frag *alloc_frag = &net->page_frag;
+ struct virtio_net_hdr *gso;
+ struct xdp_buff *xdp = &nvq->xdp[nvq->batched_xdp];
+ struct tun_xdp_hdr *hdr;
+ size_t len = iov_iter_count(from);
+ int headroom = vhost_sock_xdp(sock) ? XDP_PACKET_HEADROOM : 0;
+ int buflen = SKB_DATA_ALIGN(sizeof(struct skb_shared_info));
+ int pad = SKB_DATA_ALIGN(VHOST_NET_RX_PAD + headroom + nvq->sock_hlen);
+ int sock_hlen = nvq->sock_hlen;
+ void *buf;
+ int copied;
+
+ if (unlikely(len < nvq->sock_hlen))
+ return -EFAULT;
+
+ if (SKB_DATA_ALIGN(len + pad) +
+ SKB_DATA_ALIGN(sizeof(struct skb_shared_info)) > PAGE_SIZE)
+ return -ENOSPC;
+
+ buflen += SKB_DATA_ALIGN(len + pad);
+ alloc_frag->offset = ALIGN((u64)alloc_frag->offset, SMP_CACHE_BYTES);
+ if (unlikely(!vhost_net_page_frag_refill(net, buflen,
+ alloc_frag, GFP_KERNEL)))
+ return -ENOMEM;
+
+ buf = (char *)page_address(alloc_frag->page) + alloc_frag->offset;
+ copied = copy_page_from_iter(alloc_frag->page,
+ alloc_frag->offset +
+ offsetof(struct tun_xdp_hdr, gso),
+ sock_hlen, from);
+ if (copied != sock_hlen)
+ return -EFAULT;
+
+ hdr = buf;
+ gso = &hdr->gso;
+
+ if ((gso->flags & VIRTIO_NET_HDR_F_NEEDS_CSUM) &&
+ vhost16_to_cpu(vq, gso->csum_start) +
+ vhost16_to_cpu(vq, gso->csum_offset) + 2 >
+ vhost16_to_cpu(vq, gso->hdr_len)) {
+ gso->hdr_len = cpu_to_vhost16(vq,
+ vhost16_to_cpu(vq, gso->csum_start) +
+ vhost16_to_cpu(vq, gso->csum_offset) + 2);
+
+ if (vhost16_to_cpu(vq, gso->hdr_len) > len)
+ return -EINVAL;
+ }
+
+ len -= sock_hlen;
+ copied = copy_page_from_iter(alloc_frag->page,
+ alloc_frag->offset + pad,
+ len, from);
+ if (copied != len)
+ return -EFAULT;
+
+ xdp->data_hard_start = buf;
+ xdp->data = buf + pad;
+ xdp->data_end = xdp->data + len;
+ hdr->buflen = buflen;
+
+ --net->refcnt_bias;
+ alloc_frag->offset += buflen;
+
+ ++nvq->batched_xdp;
+
+ return 0;
+}
+
static void handle_tx_copy(struct vhost_net *net, struct socket *sock)
{
struct vhost_net_virtqueue *nvq = &net->vqs[VHOST_NET_VQ_TX];
@@ -556,10 +770,14 @@
size_t len, total_len = 0;
int err;
int sent_pkts = 0;
+ bool sock_can_batch = (sock->sk->sk_sndbuf == INT_MAX);
- for (;;) {
+ do {
bool busyloop_intr = false;
+ if (nvq->done_idx == VHOST_NET_BATCH)
+ vhost_tx_batch(net, nvq, sock, &msg);
+
head = get_tx_bufs(net, nvq, &msg, &out, &in, &len,
&busyloop_intr);
/* On error, stop handling until the next kick. */
@@ -577,14 +795,34 @@
break;
}
- vq->heads[nvq->done_idx].id = cpu_to_vhost32(vq, head);
- vq->heads[nvq->done_idx].len = 0;
-
total_len += len;
- if (tx_can_batch(vq, total_len))
- msg.msg_flags |= MSG_MORE;
- else
- msg.msg_flags &= ~MSG_MORE;
+
+ /* For simplicity, TX batching is only enabled if
+ * sndbuf is unlimited.
+ */
+ if (sock_can_batch) {
+ err = vhost_net_build_xdp(nvq, &msg.msg_iter);
+ if (!err) {
+ goto done;
+ } else if (unlikely(err != -ENOSPC)) {
+ vhost_tx_batch(net, nvq, sock, &msg);
+ vhost_discard_vq_desc(vq, 1);
+ vhost_net_enable_vq(net, vq);
+ break;
+ }
+
+ /* We can't build XDP buff, go for single
+ * packet path but let's flush batched
+ * packets.
+ */
+ vhost_tx_batch(net, nvq, sock, &msg);
+ msg.msg_control = NULL;
+ } else {
+ if (tx_can_batch(vq, total_len))
+ msg.msg_flags |= MSG_MORE;
+ else
+ msg.msg_flags &= ~MSG_MORE;
+ }
/* TODO: Check specific error and bomb out unless ENOBUFS? */
err = sock->ops->sendmsg(sock, &msg, len);
@@ -596,15 +834,13 @@
if (err != len)
pr_debug("Truncated TX packet: len %d != %zd\n",
err, len);
- if (++nvq->done_idx >= VHOST_NET_BATCH)
- vhost_net_signal_used(nvq);
- if (vhost_exceeds_weight(++sent_pkts, total_len)) {
- vhost_poll_queue(&vq->poll);
- break;
- }
- }
+done:
+ vq->heads[nvq->done_idx].id = cpu_to_vhost32(vq, head);
+ vq->heads[nvq->done_idx].len = 0;
+ ++nvq->done_idx;
+ } while (likely(!vhost_exceeds_weight(vq, ++sent_pkts, total_len)));
- vhost_net_signal_used(nvq);
+ vhost_tx_batch(net, nvq, sock, &msg);
}
static void handle_tx_zerocopy(struct vhost_net *net, struct socket *sock)
@@ -620,13 +856,14 @@
.msg_controllen = 0,
.msg_flags = MSG_DONTWAIT,
};
+ struct tun_msg_ctl ctl;
size_t len, total_len = 0;
int err;
struct vhost_net_ubuf_ref *uninitialized_var(ubufs);
bool zcopy_used;
int sent_pkts = 0;
- for (;;) {
+ do {
bool busyloop_intr;
/* Release DMAs done buffers first */
@@ -664,8 +901,10 @@
ubuf->ctx = nvq->ubufs;
ubuf->desc = nvq->upend_idx;
refcount_set(&ubuf->refcnt, 1);
- msg.msg_control = ubuf;
- msg.msg_controllen = sizeof(ubuf);
+ msg.msg_control = &ctl;
+ ctl.type = TUN_MSG_UBUF;
+ ctl.ptr = ubuf;
+ msg.msg_controllen = sizeof(ctl);
ubufs = nvq->ubufs;
atomic_inc(&ubufs->refcount);
nvq->upend_idx = (nvq->upend_idx + 1) % UIO_MAXIOV;
@@ -701,11 +940,7 @@
else
vhost_zerocopy_signal_used(net, vq);
vhost_net_tx_packet(net);
- if (unlikely(vhost_exceeds_weight(++sent_pkts, total_len))) {
- vhost_poll_queue(&vq->poll);
- break;
- }
- }
+ } while (likely(!vhost_exceeds_weight(vq, ++sent_pkts, total_len)));
}
/* Expects to be always run from workqueue - which acts as
@@ -716,12 +951,12 @@
struct vhost_virtqueue *vq = &nvq->vq;
struct socket *sock;
- mutex_lock(&vq->mutex);
+ mutex_lock_nested(&vq->mutex, VHOST_NET_VQ_TX);
sock = vq->private_data;
if (!sock)
goto out;
- if (!vq_iotlb_prefetch(vq))
+ if (!vq_meta_prefetch(vq))
goto out;
vhost_disable_notify(&net->dev, vq);
@@ -757,16 +992,6 @@
return len;
}
-static int sk_has_rx_data(struct sock *sk)
-{
- struct socket *sock = sk->sk_socket;
-
- if (sock->ops->peek_len)
- return sock->ops->peek_len(sock);
-
- return skb_queue_empty(&sk->sk_receive_queue);
-}
-
static int vhost_net_rx_peek_head_len(struct vhost_net *net, struct sock *sk,
bool *busyloop_intr)
{
@@ -774,41 +999,13 @@
struct vhost_net_virtqueue *tnvq = &net->vqs[VHOST_NET_VQ_TX];
struct vhost_virtqueue *rvq = &rnvq->vq;
struct vhost_virtqueue *tvq = &tnvq->vq;
- unsigned long uninitialized_var(endtime);
int len = peek_head_len(rnvq, sk);
- if (!len && tvq->busyloop_timeout) {
+ if (!len && rvq->busyloop_timeout) {
/* Flush batched heads first */
vhost_net_signal_used(rnvq);
/* Both tx vq and rx socket were polled here */
- mutex_lock_nested(&tvq->mutex, 1);
- vhost_disable_notify(&net->dev, tvq);
-
- preempt_disable();
- endtime = busy_clock() + tvq->busyloop_timeout;
-
- while (vhost_can_busy_poll(endtime)) {
- if (vhost_has_work(&net->dev)) {
- *busyloop_intr = true;
- break;
- }
- if ((sk_has_rx_data(sk) &&
- !vhost_vq_avail_empty(&net->dev, rvq)) ||
- !vhost_vq_avail_empty(&net->dev, tvq))
- break;
- cpu_relax();
- }
-
- preempt_enable();
-
- if (!vhost_vq_avail_empty(&net->dev, tvq)) {
- vhost_poll_queue(&tvq->poll);
- } else if (unlikely(vhost_enable_notify(&net->dev, tvq))) {
- vhost_disable_notify(&net->dev, tvq);
- vhost_poll_queue(&tvq->poll);
- }
-
- mutex_unlock(&tvq->mutex);
+ vhost_net_busy_poll(net, rvq, tvq, busyloop_intr, true);
len = peek_head_len(rnvq, sk);
}
@@ -923,12 +1120,12 @@
__virtio16 num_buffers;
int recv_pkts = 0;
- mutex_lock_nested(&vq->mutex, 0);
+ mutex_lock_nested(&vq->mutex, VHOST_NET_VQ_RX);
sock = vq->private_data;
if (!sock)
goto out;
- if (!vq_iotlb_prefetch(vq))
+ if (!vq_meta_prefetch(vq))
goto out;
vhost_disable_notify(&net->dev, vq);
@@ -941,8 +1138,11 @@
vq->log : NULL;
mergeable = vhost_has_feature(vq, VIRTIO_NET_F_MRG_RXBUF);
- while ((sock_len = vhost_net_rx_peek_head_len(net, sock->sk,
- &busyloop_intr))) {
+ do {
+ sock_len = vhost_net_rx_peek_head_len(net, sock->sk,
+ &busyloop_intr);
+ if (!sock_len)
+ break;
sock_len += sock_hlen;
vhost_len = sock_len + vhost_hlen;
headcount = get_rx_bufs(vq, vq->heads + nvq->done_idx,
@@ -1024,16 +1224,14 @@
if (nvq->done_idx > VHOST_NET_BATCH)
vhost_net_signal_used(nvq);
if (unlikely(vq_log))
- vhost_log_write(vq, vq_log, log, vhost_len);
+ vhost_log_write(vq, vq_log, log, vhost_len,
+ vq->iov, in);
total_len += vhost_len;
- if (unlikely(vhost_exceeds_weight(++recv_pkts, total_len))) {
- vhost_poll_queue(&vq->poll);
- goto out;
- }
- }
+ } while (likely(!vhost_exceeds_weight(vq, ++recv_pkts, total_len)));
+
if (unlikely(busyloop_intr))
vhost_poll_queue(&vq->poll);
- else
+ else if (!sock_len)
vhost_net_enable_vq(net, vq);
out:
vhost_net_signal_used(nvq);
@@ -1078,6 +1276,7 @@
struct vhost_dev *dev;
struct vhost_virtqueue **vqs;
void **queue;
+ struct xdp_buff *xdp;
int i;
n = kvmalloc(sizeof *n, GFP_KERNEL | __GFP_RETRY_MAYFAIL);
@@ -1098,6 +1297,15 @@
}
n->vqs[VHOST_NET_VQ_RX].rxq.queue = queue;
+ xdp = kmalloc_array(VHOST_NET_BATCH, sizeof(*xdp), GFP_KERNEL);
+ if (!xdp) {
+ kfree(vqs);
+ kvfree(n);
+ kfree(queue);
+ return -ENOMEM;
+ }
+ n->vqs[VHOST_NET_VQ_TX].xdp = xdp;
+
dev = &n->dev;
vqs[VHOST_NET_VQ_TX] = &n->vqs[VHOST_NET_VQ_TX].vq;
vqs[VHOST_NET_VQ_RX] = &n->vqs[VHOST_NET_VQ_RX].vq;
@@ -1108,17 +1316,22 @@
n->vqs[i].ubuf_info = NULL;
n->vqs[i].upend_idx = 0;
n->vqs[i].done_idx = 0;
+ n->vqs[i].batched_xdp = 0;
n->vqs[i].vhost_hlen = 0;
n->vqs[i].sock_hlen = 0;
n->vqs[i].rx_ring = NULL;
vhost_net_buf_init(&n->vqs[i].rxq);
}
- vhost_dev_init(dev, vqs, VHOST_NET_VQ_MAX);
+ vhost_dev_init(dev, vqs, VHOST_NET_VQ_MAX,
+ UIO_MAXIOV + VHOST_NET_BATCH,
+ VHOST_NET_PKT_WEIGHT, VHOST_NET_WEIGHT);
vhost_poll_init(n->poll + VHOST_NET_VQ_TX, handle_tx_net, EPOLLOUT, dev);
vhost_poll_init(n->poll + VHOST_NET_VQ_RX, handle_rx_net, EPOLLIN, dev);
f->private_data = n;
+ n->page_frag.page = NULL;
+ n->refcnt_bias = 0;
return 0;
}
@@ -1186,12 +1399,15 @@
if (rx_sock)
sockfd_put(rx_sock);
/* Make sure no callbacks are outstanding */
- synchronize_rcu_bh();
+ synchronize_rcu();
/* We do an extra flush before freeing memory,
* since jobs can re-queue themselves. */
vhost_net_flush(n);
kfree(n->vqs[VHOST_NET_VQ_RX].rxq.queue);
+ kfree(n->vqs[VHOST_NET_VQ_TX].xdp);
kfree(n->dev.vqs);
+ if (n->page_frag.page)
+ __page_frag_cache_drain(n->page_frag.page, n->refcnt_bias);
kvfree(n);
return 0;
}
diff --git a/drivers/vhost/scsi.c b/drivers/vhost/scsi.c
index e7e3ae1..a9caf1b 100644
--- a/drivers/vhost/scsi.c
+++ b/drivers/vhost/scsi.c
@@ -57,6 +57,12 @@
#define VHOST_SCSI_PREALLOC_UPAGES 2048
#define VHOST_SCSI_PREALLOC_PROT_SGLS 2048
+/* Max number of requests before requeueing the job.
+ * Using this limit prevents one virtqueue from starving others with
+ * request.
+ */
+#define VHOST_SCSI_WEIGHT 256
+
struct vhost_scsi_inflight {
/* Wait for the flush operation to finish */
struct completion comp;
@@ -203,6 +209,19 @@
int vs_events_nr; /* num of pending events, protected by vq->mutex */
};
+/*
+ * Context for processing request and control queue operations.
+ */
+struct vhost_scsi_ctx {
+ int head;
+ unsigned int out, in;
+ size_t req_size, rsp_size;
+ size_t out_size, in_size;
+ u8 *target, *lunp;
+ void *req;
+ struct iov_iter out_iter;
+};
+
static struct workqueue_struct *vhost_scsi_workqueue;
/* Global spinlock to protect vhost_scsi TPG list for vhost IOCTL access */
@@ -272,11 +291,6 @@
return 0;
}
-static char *vhost_scsi_get_fabric_name(void)
-{
- return "vhost";
-}
-
static char *vhost_scsi_get_fabric_wwn(struct se_portal_group *se_tpg)
{
struct vhost_scsi_tpg *tpg = container_of(se_tpg,
@@ -338,11 +352,6 @@
return 0;
}
-static int vhost_scsi_write_pending_status(struct se_cmd *se_cmd)
-{
- return 0;
-}
-
static void vhost_scsi_set_default_node_attrs(struct se_node_acl *nacl)
{
return;
@@ -800,24 +809,120 @@
pr_err("Faulted on virtio_scsi_cmd_resp\n");
}
+static int
+vhost_scsi_get_desc(struct vhost_scsi *vs, struct vhost_virtqueue *vq,
+ struct vhost_scsi_ctx *vc)
+{
+ int ret = -ENXIO;
+
+ vc->head = vhost_get_vq_desc(vq, vq->iov,
+ ARRAY_SIZE(vq->iov), &vc->out, &vc->in,
+ NULL, NULL);
+
+ pr_debug("vhost_get_vq_desc: head: %d, out: %u in: %u\n",
+ vc->head, vc->out, vc->in);
+
+ /* On error, stop handling until the next kick. */
+ if (unlikely(vc->head < 0))
+ goto done;
+
+ /* Nothing new? Wait for eventfd to tell us they refilled. */
+ if (vc->head == vq->num) {
+ if (unlikely(vhost_enable_notify(&vs->dev, vq))) {
+ vhost_disable_notify(&vs->dev, vq);
+ ret = -EAGAIN;
+ }
+ goto done;
+ }
+
+ /*
+ * Get the size of request and response buffers.
+ * FIXME: Not correct for BIDI operation
+ */
+ vc->out_size = iov_length(vq->iov, vc->out);
+ vc->in_size = iov_length(&vq->iov[vc->out], vc->in);
+
+ /*
+ * Copy over the virtio-scsi request header, which for a
+ * ANY_LAYOUT enabled guest may span multiple iovecs, or a
+ * single iovec may contain both the header + outgoing
+ * WRITE payloads.
+ *
+ * copy_from_iter() will advance out_iter, so that it will
+ * point at the start of the outgoing WRITE payload, if
+ * DMA_TO_DEVICE is set.
+ */
+ iov_iter_init(&vc->out_iter, WRITE, vq->iov, vc->out, vc->out_size);
+ ret = 0;
+
+done:
+ return ret;
+}
+
+static int
+vhost_scsi_chk_size(struct vhost_virtqueue *vq, struct vhost_scsi_ctx *vc)
+{
+ if (unlikely(vc->in_size < vc->rsp_size)) {
+ vq_err(vq,
+ "Response buf too small, need min %zu bytes got %zu",
+ vc->rsp_size, vc->in_size);
+ return -EINVAL;
+ } else if (unlikely(vc->out_size < vc->req_size)) {
+ vq_err(vq,
+ "Request buf too small, need min %zu bytes got %zu",
+ vc->req_size, vc->out_size);
+ return -EIO;
+ }
+
+ return 0;
+}
+
+static int
+vhost_scsi_get_req(struct vhost_virtqueue *vq, struct vhost_scsi_ctx *vc,
+ struct vhost_scsi_tpg **tpgp)
+{
+ int ret = -EIO;
+
+ if (unlikely(!copy_from_iter_full(vc->req, vc->req_size,
+ &vc->out_iter))) {
+ vq_err(vq, "Faulted on copy_from_iter_full\n");
+ } else if (unlikely(*vc->lunp != 1)) {
+ /* virtio-scsi spec requires byte 0 of the lun to be 1 */
+ vq_err(vq, "Illegal virtio-scsi lun: %u\n", *vc->lunp);
+ } else {
+ struct vhost_scsi_tpg **vs_tpg, *tpg;
+
+ vs_tpg = vq->private_data; /* validated at handler entry */
+
+ tpg = READ_ONCE(vs_tpg[*vc->target]);
+ if (unlikely(!tpg)) {
+ vq_err(vq, "Target 0x%x does not exist\n", *vc->target);
+ } else {
+ if (tpgp)
+ *tpgp = tpg;
+ ret = 0;
+ }
+ }
+
+ return ret;
+}
+
static void
vhost_scsi_handle_vq(struct vhost_scsi *vs, struct vhost_virtqueue *vq)
{
struct vhost_scsi_tpg **vs_tpg, *tpg;
struct virtio_scsi_cmd_req v_req;
struct virtio_scsi_cmd_req_pi v_req_pi;
+ struct vhost_scsi_ctx vc;
struct vhost_scsi_cmd *cmd;
- struct iov_iter out_iter, in_iter, prot_iter, data_iter;
+ struct iov_iter in_iter, prot_iter, data_iter;
u64 tag;
u32 exp_data_len, data_direction;
- unsigned int out = 0, in = 0;
- int head, ret, prot_bytes;
- size_t req_size, rsp_size = sizeof(struct virtio_scsi_cmd_resp);
- size_t out_size, in_size;
+ int ret, prot_bytes, c = 0;
u16 lun;
- u8 *target, *lunp, task_attr;
+ u8 task_attr;
bool t10_pi = vhost_has_feature(vq, VIRTIO_SCSI_F_T10_PI);
- void *req, *cdb;
+ void *cdb;
mutex_lock(&vq->mutex);
/*
@@ -828,85 +933,47 @@
if (!vs_tpg)
goto out;
+ memset(&vc, 0, sizeof(vc));
+ vc.rsp_size = sizeof(struct virtio_scsi_cmd_resp);
+
vhost_disable_notify(&vs->dev, vq);
- for (;;) {
- head = vhost_get_vq_desc(vq, vq->iov,
- ARRAY_SIZE(vq->iov), &out, &in,
- NULL, NULL);
- pr_debug("vhost_get_vq_desc: head: %d, out: %u in: %u\n",
- head, out, in);
- /* On error, stop handling until the next kick. */
- if (unlikely(head < 0))
- break;
- /* Nothing new? Wait for eventfd to tell us they refilled. */
- if (head == vq->num) {
- if (unlikely(vhost_enable_notify(&vs->dev, vq))) {
- vhost_disable_notify(&vs->dev, vq);
- continue;
- }
- break;
- }
- /*
- * Check for a sane response buffer so we can report early
- * errors back to the guest.
- */
- if (unlikely(vq->iov[out].iov_len < rsp_size)) {
- vq_err(vq, "Expecting at least virtio_scsi_cmd_resp"
- " size, got %zu bytes\n", vq->iov[out].iov_len);
- break;
- }
+ do {
+ ret = vhost_scsi_get_desc(vs, vq, &vc);
+ if (ret)
+ goto err;
+
/*
* Setup pointers and values based upon different virtio-scsi
* request header if T10_PI is enabled in KVM guest.
*/
if (t10_pi) {
- req = &v_req_pi;
- req_size = sizeof(v_req_pi);
- lunp = &v_req_pi.lun[0];
- target = &v_req_pi.lun[1];
+ vc.req = &v_req_pi;
+ vc.req_size = sizeof(v_req_pi);
+ vc.lunp = &v_req_pi.lun[0];
+ vc.target = &v_req_pi.lun[1];
} else {
- req = &v_req;
- req_size = sizeof(v_req);
- lunp = &v_req.lun[0];
- target = &v_req.lun[1];
+ vc.req = &v_req;
+ vc.req_size = sizeof(v_req);
+ vc.lunp = &v_req.lun[0];
+ vc.target = &v_req.lun[1];
}
- /*
- * FIXME: Not correct for BIDI operation
- */
- out_size = iov_length(vq->iov, out);
- in_size = iov_length(&vq->iov[out], in);
/*
- * Copy over the virtio-scsi request header, which for a
- * ANY_LAYOUT enabled guest may span multiple iovecs, or a
- * single iovec may contain both the header + outgoing
- * WRITE payloads.
- *
- * copy_from_iter() will advance out_iter, so that it will
- * point at the start of the outgoing WRITE payload, if
- * DMA_TO_DEVICE is set.
+ * Validate the size of request and response buffers.
+ * Check for a sane response buffer so we can report
+ * early errors back to the guest.
*/
- iov_iter_init(&out_iter, WRITE, vq->iov, out, out_size);
+ ret = vhost_scsi_chk_size(vq, &vc);
+ if (ret)
+ goto err;
- if (unlikely(!copy_from_iter_full(req, req_size, &out_iter))) {
- vq_err(vq, "Faulted on copy_from_iter\n");
- vhost_scsi_send_bad_target(vs, vq, head, out);
- continue;
- }
- /* virtio-scsi spec requires byte 0 of the lun to be 1 */
- if (unlikely(*lunp != 1)) {
- vq_err(vq, "Illegal virtio-scsi lun: %u\n", *lunp);
- vhost_scsi_send_bad_target(vs, vq, head, out);
- continue;
- }
+ ret = vhost_scsi_get_req(vq, &vc, &tpg);
+ if (ret)
+ goto err;
- tpg = READ_ONCE(vs_tpg[*target]);
- if (unlikely(!tpg)) {
- /* Target does not exist, fail the request */
- vhost_scsi_send_bad_target(vs, vq, head, out);
- continue;
- }
+ ret = -EIO; /* bad target on any error from here on */
+
/*
* Determine data_direction by calculating the total outgoing
* iovec sizes + incoming iovec sizes vs. virtio-scsi request +
@@ -924,17 +991,17 @@
*/
prot_bytes = 0;
- if (out_size > req_size) {
+ if (vc.out_size > vc.req_size) {
data_direction = DMA_TO_DEVICE;
- exp_data_len = out_size - req_size;
- data_iter = out_iter;
- } else if (in_size > rsp_size) {
+ exp_data_len = vc.out_size - vc.req_size;
+ data_iter = vc.out_iter;
+ } else if (vc.in_size > vc.rsp_size) {
data_direction = DMA_FROM_DEVICE;
- exp_data_len = in_size - rsp_size;
+ exp_data_len = vc.in_size - vc.rsp_size;
- iov_iter_init(&in_iter, READ, &vq->iov[out], in,
- rsp_size + exp_data_len);
- iov_iter_advance(&in_iter, rsp_size);
+ iov_iter_init(&in_iter, READ, &vq->iov[vc.out], vc.in,
+ vc.rsp_size + exp_data_len);
+ iov_iter_advance(&in_iter, vc.rsp_size);
data_iter = in_iter;
} else {
data_direction = DMA_NONE;
@@ -950,16 +1017,14 @@
if (data_direction != DMA_TO_DEVICE) {
vq_err(vq, "Received non zero pi_bytesout,"
" but wrong data_direction\n");
- vhost_scsi_send_bad_target(vs, vq, head, out);
- continue;
+ goto err;
}
prot_bytes = vhost32_to_cpu(vq, v_req_pi.pi_bytesout);
} else if (v_req_pi.pi_bytesin) {
if (data_direction != DMA_FROM_DEVICE) {
vq_err(vq, "Received non zero pi_bytesin,"
" but wrong data_direction\n");
- vhost_scsi_send_bad_target(vs, vq, head, out);
- continue;
+ goto err;
}
prot_bytes = vhost32_to_cpu(vq, v_req_pi.pi_bytesin);
}
@@ -998,8 +1063,7 @@
vq_err(vq, "Received SCSI CDB with command_size: %d that"
" exceeds SCSI_MAX_VARLEN_CDB_SIZE: %d\n",
scsi_command_size(cdb), VHOST_SCSI_MAX_CDB_SIZE);
- vhost_scsi_send_bad_target(vs, vq, head, out);
- continue;
+ goto err;
}
cmd = vhost_scsi_get_tag(vq, tpg, cdb, tag, lun, task_attr,
exp_data_len + prot_bytes,
@@ -1007,13 +1071,12 @@
if (IS_ERR(cmd)) {
vq_err(vq, "vhost_scsi_get_tag failed %ld\n",
PTR_ERR(cmd));
- vhost_scsi_send_bad_target(vs, vq, head, out);
- continue;
+ goto err;
}
cmd->tvc_vhost = vs;
cmd->tvc_vq = vq;
- cmd->tvc_resp_iov = vq->iov[out];
- cmd->tvc_in_iovs = in;
+ cmd->tvc_resp_iov = vq->iov[vc.out];
+ cmd->tvc_in_iovs = vc.in;
pr_debug("vhost_scsi got command opcode: %#02x, lun: %d\n",
cmd->tvc_cdb[0], cmd->tvc_lun);
@@ -1021,14 +1084,12 @@
" %d\n", cmd, exp_data_len, prot_bytes, data_direction);
if (data_direction != DMA_NONE) {
- ret = vhost_scsi_mapal(cmd,
- prot_bytes, &prot_iter,
- exp_data_len, &data_iter);
- if (unlikely(ret)) {
+ if (unlikely(vhost_scsi_mapal(cmd, prot_bytes,
+ &prot_iter, exp_data_len,
+ &data_iter))) {
vq_err(vq, "Failed to map iov to sgl\n");
vhost_scsi_release_cmd(&cmd->tvc_se_cmd);
- vhost_scsi_send_bad_target(vs, vq, head, out);
- continue;
+ goto err;
}
}
/*
@@ -1036,7 +1097,7 @@
* complete the virtio-scsi request in TCM callback context via
* vhost_scsi_queue_data_in() and vhost_scsi_queue_status()
*/
- cmd->tvc_vq_desc = head;
+ cmd->tvc_vq_desc = vc.head;
/*
* Dispatch cmd descriptor for cmwq execution in process
* context provided by vhost_scsi_workqueue. This also ensures
@@ -1045,14 +1106,183 @@
*/
INIT_WORK(&cmd->work, vhost_scsi_submission_work);
queue_work(vhost_scsi_workqueue, &cmd->work);
- }
+ ret = 0;
+err:
+ /*
+ * ENXIO: No more requests, or read error, wait for next kick
+ * EINVAL: Invalid response buffer, drop the request
+ * EIO: Respond with bad target
+ * EAGAIN: Pending request
+ */
+ if (ret == -ENXIO)
+ break;
+ else if (ret == -EIO)
+ vhost_scsi_send_bad_target(vs, vq, vc.head, vc.out);
+ } while (likely(!vhost_exceeds_weight(vq, ++c, 0)));
+out:
+ mutex_unlock(&vq->mutex);
+}
+
+static void
+vhost_scsi_send_tmf_reject(struct vhost_scsi *vs,
+ struct vhost_virtqueue *vq,
+ struct vhost_scsi_ctx *vc)
+{
+ struct virtio_scsi_ctrl_tmf_resp rsp;
+ struct iov_iter iov_iter;
+ int ret;
+
+ pr_debug("%s\n", __func__);
+ memset(&rsp, 0, sizeof(rsp));
+ rsp.response = VIRTIO_SCSI_S_FUNCTION_REJECTED;
+
+ iov_iter_init(&iov_iter, READ, &vq->iov[vc->out], vc->in, sizeof(rsp));
+
+ ret = copy_to_iter(&rsp, sizeof(rsp), &iov_iter);
+ if (likely(ret == sizeof(rsp)))
+ vhost_add_used_and_signal(&vs->dev, vq, vc->head, 0);
+ else
+ pr_err("Faulted on virtio_scsi_ctrl_tmf_resp\n");
+}
+
+static void
+vhost_scsi_send_an_resp(struct vhost_scsi *vs,
+ struct vhost_virtqueue *vq,
+ struct vhost_scsi_ctx *vc)
+{
+ struct virtio_scsi_ctrl_an_resp rsp;
+ struct iov_iter iov_iter;
+ int ret;
+
+ pr_debug("%s\n", __func__);
+ memset(&rsp, 0, sizeof(rsp)); /* event_actual = 0 */
+ rsp.response = VIRTIO_SCSI_S_OK;
+
+ iov_iter_init(&iov_iter, READ, &vq->iov[vc->out], vc->in, sizeof(rsp));
+
+ ret = copy_to_iter(&rsp, sizeof(rsp), &iov_iter);
+ if (likely(ret == sizeof(rsp)))
+ vhost_add_used_and_signal(&vs->dev, vq, vc->head, 0);
+ else
+ pr_err("Faulted on virtio_scsi_ctrl_an_resp\n");
+}
+
+static void
+vhost_scsi_ctl_handle_vq(struct vhost_scsi *vs, struct vhost_virtqueue *vq)
+{
+ union {
+ __virtio32 type;
+ struct virtio_scsi_ctrl_an_req an;
+ struct virtio_scsi_ctrl_tmf_req tmf;
+ } v_req;
+ struct vhost_scsi_ctx vc;
+ size_t typ_size;
+ int ret, c = 0;
+
+ mutex_lock(&vq->mutex);
+ /*
+ * We can handle the vq only after the endpoint is setup by calling the
+ * VHOST_SCSI_SET_ENDPOINT ioctl.
+ */
+ if (!vq->private_data)
+ goto out;
+
+ memset(&vc, 0, sizeof(vc));
+
+ vhost_disable_notify(&vs->dev, vq);
+
+ do {
+ ret = vhost_scsi_get_desc(vs, vq, &vc);
+ if (ret)
+ goto err;
+
+ /*
+ * Get the request type first in order to setup
+ * other parameters dependent on the type.
+ */
+ vc.req = &v_req.type;
+ typ_size = sizeof(v_req.type);
+
+ if (unlikely(!copy_from_iter_full(vc.req, typ_size,
+ &vc.out_iter))) {
+ vq_err(vq, "Faulted on copy_from_iter tmf type\n");
+ /*
+ * The size of the response buffer depends on the
+ * request type and must be validated against it.
+ * Since the request type is not known, don't send
+ * a response.
+ */
+ continue;
+ }
+
+ switch (v_req.type) {
+ case VIRTIO_SCSI_T_TMF:
+ vc.req = &v_req.tmf;
+ vc.req_size = sizeof(struct virtio_scsi_ctrl_tmf_req);
+ vc.rsp_size = sizeof(struct virtio_scsi_ctrl_tmf_resp);
+ vc.lunp = &v_req.tmf.lun[0];
+ vc.target = &v_req.tmf.lun[1];
+ break;
+ case VIRTIO_SCSI_T_AN_QUERY:
+ case VIRTIO_SCSI_T_AN_SUBSCRIBE:
+ vc.req = &v_req.an;
+ vc.req_size = sizeof(struct virtio_scsi_ctrl_an_req);
+ vc.rsp_size = sizeof(struct virtio_scsi_ctrl_an_resp);
+ vc.lunp = &v_req.an.lun[0];
+ vc.target = NULL;
+ break;
+ default:
+ vq_err(vq, "Unknown control request %d", v_req.type);
+ continue;
+ }
+
+ /*
+ * Validate the size of request and response buffers.
+ * Check for a sane response buffer so we can report
+ * early errors back to the guest.
+ */
+ ret = vhost_scsi_chk_size(vq, &vc);
+ if (ret)
+ goto err;
+
+ /*
+ * Get the rest of the request now that its size is known.
+ */
+ vc.req += typ_size;
+ vc.req_size -= typ_size;
+
+ ret = vhost_scsi_get_req(vq, &vc, NULL);
+ if (ret)
+ goto err;
+
+ if (v_req.type == VIRTIO_SCSI_T_TMF)
+ vhost_scsi_send_tmf_reject(vs, vq, &vc);
+ else
+ vhost_scsi_send_an_resp(vs, vq, &vc);
+err:
+ /*
+ * ENXIO: No more requests, or read error, wait for next kick
+ * EINVAL: Invalid response buffer, drop the request
+ * EIO: Respond with bad target
+ * EAGAIN: Pending request
+ */
+ if (ret == -ENXIO)
+ break;
+ else if (ret == -EIO)
+ vhost_scsi_send_bad_target(vs, vq, vc.head, vc.out);
+ } while (likely(!vhost_exceeds_weight(vq, ++c, 0)));
out:
mutex_unlock(&vq->mutex);
}
static void vhost_scsi_ctl_handle_kick(struct vhost_work *work)
{
+ struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
+ poll.work);
+ struct vhost_scsi *vs = container_of(vq->dev, struct vhost_scsi, dev);
+
pr_debug("%s: The handling func for control queue.\n", __func__);
+ vhost_scsi_ctl_handle_vq(vs, vq);
}
static void
@@ -1211,7 +1441,7 @@
se_tpg = &tpg->se_tpg;
ret = target_depend_item(&se_tpg->tpg_group.cg_item);
if (ret) {
- pr_warn("configfs_depend_item() failed: %d\n", ret);
+ pr_warn("target_depend_item() failed: %d\n", ret);
kfree(vs_tpg);
mutex_unlock(&tpg->tv_tpg_mutex);
goto out;
@@ -1219,7 +1449,6 @@
tpg->tv_tpg_vhost_count++;
tpg->vhost_scsi = vs;
vs_tpg[tpg->tport_tpgt] = tpg;
- smp_mb__after_atomic();
match = true;
}
mutex_unlock(&tpg->tv_tpg_mutex);
@@ -1398,7 +1627,8 @@
vqs[i] = &vs->vqs[i].vq;
vs->vqs[i].vq.handle_kick = vhost_scsi_handle_kick;
}
- vhost_dev_init(&vs->dev, vqs, VHOST_SCSI_MAX_VQ);
+ vhost_dev_init(&vs->dev, vqs, VHOST_SCSI_MAX_VQ, UIO_MAXIOV,
+ VHOST_SCSI_WEIGHT, 0);
vhost_scsi_init_inflight(vs, NULL);
@@ -2059,8 +2289,7 @@
static const struct target_core_fabric_ops vhost_scsi_ops = {
.module = THIS_MODULE,
- .name = "vhost",
- .get_fabric_name = vhost_scsi_get_fabric_name,
+ .fabric_name = "vhost",
.tpg_get_wwn = vhost_scsi_get_fabric_wwn,
.tpg_get_tag = vhost_scsi_get_tpgt,
.tpg_check_demo_mode = vhost_scsi_check_true,
@@ -2074,7 +2303,6 @@
.sess_get_index = vhost_scsi_sess_get_index,
.sess_get_initiator_sid = NULL,
.write_pending = vhost_scsi_write_pending,
- .write_pending_status = vhost_scsi_write_pending_status,
.set_default_node_attributes = vhost_scsi_set_default_node_attrs,
.get_cmd_state = vhost_scsi_get_cmd_state,
.queue_data_in = vhost_scsi_queue_data_in,
diff --git a/drivers/vhost/test.c b/drivers/vhost/test.c
index 4058985..0563080 100644
--- a/drivers/vhost/test.c
+++ b/drivers/vhost/test.c
@@ -1,8 +1,7 @@
+// SPDX-License-Identifier: GPL-2.0-only
/* Copyright (C) 2009 Red Hat, Inc.
* Author: Michael S. Tsirkin <mst@redhat.com>
*
- * This work is licensed under the terms of the GNU GPL, version 2.
- *
* test virtio server in host kernel.
*/
@@ -23,6 +22,12 @@
* Using this limit prevents one virtqueue from starving others. */
#define VHOST_TEST_WEIGHT 0x80000
+/* Max number of packets transferred before requeueing the job.
+ * Using this limit prevents one virtqueue from starving others with
+ * pkts.
+ */
+#define VHOST_TEST_PKT_WEIGHT 256
+
enum {
VHOST_TEST_VQ = 0,
VHOST_TEST_VQ_MAX = 1,
@@ -81,10 +86,8 @@
}
vhost_add_used_and_signal(&n->dev, vq, head, 0);
total_len += len;
- if (unlikely(total_len >= VHOST_TEST_WEIGHT)) {
- vhost_poll_queue(&vq->poll);
+ if (unlikely(vhost_exceeds_weight(vq, 0, total_len)))
break;
- }
}
mutex_unlock(&vq->mutex);
@@ -116,7 +119,8 @@
dev = &n->dev;
vqs[VHOST_TEST_VQ] = &n->vqs[VHOST_TEST_VQ];
n->vqs[VHOST_TEST_VQ].handle_kick = handle_vq_kick;
- vhost_dev_init(dev, vqs, VHOST_TEST_VQ_MAX);
+ vhost_dev_init(dev, vqs, VHOST_TEST_VQ_MAX, UIO_MAXIOV,
+ VHOST_TEST_PKT_WEIGHT, VHOST_TEST_WEIGHT);
f->private_data = n;
@@ -157,6 +161,7 @@
vhost_test_stop(n, &private);
vhost_test_flush(n);
+ vhost_dev_stop(&n->dev);
vhost_dev_cleanup(&n->dev);
/* We do an extra flush before freeing memory,
* since jobs can re-queue themselves. */
@@ -233,6 +238,7 @@
}
vhost_test_stop(n, &priv);
vhost_test_flush(n);
+ vhost_dev_stop(&n->dev);
vhost_dev_reset_owner(&n->dev, umem);
done:
mutex_unlock(&n->dev.mutex);
diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c
index eb95daa..36ca2cf 100644
--- a/drivers/vhost/vhost.c
+++ b/drivers/vhost/vhost.c
@@ -1,3 +1,4 @@
+// SPDX-License-Identifier: GPL-2.0-only
/* Copyright (C) 2009 Red Hat, Inc.
* Copyright (C) 2006 Rusty Russell IBM Corporation
*
@@ -6,8 +7,6 @@
* Inspiration, some code, and most witty comments come from
* Documentation/virtual/lguest/lguest.c, by Rusty Russell
*
- * This work is licensed under the terms of the GNU GPL, version 2.
- *
* Generic code for virtio server in host kernel.
*/
@@ -204,7 +203,6 @@
int vhost_poll_start(struct vhost_poll *poll, struct file *file)
{
__poll_t mask;
- int ret = 0;
if (poll->wqh)
return 0;
@@ -214,10 +212,10 @@
vhost_poll_wakeup(&poll->wait, 0, 0, poll_to_key(mask));
if (mask & EPOLLERR) {
vhost_poll_stop(poll);
- ret = -EINVAL;
+ return -EINVAL;
}
- return ret;
+ return 0;
}
EXPORT_SYMBOL_GPL(vhost_poll_start);
@@ -390,9 +388,9 @@
vq->indirect = kmalloc_array(UIO_MAXIOV,
sizeof(*vq->indirect),
GFP_KERNEL);
- vq->log = kmalloc_array(UIO_MAXIOV, sizeof(*vq->log),
+ vq->log = kmalloc_array(dev->iov_limit, sizeof(*vq->log),
GFP_KERNEL);
- vq->heads = kmalloc_array(UIO_MAXIOV, sizeof(*vq->heads),
+ vq->heads = kmalloc_array(dev->iov_limit, sizeof(*vq->heads),
GFP_KERNEL);
if (!vq->indirect || !vq->log || !vq->heads)
goto err_nomem;
@@ -413,8 +411,50 @@
vhost_vq_free_iovecs(dev->vqs[i]);
}
+bool vhost_exceeds_weight(struct vhost_virtqueue *vq,
+ int pkts, int total_len)
+{
+ struct vhost_dev *dev = vq->dev;
+
+ if ((dev->byte_weight && total_len >= dev->byte_weight) ||
+ pkts >= dev->weight) {
+ vhost_poll_queue(&vq->poll);
+ return true;
+ }
+
+ return false;
+}
+EXPORT_SYMBOL_GPL(vhost_exceeds_weight);
+
+static size_t vhost_get_avail_size(struct vhost_virtqueue *vq,
+ unsigned int num)
+{
+ size_t event __maybe_unused =
+ vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
+
+ return sizeof(*vq->avail) +
+ sizeof(*vq->avail->ring) * num + event;
+}
+
+static size_t vhost_get_used_size(struct vhost_virtqueue *vq,
+ unsigned int num)
+{
+ size_t event __maybe_unused =
+ vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
+
+ return sizeof(*vq->used) +
+ sizeof(*vq->used->ring) * num + event;
+}
+
+static size_t vhost_get_desc_size(struct vhost_virtqueue *vq,
+ unsigned int num)
+{
+ return sizeof(*vq->desc) * num;
+}
+
void vhost_dev_init(struct vhost_dev *dev,
- struct vhost_virtqueue **vqs, int nvqs)
+ struct vhost_virtqueue **vqs, int nvqs,
+ int iov_limit, int weight, int byte_weight)
{
struct vhost_virtqueue *vq;
int i;
@@ -427,6 +467,9 @@
dev->iotlb = NULL;
dev->mm = NULL;
dev->worker = NULL;
+ dev->iov_limit = iov_limit;
+ dev->weight = weight;
+ dev->byte_weight = byte_weight;
init_llist_head(&dev->work_list);
init_waitqueue_head(&dev->wait);
INIT_LIST_HEAD(&dev->read_list);
@@ -655,7 +698,7 @@
a + (unsigned long)log_base > ULONG_MAX)
return false;
- return access_ok(VERIFY_WRITE, log_base + a,
+ return access_ok(log_base + a,
(sz + VHOST_PAGE_SIZE * 8 - 1) / VHOST_PAGE_SIZE / 8);
}
@@ -681,7 +724,7 @@
return false;
- if (!access_ok(VERIFY_WRITE, (void __user *)a,
+ if (!access_ok((void __user *)a,
node->size))
return false;
else if (log_all && !log_access_ok(log_base,
@@ -868,6 +911,34 @@
ret; \
})
+static inline int vhost_put_avail_event(struct vhost_virtqueue *vq)
+{
+ return vhost_put_user(vq, cpu_to_vhost16(vq, vq->avail_idx),
+ vhost_avail_event(vq));
+}
+
+static inline int vhost_put_used(struct vhost_virtqueue *vq,
+ struct vring_used_elem *head, int idx,
+ int count)
+{
+ return vhost_copy_to_user(vq, vq->used->ring + idx, head,
+ count * sizeof(*head));
+}
+
+static inline int vhost_put_used_flags(struct vhost_virtqueue *vq)
+
+{
+ return vhost_put_user(vq, cpu_to_vhost16(vq, vq->used_flags),
+ &vq->used->flags);
+}
+
+static inline int vhost_put_used_idx(struct vhost_virtqueue *vq)
+
+{
+ return vhost_put_user(vq, cpu_to_vhost16(vq, vq->last_used_idx),
+ &vq->used->idx);
+}
+
#define vhost_get_user(vq, x, ptr, type) \
({ \
int ret; \
@@ -906,12 +977,53 @@
mutex_unlock(&d->vqs[i]->mutex);
}
+static inline int vhost_get_avail_idx(struct vhost_virtqueue *vq,
+ __virtio16 *idx)
+{
+ return vhost_get_avail(vq, *idx, &vq->avail->idx);
+}
+
+static inline int vhost_get_avail_head(struct vhost_virtqueue *vq,
+ __virtio16 *head, int idx)
+{
+ return vhost_get_avail(vq, *head,
+ &vq->avail->ring[idx & (vq->num - 1)]);
+}
+
+static inline int vhost_get_avail_flags(struct vhost_virtqueue *vq,
+ __virtio16 *flags)
+{
+ return vhost_get_avail(vq, *flags, &vq->avail->flags);
+}
+
+static inline int vhost_get_used_event(struct vhost_virtqueue *vq,
+ __virtio16 *event)
+{
+ return vhost_get_avail(vq, *event, vhost_used_event(vq));
+}
+
+static inline int vhost_get_used_idx(struct vhost_virtqueue *vq,
+ __virtio16 *idx)
+{
+ return vhost_get_used(vq, *idx, &vq->used->idx);
+}
+
+static inline int vhost_get_desc(struct vhost_virtqueue *vq,
+ struct vring_desc *desc, int idx)
+{
+ return vhost_copy_from_user(vq, desc, vq->desc + idx, sizeof(*desc));
+}
+
static int vhost_new_umem_range(struct vhost_umem *umem,
u64 start, u64 size, u64 end,
u64 userspace_addr, int perm)
{
- struct vhost_umem_node *tmp, *node = kmalloc(sizeof(*node), GFP_ATOMIC);
+ struct vhost_umem_node *tmp, *node;
+ if (!size)
+ return -EFAULT;
+
+ node = kmalloc(sizeof(*node), GFP_ATOMIC);
if (!node)
return -ENOMEM;
@@ -973,10 +1085,10 @@
return false;
if ((access & VHOST_ACCESS_RO) &&
- !access_ok(VERIFY_READ, (void __user *)a, size))
+ !access_ok((void __user *)a, size))
return false;
if ((access & VHOST_ACCESS_WO) &&
- !access_ok(VERIFY_WRITE, (void __user *)a, size))
+ !access_ok((void __user *)a, size))
return false;
return true;
}
@@ -1034,8 +1146,10 @@
int type, ret;
ret = copy_from_iter(&type, sizeof(type), from);
- if (ret != sizeof(type))
+ if (ret != sizeof(type)) {
+ ret = -EINVAL;
goto done;
+ }
switch (type) {
case VHOST_IOTLB_MSG:
@@ -1054,8 +1168,10 @@
iov_iter_advance(from, offset);
ret = copy_from_iter(&msg, sizeof(msg), from);
- if (ret != sizeof(msg))
+ if (ret != sizeof(msg)) {
+ ret = -EINVAL;
goto done;
+ }
if (vhost_process_iotlb_msg(dev, &msg)) {
ret = -EFAULT;
goto done;
@@ -1183,13 +1299,9 @@
struct vring_used __user *used)
{
- size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
-
- return access_ok(VERIFY_READ, desc, num * sizeof *desc) &&
- access_ok(VERIFY_READ, avail,
- sizeof *avail + num * sizeof *avail->ring + s) &&
- access_ok(VERIFY_WRITE, used,
- sizeof *used + num * sizeof *used->ring + s);
+ return access_ok(desc, vhost_get_desc_size(vq, num)) &&
+ access_ok(avail, vhost_get_avail_size(vq, num)) &&
+ access_ok(used, vhost_get_used_size(vq, num));
}
static void vhost_vq_meta_update(struct vhost_virtqueue *vq,
@@ -1239,26 +1351,22 @@
return true;
}
-int vq_iotlb_prefetch(struct vhost_virtqueue *vq)
+int vq_meta_prefetch(struct vhost_virtqueue *vq)
{
- size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
unsigned int num = vq->num;
if (!vq->iotlb)
return 1;
return iotlb_access_ok(vq, VHOST_ACCESS_RO, (u64)(uintptr_t)vq->desc,
- num * sizeof(*vq->desc), VHOST_ADDR_DESC) &&
+ vhost_get_desc_size(vq, num), VHOST_ADDR_DESC) &&
iotlb_access_ok(vq, VHOST_ACCESS_RO, (u64)(uintptr_t)vq->avail,
- sizeof *vq->avail +
- num * sizeof(*vq->avail->ring) + s,
+ vhost_get_avail_size(vq, num),
VHOST_ADDR_AVAIL) &&
iotlb_access_ok(vq, VHOST_ACCESS_WO, (u64)(uintptr_t)vq->used,
- sizeof *vq->used +
- num * sizeof(*vq->used->ring) + s,
- VHOST_ADDR_USED);
+ vhost_get_used_size(vq, num), VHOST_ADDR_USED);
}
-EXPORT_SYMBOL_GPL(vq_iotlb_prefetch);
+EXPORT_SYMBOL_GPL(vq_meta_prefetch);
/* Can we log writes? */
/* Caller should have device mutex but not vq mutex */
@@ -1273,13 +1381,10 @@
static bool vq_log_access_ok(struct vhost_virtqueue *vq,
void __user *log_base)
{
- size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
-
return vq_memory_access_ok(log_base, vq->umem,
vhost_has_feature(vq, VHOST_F_LOG_ALL)) &&
(!vq->log_used || log_access_ok(log_base, vq->log_addr,
- sizeof *vq->used +
- vq->num * sizeof *vq->used->ring + s));
+ vhost_get_used_size(vq, vq->num)));
}
/* Can we start vq? */
@@ -1379,6 +1484,104 @@
return -EFAULT;
}
+static long vhost_vring_set_num(struct vhost_dev *d,
+ struct vhost_virtqueue *vq,
+ void __user *argp)
+{
+ struct vhost_vring_state s;
+
+ /* Resizing ring with an active backend?
+ * You don't want to do that. */
+ if (vq->private_data)
+ return -EBUSY;
+
+ if (copy_from_user(&s, argp, sizeof s))
+ return -EFAULT;
+
+ if (!s.num || s.num > 0xffff || (s.num & (s.num - 1)))
+ return -EINVAL;
+ vq->num = s.num;
+
+ return 0;
+}
+
+static long vhost_vring_set_addr(struct vhost_dev *d,
+ struct vhost_virtqueue *vq,
+ void __user *argp)
+{
+ struct vhost_vring_addr a;
+
+ if (copy_from_user(&a, argp, sizeof a))
+ return -EFAULT;
+ if (a.flags & ~(0x1 << VHOST_VRING_F_LOG))
+ return -EOPNOTSUPP;
+
+ /* For 32bit, verify that the top 32bits of the user
+ data are set to zero. */
+ if ((u64)(unsigned long)a.desc_user_addr != a.desc_user_addr ||
+ (u64)(unsigned long)a.used_user_addr != a.used_user_addr ||
+ (u64)(unsigned long)a.avail_user_addr != a.avail_user_addr)
+ return -EFAULT;
+
+ /* Make sure it's safe to cast pointers to vring types. */
+ BUILD_BUG_ON(__alignof__ *vq->avail > VRING_AVAIL_ALIGN_SIZE);
+ BUILD_BUG_ON(__alignof__ *vq->used > VRING_USED_ALIGN_SIZE);
+ if ((a.avail_user_addr & (VRING_AVAIL_ALIGN_SIZE - 1)) ||
+ (a.used_user_addr & (VRING_USED_ALIGN_SIZE - 1)) ||
+ (a.log_guest_addr & (VRING_USED_ALIGN_SIZE - 1)))
+ return -EINVAL;
+
+ /* We only verify access here if backend is configured.
+ * If it is not, we don't as size might not have been setup.
+ * We will verify when backend is configured. */
+ if (vq->private_data) {
+ if (!vq_access_ok(vq, vq->num,
+ (void __user *)(unsigned long)a.desc_user_addr,
+ (void __user *)(unsigned long)a.avail_user_addr,
+ (void __user *)(unsigned long)a.used_user_addr))
+ return -EINVAL;
+
+ /* Also validate log access for used ring if enabled. */
+ if ((a.flags & (0x1 << VHOST_VRING_F_LOG)) &&
+ !log_access_ok(vq->log_base, a.log_guest_addr,
+ sizeof *vq->used +
+ vq->num * sizeof *vq->used->ring))
+ return -EINVAL;
+ }
+
+ vq->log_used = !!(a.flags & (0x1 << VHOST_VRING_F_LOG));
+ vq->desc = (void __user *)(unsigned long)a.desc_user_addr;
+ vq->avail = (void __user *)(unsigned long)a.avail_user_addr;
+ vq->log_addr = a.log_guest_addr;
+ vq->used = (void __user *)(unsigned long)a.used_user_addr;
+
+ return 0;
+}
+
+static long vhost_vring_set_num_addr(struct vhost_dev *d,
+ struct vhost_virtqueue *vq,
+ unsigned int ioctl,
+ void __user *argp)
+{
+ long r;
+
+ mutex_lock(&vq->mutex);
+
+ switch (ioctl) {
+ case VHOST_SET_VRING_NUM:
+ r = vhost_vring_set_num(d, vq, argp);
+ break;
+ case VHOST_SET_VRING_ADDR:
+ r = vhost_vring_set_addr(d, vq, argp);
+ break;
+ default:
+ BUG();
+ }
+
+ mutex_unlock(&vq->mutex);
+
+ return r;
+}
long vhost_vring_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp)
{
struct file *eventfp, *filep = NULL;
@@ -1388,7 +1591,6 @@
struct vhost_virtqueue *vq;
struct vhost_vring_state s;
struct vhost_vring_file f;
- struct vhost_vring_addr a;
u32 idx;
long r;
@@ -1401,26 +1603,14 @@
idx = array_index_nospec(idx, d->nvqs);
vq = d->vqs[idx];
+ if (ioctl == VHOST_SET_VRING_NUM ||
+ ioctl == VHOST_SET_VRING_ADDR) {
+ return vhost_vring_set_num_addr(d, vq, ioctl, argp);
+ }
+
mutex_lock(&vq->mutex);
switch (ioctl) {
- case VHOST_SET_VRING_NUM:
- /* Resizing ring with an active backend?
- * You don't want to do that. */
- if (vq->private_data) {
- r = -EBUSY;
- break;
- }
- if (copy_from_user(&s, argp, sizeof s)) {
- r = -EFAULT;
- break;
- }
- if (!s.num || s.num > 0xffff || (s.num & (s.num - 1))) {
- r = -EINVAL;
- break;
- }
- vq->num = s.num;
- break;
case VHOST_SET_VRING_BASE:
/* Moving base with an active backend?
* You don't want to do that. */
@@ -1446,62 +1636,6 @@
if (copy_to_user(argp, &s, sizeof s))
r = -EFAULT;
break;
- case VHOST_SET_VRING_ADDR:
- if (copy_from_user(&a, argp, sizeof a)) {
- r = -EFAULT;
- break;
- }
- if (a.flags & ~(0x1 << VHOST_VRING_F_LOG)) {
- r = -EOPNOTSUPP;
- break;
- }
- /* For 32bit, verify that the top 32bits of the user
- data are set to zero. */
- if ((u64)(unsigned long)a.desc_user_addr != a.desc_user_addr ||
- (u64)(unsigned long)a.used_user_addr != a.used_user_addr ||
- (u64)(unsigned long)a.avail_user_addr != a.avail_user_addr) {
- r = -EFAULT;
- break;
- }
-
- /* Make sure it's safe to cast pointers to vring types. */
- BUILD_BUG_ON(__alignof__ *vq->avail > VRING_AVAIL_ALIGN_SIZE);
- BUILD_BUG_ON(__alignof__ *vq->used > VRING_USED_ALIGN_SIZE);
- if ((a.avail_user_addr & (VRING_AVAIL_ALIGN_SIZE - 1)) ||
- (a.used_user_addr & (VRING_USED_ALIGN_SIZE - 1)) ||
- (a.log_guest_addr & (VRING_USED_ALIGN_SIZE - 1))) {
- r = -EINVAL;
- break;
- }
-
- /* We only verify access here if backend is configured.
- * If it is not, we don't as size might not have been setup.
- * We will verify when backend is configured. */
- if (vq->private_data) {
- if (!vq_access_ok(vq, vq->num,
- (void __user *)(unsigned long)a.desc_user_addr,
- (void __user *)(unsigned long)a.avail_user_addr,
- (void __user *)(unsigned long)a.used_user_addr)) {
- r = -EINVAL;
- break;
- }
-
- /* Also validate log access for used ring if enabled. */
- if ((a.flags & (0x1 << VHOST_VRING_F_LOG)) &&
- !log_access_ok(vq->log_base, a.log_guest_addr,
- sizeof *vq->used +
- vq->num * sizeof *vq->used->ring)) {
- r = -EINVAL;
- break;
- }
- }
-
- vq->log_used = !!(a.flags & (0x1 << VHOST_VRING_F_LOG));
- vq->desc = (void __user *)(unsigned long)a.desc_user_addr;
- vq->avail = (void __user *)(unsigned long)a.avail_user_addr;
- vq->log_addr = a.log_guest_addr;
- vq->used = (void __user *)(unsigned long)a.used_user_addr;
- break;
case VHOST_SET_VRING_KICK:
if (copy_from_user(&f, argp, sizeof f)) {
r = -EFAULT;
@@ -1685,7 +1819,7 @@
/* TODO: This is really inefficient. We need something like get_user()
* (instruction directly accesses the data, with an exception table entry
- * returning -EFAULT). See Documentation/x86/exception-tables.txt.
+ * returning -EFAULT). See Documentation/x86/exception-tables.rst.
*/
static int set_bit_to_user(int nr, void __user *addr)
{
@@ -1695,7 +1829,7 @@
int bit = nr + (log % PAGE_SIZE) * 8;
int r;
- r = get_user_pages_fast(log, 1, 1, &page);
+ r = get_user_pages_fast(log, 1, FOLL_WRITE, &page);
if (r < 0)
return r;
BUG_ON(r != 1);
@@ -1733,13 +1867,87 @@
return r;
}
+static int log_write_hva(struct vhost_virtqueue *vq, u64 hva, u64 len)
+{
+ struct vhost_umem *umem = vq->umem;
+ struct vhost_umem_node *u;
+ u64 start, end, l, min;
+ int r;
+ bool hit = false;
+
+ while (len) {
+ min = len;
+ /* More than one GPAs can be mapped into a single HVA. So
+ * iterate all possible umems here to be safe.
+ */
+ list_for_each_entry(u, &umem->umem_list, link) {
+ if (u->userspace_addr > hva - 1 + len ||
+ u->userspace_addr - 1 + u->size < hva)
+ continue;
+ start = max(u->userspace_addr, hva);
+ end = min(u->userspace_addr - 1 + u->size,
+ hva - 1 + len);
+ l = end - start + 1;
+ r = log_write(vq->log_base,
+ u->start + start - u->userspace_addr,
+ l);
+ if (r < 0)
+ return r;
+ hit = true;
+ min = min(l, min);
+ }
+
+ if (!hit)
+ return -EFAULT;
+
+ len -= min;
+ hva += min;
+ }
+
+ return 0;
+}
+
+static int log_used(struct vhost_virtqueue *vq, u64 used_offset, u64 len)
+{
+ struct iovec iov[64];
+ int i, ret;
+
+ if (!vq->iotlb)
+ return log_write(vq->log_base, vq->log_addr + used_offset, len);
+
+ ret = translate_desc(vq, (uintptr_t)vq->used + used_offset,
+ len, iov, 64, VHOST_ACCESS_WO);
+ if (ret < 0)
+ return ret;
+
+ for (i = 0; i < ret; i++) {
+ ret = log_write_hva(vq, (uintptr_t)iov[i].iov_base,
+ iov[i].iov_len);
+ if (ret)
+ return ret;
+ }
+
+ return 0;
+}
+
int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log,
- unsigned int log_num, u64 len)
+ unsigned int log_num, u64 len, struct iovec *iov, int count)
{
int i, r;
/* Make sure data written is seen before log. */
smp_wmb();
+
+ if (vq->iotlb) {
+ for (i = 0; i < count; i++) {
+ r = log_write_hva(vq, (uintptr_t)iov[i].iov_base,
+ iov[i].iov_len);
+ if (r < 0)
+ return r;
+ }
+ return 0;
+ }
+
for (i = 0; i < log_num; ++i) {
u64 l = min(log[i].len, len);
r = log_write(vq->log_base, log[i].addr, l);
@@ -1761,17 +1969,15 @@
static int vhost_update_used_flags(struct vhost_virtqueue *vq)
{
void __user *used;
- if (vhost_put_user(vq, cpu_to_vhost16(vq, vq->used_flags),
- &vq->used->flags) < 0)
+ if (vhost_put_used_flags(vq))
return -EFAULT;
if (unlikely(vq->log_used)) {
/* Make sure the flag is seen before log. */
smp_wmb();
/* Log used flag write. */
used = &vq->used->flags;
- log_write(vq->log_base, vq->log_addr +
- (used - (void __user *)vq->used),
- sizeof vq->used->flags);
+ log_used(vq, (used - (void __user *)vq->used),
+ sizeof vq->used->flags);
if (vq->log_ctx)
eventfd_signal(vq->log_ctx, 1);
}
@@ -1780,8 +1986,7 @@
static int vhost_update_avail_event(struct vhost_virtqueue *vq, u16 avail_event)
{
- if (vhost_put_user(vq, cpu_to_vhost16(vq, vq->avail_idx),
- vhost_avail_event(vq)))
+ if (vhost_put_avail_event(vq))
return -EFAULT;
if (unlikely(vq->log_used)) {
void __user *used;
@@ -1789,9 +1994,8 @@
smp_wmb();
/* Log avail event write */
used = vhost_avail_event(vq);
- log_write(vq->log_base, vq->log_addr +
- (used - (void __user *)vq->used),
- sizeof *vhost_avail_event(vq));
+ log_used(vq, (used - (void __user *)vq->used),
+ sizeof *vhost_avail_event(vq));
if (vq->log_ctx)
eventfd_signal(vq->log_ctx, 1);
}
@@ -1814,11 +2018,11 @@
goto err;
vq->signalled_used_valid = false;
if (!vq->iotlb &&
- !access_ok(VERIFY_READ, &vq->used->idx, sizeof vq->used->idx)) {
+ !access_ok(&vq->used->idx, sizeof vq->used->idx)) {
r = -EFAULT;
goto err;
}
- r = vhost_get_used(vq, last_used_idx, &vq->used->idx);
+ r = vhost_get_used_idx(vq, &last_used_idx);
if (r) {
vq_err(vq, "Can't access used idx at %p\n",
&vq->used->idx);
@@ -1974,7 +2178,7 @@
/* If this is an input descriptor, increment that count. */
if (access == VHOST_ACCESS_WO) {
*in_num += ret;
- if (unlikely(log)) {
+ if (unlikely(log && ret)) {
log[*log_num].addr = vhost64_to_cpu(vq, desc.addr);
log[*log_num].len = vhost32_to_cpu(vq, desc.len);
++*log_num;
@@ -2017,7 +2221,7 @@
last_avail_idx = vq->last_avail_idx;
if (vq->avail_idx == vq->last_avail_idx) {
- if (unlikely(vhost_get_avail(vq, avail_idx, &vq->avail->idx))) {
+ if (unlikely(vhost_get_avail_idx(vq, &avail_idx))) {
vq_err(vq, "Failed to access avail idx at %p\n",
&vq->avail->idx);
return -EFAULT;
@@ -2044,8 +2248,7 @@
/* Grab the next descriptor number they're advertising, and increment
* the index we've seen. */
- if (unlikely(vhost_get_avail(vq, ring_head,
- &vq->avail->ring[last_avail_idx & (vq->num - 1)]))) {
+ if (unlikely(vhost_get_avail_head(vq, &ring_head, last_avail_idx))) {
vq_err(vq, "Failed to read head: idx %d address %p\n",
last_avail_idx,
&vq->avail->ring[last_avail_idx % vq->num]);
@@ -2080,8 +2283,7 @@
i, vq->num, head);
return -EINVAL;
}
- ret = vhost_copy_from_user(vq, &desc, vq->desc + i,
- sizeof desc);
+ ret = vhost_get_desc(vq, &desc, i);
if (unlikely(ret)) {
vq_err(vq, "Failed to get descriptor: idx %d addr %p\n",
i, vq->desc + i);
@@ -2117,7 +2319,7 @@
/* If this is an input descriptor,
* increment that count. */
*in_num += ret;
- if (unlikely(log)) {
+ if (unlikely(log && ret)) {
log[*log_num].addr = vhost64_to_cpu(vq, desc.addr);
log[*log_num].len = vhost32_to_cpu(vq, desc.len);
++*log_num;
@@ -2174,16 +2376,7 @@
start = vq->last_used_idx & (vq->num - 1);
used = vq->used->ring + start;
- if (count == 1) {
- if (vhost_put_user(vq, heads[0].id, &used->id)) {
- vq_err(vq, "Failed to write used id");
- return -EFAULT;
- }
- if (vhost_put_user(vq, heads[0].len, &used->len)) {
- vq_err(vq, "Failed to write used len");
- return -EFAULT;
- }
- } else if (vhost_copy_to_user(vq, used, heads, count * sizeof *used)) {
+ if (vhost_put_used(vq, heads, start, count)) {
vq_err(vq, "Failed to write used");
return -EFAULT;
}
@@ -2191,10 +2384,8 @@
/* Make sure data is seen before log. */
smp_wmb();
/* Log used ring entry write. */
- log_write(vq->log_base,
- vq->log_addr +
- ((void __user *)used - (void __user *)vq->used),
- count * sizeof *used);
+ log_used(vq, ((void __user *)used - (void __user *)vq->used),
+ count * sizeof *used);
}
old = vq->last_used_idx;
new = (vq->last_used_idx += count);
@@ -2227,16 +2418,16 @@
/* Make sure buffer is written before we update index. */
smp_wmb();
- if (vhost_put_user(vq, cpu_to_vhost16(vq, vq->last_used_idx),
- &vq->used->idx)) {
+ if (vhost_put_used_idx(vq)) {
vq_err(vq, "Failed to increment used idx");
return -EFAULT;
}
if (unlikely(vq->log_used)) {
+ /* Make sure used idx is seen before log. */
+ smp_wmb();
/* Log used index update. */
- log_write(vq->log_base,
- vq->log_addr + offsetof(struct vring_used, idx),
- sizeof vq->used->idx);
+ log_used(vq, offsetof(struct vring_used, idx),
+ sizeof vq->used->idx);
if (vq->log_ctx)
eventfd_signal(vq->log_ctx, 1);
}
@@ -2260,7 +2451,7 @@
if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) {
__virtio16 flags;
- if (vhost_get_avail(vq, flags, &vq->avail->flags)) {
+ if (vhost_get_avail_flags(vq, &flags)) {
vq_err(vq, "Failed to get flags");
return true;
}
@@ -2274,7 +2465,7 @@
if (unlikely(!v))
return true;
- if (vhost_get_avail(vq, event, vhost_used_event(vq))) {
+ if (vhost_get_used_event(vq, &event)) {
vq_err(vq, "Failed to get used event idx");
return true;
}
@@ -2319,7 +2510,7 @@
if (vq->avail_idx != vq->last_avail_idx)
return false;
- r = vhost_get_avail(vq, avail_idx, &vq->avail->idx);
+ r = vhost_get_avail_idx(vq, &avail_idx);
if (unlikely(r))
return false;
vq->avail_idx = vhost16_to_cpu(vq, avail_idx);
@@ -2355,7 +2546,7 @@
/* They could have slipped one in as we were doing that: make
* sure it's written, then check again. */
smp_mb();
- r = vhost_get_avail(vq, avail_idx, &vq->avail->idx);
+ r = vhost_get_avail_idx(vq, &avail_idx);
if (r) {
vq_err(vq, "Failed to check avail idx at %p: %d\n",
&vq->avail->idx, r);
diff --git a/drivers/vhost/vhost.h b/drivers/vhost/vhost.h
index 466ef75..e9ed272 100644
--- a/drivers/vhost/vhost.h
+++ b/drivers/vhost/vhost.h
@@ -170,9 +170,14 @@
struct list_head read_list;
struct list_head pending_list;
wait_queue_head_t wait;
+ int iov_limit;
+ int weight;
+ int byte_weight;
};
-void vhost_dev_init(struct vhost_dev *, struct vhost_virtqueue **vqs, int nvqs);
+bool vhost_exceeds_weight(struct vhost_virtqueue *vq, int pkts, int total_len);
+void vhost_dev_init(struct vhost_dev *, struct vhost_virtqueue **vqs,
+ int nvqs, int iov_limit, int weight, int byte_weight);
long vhost_dev_set_owner(struct vhost_dev *dev);
bool vhost_dev_has_owner(struct vhost_dev *dev);
long vhost_dev_check_owner(struct vhost_dev *);
@@ -205,8 +210,9 @@
bool vhost_enable_notify(struct vhost_dev *, struct vhost_virtqueue *);
int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log,
- unsigned int log_num, u64 len);
-int vq_iotlb_prefetch(struct vhost_virtqueue *vq);
+ unsigned int log_num, u64 len,
+ struct iovec *iov, int count);
+int vq_meta_prefetch(struct vhost_virtqueue *vq);
struct vhost_msg_node *vhost_new_msg(struct vhost_virtqueue *vq, int type);
void vhost_enqueue_msg(struct vhost_dev *dev,
diff --git a/drivers/vhost/vringh.c b/drivers/vhost/vringh.c
index a94d700..a0a2d74 100644
--- a/drivers/vhost/vringh.c
+++ b/drivers/vhost/vringh.c
@@ -1,3 +1,4 @@
+// SPDX-License-Identifier: GPL-2.0-only
/*
* Helpers for the host side of a virtio ring.
*
@@ -851,6 +852,12 @@
return 0;
}
+static inline int kern_xfer(void *dst, void *src, size_t len)
+{
+ memcpy(dst, src, len);
+ return 0;
+}
+
/**
* vringh_init_kern - initialize a vringh for a kernelspace vring.
* @vrh: the vringh to initialize.
@@ -957,7 +964,7 @@
ssize_t vringh_iov_push_kern(struct vringh_kiov *wiov,
const void *src, size_t len)
{
- return vringh_iov_xfer(wiov, (void *)src, len, xfer_kern);
+ return vringh_iov_xfer(wiov, (void *)src, len, kern_xfer);
}
EXPORT_SYMBOL(vringh_iov_push_kern);
diff --git a/drivers/vhost/vsock.c b/drivers/vhost/vsock.c
index 98ed5be..9f57736 100644
--- a/drivers/vhost/vsock.c
+++ b/drivers/vhost/vsock.c
@@ -1,11 +1,10 @@
+// SPDX-License-Identifier: GPL-2.0-only
/*
* vhost transport for vsock
*
* Copyright (C) 2013-2015 Red Hat, Inc.
* Author: Asias He <asias@redhat.com>
* Stefan Hajnoczi <stefanha@redhat.com>
- *
- * This work is licensed under the terms of the GNU GPL, version 2.
*/
#include <linux/miscdevice.h>
#include <linux/atomic.h>
@@ -21,20 +20,28 @@
#include "vhost.h"
#define VHOST_VSOCK_DEFAULT_HOST_CID 2
+/* Max number of bytes transferred before requeueing the job.
+ * Using this limit prevents one virtqueue from starving others. */
+#define VHOST_VSOCK_WEIGHT 0x80000
+/* Max number of packets transferred before requeueing the job.
+ * Using this limit prevents one virtqueue from starving others with
+ * small pkts.
+ */
+#define VHOST_VSOCK_PKT_WEIGHT 256
enum {
VHOST_VSOCK_FEATURES = VHOST_FEATURES,
};
/* Used to track all the vhost_vsock instances on the system. */
-static DEFINE_SPINLOCK(vhost_vsock_lock);
+static DEFINE_MUTEX(vhost_vsock_mutex);
static DEFINE_READ_MOSTLY_HASHTABLE(vhost_vsock_hash, 8);
struct vhost_vsock {
struct vhost_dev dev;
struct vhost_virtqueue vqs[2];
- /* Link to global vhost_vsock_hash, writes use vhost_vsock_lock */
+ /* Link to global vhost_vsock_hash, writes use vhost_vsock_mutex */
struct hlist_node hash;
struct vhost_work send_pkt_work;
@@ -51,7 +58,7 @@
return VHOST_VSOCK_DEFAULT_HOST_CID;
}
-/* Callers that dereference the return value must hold vhost_vsock_lock or the
+/* Callers that dereference the return value must hold vhost_vsock_mutex or the
* RCU read lock.
*/
static struct vhost_vsock *vhost_vsock_get(u32 guest_cid)
@@ -78,6 +85,7 @@
struct vhost_virtqueue *vq)
{
struct vhost_virtqueue *tx_vq = &vsock->vqs[VSOCK_VQ_TX];
+ int pkts = 0, total_len = 0;
bool added = false;
bool restart_tx = false;
@@ -89,12 +97,12 @@
/* Avoid further vmexits, we're already processing the virtqueue */
vhost_disable_notify(&vsock->dev, vq);
- for (;;) {
+ do {
struct virtio_vsock_pkt *pkt;
struct iov_iter iov_iter;
unsigned out, in;
size_t nbytes;
- size_t len;
+ size_t iov_len, payload_len;
int head;
spin_lock_bh(&vsock->send_pkt_list_lock);
@@ -139,8 +147,24 @@
break;
}
- len = iov_length(&vq->iov[out], in);
- iov_iter_init(&iov_iter, READ, &vq->iov[out], in, len);
+ iov_len = iov_length(&vq->iov[out], in);
+ if (iov_len < sizeof(pkt->hdr)) {
+ virtio_transport_free_pkt(pkt);
+ vq_err(vq, "Buffer len [%zu] too small\n", iov_len);
+ break;
+ }
+
+ iov_iter_init(&iov_iter, READ, &vq->iov[out], in, iov_len);
+ payload_len = pkt->len - pkt->off;
+
+ /* If the packet is greater than the space available in the
+ * buffer, we split it using multiple buffers.
+ */
+ if (payload_len > iov_len - sizeof(pkt->hdr))
+ payload_len = iov_len - sizeof(pkt->hdr);
+
+ /* Set the correct length in the header */
+ pkt->hdr.len = cpu_to_le32(payload_len);
nbytes = copy_to_iter(&pkt->hdr, sizeof(pkt->hdr), &iov_iter);
if (nbytes != sizeof(pkt->hdr)) {
@@ -149,33 +173,48 @@
break;
}
- nbytes = copy_to_iter(pkt->buf, pkt->len, &iov_iter);
- if (nbytes != pkt->len) {
+ nbytes = copy_to_iter(pkt->buf + pkt->off, payload_len,
+ &iov_iter);
+ if (nbytes != payload_len) {
virtio_transport_free_pkt(pkt);
vq_err(vq, "Faulted on copying pkt buf\n");
break;
}
- vhost_add_used(vq, head, sizeof(pkt->hdr) + pkt->len);
+ vhost_add_used(vq, head, sizeof(pkt->hdr) + payload_len);
added = true;
- if (pkt->reply) {
- int val;
-
- val = atomic_dec_return(&vsock->queued_replies);
-
- /* Do we have resources to resume tx processing? */
- if (val + 1 == tx_vq->num)
- restart_tx = true;
- }
-
/* Deliver to monitoring devices all correctly transmitted
* packets.
*/
virtio_transport_deliver_tap_pkt(pkt);
- virtio_transport_free_pkt(pkt);
- }
+ pkt->off += payload_len;
+ total_len += payload_len;
+
+ /* If we didn't send all the payload we can requeue the packet
+ * to send it with the next available buffer.
+ */
+ if (pkt->off < pkt->len) {
+ spin_lock_bh(&vsock->send_pkt_list_lock);
+ list_add(&pkt->list, &vsock->send_pkt_list);
+ spin_unlock_bh(&vsock->send_pkt_list_lock);
+ } else {
+ if (pkt->reply) {
+ int val;
+
+ val = atomic_dec_return(&vsock->queued_replies);
+
+ /* Do we have resources to resume tx
+ * processing?
+ */
+ if (val + 1 == tx_vq->num)
+ restart_tx = true;
+ }
+
+ virtio_transport_free_pkt(pkt);
+ }
+ } while(likely(!vhost_exceeds_weight(vq, ++pkts, total_len)));
if (added)
vhost_signal(&vsock->dev, vq);
@@ -320,6 +359,8 @@
return NULL;
}
+ pkt->buf_len = pkt->len;
+
nbytes = copy_from_iter(pkt->buf, pkt->len, &iov_iter);
if (nbytes != pkt->len) {
vq_err(vq, "Expected %u byte payload, got %zu bytes\n",
@@ -350,7 +391,7 @@
struct vhost_vsock *vsock = container_of(vq->dev, struct vhost_vsock,
dev);
struct virtio_vsock_pkt *pkt;
- int head;
+ int head, pkts = 0, total_len = 0;
unsigned int out, in;
bool added = false;
@@ -360,7 +401,7 @@
goto out;
vhost_disable_notify(&vsock->dev, vq);
- for (;;) {
+ do {
u32 len;
if (!vhost_vsock_more_replies(vsock)) {
@@ -401,9 +442,11 @@
else
virtio_transport_free_pkt(pkt);
- vhost_add_used(vq, head, sizeof(pkt->hdr) + len);
+ len += sizeof(pkt->hdr);
+ vhost_add_used(vq, head, len);
+ total_len += len;
added = true;
- }
+ } while(likely(!vhost_exceeds_weight(vq, ++pkts, total_len)));
no_more_replies:
if (added)
@@ -531,7 +574,9 @@
vsock->vqs[VSOCK_VQ_TX].handle_kick = vhost_vsock_handle_tx_kick;
vsock->vqs[VSOCK_VQ_RX].handle_kick = vhost_vsock_handle_rx_kick;
- vhost_dev_init(&vsock->dev, vqs, ARRAY_SIZE(vsock->vqs));
+ vhost_dev_init(&vsock->dev, vqs, ARRAY_SIZE(vsock->vqs),
+ UIO_MAXIOV, VHOST_VSOCK_PKT_WEIGHT,
+ VHOST_VSOCK_WEIGHT);
file->private_data = vsock;
spin_lock_init(&vsock->send_pkt_list_lock);
@@ -584,10 +629,10 @@
{
struct vhost_vsock *vsock = file->private_data;
- spin_lock_bh(&vhost_vsock_lock);
+ mutex_lock(&vhost_vsock_mutex);
if (vsock->guest_cid)
hash_del_rcu(&vsock->hash);
- spin_unlock_bh(&vhost_vsock_lock);
+ mutex_unlock(&vhost_vsock_mutex);
/* Wait for other CPUs to finish using vsock */
synchronize_rcu();
@@ -631,10 +676,10 @@
return -EINVAL;
/* Refuse if CID is already in use */
- spin_lock_bh(&vhost_vsock_lock);
+ mutex_lock(&vhost_vsock_mutex);
other = vhost_vsock_get(guest_cid);
if (other && other != vsock) {
- spin_unlock_bh(&vhost_vsock_lock);
+ mutex_unlock(&vhost_vsock_mutex);
return -EADDRINUSE;
}
@@ -642,8 +687,8 @@
hash_del_rcu(&vsock->hash);
vsock->guest_cid = guest_cid;
- hash_add_rcu(vhost_vsock_hash, &vsock->hash, guest_cid);
- spin_unlock_bh(&vhost_vsock_lock);
+ hash_add_rcu(vhost_vsock_hash, &vsock->hash, vsock->guest_cid);
+ mutex_unlock(&vhost_vsock_mutex);
return 0;
}