Update Linux to v5.10.109

Sourced from [1]

[1] https://cdn.kernel.org/pub/linux/kernel/v5.x/linux-5.10.109.tar.xz

Change-Id: I19bca9fc6762d4e63bcf3e4cba88bbe560d9c76c
Signed-off-by: Olivier Deprez <olivier.deprez@arm.com>
diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c
index 7aba4ee..58d22d6 100644
--- a/net/tls/tls_main.c
+++ b/net/tls/tls_main.c
@@ -41,7 +41,9 @@
 #include <linux/inetdevice.h>
 #include <linux/inet_diag.h>
 
+#include <net/snmp.h>
 #include <net/tls.h>
+#include <net/tls_toe.h>
 
 MODULE_AUTHOR("Mellanox Technologies");
 MODULE_DESCRIPTION("Transport Layer Security Support");
@@ -54,22 +56,23 @@
 	TLS_NUM_PROTS,
 };
 
-static struct proto *saved_tcpv6_prot;
+static const struct proto *saved_tcpv6_prot;
 static DEFINE_MUTEX(tcpv6_prot_mutex);
-static struct proto *saved_tcpv4_prot;
+static const struct proto *saved_tcpv4_prot;
 static DEFINE_MUTEX(tcpv4_prot_mutex);
-static LIST_HEAD(device_list);
-static DEFINE_SPINLOCK(device_spinlock);
 static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG];
-static struct proto_ops tls_sw_proto_ops;
+static struct proto_ops tls_proto_ops[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG];
 static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
-			 struct proto *base);
+			 const struct proto *base);
 
-static void update_sk_prot(struct sock *sk, struct tls_context *ctx)
+void update_sk_prot(struct sock *sk, struct tls_context *ctx)
 {
 	int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4;
 
-	sk->sk_prot = &tls_prots[ip_ver][ctx->tx_conf][ctx->rx_conf];
+	WRITE_ONCE(sk->sk_prot,
+		   &tls_prots[ip_ver][ctx->tx_conf][ctx->rx_conf]);
+	WRITE_ONCE(sk->sk_socket->ops,
+		   &tls_proto_ops[ip_ver][ctx->tx_conf][ctx->rx_conf]);
 }
 
 int wait_on_pending_writer(struct sock *sk, long *timeo)
@@ -278,14 +281,19 @@
 		kfree(ctx->tx.rec_seq);
 		kfree(ctx->tx.iv);
 		tls_sw_release_resources_tx(sk);
+		TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW);
 	} else if (ctx->tx_conf == TLS_HW) {
 		tls_device_free_resources_tx(sk);
+		TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE);
 	}
 
-	if (ctx->rx_conf == TLS_SW)
+	if (ctx->rx_conf == TLS_SW) {
 		tls_sw_release_resources_rx(sk);
-	else if (ctx->rx_conf == TLS_HW)
+		TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW);
+	} else if (ctx->rx_conf == TLS_HW) {
 		tls_device_offload_cleanup_rx(sk);
+		TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE);
+	}
 }
 
 static void tls_sk_proto_close(struct sock *sk, long timeout)
@@ -307,7 +315,7 @@
 	write_lock_bh(&sk->sk_callback_lock);
 	if (free_ctx)
 		rcu_assign_pointer(icsk->icsk_ulp_data, NULL);
-	sk->sk_prot = ctx->sk_proto;
+	WRITE_ONCE(sk->sk_prot, ctx->sk_proto);
 	if (sk->sk_write_space == tls_write_space)
 		sk->sk_write_space = ctx->sk_write_space;
 	write_unlock_bh(&sk->sk_callback_lock);
@@ -324,12 +332,13 @@
 		tls_ctx_free(sk, ctx);
 }
 
-static int do_tls_getsockopt_tx(struct sock *sk, char __user *optval,
-				int __user *optlen)
+static int do_tls_getsockopt_conf(struct sock *sk, char __user *optval,
+				  int __user *optlen, int tx)
 {
 	int rc = 0;
 	struct tls_context *ctx = tls_get_ctx(sk);
 	struct tls_crypto_info *crypto_info;
+	struct cipher_context *cctx;
 	int len;
 
 	if (get_user(len, optlen))
@@ -346,7 +355,13 @@
 	}
 
 	/* get user crypto info */
