diff --git a/Documentation/vm/hmm.rst b/Documentation/vm/hmm.rst
index 945d5fb6d14acfd792a37377a6c5dca4f7487457..ec1efa32af3cafb0b309c2653ebc008cbdccd860 100644
--- a/Documentation/vm/hmm.rst
+++ b/Documentation/vm/hmm.rst
@@ -276,6 +276,41 @@ report commands as executed is serialized (there is no point in doing this
 concurrently).
 
 
+Leverage default_flags and pfn_flags_mask
+=========================================
+
+The hmm_range struct has 2 fields default_flags and pfn_flags_mask that allows
+to set fault or snapshot policy for a whole range instead of having to set them
+for each entries in the range.
+
+For instance if the device flags for device entries are:
+    VALID (1 << 63)
+    WRITE (1 << 62)
+
+Now let say that device driver wants to fault with at least read a range then
+it does set:
+    range->default_flags = (1 << 63)
+    range->pfn_flags_mask = 0;
+
+and calls hmm_range_fault() as described above. This will fill fault all page
+in the range with at least read permission.
+
+Now let say driver wants to do the same except for one page in the range for
+which its want to have write. Now driver set:
+    range->default_flags = (1 << 63);
+    range->pfn_flags_mask = (1 << 62);
+    range->pfns[index_of_write] = (1 << 62);
+
+With this HMM will fault in all page with at least read (ie valid) and for the
+address == range->start + (index_of_write << PAGE_SHIFT) it will fault with
+write permission ie if the CPU pte does not have write permission set then HMM
+will call handle_mm_fault().
+
+Note that HMM will populate the pfns array with write permission for any entry
+that have write permission within the CPU pte no matter what are the values set
+in default_flags or pfn_flags_mask.
+
+
 Represent and manage device memory from core kernel point of view
 =================================================================
 
diff --git a/include/linux/hmm.h b/include/linux/hmm.h
index ec4bfa91648f16f8ce74bb5cb8ae5627e962f388..dee2f8953b2e93b07ab4e82bf63f49867c77f1c7 100644
--- a/include/linux/hmm.h
+++ b/include/linux/hmm.h
@@ -165,6 +165,8 @@ enum hmm_pfn_value_e {
  * @pfns: array of pfns (big enough for the range)
  * @flags: pfn flags to match device driver page table
  * @values: pfn value for some special case (none, special, error, ...)
+ * @default_flags: default flags for the range (write, read, ... see hmm doc)
+ * @pfn_flags_mask: allows to mask pfn flags so that only default_flags matter
  * @pfn_shifts: pfn shift value (should be <= PAGE_SHIFT)
  * @valid: pfns array did not change since it has been fill by an HMM function
  */
@@ -177,6 +179,8 @@ struct hmm_range {
 	uint64_t		*pfns;
 	const uint64_t		*flags;
 	const uint64_t		*values;
+	uint64_t		default_flags;
+	uint64_t		pfn_flags_mask;
 	uint8_t			pfn_shift;
 	bool			valid;
 };
@@ -448,6 +452,15 @@ static inline int hmm_vma_fault(struct hmm_range *range, bool block)
 {
 	long ret;
 
+	/*
+	 * With the old API the driver must set each individual entries with
+	 * the requested flags (valid, write, ...). So here we set the mask to
+	 * keep intact the entries provided by the driver and zero out the
+	 * default_flags.
+	 */
+	range->default_flags = 0;
+	range->pfn_flags_mask = -1UL;
+
 	ret = hmm_range_register(range, range->vma->vm_mm,
 				 range->start, range->end);
 	if (ret)
diff --git a/mm/hmm.c b/mm/hmm.c
index 3e07f32b94f85f5406cb31615f38d1a73f45060f..0e21d3594ab6783e8e09f15c8c7e7d69348715a3 100644
--- a/mm/hmm.c
+++ b/mm/hmm.c
@@ -419,6 +419,18 @@ static inline void hmm_pte_need_fault(const struct hmm_vma_walk *hmm_vma_walk,
 	if (!hmm_vma_walk->fault)
 		return;
 
+	/*
+	 * So we not only consider the individual per page request we also
+	 * consider the default flags requested for the range. The API can
+	 * be use in 2 fashions. The first one where the HMM user coalesce
+	 * multiple page fault into one request and set flags per pfns for
+	 * of those faults. The second one where the HMM user want to pre-
+	 * fault a range with specific flags. For the latter one it is a
+	 * waste to have the user pre-fill the pfn arrays with a default
+	 * flags value.
+	 */
+	pfns = (pfns & range->pfn_flags_mask) | range->default_flags;
+
 	/* We aren't ask to do anything ... */
 	if (!(pfns & range->flags[HMM_PFN_VALID]))
 		return;