diff --git a/drivers/dax/super.c b/drivers/dax/super.c
index 938eb4868f7f78c7264cae58b15a0dcc0565d5c4..8b458f1b30c786fb8768e81f1153eaf06873c94d 100644
--- a/drivers/dax/super.c
+++ b/drivers/dax/super.c
@@ -189,8 +189,10 @@ static umode_t dax_visible(struct kobject *kobj, struct attribute *a, int n)
 	if (!dax_dev)
 		return 0;
 
-	if (a == &dev_attr_write_cache.attr && !dax_dev->ops->flush)
+#ifndef CONFIG_ARCH_HAS_PMEM_API
+	if (a == &dev_attr_write_cache.attr)
 		return 0;
+#endif
 	return a->mode;
 }
 
@@ -255,18 +257,23 @@ size_t dax_copy_from_iter(struct dax_device *dax_dev, pgoff_t pgoff, void *addr,
 }
 EXPORT_SYMBOL_GPL(dax_copy_from_iter);
 
-void dax_flush(struct dax_device *dax_dev, pgoff_t pgoff, void *addr,
-		size_t size)
+#ifdef CONFIG_ARCH_HAS_PMEM_API
+void arch_wb_cache_pmem(void *addr, size_t size);
+void dax_flush(struct dax_device *dax_dev, void *addr, size_t size)
 {
-	if (!dax_alive(dax_dev))
+	if (unlikely(!dax_alive(dax_dev)))
 		return;
 
-	if (!test_bit(DAXDEV_WRITE_CACHE, &dax_dev->flags))
+	if (unlikely(!test_bit(DAXDEV_WRITE_CACHE, &dax_dev->flags)))
 		return;
 
-	if (dax_dev->ops->flush)
-		dax_dev->ops->flush(dax_dev, pgoff, addr, size);
+	arch_wb_cache_pmem(addr, size);
 }
+#else
+void dax_flush(struct dax_device *dax_dev, void *addr, size_t size)
+{
+}
+#endif
 EXPORT_SYMBOL_GPL(dax_flush);
 
 void dax_write_cache(struct dax_device *dax_dev, bool wc)
diff --git a/drivers/md/dm-linear.c b/drivers/md/dm-linear.c
index 41971a090e34d02ea26a348af0c66fc9615b44a2..208800610af83b2c656dfc651b3056a7cd48631b 100644
--- a/drivers/md/dm-linear.c
+++ b/drivers/md/dm-linear.c
@@ -184,20 +184,6 @@ static size_t linear_dax_copy_from_iter(struct dm_target *ti, pgoff_t pgoff,
 	return dax_copy_from_iter(dax_dev, pgoff, addr, bytes, i);
 }
 
