Update Linux to v5.4.2

Change-Id: Idf6911045d9d382da2cfe01b1edff026404ac8fd
diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c
index f8183fd..447defb 100644
--- a/net/ipv4/udp.c
+++ b/net/ipv4/udp.c
@@ -1,3 +1,4 @@
+// SPDX-License-Identifier: GPL-2.0-or-later
 /*
  * INET		An implementation of the TCP/IP protocol suite for the LINUX
  *		operating system.  INET is implemented using the  BSD Socket
@@ -69,19 +70,13 @@
  *					a single port at the same time.
  *	Derek Atkins <derek@ihtfp.com>: Add Encapulation Support
  *	James Chapman		:	Add L2TP encapsulation type.
- *
- *
- *		This program is free software; you can redistribute it and/or
- *		modify it under the terms of the GNU General Public License
- *		as published by the Free Software Foundation; either version
- *		2 of the License, or (at your option) any later version.
  */
 
 #define pr_fmt(fmt) "UDP: " fmt
 
 #include <linux/uaccess.h>
 #include <asm/ioctls.h>
-#include <linux/bootmem.h>
+#include <linux/memblock.h>
 #include <linux/highmem.h>
 #include <linux/swap.h>
 #include <linux/types.h>
@@ -105,6 +100,7 @@
 #include <net/net_namespace.h>
 #include <net/icmp.h>
 #include <net/inet_hashtables.h>
+#include <net/ip_tunnels.h>
 #include <net/route.h>
 #include <net/checksum.h>
 #include <net/xfrm.h>
@@ -115,6 +111,7 @@
 #include "udp_impl.h"
 #include <net/sock_reuseport.h>
 #include <net/addrconf.h>
+#include <net/udp_tunnel.h>
 
 struct udp_table udp_table __read_mostly;
 EXPORT_SYMBOL(udp_table);
@@ -128,17 +125,6 @@
 #define MAX_UDP_PORTS 65536
 #define PORTS_PER_CHAIN (MAX_UDP_PORTS / UDP_HTABLE_SIZE_MIN)
 
-/* IPCB reference means this can not be used from early demux */
-static bool udp_lib_exact_dif_match(struct net *net, struct sk_buff *skb)
-{
-#if IS_ENABLED(CONFIG_NET_L3_MASTER_DEV)
-	if (!net->ipv4.sysctl_udp_l3mdev_accept &&
-	    skb && ipv4_l3mdev_skb(IPCB(skb)->flags))
-		return true;
-#endif
-	return false;
-}
-
 static int udp_lib_lport_inuse(struct net *net, __u16 num,
 			       const struct udp_hslot *hslot,
 			       unsigned long *bitmap,
@@ -367,25 +353,23 @@
 static int compute_score(struct sock *sk, struct net *net,
 			 __be32 saddr, __be16 sport,
 			 __be32 daddr, unsigned short hnum,
-			 int dif, int sdif, bool exact_dif)
+			 int dif, int sdif)
 {
 	int score;
 	struct inet_sock *inet;
+	bool dev_match;
 
 	if (!net_eq(sock_net(sk), net) ||
 	    udp_sk(sk)->udp_port_hash != hnum ||
 	    ipv6_only_sock(sk))
 		return -1;
 
+	if (sk->sk_rcv_saddr != daddr)
+		return -1;
+
 	score = (sk->sk_family == PF_INET) ? 2 : 1;
+
 	inet = inet_sk(sk);
-
-	if (inet->inet_rcv_saddr) {
-		if (inet->inet_rcv_saddr != daddr)
-			return -1;
-		score += 4;
-	}
-
 	if (inet->inet_daddr) {
 		if (inet->inet_daddr != saddr)
 			return -1;
@@ -398,17 +382,13 @@
 		score += 4;
 	}
 
-	if (sk->sk_bound_dev_if || exact_dif) {
-		bool dev_match = (sk->sk_bound_dev_if == dif ||
-				  sk->sk_bound_dev_if == sdif);
+	dev_match = udp_sk_bound_dev_eq(net, sk->sk_bound_dev_if,
+					dif, sdif);
+	if (!dev_match)
+		return -1;
+	score += 4;
 
-		if (!dev_match)
-			return -1;
-		if (sk->sk_bound_dev_if)
-			score += 4;
-	}
-
-	if (sk->sk_incoming_cpu == raw_smp_processor_id())
+	if (READ_ONCE(sk->sk_incoming_cpu) == raw_smp_processor_id())
 		score++;
 	return score;
 }
