diff --git a/include/linux/xarray.h b/include/linux/xarray.h
index 78eede109b1a5ceed91964ad7459755d22d618d0..9988e86787305b4e0438c8a6b2ac3262865eb045 100644
--- a/include/linux/xarray.h
+++ b/include/linux/xarray.h
@@ -1407,16 +1407,44 @@ struct xa_state {
 			order - (order % XA_CHUNK_SHIFT),	\
 			(1U << (order % XA_CHUNK_SHIFT)) - 1)
 
+/**
+ * xas_invalid() - Is the xas in a retry or error state?
+ * @xas: XArray operation state.
+ *
+ * Return: %true if the xas cannot be used for operations.
+ */
+static inline bool xas_invalid(const struct xa_state *xas)
+{
+	return (unsigned long)xas->xa_node & 3;
+}
+
+/**
+ * xas_valid() - Is the xas a valid cursor into the array?
+ * @xas: XArray operation state.
+ *
+ * Return: %true if the xas can be used for operations.
+ */
+static inline bool xas_valid(const struct xa_state *xas)
+{
+	return !xas_invalid(xas);
+}
+
+static inline struct xa_state *XAS_INVALID(struct xa_state *xas)
+{
+	XA_NODE_BUG_ON(xas->xa_node, xas_valid(xas));
+	return xas;
+}
+
 #define xas_marked(xas, mark)	xa_marked((xas)->xa, (mark))
-#define xas_trylock(xas)	xa_trylock((xas)->xa)
-#define xas_lock(xas)		xa_lock((xas)->xa)
+#define xas_trylock(xas)	xa_trylock(XAS_INVALID(xas)->xa)
+#define xas_lock(xas)		xa_lock(XAS_INVALID(xas)->xa)
 #define xas_unlock(xas)		xa_unlock((xas)->xa)
-#define xas_lock_bh(xas)	xa_lock_bh((xas)->xa)
+#define xas_lock_bh(xas)	xa_lock_bh(XAS_INVALID(xas)->xa)
 #define xas_unlock_bh(xas)	xa_unlock_bh((xas)->xa)
-#define xas_lock_irq(xas)	xa_lock_irq((xas)->xa)
+#define xas_lock_irq(xas)	xa_lock_irq(XAS_INVALID(xas)->xa)
 #define xas_unlock_irq(xas)	xa_unlock_irq((xas)->xa)
 #define xas_lock_irqsave(xas, flags) \
-				xa_lock_irqsave((xas)->xa, flags)
+				xa_lock_irqsave(XAS_INVALID(xas)->xa, flags)
 #define xas_unlock_irqrestore(xas, flags) \
 				xa_unlock_irqrestore((xas)->xa, flags)
 
@@ -1445,28 +1473,6 @@ static inline void xas_set_err(struct xa_state *xas, long err)
 	xas->xa_node = XA_ERROR(err);
 }
 
-/**
- * xas_invalid() - Is the xas in a retry or error state?
- * @xas: XArray operation state.
- *
- * Return: %true if the xas cannot be used for operations.
- */
-static inline bool xas_invalid(const struct xa_state *xas)
-{
-	return (unsigned long)xas->xa_node & 3;
-}
-
-/**
- * xas_valid() - Is the xas a valid cursor into the array?
- * @xas: XArray operation state.
- *
- * Return: %true if the xas can be used for operations.
- */
-static inline bool xas_valid(const struct xa_state *xas)
-{
-	return !xas_invalid(xas);
-}
-
 /**
  * xas_is_node() - Does the xas point to a node?
  * @xas: XArray operation state.
diff --git a/lib/xarray.c b/lib/xarray.c
index 9644b18af18d1739ab6461f87312c6a2c432a6b7..261814d170d80fe435260a38d234e6026af0e7dc 100644
--- a/lib/xarray.c
+++ b/lib/xarray.c
@@ -368,7 +368,7 @@ static void *xas_alloc(struct xa_state *xas, unsigned int shift)
 		return NULL;
 
 	if (node) {
-		xas->xa_alloc = NULL;
+		xas->xa_alloc = rcu_dereference_raw(node->parent);
 	} else {
 		gfp_t gfp = GFP_NOWAIT | __GFP_NOWARN;
 
@@ -2380,7 +2380,6 @@ void xa_destroy(struct xarray *xa)
 	unsigned long flags;
 	void *entry;
 
-	xas.xa_node = NULL;
 	xas_lock_irqsave(&xas, flags);
 	entry = xa_head_locked(xa);
 	RCU_INIT_POINTER(xa->xa_head, NULL);