diff --git a/include/net/sock.h b/include/net/sock.h
index e3bf213be6259ab00a1057cb58b44ad5e80515cb..79532540201b16cccf2f4dd5f74dd4085d4278d8 100644
--- a/include/net/sock.h
+++ b/include/net/sock.h
@@ -218,7 +218,7 @@ struct cg_proto;
   *	@sk_lock:	synchronizer
   *	@sk_rcvbuf: size of receive buffer in bytes
   *	@sk_wq: sock wait queue and async head
-  *	@sk_rx_dst: receive input route used by early tcp demux
+  *	@sk_rx_dst: receive input route used by early demux
   *	@sk_dst_cache: destination cache
   *	@sk_dst_lock: destination cache lock
   *	@sk_policy: flow policy
diff --git a/include/net/udp.h b/include/net/udp.h
index 510b8cb4d1a77429b4e41f341b2ef8caaa83d962..fe4ba9f32429ab7c2fc327cafd29bafadeb3561b 100644
--- a/include/net/udp.h
+++ b/include/net/udp.h
@@ -175,6 +175,7 @@ int udp_lib_get_port(struct sock *sk, unsigned short snum,
 		     unsigned int hash2_nulladdr);
 
 /* net/ipv4/udp.c */
+void udp_v4_early_demux(struct sk_buff *skb);
 int udp_get_port(struct sock *sk, unsigned short snum,
 		 int (*saddr_cmp)(const struct sock *,
 				  const struct sock *));