@@ -429,7 +409,7 @@
 static struct sock *udp4_lib_lookup2(struct net *net,
 				     __be32 saddr, __be16 sport,
 				     __be32 daddr, unsigned int hnum,
-				     int dif, int sdif, bool exact_dif,
+				     int dif, int sdif,
 				     struct udp_hslot *hslot2,
 				     struct sk_buff *skb)
 {
@@ -441,14 +421,15 @@
 	badness = 0;
 	udp_portaddr_for_each_entry_rcu(sk, &hslot2->head) {
 		score = compute_score(sk, net, saddr, sport,
-				      daddr, hnum, dif, sdif, exact_dif);
+				      daddr, hnum, dif, sdif);
 		if (score > badness) {
-			if (sk->sk_reuseport) {
+			if (sk->sk_reuseport &&
+			    sk->sk_state != TCP_ESTABLISHED) {
 				hash = udp_ehashfn(net, daddr, hnum,
 						   saddr, sport);
 				result = reuseport_select_sock(sk, hash, skb,
 							sizeof(struct udphdr));
-				if (result)
+				if (result && !reuseport_has_conns(sk, false))
 					return result;
 			}
 			badness = score;
@@ -465,65 +446,29 @@
 		__be16 sport, __be32 daddr, __be16 dport, int dif,
 		int sdif, struct udp_table *udptable, struct sk_buff *skb)
 {
-	struct sock *sk, *result;
+	struct sock *result;
 	unsigned short hnum = ntohs(dport);
-	unsigned int hash2, slot2, slot = udp_hashfn(net, hnum, udptable->mask);
-	struct udp_hslot *hslot2, *hslot = &udptable->hash[slot];
-	bool exact_dif = udp_lib_exact_dif_match(net, skb);
-	int score, badness;
-	u32 hash = 0;
+	unsigned int hash2, slot2;
+	struct udp_hslot *hslot2;
 
-	if (hslot->count > 10) {
-		hash2 = ipv4_portaddr_hash(net, daddr, hnum);
+	hash2 = ipv4_portaddr_hash(net, daddr, hnum);
+	slot2 = hash2 & udptable->mask;
+	hslot2 = &udptable->hash2[slot2];
+
+	result = udp4_lib_lookup2(net, saddr, sport,
+				  daddr, hnum, dif, sdif,
+				  hslot2, skb);
+	if (!result) {
+		hash2 = ipv4_portaddr_hash(net, htonl(INADDR_ANY), hnum);
 		slot2 = hash2 & udptable->mask;
 		hslot2 = &udptable->hash2[slot2];
-		if (hslot->count < hslot2->count)
-			goto begin;
 
 		result = udp4_lib_lookup2(net, saddr, sport,
-					  daddr, hnum, dif, sdif,
-					  exact_dif, hslot2, skb);
-		if (!result) {
-			unsigned int old_slot2 = slot2;
-			hash2 = ipv4_portaddr_hash(net, htonl(INADDR_ANY), hnum);
-			slot2 = hash2 & udptable->mask;
-			/* avoid searching the same slot again. */
-			if (unlikely(slot2 == old_slot2))
-				return result;
-
-			hslot2 = &udptable->hash2[slot2];
-			if (hslot->count < hslot2->count)
-				goto begin;
-
-			result = udp4_lib_lookup2(net, saddr, sport,
-						  daddr, hnum, dif, sdif,
-						  exact_dif, hslot2, skb);
-		}
-		if (unlikely(IS_ERR(result)))
-			return NULL;
-		return result;
+					  htonl(INADDR_ANY), hnum, dif, sdif,
+					  hslot2, skb);
 	}
-begin:
-	result = NULL;
-	badness = 0;
-	sk_for_each_rcu(sk, &hslot->head) {
-		score = compute_score(sk, net, saddr, sport,
-				      daddr, hnum, dif, sdif, exact_dif);
-		if (score > badness) {
-			if (sk->sk_reuseport) {
-				hash = udp_ehashfn(net, daddr, hnum,
-						   saddr, sport);
-				result = reuseport_select_sock(sk, hash, skb,
-							sizeof(struct udphdr));
-				if (unlikely(IS_ERR(result)))
-					return NULL;
-				if (result)
-					return result;
-			}
-			result = sk;
-			badness = score;
-		}
-	}
+	if (IS_ERR(result))
+		return NULL;
 	return result;
 }
 EXPORT_SYMBOL_GPL(__udp4_lib_lookup);
@@ -542,7 +487,11 @@
 struct sock *udp4_lib_lookup_skb(struct sk_buff *skb,
 				 __be16 sport, __be16 dport)
 {
-	return __udp4_lib_lookup_skb(skb, sport, dport, &udp_table);
+	const struct iphdr *iph = ip_hdr(skb);
+
+	return __udp4_lib_lookup(dev_net(skb->dev), iph->saddr, sport,
+				 iph->daddr, dport, inet_iif(skb),
+				 inet_sdif(skb), &udp_table, NULL);
 }
 EXPORT_SYMBOL_GPL(udp4_lib_lookup_skb);
 
@@ -577,14 +526,98 @@
 	    (inet->inet_dport != rmt_port && inet->inet_dport) ||
 	    (inet->inet_rcv_saddr && inet->inet_rcv_saddr != loc_addr) ||
 	    ipv6_only_sock(sk) ||
-	    (sk->sk_bound_dev_if && sk->sk_bound_dev_if != dif &&
-	     sk->sk_bound_dev_if != sdif))
+	    !udp_sk_bound_dev_eq(net, sk->sk_bound_dev_if, dif, sdif))
 		return false;
 	if (!ip_mc_sf_allow(sk, loc_addr, rmt_addr, dif, sdif))
 		return false;
 	return true;
 }
 