-	crypto_info = &ctx->crypto_send.info;
+	if (tx) {
+		crypto_info = &ctx->crypto_send.info;
+		cctx = &ctx->tx;
+	} else {
+		crypto_info = &ctx->crypto_recv.info;
+		cctx = &ctx->rx;
+	}
 
 	if (!TLS_CRYPTO_INFO_READY(crypto_info)) {
 		rc = -EBUSY;
@@ -373,9 +388,9 @@
 		}
 		lock_sock(sk);
 		memcpy(crypto_info_aes_gcm_128->iv,
-		       ctx->tx.iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
+		       cctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
 		       TLS_CIPHER_AES_GCM_128_IV_SIZE);
-		memcpy(crypto_info_aes_gcm_128->rec_seq, ctx->tx.rec_seq,
+		memcpy(crypto_info_aes_gcm_128->rec_seq, cctx->rec_seq,
 		       TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE);
 		release_sock(sk);
 		if (copy_to_user(optval,
@@ -397,9 +412,9 @@
 		}
 		lock_sock(sk);
 		memcpy(crypto_info_aes_gcm_256->iv,
-		       ctx->tx.iv + TLS_CIPHER_AES_GCM_256_SALT_SIZE,
+		       cctx->iv + TLS_CIPHER_AES_GCM_256_SALT_SIZE,
 		       TLS_CIPHER_AES_GCM_256_IV_SIZE);
-		memcpy(crypto_info_aes_gcm_256->rec_seq, ctx->tx.rec_seq,
+		memcpy(crypto_info_aes_gcm_256->rec_seq, cctx->rec_seq,
 		       TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE);
 		release_sock(sk);
 		if (copy_to_user(optval,
@@ -423,7 +438,9 @@
 
 	switch (optname) {
 	case TLS_TX:
-		rc = do_tls_getsockopt_tx(sk, optval, optlen);
+	case TLS_RX:
+		rc = do_tls_getsockopt_conf(sk, optval, optlen,
+					    optname == TLS_TX);
 		break;
 	default:
 		rc = -ENOPROTOOPT;
@@ -444,7 +461,7 @@
 	return do_tls_getsockopt(sk, optname, optval, optlen);
 }
 
-static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval,
+static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval,
 				  unsigned int optlen, int tx)
 {
 	struct tls_crypto_info *crypto_info;
@@ -454,7 +471,7 @@
 	int rc = 0;
 	int conf;
 
-	if (!optval || (optlen < sizeof(*crypto_info))) {
+	if (sockptr_is_null(optval) || (optlen < sizeof(*crypto_info))) {
 		rc = -EINVAL;
 		goto out;
 	}
@@ -473,7 +490,7 @@
 		goto out;
 	}
 
-	rc = copy_from_user(crypto_info, optval, sizeof(*crypto_info));
+	rc = copy_from_sockptr(crypto_info, optval, sizeof(*crypto_info));
 	if (rc) {
 		rc = -EFAULT;
 		goto err_crypto_info;
@@ -516,8 +533,9 @@
 		goto err_crypto_info;
 	}
 
-	rc = copy_from_user(crypto_info + 1, optval + sizeof(*crypto_info),
-			    optlen - sizeof(*crypto_info));
+	rc = copy_from_sockptr_offset(crypto_info + 1, optval,
+				      sizeof(*crypto_info),
+				      optlen - sizeof(*crypto_info));
 	if (rc) {
 		rc = -EFAULT;
 		goto err_crypto_info;
@@ -526,19 +544,29 @@
 	if (tx) {
 		rc = tls_set_device_offload(sk, ctx);
 		conf = TLS_HW;
-		if (rc) {
+		if (!rc) {
+			TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXDEVICE);
+			TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE);
+		} else {
 			rc = tls_set_sw_offload(sk, ctx, 1);
 			if (rc)
 				goto err_crypto_info;
+			TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXSW);
+			TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW);
 			conf = TLS_SW;
 		}
 	} else {
 		rc = tls_set_device_offload_rx(sk, ctx);
 		conf = TLS_HW;
-		if (rc) {
+		if (!rc) {
+			TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXDEVICE);
+			TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE);
+		} else {
 			rc = tls_set_sw_offload(sk, ctx, 0);
 			if (rc)
 				goto err_crypto_info;
+			TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXSW);
+			TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW);
 			conf = TLS_SW;
 		}
 		tls_sw_strparser_arm(sk, ctx);
@@ -552,8 +580,6 @@
 	if (tx) {
 		ctx->sk_write_space = sk->sk_write_space;
 		sk->sk_write_space = tls_write_space;
-	} else {
-		sk->sk_socket->ops = &tls_sw_proto_ops;
 	}
 	goto out;
 
@@ -563,8 +589,8 @@
 	return rc;
 }
 