diff --git a/net/ipv4/af_inet.c b/net/ipv4/af_inet.c
index cfeb85cff4f02abc28570b267ba6e64784595fab..35913fb77dc8672ac30518078eaa08378d32617d 100644
--- a/net/ipv4/af_inet.c
+++ b/net/ipv4/af_inet.c
@@ -1546,6 +1546,7 @@ static const struct net_protocol tcp_protocol = {
 };
 
 static const struct net_protocol udp_protocol = {
+	.early_demux =	udp_v4_early_demux,
 	.handler =	udp_rcv,
 	.err_handler =	udp_err,
 	.no_policy =	1,
diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c
index 5950e12bd3abb33f6fefd06f3c196bdb1e02810d..262ea3929da65a7963c2b3bbbb53ef6dc5eb4e6d 100644
--- a/net/ipv4/udp.c
+++ b/net/ipv4/udp.c
@@ -103,6 +103,7 @@
 #include <linux/seq_file.h>
 #include <net/net_namespace.h>
 #include <net/icmp.h>
+#include <net/inet_hashtables.h>
 #include <net/route.h>
 #include <net/checksum.h>
 #include <net/xfrm.h>
@@ -565,6 +566,26 @@ struct sock *udp4_lib_lookup(struct net *net, __be32 saddr, __be16 sport,
 }
 EXPORT_SYMBOL_GPL(udp4_lib_lookup);
 
+static inline bool __udp_is_mcast_sock(struct net *net, struct sock *sk,
+				       __be16 loc_port, __be32 loc_addr,
+				       __be16 rmt_port, __be32 rmt_addr,
+				       int dif, unsigned short hnum)
+{
+	struct inet_sock *inet = inet_sk(sk);
+
+	if (!net_eq(sock_net(sk), net) ||
+	    udp_sk(sk)->udp_port_hash != hnum ||
+	    (inet->inet_daddr && inet->inet_daddr != rmt_addr) ||
+	    (inet->inet_dport != rmt_port && inet->inet_dport) ||
+	    (inet->inet_rcv_saddr && inet->inet_rcv_saddr != loc_addr) ||
+	    ipv6_only_sock(sk) ||
+	    (sk->sk_bound_dev_if && sk->sk_bound_dev_if != dif))
+		return false;
+	if (!ip_mc_sf_allow(sk, loc_addr, rmt_addr, dif))
+		return false;
+	return true;
+}
+
 static inline struct sock *udp_v4_mcast_next(struct net *net, struct sock *sk,
 					     __be16 loc_port, __be32 loc_addr,
 					     __be16 rmt_port, __be32 rmt_addr,
@@ -575,20 +596,11 @@ static inline struct sock *udp_v4_mcast_next(struct net *net, struct sock *sk,
 	unsigned short hnum = ntohs(loc_port);
 
 	sk_nulls_for_each_from(s, node) {
-		struct inet_sock *inet = inet_sk(s);
-
-		if (!net_eq(sock_net(s), net) ||
-		    udp_sk(s)->udp_port_hash != hnum ||
-		    (inet->inet_daddr && inet->inet_daddr != rmt_addr) ||
-		    (inet->inet_dport != rmt_port && inet->inet_dport) ||
-		    (inet->inet_rcv_saddr &&
-		     inet->inet_rcv_saddr != loc_addr) ||
-		    ipv6_only_sock(s) ||
-		    (s->sk_bound_dev_if && s->sk_bound_dev_if != dif))
-			continue;
-		if (!ip_mc_sf_allow(s, loc_addr, rmt_addr, dif))
-			continue;
-		goto found;
+		if (__udp_is_mcast_sock(net, s,
+					loc_port, loc_addr,
+					rmt_port, rmt_addr,
+					dif, hnum))
+			goto found;
 	}
 	s = NULL;
 found:
@@ -1581,6 +1593,14 @@ static void flush_stack(struct sock **stack, unsigned int count,
 		kfree_skb(skb1);
 }
 
+static void udp_sk_rx_dst_set(struct sock *sk, const struct sk_buff *skb)
+{
+	struct dst_entry *dst = skb_dst(skb);
+
+	dst_hold(dst);
+	sk->sk_rx_dst = dst;
+}
+
 /*
  *	Multicasts and broadcasts go to each listener.
  *
@@ -1709,11 +1729,28 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
 	if (udp4_csum_init(skb, uh, proto))
 		goto csum_error;
 
-	if (rt->rt_flags & (RTCF_BROADCAST|RTCF_MULTICAST))
-		return __udp4_lib_mcast_deliver(net, skb, uh,
-				saddr, daddr, udptable);
+	if (skb->sk) {
+		int ret;
+		sk = skb->sk;
 
-	sk = __udp4_lib_lookup_skb(skb, uh->source, uh->dest, udptable);
+		if (unlikely(sk->sk_rx_dst == NULL))
+			udp_sk_rx_dst_set(sk, skb);
+
+		ret = udp_queue_rcv_skb(sk, skb);
+
+		/* a return value > 0 means to resubmit the input, but
+		 * it wants the return to be -protocol, or 0
+		 */
+		if (ret > 0)
+			return -ret;
+		return 0;
+	} else {
+		if (rt->rt_flags & (RTCF_BROADCAST|RTCF_MULTICAST))
+			return __udp4_lib_mcast_deliver(net, skb, uh,
+					saddr, daddr, udptable);
+
+		sk = __udp4_lib_lookup_skb(skb, uh->source, uh->dest, udptable);
+	}
 
 	if (sk != NULL) {
 		int ret;
@@ -1771,6 +1808,135 @@ int __udp4_lib_rcv(struct sk_buff *skb, struct udp_table *udptable,
 	return 0;
 }
 
+/* We can only early demux multicast if there is a single matching socket.
+ * If more than one socket found returns NULL
+ */
+static struct sock *__udp4_lib_mcast_demux_lookup(struct net *net,
+						  __be16 loc_port, __be32 loc_addr,
+						  __be16 rmt_port, __be32 rmt_addr,
+						  int dif)
+{
+	struct sock *sk, *result;
+	struct hlist_nulls_node *node;
+	unsigned short hnum = ntohs(loc_port);
+	unsigned int count, slot = udp_hashfn(net, hnum, udp_table.mask);
+	struct udp_hslot *hslot = &udp_table.hash[slot];
+
+	rcu_read_lock();
+begin:
+	count = 0;
+	result = NULL;
+	sk_nulls_for_each_rcu(sk, node, &hslot->head) {
+		if (__udp_is_mcast_sock(net, sk,
+					loc_port, loc_addr,
+					rmt_port, rmt_addr,
+					dif, hnum)) {
+			result = sk;
+			++count;
+		}
+	}
+	/*
+	 * if the nulls value we got at the end of this lookup is
+	 * not the expected one, we must restart lookup.
+	 * We probably met an item that was moved to another chain.
+	 */
+	if (get_nulls_value(node) != slot)
+		goto begin;
+
+	if (result) {
+		if (count != 1 ||
+		    unlikely(!atomic_inc_not_zero_hint(&result->sk_refcnt, 2)))
+			result = NULL;
+		else if (unlikely(!__udp_is_mcast_sock(net, sk,
+						       loc_port, loc_addr,
+						       rmt_port, rmt_addr,
+						       dif, hnum))) {
+			sock_put(result);
+			result = NULL;
+		}
+	}
+	rcu_read_unlock();
+	return result;
+}
+
+/* For unicast we should only early demux connected sockets or we can
+ * break forwarding setups.  The chains here can be long so only check
+ * if the first socket is an exact match and if not move on.
+ */
+static struct sock *__udp4_lib_demux_lookup(struct net *net,
+					    __be16 loc_port, __be32 loc_addr,
+					    __be16 rmt_port, __be32 rmt_addr,
+					    int dif)
+{
+	struct sock *sk, *result;
+	struct hlist_nulls_node *node;
+	unsigned short hnum = ntohs(loc_port);
+	unsigned int hash2 = udp4_portaddr_hash(net, loc_addr, hnum);
+	unsigned int slot2 = hash2 & udp_table.mask;
+	struct udp_hslot *hslot2 = &udp_table.hash2[slot2];
+	INET_ADDR_COOKIE(acookie, rmt_addr, loc_addr)
+	const __portpair ports = INET_COMBINED_PORTS(rmt_port, hnum);
+
+	rcu_read_lock();
+	result = NULL;
+	udp_portaddr_for_each_entry_rcu(sk, node, &hslot2->head) {
+		if (INET_MATCH(sk, net, acookie,
+			       rmt_addr, loc_addr, ports, dif))
+			result = sk;
+		/* Only check first socket in chain */
+		break;
+	}
+
+	if (result) {
+		if (unlikely(!atomic_inc_not_zero_hint(&result->sk_refcnt, 2)))
+			result = NULL;
+		else if (unlikely(!INET_MATCH(sk, net, acookie,
+					      rmt_addr, loc_addr,
+					      ports, dif))) {
+			sock_put(result);
+			result = NULL;
+		}
+	}
+	rcu_read_unlock();
+	return result;
+}
+
+void udp_v4_early_demux(struct sk_buff *skb)
+{
+	const struct iphdr *iph = ip_hdr(skb);
+	const struct udphdr *uh = udp_hdr(skb);
+	struct sock *sk;
+	struct dst_entry *dst;
+	struct net *net = dev_net(skb->dev);
+	int dif = skb->dev->ifindex;
+
+	/* validate the packet */
+	if (!pskb_may_pull(skb, skb_transport_offset(skb) + sizeof(struct udphdr)))
+		return;
+
+	if (skb->pkt_type == PACKET_BROADCAST ||
+	    skb->pkt_type == PACKET_MULTICAST)
+		sk = __udp4_lib_mcast_demux_lookup(net, uh->dest, iph->daddr,
+						   uh->source, iph->saddr, dif);
+	else if (skb->pkt_type == PACKET_HOST)
+		sk = __udp4_lib_demux_lookup(net, uh->dest, iph->daddr,
+					     uh->source, iph->saddr, dif);
+	else
+		return;
+
+	if (!sk)
+		return;
+
+	skb->sk = sk;
+	skb->destructor = sock_edemux;
+	dst = sk->sk_rx_dst;
+
+	if (dst)
+		dst = dst_check(dst, 0);
+	if (dst)
+		skb_dst_set_noref(skb, dst);
+}
+
 int udp_rcv(struct sk_buff *skb)
 {
 	return __udp4_lib_rcv(skb, &udp_table, IPPROTO_UDP);