+DEFINE_STATIC_KEY_FALSE(udp_encap_needed_key);
+void udp_encap_enable(void)
+{
+	static_branch_inc(&udp_encap_needed_key);
+}
+EXPORT_SYMBOL(udp_encap_enable);
+
+/* Handler for tunnels with arbitrary destination ports: no socket lookup, go
+ * through error handlers in encapsulations looking for a match.
+ */
+static int __udp4_lib_err_encap_no_sk(struct sk_buff *skb, u32 info)
+{
+	int i;
+
+	for (i = 0; i < MAX_IPTUN_ENCAP_OPS; i++) {
+		int (*handler)(struct sk_buff *skb, u32 info);
+		const struct ip_tunnel_encap_ops *encap;
+
+		encap = rcu_dereference(iptun_encaps[i]);
+		if (!encap)
+			continue;
+		handler = encap->err_handler;
+		if (handler && !handler(skb, info))
+			return 0;
+	}
+
+	return -ENOENT;
+}
+
+/* Try to match ICMP errors to UDP tunnels by looking up a socket without
+ * reversing source and destination port: this will match tunnels that force the
+ * same destination port on both endpoints (e.g. VXLAN, GENEVE). Note that
+ * lwtunnels might actually break this assumption by being configured with
+ * different destination ports on endpoints, in this case we won't be able to
+ * trace ICMP messages back to them.
+ *
+ * If this doesn't match any socket, probe tunnels with arbitrary destination
+ * ports (e.g. FoU, GUE): there, the receiving socket is useless, as the port
+ * we've sent packets to won't necessarily match the local destination port.
+ *
+ * Then ask the tunnel implementation to match the error against a valid
+ * association.
+ *
+ * Return an error if we can't find a match, the socket if we need further
+ * processing, zero otherwise.
+ */
+static struct sock *__udp4_lib_err_encap(struct net *net,
+					 const struct iphdr *iph,
+					 struct udphdr *uh,
+					 struct udp_table *udptable,
+					 struct sk_buff *skb, u32 info)
+{
+	int network_offset, transport_offset;
+	struct sock *sk;
+
+	network_offset = skb_network_offset(skb);
+	transport_offset = skb_transport_offset(skb);
+
+	/* Network header needs to point to the outer IPv4 header inside ICMP */
+	skb_reset_network_header(skb);
+
+	/* Transport header needs to point to the UDP header */
+	skb_set_transport_header(skb, iph->ihl << 2);
+
+	sk = __udp4_lib_lookup(net, iph->daddr, uh->source,
+			       iph->saddr, uh->dest, skb->dev->ifindex, 0,
+			       udptable, NULL);
+	if (sk) {
+		int (*lookup)(struct sock *sk, struct sk_buff *skb);
+		struct udp_sock *up = udp_sk(sk);
+
+		lookup = READ_ONCE(up->encap_err_lookup);
+		if (!lookup || lookup(sk, skb))
+			sk = NULL;
+	}
+
+	if (!sk)
+		sk = ERR_PTR(__udp4_lib_err_encap_no_sk(skb, info));
+
+	skb_set_transport_header(skb, transport_offset);
+	skb_set_network_header(skb, network_offset);
+
+	return sk;
+}
+
 /*
  * This routine is called by the ICMP module when it gets some
  * sort of error condition.  If err < 0 then the socket should
@@ -596,24 +629,38 @@
  * to find the appropriate port.
  */
 