-static void linear_dax_flush(struct dm_target *ti, pgoff_t pgoff, void *addr,
-		size_t size)
-{
-	struct linear_c *lc = ti->private;
-	struct block_device *bdev = lc->dev->bdev;
-	struct dax_device *dax_dev = lc->dev->dax_dev;
-	sector_t dev_sector, sector = pgoff * PAGE_SECTORS;
-
-	dev_sector = linear_map_sector(ti, sector);
-	if (bdev_dax_pgoff(bdev, dev_sector, ALIGN(size, PAGE_SIZE), &pgoff))
-		return;
-	dax_flush(dax_dev, pgoff, addr, size);
-}
-
 static struct target_type linear_target = {
 	.name   = "linear",
 	.version = {1, 4, 0},
@@ -212,7 +198,6 @@ static struct target_type linear_target = {
 	.iterate_devices = linear_iterate_devices,
 	.direct_access = linear_dax_direct_access,
 	.dax_copy_from_iter = linear_dax_copy_from_iter,
-	.dax_flush = linear_dax_flush,
 };
 
 int __init dm_linear_init(void)
diff --git a/drivers/md/dm-stripe.c b/drivers/md/dm-stripe.c
index a0375530b07f6afc7713f4ac2904e84490027d67..1690bb299b3f77881444137edfc9f9f1b1356db4 100644
--- a/drivers/md/dm-stripe.c
+++ b/drivers/md/dm-stripe.c
@@ -351,25 +351,6 @@ static size_t stripe_dax_copy_from_iter(struct dm_target *ti, pgoff_t pgoff,
 	return dax_copy_from_iter(dax_dev, pgoff, addr, bytes, i);
 }
 
-static void stripe_dax_flush(struct dm_target *ti, pgoff_t pgoff, void *addr,
-		size_t size)
-{
-	sector_t dev_sector, sector = pgoff * PAGE_SECTORS;
-	struct stripe_c *sc = ti->private;
-	struct dax_device *dax_dev;
-	struct block_device *bdev;
-	uint32_t stripe;
-
-	stripe_map_sector(sc, sector, &stripe, &dev_sector);
-	dev_sector += sc->stripe[stripe].physical_start;
-	dax_dev = sc->stripe[stripe].dev->dax_dev;
-	bdev = sc->stripe[stripe].dev->bdev;
-
-	if (bdev_dax_pgoff(bdev, dev_sector, ALIGN(size, PAGE_SIZE), &pgoff))
-		return;
-	dax_flush(dax_dev, pgoff, addr, size);
-}
-
 /*
  * Stripe status:
  *
@@ -491,7 +472,6 @@ static struct target_type stripe_target = {
 	.io_hints = stripe_io_hints,
 	.direct_access = stripe_dax_direct_access,
 	.dax_copy_from_iter = stripe_dax_copy_from_iter,
-	.dax_flush = stripe_dax_flush,
 };
 
 int __init dm_stripe_init(void)
diff --git a/drivers/md/dm.c b/drivers/md/dm.c
index d669fddd9290d027a39fa2e65b01b0a1b24afc0a..825eaffc24da9706152e092cf074678e743fc743 100644
--- a/drivers/md/dm.c
+++ b/drivers/md/dm.c
@@ -987,24 +987,6 @@ static size_t dm_dax_copy_from_iter(struct dax_device *dax_dev, pgoff_t pgoff,
 	return ret;
 }
 
-static void dm_dax_flush(struct dax_device *dax_dev, pgoff_t pgoff, void *addr,
-		size_t size)
-{
-	struct mapped_device *md = dax_get_private(dax_dev);
-	sector_t sector = pgoff * PAGE_SECTORS;
-	struct dm_target *ti;
-	int srcu_idx;
-
-	ti = dm_dax_get_live_target(md, sector, &srcu_idx);
-
-	if (!ti)
-		goto out;
-	if (ti->type->dax_flush)
-		ti->type->dax_flush(ti, pgoff, addr, size);
- out:
-	dm_put_live_table(md, srcu_idx);
-}
-
 /*
  * A target may call dm_accept_partial_bio only from the map routine.  It is
  * allowed for all bio types except REQ_PREFLUSH.
@@ -2992,7 +2974,6 @@ static const struct block_device_operations dm_blk_dops = {
 static const struct dax_operations dm_dax_ops = {
 	.direct_access = dm_dax_direct_access,
 	.copy_from_iter = dm_dax_copy_from_iter,
-	.flush = dm_dax_flush,
 };
 
 /*
diff --git a/drivers/nvdimm/pmem.c b/drivers/nvdimm/pmem.c
index f7099adaabc0b643e703cf8b160b5c5debda2bde..88c1282587603122995e178e1e77d481b3e30dfa 100644
--- a/drivers/nvdimm/pmem.c
+++ b/drivers/nvdimm/pmem.c
@@ -243,16 +243,9 @@ static size_t pmem_copy_from_iter(struct dax_device *dax_dev, pgoff_t pgoff,
 	return copy_from_iter_flushcache(addr, bytes, i);
 }
 
-static void pmem_dax_flush(struct dax_device *dax_dev, pgoff_t pgoff,
-		void *addr, size_t size)
-{
-	arch_wb_cache_pmem(addr, size);
-}
-
 static const struct dax_operations pmem_dax_ops = {
 	.direct_access = pmem_dax_direct_access,
 	.copy_from_iter = pmem_copy_from_iter,
-	.flush = pmem_dax_flush,
 };
 
 static const struct attribute_group *pmem_attribute_groups[] = {
diff --git a/fs/dax.c b/fs/dax.c
index 865d42c63e23e4746c6658fbe2746fd17ade11f9..18d970fb0e09f0b8c45f0468f0e5eb1377304969 100644
--- a/fs/dax.c
+++ b/fs/dax.c
@@ -783,7 +783,7 @@ static int dax_writeback_one(struct block_device *bdev,
 	}
 
 	dax_mapping_entry_mkclean(mapping, index, pfn_t_to_pfn(pfn));
-	dax_flush(dax_dev, pgoff, kaddr, size);
+	dax_flush(dax_dev, kaddr, size);
 	/*
 	 * After we have flushed the cache, we can clear the dirty tag. There
 	 * cannot be new dirty data in the pfn after the flush has completed as
@@ -978,7 +978,7 @@ int __dax_zero_page_range(struct block_device *bdev,
 			return rc;
 		}
 		memset(kaddr + offset, 0, size);
-		dax_flush(dax_dev, pgoff, kaddr + offset, size);
+		dax_flush(dax_dev, kaddr + offset, size);
 		dax_read_unlock(id);
 	}
 	return 0;
diff --git a/include/linux/dax.h b/include/linux/dax.h
index df97b7af7e2c7263c4ae9fb1608b19aad1ae02ba..0d8f35f6c53dce846863b25eed27660f96d090c9 100644
--- a/include/linux/dax.h
+++ b/include/linux/dax.h
@@ -19,8 +19,6 @@ struct dax_operations {
 	/* copy_from_iter: required operation for fs-dax direct-i/o */
 	size_t (*copy_from_iter)(struct dax_device *, pgoff_t, void *, size_t,
 			struct iov_iter *);
-	/* flush: optional driver-specific cache management after writes */
-	void (*flush)(struct dax_device *, pgoff_t, void *, size_t);
 };
 
 extern struct attribute_group dax_attribute_group;
