diff --git a/include/linux/memcontrol.h b/include/linux/memcontrol.h
index 5b86287fa069321f10db8963612960699622f1b5..a7a0a1a5c8d5916c7a2946d369d11f61076cbac7 100644
--- a/include/linux/memcontrol.h
+++ b/include/linux/memcontrol.h
@@ -901,6 +901,11 @@ static inline struct lruvec *mem_cgroup_page_lruvec(struct page *page,
 	return &pgdat->__lruvec;
 }
 
+static inline struct mem_cgroup *parent_mem_cgroup(struct mem_cgroup *memcg)
+{
+	return NULL;
+}
+
 static inline bool mm_match_cgroup(struct mm_struct *mm,
 		struct mem_cgroup *memcg)
 {
diff --git a/include/linux/swap.h b/include/linux/swap.h
index 063c0c1e112bdc9c9edb81ecbcbefd45bdc7ee57..1e99f7ac1d7e3a12bf8378992c4ba1e3d015abf4 100644
--- a/include/linux/swap.h
+++ b/include/linux/swap.h
@@ -307,7 +307,7 @@ struct vma_swap_readahead {
 };
 
 /* linux/mm/workingset.c */
-void *workingset_eviction(struct page *page);
+void *workingset_eviction(struct page *page, struct mem_cgroup *target_memcg);
 void workingset_refault(struct page *page, void *shadow);
 void workingset_activation(struct page *page);
 
diff --git a/mm/vmscan.c b/mm/vmscan.c
index 725b5d4784f793d19963124c6b8e64ec5eca5d8d..39657012e2d849f5af3093f769ff384f4b6e3e73 100644
--- a/mm/vmscan.c
+++ b/mm/vmscan.c
@@ -853,7 +853,7 @@ static pageout_t pageout(struct page *page, struct address_space *mapping)
  * gets returned with a refcount of 0.
  */
 static int __remove_mapping(struct address_space *mapping, struct page *page,
-			    bool reclaimed)
+			    bool reclaimed, struct mem_cgroup *target_memcg)
 {
 	unsigned long flags;
 	int refcount;
@@ -925,7 +925,7 @@ static int __remove_mapping(struct address_space *mapping, struct page *page,
 		 */
 		if (reclaimed && page_is_file_cache(page) &&
 		    !mapping_exiting(mapping) && !dax_mapping(mapping))
-			shadow = workingset_eviction(page);
+			shadow = workingset_eviction(page, target_memcg);
 		__delete_from_page_cache(page, shadow);
 		xa_unlock_irqrestore(&mapping->i_pages, flags);
 
@@ -948,7 +948,7 @@ static int __remove_mapping(struct address_space *mapping, struct page *page,
  */
 int remove_mapping(struct address_space *mapping, struct page *page)
 {
-	if (__remove_mapping(mapping, page, false)) {
+	if (__remove_mapping(mapping, page, false, NULL)) {
 		/*
 		 * Unfreezing the refcount with 1 rather than 2 effectively
 		 * drops the pagecache ref for us without requiring another
@@ -1426,7 +1426,8 @@ static unsigned long shrink_page_list(struct list_head *page_list,
 
 			count_vm_event(PGLAZYFREED);
 			count_memcg_page_event(page, PGLAZYFREED);
-		} else if (!mapping || !__remove_mapping(mapping, page, true))
+		} else if (!mapping || !__remove_mapping(mapping, page, true,
+							 sc->target_mem_cgroup))
 			goto keep_locked;
 
 		unlock_page(page);
@@ -2189,6 +2190,7 @@ static bool inactive_list_is_low(struct lruvec *lruvec, bool file,
 	enum lru_list inactive_lru = file * LRU_FILE;
 	unsigned long inactive, active;
 	unsigned long inactive_ratio;
+	struct lruvec *target_lruvec;
 	unsigned long refaults;
 	unsigned long gb;
 
@@ -2200,8 +2202,9 @@ static bool inactive_list_is_low(struct lruvec *lruvec, bool file,
 	 * is being established. Disable active list protection to get
 	 * rid of the stale workingset quickly.
 	 */
-	refaults = lruvec_page_state_local(lruvec, WORKINGSET_ACTIVATE);
-	if (file && lruvec->refaults != refaults) {
+	target_lruvec = mem_cgroup_lruvec(sc->target_mem_cgroup, pgdat);
+	refaults = lruvec_page_state(target_lruvec, WORKINGSET_ACTIVATE);
+	if (file && target_lruvec->refaults != refaults) {
 		inactive_ratio = 0;
 	} else {
 		gb = (inactive + active) >> (30 - PAGE_SHIFT);
@@ -2973,19 +2976,14 @@ static void shrink_zones(struct zonelist *zonelist, struct scan_control *sc)
 	sc->gfp_mask = orig_mask;
 }
 
-static void snapshot_refaults(struct mem_cgroup *root_memcg, pg_data_t *pgdat)
+static void snapshot_refaults(struct mem_cgroup *target_memcg, pg_data_t *pgdat)
 {
-	struct mem_cgroup *memcg;
-
-	memcg = mem_cgroup_iter(root_memcg, NULL, NULL);
-	do {
-		unsigned long refaults;
-		struct lruvec *lruvec;
+	struct lruvec *target_lruvec;
+	unsigned long refaults;
 
-		lruvec = mem_cgroup_lruvec(memcg, pgdat);
-		refaults = lruvec_page_state_local(lruvec, WORKINGSET_ACTIVATE);
-		lruvec->refaults = refaults;
-	} while ((memcg = mem_cgroup_iter(root_memcg, memcg, NULL)));
+	target_lruvec = mem_cgroup_lruvec(target_memcg, pgdat);
+	refaults = lruvec_page_state(target_lruvec, WORKINGSET_ACTIVATE);
+	target_lruvec->refaults = refaults;
 }
 
 /*
diff --git a/mm/workingset.c b/mm/workingset.c
index e8212123c1c3d3aa6c2a09a2ce74155b64f3a663..474186b76ceda30d04d0ca8fe4fd70f72ed76a4a 100644
--- a/mm/workingset.c
+++ b/mm/workingset.c
@@ -213,28 +213,53 @@ static void unpack_shadow(void *shadow, int *memcgidp, pg_data_t **pgdat,
 	*workingsetp = workingset;
 }
 
+static void advance_inactive_age(struct mem_cgroup *memcg, pg_data_t *pgdat)
+{
+	/*
+	 * Reclaiming a cgroup means reclaiming all its children in a
+	 * round-robin fashion. That means that each cgroup has an LRU
+	 * order that is composed of the LRU orders of its child
+	 * cgroups; and every page has an LRU position not just in the
+	 * cgroup that owns it, but in all of that group's ancestors.
+	 *
+	 * So when the physical inactive list of a leaf cgroup ages,
+	 * the virtual inactive lists of all its parents, including
+	 * the root cgroup's, age as well.
+	 */
+	do {
+		struct lruvec *lruvec;
+
+		lruvec = mem_cgroup_lruvec(memcg, pgdat);
+		atomic_long_inc(&lruvec->inactive_age);
+	} while (memcg && (memcg = parent_mem_cgroup(memcg)));
+}
+
 /**
  * workingset_eviction - note the eviction of a page from memory
+ * @target_memcg: the cgroup that is causing the reclaim
  * @page: the page being evicted
  *
  * Returns a shadow entry to be stored in @page->mapping->i_pages in place
  * of the evicted @page so that a later refault can be detected.
  */
-void *workingset_eviction(struct page *page)
+void *workingset_eviction(struct page *page, struct mem_cgroup *target_memcg)
 {
 	struct pglist_data *pgdat = page_pgdat(page);
-	struct mem_cgroup *memcg = page_memcg(page);
-	int memcgid = mem_cgroup_id(memcg);
 	unsigned long eviction;
 	struct lruvec *lruvec;
+	int memcgid;
 
 	/* Page is fully exclusive and pins page->mem_cgroup */
 	VM_BUG_ON_PAGE(PageLRU(page), page);
 	VM_BUG_ON_PAGE(page_count(page), page);
 	VM_BUG_ON_PAGE(!PageLocked(page), page);
 
-	lruvec = mem_cgroup_lruvec(memcg, pgdat);
-	eviction = atomic_long_inc_return(&lruvec->inactive_age);
+	advance_inactive_age(page_memcg(page), pgdat);
+
+	lruvec = mem_cgroup_lruvec(target_memcg, pgdat);
+	/* XXX: target_memcg can be NULL, go through lruvec */
+	memcgid = mem_cgroup_id(lruvec_memcg(lruvec));
+	eviction = atomic_long_read(&lruvec->inactive_age);
 	return pack_shadow(memcgid, pgdat, eviction, PageWorkingset(page));
 }
 
@@ -244,10 +269,13 @@ void *workingset_eviction(struct page *page)
  * @shadow: shadow entry of the evicted page
  *
  * Calculates and evaluates the refault distance of the previously
- * evicted page in the context of the node it was allocated in.
+ * evicted page in the context of the node and the memcg whose memory
+ * pressure caused the eviction.
  */
 void workingset_refault(struct page *page, void *shadow)
 {
+	struct mem_cgroup *eviction_memcg;
+	struct lruvec *eviction_lruvec;
 	unsigned long refault_distance;
 	struct pglist_data *pgdat;
 	unsigned long active_file;
@@ -277,12 +305,12 @@ void workingset_refault(struct page *page, void *shadow)
 	 * would be better if the root_mem_cgroup existed in all
 	 * configurations instead.
 	 */
-	memcg = mem_cgroup_from_id(memcgid);
-	if (!mem_cgroup_disabled() && !memcg)
+	eviction_memcg = mem_cgroup_from_id(memcgid);
+	if (!mem_cgroup_disabled() && !eviction_memcg)
 		goto out;
-	lruvec = mem_cgroup_lruvec(memcg, pgdat);
-	refault = atomic_long_read(&lruvec->inactive_age);
-	active_file = lruvec_lru_size(lruvec, LRU_ACTIVE_FILE, MAX_NR_ZONES);
+	eviction_lruvec = mem_cgroup_lruvec(eviction_memcg, pgdat);
+	refault = atomic_long_read(&eviction_lruvec->inactive_age);
+	active_file = lruvec_page_state(eviction_lruvec, NR_ACTIVE_FILE);
 
 	/*
 	 * Calculate the refault distance
@@ -302,6 +330,17 @@ void workingset_refault(struct page *page, void *shadow)
 	 */
 	refault_distance = (refault - eviction) & EVICTION_MASK;
 
+	/*
+	 * The activation decision for this page is made at the level
+	 * where the eviction occurred, as that is where the LRU order
+	 * during page reclaim is being determined.
+	 *
+	 * However, the cgroup that will own the page is the one that
+	 * is actually experiencing the refault event.
+	 */
+	memcg = page_memcg(page);
+	lruvec = mem_cgroup_lruvec(memcg, pgdat);
+
 	inc_lruvec_state(lruvec, WORKINGSET_REFAULT);
 
 	/*
@@ -313,7 +352,7 @@ void workingset_refault(struct page *page, void *shadow)
 		goto out;
 
 	SetPageActive(page);
-	atomic_long_inc(&lruvec->inactive_age);
+	advance_inactive_age(memcg, pgdat);
 	inc_lruvec_state(lruvec, WORKINGSET_ACTIVATE);
 
 	/* Page was active prior to eviction */
@@ -332,7 +371,6 @@ void workingset_refault(struct page *page, void *shadow)
 void workingset_activation(struct page *page)
 {
 	struct mem_cgroup *memcg;
-	struct lruvec *lruvec;
 
 	rcu_read_lock();
 	/*
@@ -345,8 +383,7 @@ void workingset_activation(struct page *page)
 	memcg = page_memcg_rcu(page);
 	if (!mem_cgroup_disabled() && !memcg)
 		goto out;
-	lruvec = mem_cgroup_lruvec(memcg, page_pgdat(page));
-	atomic_long_inc(&lruvec->inactive_age);
+	advance_inactive_age(memcg, page_pgdat(page));
 out:
 	rcu_read_unlock();
 }