diff --git a/drivers/scsi/cxlflash/common.h b/drivers/scsi/cxlflash/common.h
index c11cd193f8964ff1741453e309a85a0f6d5d4402..5ada9268a450db224663f2fa28e43fc79701c886 100644
--- a/drivers/scsi/cxlflash/common.h
+++ b/drivers/scsi/cxlflash/common.h
@@ -165,6 +165,8 @@ struct afu {
 	struct sisl_host_map __iomem *host_map;		/* MC host map */
 	struct sisl_ctrl_map __iomem *ctrl_map;		/* MC control map */
 
+	struct kref mapcount;
+
 	ctx_hndl_t ctx_hndl;	/* master's context handle */
 	u64 *hrrq_start;
 	u64 *hrrq_end;
diff --git a/drivers/scsi/cxlflash/main.c b/drivers/scsi/cxlflash/main.c
index ac39856a74b49f02d9b0fa57d8628a93871a8191..30542ca9415b02cfeab503a313bb72609db8d35c 100644
--- a/drivers/scsi/cxlflash/main.c
+++ b/drivers/scsi/cxlflash/main.c
@@ -368,6 +368,7 @@ static int send_cmd(struct afu *afu, struct afu_cmd *cmd)
 
 no_room:
 	afu->read_room = true;
+	kref_get(&cfg->afu->mapcount);
 	schedule_work(&cfg->work_q);
 	rc = SCSI_MLQUEUE_HOST_BUSY;
 	goto out;
@@ -473,6 +474,16 @@ static int send_tmf(struct afu *afu, struct scsi_cmnd *scp, u64 tmfcmd)
 	return rc;
 }
 
+static void afu_unmap(struct kref *ref)
+{
+	struct afu *afu = container_of(ref, struct afu, mapcount);
+
+	if (likely(afu->afu_map)) {
+		cxl_psa_unmap((void __iomem *)afu->afu_map);
+		afu->afu_map = NULL;
+	}
+}
+
 /**
  * cxlflash_driver_info() - information handler for this host driver
  * @host:	SCSI host associated with device.
@@ -503,6 +514,7 @@ static int cxlflash_queuecommand(struct Scsi_Host *host, struct scsi_cmnd *scp)
 	ulong lock_flags;
 	short lflag = 0;
 	int rc = 0;
+	int kref_got = 0;
 
 	dev_dbg_ratelimited(dev, "%s: (scp=%p) %d/%d/%d/%llu "
 			    "cdb=(%08X-%08X-%08X-%08X)\n",
@@ -547,6 +559,9 @@ static int cxlflash_queuecommand(struct Scsi_Host *host, struct scsi_cmnd *scp)
 		goto out;
 	}
 
+	kref_get(&cfg->afu->mapcount);
+	kref_got = 1;
+
 	cmd->rcb.ctx_id = afu->ctx_hndl;
 	cmd->rcb.port_sel = port_sel;
 	cmd->rcb.lun_id = lun_to_lunid(scp->device->lun);
@@ -587,6 +602,8 @@ static int cxlflash_queuecommand(struct Scsi_Host *host, struct scsi_cmnd *scp)
 	}
 
 out:
+	if (kref_got)
+		kref_put(&afu->mapcount, afu_unmap);
 	pr_devel("%s: returning rc=%d\n", __func__, rc);
 	return rc;
 }
@@ -661,6 +678,7 @@ static void stop_afu(struct cxlflash_cfg *cfg)
 			cxl_psa_unmap((void __iomem *)afu->afu_map);
 			afu->afu_map = NULL;
 		}
+		kref_put(&afu->mapcount, afu_unmap);
 	}
 }
 
@@ -746,8 +764,8 @@ static void cxlflash_remove(struct pci_dev *pdev)
 		scsi_remove_host(cfg->host);
 		/* fall through */
 	case INIT_STATE_AFU:
-		term_afu(cfg);
 		cancel_work_sync(&cfg->work_q);
+		term_afu(cfg);
 	case INIT_STATE_PCI:
 		pci_release_regions(cfg->dev);
 		pci_disable_device(pdev);
@@ -1331,6 +1349,7 @@ static irqreturn_t cxlflash_async_err_irq(int irq, void *data)
 				__func__, port);
 			cfg->lr_state = LINK_RESET_REQUIRED;
 			cfg->lr_port = port;
+			kref_get(&cfg->afu->mapcount);
 			schedule_work(&cfg->work_q);
 		}
 
@@ -1351,6 +1370,7 @@ static irqreturn_t cxlflash_async_err_irq(int irq, void *data)
 
 		if (info->action & SCAN_HOST) {
 			atomic_inc(&cfg->scan_host_needed);
+			kref_get(&cfg->afu->mapcount);
 			schedule_work(&cfg->work_q);
 		}
 	}
@@ -1746,6 +1766,7 @@ static int init_afu(struct cxlflash_cfg *cfg)
 		rc = -ENOMEM;
 		goto err1;
 	}
+	kref_init(&afu->mapcount);
 
 	/* No byte reverse on reading afu_version or string will be backwards */
 	reg = readq(&afu->afu_map->global.regs.afu_version);
@@ -1780,8 +1801,7 @@ static int init_afu(struct cxlflash_cfg *cfg)
 	return rc;
 
 err2:
-	cxl_psa_unmap((void __iomem *)afu->afu_map);
-	afu->afu_map = NULL;
+	kref_put(&afu->mapcount, afu_unmap);
 err1:
 	term_mc(cfg, UNDO_START);
 	goto out;
@@ -2354,6 +2374,7 @@ static void cxlflash_worker_thread(struct work_struct *work)
 
 	if (atomic_dec_if_positive(&cfg->scan_host_needed) >= 0)
 		scsi_scan_host(cfg->host);
+	kref_put(&afu->mapcount, afu_unmap);
 }
 
 /**