-void __udp4_lib_err(struct sk_buff *skb, u32 info, struct udp_table *udptable)
+int __udp4_lib_err(struct sk_buff *skb, u32 info, struct udp_table *udptable)
 {
 	struct inet_sock *inet;
 	const struct iphdr *iph = (const struct iphdr *)skb->data;
 	struct udphdr *uh = (struct udphdr *)(skb->data+(iph->ihl<<2));
 	const int type = icmp_hdr(skb)->type;
 	const int code = icmp_hdr(skb)->code;
+	bool tunnel = false;
 	struct sock *sk;
 	int harderr;
 	int err;
 	struct net *net = dev_net(skb->dev);
 
 	sk = __udp4_lib_lookup(net, iph->daddr, uh->dest,
-			       iph->saddr, uh->source, skb->dev->ifindex, 0,
-			       udptable, NULL);
+			       iph->saddr, uh->source, skb->dev->ifindex,
+			       inet_sdif(skb), udptable, NULL);
 	if (!sk) {
-		__ICMP_INC_STATS(net, ICMP_MIB_INERRORS);
-		return;	/* No socket for error */
+		/* No socket for error: try tunnels before discarding */
+		sk = ERR_PTR(-ENOENT);
+		if (static_branch_unlikely(&udp_encap_needed_key)) {
+			sk = __udp4_lib_err_encap(net, iph, uh, udptable, skb,
+						  info);
+			if (!sk)
+				return 0;
+		}
+
+		if (IS_ERR(sk)) {
+			__ICMP_INC_STATS(net, ICMP_MIB_INERRORS);
+			return PTR_ERR(sk);
+		}
+
+		tunnel = true;
 	}
 
 	err = 0;
@@ -656,6 +703,10 @@
 	 *      RFC1122: OK.  Passes ICMP errors back to application, as per
 	 *	4.1.3.3.
 	 */
+	if (tunnel) {
+		/* ...not for tunnels though: we don't have a sending socket */
+		goto out;
+	}
 	if (!inet->recverr) {
 		if (!harderr || sk->sk_state != TCP_ESTABLISHED)
 			goto out;
@@ -665,12 +716,12 @@
 	sk->sk_err = err;
 	sk->sk_error_report(sk);
 out:
-	return;
+	return 0;
 }
 
