diff --git a/mm/internal.h b/mm/internal.h
index ccfc2a2969f4402bdbfb27e0b48df151f4da68b7..266efaeaa370a46debcc5b6b614a72e33833ac4d 100644
--- a/mm/internal.h
+++ b/mm/internal.h
@@ -481,6 +481,13 @@ unsigned long reclaim_clean_pages_from_list(struct zone *zone,
 enum ttu_flags;
 struct tlbflush_unmap_batch;
 
+
+/*
+ * only for MM internal work items which do not depend on
+ * any allocations or locks which might depend on allocations
+ */
+extern struct workqueue_struct *mm_percpu_wq;
+
 #ifdef CONFIG_ARCH_WANT_BATCHED_UNMAP_TLB_FLUSH
 void try_to_unmap_flush(void);
 void try_to_unmap_flush_dirty(void);
diff --git a/mm/page_alloc.c b/mm/page_alloc.c
index d6a665057d6100fd6c62f10a98558f70e48ca4e2..f3d603cef2c0c0e5aef09540dd2f8d50da5a808c 100644
--- a/mm/page_alloc.c
+++ b/mm/page_alloc.c
@@ -2373,6 +2373,13 @@ void drain_all_pages(struct zone *zone)
 	 */
 	static cpumask_t cpus_with_pcps;
 
+	/*
+	 * Make sure nobody triggers this path before mm_percpu_wq is fully
+	 * initialized.
+	 */
+	if (WARN_ON_ONCE(!mm_percpu_wq))
+		return;
+
 	/* Workqueues cannot recurse */
 	if (current->flags & PF_WQ_WORKER)
 		return;
@@ -2422,7 +2429,7 @@ void drain_all_pages(struct zone *zone)
 	for_each_cpu(cpu, &cpus_with_pcps) {
 		struct work_struct *work = per_cpu_ptr(&pcpu_drain, cpu);
 		INIT_WORK(work, drain_local_pages_wq);
-		schedule_work_on(cpu, work);
+		queue_work_on(cpu, mm_percpu_wq, work);
 	}
 	for_each_cpu(cpu, &cpus_with_pcps)
 		flush_work(per_cpu_ptr(&pcpu_drain, cpu));
diff --git a/mm/swap.c b/mm/swap.c
index c4910f14f9579ef1d8b165355f9294715968bf2d..5dabf444d724db98595567b0f7daed7d53fc877e 100644
--- a/mm/swap.c
+++ b/mm/swap.c
@@ -670,30 +670,19 @@ static void lru_add_drain_per_cpu(struct work_struct *dummy)
 
 static DEFINE_PER_CPU(struct work_struct, lru_add_drain_work);
 
-/*
- * lru_add_drain_wq is used to do lru_add_drain_all() from a WQ_MEM_RECLAIM
- * workqueue, aiding in getting memory freed.
- */
-static struct workqueue_struct *lru_add_drain_wq;
-
-static int __init lru_init(void)
-{
-	lru_add_drain_wq = alloc_workqueue("lru-add-drain", WQ_MEM_RECLAIM, 0);
-
-	if (WARN(!lru_add_drain_wq,
-		"Failed to create workqueue lru_add_drain_wq"))
-		return -ENOMEM;
-
-	return 0;
-}
-early_initcall(lru_init);
-
 void lru_add_drain_all(void)
 {
 	static DEFINE_MUTEX(lock);
 	static struct cpumask has_work;
 	int cpu;
 
+	/*
+	 * Make sure nobody triggers this path before mm_percpu_wq is fully
+	 * initialized.
+	 */
+	if (WARN_ON(!mm_percpu_wq))
+		return;
+
 	mutex_lock(&lock);
 	get_online_cpus();
 	cpumask_clear(&has_work);
@@ -707,7 +696,7 @@ void lru_add_drain_all(void)
 		    pagevec_count(&per_cpu(lru_deactivate_pvecs, cpu)) ||
 		    need_activate_page_drain(cpu)) {
 			INIT_WORK(work, lru_add_drain_per_cpu);
-			queue_work_on(cpu, lru_add_drain_wq, work);
+			queue_work_on(cpu, mm_percpu_wq, work);
 			cpumask_set_cpu(cpu, &has_work);
 		}
 	}
diff --git a/mm/vmstat.c b/mm/vmstat.c
index 89f95396ec46be64055f1a658c9c0f7bdad90d5c..809025ed97ea0eee97573a32ba2764c63ee2dffd 100644
--- a/mm/vmstat.c
+++ b/mm/vmstat.c
@@ -1552,7 +1552,6 @@ static const struct file_operations proc_vmstat_file_operations = {
 #endif /* CONFIG_PROC_FS */
 
 #ifdef CONFIG_SMP
-static struct workqueue_struct *vmstat_wq;
 static DEFINE_PER_CPU(struct delayed_work, vmstat_work);
 int sysctl_stat_interval __read_mostly = HZ;
 
@@ -1623,7 +1622,7 @@ static void vmstat_update(struct work_struct *w)
 		 * to occur in the future. Keep on running the
 		 * update worker thread.
 		 */
-		queue_delayed_work_on(smp_processor_id(), vmstat_wq,
+		queue_delayed_work_on(smp_processor_id(), mm_percpu_wq,
 				this_cpu_ptr(&vmstat_work),
 				round_jiffies_relative(sysctl_stat_interval));
 	}
@@ -1702,7 +1701,7 @@ static void vmstat_shepherd(struct work_struct *w)
 		struct delayed_work *dw = &per_cpu(vmstat_work, cpu);
 
 		if (!delayed_work_pending(dw) && need_update(cpu))
-			queue_delayed_work_on(cpu, vmstat_wq, dw, 0);
+			queue_delayed_work_on(cpu, mm_percpu_wq, dw, 0);
 	}
 	put_online_cpus();
 
@@ -1718,7 +1717,6 @@ static void __init start_shepherd_timer(void)
 		INIT_DEFERRABLE_WORK(per_cpu_ptr(&vmstat_work, cpu),
 			vmstat_update);
 
-	vmstat_wq = alloc_workqueue("vmstat", WQ_FREEZABLE|WQ_MEM_RECLAIM, 0);
 	schedule_delayed_work(&shepherd,
 		round_jiffies_relative(sysctl_stat_interval));
 }
@@ -1764,11 +1762,16 @@ static int vmstat_cpu_dead(unsigned int cpu)
 
 #endif
 
+struct workqueue_struct *mm_percpu_wq;
+
 void __init init_mm_internals(void)
 {
-#ifdef CONFIG_SMP
-	int ret;
+	int ret __maybe_unused;
 
+	mm_percpu_wq = alloc_workqueue("mm_percpu_wq",
+				       WQ_FREEZABLE|WQ_MEM_RECLAIM, 0);
+
+#ifdef CONFIG_SMP
 	ret = cpuhp_setup_state_nocalls(CPUHP_MM_VMSTAT_DEAD, "mm/vmstat:dead",
 					NULL, vmstat_cpu_dead);
 	if (ret < 0)