diff --git a/mm/memcontrol.c b/mm/memcontrol.c
index 089ef36141555c415062b6ebb8787c6d8774ea8d..ff53e348c4bbf7d0b55f9cd8b42c18375688275f 100644
--- a/mm/memcontrol.c
+++ b/mm/memcontrol.c
@@ -5505,15 +5505,18 @@ void mem_cgroup_cancel_charge(struct page *page, struct mem_cgroup *memcg,
 
 static void uncharge_batch(struct mem_cgroup *memcg, unsigned long pgpgout,
 			   unsigned long nr_anon, unsigned long nr_file,
-			   unsigned long nr_huge, struct page *dummy_page)
+			   unsigned long nr_huge, unsigned long nr_kmem,
+			   struct page *dummy_page)
 {
-	unsigned long nr_pages = nr_anon + nr_file;
+	unsigned long nr_pages = nr_anon + nr_file + nr_kmem;
 	unsigned long flags;
 
 	if (!mem_cgroup_is_root(memcg)) {
 		page_counter_uncharge(&memcg->memory, nr_pages);
 		if (do_memsw_account())
 			page_counter_uncharge(&memcg->memsw, nr_pages);
+		if (!cgroup_subsys_on_dfl(memory_cgrp_subsys) && nr_kmem)
+			page_counter_uncharge(&memcg->kmem, nr_kmem);
 		memcg_oom_recover(memcg);
 	}
 
@@ -5536,6 +5539,7 @@ static void uncharge_list(struct list_head *page_list)
 	unsigned long nr_anon = 0;
 	unsigned long nr_file = 0;
 	unsigned long nr_huge = 0;
+	unsigned long nr_kmem = 0;
 	unsigned long pgpgout = 0;
 	struct list_head *next;
 	struct page *page;
@@ -5546,8 +5550,6 @@ static void uncharge_list(struct list_head *page_list)
 	 */
 	next = page_list->next;
 	do {
-		unsigned int nr_pages = 1;
-
 		page = list_entry(next, struct page, lru);
 		next = page->lru.next;
 
@@ -5566,31 +5568,35 @@ static void uncharge_list(struct list_head *page_list)
 		if (memcg != page->mem_cgroup) {
 			if (memcg) {
 				uncharge_batch(memcg, pgpgout, nr_anon, nr_file,
-					       nr_huge, page);
-				pgpgout = nr_anon = nr_file = nr_huge = 0;
+					       nr_huge, nr_kmem, page);
+				pgpgout = nr_anon = nr_file =
+					nr_huge = nr_kmem = 0;
 			}
 			memcg = page->mem_cgroup;
 		}
 
-		if (PageTransHuge(page)) {
-			nr_pages <<= compound_order(page);
-			VM_BUG_ON_PAGE(!PageTransHuge(page), page);
-			nr_huge += nr_pages;
-		}
+		if (!PageKmemcg(page)) {
+			unsigned int nr_pages = 1;
 
-		if (PageAnon(page))
-			nr_anon += nr_pages;
-		else
-			nr_file += nr_pages;
+			if (PageTransHuge(page)) {
+				nr_pages <<= compound_order(page);
+				VM_BUG_ON_PAGE(!PageTransHuge(page), page);
+				nr_huge += nr_pages;
+			}
+			if (PageAnon(page))
+				nr_anon += nr_pages;
+			else
+				nr_file += nr_pages;
+			pgpgout++;
+		} else
+			nr_kmem += 1 << compound_order(page);
 
 		page->mem_cgroup = NULL;
-
-		pgpgout++;
 	} while (next != page_list);
 
 	if (memcg)
 		uncharge_batch(memcg, pgpgout, nr_anon, nr_file,
-			       nr_huge, page);
+			       nr_huge, nr_kmem, page);
 }
 
 /**