-static int do_tls_setsockopt(struct sock *sk, int optname,
-			     char __user *optval, unsigned int optlen)
+static int do_tls_setsockopt(struct sock *sk, int optname, sockptr_t optval,
+			     unsigned int optlen)
 {
 	int rc = 0;
 
@@ -584,7 +610,7 @@
 }
 
 static int tls_setsockopt(struct sock *sk, int level, int optname,
-			  char __user *optval, unsigned int optlen)
+			  sockptr_t optval, unsigned int optlen)
 {
 	struct tls_context *ctx = tls_get_ctx(sk);
 
@@ -595,7 +621,7 @@
 	return do_tls_setsockopt(sk, optname, optval, optlen);
 }
 
-static struct tls_context *create_ctx(struct sock *sk)
+struct tls_context *tls_ctx_create(struct sock *sk)
 {
 	struct inet_connection_sock *icsk = inet_csk(sk);
 	struct tls_context *ctx;
@@ -606,122 +632,77 @@
 
 	mutex_init(&ctx->tx_lock);
 	rcu_assign_pointer(icsk->icsk_ulp_data, ctx);
-	ctx->sk_proto = sk->sk_prot;
+	ctx->sk_proto = READ_ONCE(sk->sk_prot);
+	ctx->sk = sk;
 	return ctx;
 }
 