-void udp_err(struct sk_buff *skb, u32 info)
+int udp_err(struct sk_buff *skb, u32 info)
 {
-	__udp4_lib_err(skb, info, &udp_table);
+	return __udp4_lib_err(skb, info, &udp_table);
 }
 
 /*
@@ -770,6 +821,7 @@
 	int is_udplite = IS_UDPLITE(sk);
 	int offset = skb_transport_offset(skb);
 	int len = skb->len - offset;
+	int datalen = len - sizeof(*uh);
 	__wsum csum = 0;
 
 	/*
@@ -785,20 +837,30 @@
 		const int hlen = skb_network_header_len(skb) +
 				 sizeof(struct udphdr);
 
-		if (hlen + cork->gso_size > cork->fragsize)
+		if (hlen + cork->gso_size > cork->fragsize) {
+			kfree_skb(skb);
 			return -EINVAL;
-		if (skb->len > cork->gso_size * UDP_MAX_SEGMENTS)
+		}
+		if (skb->len > cork->gso_size * UDP_MAX_SEGMENTS) {
+			kfree_skb(skb);
 			return -EINVAL;
-		if (sk->sk_no_check_tx)
+		}
+		if (sk->sk_no_check_tx) {
+			kfree_skb(skb);
 			return -EINVAL;
+		}
 		if (skb->ip_summed != CHECKSUM_PARTIAL || is_udplite ||
-		    dst_xfrm(skb_dst(skb)))
+		    dst_xfrm(skb_dst(skb))) {
+			kfree_skb(skb);
 			return -EIO;
+		}
 
-		skb_shinfo(skb)->gso_size = cork->gso_size;
-		skb_shinfo(skb)->gso_type = SKB_GSO_UDP_L4;
-		skb_shinfo(skb)->gso_segs = DIV_ROUND_UP(len - sizeof(uh),
-							 cork->gso_size);
+		if (datalen > cork->gso_size) {
+			skb_shinfo(skb)->gso_size = cork->gso_size;
+			skb_shinfo(skb)->gso_type = SKB_GSO_UDP_L4;
+			skb_shinfo(skb)->gso_segs = DIV_ROUND_UP(datalen,
+								 cork->gso_size);
+		}
 		goto csum_partial;
 	}
 
@@ -1042,7 +1104,7 @@
 	}
 
 	if (ipv4_is_multicast(daddr)) {
-		if (!ipc.oif)
+		if (!ipc.oif || netif_index_is_l3_master(sock_net(sk), ipc.oif))
 			ipc.oif = inet->mc_index;
 		if (!saddr)
 			saddr = inet->mc_addr;
@@ -1072,7 +1134,7 @@
 
 		fl4 = &fl4_stack;
 
-		flowi4_init_output(fl4, ipc.oif, sk->sk_mark, tos,
+		flowi4_init_output(fl4, ipc.oif, ipc.sockc.mark, tos,
 				   RT_SCOPE_UNIVERSE, sk->sk_protocol,
 				   flow_flags,
 				   faddr, saddr, dport, inet->inet_sport,
@@ -1235,6 +1297,27 @@
 
 #define UDP_SKB_IS_STATELESS 0x80000000
 
+/* all head states (dst, sk, nf conntrack) except skb extensions are
+ * cleared by udp_rcv().
+ *
+ * We need to preserve secpath, if present, to eventually process
+ * IP_CMSG_PASSSEC at recvmsg() time.
+ *
+ * Other extensions can be cleared.
+ */
+static bool udp_try_make_stateless(struct sk_buff *skb)
+{
+	if (!skb_has_extensions(skb))
+		return true;
+
+	if (!secpath_exists(skb)) {
+		skb_ext_reset(skb);
+		return true;
+	}
+
+	return false;
+}
+
 static void udp_set_dev_scratch(struct sk_buff *skb)
 {
 	struct udp_dev_scratch *scratch = udp_skb_scratch(skb);
@@ -1246,14 +1329,24 @@
 	scratch->csum_unnecessary = !!skb_csum_unnecessary(skb);
 	scratch->is_linear = !skb_is_nonlinear(skb);
 #endif
-	/* all head states execept sp (dst, sk, nf) are always cleared by
-	 * udp_rcv() and we need to preserve secpath, if present, to eventually
-	 * process IP_CMSG_PASSSEC at recvmsg() time
-	 */
-	if (likely(!skb_sec_path(skb)))
+	if (udp_try_make_stateless(skb))
 		scratch->_tsize_state |= UDP_SKB_IS_STATELESS;
 }
 
