diff --git a/drivers/net/ethernet/mellanox/mlx5/core/main.c b/drivers/net/ethernet/mellanox/mlx5/core/main.c
index 23c12f1aaa39719852a781099d540d860cc98c26..0b49739eadd3dd89c8fb4cfa6cddf0475b2cc8c2 100644
--- a/drivers/net/ethernet/mellanox/mlx5/core/main.c
+++ b/drivers/net/ethernet/mellanox/mlx5/core/main.c
@@ -1196,6 +1196,8 @@ static int mlx5_unload_one(struct mlx5_core_dev *dev, struct mlx5_priv *priv,
 {
 	int err = 0;
 
+	mlx5_drain_health_wq(dev);
+
 	mutex_lock(&dev->intf_state_mutex);
 	if (test_bit(MLX5_INTERFACE_STATE_DOWN, &dev->intf_state)) {
 		dev_warn(&dev->pdev->dev, "%s: interface is down, NOP\n",
@@ -1358,10 +1360,9 @@ static pci_ers_result_t mlx5_pci_err_detected(struct pci_dev *pdev,
 
 	mlx5_enter_error_state(dev);
 	mlx5_unload_one(dev, priv, false);
-	/* In case of kernel call save the pci state and drain health wq */
+	/* In case of kernel call save the pci state */
 	if (state) {
 		pci_save_state(pdev);
-		mlx5_drain_health_wq(dev);
 		mlx5_pci_disable_device(dev);
 	}