+static void build_proto_ops(struct proto_ops ops[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
+			    const struct proto_ops *base)
+{
+	ops[TLS_BASE][TLS_BASE] = *base;
+
+	ops[TLS_SW  ][TLS_BASE] = ops[TLS_BASE][TLS_BASE];
+	ops[TLS_SW  ][TLS_BASE].sendpage_locked	= tls_sw_sendpage_locked;
+
+	ops[TLS_BASE][TLS_SW  ] = ops[TLS_BASE][TLS_BASE];
+	ops[TLS_BASE][TLS_SW  ].splice_read	= tls_sw_splice_read;
+
+	ops[TLS_SW  ][TLS_SW  ] = ops[TLS_SW  ][TLS_BASE];
+	ops[TLS_SW  ][TLS_SW  ].splice_read	= tls_sw_splice_read;
+
+#ifdef CONFIG_TLS_DEVICE
+	ops[TLS_HW  ][TLS_BASE] = ops[TLS_BASE][TLS_BASE];
+	ops[TLS_HW  ][TLS_BASE].sendpage_locked	= NULL;
+
+	ops[TLS_HW  ][TLS_SW  ] = ops[TLS_BASE][TLS_SW  ];
+	ops[TLS_HW  ][TLS_SW  ].sendpage_locked	= NULL;
+
+	ops[TLS_BASE][TLS_HW  ] = ops[TLS_BASE][TLS_SW  ];
+
+	ops[TLS_SW  ][TLS_HW  ] = ops[TLS_SW  ][TLS_SW  ];
+
+	ops[TLS_HW  ][TLS_HW  ] = ops[TLS_HW  ][TLS_SW  ];
+	ops[TLS_HW  ][TLS_HW  ].sendpage_locked	= NULL;
+#endif
+#ifdef CONFIG_TLS_TOE
+	ops[TLS_HW_RECORD][TLS_HW_RECORD] = *base;
+#endif
+}
+
 static void tls_build_proto(struct sock *sk)
 {
 	int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4;
+	struct proto *prot = READ_ONCE(sk->sk_prot);
 
 	/* Build IPv6 TLS whenever the address of tcpv6 _prot changes */
 	if (ip_ver == TLSV6 &&
-	    unlikely(sk->sk_prot != smp_load_acquire(&saved_tcpv6_prot))) {
+	    unlikely(prot != smp_load_acquire(&saved_tcpv6_prot))) {
 		mutex_lock(&tcpv6_prot_mutex);
-		if (likely(sk->sk_prot != saved_tcpv6_prot)) {
-			build_protos(tls_prots[TLSV6], sk->sk_prot);
-			smp_store_release(&saved_tcpv6_prot, sk->sk_prot);
+		if (likely(prot != saved_tcpv6_prot)) {
+			build_protos(tls_prots[TLSV6], prot);
+			build_proto_ops(tls_proto_ops[TLSV6],
+					sk->sk_socket->ops);
+			smp_store_release(&saved_tcpv6_prot, prot);
 		}
 		mutex_unlock(&tcpv6_prot_mutex);
 	}
 
 	if (ip_ver == TLSV4 &&
-	    unlikely(sk->sk_prot != smp_load_acquire(&saved_tcpv4_prot))) {
+	    unlikely(prot != smp_load_acquire(&saved_tcpv4_prot))) {
 		mutex_lock(&tcpv4_prot_mutex);
-		if (likely(sk->sk_prot != saved_tcpv4_prot)) {
-			build_protos(tls_prots[TLSV4], sk->sk_prot);
-			smp_store_release(&saved_tcpv4_prot, sk->sk_prot);
+		if (likely(prot != saved_tcpv4_prot)) {
+			build_protos(tls_prots[TLSV4], prot);
+			build_proto_ops(tls_proto_ops[TLSV4],
+					sk->sk_socket->ops);
+			smp_store_release(&saved_tcpv4_prot, prot);
 		}
 		mutex_unlock(&tcpv4_prot_mutex);
 	}
 }
 
-static void tls_hw_sk_destruct(struct sock *sk)
-{
-	struct tls_context *ctx = tls_get_ctx(sk);
-	struct inet_connection_sock *icsk = inet_csk(sk);
-
-	ctx->sk_destruct(sk);
-	/* Free ctx */
-	rcu_assign_pointer(icsk->icsk_ulp_data, NULL);
-	tls_ctx_free(sk, ctx);
-}
-
-static int tls_hw_prot(struct sock *sk)
-{
-	struct tls_context *ctx;
-	struct tls_device *dev;
-	int rc = 0;
-
-	spin_lock_bh(&device_spinlock);
-	list_for_each_entry(dev, &device_list, dev_list) {
-		if (dev->feature && dev->feature(dev)) {
-			ctx = create_ctx(sk);
-			if (!ctx)
-				goto out;
-
-			spin_unlock_bh(&device_spinlock);
-			tls_build_proto(sk);
-			ctx->sk_destruct = sk->sk_destruct;
-			sk->sk_destruct = tls_hw_sk_destruct;
-			ctx->rx_conf = TLS_HW_RECORD;
-			ctx->tx_conf = TLS_HW_RECORD;
-			update_sk_prot(sk, ctx);
-			spin_lock_bh(&device_spinlock);
-			rc = 1;
-			break;
-		}
-	}
-out:
-	spin_unlock_bh(&device_spinlock);
-	return rc;
-}
-
-static void tls_hw_unhash(struct sock *sk)
-{
-	struct tls_context *ctx = tls_get_ctx(sk);
-	struct tls_device *dev;
-
-	spin_lock_bh(&device_spinlock);
-	list_for_each_entry(dev, &device_list, dev_list) {
-		if (dev->unhash) {
-			kref_get(&dev->kref);
-			spin_unlock_bh(&device_spinlock);
-			dev->unhash(dev, sk);
-			kref_put(&dev->kref, dev->release);
-			spin_lock_bh(&device_spinlock);
-		}
-	}
-	spin_unlock_bh(&device_spinlock);
-	ctx->sk_proto->unhash(sk);
-}
-
-static int tls_hw_hash(struct sock *sk)
-{
-	struct tls_context *ctx = tls_get_ctx(sk);
-	struct tls_device *dev;
-	int err;
-
-	err = ctx->sk_proto->hash(sk);
-	spin_lock_bh(&device_spinlock);
-	list_for_each_entry(dev, &device_list, dev_list) {
-		if (dev->hash) {
-			kref_get(&dev->kref);
-			spin_unlock_bh(&device_spinlock);
-			err |= dev->hash(dev, sk);
-			kref_put(&dev->kref, dev->release);
-			spin_lock_bh(&device_spinlock);
-		}
-	}
-	spin_unlock_bh(&device_spinlock);
-
-	if (err)
-		tls_hw_unhash(sk);
-	return err;
-}
-
 static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
-			 struct proto *base)
+			 const struct proto *base)
 {
 	prot[TLS_BASE][TLS_BASE] = *base;
 	prot[TLS_BASE][TLS_BASE].setsockopt	= tls_setsockopt;
@@ -757,10 +738,11 @@
 
 	prot[TLS_HW][TLS_HW] = prot[TLS_HW][TLS_SW];
 #endif
-
+#ifdef CONFIG_TLS_TOE
 	prot[TLS_HW_RECORD][TLS_HW_RECORD] = *base;
-	prot[TLS_HW_RECORD][TLS_HW_RECORD].hash		= tls_hw_hash;
-	prot[TLS_HW_RECORD][TLS_HW_RECORD].unhash	= tls_hw_unhash;
+	prot[TLS_HW_RECORD][TLS_HW_RECORD].hash		= tls_toe_hash;
+	prot[TLS_HW_RECORD][TLS_HW_RECORD].unhash	= tls_toe_unhash;
+#endif
 }
 
 static int tls_init(struct sock *sk)
@@ -768,8 +750,12 @@
 	struct tls_context *ctx;
 	int rc = 0;
 
-	if (tls_hw_prot(sk))
+	tls_build_proto(sk);
+
+#ifdef CONFIG_TLS_TOE
+	if (tls_toe_bypass(sk))
 		return 0;
+#endif
 
 	/* The TLS ulp is currently supported only for TCP sockets
 	 * in ESTABLISHED state.
@@ -780,11 +766,9 @@
 	if (sk->sk_state != TCP_ESTABLISHED)
 		return -ENOTCONN;
 
-	tls_build_proto(sk);
-
 	/* allocate tls context */
 	write_lock_bh(&sk->sk_callback_lock);
-	ctx = create_ctx(sk);
+	ctx = tls_ctx_create(sk);
 	if (!ctx) {
 		rc = -ENOMEM;
 		goto out;
@@ -808,7 +792,8 @@
 		ctx->sk_write_space = write_space;
 		ctx->sk_proto = p;
 	} else {
-		sk->sk_prot = p;
+		/* Pairs with lockless read in sk_clone_lock(). */
+		WRITE_ONCE(sk->sk_prot, p);
 		sk->sk_write_space = write_space;
 	}
 }
