diff --git a/Documentation/kernel-parameters.txt b/Documentation/kernel-parameters.txt
index b7d44871effc2c49f194a3c3c986733f2a497b2c..168fd79dc697a06648b28a1125a07265880b58aa 100644
--- a/Documentation/kernel-parameters.txt
+++ b/Documentation/kernel-parameters.txt
@@ -608,6 +608,10 @@ bytes respectively. Such letter suffixes can also be entirely omitted.
 			cut the overhead, others just disable the usage. So
 			only cgroup_disable=memory is actually worthy}
 
+	cgroup.memory=	[KNL] Pass options to the cgroup memory controller.
+			Format: <string>
+			nosocket -- Disable socket memory accounting.
+
 	checkreqprot	[SELINUX] Set initial checkreqprot flag value.
 			Format: { "0" | "1" }
 			See security/selinux/Kconfig help text.
diff --git a/include/linux/memcontrol.h b/include/linux/memcontrol.h
index 03090e8e7fffbb882c39ccf7aa3b653ccaf06ae3..a355f61a2ed39ba68cb8b13dc40659072f2a490b 100644
--- a/include/linux/memcontrol.h
+++ b/include/linux/memcontrol.h
@@ -170,6 +170,9 @@ struct mem_cgroup {
 	unsigned long low;
 	unsigned long high;
 
+	/* Range enforcement for interrupt charges */
+	struct work_struct high_work;
+
 	unsigned long soft_limit;
 
 	/* vmpressure notifications */
@@ -680,12 +683,16 @@ void sock_update_memcg(struct sock *sk);
 void sock_release_memcg(struct sock *sk);
 bool mem_cgroup_charge_skmem(struct mem_cgroup *memcg, unsigned int nr_pages);
 void mem_cgroup_uncharge_skmem(struct mem_cgroup *memcg, unsigned int nr_pages);
-#if defined(CONFIG_MEMCG_KMEM) && defined(CONFIG_INET)
+#if defined(CONFIG_MEMCG) && defined(CONFIG_INET)
 extern struct static_key memcg_sockets_enabled_key;
 #define mem_cgroup_sockets_enabled static_key_false(&memcg_sockets_enabled_key)
 static inline bool mem_cgroup_under_socket_pressure(struct mem_cgroup *memcg)
 {
+#ifdef CONFIG_MEMCG_KMEM
 	return memcg->tcp_mem.memory_pressure;
+#else
+	return false;
+#endif
 }
 #else
 #define mem_cgroup_sockets_enabled 0
diff --git a/mm/memcontrol.c b/mm/memcontrol.c
index 6aac8d2e31d75cefdf1c05c225de6dae35ad85fd..60ebc486c2aa6ff559c3ac297518acb21d1f2229 100644
--- a/mm/memcontrol.c
+++ b/mm/memcontrol.c
@@ -80,6 +80,9 @@ struct mem_cgroup *root_mem_cgroup __read_mostly;
 
 #define MEM_CGROUP_RECLAIM_RETRIES	5
 
+/* Socket memory accounting disabled? */
+static bool cgroup_memory_nosocket;
+
 /* Whether the swap controller is active */
 #ifdef CONFIG_MEMCG_SWAP
 int do_swap_account __read_mostly;
@@ -1945,6 +1948,26 @@ static int memcg_cpu_hotplug_callback(struct notifier_block *nb,
 	return NOTIFY_OK;
 }
 
+static void reclaim_high(struct mem_cgroup *memcg,
+			 unsigned int nr_pages,
+			 gfp_t gfp_mask)
+{
+	do {
+		if (page_counter_read(&memcg->memory) <= memcg->high)
+			continue;
+		mem_cgroup_events(memcg, MEMCG_HIGH, 1);
+		try_to_free_mem_cgroup_pages(memcg, nr_pages, gfp_mask, true);
+	} while ((memcg = parent_mem_cgroup(memcg)));
+}
+
+static void high_work_func(struct work_struct *work)
+{
+	struct mem_cgroup *memcg;
+
+	memcg = container_of(work, struct mem_cgroup, high_work);
+	reclaim_high(memcg, CHARGE_BATCH, GFP_KERNEL);
+}
+
 /*
  * Scheduled by try_charge() to be executed from the userland return path
  * and reclaims memory over the high limit.
@@ -1952,20 +1975,13 @@ static int memcg_cpu_hotplug_callback(struct notifier_block *nb,
 void mem_cgroup_handle_over_high(void)
 {
 	unsigned int nr_pages = current->memcg_nr_pages_over_high;
-	struct mem_cgroup *memcg, *pos;
+	struct mem_cgroup *memcg;
 
 	if (likely(!nr_pages))
 		return;
 
-	pos = memcg = get_mem_cgroup_from_mm(current->mm);
-
-	do {
-		if (page_counter_read(&pos->memory) <= pos->high)
-			continue;
-		mem_cgroup_events(pos, MEMCG_HIGH, 1);
-		try_to_free_mem_cgroup_pages(pos, nr_pages, GFP_KERNEL, true);
-	} while ((pos = parent_mem_cgroup(pos)));
-
+	memcg = get_mem_cgroup_from_mm(current->mm);
+	reclaim_high(memcg, nr_pages, GFP_KERNEL);
 	css_put(&memcg->css);
 	current->memcg_nr_pages_over_high = 0;
 }
@@ -2100,6 +2116,11 @@ static int try_charge(struct mem_cgroup *memcg, gfp_t gfp_mask,
 	 */
 	do {
 		if (page_counter_read(&memcg->memory) > memcg->high) {
+			/* Don't bother a random interrupted task */
+			if (in_interrupt()) {
+				schedule_work(&memcg->high_work);
+				break;
+			}
 			current->memcg_nr_pages_over_high += batch;
 			set_notify_resume(current);
 			break;
@@ -4150,6 +4171,8 @@ static void __mem_cgroup_free(struct mem_cgroup *memcg)
 {
 	int node;
 
+	cancel_work_sync(&memcg->high_work);
+
 	mem_cgroup_remove_from_trees(memcg);
 
 	for_each_node(node)
@@ -4196,6 +4219,7 @@ mem_cgroup_css_alloc(struct cgroup_subsys_state *parent_css)
 		page_counter_init(&memcg->kmem, NULL);
 	}
 
+	INIT_WORK(&memcg->high_work, high_work_func);
 	memcg->last_scanned_node = MAX_NUMNODES;
 	INIT_LIST_HEAD(&memcg->oom_notify);
 	memcg->move_charge_at_immigrate = 0;
@@ -4267,6 +4291,11 @@ mem_cgroup_css_online(struct cgroup_subsys_state *css)
 	if (ret)
 		return ret;
 
+#ifdef CONFIG_INET
+	if (cgroup_subsys_on_dfl(memory_cgrp_subsys) && !cgroup_memory_nosocket)
+		static_key_slow_inc(&memcg_sockets_enabled_key);
+#endif
+
 	/*
 	 * Make sure the memcg is initialized: mem_cgroup_iter()
 	 * orders reading memcg->initialized against its callers
@@ -4313,6 +4342,10 @@ static void mem_cgroup_css_free(struct cgroup_subsys_state *css)
 	struct mem_cgroup *memcg = mem_cgroup_from_css(css);
 
 	memcg_destroy_kmem(memcg);
+#ifdef CONFIG_INET
+	if (cgroup_subsys_on_dfl(memory_cgrp_subsys) && !cgroup_memory_nosocket)
+		static_key_slow_dec(&memcg_sockets_enabled_key);
+#endif
 	__mem_cgroup_free(memcg);
 }
 
@@ -5533,8 +5566,7 @@ void mem_cgroup_replace_page(struct page *oldpage, struct page *newpage)
 	commit_charge(newpage, memcg, true);
 }
 
-/* Writing them here to avoid exposing memcg's inner layout */
-#if defined(CONFIG_INET) && defined(CONFIG_MEMCG_KMEM)
+#ifdef CONFIG_INET
 
 struct static_key memcg_sockets_enabled_key;
 EXPORT_SYMBOL(memcg_sockets_enabled_key);
@@ -5559,10 +5591,15 @@ void sock_update_memcg(struct sock *sk)
 
 	rcu_read_lock();
 	memcg = mem_cgroup_from_task(current);
-	if (memcg != root_mem_cgroup &&
-	    memcg->tcp_mem.active &&
-	    css_tryget_online(&memcg->css))
+	if (memcg == root_mem_cgroup)
+		goto out;
+#ifdef CONFIG_MEMCG_KMEM
+	if (!cgroup_subsys_on_dfl(memory_cgrp_subsys) && !memcg->tcp_mem.active)
+		goto out;
+#endif
+	if (css_tryget_online(&memcg->css))
 		sk->sk_memcg = memcg;
+out:
 	rcu_read_unlock();
 }
 EXPORT_SYMBOL(sock_update_memcg);
@@ -5583,15 +5620,30 @@ void sock_release_memcg(struct sock *sk)
  */
 bool mem_cgroup_charge_skmem(struct mem_cgroup *memcg, unsigned int nr_pages)
 {
-	struct page_counter *counter;
+	gfp_t gfp_mask = GFP_KERNEL;
 
-	if (page_counter_try_charge(&memcg->tcp_mem.memory_allocated,
-				    nr_pages, &counter)) {
-		memcg->tcp_mem.memory_pressure = 0;
-		return true;
+#ifdef CONFIG_MEMCG_KMEM
+	if (!cgroup_subsys_on_dfl(memory_cgrp_subsys)) {
+		struct page_counter *counter;
+
+		if (page_counter_try_charge(&memcg->tcp_mem.memory_allocated,
+					    nr_pages, &counter)) {
+			memcg->tcp_mem.memory_pressure = 0;
+			return true;
+		}
+		page_counter_charge(&memcg->tcp_mem.memory_allocated, nr_pages);
+		memcg->tcp_mem.memory_pressure = 1;
+		return false;
 	}
-	page_counter_charge(&memcg->tcp_mem.memory_allocated, nr_pages);
-	memcg->tcp_mem.memory_pressure = 1;
+#endif
+	/* Don't block in the packet receive path */
+	if (in_softirq())
+		gfp_mask = GFP_NOWAIT;
+
+	if (try_charge(memcg, gfp_mask, nr_pages) == 0)
+		return true;
+
+	try_charge(memcg, gfp_mask|__GFP_NOFAIL, nr_pages);
 	return false;
 }
 
@@ -5602,10 +5654,32 @@ bool mem_cgroup_charge_skmem(struct mem_cgroup *memcg, unsigned int nr_pages)
  */
 void mem_cgroup_uncharge_skmem(struct mem_cgroup *memcg, unsigned int nr_pages)
 {
-	page_counter_uncharge(&memcg->tcp_mem.memory_allocated, nr_pages);
+#ifdef CONFIG_MEMCG_KMEM
+	if (!cgroup_subsys_on_dfl(memory_cgrp_subsys)) {
+		page_counter_uncharge(&memcg->tcp_mem.memory_allocated,
+				      nr_pages);
+		return;
+	}
+#endif
+	page_counter_uncharge(&memcg->memory, nr_pages);
+	css_put_many(&memcg->css, nr_pages);
 }
 
-#endif
+#endif /* CONFIG_INET */
+
+static int __init cgroup_memory(char *s)
+{
+	char *token;
+
+	while ((token = strsep(&s, ",")) != NULL) {
+		if (!*token)
+			continue;
+		if (!strcmp(token, "nosocket"))
+			cgroup_memory_nosocket = true;
+	}
+	return 0;
+}
+__setup("cgroup.memory=", cgroup_memory);
 
 /*
  * subsys_initcall() for memory controller.