diff --git a/include/linux/mmzone.h b/include/linux/mmzone.h
index 150c6049f961407b82c75d5b7fe3bef0976f0993..cfcd7723edb6fb5838eaa90c625c7636b048bdb9 100644
--- a/include/linux/mmzone.h
+++ b/include/linux/mmzone.h
@@ -919,6 +919,10 @@ static inline int zonelist_node_idx(struct zoneref *zoneref)
 #endif /* CONFIG_NUMA */
 }
 
+struct zoneref *__next_zones_zonelist(struct zoneref *z,
+					enum zone_type highest_zoneidx,
+					nodemask_t *nodes);
+
 /**
  * next_zones_zonelist - Returns the next zone at or below highest_zoneidx within the allowed nodemask using a cursor within a zonelist as a starting point
  * @z - The cursor used as a starting point for the search
@@ -931,9 +935,14 @@ static inline int zonelist_node_idx(struct zoneref *zoneref)
  * being examined. It should be advanced by one before calling
  * next_zones_zonelist again.
  */
-struct zoneref *next_zones_zonelist(struct zoneref *z,
+static __always_inline struct zoneref *next_zones_zonelist(struct zoneref *z,
 					enum zone_type highest_zoneidx,
-					nodemask_t *nodes);
+					nodemask_t *nodes)
+{
+	if (likely(!nodes && zonelist_zone_idx(z) <= highest_zoneidx))
+		return z;
+	return __next_zones_zonelist(z, highest_zoneidx, nodes);
+}
 
 /**
  * first_zones_zonelist - Returns the first zone at or below highest_zoneidx within the allowed nodemask in a zonelist
diff --git a/mm/mmzone.c b/mm/mmzone.c
index 52687fb4de6f46ebcb095987b9474e3957b5eafd..5652be858e5e320c7a748d4e6a407c66baaa62a0 100644
--- a/mm/mmzone.c
+++ b/mm/mmzone.c
@@ -52,7 +52,7 @@ static inline int zref_in_nodemask(struct zoneref *zref, nodemask_t *nodes)
 }
 
 /* Returns the next zone at or below highest_zoneidx in a zonelist */
-struct zoneref *next_zones_zonelist(struct zoneref *z,
+struct zoneref *__next_zones_zonelist(struct zoneref *z,
 					enum zone_type highest_zoneidx,
 					nodemask_t *nodes)
 {
diff --git a/mm/page_alloc.c b/mm/page_alloc.c
index 36384baa74e18dbe0828a4d759c92f06f243abe4..789e5f065e8d49ac351cdf29d123b5e06dae9239 100644
--- a/mm/page_alloc.c
+++ b/mm/page_alloc.c
@@ -3192,17 +3192,6 @@ __alloc_pages_slowpath(gfp_t gfp_mask, unsigned int order,
 	 */
 	alloc_flags = gfp_to_alloc_flags(gfp_mask);
 
-	/*
-	 * Find the true preferred zone if the allocation is unconstrained by
-	 * cpusets.
-	 */
-	if (!(alloc_flags & ALLOC_CPUSET) && !ac->nodemask) {
-		struct zoneref *preferred_zoneref;
-		preferred_zoneref = first_zones_zonelist(ac->zonelist,
-				ac->high_zoneidx, NULL, &ac->preferred_zone);
-		ac->classzone_idx = zonelist_zone_idx(preferred_zoneref);
-	}
-
 	/* This is the last chance, in general, before the goto nopage. */
 	page = get_page_from_freelist(gfp_mask, order,
 				alloc_flags & ~ALLOC_NO_WATERMARKS, ac);
@@ -3358,14 +3347,21 @@ __alloc_pages_nodemask(gfp_t gfp_mask, unsigned int order,
 	struct zoneref *preferred_zoneref;
 	struct page *page = NULL;
 	unsigned int cpuset_mems_cookie;
-	int alloc_flags = ALLOC_WMARK_LOW|ALLOC_CPUSET|ALLOC_FAIR;
+	int alloc_flags = ALLOC_WMARK_LOW|ALLOC_FAIR;
 	gfp_t alloc_mask; /* The gfp_t that was actually used for allocation */
 	struct alloc_context ac = {
 		.high_zoneidx = gfp_zone(gfp_mask),
+		.zonelist = zonelist,
 		.nodemask = nodemask,
 		.migratetype = gfpflags_to_migratetype(gfp_mask),
 	};
 
+	if (cpusets_enabled()) {
+		alloc_flags |= ALLOC_CPUSET;
+		if (!ac.nodemask)
+			ac.nodemask = &cpuset_current_mems_allowed;
+	}
+
 	gfp_mask &= gfp_allowed_mask;
 
 	lockdep_trace_alloc(gfp_mask);
@@ -3389,16 +3385,12 @@ __alloc_pages_nodemask(gfp_t gfp_mask, unsigned int order,
 retry_cpuset:
 	cpuset_mems_cookie = read_mems_allowed_begin();
 
-	/* We set it here, as __alloc_pages_slowpath might have changed it */
-	ac.zonelist = zonelist;
-
 	/* Dirty zone balancing only done in the fast path */
 	ac.spread_dirty_pages = (gfp_mask & __GFP_WRITE);
 
 	/* The preferred zone is used for statistics later */
 	preferred_zoneref = first_zones_zonelist(ac.zonelist, ac.high_zoneidx,
-				ac.nodemask ? : &cpuset_current_mems_allowed,
-				&ac.preferred_zone);
+				ac.nodemask, &ac.preferred_zone);
 	if (!ac.preferred_zone)
 		goto out;
 	ac.classzone_idx = zonelist_zone_idx(preferred_zoneref);