diff --git a/arch/x86/kvm/x86.c b/arch/x86/kvm/x86.c
index f7dff0457846474c6b1af5822674ffb7a64d2c9b..4a74a7cf0a8bf835800c0cb23d72b17bcb931efa 100644
--- a/arch/x86/kvm/x86.c
+++ b/arch/x86/kvm/x86.c
@@ -7305,8 +7305,9 @@ static void vcpu_load_eoi_exitmap(struct kvm_vcpu *vcpu)
 	kvm_x86_ops->load_eoi_exitmap(vcpu, eoi_exit_bitmap);
 }
 
-void kvm_arch_mmu_notifier_invalidate_range(struct kvm *kvm,
-		unsigned long start, unsigned long end)
+int kvm_arch_mmu_notifier_invalidate_range(struct kvm *kvm,
+		unsigned long start, unsigned long end,
+		bool blockable)
 {
 	unsigned long apic_address;
 
@@ -7317,6 +7318,8 @@ void kvm_arch_mmu_notifier_invalidate_range(struct kvm *kvm,
 	apic_address = gfn_to_hva(kvm, APIC_DEFAULT_PHYS_BASE >> PAGE_SHIFT);
 	if (start <= apic_address && apic_address < end)
 		kvm_make_all_cpus_request(kvm, KVM_REQ_APIC_PAGE_RELOAD);
+
+	return 0;
 }
 
 void kvm_vcpu_reload_apic_access_page(struct kvm_vcpu *vcpu)
diff --git a/drivers/gpu/drm/amd/amdgpu/amdgpu_mn.c b/drivers/gpu/drm/amd/amdgpu/amdgpu_mn.c
index a365ea2383d18c137df14d1fa116fad954d9ca45..e55508b394962d8d77e6fe285cee1a9544214714 100644
--- a/drivers/gpu/drm/amd/amdgpu/amdgpu_mn.c
+++ b/drivers/gpu/drm/amd/amdgpu/amdgpu_mn.c
@@ -178,12 +178,18 @@ void amdgpu_mn_unlock(struct amdgpu_mn *mn)
  *
  * @amn: our notifier
  */
-static void amdgpu_mn_read_lock(struct amdgpu_mn *amn)
+static int amdgpu_mn_read_lock(struct amdgpu_mn *amn, bool blockable)
 {
-	mutex_lock(&amn->read_lock);
+	if (blockable)
+		mutex_lock(&amn->read_lock);
+	else if (!mutex_trylock(&amn->read_lock))
+		return -EAGAIN;
+
 	if (atomic_inc_return(&amn->recursion) == 1)
 		down_read_non_owner(&amn->lock);
 	mutex_unlock(&amn->read_lock);
+
+	return 0;
 }
 
 /**
@@ -239,10 +245,11 @@ static void amdgpu_mn_invalidate_node(struct amdgpu_mn_node *node,
  * Block for operations on BOs to finish and mark pages as accessed and
  * potentially dirty.
  */
-static void amdgpu_mn_invalidate_range_start_gfx(struct mmu_notifier *mn,
+static int amdgpu_mn_invalidate_range_start_gfx(struct mmu_notifier *mn,
 						 struct mm_struct *mm,
 						 unsigned long start,
-						 unsigned long end)
+						 unsigned long end,
+						 bool blockable)
 {
 	struct amdgpu_mn *amn = container_of(mn, struct amdgpu_mn, mn);
 	struct interval_tree_node *it;
@@ -250,17 +257,28 @@ static void amdgpu_mn_invalidate_range_start_gfx(struct mmu_notifier *mn,
 	/* notification is exclusive, but interval is inclusive */
 	end -= 1;
 
-	amdgpu_mn_read_lock(amn);
+	/* TODO we should be able to split locking for interval tree and
+	 * amdgpu_mn_invalidate_node
+	 */
+	if (amdgpu_mn_read_lock(amn, blockable))
+		return -EAGAIN;
 
 	it = interval_tree_iter_first(&amn->objects, start, end);
 	while (it) {
 		struct amdgpu_mn_node *node;
 
+		if (!blockable) {
+			amdgpu_mn_read_unlock(amn);
+			return -EAGAIN;
+		}
+
 		node = container_of(it, struct amdgpu_mn_node, it);
 		it = interval_tree_iter_next(it, start, end);
 
 		amdgpu_mn_invalidate_node(node, start, end);
 	}
+
+	return 0;
 }
 
 /**
@@ -275,10 +293,11 @@ static void amdgpu_mn_invalidate_range_start_gfx(struct mmu_notifier *mn,
  * necessitates evicting all user-mode queues of the process. The BOs
  * are restorted in amdgpu_mn_invalidate_range_end_hsa.
  */
-static void amdgpu_mn_invalidate_range_start_hsa(struct mmu_notifier *mn,
+static int amdgpu_mn_invalidate_range_start_hsa(struct mmu_notifier *mn,
 						 struct mm_struct *mm,
 						 unsigned long start,
-						 unsigned long end)
+						 unsigned long end,
+						 bool blockable)
 {
 	struct amdgpu_mn *amn = container_of(mn, struct amdgpu_mn, mn);
 	struct interval_tree_node *it;
@@ -286,13 +305,19 @@ static void amdgpu_mn_invalidate_range_start_hsa(struct mmu_notifier *mn,
 	/* notification is exclusive, but interval is inclusive */
 	end -= 1;
 
-	amdgpu_mn_read_lock(amn);
+	if (amdgpu_mn_read_lock(amn, blockable))
+		return -EAGAIN;
 
 	it = interval_tree_iter_first(&amn->objects, start, end);
 	while (it) {
 		struct amdgpu_mn_node *node;
 		struct amdgpu_bo *bo;
 
+		if (!blockable) {
+			amdgpu_mn_read_unlock(amn);
+			return -EAGAIN;
+		}
+
 		node = container_of(it, struct amdgpu_mn_node, it);
 		it = interval_tree_iter_next(it, start, end);
 
@@ -304,6 +329,8 @@ static void amdgpu_mn_invalidate_range_start_hsa(struct mmu_notifier *mn,
 				amdgpu_amdkfd_evict_userptr(mem, mm);
 		}
 	}
+
+	return 0;
 }
 
 /**
diff --git a/drivers/gpu/drm/i915/i915_gem_userptr.c b/drivers/gpu/drm/i915/i915_gem_userptr.c
index dcd6e230d16aa7905c1ae075fab469e333f7ef8b..2c9b284036d10217a013e146eba8704123006d65 100644
--- a/drivers/gpu/drm/i915/i915_gem_userptr.c
+++ b/drivers/gpu/drm/i915/i915_gem_userptr.c
@@ -112,10 +112,11 @@ static void del_object(struct i915_mmu_object *mo)
 	mo->attached = false;
 }
 
-static void i915_gem_userptr_mn_invalidate_range_start(struct mmu_notifier *_mn,
+static int i915_gem_userptr_mn_invalidate_range_start(struct mmu_notifier *_mn,
 						       struct mm_struct *mm,
 						       unsigned long start,
-						       unsigned long end)
+						       unsigned long end,
+						       bool blockable)
 {
 	struct i915_mmu_notifier *mn =
 		container_of(_mn, struct i915_mmu_notifier, mn);
@@ -124,7 +125,7 @@ static void i915_gem_userptr_mn_invalidate_range_start(struct mmu_notifier *_mn,
 	LIST_HEAD(cancelled);
 
 	if (RB_EMPTY_ROOT(&mn->objects.rb_root))
-		return;
+		return 0;
 
 	/* interval ranges are inclusive, but invalidate range is exclusive */
 	end--;
@@ -132,6 +133,10 @@ static void i915_gem_userptr_mn_invalidate_range_start(struct mmu_notifier *_mn,
 	spin_lock(&mn->lock);
 	it = interval_tree_iter_first(&mn->objects, start, end);
 	while (it) {
+		if (!blockable) {
+			spin_unlock(&mn->lock);
+			return -EAGAIN;
+		}
 		/* The mmu_object is released late when destroying the
 		 * GEM object so it is entirely possible to gain a
 		 * reference on an object in the process of being freed
@@ -154,6 +159,8 @@ static void i915_gem_userptr_mn_invalidate_range_start(struct mmu_notifier *_mn,
 
 	if (!list_empty(&cancelled))
 		flush_workqueue(mn->wq);
+
+	return 0;
 }
 
 static const struct mmu_notifier_ops i915_gem_userptr_notifier = {
diff --git a/drivers/gpu/drm/radeon/radeon_mn.c b/drivers/gpu/drm/radeon/radeon_mn.c
index abd24975c9b1d946cec7c85ba7ebe3b820d1c8e4..f8b35df44c60eba57e7b2ec76c068190dd941c79 100644
--- a/drivers/gpu/drm/radeon/radeon_mn.c
+++ b/drivers/gpu/drm/radeon/radeon_mn.c
@@ -118,19 +118,27 @@ static void radeon_mn_release(struct mmu_notifier *mn,
  * We block for all BOs between start and end to be idle and
  * unmap them by move them into system domain again.
  */
-static void radeon_mn_invalidate_range_start(struct mmu_notifier *mn,
+static int radeon_mn_invalidate_range_start(struct mmu_notifier *mn,
 					     struct mm_struct *mm,
 					     unsigned long start,
-					     unsigned long end)
+					     unsigned long end,
+					     bool blockable)
 {
 	struct radeon_mn *rmn = container_of(mn, struct radeon_mn, mn);
 	struct ttm_operation_ctx ctx = { false, false };
 	struct interval_tree_node *it;
+	int ret = 0;
 
 	/* notification is exclusive, but interval is inclusive */
 	end -= 1;
 
-	mutex_lock(&rmn->lock);
+	/* TODO we should be able to split locking for interval tree and
+	 * the tear down.
+	 */
+	if (blockable)
+		mutex_lock(&rmn->lock);
+	else if (!mutex_trylock(&rmn->lock))
+		return -EAGAIN;
 
 	it = interval_tree_iter_first(&rmn->objects, start, end);
 	while (it) {
@@ -138,6 +146,11 @@ static void radeon_mn_invalidate_range_start(struct mmu_notifier *mn,
 		struct radeon_bo *bo;
 		long r;
 
+		if (!blockable) {
+			ret = -EAGAIN;
+			goto out_unlock;
+		}
+
 		node = container_of(it, struct radeon_mn_node, it);
 		it = interval_tree_iter_next(it, start, end);
 
@@ -166,7 +179,10 @@ static void radeon_mn_invalidate_range_start(struct mmu_notifier *mn,
 		}
 	}
 	
+out_unlock:
 	mutex_unlock(&rmn->lock);
+
+	return ret;
 }
 
 static const struct mmu_notifier_ops radeon_mn_ops = {
diff --git a/drivers/infiniband/core/umem_odp.c b/drivers/infiniband/core/umem_odp.c
index 182436b92ba93a95ca6e119108c30d50706b7212..6ec748eccff7e87d7691b92a6d50534d89d4625e 100644
--- a/drivers/infiniband/core/umem_odp.c
+++ b/drivers/infiniband/core/umem_odp.c
@@ -186,6 +186,7 @@ static void ib_umem_notifier_release(struct mmu_notifier *mn,
 	rbt_ib_umem_for_each_in_range(&context->umem_tree, 0,
 				      ULLONG_MAX,
 				      ib_umem_notifier_release_trampoline,
+				      true,
 				      NULL);
 	up_read(&context->umem_rwsem);
 }
@@ -207,22 +208,31 @@ static int invalidate_range_start_trampoline(struct ib_umem *item, u64 start,
 	return 0;
 }
 
-static void ib_umem_notifier_invalidate_range_start(struct mmu_notifier *mn,
+static int ib_umem_notifier_invalidate_range_start(struct mmu_notifier *mn,
 						    struct mm_struct *mm,
 						    unsigned long start,
-						    unsigned long end)
+						    unsigned long end,
+						    bool blockable)
 {
 	struct ib_ucontext *context = container_of(mn, struct ib_ucontext, mn);
+	int ret;
 
 	if (!context->invalidate_range)
-		return;
+		return 0;
+
+	if (blockable)
+		down_read(&context->umem_rwsem);
+	else if (!down_read_trylock(&context->umem_rwsem))
+		return -EAGAIN;
 
 	ib_ucontext_notifier_start_account(context);
-	down_read(&context->umem_rwsem);
-	rbt_ib_umem_for_each_in_range(&context->umem_tree, start,
+	ret = rbt_ib_umem_for_each_in_range(&context->umem_tree, start,
 				      end,
-				      invalidate_range_start_trampoline, NULL);
+				      invalidate_range_start_trampoline,
+				      blockable, NULL);
 	up_read(&context->umem_rwsem);
+
+	return ret;
 }
 
 static int invalidate_range_end_trampoline(struct ib_umem *item, u64 start,
@@ -242,10 +252,15 @@ static void ib_umem_notifier_invalidate_range_end(struct mmu_notifier *mn,
 	if (!context->invalidate_range)
 		return;
 
+	/*
+	 * TODO: we currently bail out if there is any sleepable work to be done
+	 * in ib_umem_notifier_invalidate_range_start so we shouldn't really block
+	 * here. But this is ugly and fragile.
+	 */
 	down_read(&context->umem_rwsem);
 	rbt_ib_umem_for_each_in_range(&context->umem_tree, start,
 				      end,
-				      invalidate_range_end_trampoline, NULL);
+				      invalidate_range_end_trampoline, true, NULL);
 	up_read(&context->umem_rwsem);
 	ib_ucontext_notifier_end_account(context);
 }
@@ -798,6 +813,7 @@ EXPORT_SYMBOL(ib_umem_odp_unmap_dma_pages);
 int rbt_ib_umem_for_each_in_range(struct rb_root_cached *root,
 				  u64 start, u64 last,
 				  umem_call_back cb,
+				  bool blockable,
 				  void *cookie)
 {
 	int ret_val = 0;
@@ -809,6 +825,9 @@ int rbt_ib_umem_for_each_in_range(struct rb_root_cached *root,
 
 	for (node = rbt_ib_umem_iter_first(root, start, last - 1);
 			node; node = next) {
+		/* TODO move the blockable decision up to the callback */
+		if (!blockable)
+			return -EAGAIN;
 		next = rbt_ib_umem_iter_next(node, start, last - 1);
 		umem = container_of(node, struct ib_umem_odp, interval_tree);
 		ret_val = cb(umem->umem, start, last, cookie) || ret_val;
diff --git a/drivers/infiniband/hw/hfi1/mmu_rb.c b/drivers/infiniband/hw/hfi1/mmu_rb.c
index 70aceefe14d5fa306a3ce78408b9866518030104..e1c7996c018efe688dbe081cc0928f05c5f0bbdd 100644
--- a/drivers/infiniband/hw/hfi1/mmu_rb.c
+++ b/drivers/infiniband/hw/hfi1/mmu_rb.c
@@ -67,9 +67,9 @@ struct mmu_rb_handler {
 
 static unsigned long mmu_node_start(struct mmu_rb_node *);
 static unsigned long mmu_node_last(struct mmu_rb_node *);
-static void mmu_notifier_range_start(struct mmu_notifier *,
+static int mmu_notifier_range_start(struct mmu_notifier *,
 				     struct mm_struct *,
-				     unsigned long, unsigned long);
+				     unsigned long, unsigned long, bool);
 static struct mmu_rb_node *__mmu_rb_search(struct mmu_rb_handler *,
 					   unsigned long, unsigned long);
 static void do_remove(struct mmu_rb_handler *handler,
@@ -284,10 +284,11 @@ void hfi1_mmu_rb_remove(struct mmu_rb_handler *handler,
 	handler->ops->remove(handler->ops_arg, node);
 }
 
-static void mmu_notifier_range_start(struct mmu_notifier *mn,
+static int mmu_notifier_range_start(struct mmu_notifier *mn,
 				     struct mm_struct *mm,
 				     unsigned long start,
-				     unsigned long end)
+				     unsigned long end,
+				     bool blockable)
 {
 	struct mmu_rb_handler *handler =
 		container_of(mn, struct mmu_rb_handler, mn);
@@ -313,6 +314,8 @@ static void mmu_notifier_range_start(struct mmu_notifier *mn,
 
 	if (added)
 		queue_work(handler->wq, &handler->del_work);
+
+	return 0;
 }
 
 /*
diff --git a/drivers/infiniband/hw/mlx5/odp.c b/drivers/infiniband/hw/mlx5/odp.c
index f1a87a690a4cd0555a1bb0ccbd585bb114a498b9..d216e0d2921dafc28b59e9a56dfbd82c660cc0d8 100644
--- a/drivers/infiniband/hw/mlx5/odp.c
+++ b/drivers/infiniband/hw/mlx5/odp.c
@@ -488,7 +488,7 @@ void mlx5_ib_free_implicit_mr(struct mlx5_ib_mr *imr)
 
 	down_read(&ctx->umem_rwsem);
 	rbt_ib_umem_for_each_in_range(&ctx->umem_tree, 0, ULLONG_MAX,
-				      mr_leaf_free, imr);
+				      mr_leaf_free, true, imr);
 	up_read(&ctx->umem_rwsem);
 
 	wait_event(imr->q_leaf_free, !atomic_read(&imr->num_leaf_free));
diff --git a/drivers/misc/mic/scif/scif_dma.c b/drivers/misc/mic/scif/scif_dma.c
index 63d6246d6dff3caf8b6f7099257f902d50d2de20..6369aeaa70562e92325813857e2ba5c59aadcecb 100644
--- a/drivers/misc/mic/scif/scif_dma.c
+++ b/drivers/misc/mic/scif/scif_dma.c
@@ -200,15 +200,18 @@ static void scif_mmu_notifier_release(struct mmu_notifier *mn,
 	schedule_work(&scif_info.misc_work);
 }
 
-static void scif_mmu_notifier_invalidate_range_start(struct mmu_notifier *mn,
+static int scif_mmu_notifier_invalidate_range_start(struct mmu_notifier *mn,
 						     struct mm_struct *mm,
 						     unsigned long start,
-						     unsigned long end)
+						     unsigned long end,
+						     bool blockable)
 {
 	struct scif_mmu_notif	*mmn;
 
 	mmn = container_of(mn, struct scif_mmu_notif, ep_mmu_notifier);
 	scif_rma_destroy_tcw(mmn, start, end - start);
+
+	return 0;
 }
 
 static void scif_mmu_notifier_invalidate_range_end(struct mmu_notifier *mn,
diff --git a/drivers/misc/sgi-gru/grutlbpurge.c b/drivers/misc/sgi-gru/grutlbpurge.c
index a3454eb56fbf57e3a3868c514cdd72affd578cc4..be28f05bfafa9ee41a53c642f0e384812ea88adf 100644
--- a/drivers/misc/sgi-gru/grutlbpurge.c
+++ b/drivers/misc/sgi-gru/grutlbpurge.c
@@ -219,9 +219,10 @@ void gru_flush_all_tlb(struct gru_state *gru)
 /*
  * MMUOPS notifier callout functions
  */
-static void gru_invalidate_range_start(struct mmu_notifier *mn,
+static int gru_invalidate_range_start(struct mmu_notifier *mn,
 				       struct mm_struct *mm,
-				       unsigned long start, unsigned long end)
+				       unsigned long start, unsigned long end,
+				       bool blockable)
 {
 	struct gru_mm_struct *gms = container_of(mn, struct gru_mm_struct,
 						 ms_notifier);
@@ -231,6 +232,8 @@ static void gru_invalidate_range_start(struct mmu_notifier *mn,
 	gru_dbg(grudev, "gms %p, start 0x%lx, end 0x%lx, act %d\n", gms,
 		start, end, atomic_read(&gms->ms_range_active));
 	gru_flush_tlb_range(gms, start, end - start);
+
+	return 0;
 }
 
 static void gru_invalidate_range_end(struct mmu_notifier *mn,
diff --git a/drivers/xen/gntdev.c b/drivers/xen/gntdev.c
index c866a62f766d4ee4bfd5e5bc57e701ef8766159a..57390c7666e5dd8d44bfe9bdf1e503afb13de189 100644
--- a/drivers/xen/gntdev.c
+++ b/drivers/xen/gntdev.c
@@ -479,18 +479,25 @@ static const struct vm_operations_struct gntdev_vmops = {
 
 /* ------------------------------------------------------------------ */
 
+static bool in_range(struct gntdev_grant_map *map,
+			      unsigned long start, unsigned long end)
+{
+	if (!map->vma)
+		return false;
+	if (map->vma->vm_start >= end)
+		return false;
+	if (map->vma->vm_end <= start)
+		return false;
+
+	return true;
+}
+
 static void unmap_if_in_range(struct gntdev_grant_map *map,
 			      unsigned long start, unsigned long end)
 {
 	unsigned long mstart, mend;
 	int err;
 
-	if (!map->vma)
-		return;
-	if (map->vma->vm_start >= end)
-		return;
-	if (map->vma->vm_end <= start)
-		return;
 	mstart = max(start, map->vma->vm_start);
 	mend   = min(end,   map->vma->vm_end);
 	pr_debug("map %d+%d (%lx %lx), range %lx %lx, mrange %lx %lx\n",
@@ -503,21 +510,40 @@ static void unmap_if_in_range(struct gntdev_grant_map *map,
 	WARN_ON(err);
 }
 
-static void mn_invl_range_start(struct mmu_notifier *mn,
+static int mn_invl_range_start(struct mmu_notifier *mn,
 				struct mm_struct *mm,
-				unsigned long start, unsigned long end)
+				unsigned long start, unsigned long end,
+				bool blockable)
 {
 	struct gntdev_priv *priv = container_of(mn, struct gntdev_priv, mn);
 	struct gntdev_grant_map *map;
+	int ret = 0;
+
+	/* TODO do we really need a mutex here? */
+	if (blockable)
+		mutex_lock(&priv->lock);
+	else if (!mutex_trylock(&priv->lock))
+		return -EAGAIN;
 
-	mutex_lock(&priv->lock);
 	list_for_each_entry(map, &priv->maps, next) {
+		if (in_range(map, start, end)) {
+			ret = -EAGAIN;
+			goto out_unlock;
+		}
 		unmap_if_in_range(map, start, end);
 	}
 	list_for_each_entry(map, &priv->freeable_maps, next) {
+		if (in_range(map, start, end)) {
+			ret = -EAGAIN;
+			goto out_unlock;
+		}
 		unmap_if_in_range(map, start, end);
 	}
+
+out_unlock:
 	mutex_unlock(&priv->lock);
+
+	return ret;
 }
 
 static void mn_release(struct mmu_notifier *mn,
diff --git a/include/linux/kvm_host.h b/include/linux/kvm_host.h
index 7c7362dd2faa9c549500756713ea5512c973cfcf..0205aee44dedd8522be52ffe7d3237dd3819656a 100644
--- a/include/linux/kvm_host.h
+++ b/include/linux/kvm_host.h
@@ -1289,8 +1289,8 @@ static inline long kvm_arch_vcpu_async_ioctl(struct file *filp,
 }
 #endif /* CONFIG_HAVE_KVM_VCPU_ASYNC_IOCTL */
 
-void kvm_arch_mmu_notifier_invalidate_range(struct kvm *kvm,
-		unsigned long start, unsigned long end);
+int kvm_arch_mmu_notifier_invalidate_range(struct kvm *kvm,
+		unsigned long start, unsigned long end, bool blockable);
 
 #ifdef CONFIG_HAVE_KVM_VCPU_RUN_PID_CHANGE
 int kvm_arch_vcpu_run_pid_change(struct kvm_vcpu *vcpu);
diff --git a/include/linux/mmu_notifier.h b/include/linux/mmu_notifier.h
index 392e6af827016cb48d8bfffde0b096463378aaad..133ba78820ee5f2da69b865df43689c0a736ba3e 100644
--- a/include/linux/mmu_notifier.h
+++ b/include/linux/mmu_notifier.h
@@ -151,13 +151,15 @@ struct mmu_notifier_ops {
 	 * address space but may still be referenced by sptes until
 	 * the last refcount is dropped.
 	 *
-	 * If both of these callbacks cannot block, and invalidate_range
-	 * cannot block, mmu_notifier_ops.flags should have
-	 * MMU_INVALIDATE_DOES_NOT_BLOCK set.
+	 * If blockable argument is set to false then the callback cannot
+	 * sleep and has to return with -EAGAIN. 0 should be returned
+	 * otherwise.
+	 *
 	 */
-	void (*invalidate_range_start)(struct mmu_notifier *mn,
+	int (*invalidate_range_start)(struct mmu_notifier *mn,
 				       struct mm_struct *mm,
-				       unsigned long start, unsigned long end);
+				       unsigned long start, unsigned long end,
+				       bool blockable);
 	void (*invalidate_range_end)(struct mmu_notifier *mn,
 				     struct mm_struct *mm,
 				     unsigned long start, unsigned long end);
@@ -229,8 +231,9 @@ extern int __mmu_notifier_test_young(struct mm_struct *mm,
 				     unsigned long address);
 extern void __mmu_notifier_change_pte(struct mm_struct *mm,
 				      unsigned long address, pte_t pte);
-extern void __mmu_notifier_invalidate_range_start(struct mm_struct *mm,
-				  unsigned long start, unsigned long end);
+extern int __mmu_notifier_invalidate_range_start(struct mm_struct *mm,
+				  unsigned long start, unsigned long end,
+				  bool blockable);
 extern void __mmu_notifier_invalidate_range_end(struct mm_struct *mm,
 				  unsigned long start, unsigned long end,
 				  bool only_end);
@@ -281,7 +284,15 @@ static inline void mmu_notifier_invalidate_range_start(struct mm_struct *mm,
 				  unsigned long start, unsigned long end)
 {
 	if (mm_has_notifiers(mm))
-		__mmu_notifier_invalidate_range_start(mm, start, end);
+		__mmu_notifier_invalidate_range_start(mm, start, end, true);
+}
+
+static inline int mmu_notifier_invalidate_range_start_nonblock(struct mm_struct *mm,
+				  unsigned long start, unsigned long end)
+{
+	if (mm_has_notifiers(mm))
+		return __mmu_notifier_invalidate_range_start(mm, start, end, false);
+	return 0;
 }
 
 static inline void mmu_notifier_invalidate_range_end(struct mm_struct *mm,
@@ -461,6 +472,12 @@ static inline void mmu_notifier_invalidate_range_start(struct mm_struct *mm,
 {
 }
 
+static inline int mmu_notifier_invalidate_range_start_nonblock(struct mm_struct *mm,
+				  unsigned long start, unsigned long end)
+{
+	return 0;
+}
+
 static inline void mmu_notifier_invalidate_range_end(struct mm_struct *mm,
 				  unsigned long start, unsigned long end)
 {
diff --git a/include/linux/oom.h b/include/linux/oom.h
index 6adac113e96d29b5059ed65ac0522237c9b94388..92f70e4c62529fff53e592f99be9ec12d6372ce2 100644
--- a/include/linux/oom.h
+++ b/include/linux/oom.h
@@ -95,7 +95,7 @@ static inline int check_stable_address_space(struct mm_struct *mm)
 	return 0;
 }
 
-void __oom_reap_task_mm(struct mm_struct *mm);
+bool __oom_reap_task_mm(struct mm_struct *mm);
 
 extern unsigned long oom_badness(struct task_struct *p,
 		struct mem_cgroup *memcg, const nodemask_t *nodemask,
diff --git a/include/rdma/ib_umem_odp.h b/include/rdma/ib_umem_odp.h
index 6a17f856f8418b199f78108833d576e18daa142f..381cdf5a9bd1ea5da114e41c068f81c92c879c1e 100644
--- a/include/rdma/ib_umem_odp.h
+++ b/include/rdma/ib_umem_odp.h
@@ -119,7 +119,8 @@ typedef int (*umem_call_back)(struct ib_umem *item, u64 start, u64 end,
  */
 int rbt_ib_umem_for_each_in_range(struct rb_root_cached *root,
 				  u64 start, u64 end,
-				  umem_call_back cb, void *cookie);
+				  umem_call_back cb,
+				  bool blockable, void *cookie);
 
 /*
  * Find first region intersecting with address range.
diff --git a/mm/hmm.c b/mm/hmm.c
index 76e7a058b32fc2c4dfb675d262dde082dced208d..0b05545916106cad2f5bd490af19514107a6b9c5 100644
--- a/mm/hmm.c
+++ b/mm/hmm.c
@@ -177,16 +177,19 @@ static void hmm_release(struct mmu_notifier *mn, struct mm_struct *mm)
 	up_write(&hmm->mirrors_sem);
 }
 
-static void hmm_invalidate_range_start(struct mmu_notifier *mn,
+static int hmm_invalidate_range_start(struct mmu_notifier *mn,
 				       struct mm_struct *mm,
 				       unsigned long start,
-				       unsigned long end)
+				       unsigned long end,
+				       bool blockable)
 {
 	struct hmm *hmm = mm->hmm;
 
 	VM_BUG_ON(!hmm);
 
 	atomic_inc(&hmm->sequence);
+
+	return 0;
 }
 
 static void hmm_invalidate_range_end(struct mmu_notifier *mn,
diff --git a/mm/mmap.c b/mm/mmap.c
index 8d6449e74431155edcb21308743bd52d449bd8f2..bb2a7e097c7d32d4b86372727005d296335ba72c 100644
--- a/mm/mmap.c
+++ b/mm/mmap.c
@@ -3064,7 +3064,7 @@ void exit_mmap(struct mm_struct *mm)
 		 * reliably test it.
 		 */
 		mutex_lock(&oom_lock);
-		__oom_reap_task_mm(mm);
+		(void)__oom_reap_task_mm(mm);
 		mutex_unlock(&oom_lock);
 
 		set_bit(MMF_OOM_SKIP, &mm->flags);
diff --git a/mm/mmu_notifier.c b/mm/mmu_notifier.c
index eff6b88a993f2491b8fb96a2ad553196448966da..82bb1a939c0e496affc2849c1e1af795fe3d5d96 100644
--- a/mm/mmu_notifier.c
+++ b/mm/mmu_notifier.c
@@ -174,18 +174,29 @@ void __mmu_notifier_change_pte(struct mm_struct *mm, unsigned long address,
 	srcu_read_unlock(&srcu, id);
 }
 
-void __mmu_notifier_invalidate_range_start(struct mm_struct *mm,
-				  unsigned long start, unsigned long end)
+int __mmu_notifier_invalidate_range_start(struct mm_struct *mm,
+				  unsigned long start, unsigned long end,
+				  bool blockable)
 {
 	struct mmu_notifier *mn;
+	int ret = 0;
 	int id;
 
 	id = srcu_read_lock(&srcu);
 	hlist_for_each_entry_rcu(mn, &mm->mmu_notifier_mm->list, hlist) {
-		if (mn->ops->invalidate_range_start)
-			mn->ops->invalidate_range_start(mn, mm, start, end);
+		if (mn->ops->invalidate_range_start) {
+			int _ret = mn->ops->invalidate_range_start(mn, mm, start, end, blockable);
+			if (_ret) {
+				pr_info("%pS callback failed with %d in %sblockable context.\n",
+						mn->ops->invalidate_range_start, _ret,
+						!blockable ? "non-" : "");
+				ret = _ret;
+			}
+		}
 	}
 	srcu_read_unlock(&srcu, id);
+
+	return ret;
 }
 EXPORT_SYMBOL_GPL(__mmu_notifier_invalidate_range_start);
 
diff --git a/mm/oom_kill.c b/mm/oom_kill.c
index 412f43453a68ff5260ec46c6211c9249cf9d9cc6..be31a1e0fe78ff23319441d71ad92d3744972f78 100644
--- a/mm/oom_kill.c
+++ b/mm/oom_kill.c
@@ -487,9 +487,10 @@ static DECLARE_WAIT_QUEUE_HEAD(oom_reaper_wait);
 static struct task_struct *oom_reaper_list;
 static DEFINE_SPINLOCK(oom_reaper_lock);
 
-void __oom_reap_task_mm(struct mm_struct *mm)
+bool __oom_reap_task_mm(struct mm_struct *mm)
 {
 	struct vm_area_struct *vma;
+	bool ret = true;
 
 	/*
 	 * Tell all users of get_user/copy_from_user etc... that the content
@@ -519,12 +520,17 @@ void __oom_reap_task_mm(struct mm_struct *mm)
 			struct mmu_gather tlb;
 
 			tlb_gather_mmu(&tlb, mm, start, end);
-			mmu_notifier_invalidate_range_start(mm, start, end);
+			if (mmu_notifier_invalidate_range_start_nonblock(mm, start, end)) {
+				ret = false;
+				continue;
+			}
 			unmap_page_range(&tlb, vma, start, end, NULL);
 			mmu_notifier_invalidate_range_end(mm, start, end);
 			tlb_finish_mmu(&tlb, start, end);
 		}
 	}
+
+	return ret;
 }
 
 static bool oom_reap_task_mm(struct task_struct *tsk, struct mm_struct *mm)
@@ -553,18 +559,6 @@ static bool oom_reap_task_mm(struct task_struct *tsk, struct mm_struct *mm)
 		goto unlock_oom;
 	}
 
-	/*
-	 * If the mm has invalidate_{start,end}() notifiers that could block,
-	 * sleep to give the oom victim some more time.
-	 * TODO: we really want to get rid of this ugly hack and make sure that
-	 * notifiers cannot block for unbounded amount of time
-	 */
-	if (mm_has_blockable_invalidate_notifiers(mm)) {
-		up_read(&mm->mmap_sem);
-		schedule_timeout_idle(HZ);
-		goto unlock_oom;
-	}
-
 	/*
 	 * MMF_OOM_SKIP is set by exit_mmap when the OOM reaper can't
 	 * work on the mm anymore. The check for MMF_OOM_SKIP must run
@@ -579,7 +573,12 @@ static bool oom_reap_task_mm(struct task_struct *tsk, struct mm_struct *mm)
 
 	trace_start_task_reaping(tsk->pid);
 
-	__oom_reap_task_mm(mm);
+	/* failed to reap part of the address space. Try again later */
+	if (!__oom_reap_task_mm(mm)) {
+		up_read(&mm->mmap_sem);
+		ret = false;
+		goto unlock_oom;
+	}
 
 	pr_info("oom_reaper: reaped process %d (%s), now anon-rss:%lukB, file-rss:%lukB, shmem-rss:%lukB\n",
 			task_pid_nr(tsk), tsk->comm,
diff --git a/virt/kvm/kvm_main.c b/virt/kvm/kvm_main.c
index 9263ead9fd32506250d5559feaf008d10a5e5846..0116b449b99346752e542b52dc70ceb5c469e0b7 100644
--- a/virt/kvm/kvm_main.c
+++ b/virt/kvm/kvm_main.c
@@ -140,9 +140,10 @@ static void kvm_uevent_notify_change(unsigned int type, struct kvm *kvm);
 static unsigned long long kvm_createvm_count;
 static unsigned long long kvm_active_vms;
 
-__weak void kvm_arch_mmu_notifier_invalidate_range(struct kvm *kvm,
-		unsigned long start, unsigned long end)
+__weak int kvm_arch_mmu_notifier_invalidate_range(struct kvm *kvm,
+		unsigned long start, unsigned long end, bool blockable)
 {
+	return 0;
 }
 
 bool kvm_is_reserved_pfn(kvm_pfn_t pfn)
@@ -360,13 +361,15 @@ static void kvm_mmu_notifier_change_pte(struct mmu_notifier *mn,
 	srcu_read_unlock(&kvm->srcu, idx);
 }
 
-static void kvm_mmu_notifier_invalidate_range_start(struct mmu_notifier *mn,
+static int kvm_mmu_notifier_invalidate_range_start(struct mmu_notifier *mn,
 						    struct mm_struct *mm,
 						    unsigned long start,
-						    unsigned long end)
+						    unsigned long end,
+						    bool blockable)
 {
 	struct kvm *kvm = mmu_notifier_to_kvm(mn);
 	int need_tlb_flush = 0, idx;
+	int ret;
 
 	idx = srcu_read_lock(&kvm->srcu);
 	spin_lock(&kvm->mmu_lock);
@@ -384,9 +387,11 @@ static void kvm_mmu_notifier_invalidate_range_start(struct mmu_notifier *mn,
 
 	spin_unlock(&kvm->mmu_lock);
 
-	kvm_arch_mmu_notifier_invalidate_range(kvm, start, end);
+	ret = kvm_arch_mmu_notifier_invalidate_range(kvm, start, end, blockable);
 
 	srcu_read_unlock(&kvm->srcu, idx);
+
+	return ret;
 }
 
 static void kvm_mmu_notifier_invalidate_range_end(struct mmu_notifier *mn,