@@ -874,21 +859,34 @@
 	return size;
 }
 
-void tls_register_device(struct tls_device *device)
+static int __net_init tls_init_net(struct net *net)
 {
-	spin_lock_bh(&device_spinlock);
-	list_add_tail(&device->dev_list, &device_list);
-	spin_unlock_bh(&device_spinlock);
-}
-EXPORT_SYMBOL(tls_register_device);
+	int err;
 
-void tls_unregister_device(struct tls_device *device)
-{
-	spin_lock_bh(&device_spinlock);
-	list_del(&device->dev_list);
-	spin_unlock_bh(&device_spinlock);
+	net->mib.tls_statistics = alloc_percpu(struct linux_tls_mib);
+	if (!net->mib.tls_statistics)
+		return -ENOMEM;
+
+	err = tls_proc_init(net);
+	if (err)
+		goto err_free_stats;
+
+	return 0;
+err_free_stats:
+	free_percpu(net->mib.tls_statistics);
+	return err;
 }
-EXPORT_SYMBOL(tls_unregister_device);
+
+static void __net_exit tls_exit_net(struct net *net)
+{
+	tls_proc_fini(net);
+	free_percpu(net->mib.tls_statistics);
+}
+
+static struct pernet_operations tls_proc_ops = {
+	.init = tls_init_net,
+	.exit = tls_exit_net,
+};
 
 static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = {
 	.name			= "tls",
@@ -901,9 +899,11 @@
 
 static int __init tls_register(void)
 {
-	tls_sw_proto_ops = inet_stream_ops;
-	tls_sw_proto_ops.splice_read = tls_sw_splice_read;
-	tls_sw_proto_ops.sendpage_locked   = tls_sw_sendpage_locked,
+	int err;
+
+	err = register_pernet_subsys(&tls_proc_ops);
+	if (err)
+		return err;
 
 	tls_device_init();
 	tcp_register_ulp(&tcp_tls_ulp_ops);
@@ -915,6 +915,7 @@
 {
 	tcp_unregister_ulp(&tcp_tls_ulp_ops);
 	tls_device_cleanup();
+	unregister_pernet_subsys(&tls_proc_ops);
 }
 
 module_init(tls_register);