+static void udp_skb_csum_unnecessary_set(struct sk_buff *skb)
+{
+	/* We come here after udp_lib_checksum_complete() returned 0.
+	 * This means that __skb_checksum_complete() might have
+	 * set skb->csum_valid to 1.
+	 * On 64bit platforms, we can set csum_unnecessary
+	 * to true, but only if the skb is not shared.
+	 */
+#if BITS_PER_LONG == 64
+	if (!skb_shared(skb))
+		udp_skb_scratch(skb)->csum_unnecessary = true;
+#endif
+}
+
 static int udp_skb_truesize(struct sk_buff *skb)
 {
 	return udp_skb_scratch(skb)->_tsize_state & ~UDP_SKB_IS_STATELESS;
@@ -1488,10 +1581,7 @@
 			*total += skb->truesize;
 			kfree_skb(skb);
 		} else {
-			/* the csum related bits could be changed, refresh
-			 * the scratch area
-			 */
-			udp_set_dev_scratch(skb);
+			udp_skb_csum_unnecessary_set(skb);
 			break;
 		}
 	}
@@ -1515,7 +1605,7 @@
 
 	spin_lock_bh(&rcvq->lock);
 	skb = __first_packet_length(sk, rcvq, &total);
-	if (!skb && !skb_queue_empty(sk_queue)) {
+	if (!skb && !skb_queue_empty_lockless(sk_queue)) {
 		spin_lock(&sk_queue->lock);
 		skb_queue_splice_tail_init(sk_queue, rcvq);
 		spin_unlock(&sk_queue->lock);
@@ -1559,7 +1649,7 @@
 EXPORT_SYMBOL(udp_ioctl);
 
 struct sk_buff *__skb_recv_udp(struct sock *sk, unsigned int flags,
-			       int noblock, int *peeked, int *off, int *err)
+			       int noblock, int *off, int *err)
 {
 	struct sk_buff_head *sk_queue = &sk->sk_receive_queue;
 	struct sk_buff_head *queue;
@@ -1578,19 +1668,17 @@
 			break;
 
 		error = -EAGAIN;
-		*peeked = 0;
 		do {
 			spin_lock_bh(&queue->lock);
 			skb = __skb_try_recv_from_queue(sk, queue, flags,
 							udp_skb_destructor,
-							peeked, off, err,
-							&last);
+							off, err, &last);
 			if (skb) {
 				spin_unlock_bh(&queue->lock);
 				return skb;
 			}
 
-			if (skb_queue_empty(sk_queue)) {
+			if (skb_queue_empty_lockless(sk_queue)) {
 				spin_unlock_bh(&queue->lock);
 				goto busy_check;
 			}
@@ -1605,8 +1693,7 @@
 
 			skb = __skb_try_recv_from_queue(sk, queue, flags,
 							udp_skb_dtor_locked,
-							peeked, off, err,
-							&last);
+							off, err, &last);
 			spin_unlock(&sk_queue->lock);
 			spin_unlock_bh(&queue->lock);
 			if (skb)
@@ -1617,7 +1704,7 @@
 				break;
 
 			sk_busy_loop(sk, flags & MSG_DONTWAIT);
-		} while (!skb_queue_empty(sk_queue));
+		} while (!skb_queue_empty_lockless(sk_queue));
 
 		/* sk_queue is empty, reader_queue may contain peeked packets */
 	} while (timeo &&
@@ -1641,8 +1728,7 @@
 	DECLARE_SOCKADDR(struct sockaddr_in *, sin, msg->msg_name);
 	struct sk_buff *skb;
 	unsigned int ulen, copied;
-	int peeked, peeking, off;
-	int err;
+	int off, err, peeking = flags & MSG_PEEK;
 	int is_udplite = IS_UDPLITE(sk);
 	bool checksum_valid = false;
 
@@ -1650,9 +1736,8 @@
 		return ip_recv_error(sk, msg, len, addr_len);
 
 try_again:
-	peeking = flags & MSG_PEEK;
 	off = sk_peek_offset(sk, flags);
-	skb = __skb_recv_udp(sk, flags, noblock, &peeked, &off, &err);
+	skb = __skb_recv_udp(sk, flags, noblock, &off, &err);
 	if (!skb)
 		return err;
 
@@ -1690,7 +1775,7 @@
 	}
 
 	if (unlikely(err)) {
-		if (!peeked) {
+		if (!peeking) {
 			atomic_inc(&sk->sk_drops);
 			UDP_INC_STATS(sock_net(sk),
 				      UDP_MIB_INERRORS, is_udplite);
@@ -1699,7 +1784,7 @@
 		return err;
 	}
 
-	if (!peeked)
+	if (!peeking)
 		UDP_INC_STATS(sock_net(sk),
 			      UDP_MIB_INDATAGRAMS, is_udplite);
 
@@ -1712,7 +1797,15 @@
 		sin->sin_addr.s_addr = ip_hdr(skb)->saddr;
 		memset(sin->sin_zero, 0, sizeof(sin->sin_zero));
 		*addr_len = sizeof(*sin);
+
+		if (cgroup_bpf_enabled)
+			BPF_CGROUP_RUN_PROG_UDP4_RECVMSG_LOCK(sk,
+							(struct sockaddr *)sin);
 	}
+
+	if (udp_sk(sk)->gro_enabled)
+		udp_cmsg_recv(msg, sk, skb);
+
 	if (inet->cmsg_flags)
 		ip_cmsg_recv_offset(msg, sk, skb, sizeof(struct udphdr), off);
 
@@ -1852,7 +1945,7 @@
 }
 EXPORT_SYMBOL(udp_lib_rehash);
 
