diff --git a/drivers/pci/pci-driver.c b/drivers/pci/pci-driver.c
index ac6c9e493f4cb051eb5a9d73b4bf0e861a51dde3..93eac14235854c648e6e895c51f063b8af68e152 100644
--- a/drivers/pci/pci-driver.c
+++ b/drivers/pci/pci-driver.c
@@ -430,39 +430,22 @@ static void pci_pm_default_resume_noirq(struct pci_dev *pci_dev)
 	pci_fixup_device(pci_fixup_resume_early, pci_dev);
 }
 
-static int pci_pm_default_resume(struct pci_dev *pci_dev)
+static void pci_pm_default_resume(struct pci_dev *pci_dev)
 {
 	pci_fixup_device(pci_fixup_resume, pci_dev);
 
-	if (pci_is_bridge(pci_dev))
-		return 0;
-
-	pci_enable_wake(pci_dev, PCI_D0, false);
-	return pci_pm_reenable_device(pci_dev);
+	if (!pci_is_bridge(pci_dev))
+		pci_enable_wake(pci_dev, PCI_D0, false);
 }
 
-static void pci_pm_default_suspend_generic(struct pci_dev *pci_dev)
+static void pci_pm_default_suspend(struct pci_dev *pci_dev)
 {
-	/* If a non-bridge device is enabled at this point, disable it */
+	/* Disable non-bridge devices without PM support */
 	if (!pci_is_bridge(pci_dev))
 		pci_disable_enabled_device(pci_dev);
-	/*
-	 * Save state with interrupts enabled, because in principle the bus the
-	 * device is on may be put into a low power state after this code runs.
-	 */
 	pci_save_state(pci_dev);
 }
 
