Update Linux to v5.10.109

Sourced from [1]

[1] https://cdn.kernel.org/pub/linux/kernel/v5.x/linux-5.10.109.tar.xz

Change-Id: I19bca9fc6762d4e63bcf3e4cba88bbe560d9c76c
Signed-off-by: Olivier Deprez <olivier.deprez@arm.com>
diff --git a/net/ipv4/tcp_bpf.c b/net/ipv4/tcp_bpf.c
index 6a0c432..6b745ce 100644
--- a/net/ipv4/tcp_bpf.c
+++ b/net/ipv4/tcp_bpf.c
@@ -10,38 +10,6 @@
 #include <net/inet_common.h>
 #include <net/tls.h>
 
-static bool tcp_bpf_stream_read(const struct sock *sk)
-{
-	struct sk_psock *psock;
-	bool empty = true;
-
-	rcu_read_lock();
-	psock = sk_psock(sk);
-	if (likely(psock))
-		empty = list_empty(&psock->ingress_msg);
-	rcu_read_unlock();
-	return !empty;
-}
-
-static int tcp_bpf_wait_data(struct sock *sk, struct sk_psock *psock,
-			     int flags, long timeo, int *err)
-{
-	DEFINE_WAIT_FUNC(wait, woken_wake_function);
-	int ret = 0;
-
-	if (!timeo)
-		return ret;
-
-	add_wait_queue(sk_sleep(sk), &wait);
-	sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
-	ret = sk_wait_event(sk, &timeo,
-			    !list_empty(&psock->ingress_msg) ||
-			    !skb_queue_empty(&sk->sk_receive_queue), &wait);
-	sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
-	remove_wait_queue(sk_sleep(sk), &wait);
-	return ret;
-}
-
 int __tcp_bpf_recvmsg(struct sock *sk, struct sk_psock *psock,
 		      struct msghdr *msg, int len, int flags)
 {
@@ -122,52 +90,6 @@
 }
 EXPORT_SYMBOL_GPL(__tcp_bpf_recvmsg);
 
-int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
-		    int nonblock, int flags, int *addr_len)
-{
-	struct sk_psock *psock;
-	int copied, ret;
-
-	if (unlikely(flags & MSG_ERRQUEUE))
-		return inet_recv_error(sk, msg, len, addr_len);
-
-	psock = sk_psock_get(sk);
-	if (unlikely(!psock))
-		return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
-	if (!skb_queue_empty(&sk->sk_receive_queue) &&
-	    sk_psock_queue_empty(psock)) {
-		sk_psock_put(sk, psock);
-		return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
-	}
-	lock_sock(sk);
-msg_bytes_ready:
-	copied = __tcp_bpf_recvmsg(sk, psock, msg, len, flags);
-	if (!copied) {
-		int data, err = 0;
-		long timeo;
-
-		timeo = sock_rcvtimeo(sk, nonblock);
-		data = tcp_bpf_wait_data(sk, psock, flags, timeo, &err);
-		if (data) {
-			if (!sk_psock_queue_empty(psock))
-				goto msg_bytes_ready;
-			release_sock(sk);
-			sk_psock_put(sk, psock);
-			return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
-		}
-		if (err) {
-			ret = err;
-			goto out;
-		}
-		copied = -EAGAIN;
-	}
-	ret = copied;
-out:
-	release_sock(sk);
-	sk_psock_put(sk, psock);
-	return ret;
-}
-
 static int bpf_tcp_ingress(struct sock *sk, struct sk_psock *psock,
 			   struct sk_msg *msg, u32 apply_bytes, int flags)
 {
@@ -307,12 +229,95 @@
 }
 EXPORT_SYMBOL_GPL(tcp_bpf_sendmsg_redir);
 
