diff --git a/drivers/virtio/virtio_pci_common.c b/drivers/virtio/virtio_pci_common.c
index 4608fd9aaa6c3491967c690c44e7fa8ad0a365cb..3921b0a2439ed775a6a10323cb350d4b6974c47c 100644
--- a/drivers/virtio/virtio_pci_common.c
+++ b/drivers/virtio/virtio_pci_common.c
@@ -125,7 +125,7 @@ void vp_del_vqs(struct virtio_device *vdev)
 
 	vp_remove_vqs(vdev);
 
-	if (vp_dev->pci_dev->msix_enabled) {
+	if (vp_dev->msix_enabled) {
 		for (i = 0; i < vp_dev->msix_vectors; i++)
 			free_cpumask_var(vp_dev->msix_affinity_masks[i]);
 
@@ -249,6 +249,7 @@ static int vp_find_vqs_msix(struct virtio_device *vdev, unsigned nvqs,
 			allocated_vectors++;
 	}
 
+	vp_dev->msix_enabled = 1;
 	return 0;
 
 out_remove_vqs:
@@ -343,7 +344,7 @@ int vp_set_vq_affinity(struct virtqueue *vq, int cpu)
 	if (!vq->callback)
 		return -EINVAL;
 
-	if (vp_dev->pci_dev->msix_enabled) {
+	if (vp_dev->msix_enabled) {
 		int vec = vp_dev->msix_vector_map[vq->index];
 		struct cpumask *mask = vp_dev->msix_affinity_masks[vec];
 		unsigned int irq = pci_irq_vector(vp_dev->pci_dev, vec);
diff --git a/drivers/virtio/virtio_pci_common.h b/drivers/virtio/virtio_pci_common.h
index ac8c9d7889646ab3cc28bb51accd0d3840d5a34f..c8074997fd28a00dc535f3382f242700ca69e384 100644
--- a/drivers/virtio/virtio_pci_common.h
+++ b/drivers/virtio/virtio_pci_common.h
@@ -64,6 +64,8 @@ struct virtio_pci_device {
 	/* the IO mapping for the PCI config space */
 	void __iomem *ioaddr;
 
+	/* MSI-X support */
+	int msix_enabled;
 	cpumask_var_t *msix_affinity_masks;
 	/* Name strings for interrupts. This size should be enough,
 	 * and I'm too lazy to allocate each name separately. */
diff --git a/drivers/virtio/virtio_pci_legacy.c b/drivers/virtio/virtio_pci_legacy.c
index f7362c5fe18a96a902bc81138b8ea796e03e79d9..5dd01f09608bc852f2cd43413cc9cb1e28831e2d 100644
--- a/drivers/virtio/virtio_pci_legacy.c
+++ b/drivers/virtio/virtio_pci_legacy.c
@@ -165,7 +165,7 @@ static void del_vq(struct virtqueue *vq)
 
 	iowrite16(vq->index, vp_dev->ioaddr + VIRTIO_PCI_QUEUE_SEL);
 
-	if (vp_dev->pci_dev->msix_enabled) {
+	if (vp_dev->msix_enabled) {
 		iowrite16(VIRTIO_MSI_NO_VECTOR,
 			  vp_dev->ioaddr + VIRTIO_MSI_QUEUE_VECTOR);
 		/* Flush the write out to device */
diff --git a/drivers/virtio/virtio_pci_modern.c b/drivers/virtio/virtio_pci_modern.c
index 7bc3004b840ef3e3dabb5c2e24af8a935f552eba..7ce36daccc31953fb108d3a786ee67b8f7e5e191 100644
--- a/drivers/virtio/virtio_pci_modern.c
+++ b/drivers/virtio/virtio_pci_modern.c
@@ -411,7 +411,7 @@ static void del_vq(struct virtqueue *vq)
 
 	vp_iowrite16(vq->index, &vp_dev->common->queue_select);
 
-	if (vp_dev->pci_dev->msix_enabled) {
+	if (vp_dev->msix_enabled) {
 		vp_iowrite16(VIRTIO_MSI_NO_VECTOR,
 			     &vp_dev->common->queue_msix_vector);
 		/* Flush the write out to device */
diff --git a/include/uapi/linux/virtio_pci.h b/include/uapi/linux/virtio_pci.h
index 15b4385a2be169e9b99221f31a444820feadae3d..90007a1abcab144ac3d6ac7d6e6f4001d58abb14 100644
--- a/include/uapi/linux/virtio_pci.h
+++ b/include/uapi/linux/virtio_pci.h
@@ -79,7 +79,7 @@
  * configuration space */
 #define VIRTIO_PCI_CONFIG_OFF(msix_enabled)	((msix_enabled) ? 24 : 20)
 /* Deprecated: please use VIRTIO_PCI_CONFIG_OFF instead */
-#define VIRTIO_PCI_CONFIG(dev)	VIRTIO_PCI_CONFIG_OFF((dev)->pci_dev->msix_enabled)
+#define VIRTIO_PCI_CONFIG(dev)	VIRTIO_PCI_CONFIG_OFF((dev)->msix_enabled)
 
 /* Virtio ABI version, this must match exactly */
 #define VIRTIO_PCI_ABI_VERSION		0