@@ -84,8 +82,7 @@ long dax_direct_access(struct dax_device *dax_dev, pgoff_t pgoff, long nr_pages,
 		void **kaddr, pfn_t *pfn);
 size_t dax_copy_from_iter(struct dax_device *dax_dev, pgoff_t pgoff, void *addr,
 		size_t bytes, struct iov_iter *i);
-void dax_flush(struct dax_device *dax_dev, pgoff_t pgoff, void *addr,
-		size_t size);
+void dax_flush(struct dax_device *dax_dev, void *addr, size_t size);
 void dax_write_cache(struct dax_device *dax_dev, bool wc);
 bool dax_write_cache_enabled(struct dax_device *dax_dev);
 
diff --git a/include/linux/device-mapper.h b/include/linux/device-mapper.h
index 3b5fdf308148410fd684da30b30cc7c215c3b47c..a5538433c927abebad75ec6763a3e6ab8f1cc0e2 100644
--- a/include/linux/device-mapper.h
+++ b/include/linux/device-mapper.h
@@ -134,8 +134,6 @@ typedef long (*dm_dax_direct_access_fn) (struct dm_target *ti, pgoff_t pgoff,
 		long nr_pages, void **kaddr, pfn_t *pfn);
 typedef size_t (*dm_dax_copy_from_iter_fn)(struct dm_target *ti, pgoff_t pgoff,
 		void *addr, size_t bytes, struct iov_iter *i);
-typedef void (*dm_dax_flush_fn)(struct dm_target *ti, pgoff_t pgoff, void *addr,
-		size_t size);
 #define PAGE_SECTORS (PAGE_SIZE / 512)
 
 void dm_error(const char *message);
@@ -186,7 +184,6 @@ struct target_type {
 	dm_io_hints_fn io_hints;
 	dm_dax_direct_access_fn direct_access;
 	dm_dax_copy_from_iter_fn dax_copy_from_iter;
-	dm_dax_flush_fn dax_flush;
 
 	/* For internal device-mapper use. */
 	struct list_head list;