+#ifdef CONFIG_BPF_STREAM_PARSER
+static bool tcp_bpf_stream_read(const struct sock *sk)
+{
+	struct sk_psock *psock;
+	bool empty = true;
+
+	rcu_read_lock();
+	psock = sk_psock(sk);
+	if (likely(psock))
+		empty = list_empty(&psock->ingress_msg);
+	rcu_read_unlock();
+	return !empty;
+}
+
+static int tcp_bpf_wait_data(struct sock *sk, struct sk_psock *psock,
+			     int flags, long timeo, int *err)
+{
+	DEFINE_WAIT_FUNC(wait, woken_wake_function);
+	int ret = 0;
+
+	if (sk->sk_shutdown & RCV_SHUTDOWN)
+		return 1;
+
+	if (!timeo)
+		return ret;
+
+	add_wait_queue(sk_sleep(sk), &wait);
+	sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
+	ret = sk_wait_event(sk, &timeo,
+			    !list_empty(&psock->ingress_msg) ||
+			    !skb_queue_empty(&sk->sk_receive_queue), &wait);
+	sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
+	remove_wait_queue(sk_sleep(sk), &wait);
+	return ret;
+}
+
+static int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
+		    int nonblock, int flags, int *addr_len)
+{
+	struct sk_psock *psock;
+	int copied, ret;
+
+	if (unlikely(flags & MSG_ERRQUEUE))
+		return inet_recv_error(sk, msg, len, addr_len);
+
+	psock = sk_psock_get(sk);
+	if (unlikely(!psock))
+		return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
+	if (!skb_queue_empty(&sk->sk_receive_queue) &&
+	    sk_psock_queue_empty(psock)) {
+		sk_psock_put(sk, psock);
+		return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
+	}
+	lock_sock(sk);
+msg_bytes_ready:
+	copied = __tcp_bpf_recvmsg(sk, psock, msg, len, flags);
+	if (!copied) {
+		int data, err = 0;
+		long timeo;
+
+		timeo = sock_rcvtimeo(sk, nonblock);
+		data = tcp_bpf_wait_data(sk, psock, flags, timeo, &err);
+		if (data) {
+			if (!sk_psock_queue_empty(psock))
+				goto msg_bytes_ready;
+			release_sock(sk);
+			sk_psock_put(sk, psock);
+			return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len);
+		}
+		if (err) {
+			ret = err;
+			goto out;
+		}
+		copied = -EAGAIN;
+	}
+	ret = copied;
+out:
+	release_sock(sk);
+	sk_psock_put(sk, psock);
+	return ret;
+}
+
 static int tcp_bpf_send_verdict(struct sock *sk, struct sk_psock *psock,
 				struct sk_msg *msg, int *copied, int flags)
 {
 	bool cork = false, enospc = sk_msg_full(msg);
 	struct sock *sk_redir;
 	u32 tosend, delta = 0;
+	u32 eval = __SK_NONE;
 	int ret;
 
 more_data:
@@ -356,13 +361,24 @@
 	case __SK_REDIRECT:
 		sk_redir = psock->sk_redir;
 		sk_msg_apply_bytes(psock, tosend);
+		if (!psock->apply_bytes) {
+			/* Clean up before releasing the sock lock. */
+			eval = psock->eval;
+			psock->eval = __SK_NONE;
+			psock->sk_redir = NULL;
+		}
 		if (psock->cork) {
 			cork = true;
 			psock->cork = NULL;
 		}
 		sk_msg_return(sk, msg, tosend);
 		release_sock(sk);
+
 		ret = tcp_bpf_sendmsg_redir(sk_redir, msg, tosend, flags);
+
+		if (eval == __SK_REDIRECT)
+			sock_put(sk_redir);
+
 		lock_sock(sk);
 		if (unlikely(ret < 0)) {
 			int free = sk_msg_free_nocharge(sk, msg);
@@ -537,57 +553,6 @@
 	return copied ? copied : err;
 }
 
-static void tcp_bpf_remove(struct sock *sk, struct sk_psock *psock)
-{
-	struct sk_psock_link *link;
-
-	while ((link = sk_psock_link_pop(psock))) {
-		sk_psock_unlink(sk, link);
-		sk_psock_free_link(link);
-	}
-}
-
-static void tcp_bpf_unhash(struct sock *sk)
-{
-	void (*saved_unhash)(struct sock *sk);
-	struct sk_psock *psock;
-
-	rcu_read_lock();
-	psock = sk_psock(sk);
-	if (unlikely(!psock)) {
-		rcu_read_unlock();
-		if (sk->sk_prot->unhash)
-			sk->sk_prot->unhash(sk);
-		return;
-	}
-
-	saved_unhash = psock->saved_unhash;
-	tcp_bpf_remove(sk, psock);
-	rcu_read_unlock();
-	saved_unhash(sk);
-}
-
-static void tcp_bpf_close(struct sock *sk, long timeout)
-{
-	void (*saved_close)(struct sock *sk, long timeout);
-	struct sk_psock *psock;
-
-	lock_sock(sk);
-	rcu_read_lock();
-	psock = sk_psock(sk);
-	if (unlikely(!psock)) {
-		rcu_read_unlock();
-		release_sock(sk);
-		return sk->sk_prot->close(sk, timeout);
-	}
-
-	saved_close = psock->saved_close;
-	tcp_bpf_remove(sk, psock);
-	rcu_read_unlock();
-	release_sock(sk);
-	saved_close(sk, timeout);
-}
-
 enum {
 	TCP_BPF_IPV4,
 	TCP_BPF_IPV6,
@@ -608,8 +573,7 @@
 				   struct proto *base)
 {
 	prot[TCP_BPF_BASE]			= *base;
-	prot[TCP_BPF_BASE].unhash		= tcp_bpf_unhash;
-	prot[TCP_BPF_BASE].close		= tcp_bpf_close;
+	prot[TCP_BPF_BASE].close		= sock_map_close;
 	prot[TCP_BPF_BASE].recvmsg		= tcp_bpf_recvmsg;
 	prot[TCP_BPF_BASE].stream_memory_read	= tcp_bpf_stream_read;
 
@@ -618,10 +582,9 @@
 	prot[TCP_BPF_TX].sendpage		= tcp_bpf_sendpage;
 }
 
-static void tcp_bpf_check_v6_needs_rebuild(struct sock *sk, struct proto *ops)
+static void tcp_bpf_check_v6_needs_rebuild(struct proto *ops)
 {
-	if (sk->sk_family == AF_INET6 &&
-	    unlikely(ops != smp_load_acquire(&tcpv6_prot_saved))) {
+	if (unlikely(ops != smp_load_acquire(&tcpv6_prot_saved))) {
 		spin_lock_bh(&tcpv6_prot_lock);
 		if (likely(ops != tcpv6_prot_saved)) {
 			tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV6], ops);
@@ -638,26 +601,6 @@
 }
 late_initcall(tcp_bpf_v4_build_proto);
 
-static void tcp_bpf_update_sk_prot(struct sock *sk, struct sk_psock *psock)
-{
-	int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
-	int config = psock->progs.msg_parser   ? TCP_BPF_TX   : TCP_BPF_BASE;
-
-	sk_psock_update_proto(sk, psock, &tcp_bpf_prots[family][config]);
-}
-
-static void tcp_bpf_reinit_sk_prot(struct sock *sk, struct sk_psock *psock)
-{
-	int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
-	int config = psock->progs.msg_parser   ? TCP_BPF_TX   : TCP_BPF_BASE;
-
-	/* Reinit occurs when program types change e.g. TCP_BPF_TX is removed
-	 * or added requiring sk_prot hook updates. We keep original saved
-	 * hooks in this case.
-	 */
-	sk->sk_prot = &tcp_bpf_prots[family][config];
-}
-
 static int tcp_bpf_assert_proto_ops(struct proto *ops)
 {
 	/* In order to avoid retpoline, we make assumptions when we call
@@ -669,34 +612,32 @@
 	       ops->sendpage == tcp_sendpage ? 0 : -ENOTSUPP;
 }
 
-void tcp_bpf_reinit(struct sock *sk)
+struct proto *tcp_bpf_get_proto(struct sock *sk, struct sk_psock *psock)
 {
-	struct sk_psock *psock;
+	int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
+	int config = psock->progs.msg_parser   ? TCP_BPF_TX   : TCP_BPF_BASE;
 
-	sock_owned_by_me(sk);
+	if (sk->sk_family == AF_INET6) {
+		if (tcp_bpf_assert_proto_ops(psock->sk_proto))
+			return ERR_PTR(-EINVAL);
 
-	rcu_read_lock();
-	psock = sk_psock(sk);
-	tcp_bpf_reinit_sk_prot(sk, psock);
-	rcu_read_unlock();
-}
-
-int tcp_bpf_init(struct sock *sk)
-{
-	struct proto *ops = READ_ONCE(sk->sk_prot);
-	struct sk_psock *psock;
-
-	sock_owned_by_me(sk);
-
-	rcu_read_lock();
-	psock = sk_psock(sk);
-	if (unlikely(!psock || psock->sk_proto ||
-		     tcp_bpf_assert_proto_ops(ops))) {
-		rcu_read_unlock();
-		return -EINVAL;
+		tcp_bpf_check_v6_needs_rebuild(psock->sk_proto);
 	}
-	tcp_bpf_check_v6_needs_rebuild(sk, ops);
-	tcp_bpf_update_sk_prot(sk, psock);
-	rcu_read_unlock();
-	return 0;
+
+	return &tcp_bpf_prots[family][config];
 }
+
+/* If a child got cloned from a listening socket that had tcp_bpf
+ * protocol callbacks installed, we need to restore the callbacks to
+ * the default ones because the child does not inherit the psock state
+ * that tcp_bpf callbacks expect.
+ */
+void tcp_bpf_clone(const struct sock *sk, struct sock *newsk)
+{
+	int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
+	struct proto *prot = newsk->sk_prot;
+
+	if (prot == &tcp_bpf_prots[family][TCP_BPF_BASE])
+		newsk->sk_prot = sk->sk_prot_creator;
+}
+#endif /* CONFIG_BPF_STREAM_PARSER */