diff --git a/arch/arm/kvm/mmu.c b/arch/arm/kvm/mmu.c
index f2e2e0c6d6fdfcc0d0926134f19d32826efe3b60..13b9c1fa8961b42ca672ffcb752fb3943b2f6916 100644
--- a/arch/arm/kvm/mmu.c
+++ b/arch/arm/kvm/mmu.c
@@ -1803,6 +1803,7 @@ int kvm_arch_prepare_memory_region(struct kvm *kvm,
 	    (KVM_PHYS_SIZE >> PAGE_SHIFT))
 		return -EFAULT;
 
+	down_read(&current->mm->mmap_sem);
 	/*
 	 * A memory region could potentially cover multiple VMAs, and any holes
 	 * between them, so iterate over all of them to find out if we can map
@@ -1846,8 +1847,10 @@ int kvm_arch_prepare_memory_region(struct kvm *kvm,
 			pa += vm_start - vma->vm_start;
 
 			/* IO region dirty page logging not allowed */
-			if (memslot->flags & KVM_MEM_LOG_DIRTY_PAGES)
-				return -EINVAL;
+			if (memslot->flags & KVM_MEM_LOG_DIRTY_PAGES) {
+				ret = -EINVAL;
+				goto out;
+			}
 
 			ret = kvm_phys_addr_ioremap(kvm, gpa, pa,
 						    vm_end - vm_start,
@@ -1859,7 +1862,7 @@ int kvm_arch_prepare_memory_region(struct kvm *kvm,
 	} while (hva < reg_end);
 
 	if (change == KVM_MR_FLAGS_ONLY)
-		return ret;
+		goto out;
 
 	spin_lock(&kvm->mmu_lock);
 	if (ret)
@@ -1867,6 +1870,8 @@ int kvm_arch_prepare_memory_region(struct kvm *kvm,
 	else
 		stage2_flush_memslot(kvm, memslot);
 	spin_unlock(&kvm->mmu_lock);
+out:
+	up_read(&current->mm->mmap_sem);
 	return ret;
 }