diff --git a/mm/memcontrol.c b/mm/memcontrol.c
index 8cc617ede7e2ec408970e0a6c9b0a30cd54ab90a..24892a14cc75cc3b701ce4225c5e88a1a80a149d 100644
--- a/mm/memcontrol.c
+++ b/mm/memcontrol.c
@@ -3416,7 +3416,8 @@ static int memcg_online_kmem(struct mem_cgroup *memcg)
 	if (memcg_id < 0)
 		return memcg_id;
 
-	static_branch_inc(&memcg_kmem_enabled_key);
+	static_branch_enable(&memcg_kmem_enabled_key);
+
 	/*
 	 * A memory cgroup is considered kmem-online as soon as it gets
 	 * kmemcg_id. Setting the id after enabling static branching will
@@ -3486,11 +3487,6 @@ static void memcg_free_kmem(struct mem_cgroup *memcg)
 	/* css_alloc() failed, offlining didn't happen */
 	if (unlikely(memcg->kmem_state == KMEM_ONLINE))
 		memcg_offline_kmem(memcg);
-
-	if (memcg->kmem_state == KMEM_ALLOCATED) {
-		WARN_ON(!list_empty(&memcg->kmem_caches));
-		static_branch_dec(&memcg_kmem_enabled_key);
-	}
 }
 #else
 static int memcg_online_kmem(struct mem_cgroup *memcg)