-static void udp_v4_rehash(struct sock *sk)
+void udp_v4_rehash(struct sock *sk)
 {
 	u16 new_hash = ipv4_portaddr_hash(sock_net(sk),
 					  inet_sk(sk)->inet_rcv_saddr,
@@ -1889,13 +1982,6 @@
 	return 0;
 }
 
-static DEFINE_STATIC_KEY_FALSE(udp_encap_needed_key);
-void udp_encap_enable(void)
-{
-	static_branch_enable(&udp_encap_needed_key);
-}
-EXPORT_SYMBOL(udp_encap_enable);
-
 /* returns:
  *  -1: error
  *   0: success
@@ -1904,7 +1990,7 @@
  * Note that in the success and error cases, the skb is assumed to
  * have either been requeued or freed.
  */
-static int udp_queue_rcv_skb(struct sock *sk, struct sk_buff *skb)
+static int udp_queue_rcv_one_skb(struct sock *sk, struct sk_buff *skb)
 {
 	struct udp_sock *up = udp_sk(sk);
 	int is_udplite = IS_UDPLITE(sk);
@@ -1914,7 +2000,7 @@
 	 */
 	if (!xfrm4_policy_check(sk, XFRM_POLICY_IN, skb))
 		goto drop;
-	nf_reset(skb);
+	nf_reset_ct(skb);
 
 	if (static_branch_unlikely(&udp_encap_needed_key) && up->encap_type) {
 		int (*encap_rcv)(struct sock *sk, struct sk_buff *skb);
@@ -2007,6 +2093,27 @@
 	return -1;
 }
 
+static int udp_queue_rcv_skb(struct sock *sk, struct sk_buff *skb)
+{
+	struct sk_buff *next, *segs;
+	int ret;
+
+	if (likely(!udp_unexpected_gso(sk, skb)))
+		return udp_queue_rcv_one_skb(sk, skb);
+
+	BUILD_BUG_ON(sizeof(struct udp_skb_cb) > SKB_SGO_CB_OFFSET);
+	__skb_push(skb, -skb_mac_offset(skb));
+	segs = udp_rcv_segment(sk, skb, true);
+	for (skb = segs; skb; skb = next) {
+		next = skb->next;
+		__skb_pull(skb, skb_transport_offset(skb));
+		ret = udp_queue_rcv_one_skb(sk, skb);
+		if (ret > 0)
+			ip_protocol_deliver_rcu(dev_net(skb->dev), skb, -ret);
+	}
+	return 0;
+}
+
 /* For TCP sockets, sk_rx_dst is protected by socket lock
  * For UDP, we use xchg() to guard against concurrent changes.
  */
@@ -2095,7 +2202,7 @@
 
 /* Initialize UDP checksum. If exited with zero value (success),
  * CHECKSUM_UNNECESSARY means, that no more checks are required.
- * Otherwise, csum completion requires chacksumming packet body,
+ * Otherwise, csum completion requires checksumming packet body,
  * including udp header and folding it to skb->csum.
  */
 static inline int udp4_csum_init(struct sk_buff *skb, struct udphdr *uh,
@@ -2149,8 +2256,7 @@
 	int ret;
 
 	if (inet_get_convert_csum(sk) && uh->check && !IS_UDPLITE(sk))
-		skb_checksum_try_convert(skb, IPPROTO_UDP, uh->check,
-					 inet_compute_pseudo);
+		skb_checksum_try_convert(skb, IPPROTO_UDP, inet_compute_pseudo);
 
 	ret = udp_queue_rcv_skb(sk, skb);
 
@@ -2223,7 +2329,7 @@
 
 	if (!xfrm4_policy_check(NULL, XFRM_POLICY_IN, skb))
 		goto drop;
-	nf_reset(skb);
+	nf_reset_ct(skb);
 
 	/* No socket. Drop packet silently, if checksum is wrong */
 	if (udp_lib_checksum_complete(skb))
@@ -2398,11 +2504,15 @@
 	bool slow = lock_sock_fast(sk);
 	udp_flush_pending_frames(sk);
 	unlock_sock_fast(sk, slow);
-	if (static_branch_unlikely(&udp_encap_needed_key) && up->encap_type) {
-		void (*encap_destroy)(struct sock *sk);
-		encap_destroy = READ_ONCE(up->encap_destroy);
-		if (encap_destroy)
-			encap_destroy(sk);
+	if (static_branch_unlikely(&udp_encap_needed_key)) {
+		if (up->encap_type) {
+			void (*encap_destroy)(struct sock *sk);
+			encap_destroy = READ_ONCE(up->encap_destroy);
+			if (encap_destroy)
+				encap_destroy(sk);
+		}
+		if (up->encap_enabled)
+			static_branch_dec(&udp_encap_needed_key);
 	}
 }
 
@@ -2447,7 +2557,9 @@
 			/* FALLTHROUGH */
 		case UDP_ENCAP_L2TPINUDP:
 			up->encap_type = val;
-			udp_encap_enable();
+			lock_sock(sk);
+			udp_tunnel_encap_enable(sk->sk_socket);
+			release_sock(sk);
 			break;
 		default:
 			err = -ENOPROTOOPT;
@@ -2469,6 +2581,14 @@
 		up->gso_size = val;
 		break;
 
+	case UDP_GRO:
+		lock_sock(sk);
+		if (valbool)
+			udp_tunnel_encap_enable(sk->sk_socket);
+		up->gro_enabled = valbool;
+		release_sock(sk);
+		break;
+
 	/*
 	 * 	UDP-Lite's partial checksum coverage (RFC 3828).
 	 */
@@ -2620,7 +2740,7 @@
 	__poll_t mask = datagram_poll(file, sock, wait);
 	struct sock *sk = sock->sk;
 
-	if (!skb_queue_empty(&udp_sk(sk)->reader_queue))
+	if (!skb_queue_empty_lockless(&udp_sk(sk)->reader_queue))
 		mask |= EPOLLIN | EPOLLRDNORM;
 
 	/* Check for false positives due to checksum errors */
@@ -2784,7 +2904,7 @@
 	__u16 srcp	  = ntohs(inet->inet_sport);
 
 	seq_printf(f, "%5d: %08X:%04X %08X:%04X"
-		" %02X %08X:%08X %02X:%08lX %08X %5u %8d %lu %d %pK %d",
+		" %02X %08X:%08X %02X:%08lX %08X %5u %8d %lu %d %pK %u",
 		bucket, src, srcp, dest, destp, sp->sk_state,
 		sk_wmem_alloc_get(sp),
 		udp_rqueue_get(sp),