diff --git a/drivers/block/virtio_blk.c b/drivers/block/virtio_blk.c
index 655e570b9b3170339b47b77046290f832f857001..5ea2f0bbbc7c3d6e62a016bcce22ff5a8ce46c1b 100644
--- a/drivers/block/virtio_blk.c
+++ b/drivers/block/virtio_blk.c
@@ -342,7 +342,7 @@ static void virtblk_config_changed_work(struct work_struct *work)
 	struct request_queue *q = vblk->disk->queue;
 	char cap_str_2[10], cap_str_10[10];
 	char *envp[] = { "RESIZE=1", NULL };
-	u64 capacity, size;
+	u64 capacity;
 
 	/* Host must always specify the capacity. */
 	virtio_cread(vdev, struct virtio_blk_config, capacity, &capacity);
@@ -354,9 +354,10 @@ static void virtblk_config_changed_work(struct work_struct *work)
 		capacity = (sector_t)-1;
 	}
 
-	size = capacity * queue_logical_block_size(q);
-	string_get_size(size, STRING_UNITS_2, cap_str_2, sizeof(cap_str_2));
-	string_get_size(size, STRING_UNITS_10, cap_str_10, sizeof(cap_str_10));
+	string_get_size(capacity, queue_logical_block_size(q),
+			STRING_UNITS_2, cap_str_2, sizeof(cap_str_2));
+	string_get_size(capacity, queue_logical_block_size(q),
+			STRING_UNITS_10, cap_str_10, sizeof(cap_str_10));
 
 	dev_notice(&vdev->dev,
 		  "new size: %llu %d-byte logical blocks (%s/%s)\n",
diff --git a/drivers/mmc/card/block.c b/drivers/mmc/card/block.c
index c69afb5e264e53976a3bf85a1bbd31d127668381..2fc426926574ec0b092485313b852158272d96ae 100644
--- a/drivers/mmc/card/block.c
+++ b/drivers/mmc/card/block.c
@@ -2230,7 +2230,7 @@ static int mmc_blk_alloc_part(struct mmc_card *card,
 	part_md->part_type = part_type;
 	list_add(&part_md->part, &md->part);
 
-	string_get_size((u64)get_capacity(part_md->disk) << 9, STRING_UNITS_2,
+	string_get_size((u64)get_capacity(part_md->disk), 512, STRING_UNITS_2,
 			cap_str, sizeof(cap_str));
 	pr_info("%s: %s %s partition %u %s\n",
 	       part_md->disk->disk_name, mmc_card_id(card),
@@ -2436,7 +2436,7 @@ static int mmc_blk_probe(struct device *dev)
 	if (IS_ERR(md))
 		return PTR_ERR(md);
 
-	string_get_size((u64)get_capacity(md->disk) << 9, STRING_UNITS_2,
+	string_get_size((u64)get_capacity(md->disk), 512, STRING_UNITS_2,
 			cap_str, sizeof(cap_str));
 	pr_info("%s: %s %s %s %s\n",
 		md->disk->disk_name, mmc_card_id(card), mmc_card_name(card),
diff --git a/drivers/scsi/sd.c b/drivers/scsi/sd.c
index 9e0c63e57affd837c37ebe0d4b4cbda60e7862f5..dcc42446f58a5e9c21ee72e7fca34d9ff6235c7b 100644
--- a/drivers/scsi/sd.c
+++ b/drivers/scsi/sd.c
@@ -2211,11 +2211,11 @@ sd_read_capacity(struct scsi_disk *sdkp, unsigned char *buffer)
 
 	{
 		char cap_str_2[10], cap_str_10[10];
-		u64 sz = (u64)sdkp->capacity << ilog2(sector_size);
 
-		string_get_size(sz, STRING_UNITS_2, cap_str_2,
-				sizeof(cap_str_2));
-		string_get_size(sz, STRING_UNITS_10, cap_str_10,
+		string_get_size(sdkp->capacity, sector_size,
+				STRING_UNITS_2, cap_str_2, sizeof(cap_str_2));
+		string_get_size(sdkp->capacity, sector_size,
+				STRING_UNITS_10, cap_str_10,
 				sizeof(cap_str_10));
 
 		if (sdkp->first_scan || old_capacity != sdkp->capacity) {
diff --git a/include/linux/string_helpers.h b/include/linux/string_helpers.h
index 657571817260ab9b22b76606f81492074f240d49..2633280637307f52a81adfc6b7dac126e0628e72 100644
--- a/include/linux/string_helpers.h
+++ b/include/linux/string_helpers.h
@@ -10,7 +10,7 @@ enum string_size_units {
 	STRING_UNITS_2,		/* use binary powers of 2^10 */
 };
 
-void string_get_size(u64 size, enum string_size_units units,
+void string_get_size(u64 size, u64 blk_size, enum string_size_units units,
 		     char *buf, int len);
 
 #define UNESCAPE_SPACE		0x01
diff --git a/lib/string_helpers.c b/lib/string_helpers.c
index 8f8c4417f22865d9ffb274700540b37f1a17bd93..4a913ec3acf90dedae354b62bdebdfb78fb9fd0d 100644
--- a/lib/string_helpers.c
+++ b/lib/string_helpers.c
@@ -4,6 +4,7 @@
  * Copyright 31 August 2008 James Bottomley
  * Copyright (C) 2013, Intel Corporation
  */
+#include <linux/bug.h>
 #include <linux/kernel.h>
 #include <linux/math64.h>
 #include <linux/export.h>
@@ -14,7 +15,8 @@
 
 /**
  * string_get_size - get the size in the specified units
- * @size:	The size to be converted
+ * @size:	The size to be converted in blocks
+ * @blk_size:	Size of the block (use 1 for size in bytes)
  * @units:	units to use (powers of 1000 or 1024)
  * @buf:	buffer to format to
  * @len:	length of buffer
@@ -24,14 +26,14 @@
  * at least 9 bytes and will always be zero terminated.
  *
  */
-void string_get_size(u64 size, const enum string_size_units units,
+void string_get_size(u64 size, u64 blk_size, const enum string_size_units units,
 		     char *buf, int len)
 {
 	static const char *const units_10[] = {
-		"B", "kB", "MB", "GB", "TB", "PB", "EB"
+		"B", "kB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB"
 	};
 	static const char *const units_2[] = {
-		"B", "KiB", "MiB", "GiB", "TiB", "PiB", "EiB"
+		"B", "KiB", "MiB", "GiB", "TiB", "PiB", "EiB", "ZiB", "YiB"
 	};
 	static const char *const *const units_str[] = {
 		[STRING_UNITS_10] = units_10,
@@ -42,31 +44,57 @@ void string_get_size(u64 size, const enum string_size_units units,
 		[STRING_UNITS_2] = 1024,
 	};
 	int i, j;
-	u32 remainder = 0, sf_cap;
+	u32 remainder = 0, sf_cap, exp;
 	char tmp[8];
+	const char *unit;
 
 	tmp[0] = '\0';
 	i = 0;
-	if (size >= divisor[units]) {
-		while (size >= divisor[units]) {
-			remainder = do_div(size, divisor[units]);
-			i++;
-		}
+	if (!size)
+		goto out;
 
-		sf_cap = size;
-		for (j = 0; sf_cap*10 < 1000; j++)
-			sf_cap *= 10;
+	while (blk_size >= divisor[units]) {
+		remainder = do_div(blk_size, divisor[units]);
+		i++;
+	}
 
-		if (j) {
-			remainder *= 1000;
-			remainder /= divisor[units];
-			snprintf(tmp, sizeof(tmp), ".%03u", remainder);
-			tmp[j+1] = '\0';
-		}
+	exp = divisor[units] / (u32)blk_size;
+	if (size >= exp) {
+		remainder = do_div(size, divisor[units]);
+		remainder *= blk_size;
+		i++;
+	} else {
+		remainder *= size;
+	}
+
+	size *= blk_size;
+	size += remainder / divisor[units];
+	remainder %= divisor[units];
+
+	while (size >= divisor[units]) {
+		remainder = do_div(size, divisor[units]);
+		i++;
 	}
 
+	sf_cap = size;
+	for (j = 0; sf_cap*10 < 1000; j++)
+		sf_cap *= 10;
+
+	if (j) {
+		remainder *= 1000;
+		remainder /= divisor[units];
+		snprintf(tmp, sizeof(tmp), ".%03u", remainder);
+		tmp[j+1] = '\0';
+	}
+
+ out:
+	if (i >= ARRAY_SIZE(units_2))
+		unit = "UNK";
+	else
+		unit = units_str[units][i];
+
 	snprintf(buf, len, "%u%s %s", (u32)size,
-		 tmp, units_str[units][i]);
+		 tmp, unit);
 }
 EXPORT_SYMBOL(string_get_size);