-static void pci_pm_default_suspend(struct pci_dev *pci_dev, bool prepare)
-{
-	pci_pm_default_suspend_generic(pci_dev);
-
-	if (prepare && !pci_is_bridge(pci_dev))
-		pci_prepare_to_sleep(pci_dev);
-
-	pci_fixup_device(pci_fixup_suspend, pci_dev);
-}
-
 static bool pci_has_legacy_pm_support(struct pci_dev *pci_dev)
 {
 	struct pci_driver *drv = pci_dev->driver;
@@ -506,20 +489,48 @@ static int pci_pm_suspend(struct device *dev)
 {
 	struct pci_dev *pci_dev = to_pci_dev(dev);
 	struct dev_pm_ops *pm = dev->driver ? dev->driver->pm : NULL;
-	int error = 0;
 
 	if (pci_has_legacy_pm_support(pci_dev))
 		return pci_legacy_suspend(dev, PMSG_SUSPEND);
 
-	if (pm && pm->suspend) {
+	if (!pm) {
+		pci_pm_default_suspend(pci_dev);
+		goto Fixup;
+	}
+
+	pci_dev->state_saved = false;
+
+	if (pm->suspend) {
+		pci_power_t prev = pci_dev->current_state;
+		int error;
+
 		error = pm->suspend(dev);
 		suspend_report_result(pm->suspend, error);
+		if (error)
+			return error;
+
+		if (pci_dev->state_saved)
+			goto Fixup;
+
+		if (pci_dev->current_state != PCI_D0
+		    && pci_dev->current_state != PCI_UNKNOWN) {
+			WARN_ONCE(pci_dev->current_state != prev,
+				"PCI PM: State of device not saved by %pF\n",
+				pm->suspend);
+			goto Fixup;
+		}
 	}
 
-	if (!error)
-		pci_pm_default_suspend(pci_dev, !!pm);
+	if (!pci_dev->state_saved) {
+		pci_save_state(pci_dev);
+		if (!pci_is_bridge(pci_dev))
+			pci_prepare_to_sleep(pci_dev);
+	}
 
-	return error;
+ Fixup:
+	pci_fixup_device(pci_fixup_suspend, pci_dev);
+
+	return 0;
 }
 
 static int pci_pm_suspend_noirq(struct device *dev)
@@ -562,7 +573,7 @@ static int pci_pm_resume_noirq(struct device *dev)
 static int pci_pm_resume(struct device *dev)
 {
 	struct pci_dev *pci_dev = to_pci_dev(dev);
-	struct device_driver *drv = dev->driver;
+	struct dev_pm_ops *pm = dev->driver ? dev->driver->pm : NULL;
 	int error = 0;
 
 	/*
@@ -575,12 +586,16 @@ static int pci_pm_resume(struct device *dev)
 	if (pci_has_legacy_pm_support(pci_dev))
 		return pci_legacy_resume(dev);
 
-	error = pci_pm_default_resume(pci_dev);
+	pci_pm_default_resume(pci_dev);
 
-	if (!error && drv && drv->pm && drv->pm->resume)
-		error = drv->pm->resume(dev);
+	if (pm) {
+		if (pm->resume)
+			error = pm->resume(dev);
+	} else {
+		pci_pm_reenable_device(pci_dev);
+	}
 
-	return error;
+	return 0;
 }
 
 #else /* !CONFIG_SUSPEND */
@@ -597,21 +612,31 @@ static int pci_pm_resume(struct device *dev)
 static int pci_pm_freeze(struct device *dev)
 {
 	struct pci_dev *pci_dev = to_pci_dev(dev);
-	struct device_driver *drv = dev->driver;
-	int error = 0;
+	struct dev_pm_ops *pm = dev->driver ? dev->driver->pm : NULL;
 
 	if (pci_has_legacy_pm_support(pci_dev))
 		return pci_legacy_suspend(dev, PMSG_FREEZE);
 
-	if (drv && drv->pm && drv->pm->freeze) {
-		error = drv->pm->freeze(dev);
-		suspend_report_result(drv->pm->freeze, error);
+	if (!pm) {
+		pci_pm_default_suspend(pci_dev);
+		return 0;
 	}
 
-	if (!error)
-		pci_pm_default_suspend_generic(pci_dev);
+	pci_dev->state_saved = false;
 
-	return error;
+	if (pm->freeze) {
+		int error;
+
+		error = pm->freeze(dev);
+		suspend_report_result(pm->freeze, error);
+		if (error)
+			return error;
+	}
+
+	if (!pci_dev->state_saved)
+		pci_save_state(pci_dev);
+
+	return 0;
 }
 
 static int pci_pm_freeze_noirq(struct device *dev)
@@ -654,16 +679,18 @@ static int pci_pm_thaw_noirq(struct device *dev)
 static int pci_pm_thaw(struct device *dev)
 {
 	struct pci_dev *pci_dev = to_pci_dev(dev);
-	struct device_driver *drv = dev->driver;
+	struct dev_pm_ops *pm = dev->driver ? dev->driver->pm : NULL;
 	int error = 0;
 
 	if (pci_has_legacy_pm_support(pci_dev))
 		return pci_legacy_resume(dev);
 
-	pci_pm_reenable_device(pci_dev);
-
-	if (drv && drv->pm && drv->pm->thaw)
-		error =  drv->pm->thaw(dev);
+	if (pm) {
+		if (pm->thaw)
+			error = pm->thaw(dev);
+	} else {
+		pci_pm_reenable_device(pci_dev);
+	}
 
 	return error;
 }
@@ -677,13 +704,23 @@ static int pci_pm_poweroff(struct device *dev)
 	if (pci_has_legacy_pm_support(pci_dev))
 		return pci_legacy_suspend(dev, PMSG_HIBERNATE);
 
-	if (pm && pm->poweroff) {
+	if (!pm) {
+		pci_pm_default_suspend(pci_dev);
+		goto Fixup;
+	}
+
+	pci_dev->state_saved = false;
+
+	if (pm->poweroff) {
 		error = pm->poweroff(dev);
 		suspend_report_result(pm->poweroff, error);
 	}
 
-	if (!error)
-		pci_pm_default_suspend(pci_dev, !!pm);
+	if (!pci_dev->state_saved && !pci_is_bridge(pci_dev))
+		pci_prepare_to_sleep(pci_dev);
+
+ Fixup:
+	pci_fixup_device(pci_fixup_suspend, pci_dev);
 
 	return error;
 }
@@ -724,7 +761,7 @@ static int pci_pm_restore_noirq(struct device *dev)
 static int pci_pm_restore(struct device *dev)
 {
 	struct pci_dev *pci_dev = to_pci_dev(dev);
-	struct device_driver *drv = dev->driver;
+	struct dev_pm_ops *pm = dev->driver ? dev->driver->pm : NULL;
 	int error = 0;
 
 	/*
@@ -737,10 +774,14 @@ static int pci_pm_restore(struct device *dev)
 	if (pci_has_legacy_pm_support(pci_dev))
 		return pci_legacy_resume(dev);
 
-	error = pci_pm_default_resume(pci_dev);
+	pci_pm_default_resume(pci_dev);
 
-	if (!error && drv && drv->pm && drv->pm->restore)
-		error = drv->pm->restore(dev);
+	if (pm) {
+		if (pm->restore)
+			error = pm->restore(dev);
+	} else {
+		pci_pm_reenable_device(pci_dev);
+	}
 
 	return error;
 }